Refactor driver's listener interface

Current Listener interface has poll() method which return messages

To use it we need have poller thread which is located in MessageHandlerServer
But my investigations of existing driver's code shows that some implemetations have
its own thread inside for processing connection event loop. This event loop received
messages and store in queue object. And then our poller's thread reads this queue
This situation can be improved. we can remove poller's thread, remove queue object
and just call on_message server's callback from connection eventloop thread

This path provide posibility to do this for one of drivers and leave as is other drivers

Change-Id: I3e3d4369d8fdadcecf079d10af58b1e4f5616047
This commit is contained in:
Dmitriy Ukhlov 2016-04-02 14:58:29 +03:00
parent ee394d3c5b
commit 5d7d7253d1
21 changed files with 325 additions and 223 deletions

View File

@ -176,7 +176,7 @@ class ObsoleteReplyQueuesCache(object):
'msg_id': msg_id}) 'msg_id': msg_id})
class AMQPListener(base.Listener): class AMQPListener(base.PollStyleListener):
def __init__(self, driver, conn): def __init__(self, driver, conn):
super(AMQPListener, self).__init__(driver.prefetch_size) super(AMQPListener, self).__init__(driver.prefetch_size)
@ -473,7 +473,7 @@ class AMQPDriverBase(base.BaseDriver):
return self._send(target, ctxt, message, return self._send(target, ctxt, message,
envelope=(version == 2.0), notify=True, retry=retry) envelope=(version == 2.0), notify=True, retry=retry)
def listen(self, target): def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
conn = self._get_connection(rpc_common.PURPOSE_LISTEN) conn = self._get_connection(rpc_common.PURPOSE_LISTEN)
listener = AMQPListener(self, conn) listener = AMQPListener(self, conn)
@ -487,9 +487,12 @@ class AMQPDriverBase(base.BaseDriver):
callback=listener) callback=listener)
conn.declare_fanout_consumer(target.topic, listener) conn.declare_fanout_consumer(target.topic, listener)
return listener return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def listen_for_notifications(self, targets_and_priorities, pool): def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
conn = self._get_connection(rpc_common.PURPOSE_LISTEN) conn = self._get_connection(rpc_common.PURPOSE_LISTEN)
listener = AMQPListener(self, conn) listener = AMQPListener(self, conn)
@ -498,7 +501,8 @@ class AMQPDriverBase(base.BaseDriver):
exchange_name=self._get_exchange(target), exchange_name=self._get_exchange(target),
topic='%s.%s' % (target.topic, priority), topic='%s.%s' % (target.topic, priority),
callback=listener, queue_name=pool) callback=listener, queue_name=pool)
return listener return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def cleanup(self): def cleanup(self):
if self._connection_pool: if self._connection_pool:

View File

@ -14,12 +14,12 @@
# under the License. # under the License.
import abc import abc
import time import threading
from oslo_config import cfg from oslo_config import cfg
from oslo_utils import excutils
from oslo_utils import timeutils from oslo_utils import timeutils
import six import six
from six.moves import range as compat_range
from oslo_messaging import exceptions from oslo_messaging import exceptions
@ -38,21 +38,33 @@ def batch_poll_helper(func):
This decorator helps driver that polls message one by one, This decorator helps driver that polls message one by one,
to returns a list of message. to returns a list of message.
""" """
def wrapper(in_self, timeout=None, prefetch_size=1): def wrapper(in_self, timeout=None, batch_size=1, batch_timeout=None):
incomings = [] incomings = []
driver_prefetch = in_self.prefetch_size driver_prefetch = in_self.prefetch_size
if driver_prefetch > 0: if driver_prefetch > 0:
prefetch_size = min(prefetch_size, driver_prefetch) batch_size = min(batch_size, driver_prefetch)
watch = timeutils.StopWatch(duration=timeout)
with watch: with timeutils.StopWatch(timeout) as timeout_watch:
for __ in compat_range(prefetch_size): # poll first message
msg = func(in_self, timeout=watch.leftover(return_none=True)) msg = func(in_self, timeout=timeout_watch.leftover(True))
if msg is not None: if msg is not None:
incomings.append(msg) incomings.append(msg)
else: if batch_size == 1 or msg is None:
# timeout reached or listener stopped return incomings
break
time.sleep(0) # update batch_timeout according to timeout for whole operation
timeout_left = timeout_watch.leftover(True)
if timeout_left is not None and (
batch_timeout is None or timeout_left < batch_timeout):
batch_timeout = timeout_left
with timeutils.StopWatch(batch_timeout) as batch_timeout_watch:
# poll remained batch messages
while len(incomings) < batch_size and msg is not None:
msg = func(in_self, timeout=batch_timeout_watch.leftover(True))
if msg is not None:
incomings.append(msg)
return incomings return incomings
return wrapper return wrapper
@ -81,20 +93,22 @@ class RpcIncomingMessage(IncomingMessage):
@abc.abstractmethod @abc.abstractmethod
def reply(self, reply=None, failure=None, log_failure=True): def reply(self, reply=None, failure=None, log_failure=True):
"Send a reply or failure back to the client." """Send a reply or failure back to the client."""
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class Listener(object): class PollStyleListener(object):
def __init__(self, prefetch_size=-1): def __init__(self, prefetch_size=-1):
self.prefetch_size = prefetch_size self.prefetch_size = prefetch_size
@abc.abstractmethod @abc.abstractmethod
def poll(self, timeout=None, prefetch_size=1): def poll(self, timeout=None, batch_size=1, batch_timeout=None):
"""Blocking until 'prefetch_size' message is pending and return """Blocking until 'batch_size' message is pending and return
[IncomingMessage]. [IncomingMessage].
Return None after timeout seconds if timeout is set and no message is Waits for first message. Then waits for next batch_size-1 messages
ending or if the listener have been stopped. during batch window defined by batch_timeout
This method block current thread until message comes, stop() is
executed by another thread or timemout is elapsed.
""" """
def stop(self): def stop(self):
@ -112,6 +126,113 @@ class Listener(object):
pass pass
@six.add_metaclass(abc.ABCMeta)
class Listener(object):
def __init__(self, on_incoming_callback, batch_size, batch_timeout,
prefetch_size=-1):
"""Init Listener
:param on_incoming_callback: callback function to be executed when
listener received messages. Messages should be processed and
acked/nacked by callback
:param batch_size: desired number of messages passed to
single on_incoming_callback call
:param batch_timeout: defines how long should we wait for batch_size
messages if we already have some messages waiting for processing
:param prefetch_size: defines how many massages we want to prefetch
from backend (depend on driver type) by single request
"""
self.on_incoming_callback = on_incoming_callback
self.batch_timeout = batch_timeout
self.prefetch_size = prefetch_size
if prefetch_size > 0:
batch_size = min(batch_size, prefetch_size)
self.batch_size = batch_size
@abc.abstractmethod
def start(self):
"""Stop listener.
Stop the listener message polling
"""
@abc.abstractmethod
def wait(self):
"""Wait listener.
Wait for processing remained input after listener Stop
"""
@abc.abstractmethod
def stop(self):
"""Stop listener.
Stop the listener message polling
"""
@abc.abstractmethod
def cleanup(self):
"""Cleanup listener.
Close connection (socket) used by listener if any.
As this is listener specific method, overwrite it in to derived class
if cleanup of listener required.
"""
class PollStyleListenerAdapter(Listener):
def __init__(self, poll_style_listener, on_incoming_callback, batch_size,
batch_timeout):
super(PollStyleListenerAdapter, self).__init__(
on_incoming_callback, batch_size, batch_timeout,
poll_style_listener.prefetch_size
)
self._poll_style_listener = poll_style_listener
self._listen_thread = threading.Thread(target=self._runner)
self._listen_thread.daemon = True
self._started = False
def start(self):
"""Start listener.
Start the listener message polling
"""
self._started = True
self._listen_thread.start()
@excutils.forever_retry_uncaught_exceptions
def _runner(self):
while self._started:
incoming = self._poll_style_listener.poll(
batch_size=self.batch_size, batch_timeout=self.batch_timeout)
if incoming:
self.on_incoming_callback(incoming)
# listener is stopped but we need to process all already consumed
# messages
while True:
incoming = self._poll_style_listener.poll(
batch_size=self.batch_size, batch_timeout=self.batch_timeout)
if not incoming:
return
self.on_incoming_callback(incoming)
def stop(self):
"""Stop listener.
Stop the listener message polling
"""
self._started = False
self._poll_style_listener.stop()
def wait(self):
self._listen_thread.join()
def cleanup(self):
"""Cleanup listener.
Close connection (socket) used by listener if any.
As this is listener specific method, overwrite it in to derived class
if cleanup of listener required.
"""
self._poll_style_listener.cleanup()
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class BaseDriver(object): class BaseDriver(object):
prefetch_size = 0 prefetch_size = 0
@ -138,11 +259,13 @@ class BaseDriver(object):
"""Send a notification message to the given target.""" """Send a notification message to the given target."""
@abc.abstractmethod @abc.abstractmethod
def listen(self, target): def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
"""Construct a Listener for the given target.""" """Construct a Listener for the given target."""
@abc.abstractmethod @abc.abstractmethod
def listen_for_notifications(self, targets_and_priorities, pool): def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
"""Construct a notification Listener for the given list of """Construct a notification Listener for the given list of
tuple of (target, priority). tuple of (target, priority).
""" """

