Ensures that some assumptions are true.

It's documented, the application consumer must not use wait before stop.
but this is not enforced, so enforce it

Also the code assume start/stop/wait are called from the same thread,
but this is not enforced, so enforce it.

A common broken usage is:

    server = oslo.messaging.get_rpc_server(..., executor='eventlet')
    t = threading.Thread(target=server.start)
    t.daemon = True
    t.start()
    ...foobar code...
    server.stop()
    server.wait()

With monkey patching, start() will do a context switch and then stop()
is called but start is unfinished, that can cause unexpected behavior.

This patch fixes these issues by making all of this explicit.

Closes-bug: #1465850
Closes-bug: #1466001

Change-Id: I0fc1717e3118bc1cd7b9cd0ccc072251cfb2c038
This commit is contained in:
Mehdi Abaakouk 2015-06-16 23:15:20 +02:00
parent d1c546e5bb
commit 0dafde9407
7 changed files with 113 additions and 50 deletions

View File

@ -40,6 +40,7 @@ from oslo_messaging._i18n import _
from oslo_messaging._i18n import _LE
from oslo_messaging._i18n import _LI
from oslo_messaging._i18n import _LW
from oslo_messaging import _utils
from oslo_messaging import exceptions
@ -309,7 +310,7 @@ class ConnectionLock(DummyConnectionLock):
self._monitor = threading.Lock()
self._workers_locks = threading.Condition(self._monitor)
self._heartbeat_lock = threading.Condition(self._monitor)
self._get_thread_id = self._fetch_current_thread_functor()
self._get_thread_id = _utils.fetch_current_thread_functor()
def acquire(self):
with self._monitor:
@ -351,25 +352,6 @@ class ConnectionLock(DummyConnectionLock):
finally:
self.release()
@staticmethod
def _fetch_current_thread_functor():
# Until https://github.com/eventlet/eventlet/issues/172 is resolved
# or addressed we have to use complicated workaround to get a object
# that will not be recycled; the usage of threading.current_thread()
# doesn't appear to currently be monkey patched and therefore isn't
# reliable to use (and breaks badly when used as all threads share
# the same current_thread() object)...
try:
import eventlet
from eventlet import patcher
green_threaded = patcher.is_monkey_patched('thread')
except ImportError:
green_threaded = False
if green_threaded:
return lambda: eventlet.getcurrent()
else:
return lambda: threading.current_thread()
class Connection(object):
"""Connection object."""

View File

@ -14,6 +14,7 @@
# under the License.
import logging
import threading
LOG = logging.getLogger(__name__)
@ -94,3 +95,22 @@ class DispatcherExecutorContext(object):
# else
if self._post is not None:
self._post(self._incoming, self._result)
def fetch_current_thread_functor():
# Until https://github.com/eventlet/eventlet/issues/172 is resolved
# or addressed we have to use complicated workaround to get a object
# that will not be recycled; the usage of threading.current_thread()
# doesn't appear to currently be monkey patched and therefore isn't
# reliable to use (and breaks badly when used as all threads share
# the same current_thread() object)...
try:
import eventlet
from eventlet import patcher
green_threaded = patcher.is_monkey_patched('thread')
except ImportError:
green_threaded = False
if green_threaded:
return lambda: eventlet.getcurrent()
else:
return lambda: threading.current_thread()

View File

