diff --git a/oslo_messaging/_drivers/impl_rabbit.py b/oslo_messaging/_drivers/impl_rabbit.py index 66ee83eee..af3f6540f 100644 --- a/oslo_messaging/_drivers/impl_rabbit.py +++ b/oslo_messaging/_drivers/impl_rabbit.py @@ -40,6 +40,7 @@ from oslo_messaging._i18n import _ from oslo_messaging._i18n import _LE from oslo_messaging._i18n import _LI from oslo_messaging._i18n import _LW +from oslo_messaging import _utils from oslo_messaging import exceptions @@ -309,7 +310,7 @@ class ConnectionLock(DummyConnectionLock): self._monitor = threading.Lock() self._workers_locks = threading.Condition(self._monitor) self._heartbeat_lock = threading.Condition(self._monitor) - self._get_thread_id = self._fetch_current_thread_functor() + self._get_thread_id = _utils.fetch_current_thread_functor() def acquire(self): with self._monitor: @@ -351,25 +352,6 @@ class ConnectionLock(DummyConnectionLock): finally: self.release() - @staticmethod - def _fetch_current_thread_functor(): - # Until https://github.com/eventlet/eventlet/issues/172 is resolved - # or addressed we have to use complicated workaround to get a object - # that will not be recycled; the usage of threading.current_thread() - # doesn't appear to currently be monkey patched and therefore isn't - # reliable to use (and breaks badly when used as all threads share - # the same current_thread() object)... - try: - import eventlet - from eventlet import patcher - green_threaded = patcher.is_monkey_patched('thread') - except ImportError: - green_threaded = False - if green_threaded: - return lambda: eventlet.getcurrent() - else: - return lambda: threading.current_thread() - class Connection(object): """Connection object.""" diff --git a/oslo_messaging/_utils.py b/oslo_messaging/_utils.py index ddec6d7a7..1c816de54 100644 --- a/oslo_messaging/_utils.py +++ b/oslo_messaging/_utils.py @@ -14,6 +14,7 @@ # under the License. import logging +import threading LOG = logging.getLogger(__name__) @@ -94,3 +95,22 @@ class DispatcherExecutorContext(object): # else if self._post is not None: self._post(self._incoming, self._result) + + +def fetch_current_thread_functor(): + # Until https://github.com/eventlet/eventlet/issues/172 is resolved + # or addressed we have to use complicated workaround to get a object + # that will not be recycled; the usage of threading.current_thread() + # doesn't appear to currently be monkey patched and therefore isn't + # reliable to use (and breaks badly when used as all threads share + # the same current_thread() object)... + try: + import eventlet + from eventlet import patcher + green_threaded = patcher.is_monkey_patched('thread') + except ImportError: + green_threaded = False + if green_threaded: + return lambda: eventlet.getcurrent() + else: + return lambda: threading.current_thread() diff --git a/oslo_messaging/server.py b/oslo_messaging/server.py index 087537185..c5b4093d4 100644 --- a/oslo_messaging/server.py +++ b/oslo_messaging/server.py @@ -27,6 +27,8 @@ from oslo_service import service from stevedore import driver from oslo_messaging._drivers import base as driver_base +from oslo_messaging._i18n import _ +from oslo_messaging import _utils from oslo_messaging import exceptions @@ -86,6 +88,8 @@ class MessageHandlingServer(service.ServiceBase): self.dispatcher = dispatcher self.executor = executor + self._get_thread_id = _utils.fetch_current_thread_functor() + try: mgr = driver.DriverManager('oslo.messaging.executors', self.executor) @@ -94,6 +98,8 @@ class MessageHandlingServer(service.ServiceBase): else: self._executor_cls = mgr.driver self._executor = None + self._running = False + self._thread_id = None super(MessageHandlingServer, self).__init__() @@ -111,6 +117,8 @@ class MessageHandlingServer(service.ServiceBase): choose to dispatch messages in a new thread, coroutine or simply the current thread. """ + self._check_same_thread_id() + if self._executor is not None: return try: @@ -118,10 +126,18 @@ class MessageHandlingServer(service.ServiceBase): except driver_base.TransportDriverError as ex: raise ServerListenError(self.target, ex) + self._running = True self._executor = self._executor_cls(self.conf, listener, self.dispatcher) self._executor.start() + def _check_same_thread_id(self): + if self._thread_id is None: + self._thread_id = self._get_thread_id() + elif self._thread_id != self._get_thread_id(): + raise RuntimeError(_("start/stop/wait must be called in the " + "same thread")) + def stop(self): """Stop handling incoming messages. @@ -130,7 +146,10 @@ class MessageHandlingServer(service.ServiceBase): some messages, and underlying driver resources associated to this server are still in use. See 'wait' for more details. """ + self._check_same_thread_id() + if self._executor is not None: + self._running = False self._executor.stop() def wait(self): @@ -143,12 +162,22 @@ class MessageHandlingServer(service.ServiceBase): Once it's finished, the underlying driver resources associated to this server are released (like closing useless network connections). """ + self._check_same_thread_id() + + if self._running: + raise RuntimeError(_("wait() should be called after stop() as it " + "waits for existing messages to finish " + "processing")) + if self._executor is not None: self._executor.wait() # Close listener connection after processing all messages self._executor.listener.cleanup() self._executor = None + # NOTE(sileht): executor/listener have been properly stopped + # allow to restart it into another thread + self._thread_id = None def reset(self): """Reset service. diff --git a/oslo_messaging/tests/functional/utils.py b/oslo_messaging/tests/functional/utils.py index 7c82939e8..ad9fd11a9 100644 --- a/oslo_messaging/tests/functional/utils.py +++ b/oslo_messaging/tests/functional/utils.py @@ -12,7 +12,6 @@ # under the License. import os -import threading import time import uuid @@ -76,7 +75,7 @@ class RpcServerFixture(fixtures.Fixture): """Fixture to setup the TestServerEndpoint.""" def __init__(self, url, target, endpoint=None, ctrl_target=None, - executor='blocking'): + executor='eventlet'): super(RpcServerFixture, self).__init__() self.url = url self.target = target @@ -104,14 +103,12 @@ class RpcServerFixture(fixtures.Fixture): super(RpcServerFixture, self).cleanUp() def _start(self): - self.thread = threading.Thread(target=self.server.start) - self.thread.daemon = True + self.thread = test_utils.ServerThreadHelper(self.server) self.thread.start() def _stop(self): - self.server.stop() + self.thread.stop() self._ctrl.cast({}, 'ping') - self.server.wait() self.thread.join() def ping(self, ctxt): @@ -308,7 +305,7 @@ class NotificationFixture(fixtures.Fixture): self.server = oslo_messaging.get_notification_listener( transport.transport, targets, - [self]) + [self], 'eventlet') self._ctrl = self.notifier('internal', topic=self.name) self._start() transport.wait() @@ -318,14 +315,12 @@ class NotificationFixture(fixtures.Fixture): super(NotificationFixture, self).cleanUp() def _start(self): - self.thread = threading.Thread(target=self.server.start) - self.thread.daemon = True + self.thread = test_utils.ServerThreadHelper(self.server) self.thread.start() def _stop(self): - self.server.stop() + self.thread.stop() self._ctrl.sample({}, 'shutdown', 'shutdown') - self.server.wait() self.thread.join() def notifier(self, publisher, topic=None): diff --git a/oslo_messaging/tests/notify/test_listener.py b/oslo_messaging/tests/notify/test_listener.py index 20f8911fc..1644936ad 100644 --- a/oslo_messaging/tests/notify/test_listener.py +++ b/oslo_messaging/tests/notify/test_listener.py @@ -31,32 +31,15 @@ class RestartableServerThread(object): def __init__(self, server): self.server = server self.thread = None - self._started = threading.Event() - self._tombstone = threading.Event() def start(self): - self._tombstone.clear() if self.thread is None: - self._started.clear() - self.thread = threading.Thread(target=self._target) - self.thread.daemon = True + self.thread = test_utils.ServerThreadHelper(self.server) self.thread.start() - self._started.wait() - - def _target(self): - self.server.start() - self._started.set() - self._tombstone.wait() def stop(self): if self.thread is not None: - # Check start() does nothing with a running listener - self.server.start() - - self._tombstone.set() - if self.thread is not None: - self.server.stop() - self.server.wait() + self.thread.stop() self.thread.join(timeout=15) ret = self.thread.isAlive() self.thread = None diff --git a/oslo_messaging/tests/rpc/test_server.py b/oslo_messaging/tests/rpc/test_server.py index 5c56285dc..2601f4f8a 100644 --- a/oslo_messaging/tests/rpc/test_server.py +++ b/oslo_messaging/tests/rpc/test_server.py @@ -130,6 +130,37 @@ class TestRPCServer(test_utils.BaseTestCase, ServerSetupMixin): self.assertIsNone(server._executor) self.assertEqual(1, listener.cleanup.call_count) + def test_server_invalid_wait_running_server(self): + transport = oslo_messaging.get_transport(self.conf, url='fake:') + target = oslo_messaging.Target(topic='foo', server='bar') + endpoints = [object()] + serializer = object() + + server = oslo_messaging.get_rpc_server(transport, target, endpoints, + serializer=serializer, + executor='eventlet') + self.addCleanup(server.wait) + self.addCleanup(server.stop) + server.start() + self.assertRaises(RuntimeError, server.wait) + + def test_server_invalid_stop_from_other_thread(self): + transport = oslo_messaging.get_transport(self.conf, url='fake:') + target = oslo_messaging.Target(topic='foo', server='bar') + endpoints = [object()] + serializer = object() + + server = oslo_messaging.get_rpc_server(transport, target, endpoints, + serializer=serializer, + executor='eventlet') + + t = test_utils.ServerThreadHelper(server) + t.start() + self.addCleanup(t.join) + self.addCleanup(t.stop) + self.assertRaises(RuntimeError, server.stop) + self.assertRaises(RuntimeError, server.wait) + def test_no_target_server(self): transport = oslo_messaging.get_transport(self.conf, url='fake:') diff --git a/oslo_messaging/tests/utils.py b/oslo_messaging/tests/utils.py index bfa73a8c3..8ea89c5ed 100644 --- a/oslo_messaging/tests/utils.py +++ b/oslo_messaging/tests/utils.py @@ -19,6 +19,8 @@ """Common utilities used in testing""" +import threading + from oslo_config import cfg from oslotest import base from oslotest import moxstubout @@ -56,3 +58,24 @@ class BaseTestCase(base.BaseTestCase): group = kw.pop('group', None) for k, v in six.iteritems(kw): self.conf.set_override(k, v, group) + + +class ServerThreadHelper(threading.Thread): + def __init__(self, server): + super(ServerThreadHelper, self).__init__() + self.daemon = True + self._server = server + self._stop_event = threading.Event() + self._wait_event = threading.Event() + + def run(self): + self._server.start() + self._stop_event.wait() + # Check start() does nothing with a running listener + self._server.start() + self._server.stop() + self._server.wait() + self._wait_event.set() + + def stop(self): + self._stop_event.set()