diff --git a/oslo_messaging/_utils.py b/oslo_messaging/_utils.py index 1bb20b089..cec94bb48 100644 --- a/oslo_messaging/_utils.py +++ b/oslo_messaging/_utils.py @@ -116,29 +116,6 @@ def fetch_current_thread_functor(): return lambda: threading.current_thread() -class DummyCondition(object): - def acquire(self): - pass - - def notify(self): - pass - - def notify_all(self): - pass - - def wait(self, timeout=None): - pass - - def release(self): - pass - - def __enter__(self): - self.acquire() - - def __exit__(self, type, value, traceback): - self.release() - - class DummyLock(object): def acquire(self): pass diff --git a/oslo_messaging/server.py b/oslo_messaging/server.py index 491ccbf52..f1739ad90 100644 --- a/oslo_messaging/server.py +++ b/oslo_messaging/server.py @@ -23,16 +23,17 @@ __all__ = [ 'ServerListenError', ] +import functools +import inspect import logging import threading +import traceback from oslo_service import service from oslo_utils import timeutils from stevedore import driver from oslo_messaging._drivers import base as driver_base -from oslo_messaging._i18n import _LW -from oslo_messaging import _utils from oslo_messaging import exceptions LOG = logging.getLogger(__name__) @@ -62,7 +63,170 @@ class ServerListenError(MessagingServerError): self.ex = ex -class MessageHandlingServer(service.ServiceBase): +class _OrderedTask(object): + """A task which must be executed in a particular order. + + A caller may wait for this task to complete by calling + `wait_for_completion`. + + A caller may run this task with `run_once`, which will ensure that however + many times the task is called it only runs once. Simultaneous callers will + block until the running task completes, which means that any caller can be + sure that the task has completed after run_once returns. + """ + + INIT = 0 # The task has not yet started + RUNNING = 1 # The task is running somewhere + COMPLETE = 2 # The task has run somewhere + + # We generate a log message if we wait for a lock longer than + # LOG_AFTER_WAIT_SECS seconds + LOG_AFTER_WAIT_SECS = 30 + + def __init__(self, name): + """Create a new _OrderedTask. + + :param name: The name of this task. Used in log messages. + """ + + super(_OrderedTask, self).__init__() + + self._name = name + self._cond = threading.Condition() + self._state = self.INIT + + def _wait(self, condition, warn_msg): + """Wait while condition() is true. Write a log message if condition() + has not become false within LOG_AFTER_WAIT_SECS. + """ + with timeutils.StopWatch(duration=self.LOG_AFTER_WAIT_SECS) as w: + logged = False + while condition(): + wait = None if logged else w.leftover() + self._cond.wait(wait) + + if not logged and w.expired(): + LOG.warn(warn_msg) + LOG.debug(''.join(traceback.format_stack())) + # Only log once. After than we wait indefinitely without + # logging. + logged = True + + def wait_for_completion(self, caller): + """Wait until this task has completed. + + :param caller: The name of the task which is waiting. + """ + with self._cond: + self._wait(lambda: self._state != self.COMPLETE, + '%s has been waiting for %s to complete for longer ' + 'than %i seconds' + % (caller, self._name, self.LOG_AFTER_WAIT_SECS)) + + def run_once(self, fn): + """Run a task exactly once. If it is currently running in another + thread, wait for it to complete. If it has already run, return + immediately without running it again. + + :param fn: The task to run. It must be a callable taking no arguments. + It may optionally return another callable, which also takes + no arguments, which will be executed after completion has + been signaled to other threads. + """ + with self._cond: + if self._state == self.INIT: + self._state = self.RUNNING + # Note that nothing waits on RUNNING, so no need to notify + + # We need to release the condition lock before calling out to + # prevent deadlocks. Reacquire it immediately afterwards. + self._cond.release() + try: + post_fn = fn() + finally: + self._cond.acquire() + self._state = self.COMPLETE + self._cond.notify_all() + + if post_fn is not None: + # Release the condition lock before calling out to prevent + # deadlocks. Reacquire it immediately afterwards. + self._cond.release() + try: + post_fn() + finally: + self._cond.acquire() + elif self._state == self.RUNNING: + self._wait(lambda: self._state == self.RUNNING, + '%s has been waiting on another thread to complete ' + 'for longer than %i seconds' + % (self._name, self.LOG_AFTER_WAIT_SECS)) + + +class _OrderedTaskRunner(object): + """Mixin for a class which executes ordered tasks.""" + + def __init__(self, *args, **kwargs): + super(_OrderedTaskRunner, self).__init__(*args, **kwargs) + + # Get a list of methods on this object which have the _ordered + # attribute + self._tasks = [name + for (name, member) in inspect.getmembers(self) + if inspect.ismethod(member) and + getattr(member, '_ordered', False)] + self.init_task_states() + + def init_task_states(self): + # Note that we don't need to lock this. Once created, the _states dict + # is immutable. Get and set are (individually) atomic operations in + # Python, and we only set after the dict is fully created. + self._states = {task: _OrderedTask(task) for task in self._tasks} + + @staticmethod + def decorate_ordered(fn, state, after): + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + # Store the states we started with in case the state wraps on us + # while we're sleeping. We must wait and run_once in the same + # epoch. If the epoch ended while we were sleeping, run_once will + # safely do nothing. + states = self._states + + # Wait for the given preceding state to complete + if after is not None: + states[after].wait_for_completion(state) + + # Run this state + states[state].run_once(lambda: fn(self, *args, **kwargs)) + return wrapper + + +def ordered(after=None): + """A method which will be executed as an ordered task. The method will be + called exactly once, however many times it is called. If it is called + multiple times simultaneously it will only be called once, but all callers + will wait until execution is complete. + + If `after` is given, this method will not run until `after` has completed. + + :param after: Optionally, another method decorated with `ordered`. Wait for + the completion of `after` before executing this method. + """ + if after is not None: + after = after.__name__ + + def _ordered(fn): + # Set an attribute on the method so we can find it later + setattr(fn, '_ordered', True) + state = fn.__name__ + + return _OrderedTaskRunner.decorate_ordered(fn, state, after) + return _ordered + + +class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner): """Server for handling messages. Connect a transport to a dispatcher that knows how to process the @@ -94,29 +258,18 @@ class MessageHandlingServer(service.ServiceBase): self.dispatcher = dispatcher self.executor = executor - # NOTE(sileht): we use a lock to protect the state change of the - # server, we don't want to call stop until the transport driver - # is fully started. Except for the blocking executor that have - # start() that doesn't return - if self.executor != "blocking": - self._state_cond = threading.Condition() - self._dummy_cond = False - else: - self._state_cond = _utils.DummyCondition() - self._dummy_cond = True - try: mgr = driver.DriverManager('oslo.messaging.executors', self.executor) except RuntimeError as ex: raise ExecutorLoadFailure(self.executor, ex) - else: - self._executor_cls = mgr.driver - self._executor_obj = None - self._running = False + + self._executor_cls = mgr.driver + self._executor_obj = None super(MessageHandlingServer, self).__init__() + @ordered() def start(self): """Start handling incoming messages. @@ -131,24 +284,21 @@ class MessageHandlingServer(service.ServiceBase): choose to dispatch messages in a new thread, coroutine or simply the current thread. """ - if self._executor_obj is not None: - return - with self._state_cond: - if self._executor_obj is not None: - return - try: - listener = self.dispatcher._listen(self.transport) - except driver_base.TransportDriverError as ex: - raise ServerListenError(self.target, ex) - self._executor_obj = self._executor_cls(self.conf, listener, - self.dispatcher) - self._executor_obj.start() - self._running = True - self._state_cond.notify_all() + try: + listener = self.dispatcher._listen(self.transport) + except driver_base.TransportDriverError as ex: + raise ServerListenError(self.target, ex) + executor = self._executor_cls(self.conf, listener, self.dispatcher) + executor.start() + self._executor_obj = executor if self.executor == 'blocking': - self._executor_obj.execute() + # N.B. This will be executed unlocked and unordered, so + # we can't rely on the value of self._executor_obj when this runs. + # We explicitly pass the local variable. + return lambda: executor.execute() + @ordered(after=start) def stop(self): """Stop handling incoming messages. @@ -157,12 +307,9 @@ class MessageHandlingServer(service.ServiceBase): some messages, and underlying driver resources associated to this server are still in use. See 'wait' for more details. """ - with self._state_cond: - if self._executor_obj is not None: - self._running = False - self._executor_obj.stop() - self._state_cond.notify_all() + self._executor_obj.stop() + @ordered(after=stop) def wait(self): """Wait for message processing to complete. @@ -173,37 +320,14 @@ class MessageHandlingServer(service.ServiceBase): Once it's finished, the underlying driver resources associated to this server are released (like closing useless network connections). """ - with self._state_cond: - if self._running: - LOG.warn(_LW("wait() should be called after stop() as it " - "waits for existing messages to finish " - "processing")) - w = timeutils.StopWatch() - w.start() - while self._running: - # NOTE(harlowja): 1.0 seconds was mostly chosen at - # random, but it seems like a reasonable value to - # use to avoid spamming the logs with to much - # information. - self._state_cond.wait(1.0) - if self._running and not self._dummy_cond: - LOG.warn( - _LW("wait() should have been called" - " after stop() as wait() waits for existing" - " messages to finish processing, it has" - " been %0.2f seconds and stop() still has" - " not been called"), w.elapsed()) - executor = self._executor_obj + try: + self._executor_obj.wait() + finally: + # Close listener connection after processing all messages + self._executor_obj.listener.cleanup() self._executor_obj = None - if executor is not None: - # We are the lucky calling thread to wait on the executor to - # actually finish. - try: - executor.wait() - finally: - # Close listener connection after processing all messages - executor.listener.cleanup() - executor = None + + self.init_task_states() def reset(self): """Reset service. diff --git a/oslo_messaging/tests/rpc/test_server.py b/oslo_messaging/tests/rpc/test_server.py index 258dacb24..1a2d2aa63 100644 --- a/oslo_messaging/tests/rpc/test_server.py +++ b/oslo_messaging/tests/rpc/test_server.py @@ -13,6 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. +import eventlet +import time import threading from oslo_config import cfg @@ -20,6 +22,7 @@ import testscenarios import mock import oslo_messaging +from oslo_messaging import server as server_module from oslo_messaging.tests import utils as test_utils load_tests = testscenarios.load_tests_apply_scenarios @@ -528,3 +531,210 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin): TestMultipleServers.generate_scenarios() + +class TestServerLocking(test_utils.BaseTestCase): + def setUp(self): + super(TestServerLocking, self).setUp(conf=cfg.ConfigOpts()) + + def _logmethod(name): + def method(self): + with self._lock: + self._calls.append(name) + return method + + executors = [] + class FakeExecutor(object): + def __init__(self, *args, **kwargs): + self._lock = threading.Lock() + self._calls = [] + self.listener = mock.MagicMock() + executors.append(self) + + start = _logmethod('start') + stop = _logmethod('stop') + wait = _logmethod('wait') + execute = _logmethod('execute') + self.executors = executors + + self.server = oslo_messaging.MessageHandlingServer(mock.Mock(), + mock.Mock()) + self.server._executor_cls = FakeExecutor + + def test_start_stop_wait(self): + # Test a simple execution of start, stop, wait in order + + thread = eventlet.spawn(self.server.start) + self.server.stop() + self.server.wait() + + self.assertEqual(len(self.executors), 1) + executor = self.executors[0] + self.assertEqual(executor._calls, + ['start', 'execute', 'stop', 'wait']) + self.assertTrue(executor.listener.cleanup.called) + + def test_reversed_order(self): + # Test that if we call wait, stop, start, these will be correctly + # reordered + + wait = eventlet.spawn(self.server.wait) + # This is non-deterministic, but there's not a great deal we can do + # about that + eventlet.sleep(0) + + stop = eventlet.spawn(self.server.stop) + eventlet.sleep(0) + + start = eventlet.spawn(self.server.start) + + self.server.wait() + + self.assertEqual(len(self.executors), 1) + executor = self.executors[0] + self.assertEqual(executor._calls, + ['start', 'execute', 'stop', 'wait']) + + def test_wait_for_running_task(self): + # Test that if 2 threads call a method simultaneously, both will wait, + # but only 1 will call the underlying executor method. + + start_event = threading.Event() + finish_event = threading.Event() + + running_event = threading.Event() + done_event = threading.Event() + + runner = [None] + class SteppingFakeExecutor(self.server._executor_cls): + def start(self): + # Tell the test which thread won the race + runner[0] = eventlet.getcurrent() + running_event.set() + + start_event.wait() + super(SteppingFakeExecutor, self).start() + done_event.set() + + finish_event.wait() + self.server._executor_cls = SteppingFakeExecutor + + start1 = eventlet.spawn(self.server.start) + start2 = eventlet.spawn(self.server.start) + + # Wait until one of the threads starts running + running_event.wait() + runner = runner[0] + waiter = start2 if runner == start1 else start2 + + waiter_finished = threading.Event() + waiter.link(lambda _: waiter_finished.set()) + + # At this point, runner is running start(), and waiter() is waiting for + # it to complete. runner has not yet logged anything. + self.assertEqual(1, len(self.executors)) + executor = self.executors[0] + + self.assertEqual(executor._calls, []) + self.assertFalse(waiter_finished.is_set()) + + # Let the runner log the call + start_event.set() + done_event.wait() + + # We haven't signalled completion yet, so execute shouldn't have run + self.assertEqual(executor._calls, ['start']) + self.assertFalse(waiter_finished.is_set()) + + # Let the runner complete + finish_event.set() + waiter.wait() + runner.wait() + + # Check that both threads have finished, start was only called once, + # and execute ran + self.assertTrue(waiter_finished.is_set()) + self.assertEqual(executor._calls, ['start', 'execute']) + + def test_state_wrapping(self): + # Test that we behave correctly if a thread waits, and the server state + # has wrapped when it it next scheduled + + # Ensure that if 2 threads wait for the completion of 'start', the + # first will wait until complete_event is signalled, but the second + # will continue + complete_event = threading.Event() + complete_waiting_callback = threading.Event() + + start_state = self.server._states['start'] + old_wait_for_completion = start_state.wait_for_completion + waited = [False] + def new_wait_for_completion(*args, **kwargs): + if not waited[0]: + waited[0] = True + complete_waiting_callback.set() + complete_event.wait() + old_wait_for_completion(*args, **kwargs) + start_state.wait_for_completion = new_wait_for_completion + + # thread1 will wait for start to complete until we signal it + thread1 = eventlet.spawn(self.server.stop) + thread1_finished = threading.Event() + thread1.link(lambda _: thread1_finished.set()) + + self.server.start() + complete_waiting_callback.wait() + + # The server should have started, but stop should not have been called + self.assertEqual(1, len(self.executors)) + self.assertEqual(self.executors[0]._calls, ['start', 'execute']) + self.assertFalse(thread1_finished.is_set()) + + self.server.stop() + self.server.wait() + + # We should have gone through all the states, and thread1 should still + # be waiting + self.assertEqual(1, len(self.executors)) + self.assertEqual(self.executors[0]._calls, ['start', 'execute', + 'stop', 'wait']) + self.assertFalse(thread1_finished.is_set()) + + # Start again + self.server.start() + + # We should now record 2 executors + self.assertEqual(2, len(self.executors)) + self.assertEqual(self.executors[0]._calls, ['start', 'execute', + 'stop', 'wait']) + self.assertEqual(self.executors[1]._calls, ['start', 'execute']) + self.assertFalse(thread1_finished.is_set()) + + # Allow thread1 to complete + complete_event.set() + thread1_finished.wait() + + # thread1 should now have finished, and stop should not have been + # called again on either the first or second executor + self.assertEqual(2, len(self.executors)) + self.assertEqual(self.executors[0]._calls, ['start', 'execute', + 'stop', 'wait']) + self.assertEqual(self.executors[1]._calls, ['start', 'execute']) + self.assertTrue(thread1_finished.is_set()) + + @mock.patch.object(server_module._OrderedTask, + 'LOG_AFTER_WAIT_SECS', 1) + @mock.patch.object(server_module, 'LOG') + def test_timeout_logging(self, mock_log): + # Test that we generate a log message if we wait longer than + # LOG_AFTER_WAIT_SECS + + log_event = threading.Event() + mock_log.warn.side_effect = lambda _: log_event.set() + + # Call stop without calling start. We should log a wait after 1 second + thread = eventlet.spawn(self.server.stop) + log_event.wait() + + # Redundant given that we already waited, but it's nice to assert + self.assertTrue(mock_log.warn.called) + thread.kill()