@ -27,6 +27,8 @@ from oslo_service import service
from stevedore import driver
from oslo_messaging._drivers import base as driver_base
from oslo_messaging._i18n import _
from oslo_messaging import _utils
from oslo_messaging import exceptions
@ -86,6 +88,8 @@ class MessageHandlingServer(service.ServiceBase):
self.dispatcher = dispatcher
self.executor = executor
self._get_thread_id = _utils.fetch_current_thread_functor()
try:
mgr = driver.DriverManager('oslo.messaging.executors',
self.executor)
@ -94,6 +98,8 @@ class MessageHandlingServer(service.ServiceBase):
else:
self._executor_cls = mgr.driver
self._executor = None
self._running = False
self._thread_id = None
super(MessageHandlingServer, self).__init__()
@ -111,6 +117,8 @@ class MessageHandlingServer(service.ServiceBase):
choose to dispatch messages in a new thread, coroutine or simply the
current thread.
"""
self._check_same_thread_id()
if self._executor is not None:
return
try:
@ -118,10 +126,18 @@ class MessageHandlingServer(service.ServiceBase):
except driver_base.TransportDriverError as ex:
raise ServerListenError(self.target, ex)
self._running = True
self._executor = self._executor_cls(self.conf, listener,
self.dispatcher)
self._executor.start()
def _check_same_thread_id(self):
if self._thread_id is None:
self._thread_id = self._get_thread_id()
elif self._thread_id != self._get_thread_id():
raise RuntimeError(_("start/stop/wait must be called in the "
"same thread"))
def stop(self):
"""Stop handling incoming messages.
@ -130,7 +146,10 @@ class MessageHandlingServer(service.ServiceBase):
some messages, and underlying driver resources associated to this
server are still in use. See 'wait' for more details.
"""
self._check_same_thread_id()
if self._executor is not None:
self._running = False
self._executor.stop()
def wait(self):
@ -143,12 +162,22 @@ class MessageHandlingServer(service.ServiceBase):
Once it's finished, the underlying driver resources associated to this
server are released (like closing useless network connections).
"""
self._check_same_thread_id()
if self._running:
raise RuntimeError(_("wait() should be called after stop() as it "
"waits for existing messages to finish "
"processing"))
if self._executor is not None:
self._executor.wait()
# Close listener connection after processing all messages
self._executor.listener.cleanup()
self._executor = None
# NOTE(sileht): executor/listener have been properly stopped
# allow to restart it into another thread
self._thread_id = None
def reset(self):
"""Reset service.

View File

@ -12,7 +12,6 @@
# under the License.
import os
import threading
import time
import uuid
@ -76,7 +75,7 @@ class RpcServerFixture(fixtures.Fixture):
"""Fixture to setup the TestServerEndpoint."""
def __init__(self, url, target, endpoint=None, ctrl_target=None,
executor='blocking'):
executor='eventlet'):
super(RpcServerFixture, self).__init__()
self.url = url
self.target = target
@ -104,14 +103,12 @@ class RpcServerFixture(fixtures.Fixture):
super(RpcServerFixture, self).cleanUp()
def _start(self):
self.thread = threading.Thread(target=self.server.start)
self.thread.daemon = True
self.thread = test_utils.ServerThreadHelper(self.server)
self.thread.start()
def _stop(self):
self.server.stop()
self.thread.stop()
self._ctrl.cast({}, 'ping')
self.server.wait()
self.thread.join()
def ping(self, ctxt):
@ -308,7 +305,7 @@ class NotificationFixture(fixtures.Fixture):
self.server = oslo_messaging.get_notification_listener(
transport.transport,
targets,
[self])
[self], 'eventlet')
self._ctrl = self.notifier('internal', topic=self.name)
self._start()
transport.wait()
@ -318,14 +315,12 @@ class NotificationFixture(fixtures.Fixture):
super(NotificationFixture, self).cleanUp()
def _start(self):
self.thread = threading.Thread(target=self.server.start)
self.thread.daemon = True
self.thread = test_utils.ServerThreadHelper(self.server)
self.thread.start()
def _stop(self):
self.server.stop()
self.thread.stop()
self._ctrl.sample({}, 'shutdown', 'shutdown')
self.server.wait()
self.thread.join()
def notifier(self, publisher, topic=None):

View File

@ -31,32 +31,15 @@ class RestartableServerThread(object):
def __init__(self, server):
self.server = server
self.thread = None
self._started = threading.Event()
self._tombstone = threading.Event()
def start(self):
self._tombstone.clear()
if self.thread is None:
self._started.clear()
self.thread = threading.Thread(target=self._target)
self.thread.daemon = True
self.thread = test_utils.ServerThreadHelper(self.server)
self.thread.start()
self._started.wait()
def _target(self):
self.server.start()
self._started.set()
self._tombstone.wait()
def stop(self):
if self.thread is not None:
# Check start() does nothing with a running listener
self.server.start()
self._tombstone.set()
if self.thread is not None:
self.server.stop()
self.server.wait()
self.thread.stop()
self.thread.join(timeout=15)
ret = self.thread.isAlive()
self.thread = None

View File

@ -130,6 +130,37 @@ class TestRPCServer(test_utils.BaseTestCase, ServerSetupMixin):
self.assertIsNone(server._executor)
self.assertEqual(1, listener.cleanup.call_count)
def test_server_invalid_wait_running_server(self):
transport = oslo_messaging.get_transport(self.conf, url='fake:')
target = oslo_messaging.Target(topic='foo', server='bar')
endpoints = [object()]
serializer = object()
server = oslo_messaging.get_rpc_server(transport, target, endpoints,
serializer=serializer,
executor='eventlet')
self.addCleanup(server.wait)
self.addCleanup(server.stop)
server.start()
self.assertRaises(RuntimeError, server.wait)
def test_server_invalid_stop_from_other_thread(self):
transport = oslo_messaging.get_transport(self.conf, url='fake:')
target = oslo_messaging.Target(topic='foo', server='bar')
endpoints = [object()]
serializer = object()
server = oslo_messaging.get_rpc_server(transport, target, endpoints,
serializer=serializer,
executor='eventlet')
t = test_utils.ServerThreadHelper(server)
t.start()
self.addCleanup(t.join)
self.addCleanup(t.stop)
self.assertRaises(RuntimeError, server.stop)
self.assertRaises(RuntimeError, server.wait)
def test_no_target_server(self):
transport = oslo_messaging.get_transport(self.conf, url='fake:')

View File

@ -19,6 +19,8 @@
"""Common utilities used in testing"""
import threading
from oslo_config import cfg
from oslotest import base
from oslotest import moxstubout
@ -56,3 +58,24 @@ class BaseTestCase(base.BaseTestCase):
group = kw.pop('group', None)
for k, v in six.iteritems(kw):
self.conf.set_override(k, v, group)
class ServerThreadHelper(threading.Thread):
def __init__(self, server):
super(ServerThreadHelper, self).__init__()
self.daemon = True
self._server = server
self._stop_event = threading.Event()
self._wait_event = threading.Event()
def run(self):
self._server.start()
self._stop_event.wait()
# Check start() does nothing with a running listener
self._server.start()
self._server.stop()
self._server.wait()
self._wait_event.set()
def stop(self):
self._stop_event.set()