Ensures that some assumptions are true.

It's documented, the application consumer must not use wait before stop.
but this is not enforced, so enforce it

Also the code assume start/stop/wait are called from the same thread,
but this is not enforced, so enforce it.

A common broken usage is:

    server = oslo.messaging.get_rpc_server(..., executor='eventlet')
    t = threading.Thread(target=server.start)
    t.daemon = True
    t.start()
    ...foobar code...
    server.stop()
    server.wait()

With monkey patching, start() will do a context switch and then stop()
is called but start is unfinished, that can cause unexpected behavior.

This patch fixes these issues by making all of this explicit.

Closes-bug: #1465850
Closes-bug: #1466001

Change-Id: I0fc1717e3118bc1cd7b9cd0ccc072251cfb2c038
This commit is contained in:
Mehdi Abaakouk 2015-06-16 23:15:20 +02:00
parent d1c546e5bb
commit 0dafde9407
7 changed files with 113 additions and 50 deletions

View File

@ -40,6 +40,7 @@ from oslo_messaging._i18n import _
from oslo_messaging._i18n import _LE from oslo_messaging._i18n import _LE
from oslo_messaging._i18n import _LI from oslo_messaging._i18n import _LI
from oslo_messaging._i18n import _LW from oslo_messaging._i18n import _LW
from oslo_messaging import _utils
from oslo_messaging import exceptions from oslo_messaging import exceptions
@ -309,7 +310,7 @@ class ConnectionLock(DummyConnectionLock):
self._monitor = threading.Lock() self._monitor = threading.Lock()
self._workers_locks = threading.Condition(self._monitor) self._workers_locks = threading.Condition(self._monitor)
self._heartbeat_lock = threading.Condition(self._monitor) self._heartbeat_lock = threading.Condition(self._monitor)
self._get_thread_id = self._fetch_current_thread_functor() self._get_thread_id = _utils.fetch_current_thread_functor()
def acquire(self): def acquire(self):
with self._monitor: with self._monitor:
@ -351,25 +352,6 @@ class ConnectionLock(DummyConnectionLock):
finally: finally:
self.release() self.release()
@staticmethod
def _fetch_current_thread_functor():
# Until https://github.com/eventlet/eventlet/issues/172 is resolved
# or addressed we have to use complicated workaround to get a object
# that will not be recycled; the usage of threading.current_thread()
# doesn't appear to currently be monkey patched and therefore isn't
# reliable to use (and breaks badly when used as all threads share
# the same current_thread() object)...
try:
import eventlet
from eventlet import patcher
green_threaded = patcher.is_monkey_patched('thread')
except ImportError:
green_threaded = False
if green_threaded:
return lambda: eventlet.getcurrent()
else:
return lambda: threading.current_thread()
class Connection(object): class Connection(object):
"""Connection object.""" """Connection object."""

View File

@ -14,6 +14,7 @@
# under the License. # under the License.
import logging import logging
import threading
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -94,3 +95,22 @@ class DispatcherExecutorContext(object):
# else # else
if self._post is not None: if self._post is not None:
self._post(self._incoming, self._result) self._post(self._incoming, self._result)
def fetch_current_thread_functor():
# Until https://github.com/eventlet/eventlet/issues/172 is resolved
# or addressed we have to use complicated workaround to get a object
# that will not be recycled; the usage of threading.current_thread()
# doesn't appear to currently be monkey patched and therefore isn't
# reliable to use (and breaks badly when used as all threads share
# the same current_thread() object)...
try:
import eventlet
from eventlet import patcher
green_threaded = patcher.is_monkey_patched('thread')
except ImportError:
green_threaded = False
if green_threaded:
return lambda: eventlet.getcurrent()
else:
return lambda: threading.current_thread()

View File