View File

@ -39,7 +39,7 @@ class FakeIncomingMessage(base.RpcIncomingMessage):
self.requeue_callback() self.requeue_callback()
class FakeListener(base.Listener): class FakeListener(base.PollStyleListener):
def __init__(self, exchange_manager, targets, pool=None): def __init__(self, exchange_manager, targets, pool=None):
super(FakeListener, self).__init__() super(FakeListener, self).__init__()
@ -222,7 +222,7 @@ class FakeDriver(base.BaseDriver):
# transport always works # transport always works
self._send(target, ctxt, message) self._send(target, ctxt, message)
def listen(self, target): def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
exchange = target.exchange or self._default_exchange exchange = target.exchange or self._default_exchange
listener = FakeListener(self._exchange_manager, listener = FakeListener(self._exchange_manager,
[oslo_messaging.Target( [oslo_messaging.Target(
@ -232,9 +232,12 @@ class FakeDriver(base.BaseDriver):
oslo_messaging.Target( oslo_messaging.Target(
topic=target.topic, topic=target.topic,
exchange=exchange)]) exchange=exchange)])
return listener return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def listen_for_notifications(self, targets_and_priorities, pool): def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
targets = [ targets = [
oslo_messaging.Target( oslo_messaging.Target(
topic='%s.%s' % (target.topic, priority), topic='%s.%s' % (target.topic, priority),
@ -242,7 +245,8 @@ class FakeDriver(base.BaseDriver):
for target, priority in targets_and_priorities] for target, priority in targets_and_priorities]
listener = FakeListener(self._exchange_manager, targets, pool) listener = FakeListener(self._exchange_manager, targets, pool)
return listener return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def cleanup(self): def cleanup(self):
pass pass

View File

@ -247,7 +247,7 @@ class OsloKafkaMessage(base.RpcIncomingMessage):
LOG.warning(_LW("reply is not supported")) LOG.warning(_LW("reply is not supported"))
class KafkaListener(base.Listener): class KafkaListener(base.PollStyleListener):
def __init__(self, conn): def __init__(self, conn):
super(KafkaListener, self).__init__() super(KafkaListener, self).__init__()
@ -342,7 +342,9 @@ class KafkaDriver(base.BaseDriver):
raise NotImplementedError( raise NotImplementedError(
'The RPC implementation for Kafka is not implemented') 'The RPC implementation for Kafka is not implemented')
def listen_for_notifications(self, targets_and_priorities, pool=None): def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
"""Listen to a specified list of targets on Kafka brokers """Listen to a specified list of targets on Kafka brokers
:param targets_and_priorities: List of pairs (target, priority) :param targets_and_priorities: List of pairs (target, priority)
@ -361,7 +363,8 @@ class KafkaDriver(base.BaseDriver):
conn.declare_topic_consumer(topics, pool) conn.declare_topic_consumer(topics, pool)
listener = KafkaListener(conn) listener = KafkaListener(conn)
return listener return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def _get_connection(self, purpose): def _get_connection(self, purpose):
return driver_common.ConnectionContext(self.connection_pool, purpose) return driver_common.ConnectionContext(self.connection_pool, purpose)

View File

@ -334,15 +334,18 @@ class PikaDriver(base.BaseDriver):
retrier=retrier retrier=retrier
) )
def listen(self, target): def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
listener = pika_drv_poller.RpcServicePikaPoller( listener = pika_drv_poller.RpcServicePikaPoller(
self._pika_engine, target, self._pika_engine, target,
prefetch_count=self._pika_engine.rpc_listener_prefetch_count prefetch_count=self._pika_engine.rpc_listener_prefetch_count
) )
listener.start() listener.start()
return listener return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def listen_for_notifications(self, targets_and_priorities, pool): def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback,
batch_size, batch_timeout):
listener = pika_drv_poller.NotificationPikaPoller( listener = pika_drv_poller.NotificationPikaPoller(
self._pika_engine, targets_and_priorities, self._pika_engine, targets_and_priorities,
prefetch_count=( prefetch_count=(
@ -351,7 +354,8 @@ class PikaDriver(base.BaseDriver):
queue_name=pool queue_name=pool
) )
listener.start() listener.start()
return listener return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def cleanup(self): def cleanup(self):
self._reply_listener.cleanup() self._reply_listener.cleanup()

View File

@ -251,15 +251,20 @@ class ZmqDriver(base.BaseDriver):
client = self.notifier.get() client = self.notifier.get()
client.send_notify(target, ctxt, message, version, retry) client.send_notify(target, ctxt, message, version, retry)
def listen(self, target): def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
"""Listen to a specified target on a server side """Listen to a specified target on a server side
:param target: Message destination target :param target: Message destination target
:type target: oslo_messaging.Target :type target: oslo_messaging.Target
""" """
return zmq_server.ZmqServer(self, self.conf, self.matchmaker, target) listener = zmq_server.ZmqServer(self, self.conf, self.matchmaker,
target)
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def listen_for_notifications(self, targets_and_priorities, pool): def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
"""Listen to a specified list of targets on a server side """Listen to a specified list of targets on a server side
:param targets_and_priorities: List of pairs (target, priority) :param targets_and_priorities: List of pairs (target, priority)
@ -267,8 +272,10 @@ class ZmqDriver(base.BaseDriver):
:param pool: Not used for zmq implementation :param pool: Not used for zmq implementation
:type pool: object :type pool: object
""" """
return zmq_server.ZmqNotificationServer( listener = zmq_server.ZmqNotificationServer(
self, self.conf, self.matchmaker, targets_and_priorities) self, self.conf, self.matchmaker, targets_and_priorities)
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def cleanup(self): def cleanup(self):
"""Cleanup all driver's connections finally """Cleanup all driver's connections finally

View File

@ -27,7 +27,7 @@ from oslo_messaging._drivers.pika_driver import pika_message as pika_drv_msg
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class PikaPoller(base.Listener): class PikaPoller(base.PollStyleListener):
"""Provides user friendly functionality for RabbitMQ message consuming, """Provides user friendly functionality for RabbitMQ message consuming,
handles low level connectivity problems and restore connection if some handles low level connectivity problems and restore connection if some
connectivity related problem detected connectivity related problem detected
@ -43,8 +43,8 @@ class PikaPoller(base.Listener):
:param incoming_message_class: PikaIncomingMessage, wrapper for :param incoming_message_class: PikaIncomingMessage, wrapper for
consumed RabbitMQ message consumed RabbitMQ message
""" """
super(PikaPoller, self).__init__(prefetch_count)
self._pika_engine = pika_engine self._pika_engine = pika_engine
self._prefetch_count = prefetch_count
self._incoming_message_class = incoming_message_class self._incoming_message_class = incoming_message_class
self._connection = None self._connection = None
@ -65,7 +65,7 @@ class PikaPoller(base.Listener):
for_listening=True for_listening=True
) )
self._channel = self._connection.channel() self._channel = self._connection.channel()
self._channel.basic_qos(prefetch_count=self._prefetch_count) self._channel.basic_qos(prefetch_count=self.prefetch_size)
if self._queues_to_consume is None: if self._queues_to_consume is None:
self._queues_to_consume = self._declare_queue_binding() self._queues_to_consume = self._declare_queue_binding()
@ -161,27 +161,23 @@ class PikaPoller(base.Listener):
if message.need_ack(): if message.need_ack():
del self._message_queue[i] del self._message_queue[i]
def poll(self, timeout=None, prefetch_size=1): @base.batch_poll_helper
def poll(self, timeout=None):
"""Main method of this class - consumes message from RabbitMQ """Main method of this class - consumes message from RabbitMQ
:param: timeout: float, seconds, timeout for waiting new incoming :param: timeout: float, seconds, timeout for waiting new incoming
message, None means wait forever message, None means wait forever
:param: prefetch_size: Integer, count of messages which we are want to
poll. It blocks until prefetch_size messages are consumed or until
timeout gets expired
:return: list of PikaIncomingMessage, RabbitMQ messages :return: list of PikaIncomingMessage, RabbitMQ messages
""" """
with timeutils.StopWatch(timeout) as stop_watch: with timeutils.StopWatch(timeout) as stop_watch:
while True: while True:
with self._lock: with self._lock:
last_queue_size = len(self._message_queue) if self._message_queue:
return self._message_queue.pop(0)
if (last_queue_size >= prefetch_size if stop_watch.expired():
or stop_watch.expired()): return None
result = self._message_queue[:prefetch_size]
del self._message_queue[:prefetch_size]
return result
try: try:
if self._started: if self._started:
@ -202,11 +198,10 @@ class PikaPoller(base.Listener):
self._connection.process_data_events( self._connection.process_data_events(
time_limit=0 time_limit=0
) )
# and return result if we don't see new messages
if last_queue_size == len(self._message_queue): # and return if we don't see new messages
result = self._message_queue[:prefetch_size] if not self._message_queue:
del self._message_queue[:prefetch_size] return None
return result
except pika_drv_exc.EstablishConnectionException as e: except pika_drv_exc.EstablishConnectionException as e:
LOG.warning( LOG.warning(
"Problem during establishing connection for pika " "Problem during establishing connection for pika "

View File

@ -145,7 +145,7 @@ class Queue(object):
self._pop_wake_condition.notify_all() self._pop_wake_condition.notify_all()
class ProtonListener(base.Listener): class ProtonListener(base.PollStyleListener):
def __init__(self, driver): def __init__(self, driver):
super(ProtonListener, self).__init__(driver.prefetch_size) super(ProtonListener, self).__init__(driver.prefetch_size)
self.driver = driver self.driver = driver
@ -266,15 +266,19 @@ class ProtonDriver(base.BaseDriver):
return self.send(target, ctxt, message, envelope=(version == 2.0)) return self.send(target, ctxt, message, envelope=(version == 2.0))
@_ensure_connect_called @_ensure_connect_called
def listen(self, target): def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
"""Construct a Listener for the given target.""" """Construct a Listener for the given target."""
LOG.debug("Listen to %s", target) LOG.debug("Listen to %s", target)
listener = ProtonListener(self) listener = ProtonListener(self)
self._ctrl.add_task(drivertasks.ListenTask(target, listener)) self._ctrl.add_task(drivertasks.ListenTask(target, listener))
return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
return listener return listener
@_ensure_connect_called @_ensure_connect_called
def listen_for_notifications(self, targets_and_priorities, pool): def listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
LOG.debug("Listen for notifications %s", targets_and_priorities) LOG.debug("Listen for notifications %s", targets_and_priorities)
if pool: if pool:
raise NotImplementedError('"pool" not implemented by ' raise NotImplementedError('"pool" not implemented by '
@ -284,7 +288,8 @@ class ProtonDriver(base.BaseDriver):
topic = '%s.%s' % (target.topic, priority) topic = '%s.%s' % (target.topic, priority)
t = messaging_target.Target(topic=topic) t = messaging_target.Target(topic=topic)
self._ctrl.add_task(drivertasks.ListenTask(t, listener, True)) self._ctrl.add_task(drivertasks.ListenTask(t, listener, True))
return listener return base.PollStyleListenerAdapter(listener, on_incoming_callback,
batch_size, batch_timeout)
def cleanup(self): def cleanup(self):
"""Release all resources.""" """Release all resources."""

View File

@ -28,7 +28,7 @@ LOG = logging.getLogger(__name__)
zmq = zmq_async.import_zmq() zmq = zmq_async.import_zmq()
class ZmqServer(base.Listener): class ZmqServer(base.PollStyleListener):
def __init__(self, driver, conf, matchmaker, target, poller=None): def __init__(self, driver, conf, matchmaker, target, poller=None):
super(ZmqServer, self).__init__() super(ZmqServer, self).__init__()
@ -47,7 +47,7 @@ class ZmqServer(base.Listener):
self.consumers.append(self.sub_consumer) self.consumers.append(self.sub_consumer)
@base.batch_poll_helper @base.batch_poll_helper
def poll(self, timeout=None, prefetch_size=1): def poll(self, timeout=None):
message, socket = self.poller.poll( message, socket = self.poller.poll(
timeout or self.conf.rpc_poll_timeout) timeout or self.conf.rpc_poll_timeout)
return message return message
@ -63,7 +63,7 @@ class ZmqServer(base.Listener):
consumer.cleanup() consumer.cleanup()
class ZmqNotificationServer(base.Listener): class ZmqNotificationServer(base.PollStyleListener):
def __init__(self, driver, conf, matchmaker, targets_and_priorities): def __init__(self, driver, conf, matchmaker, targets_and_priorities):
super(ZmqNotificationServer, self).__init__() super(ZmqNotificationServer, self).__init__()
@ -82,7 +82,7 @@ class ZmqNotificationServer(base.Listener):
self.driver, self.conf, self.matchmaker, t, self.poller)) self.driver, self.conf, self.matchmaker, t, self.poller))
@base.batch_poll_helper @base.batch_poll_helper
def poll(self, timeout=None, prefetch_size=1): def poll(self, timeout=None):
message, socket = self.poller.poll( message, socket = self.poller.poll(
timeout or self.conf.rpc_poll_timeout) timeout or self.conf.rpc_poll_timeout)
return message return message

View File

@ -127,10 +127,9 @@ class NotificationServer(msg_server.MessageHandlingServer):
) )
def _create_listener(self): def _create_listener(self):
return msg_server.SingleMessageListenerAdapter( return self.transport._listen_for_notifications(
self.transport._listen_for_notifications( self._targets_priorities, self._pool,
self._targets_priorities, self._pool lambda incoming: self._on_incoming(incoming[0]), 1, None
)
) )
def _process_incoming(self, incoming): def _process_incoming(self, incoming):
@ -163,12 +162,9 @@ class BatchNotificationServer(NotificationServer):
self._batch_timeout = batch_timeout self._batch_timeout = batch_timeout
def _create_listener(self): def _create_listener(self):
return msg_server.BatchMessageListenerAdapter( return self.transport._listen_for_notifications(
self.transport._listen_for_notifications( self._targets_priorities, self._pool, self._on_incoming,
self._targets_priorities, self._pool self._batch_size, self._batch_timeout,
),
timeout=self._batch_timeout,
batch_size=self._batch_size
) )
def _process_incoming(self, incoming): def _process_incoming(self, incoming):

