From 00d07f5205c758757bb854372c5576e62a5f57d6 Mon Sep 17 00:00:00 2001 From: Matthew Booth Date: Mon, 19 Oct 2015 14:11:23 +0100 Subject: [PATCH] Robustify locking in MessageHandlingServer This change formalises locking in MessageHandlingServer, which closes several bugs: * It adds locking for internal state when using the blocking executor, which closes a number of races. * It does not hold a lock while executing server functions, which removes a potential cause of deadlock if the server does its own locking. * It fixes a regression introduced in change gI3cfbe1bf02d451e379b1dcc23dacb0139c03be76. If multiple threads called wait() simultaneously, only 1 of them would wait and the others would return immediately, despite message handling not having completed. With this change only 1 will call the underlying wait, but all will wait on its completion. Additionally, it introduces some new functionality: * It allows the user to make calls in any order and it will ensure, with locking, that these will be reordered appropriately. * The caller can pass a `timeout` argument to any server method, which will cause it to raise an exception if it waits too long. * The caller can pass a `log_after` argument to any server method, which will cause it to raise a log message if it waits too long. It can also be used to disable logging when waiting is intentional. We remove DummyCondition as it no longer has any users. This change was originally committed as change I9d516b208446963dcd80b75e2d5a2cecb1187efa, but was reverted as it caused a hang in a Nova test. This was caused by the locking behaviour for handling restarting a previously stopped server. The original patch caused the state to 'wrap' immediately after the user called wait(). This caused a hang in tests which redundantly called stop() and wait() multiple times. This new patch only wraps when the user calls start() again. Callers who do not restart a server will therefore not be affected by the wrapping behaviour. Callers who do restart a server will be no worse than before. We add a deprecation warning on restart, as this operation is inherently racy with this api and there is a simple, safe alternative. This new version has been successfully tested against the unit and functional tests of nova, cinder, glance, and ceilometer. Change-Id: Ic79f87e7b069c1f62d6121486fd6cafd732fdde7 --- oslo_messaging/_utils.py | 23 -- oslo_messaging/server.py | 352 +++++++++++++++++++----- oslo_messaging/tests/rpc/test_server.py | 302 ++++++++++++++++++++ 3 files changed, 585 insertions(+), 92 deletions(-) diff --git a/oslo_messaging/_utils.py b/oslo_messaging/_utils.py index 1bb20b089..cec94bb48 100644 --- a/oslo_messaging/_utils.py +++ b/oslo_messaging/_utils.py @@ -116,29 +116,6 @@ def fetch_current_thread_functor(): return lambda: threading.current_thread() -class DummyCondition(object): - def acquire(self): - pass - - def notify(self): - pass - - def notify_all(self): - pass - - def wait(self, timeout=None): - pass - - def release(self): - pass - - def __enter__(self): - self.acquire() - - def __exit__(self, type, value, traceback): - self.release() - - class DummyLock(object): def acquire(self): pass diff --git a/oslo_messaging/server.py b/oslo_messaging/server.py index 491ccbf52..804bcf105 100644 --- a/oslo_messaging/server.py +++ b/oslo_messaging/server.py @@ -23,20 +23,25 @@ __all__ = [ 'ServerListenError', ] +import functools +import inspect import logging import threading +import traceback from oslo_service import service from oslo_utils import timeutils from stevedore import driver from oslo_messaging._drivers import base as driver_base -from oslo_messaging._i18n import _LW -from oslo_messaging import _utils from oslo_messaging import exceptions LOG = logging.getLogger(__name__) +# The default number of seconds of waiting after which we will emit a log +# message +DEFAULT_LOG_AFTER = 30 + class MessagingServerError(exceptions.MessagingException): """Base class for all MessageHandlingServer exceptions.""" @@ -62,7 +67,223 @@ class ServerListenError(MessagingServerError): self.ex = ex -class MessageHandlingServer(service.ServiceBase): +class TaskTimeout(MessagingServerError): + """Raised if we timed out waiting for a task to complete.""" + + +class _OrderedTask(object): + """A task which must be executed in a particular order. + + A caller may wait for this task to complete by calling + `wait_for_completion`. + + A caller may run this task with `run_once`, which will ensure that however + many times the task is called it only runs once. Simultaneous callers will + block until the running task completes, which means that any caller can be + sure that the task has completed after run_once returns. + """ + + INIT = 0 # The task has not yet started + RUNNING = 1 # The task is running somewhere + COMPLETE = 2 # The task has run somewhere + + def __init__(self, name): + """Create a new _OrderedTask. + + :param name: The name of this task. Used in log messages. + """ + super(_OrderedTask, self).__init__() + + self._name = name + self._cond = threading.Condition() + self._state = self.INIT + + def _wait(self, condition, msg, log_after, timeout_timer): + """Wait while condition() is true. Write a log message if condition() + has not become false within `log_after` seconds. Raise TaskTimeout if + timeout_timer expires while waiting. + """ + + log_timer = None + if log_after != 0: + log_timer = timeutils.StopWatch(duration=log_after) + log_timer.start() + + while condition(): + if log_timer is not None and log_timer.expired(): + LOG.warn('Possible hang: %s' % msg) + LOG.debug(''.join(traceback.format_stack())) + # Only log once. After than we wait indefinitely without + # logging. + log_timer = None + + if timeout_timer is not None and timeout_timer.expired(): + raise TaskTimeout(msg) + + timeouts = [] + if log_timer is not None: + timeouts.append(log_timer.leftover()) + if timeout_timer is not None: + timeouts.append(timeout_timer.leftover()) + + wait = None + if timeouts: + wait = min(timeouts) + self._cond.wait(wait) + + @property + def complete(self): + return self._state == self.COMPLETE + + def wait_for_completion(self, caller, log_after, timeout_timer): + """Wait until this task has completed. + + :param caller: The name of the task which is waiting. + :param log_after: Emit a log message if waiting longer than `log_after` + seconds. + :param timeout_timer: Raise TaskTimeout if StopWatch object + `timeout_timer` expires while waiting. + """ + with self._cond: + msg = '%s is waiting for %s to complete' % (caller, self._name) + self._wait(lambda: not self.complete, + msg, log_after, timeout_timer) + + def run_once(self, fn, log_after, timeout_timer): + """Run a task exactly once. If it is currently running in another + thread, wait for it to complete. If it has already run, return + immediately without running it again. + + :param fn: The task to run. It must be a callable taking no arguments. + It may optionally return another callable, which also takes + no arguments, which will be executed after completion has + been signaled to other threads. + :param log_after: Emit a log message if waiting longer than `log_after` + seconds. + :param timeout_timer: Raise TaskTimeout if StopWatch object + `timeout_timer` expires while waiting. + """ + with self._cond: + if self._state == self.INIT: + self._state = self.RUNNING + # Note that nothing waits on RUNNING, so no need to notify + + # We need to release the condition lock before calling out to + # prevent deadlocks. Reacquire it immediately afterwards. + self._cond.release() + try: + post_fn = fn() + finally: + self._cond.acquire() + self._state = self.COMPLETE + self._cond.notify_all() + + if post_fn is not None: + # Release the condition lock before calling out to prevent + # deadlocks. Reacquire it immediately afterwards. + self._cond.release() + try: + post_fn() + finally: + self._cond.acquire() + elif self._state == self.RUNNING: + msg = ('%s is waiting for another thread to complete' + % self._name) + self._wait(lambda: self._state == self.RUNNING, + msg, log_after, timeout_timer) + + +class _OrderedTaskRunner(object): + """Mixin for a class which executes ordered tasks.""" + + def __init__(self, *args, **kwargs): + super(_OrderedTaskRunner, self).__init__(*args, **kwargs) + + # Get a list of methods on this object which have the _ordered + # attribute + self._tasks = [name + for (name, member) in inspect.getmembers(self) + if inspect.ismethod(member) and + getattr(member, '_ordered', False)] + self.reset_states() + + self._reset_lock = threading.Lock() + + def reset_states(self): + # Create new task states for tasks in reset + self._states = {task: _OrderedTask(task) for task in self._tasks} + + @staticmethod + def decorate_ordered(fn, state, after, reset_after): + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + # If the reset_after state has already completed, reset state so + # we can run again. + # NOTE(mdbooth): This is ugly and requires external locking to be + # deterministic when using multiple threads. Consider a thread that + # does: server.stop(), server.wait(). If another thread causes a + # reset between stop() and wait(), this will not have the intended + # behaviour. It is safe without external locking, if the caller + # instantiates a new object. + with self._reset_lock: + if (reset_after is not None and + self._states[reset_after].complete): + self.reset_states() + + # Store the states we started with in case the state wraps on us + # while we're sleeping. We must wait and run_once in the same + # epoch. If the epoch ended while we were sleeping, run_once will + # safely do nothing. + states = self._states + + log_after = kwargs.pop('log_after', DEFAULT_LOG_AFTER) + timeout = kwargs.pop('timeout', None) + + timeout_timer = None + if timeout is not None: + timeout_timer = timeutils.StopWatch(duration=timeout) + timeout_timer.start() + + # Wait for the given preceding state to complete + if after is not None: + states[after].wait_for_completion(state, + log_after, timeout_timer) + + # Run this state + states[state].run_once(lambda: fn(self, *args, **kwargs), + log_after, timeout_timer) + return wrapper + + +def ordered(after=None, reset_after=None): + """A method which will be executed as an ordered task. The method will be + called exactly once, however many times it is called. If it is called + multiple times simultaneously it will only be called once, but all callers + will wait until execution is complete. + + If `after` is given, this method will not run until `after` has completed. + + If `reset_after` is given and the target method has completed, allow this + task to run again by resetting all task states. + + :param after: Optionally, the name of another `ordered` method. Wait for + the completion of `after` before executing this method. + :param reset_after: Optionally, the name of another `ordered` method. Reset + all states when calling this method if `reset_after` + has completed. + """ + def _ordered(fn): + # Set an attribute on the method so we can find it later + setattr(fn, '_ordered', True) + state = fn.__name__ + + return _OrderedTaskRunner.decorate_ordered(fn, state, after, + reset_after) + return _ordered + + +class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner): """Server for handling messages. Connect a transport to a dispatcher that knows how to process the @@ -94,29 +315,20 @@ class MessageHandlingServer(service.ServiceBase): self.dispatcher = dispatcher self.executor = executor - # NOTE(sileht): we use a lock to protect the state change of the - # server, we don't want to call stop until the transport driver - # is fully started. Except for the blocking executor that have - # start() that doesn't return - if self.executor != "blocking": - self._state_cond = threading.Condition() - self._dummy_cond = False - else: - self._state_cond = _utils.DummyCondition() - self._dummy_cond = True - try: mgr = driver.DriverManager('oslo.messaging.executors', self.executor) except RuntimeError as ex: raise ExecutorLoadFailure(self.executor, ex) - else: - self._executor_cls = mgr.driver - self._executor_obj = None - self._running = False + + self._executor_cls = mgr.driver + self._executor_obj = None + + self._started = False super(MessageHandlingServer, self).__init__() + @ordered(reset_after='stop') def start(self): """Start handling incoming messages. @@ -130,25 +342,39 @@ class MessageHandlingServer(service.ServiceBase): registering a callback with an event loop. Similarly, the executor may choose to dispatch messages in a new thread, coroutine or simply the current thread. + + :param log_after: Emit a log message if waiting longer than `log_after` + seconds to run this task. If set to zero, no log + message will be emitted. Defaults to 30 seconds. + :type log_after: int + :param timeout: Raise `TaskTimeout` if the task has to wait longer than + `timeout` seconds before executing. + :type timeout: int """ - if self._executor_obj is not None: - return - with self._state_cond: - if self._executor_obj is not None: - return - try: - listener = self.dispatcher._listen(self.transport) - except driver_base.TransportDriverError as ex: - raise ServerListenError(self.target, ex) - self._executor_obj = self._executor_cls(self.conf, listener, - self.dispatcher) - self._executor_obj.start() - self._running = True - self._state_cond.notify_all() + # Warn that restarting will be deprecated + if self._started: + LOG.warn('Restarting a MessageHandlingServer is inherently racy. ' + 'It is deprecated, and will become a noop in a future ' + 'release of oslo.messaging. If you need to restart ' + 'MessageHandlingServer you should instantiate a new ' + 'object.') + self._started = True + + try: + listener = self.dispatcher._listen(self.transport) + except driver_base.TransportDriverError as ex: + raise ServerListenError(self.target, ex) + executor = self._executor_cls(self.conf, listener, self.dispatcher) + executor.start() + self._executor_obj = executor if self.executor == 'blocking': - self._executor_obj.execute() + # N.B. This will be executed unlocked and unordered, so + # we can't rely on the value of self._executor_obj when this runs. + # We explicitly pass the local variable. + return lambda: executor.execute() + @ordered(after='start') def stop(self): """Stop handling incoming messages. @@ -156,13 +382,18 @@ class MessageHandlingServer(service.ServiceBase): the server. However, the server may still be in the process of handling some messages, and underlying driver resources associated to this server are still in use. See 'wait' for more details. - """ - with self._state_cond: - if self._executor_obj is not None: - self._running = False - self._executor_obj.stop() - self._state_cond.notify_all() + :param log_after: Emit a log message if waiting longer than `log_after` + seconds to run this task. If set to zero, no log + message will be emitted. Defaults to 30 seconds. + :type log_after: int + :param timeout: Raise `TaskTimeout` if the task has to wait longer than + `timeout` seconds before executing. + :type timeout: int + """ + self._executor_obj.stop() + + @ordered(after='stop') def wait(self): """Wait for message processing to complete. @@ -172,38 +403,21 @@ class MessageHandlingServer(service.ServiceBase): Once it's finished, the underlying driver resources associated to this server are released (like closing useless network connections). + + :param log_after: Emit a log message if waiting longer than `log_after` + seconds to run this task. If set to zero, no log + message will be emitted. Defaults to 30 seconds. + :type log_after: int + :param timeout: Raise `TaskTimeout` if the task has to wait longer than + `timeout` seconds before executing. + :type timeout: int """ - with self._state_cond: - if self._running: - LOG.warn(_LW("wait() should be called after stop() as it " - "waits for existing messages to finish " - "processing")) - w = timeutils.StopWatch() - w.start() - while self._running: - # NOTE(harlowja): 1.0 seconds was mostly chosen at - # random, but it seems like a reasonable value to - # use to avoid spamming the logs with to much - # information. - self._state_cond.wait(1.0) - if self._running and not self._dummy_cond: - LOG.warn( - _LW("wait() should have been called" - " after stop() as wait() waits for existing" - " messages to finish processing, it has" - " been %0.2f seconds and stop() still has" - " not been called"), w.elapsed()) - executor = self._executor_obj + try: + self._executor_obj.wait() + finally: + # Close listener connection after processing all messages + self._executor_obj.listener.cleanup() self._executor_obj = None - if executor is not None: - # We are the lucky calling thread to wait on the executor to - # actually finish. - try: - executor.wait() - finally: - # Close listener connection after processing all messages - executor.listener.cleanup() - executor = None def reset(self): """Reset service. diff --git a/oslo_messaging/tests/rpc/test_server.py b/oslo_messaging/tests/rpc/test_server.py index 258dacb24..846ea86e2 100644 --- a/oslo_messaging/tests/rpc/test_server.py +++ b/oslo_messaging/tests/rpc/test_server.py @@ -13,6 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. +import eventlet +import time import threading from oslo_config import cfg @@ -20,6 +22,7 @@ import testscenarios import mock import oslo_messaging +from oslo_messaging import server as server_module from oslo_messaging.tests import utils as test_utils load_tests = testscenarios.load_tests_apply_scenarios @@ -528,3 +531,302 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin): TestMultipleServers.generate_scenarios() + +class TestServerLocking(test_utils.BaseTestCase): + def setUp(self): + super(TestServerLocking, self).setUp(conf=cfg.ConfigOpts()) + + def _logmethod(name): + def method(self): + with self._lock: + self._calls.append(name) + return method + + executors = [] + class FakeExecutor(object): + def __init__(self, *args, **kwargs): + self._lock = threading.Lock() + self._calls = [] + self.listener = mock.MagicMock() + executors.append(self) + + start = _logmethod('start') + stop = _logmethod('stop') + wait = _logmethod('wait') + execute = _logmethod('execute') + self.executors = executors + + self.server = oslo_messaging.MessageHandlingServer(mock.Mock(), + mock.Mock()) + self.server._executor_cls = FakeExecutor + + def test_start_stop_wait(self): + # Test a simple execution of start, stop, wait in order + + thread = eventlet.spawn(self.server.start) + self.server.stop() + self.server.wait() + + self.assertEqual(len(self.executors), 1) + executor = self.executors[0] + self.assertEqual(executor._calls, + ['start', 'execute', 'stop', 'wait']) + self.assertTrue(executor.listener.cleanup.called) + + def test_reversed_order(self): + # Test that if we call wait, stop, start, these will be correctly + # reordered + + wait = eventlet.spawn(self.server.wait) + # This is non-deterministic, but there's not a great deal we can do + # about that + eventlet.sleep(0) + + stop = eventlet.spawn(self.server.stop) + eventlet.sleep(0) + + start = eventlet.spawn(self.server.start) + + self.server.wait() + + self.assertEqual(len(self.executors), 1) + executor = self.executors[0] + self.assertEqual(executor._calls, + ['start', 'execute', 'stop', 'wait']) + + def test_wait_for_running_task(self): + # Test that if 2 threads call a method simultaneously, both will wait, + # but only 1 will call the underlying executor method. + + start_event = threading.Event() + finish_event = threading.Event() + + running_event = threading.Event() + done_event = threading.Event() + + runner = [None] + class SteppingFakeExecutor(self.server._executor_cls): + def start(self): + # Tell the test which thread won the race + runner[0] = eventlet.getcurrent() + running_event.set() + + start_event.wait() + super(SteppingFakeExecutor, self).start() + done_event.set() + + finish_event.wait() + self.server._executor_cls = SteppingFakeExecutor + + start1 = eventlet.spawn(self.server.start) + start2 = eventlet.spawn(self.server.start) + + # Wait until one of the threads starts running + running_event.wait() + runner = runner[0] + waiter = start2 if runner == start1 else start2 + + waiter_finished = threading.Event() + waiter.link(lambda _: waiter_finished.set()) + + # At this point, runner is running start(), and waiter() is waiting for + # it to complete. runner has not yet logged anything. + self.assertEqual(1, len(self.executors)) + executor = self.executors[0] + + self.assertEqual(executor._calls, []) + self.assertFalse(waiter_finished.is_set()) + + # Let the runner log the call + start_event.set() + done_event.wait() + + # We haven't signalled completion yet, so execute shouldn't have run + self.assertEqual(executor._calls, ['start']) + self.assertFalse(waiter_finished.is_set()) + + # Let the runner complete + finish_event.set() + waiter.wait() + runner.wait() + + # Check that both threads have finished, start was only called once, + # and execute ran + self.assertTrue(waiter_finished.is_set()) + self.assertEqual(executor._calls, ['start', 'execute']) + + def test_start_stop_wait_stop_wait(self): + # Test that we behave correctly when calling stop/wait more than once. + # Subsequent calls should be noops. + + self.server.start() + self.server.stop() + self.server.wait() + self.server.stop() + self.server.wait() + + self.assertEqual(len(self.executors), 1) + executor = self.executors[0] + self.assertEqual(executor._calls, + ['start', 'execute', 'stop', 'wait']) + self.assertTrue(executor.listener.cleanup.called) + + def test_state_wrapping(self): + # Test that we behave correctly if a thread waits, and the server state + # has wrapped when it it next scheduled + + # Ensure that if 2 threads wait for the completion of 'start', the + # first will wait until complete_event is signalled, but the second + # will continue + complete_event = threading.Event() + complete_waiting_callback = threading.Event() + + start_state = self.server._states['start'] + old_wait_for_completion = start_state.wait_for_completion + waited = [False] + def new_wait_for_completion(*args, **kwargs): + if not waited[0]: + waited[0] = True + complete_waiting_callback.set() + complete_event.wait() + old_wait_for_completion(*args, **kwargs) + start_state.wait_for_completion = new_wait_for_completion + + # thread1 will wait for start to complete until we signal it + thread1 = eventlet.spawn(self.server.stop) + thread1_finished = threading.Event() + thread1.link(lambda _: thread1_finished.set()) + + self.server.start() + complete_waiting_callback.wait() + + # The server should have started, but stop should not have been called + self.assertEqual(1, len(self.executors)) + self.assertEqual(self.executors[0]._calls, ['start', 'execute']) + self.assertFalse(thread1_finished.is_set()) + + self.server.stop() + self.server.wait() + + # We should have gone through all the states, and thread1 should still + # be waiting + self.assertEqual(1, len(self.executors)) + self.assertEqual(self.executors[0]._calls, ['start', 'execute', + 'stop', 'wait']) + self.assertFalse(thread1_finished.is_set()) + + # Start again + self.server.start() + + # We should now record 2 executors + self.assertEqual(2, len(self.executors)) + self.assertEqual(self.executors[0]._calls, ['start', 'execute', + 'stop', 'wait']) + self.assertEqual(self.executors[1]._calls, ['start', 'execute']) + self.assertFalse(thread1_finished.is_set()) + + # Allow thread1 to complete + complete_event.set() + thread1_finished.wait() + + # thread1 should now have finished, and stop should not have been + # called again on either the first or second executor + self.assertEqual(2, len(self.executors)) + self.assertEqual(self.executors[0]._calls, ['start', 'execute', + 'stop', 'wait']) + self.assertEqual(self.executors[1]._calls, ['start', 'execute']) + self.assertTrue(thread1_finished.is_set()) + + @mock.patch.object(server_module, 'DEFAULT_LOG_AFTER', 1) + @mock.patch.object(server_module, 'LOG') + def test_logging(self, mock_log): + # Test that we generate a log message if we wait longer than + # DEFAULT_LOG_AFTER + + log_event = threading.Event() + mock_log.warn.side_effect = lambda _: log_event.set() + + # Call stop without calling start. We should log a wait after 1 second + thread = eventlet.spawn(self.server.stop) + log_event.wait() + + # Redundant given that we already waited, but it's nice to assert + self.assertTrue(mock_log.warn.called) + thread.kill() + + @mock.patch.object(server_module, 'LOG') + def test_logging_explicit_wait(self, mock_log): + # Test that we generate a log message if we wait longer than + # the number of seconds passed to log_after + + log_event = threading.Event() + mock_log.warn.side_effect = lambda _: log_event.set() + + # Call stop without calling start. We should log a wait after 1 second + thread = eventlet.spawn(self.server.stop, log_after=1) + log_event.wait() + + # Redundant given that we already waited, but it's nice to assert + self.assertTrue(mock_log.warn.called) + thread.kill() + + @mock.patch.object(server_module, 'LOG') + def test_logging_with_timeout(self, mock_log): + # Test that we log a message after log_after seconds if we've also + # specified an absolute timeout + + log_event = threading.Event() + mock_log.warn.side_effect = lambda _: log_event.set() + + # Call stop without calling start. We should log a wait after 1 second + thread = eventlet.spawn(self.server.stop, log_after=1, timeout=2) + log_event.wait() + + # Redundant given that we already waited, but it's nice to assert + self.assertTrue(mock_log.warn.called) + thread.kill() + + def test_timeout_wait(self): + # Test that we will eventually timeout when passing the timeout option + # if a preceding condition is not satisfied. + + self.assertRaises(server_module.TaskTimeout, + self.server.stop, timeout=1) + + def test_timeout_running(self): + # Test that we will eventually timeout if we're waiting for another + # thread to complete this task + + # Start the server, which will also instantiate an executor + self.server.start() + + stop_called = threading.Event() + + # Patch the executor's stop method to be very slow + def slow_stop(): + stop_called.set() + eventlet.sleep(10) + self.executors[0].stop = slow_stop + + # Call stop in a new thread + thread = eventlet.spawn(self.server.stop) + + # Wait until the thread is in the slow stop method + stop_called.wait() + + # Call stop again in the main thread with a timeout + self.assertRaises(server_module.TaskTimeout, + self.server.stop, timeout=1) + thread.kill() + + @mock.patch.object(server_module, 'LOG') + def test_log_after_zero(self, mock_log): + # Test that we do not log a message after DEFAULT_LOG_AFTER if the + # caller gave log_after=1 + + # Call stop without calling start. + self.assertRaises(server_module.TaskTimeout, + self.server.stop, log_after=0, timeout=2) + + # We timed out. Ensure we didn't log anything. + self.assertFalse(mock_log.warn.called)