@ -27,6 +27,8 @@ from oslo_service import service
from stevedore import driver from stevedore import driver
from oslo_messaging._drivers import base as driver_base from oslo_messaging._drivers import base as driver_base
from oslo_messaging._i18n import _
from oslo_messaging import _utils
from oslo_messaging import exceptions from oslo_messaging import exceptions
@ -86,6 +88,8 @@ class MessageHandlingServer(service.ServiceBase):
self.dispatcher = dispatcher self.dispatcher = dispatcher
self.executor = executor self.executor = executor
self._get_thread_id = _utils.fetch_current_thread_functor()
try: try:
mgr = driver.DriverManager('oslo.messaging.executors', mgr = driver.DriverManager('oslo.messaging.executors',
self.executor) self.executor)
@ -94,6 +98,8 @@ class MessageHandlingServer(service.ServiceBase):
else: else:
self._executor_cls = mgr.driver self._executor_cls = mgr.driver
self._executor = None self._executor = None
self._running = False
self._thread_id = None
super(MessageHandlingServer, self).__init__() super(MessageHandlingServer, self).__init__()
@ -111,6 +117,8 @@ class MessageHandlingServer(service.ServiceBase):
choose to dispatch messages in a new thread, coroutine or simply the choose to dispatch messages in a new thread, coroutine or simply the
current thread. current thread.
""" """
self._check_same_thread_id()
if self._executor is not None: if self._executor is not None:
return return
try: try:
@ -118,10 +126,18 @@ class MessageHandlingServer(service.ServiceBase):
except driver_base.TransportDriverError as ex: except driver_base.TransportDriverError as ex:
raise ServerListenError(self.target, ex) raise ServerListenError(self.target, ex)
self._running = True
self._executor = self._executor_cls(self.conf, listener, self._executor = self._executor_cls(self.conf, listener,
self.dispatcher) self.dispatcher)
self._executor.start() self._executor.start()
def _check_same_thread_id(self):
if self._thread_id is None:
self._thread_id = self._get_thread_id()
elif self._thread_id != self._get_thread_id():
raise RuntimeError(_("start/stop/wait must be called in the "
"same thread"))
def stop(self): def stop(self):
"""Stop handling incoming messages. """Stop handling incoming messages.
@ -130,7 +146,10 @@ class MessageHandlingServer(service.ServiceBase):
some messages, and underlying driver resources associated to this some messages, and underlying driver resources associated to this
server are still in use. See 'wait' for more details. server are still in use. See 'wait' for more details.
""" """
self._check_same_thread_id()
if self._executor is not None: if self._executor is not None:
self._running = False
self._executor.stop() self._executor.stop()
def wait(self): def wait(self):
@ -143,12 +162,22 @@ class MessageHandlingServer(service.ServiceBase):
Once it's finished, the underlying driver resources associated to this Once it's finished, the underlying driver resources associated to this
server are released (like closing useless network connections). server are released (like closing useless network connections).
""" """
self._check_same_thread_id()
if self._running:
raise RuntimeError(_("wait() should be called after stop() as it "
"waits for existing messages to finish "
"processing"))
if self._executor is not None: if self._executor is not None:
self._executor.wait() self._executor.wait()
# Close listener connection after processing all messages # Close listener connection after processing all messages
self._executor.listener.cleanup() self._executor.listener.cleanup()
self._executor = None self._executor = None
# NOTE(sileht): executor/listener have been properly stopped
# allow to restart it into another thread
self._thread_id = None
def reset(self): def reset(self):
"""Reset service. """Reset service.

View File

@ -12,7 +12,6 @@
# under the License. # under the License.
import os import os
import threading
import time import time
import uuid import uuid
@ -76,7 +75,7 @@ class RpcServerFixture(fixtures.Fixture):
"""Fixture to setup the TestServerEndpoint.""" """Fixture to setup the TestServerEndpoint."""
def __init__(self, url, target, endpoint=None, ctrl_target=None, def __init__(self, url, target, endpoint=None, ctrl_target=None,
executor='blocking'): executor='eventlet'):
super(RpcServerFixture, self).__init__() super(RpcServerFixture, self).__init__()
self.url = url self.url = url
self.target = target self.target = target
@ -104,14 +103,12 @@ class RpcServerFixture(fixtures.Fixture):
super(RpcServerFixture, self).cleanUp() super(RpcServerFixture, self).cleanUp()
def _start(self): def _start(self):
self.thread = threading.Thread(target=self.server.start) self.thread = test_utils.ServerThreadHelper(self.server)
self.thread.daemon = True
self.thread.start() self.thread.start()
def _stop(self): def _stop(self):
self.server.stop() self.thread.stop()
self._ctrl.cast({}, 'ping') self._ctrl.cast({}, 'ping')
self.server.wait()
self.thread.join() self.thread.join()
def ping(self, ctxt): def ping(self, ctxt):
@ -308,7 +305,7 @@ class NotificationFixture(fixtures.Fixture):
self.server = oslo_messaging.get_notification_listener( self.server = oslo_messaging.get_notification_listener(
transport.transport, transport.transport,
targets, targets,
[self]) [self], 'eventlet')
self._ctrl = self.notifier('internal', topic=self.name) self._ctrl = self.notifier('internal', topic=self.name)
self._start() self._start()
transport.wait() transport.wait()
@ -318,14 +315,12 @@ class NotificationFixture(fixtures.Fixture):
super(NotificationFixture, self).cleanUp() super(NotificationFixture, self).cleanUp()
def _start(self): def _start(self):
self.thread = threading.Thread(target=self.server.start) self.thread = test_utils.ServerThreadHelper(self.server)
self.thread.daemon = True
self.thread.start() self.thread.start()
def _stop(self): def _stop(self):
self.server.stop() self.thread.stop()
self._ctrl.sample({}, 'shutdown', 'shutdown') self._ctrl.sample({}, 'shutdown', 'shutdown')
self.server.wait()
self.thread.join() self.thread.join()
def notifier(self, publisher, topic=None): def notifier(self, publisher, topic=None):