View File

@ -118,8 +118,9 @@ class RPCServer(msg_server.MessageHandlingServer):
self._target = target self._target = target
def _create_listener(self): def _create_listener(self):
return msg_server.SingleMessageListenerAdapter( return self.transport._listen(
self.transport._listen(self._target) self._target,
lambda incoming: self._on_incoming(incoming[0]), 1, None
) )
def _process_incoming(self, incoming): def _process_incoming(self, incoming):

View File

@ -33,7 +33,6 @@ import traceback
from oslo_config import cfg from oslo_config import cfg
from oslo_service import service from oslo_service import service
from oslo_utils import eventletutils from oslo_utils import eventletutils
from oslo_utils import excutils
from oslo_utils import timeutils from oslo_utils import timeutils
import six import six
from stevedore import driver from stevedore import driver
@ -297,41 +296,6 @@ def ordered(after=None, reset_after=None):
return _ordered return _ordered
@six.add_metaclass(abc.ABCMeta)
class MessageListenerAdapter(object):
def __init__(self, driver_listener, timeout=None):
self._driver_listener = driver_listener
self._timeout = timeout
@abc.abstractmethod
def poll(self):
"""Poll incoming and return incoming request"""
def stop(self):
self._driver_listener.stop()
def cleanup(self):
self._driver_listener.cleanup()
class SingleMessageListenerAdapter(MessageListenerAdapter):
def poll(self):
msgs = self._driver_listener.poll(prefetch_size=1,
timeout=self._timeout)
return msgs[0] if msgs else None
class BatchMessageListenerAdapter(MessageListenerAdapter):
def __init__(self, driver_listener, timeout=None, batch_size=1):
super(BatchMessageListenerAdapter, self).__init__(driver_listener,
timeout)
self._batch_size = batch_size
def poll(self):
return self._driver_listener.poll(prefetch_size=self._batch_size,
timeout=self._timeout)
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner): class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
"""Server for handling messages. """Server for handling messages.
@ -377,15 +341,21 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
self._executor_cls = mgr.driver self._executor_cls = mgr.driver
self._work_executor = None self._work_executor = None
self._poll_executor = None
self._started = False self._started = False
super(MessageHandlingServer, self).__init__() super(MessageHandlingServer, self).__init__()
def _on_incoming(self, incoming):
"""Hanles on_incoming event
:param incoming: incoming request.
"""
self._work_executor.submit(self._process_incoming, incoming)
@abc.abstractmethod @abc.abstractmethod
def _process_incoming(self, incoming): def _process_incoming(self, incoming):
"""Process incoming request """Perform processing incoming request
:param incoming: incoming request. :param incoming: incoming request.
""" """
@ -420,11 +390,6 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
'instantiate a new object.')) 'instantiate a new object.'))
self._started = True self._started = True
try:
self.listener = self._create_listener()
except driver_base.TransportDriverError as ex:
raise ServerListenError(self.target, ex)
executor_opts = {} executor_opts = {}
if self.executor_type == "threading": if self.executor_type == "threading":
@ -440,9 +405,13 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
) )
self._work_executor = self._executor_cls(**executor_opts) self._work_executor = self._executor_cls(**executor_opts)
self._poll_executor = self._executor_cls(**executor_opts)
return lambda: self._poll_executor.submit(self._runner) try:
self.listener = self._create_listener()
except driver_base.TransportDriverError as ex:
raise ServerListenError(self.target, ex)
return self.listener.start
@ordered(after='start') @ordered(after='start')
def stop(self): def stop(self):
@ -456,24 +425,6 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
self.listener.stop() self.listener.stop()
self._started = False self._started = False
@excutils.forever_retry_uncaught_exceptions
def _runner(self):
while self._started:
incoming = self.listener.poll()
if incoming:
self._work_executor.submit(self._process_incoming, incoming)
# listener is stopped but we need to process all already consumed
# messages
while True:
incoming = self.listener.poll()
if incoming:
self._work_executor.submit(self._process_incoming, incoming)
else:
return
@ordered(after='stop') @ordered(after='stop')
def wait(self): def wait(self):
"""Wait for message processing to complete. """Wait for message processing to complete.
@ -485,7 +436,7 @@ class MessageHandlingServer(service.ServiceBase, _OrderedTaskRunner):
Once it's finished, the underlying driver resources associated to this Once it's finished, the underlying driver resources associated to this
server are released (like closing useless network connections). server are released (like closing useless network connections).
""" """
self._poll_executor.shutdown(wait=True) self.listener.wait()
self._work_executor.shutdown(wait=True) self._work_executor.shutdown(wait=True)
# Close listener connection after processing all messages # Close listener connection after processing all messages

