diff --git a/oslo/messaging/_executors/base.py b/oslo/messaging/_executors/base.py index 897097d8c..8019017b3 100644 --- a/oslo/messaging/_executors/base.py +++ b/oslo/messaging/_executors/base.py @@ -13,41 +13,17 @@ # under the License. import abc -import logging -import sys import six -from oslo import messaging - -_LOG = logging.getLogger(__name__) - @six.add_metaclass(abc.ABCMeta) class ExecutorBase(object): - def __init__(self, conf, listener, callback): + def __init__(self, conf, listener, dispatcher): self.conf = conf self.listener = listener - self.callback = callback - - def _dispatch(self, incoming): - try: - incoming.reply(self.callback(incoming.ctxt, incoming.message)) - except messaging.ExpectedException as e: - _LOG.debug('Expected exception during message handling (%s)' % - e.exc_info[1]) - incoming.reply(failure=e.exc_info, log_failure=False) - except Exception as e: - # sys.exc_info() is deleted by LOG.exception(). - exc_info = sys.exc_info() - _LOG.error('Exception during message handling: %s', e, - exc_info=exc_info) - incoming.reply(failure=exc_info) - # NOTE(dhellmann): Remove circular object reference - # between the current stack frame and the traceback in - # exc_info. - del exc_info + self.dispatcher = dispatcher @abc.abstractmethod def start(self): diff --git a/oslo/messaging/_executors/impl_blocking.py b/oslo/messaging/_executors/impl_blocking.py index 55506a133..8e463a0c7 100644 --- a/oslo/messaging/_executors/impl_blocking.py +++ b/oslo/messaging/_executors/impl_blocking.py @@ -29,14 +29,15 @@ class BlockingExecutor(base.ExecutorBase): for simple demo programs. """ - def __init__(self, conf, listener, callback): - super(BlockingExecutor, self).__init__(conf, listener, callback) + def __init__(self, conf, listener, dispatcher): + super(BlockingExecutor, self).__init__(conf, listener, dispatcher) self._running = False def start(self): self._running = True while self._running: - self._dispatch(self.listener.poll()) + with self.dispatcher(self.listener.poll()) as callback: + callback() def stop(self): self._running = False diff --git a/oslo/messaging/_executors/impl_eventlet.py b/oslo/messaging/_executors/impl_eventlet.py index 5c6b8f818..3b294584c 100644 --- a/oslo/messaging/_executors/impl_eventlet.py +++ b/oslo/messaging/_executors/impl_eventlet.py @@ -13,6 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. +import sys + import eventlet from eventlet import greenpool import greenlet @@ -29,6 +31,33 @@ _eventlet_opts = [ ] +def spawn_with(ctxt, pool): + """This is the equivalent of a with statement + but with the content of the BLOCK statement executed + into a greenthread + + exception path grab from: + http://www.python.org/dev/peps/pep-0343/ + """ + + def complete(thread, exit): + exc = True + try: + try: + thread.wait() + except Exception: + exc = False + if not exit(*sys.exc_info()): + raise + finally: + if exc: + exit(None, None, None) + + callback = ctxt.__enter__() + thread = pool.spawn(callback) + thread.link(complete, ctxt.__exit__) + + class EventletExecutor(base.ExecutorBase): """A message executor which integrates with eventlet. @@ -40,8 +69,8 @@ class EventletExecutor(base.ExecutorBase): method waits for all message dispatch greenthreads to complete. """ - def __init__(self, conf, listener, callback): - super(EventletExecutor, self).__init__(conf, listener, callback) + def __init__(self, conf, listener, dispatcher): + super(EventletExecutor, self).__init__(conf, listener, dispatcher) self.conf.register_opts(_eventlet_opts) self._thread = None self._greenpool = greenpool.GreenPool(self.conf.rpc_thread_pool_size) @@ -55,7 +84,8 @@ class EventletExecutor(base.ExecutorBase): try: while True: incoming = self.listener.poll() - self._greenpool.spawn_n(self._dispatch, incoming) + spawn_with(ctxt=self.dispatcher(incoming), + pool=self._greenpool) except greenlet.GreenletExit: return diff --git a/oslo/messaging/notify/dispatcher.py b/oslo/messaging/notify/dispatcher.py index f36c3925f..a79cc29e8 100644 --- a/oslo/messaging/notify/dispatcher.py +++ b/oslo/messaging/notify/dispatcher.py @@ -14,8 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. +import contextlib import itertools import logging +import sys from oslo.messaging import localcontext from oslo.messaging import serializer as msg_serializer @@ -55,7 +57,25 @@ class NotificationDispatcher(object): def _listen(self, transport): return transport._listen_for_notifications(self._targets_priorities) - def __call__(self, ctxt, message): + @contextlib.contextmanager + def __call__(self, incoming): + yield lambda: self._dispatch_and_handle_error(incoming) + + def _dispatch_and_handle_error(self, incoming): + """Dispatch a notification message to the appropriate endpoint method. + + :param incoming: the incoming notification message + :type ctxt: IncomingMessage + """ + try: + self._dispatch(incoming.ctxt, incoming.message) + except Exception: + # sys.exc_info() is deleted by LOG.exception(). + exc_info = sys.exc_info() + LOG.error('Exception during message handling', + exc_info=exc_info) + + def _dispatch(self, ctxt, message): """Dispatch an RPC message to the appropriate endpoint method. :param ctxt: the request context diff --git a/oslo/messaging/rpc/dispatcher.py b/oslo/messaging/rpc/dispatcher.py index a5bad113d..ecb1e8795 100644 --- a/oslo/messaging/rpc/dispatcher.py +++ b/oslo/messaging/rpc/dispatcher.py @@ -21,8 +21,13 @@ __all__ = [ 'RPCDispatcher', 'RPCDispatcherError', 'UnsupportedVersion', + 'ExpectedException', ] +import contextlib +import logging +import sys + import six from oslo.messaging import _utils as utils @@ -31,6 +36,19 @@ from oslo.messaging import serializer as msg_serializer from oslo.messaging import server as msg_server from oslo.messaging import target as msg_target +LOG = logging.getLogger(__name__) + + +class ExpectedException(Exception): + """Encapsulates an expected exception raised by an RPC endpoint + + Merely instantiating this exception records the current exception + information, which will be passed back to the RPC client without + exceptional logging. + """ + def __init__(self): + self.exc_info = sys.exc_info() + class RPCDispatcherError(msg_server.MessagingServerError): "A base class for all RPC dispatcher exceptions." @@ -96,7 +114,7 @@ class RPCDispatcher(object): endpoint_version = target.version or '1.0' return utils.version_is_compatible(endpoint_version, version) - def _dispatch(self, endpoint, method, ctxt, args): + def _do_dispatch(self, endpoint, method, ctxt, args): ctxt = self.serializer.deserialize_context(ctxt) new_args = dict() for argname, arg in six.iteritems(args): @@ -104,7 +122,30 @@ class RPCDispatcher(object): result = getattr(endpoint, method)(ctxt, **new_args) return self.serializer.serialize_entity(ctxt, result) - def __call__(self, ctxt, message): + @contextlib.contextmanager + def __call__(self, incoming): + yield lambda: self._dispatch_and_reply(incoming) + + def _dispatch_and_reply(self, incoming): + try: + incoming.reply(self._dispatch(incoming.ctxt, + incoming.message)) + except ExpectedException as e: + LOG.debug('Expected exception during message handling (%s)' % + e.exc_info[1]) + incoming.reply(failure=e.exc_info, log_failure=False) + except Exception as e: + # sys.exc_info() is deleted by LOG.exception(). + exc_info = sys.exc_info() + LOG.error('Exception during message handling: %s', e, + exc_info=exc_info) + incoming.reply(failure=exc_info) + # NOTE(dhellmann): Remove circular object reference + # between the current stack frame and the traceback in + # exc_info. + del exc_info + + def _dispatch(self, ctxt, message): """Dispatch an RPC message to the appropriate endpoint method. :param ctxt: the request context @@ -131,7 +172,7 @@ class RPCDispatcher(object): if hasattr(endpoint, method): localcontext.set_local_context(ctxt) try: - return self._dispatch(endpoint, method, ctxt, args) + return self._do_dispatch(endpoint, method, ctxt, args) finally: localcontext.clear_local_context() diff --git a/oslo/messaging/rpc/server.py b/oslo/messaging/rpc/server.py index 2a1f238c7..e1f700070 100644 --- a/oslo/messaging/rpc/server.py +++ b/oslo/messaging/rpc/server.py @@ -92,12 +92,9 @@ to - primitive types. __all__ = [ 'get_rpc_server', - 'ExpectedException', 'expected_exceptions', ] -import sys - from oslo.messaging.rpc import dispatcher as rpc_dispatcher from oslo.messaging import server as msg_server @@ -125,17 +122,6 @@ def get_rpc_server(transport, target, endpoints, return msg_server.MessageHandlingServer(transport, dispatcher, executor) -class ExpectedException(Exception): - """Encapsulates an expected exception raised by an RPC endpoint - - Merely instantiating this exception records the current exception - information, which will be passed back to the RPC client without - exceptional logging. - """ - def __init__(self): - self.exc_info = sys.exc_info() - - def expected_exceptions(*exceptions): """Decorator for RPC endpoint methods that raise expected exceptions. @@ -158,6 +144,6 @@ def expected_exceptions(*exceptions): # derived from the args passed to us will be # ignored and thrown as normal. except exceptions: - raise ExpectedException() + raise rpc_dispatcher.ExpectedException() return inner return outer diff --git a/tests/test_executor.py b/tests/test_executor.py new file mode 100644 index 000000000..b6b043306 --- /dev/null +++ b/tests/test_executor.py @@ -0,0 +1,131 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# Copyright 2013 eNovance +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import contextlib +import eventlet +import threading + +import mock +import testscenarios + +from oslo.messaging._executors import impl_blocking +from oslo.messaging._executors import impl_eventlet +from tests import utils as test_utils + +load_tests = testscenarios.load_tests_apply_scenarios + + +class TestExecutor(test_utils.BaseTestCase): + + _impl = [('blocking', dict(executor=impl_blocking.BlockingExecutor, + stop_before_return=True)), + ('eventlet', dict(executor=impl_eventlet.EventletExecutor, + stop_before_return=False))] + + @classmethod + def generate_scenarios(cls): + cls.scenarios = testscenarios.multiply_scenarios(cls._impl) + + @staticmethod + def _run_in_thread(executor): + def thread(): + executor.start() + executor.wait() + thread = threading.Thread(target=thread) + thread.daemon = True + thread.start() + thread.join(timeout=30) + + def test_executor_dispatch(self): + callback = mock.MagicMock(return_value='result') + + class Dispatcher(object): + @contextlib.contextmanager + def __call__(self, incoming): + yield lambda: callback(incoming.ctxt, incoming.message) + + listener = mock.Mock(spec=['poll']) + executor = self.executor(self.conf, listener, Dispatcher()) + + incoming_message = mock.MagicMock(ctxt={}, + message={'payload': 'data'}) + + def fake_poll(): + if self.stop_before_return: + executor.stop() + return incoming_message + else: + if listener.poll.call_count == 1: + return incoming_message + executor.stop() + + listener.poll.side_effect = fake_poll + + self._run_in_thread(executor) + + callback.assert_called_once_with({}, {'payload': 'data'}) + +TestExecutor.generate_scenarios() + + +class ExceptedException(Exception): + pass + + +class EventletContextManagerSpawnTest(test_utils.BaseTestCase): + def setUp(self): + super(EventletContextManagerSpawnTest, self).setUp() + self.before = mock.Mock() + self.callback = mock.Mock() + self.after = mock.Mock() + self.exception_call = mock.Mock() + + @contextlib.contextmanager + def context_mgr(): + self.before() + try: + yield lambda: self.callback() + except ExceptedException: + self.exception_call() + self.after() + + self.mgr = context_mgr() + + def test_normal_run(self): + impl_eventlet.spawn_with(self.mgr, pool=eventlet) + eventlet.sleep(0) + self.assertEqual(self.before.call_count, 1) + self.assertEqual(self.callback.call_count, 1) + self.assertEqual(self.after.call_count, 1) + self.assertEqual(self.exception_call.call_count, 0) + + def test_excepted_exception(self): + self.callback.side_effect = ExceptedException + impl_eventlet.spawn_with(self.mgr, pool=eventlet) + eventlet.sleep(0) + self.assertEqual(self.before.call_count, 1) + self.assertEqual(self.callback.call_count, 1) + self.assertEqual(self.after.call_count, 1) + self.assertEqual(self.exception_call.call_count, 1) + + def test_unexcepted_exception(self): + self.callback.side_effect = Exception + impl_eventlet.spawn_with(self.mgr, pool=eventlet) + eventlet.sleep(0) + self.assertEqual(self.before.call_count, 1) + self.assertEqual(self.callback.call_count, 1) + self.assertEqual(self.after.call_count, 0) + self.assertEqual(self.exception_call.call_count, 0) diff --git a/tests/test_notify_dispatcher.py b/tests/test_notify_dispatcher.py index e1c1f9abb..24b6e6b32 100644 --- a/tests/test_notify_dispatcher.py +++ b/tests/test_notify_dispatcher.py @@ -73,7 +73,9 @@ class TestDispatcher(test_utils.BaseTestCase): for prio in itertools.chain.from_iterable( self.endpoints)))) - dispatcher({}, msg) + incoming = mock.Mock(ctxt={}, message=msg) + with dispatcher(incoming) as callback: + callback() # check endpoint callbacks are called or not for i, endpoint_methods in enumerate(self.endpoints): @@ -94,5 +96,6 @@ class TestDispatcher(test_utils.BaseTestCase): dispatcher = notify_dispatcher.NotificationDispatcher([mock.Mock()], [mock.Mock()], None) - dispatcher({}, msg) + with dispatcher(mock.Mock(ctxt={}, message=msg)) as callback: + callback() mylog.warning.assert_called_once_with('Unknown priority "what???"') diff --git a/tests/test_rpc_dispatcher.py b/tests/test_rpc_dispatcher.py index 5d20813da..8d12278a2 100644 --- a/tests/test_rpc_dispatcher.py +++ b/tests/test_rpc_dispatcher.py @@ -13,6 +13,7 @@ # License for the specific language governing permissions and limitations # under the License. +import mock import testscenarios from oslo import messaging @@ -91,38 +92,46 @@ class TestDispatcher(test_utils.BaseTestCase): ] def test_dispatcher(self): - endpoints = [] - for e in self.endpoints: - target = messaging.Target(**e) if e else None - endpoints.append(_FakeEndpoint(target)) + endpoints = [mock.Mock(spec=_FakeEndpoint, + target=messaging.Target(**e)) + for e in self.endpoints] serializer = None target = messaging.Target() dispatcher = messaging.RPCDispatcher(target, endpoints, serializer) - if self.dispatch_to is not None: - endpoint = endpoints[self.dispatch_to['endpoint']] - method = self.dispatch_to['method'] + def check_reply(reply=None, failure=None, log_failure=True): + if self.ex and failure is not None: + ex = failure[1] + self.assertFalse(self.success, ex) + self.assertIsNotNone(self.ex, ex) + self.assertIsInstance(ex, self.ex, ex) + if isinstance(ex, messaging.NoSuchMethod): + self.assertEqual(ex.method, self.msg.get('method')) + elif isinstance(ex, messaging.UnsupportedVersion): + self.assertEqual(ex.version, + self.msg.get('version', '1.0')) + else: + self.assertTrue(self.success, failure) + self.assertIsNone(failure) - self.mox.StubOutWithMock(endpoint, method) + incoming = mock.Mock(ctxt=self.ctxt, message=self.msg) + incoming.reply.side_effect = check_reply - method = getattr(endpoint, method) - method(self.ctxt, **self.msg.get('args', {})) + with dispatcher(incoming) as callback: + callback() - self.mox.ReplayAll() + for n, endpoint in enumerate(endpoints): + for method_name in ['foo', 'bar']: + method = getattr(endpoint, method_name) + if self.dispatch_to and n == self.dispatch_to['endpoint'] and \ + method_name == self.dispatch_to['method']: + method.assert_called_once_with( + self.ctxt, **self.msg.get('args', {})) + else: + self.assertEqual(method.call_count, 0) - try: - dispatcher(self.ctxt, self.msg) - except Exception as ex: - self.assertFalse(self.success, ex) - self.assertIsNotNone(self.ex, ex) - self.assertIsInstance(ex, self.ex, ex) - if isinstance(ex, messaging.NoSuchMethod): - self.assertEqual(ex.method, self.msg.get('method')) - elif isinstance(ex, messaging.UnsupportedVersion): - self.assertEqual(ex.version, self.msg.get('version', '1.0')) - else: - self.assertTrue(self.success) + self.assertEqual(incoming.reply.call_count, 1) class TestSerializer(test_utils.BaseTestCase): @@ -161,6 +170,7 @@ class TestSerializer(test_utils.BaseTestCase): self.mox.ReplayAll() - retval = dispatcher(self.ctxt, dict(method='foo', args=self.args)) + retval = dispatcher._dispatch(self.ctxt, dict(method='foo', + args=self.args)) if self.retval is not None: self.assertEqual(retval, 's' + self.retval)