View File

@ -31,32 +31,15 @@ class RestartableServerThread(object):
def __init__(self, server): def __init__(self, server):
self.server = server self.server = server
self.thread = None self.thread = None
self._started = threading.Event()
self._tombstone = threading.Event()
def start(self): def start(self):
self._tombstone.clear()
if self.thread is None: if self.thread is None:
self._started.clear() self.thread = test_utils.ServerThreadHelper(self.server)
self.thread = threading.Thread(target=self._target)
self.thread.daemon = True
self.thread.start() self.thread.start()
self._started.wait()
def _target(self):
self.server.start()
self._started.set()
self._tombstone.wait()
def stop(self): def stop(self):
if self.thread is not None: if self.thread is not None:
# Check start() does nothing with a running listener self.thread.stop()
self.server.start()
self._tombstone.set()
if self.thread is not None:
self.server.stop()
self.server.wait()
self.thread.join(timeout=15) self.thread.join(timeout=15)
ret = self.thread.isAlive() ret = self.thread.isAlive()
self.thread = None self.thread = None

View File

@ -130,6 +130,37 @@ class TestRPCServer(test_utils.BaseTestCase, ServerSetupMixin):
self.assertIsNone(server._executor) self.assertIsNone(server._executor)
self.assertEqual(1, listener.cleanup.call_count) self.assertEqual(1, listener.cleanup.call_count)
def test_server_invalid_wait_running_server(self):
transport = oslo_messaging.get_transport(self.conf, url='fake:')
target = oslo_messaging.Target(topic='foo', server='bar')
endpoints = [object()]
serializer = object()
server = oslo_messaging.get_rpc_server(transport, target, endpoints,
serializer=serializer,
executor='eventlet')
self.addCleanup(server.wait)
self.addCleanup(server.stop)
server.start()
self.assertRaises(RuntimeError, server.wait)
def test_server_invalid_stop_from_other_thread(self):
transport = oslo_messaging.get_transport(self.conf, url='fake:')
target = oslo_messaging.Target(topic='foo', server='bar')
endpoints = [object()]
serializer = object()
server = oslo_messaging.get_rpc_server(transport, target, endpoints,
serializer=serializer,
executor='eventlet')
t = test_utils.ServerThreadHelper(server)
t.start()
self.addCleanup(t.join)
self.addCleanup(t.stop)
self.assertRaises(RuntimeError, server.stop)
self.assertRaises(RuntimeError, server.wait)
def test_no_target_server(self): def test_no_target_server(self):
transport = oslo_messaging.get_transport(self.conf, url='fake:') transport = oslo_messaging.get_transport(self.conf, url='fake:')

View File

@ -19,6 +19,8 @@
"""Common utilities used in testing""" """Common utilities used in testing"""
import threading
from oslo_config import cfg from oslo_config import cfg
from oslotest import base from oslotest import base
from oslotest import moxstubout from oslotest import moxstubout
@ -56,3 +58,24 @@ class BaseTestCase(base.BaseTestCase):
group = kw.pop('group', None) group = kw.pop('group', None)
for k, v in six.iteritems(kw): for k, v in six.iteritems(kw):
self.conf.set_override(k, v, group) self.conf.set_override(k, v, group)
class ServerThreadHelper(threading.Thread):
def __init__(self, server):
super(ServerThreadHelper, self).__init__()
self.daemon = True
self._server = server
self._stop_event = threading.Event()
self._wait_event = threading.Event()
def run(self):
self._server.start()
self._stop_event.wait()
# Check start() does nothing with a running listener
self._server.start()
self._server.stop()
self._server.wait()
self._wait_event.set()
def stop(self):
self._stop_event.set()