View File

@ -106,7 +106,7 @@ class PikaPollerTestCase(unittest.TestCase):
self._poller_connection_mock.process_data_events.side_effect = f self._poller_connection_mock.process_data_events.side_effect = f
poller.start() poller.start()
res = poller.poll(prefetch_size=1) res = poller.poll(batch_size=1)
self.assertEqual(len(res), 1) self.assertEqual(len(res), 1)
self.assertEqual(res[0], incoming_message_class_mock.return_value) self.assertEqual(res[0], incoming_message_class_mock.return_value)
self.assertEqual( self.assertEqual(
@ -116,7 +116,7 @@ class PikaPollerTestCase(unittest.TestCase):
poller.stop() poller.stop()
res2 = poller.poll(prefetch_size=n) res2 = poller.poll(batch_size=n)
self.assertEqual(len(res2), n - 1) self.assertEqual(len(res2), n - 1)
self.assertEqual(incoming_message_class_mock.call_count, n) self.assertEqual(incoming_message_class_mock.call_count, n)
@ -162,7 +162,7 @@ class PikaPollerTestCase(unittest.TestCase):
self._poller_connection_mock.process_data_events.side_effect = f self._poller_connection_mock.process_data_events.side_effect = f
poller.start() poller.start()
res = poller.poll(prefetch_size=n) res = poller.poll(batch_size=n)
self.assertEqual(len(res), n) self.assertEqual(len(res), n)
self.assertEqual(incoming_message_class_mock.call_count, n) self.assertEqual(incoming_message_class_mock.call_count, n)
@ -210,7 +210,7 @@ class PikaPollerTestCase(unittest.TestCase):
self._poller_connection_mock.process_data_events.side_effect = f self._poller_connection_mock.process_data_events.side_effect = f
poller.start() poller.start()
res = poller.poll(prefetch_size=n, timeout=timeout) res = poller.poll(batch_size=n, timeout=timeout)
self.assertEqual(len(res), success_count) self.assertEqual(len(res), success_count)
self.assertEqual(incoming_message_class_mock.call_count, success_count) self.assertEqual(incoming_message_class_mock.call_count, success_count)

View File

@ -203,7 +203,8 @@ class TestKafkaListener(test_utils.BaseTestCase):
def test_create_listener(self, fake_consumer, fake_ensure_connection): def test_create_listener(self, fake_consumer, fake_ensure_connection):
fake_target = oslo_messaging.Target(topic='fake_topic') fake_target = oslo_messaging.Target(topic='fake_topic')
fake_targets_and_priorities = [(fake_target, 'info')] fake_targets_and_priorities = [(fake_target, 'info')]
self.driver.listen_for_notifications(fake_targets_and_priorities) self.driver.listen_for_notifications(fake_targets_and_priorities, None,
None, None, None)
self.assertEqual(1, len(fake_consumer.mock_calls)) self.assertEqual(1, len(fake_consumer.mock_calls))
@mock.patch.object(kafka_driver.Connection, '_ensure_connection') @mock.patch.object(kafka_driver.Connection, '_ensure_connection')
@ -220,7 +221,8 @@ class TestKafkaListener(test_utils.BaseTestCase):
(oslo_messaging.Target(topic="fake_topic", (oslo_messaging.Target(topic="fake_topic",
exchange="test3"), 'error'), exchange="test3"), 'error'),
] ]
self.driver.listen_for_notifications(fake_targets_and_priorities) self.driver.listen_for_notifications(fake_targets_and_priorities, None,
None, None, None)
self.assertEqual(1, len(fake_consumer.mock_calls)) self.assertEqual(1, len(fake_consumer.mock_calls))
fake_consumer.assert_called_once_with(set(['fake_topic.error', fake_consumer.assert_called_once_with(set(['fake_topic.error',
'fake_topic.info']), 'fake_topic.info']),
@ -232,7 +234,8 @@ class TestKafkaListener(test_utils.BaseTestCase):
fake_target = oslo_messaging.Target(topic='fake_topic') fake_target = oslo_messaging.Target(topic='fake_topic')
fake_targets_and_priorities = [(fake_target, 'info')] fake_targets_and_priorities = [(fake_target, 'info')]
listener = self.driver.listen_for_notifications( listener = self.driver.listen_for_notifications(
fake_targets_and_priorities) fake_targets_and_priorities, None, None, None,
None)._poll_style_listener
listener.conn.consume = mock.MagicMock() listener.conn.consume = mock.MagicMock()
listener.conn.consume.return_value = ( listener.conn.consume.return_value = (
iter([kafka.common.KafkaMessage( iter([kafka.common.KafkaMessage(
@ -264,7 +267,8 @@ class TestWithRealKafkaBroker(test_utils.BaseTestCase):
targets_and_priorities = [(target, 'fake_info')] targets_and_priorities = [(target, 'fake_info')]
listener = self.driver.listen_for_notifications( listener = self.driver.listen_for_notifications(
targets_and_priorities) targets_and_priorities, None, None, None,
None)._poll_style_listener
fake_context = {"fake_context_key": "fake_context_value"} fake_context = {"fake_context_key": "fake_context_value"}
fake_message = {"fake_message_key": "fake_message_value"} fake_message = {"fake_message_key": "fake_message_value"}
self.driver.send_notification( self.driver.send_notification(
@ -281,7 +285,8 @@ class TestWithRealKafkaBroker(test_utils.BaseTestCase):
targets_and_priorities = [(target, 'fake_info')] targets_and_priorities = [(target, 'fake_info')]
listener = self.driver.listen_for_notifications( listener = self.driver.listen_for_notifications(
targets_and_priorities) targets_and_priorities, None, None, None,
None)._poll_style_listener
fake_context = {"fake_context_key": "fake_context_value"} fake_context = {"fake_context_key": "fake_context_value"}
fake_message = {"fake_message_key": "fake_message_value"} fake_message = {"fake_message_key": "fake_message_value"}
self.driver.send_notification( self.driver.send_notification(
@ -299,9 +304,10 @@ class TestWithRealKafkaBroker(test_utils.BaseTestCase):
targets_and_priorities = [(target, 'fake_info')] targets_and_priorities = [(target, 'fake_info')]
listener = self.driver.listen_for_notifications( listener = self.driver.listen_for_notifications(
targets_and_priorities) targets_and_priorities, None, None, None,
None)._poll_style_listener
deadline = time.time() + 3 deadline = time.time() + 3
received_message = listener.poll(timeout=3) received_message = listener.poll(batch_timeout=3)
self.assertEqual(0, int(deadline - time.time())) self.assertEqual(0, int(deadline - time.time()))
self.assertEqual([], received_message) self.assertEqual([], received_message)

View File

@ -435,7 +435,7 @@ class TestSendReceive(test_utils.BaseTestCase):
target = oslo_messaging.Target(topic='testtopic') target = oslo_messaging.Target(topic='testtopic')
listener = driver.listen(target) listener = driver.listen(target, None, None, None)._poll_style_listener
senders = [] senders = []
replies = [] replies = []
@ -525,7 +525,7 @@ class TestPollAsync(test_utils.BaseTestCase):
self.addCleanup(transport.cleanup) self.addCleanup(transport.cleanup)
driver = transport._driver driver = transport._driver
target = oslo_messaging.Target(topic='testtopic') target = oslo_messaging.Target(topic='testtopic')
listener = driver.listen(target) listener = driver.listen(target, None, None, None)._poll_style_listener
received = listener.poll(timeout=0.050) received = listener.poll(timeout=0.050)
self.assertEqual([], received) self.assertEqual([], received)
@ -541,8 +541,7 @@ class TestRacyWaitForReply(test_utils.BaseTestCase):
target = oslo_messaging.Target(topic='testtopic') target = oslo_messaging.Target(topic='testtopic')
listener = driver.listen(target) listener = driver.listen(target, None, None, None)._poll_style_listener
senders = [] senders = []
replies = [] replies = []
msgs = [] msgs = []
@ -878,7 +877,7 @@ class TestReplyWireFormat(test_utils.BaseTestCase):
server=self.server, server=self.server,
fanout=self.fanout) fanout=self.fanout)
listener = driver.listen(target) listener = driver.listen(target, None, None, None)._poll_style_listener
connection, producer = _create_producer(target) connection, producer = _create_producer(target)
self.addCleanup(connection.release) self.addCleanup(connection.release)

View File

@ -42,7 +42,7 @@ class ZmqTestPortsRange(zmq_common.ZmqBaseTestCase):
for i in range(10): for i in range(10):
try: try:
target = oslo_messaging.Target(topic='testtopic_' + str(i)) target = oslo_messaging.Target(topic='testtopic_' + str(i))
new_listener = self.driver.listen(target) new_listener = self.driver.listen(target, None, None, None)
listeners.append(new_listener) listeners.append(new_listener)
except zmq_socket.ZmqPortRangeExceededException: except zmq_socket.ZmqPortRangeExceededException:
pass pass

View File

@ -39,12 +39,14 @@ class TestServerListener(object):
self.message = None self.message = None
def listen(self, target): def listen(self, target):
self.listener = self.driver.listen(target) self.listener = self.driver.listen(target, None, None,
None)._poll_style_listener
self.executor.execute() self.executor.execute()
def listen_notifications(self, targets_and_priorities): def listen_notifications(self, targets_and_priorities):
self.listener = self.driver.listen_for_notifications( self.listener = self.driver.listen_for_notifications(
targets_and_priorities, {}) targets_and_priorities, None, None, None,
None)._poll_style_listener
self.executor.execute() self.executor.execute()
def _run(self): def _run(self):

View File

@ -29,7 +29,7 @@ load_tests = testscenarios.load_tests_apply_scenarios
class ServerSetupMixin(object): class ServerSetupMixin(object):
class Server(threading.Thread): class Server(object):
def __init__(self, transport, topic, server, endpoint, serializer): def __init__(self, transport, topic, server, endpoint, serializer):
self.controller = ServerSetupMixin.ServerController() self.controller = ServerSetupMixin.ServerController()
target = oslo_messaging.Target(topic=topic, server=server) target = oslo_messaging.Target(topic=topic, server=server)
@ -39,9 +39,6 @@ class ServerSetupMixin(object):
self.controller], self.controller],
serializer=serializer) serializer=serializer)
super(ServerSetupMixin.Server, self).__init__()
self.daemon = True
def wait(self): def wait(self):
# Wait for the executor to process the stop message, indicating all # Wait for the executor to process the stop message, indicating all
# test messages have been processed # test messages have been processed
@ -52,7 +49,7 @@ class ServerSetupMixin(object):
self.server.stop() self.server.stop()
self.server.wait() self.server.wait()
def run(self): def start(self):
self.server.start() self.server.start()
class ServerController(object): class ServerController(object):
@ -86,10 +83,7 @@ class ServerSetupMixin(object):
endpoint=endpoint, endpoint=endpoint,
serializer=self.serializer) serializer=self.serializer)
thread = threading.Thread(target=server.start) server.start()
thread.daemon = True
thread.start()
return server return server
def _stop_server(self, client, server, topic=None): def _stop_server(self, client, server, topic=None):
@ -492,9 +486,9 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin):
else: else:
endpoint1 = endpoint2 = TestEndpoint() endpoint1 = endpoint2 = TestEndpoint()
thread1 = self._setup_server(transport1, endpoint1, server1 = self._setup_server(transport1, endpoint1,
topic=self.topic1, server=self.server1) topic=self.topic1, server=self.server1)
thread2 = self._setup_server(transport2, endpoint2, server2 = self._setup_server(transport2, endpoint2,
topic=self.topic2, server=self.server2) topic=self.topic2, server=self.server2)
client1 = self._setup_client(transport1, topic=self.topic1) client1 = self._setup_client(transport1, topic=self.topic1)
@ -513,12 +507,10 @@ class TestMultipleServers(test_utils.BaseTestCase, ServerSetupMixin):
(client1.call if self.call1 else client1.cast)({}, 'ping', arg='1') (client1.call if self.call1 else client1.cast)({}, 'ping', arg='1')
(client2.call if self.call2 else client2.cast)({}, 'ping', arg='2') (client2.call if self.call2 else client2.cast)({}, 'ping', arg='2')
self.assertTrue(thread1.isAlive())
self._stop_server(client1.prepare(fanout=None), self._stop_server(client1.prepare(fanout=None),
thread1, topic=self.topic1) server1, topic=self.topic1)
self.assertTrue(thread2.isAlive())
self._stop_server(client2.prepare(fanout=None), self._stop_server(client2.prepare(fanout=None),
thread2, topic=self.topic2) server2, topic=self.topic2)
def check(pings, expect): def check(pings, expect):
self.assertEqual(len(expect), len(pings)) self.assertEqual(len(expect), len(pings))
@ -560,14 +552,13 @@ class TestServerLocking(test_utils.BaseTestCase):
class MessageHandlingServerImpl(oslo_messaging.MessageHandlingServer): class MessageHandlingServerImpl(oslo_messaging.MessageHandlingServer):
def _create_listener(self): def _create_listener(self):
pass return mock.Mock()
def _process_incoming(self, incoming): def _process_incoming(self, incoming):
pass pass
self.server = MessageHandlingServerImpl(mock.Mock(), mock.Mock()) self.server = MessageHandlingServerImpl(mock.Mock(), mock.Mock())
self.server._executor_cls = FakeExecutor self.server._executor_cls = FakeExecutor
self.server._create_listener = mock.Mock()
def test_start_stop_wait(self): def test_start_stop_wait(self):
# Test a simple execution of start, stop, wait in order # Test a simple execution of start, stop, wait in order
@ -576,9 +567,8 @@ class TestServerLocking(test_utils.BaseTestCase):
self.server.stop() self.server.stop()
self.server.wait() self.server.wait()
self.assertEqual(len(self.executors), 2) self.assertEqual(len(self.executors), 1)
self.assertEqual(self.executors[0]._calls, ['shutdown']) self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown'])
self.assertTrue(self.server.listener.cleanup.called) self.assertTrue(self.server.listener.cleanup.called)
def test_reversed_order(self): def test_reversed_order(self):
@ -597,9 +587,8 @@ class TestServerLocking(test_utils.BaseTestCase):
self.server.wait() self.server.wait()
self.assertEqual(len(self.executors), 2) self.assertEqual(len(self.executors), 1)
self.assertEqual(self.executors[0]._calls, ['shutdown']) self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown'])
def test_wait_for_running_task(self): def test_wait_for_running_task(self):
# Test that if 2 threads call a method simultaneously, both will wait, # Test that if 2 threads call a method simultaneously, both will wait,
@ -660,9 +649,8 @@ class TestServerLocking(test_utils.BaseTestCase):
# Check that both threads have finished, start was only called once, # Check that both threads have finished, start was only called once,
# and execute ran # and execute ran
self.assertTrue(waiter_finished.is_set()) self.assertTrue(waiter_finished.is_set())
self.assertEqual(2, len(self.executors)) self.assertEqual(1, len(self.executors))
self.assertEqual(self.executors[0]._calls, []) self.assertEqual(self.executors[0]._calls, [])
self.assertEqual(self.executors[1]._calls, ['submit'])
def test_start_stop_wait_stop_wait(self): def test_start_stop_wait_stop_wait(self):
# Test that we behave correctly when calling stop/wait more than once. # Test that we behave correctly when calling stop/wait more than once.
@ -674,9 +662,8 @@ class TestServerLocking(test_utils.BaseTestCase):
self.server.stop() self.server.stop()
self.server.wait() self.server.wait()
self.assertEqual(len(self.executors), 2) self.assertEqual(len(self.executors), 1)
self.assertEqual(self.executors[0]._calls, ['shutdown']) self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown'])
self.assertTrue(self.server.listener.cleanup.called) self.assertTrue(self.server.listener.cleanup.called)
def test_state_wrapping(self): def test_state_wrapping(self):
@ -711,9 +698,8 @@ class TestServerLocking(test_utils.BaseTestCase):
complete_waiting_callback.wait() complete_waiting_callback.wait()
# The server should have started, but stop should not have been called # The server should have started, but stop should not have been called
self.assertEqual(2, len(self.executors)) self.assertEqual(1, len(self.executors))
self.assertEqual(self.executors[0]._calls, []) self.assertEqual(self.executors[0]._calls, [])
self.assertEqual(self.executors[1]._calls, ['submit'])
self.assertFalse(thread1_finished.is_set()) self.assertFalse(thread1_finished.is_set())
self.server.stop() self.server.stop()
@ -721,20 +707,17 @@ class TestServerLocking(test_utils.BaseTestCase):
# We should have gone through all the states, and thread1 should still # We should have gone through all the states, and thread1 should still
# be waiting # be waiting
self.assertEqual(2, len(self.executors)) self.assertEqual(1, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['shutdown']) self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown'])
self.assertFalse(thread1_finished.is_set()) self.assertFalse(thread1_finished.is_set())
# Start again # Start again
self.server.start() self.server.start()
# We should now record 4 executors (2 for each server) # We should now record 4 executors (2 for each server)
self.assertEqual(4, len(self.executors)) self.assertEqual(2, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['shutdown']) self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown']) self.assertEqual(self.executors[1]._calls, [])
self.assertEqual(self.executors[2]._calls, [])
self.assertEqual(self.executors[3]._calls, ['submit'])
self.assertFalse(thread1_finished.is_set()) self.assertFalse(thread1_finished.is_set())
# Allow thread1 to complete # Allow thread1 to complete
@ -743,11 +726,9 @@ class TestServerLocking(test_utils.BaseTestCase):
# thread1 should now have finished, and stop should not have been # thread1 should now have finished, and stop should not have been
# called again on either the first or second executor # called again on either the first or second executor
self.assertEqual(4, len(self.executors)) self.assertEqual(2, len(self.executors))
self.assertEqual(self.executors[0]._calls, ['shutdown']) self.assertEqual(self.executors[0]._calls, ['shutdown'])
self.assertEqual(self.executors[1]._calls, ['submit', 'shutdown']) self.assertEqual(self.executors[1]._calls, [])
self.assertEqual(self.executors[2]._calls, [])
self.assertEqual(self.executors[3]._calls, ['submit'])
self.assertTrue(thread1_finished.is_set()) self.assertTrue(thread1_finished.is_set())
@mock.patch.object(server_module, 'DEFAULT_LOG_AFTER', 1) @mock.patch.object(server_module, 'DEFAULT_LOG_AFTER', 1)

View File

@ -131,14 +131,15 @@ class TestAmqpSend(_AmqpBrokerTestCase):
"""Verify unused listener can cleanly shutdown.""" """Verify unused listener can cleanly shutdown."""
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url) driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target = oslo_messaging.Target(topic="test-topic") target = oslo_messaging.Target(topic="test-topic")
listener = driver.listen(target) listener = driver.listen(target, None, None, None)._poll_style_listener
self.assertIsInstance(listener, amqp_driver.ProtonListener) self.assertIsInstance(listener, amqp_driver.ProtonListener)
driver.cleanup() driver.cleanup()
def test_send_no_reply(self): def test_send_no_reply(self):
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url) driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target = oslo_messaging.Target(topic="test-topic") target = oslo_messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1) listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
rc = driver.send(target, {"context": True}, rc = driver.send(target, {"context": True},
{"msg": "value"}, wait_for_reply=False) {"msg": "value"}, wait_for_reply=False)
self.assertIsNone(rc) self.assertIsNone(rc)
@ -150,9 +151,11 @@ class TestAmqpSend(_AmqpBrokerTestCase):
def test_send_exchange_with_reply(self): def test_send_exchange_with_reply(self):
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url) driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target1 = oslo_messaging.Target(topic="test-topic", exchange="e1") target1 = oslo_messaging.Target(topic="test-topic", exchange="e1")
listener1 = _ListenerThread(driver.listen(target1), 1) listener1 = _ListenerThread(
driver.listen(target1, None, None, None)._poll_style_listener, 1)
target2 = oslo_messaging.Target(topic="test-topic", exchange="e2") target2 = oslo_messaging.Target(topic="test-topic", exchange="e2")
listener2 = _ListenerThread(driver.listen(target2), 1) listener2 = _ListenerThread(
driver.listen(target2, None, None, None)._poll_style_listener, 1)
rc = driver.send(target1, {"context": "whatever"}, rc = driver.send(target1, {"context": "whatever"},
{"method": "echo", "id": "e1"}, {"method": "echo", "id": "e1"},
@ -178,9 +181,11 @@ class TestAmqpSend(_AmqpBrokerTestCase):
"""Verify the direct, shared, and fanout message patterns work.""" """Verify the direct, shared, and fanout message patterns work."""
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url) driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target1 = oslo_messaging.Target(topic="test-topic", server="server1") target1 = oslo_messaging.Target(topic="test-topic", server="server1")
listener1 = _ListenerThread(driver.listen(target1), 4) listener1 = _ListenerThread(
driver.listen(target1, None, None, None)._poll_style_listener, 4)
target2 = oslo_messaging.Target(topic="test-topic", server="server2") target2 = oslo_messaging.Target(topic="test-topic", server="server2")
listener2 = _ListenerThread(driver.listen(target2), 3) listener2 = _ListenerThread(
driver.listen(target2, None, None, None)._poll_style_listener, 3)
shared_target = oslo_messaging.Target(topic="test-topic") shared_target = oslo_messaging.Target(topic="test-topic")
fanout_target = oslo_messaging.Target(topic="test-topic", fanout_target = oslo_messaging.Target(topic="test-topic",
@ -250,7 +255,8 @@ class TestAmqpSend(_AmqpBrokerTestCase):
"""Verify send timeout.""" """Verify send timeout."""
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url) driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target = oslo_messaging.Target(topic="test-topic") target = oslo_messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1) listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
# the listener will drop this message: # the listener will drop this message:
try: try:
@ -276,7 +282,8 @@ class TestAmqpNotification(_AmqpBrokerTestCase):
notifications = [(oslo_messaging.Target(topic="topic-1"), 'info'), notifications = [(oslo_messaging.Target(topic="topic-1"), 'info'),
(oslo_messaging.Target(topic="topic-1"), 'error'), (oslo_messaging.Target(topic="topic-1"), 'error'),
(oslo_messaging.Target(topic="topic-2"), 'debug')] (oslo_messaging.Target(topic="topic-2"), 'debug')]
nl = driver.listen_for_notifications(notifications, None) nl = driver.listen_for_notifications(
notifications, None, None, None, None)._poll_style_listener
# send one for each support version: # send one for each support version:
msg_count = len(notifications) * 2 msg_count = len(notifications) * 2
@ -340,7 +347,8 @@ class TestAuthentication(test_utils.BaseTestCase):
url = oslo_messaging.TransportURL.parse(self.conf, addr) url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url) driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic") target = oslo_messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1) listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
rc = driver.send(target, {"context": True}, rc = driver.send(target, {"context": True},
{"method": "echo"}, wait_for_reply=True) {"method": "echo"}, wait_for_reply=True)
self.assertIsNotNone(rc) self.assertIsNotNone(rc)
@ -358,7 +366,8 @@ class TestAuthentication(test_utils.BaseTestCase):
url = oslo_messaging.TransportURL.parse(self.conf, addr) url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url) driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic") target = oslo_messaging.Target(topic="test-topic")
_ListenerThread(driver.listen(target), 1) _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
self.assertRaises(oslo_messaging.MessagingTimeout, self.assertRaises(oslo_messaging.MessagingTimeout,
driver.send, driver.send,
target, {"context": True}, target, {"context": True},
@ -429,7 +438,8 @@ mech_list: ${mechs}
url = oslo_messaging.TransportURL.parse(self.conf, addr) url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url) driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic") target = oslo_messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1) listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
rc = driver.send(target, {"context": True}, rc = driver.send(target, {"context": True},
{"method": "echo"}, wait_for_reply=True) {"method": "echo"}, wait_for_reply=True)
self.assertIsNotNone(rc) self.assertIsNotNone(rc)
@ -447,7 +457,8 @@ mech_list: ${mechs}
url = oslo_messaging.TransportURL.parse(self.conf, addr) url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url) driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic") target = oslo_messaging.Target(topic="test-topic")
_ListenerThread(driver.listen(target), 1) _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
self.assertRaises(oslo_messaging.MessagingTimeout, self.assertRaises(oslo_messaging.MessagingTimeout,
driver.send, driver.send,
target, {"context": True}, target, {"context": True},
@ -467,7 +478,8 @@ mech_list: ${mechs}
url = oslo_messaging.TransportURL.parse(self.conf, addr) url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url) driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic") target = oslo_messaging.Target(topic="test-topic")
_ListenerThread(driver.listen(target), 1) _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
self.assertRaises(oslo_messaging.MessagingTimeout, self.assertRaises(oslo_messaging.MessagingTimeout,
driver.send, driver.send,
target, {"context": True}, target, {"context": True},
@ -487,7 +499,8 @@ mech_list: ${mechs}
url = oslo_messaging.TransportURL.parse(self.conf, addr) url = oslo_messaging.TransportURL.parse(self.conf, addr)
driver = amqp_driver.ProtonDriver(self.conf, url) driver = amqp_driver.ProtonDriver(self.conf, url)
target = oslo_messaging.Target(topic="test-topic") target = oslo_messaging.Target(topic="test-topic")
listener = _ListenerThread(driver.listen(target), 1) listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 1)
rc = driver.send(target, {"context": True}, rc = driver.send(target, {"context": True},
{"method": "echo"}, wait_for_reply=True) {"method": "echo"}, wait_for_reply=True)
self.assertIsNotNone(rc) self.assertIsNotNone(rc)
@ -522,7 +535,8 @@ class TestFailover(test_utils.BaseTestCase):
driver = amqp_driver.ProtonDriver(self.conf, self._broker_url) driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
target = oslo_messaging.Target(topic="my-topic") target = oslo_messaging.Target(topic="my-topic")
listener = _ListenerThread(driver.listen(target), 2) listener = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 2)
# wait for listener links to come up # wait for listener links to come up
# 4 == 3 links per listener + 1 for the global reply queue # 4 == 3 links per listener + 1 for the global reply queue
@ -608,8 +622,10 @@ class TestFailover(test_utils.BaseTestCase):
target = oslo_messaging.Target(topic="my-topic") target = oslo_messaging.Target(topic="my-topic")
bcast = oslo_messaging.Target(topic="my-topic", fanout=True) bcast = oslo_messaging.Target(topic="my-topic", fanout=True)
listener1 = _ListenerThread(driver.listen(target), 2) listener1 = _ListenerThread(
listener2 = _ListenerThread(driver.listen(target), 2) driver.listen(target, None, None, None)._poll_style_listener, 2)
listener2 = _ListenerThread(
driver.listen(target, None, None, None)._poll_style_listener, 2)
# wait for 7 sending links to become active on the broker. # wait for 7 sending links to become active on the broker.
# 7 = 3 links per Listener + 1 global reply link # 7 = 3 links per Listener + 1 global reply link

View File

@ -38,7 +38,7 @@ class _FakeDriver(object):
def send_notification(self, *args, **kwargs): def send_notification(self, *args, **kwargs):
pass pass
def listen(self, target): def listen(self, target, on_incoming_callback, batch_size, batch_timeout):
pass pass
@ -314,10 +314,10 @@ class TestTransportMethodArgs(test_utils.BaseTestCase):
t = transport.Transport(_FakeDriver(cfg.CONF)) t = transport.Transport(_FakeDriver(cfg.CONF))
self.mox.StubOutWithMock(t._driver, 'listen') self.mox.StubOutWithMock(t._driver, 'listen')
t._driver.listen(self._target) t._driver.listen(self._target, None, 1, None)
self.mox.ReplayAll() self.mox.ReplayAll()
t._listen(self._target) t._listen(self._target, None, 1, None)
class TestTransportUrlCustomisation(test_utils.BaseTestCase): class TestTransportUrlCustomisation(test_utils.BaseTestCase):

View File

@ -96,21 +96,26 @@ class Transport(object):
self._driver.send_notification(target, ctxt, message, version, self._driver.send_notification(target, ctxt, message, version,
retry=retry) retry=retry)
def _listen(self, target): def _listen(self, target, on_incoming_callback, batch_size, batch_timeout):
if not (target.topic and target.server): if not (target.topic and target.server):
raise exceptions.InvalidTarget('A server\'s target must have ' raise exceptions.InvalidTarget('A server\'s target must have '
'topic and server names specified', 'topic and server names specified',
target) target)
return self._driver.listen(target) return self._driver.listen(target, on_incoming_callback, batch_size,
batch_timeout)
def _listen_for_notifications(self, targets_and_priorities, pool): def _listen_for_notifications(self, targets_and_priorities, pool,
on_incoming_callback, batch_size,
batch_timeout):
for target, priority in targets_and_priorities: for target, priority in targets_and_priorities:
if not target.topic: if not target.topic:
raise exceptions.InvalidTarget('A target must have ' raise exceptions.InvalidTarget('A target must have '
'topic specified', 'topic specified',
target) target)
return self._driver.listen_for_notifications( return self._driver.listen_for_notifications(
targets_and_priorities, pool) targets_and_priorities, pool, on_incoming_callback, batch_size,
batch_timeout
)
def cleanup(self): def cleanup(self):
"""Release all resources associated with this transport.""" """Release all resources associated with this transport."""