Implement wait_for_reply timeout in rabbit driver

Note - the tests use timeout=0.01 because timeout=0 doesn't seem to be
working for some reason.

Change-Id: I814a3decdad5ddce0a1a2301ba2d59fa928b53a7
This commit is contained in:
Mark McLoughlin 2013-08-01 21:45:32 +01:00
parent cb2623f46e
commit 7c305150ff
4 changed files with 50 additions and 34 deletions

View File

@ -20,6 +20,7 @@ import Queue
import threading import threading
import uuid import uuid
from oslo import messaging
from oslo.messaging._drivers import amqp as rpc_amqp from oslo.messaging._drivers import amqp as rpc_amqp
from oslo.messaging._drivers import base from oslo.messaging._drivers import base
from oslo.messaging._drivers import common as rpc_common from oslo.messaging._drivers import common as rpc_common
@ -97,8 +98,6 @@ class AMQPListener(base.Listener):
while True: while True:
if self.incoming: if self.incoming:
return self.incoming.pop(0) return self.incoming.pop(0)
# FIXME(markmc): timeout?
self.conn.consume(limit=1) self.conn.consume(limit=1)
@ -108,8 +107,12 @@ class ReplyWaiters(object):
self._queues = {} self._queues = {}
self._wrn_threshhold = 10 self._wrn_threshhold = 10
def get(self, msg_id): def get(self, msg_id, timeout):
return self._queues[msg_id].get() try:
return self._queues[msg_id].get(block=True, timeout=timeout)
except Queue.Empty:
raise messaging.MessagingTimeout('Timed out waiting for a reply '
'to message ID %s' % msg_id)
def put(self, msg_id, message_data): def put(self, msg_id, message_data):
queue = self._queues.get(msg_id) queue = self._queues.get(msg_id)
@ -176,7 +179,7 @@ class ReplyWaiter(object):
result = data['result'] result = data['result']
return result, ending return result, ending
def _poll_connection(self, msg_id): def _poll_connection(self, msg_id, timeout):
while True: while True:
while self.incoming: while self.incoming:
message_data = self.incoming.pop(0) message_data = self.incoming.pop(0)
@ -187,20 +190,23 @@ class ReplyWaiter(object):
self.waiters.put(incoming_msg_id, message_data) self.waiters.put(incoming_msg_id, message_data)
# FIXME(markmc): timeout? try:
self.conn.consume(limit=1) self.conn.consume(limit=1, timeout=timeout)
except rpc_common.Timeout:
raise messaging.MessagingTimeout('Timed out waiting for a '
'reply to message ID %s'
% msg_id)
def _poll_queue(self, msg_id): def _poll_queue(self, msg_id, timeout):
while True: while True:
# FIXME(markmc): timeout? message = self.waiters.get(msg_id, timeout)
message = self.waiters.get(msg_id)
if message is None: if message is None:
return None, None, True # lock was released return None, None, True # lock was released
reply, ending = self._process_reply(message) reply, ending = self._process_reply(message)
return reply, ending, False return reply, ending, False
def wait(self, msg_id): def wait(self, msg_id, timeout):
# NOTE(markmc): multiple threads may call this # NOTE(markmc): multiple threads may call this
# First thread calls consume, when it gets its reply # First thread calls consume, when it gets its reply
# it wakes up other threads and they call consume # it wakes up other threads and they call consume
@ -211,7 +217,7 @@ class ReplyWaiter(object):
if self.conn_lock.acquire(False): if self.conn_lock.acquire(False):
try: try:
while True: while True:
reply, ending = self._poll_connection(msg_id) reply, ending = self._poll_connection(msg_id, timeout)
if reply: if reply:
final_reply = reply final_reply = reply
elif ending: elif ending:
@ -220,7 +226,7 @@ class ReplyWaiter(object):
self.conn_lock.release() self.conn_lock.release()
self.waiters.wake_all(msg_id) self.waiters.wake_all(msg_id)
else: else:
reply, ending, trylock = self._poll_queue(msg_id) reply, ending, trylock = self._poll_queue(msg_id, timeout)
if trylock: if trylock:
continue continue
if reply: if reply:
@ -308,8 +314,7 @@ class AMQPDriverBase(base.BaseDriver):
conn.topic_send(topic, msg, timeout=timeout) conn.topic_send(topic, msg, timeout=timeout)
if wait_for_reply: if wait_for_reply:
# FIXME(markmc): timeout? result = self._waiter.wait(msg_id, timeout)
result = self._waiter.wait(msg_id)
if isinstance(result, Exception): if isinstance(result, Exception):
raise result raise result
return result return result

View File

@ -600,9 +600,9 @@ class Connection(object):
"""Send a notify message on a topic.""" """Send a notify message on a topic."""
self.publisher_send(NotifyPublisher, topic, msg) self.publisher_send(NotifyPublisher, topic, msg)
def consume(self, limit=None): def consume(self, limit=None, timeout=None):
"""Consume from all queues/consumers.""" """Consume from all queues/consumers."""
it = self.iterconsume(limit=limit) it = self.iterconsume(limit=limit, timeout=timeout)
while True: while True:
try: try:
it.next() it.next()

View File

@ -743,9 +743,9 @@ class Connection(object):
"""Send a notify message on a topic.""" """Send a notify message on a topic."""
self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs) self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs)
def consume(self, limit=None): def consume(self, limit=None, timeout=None):
"""Consume from all queues/consumers.""" """Consume from all queues/consumers."""
it = self.iterconsume(limit=limit) it = self.iterconsume(limit=limit, timeout=timeout)
while True: while True:
try: try:
it.next() it.next()

View File

@ -63,11 +63,17 @@ class TestSendReceive(test_utils.BaseTestCase):
('failure', dict(failure=True)), ('failure', dict(failure=True)),
] ]
_timeout = [
('no_timeout', dict(timeout=None)),
('timeout', dict(timeout=0.01)), # FIXME(markmc): timeout=0 is broken?
]
@classmethod @classmethod
def generate_scenarios(cls): def generate_scenarios(cls):
cls.scenarios = testscenarios.multiply_scenarios(cls._n_senders, cls.scenarios = testscenarios.multiply_scenarios(cls._n_senders,
cls._context, cls._context,
cls._failure) cls._failure,
cls._timeout)
def setUp(self): def setUp(self):
super(TestSendReceive, self).setUp() super(TestSendReceive, self).setUp()
@ -95,11 +101,13 @@ class TestSendReceive(test_utils.BaseTestCase):
replies.append(driver.send(target, replies.append(driver.send(target,
self.ctxt, self.ctxt,
{'foo': i}, {'foo': i},
wait_for_reply=True)) wait_for_reply=True,
timeout=self.timeout))
self.assertFalse(self.failure) self.assertFalse(self.failure)
except ZeroDivisionError as e: self.assertIsNone(self.timeout)
except (ZeroDivisionError, messaging.MessagingTimeout) as e:
replies.append(e) replies.append(e)
self.assertTrue(self.failure) self.assertTrue(self.failure or self.timeout is not None)
while len(senders) < self.n_senders: while len(senders) < self.n_senders:
senders.append(threading.Thread(target=send_and_wait_for_reply, senders.append(threading.Thread(target=send_and_wait_for_reply,
@ -120,22 +128,25 @@ class TestSendReceive(test_utils.BaseTestCase):
order[-1], order[-2] = order[-2], order[-1] order[-1], order[-2] = order[-2], order[-1]
for i in order: for i in order:
if self.failure: if self.timeout is None:
try: if self.failure:
raise ZeroDivisionError try:
except Exception: raise ZeroDivisionError
failure = sys.exc_info() except Exception:
msgs[i].reply(failure=failure) failure = sys.exc_info()
else: msgs[i].reply(failure=failure)
msgs[i].reply({'bar': msgs[i].message['foo']}) else:
msgs[i].reply({'bar': msgs[i].message['foo']})
senders[i].join() senders[i].join()
self.assertEqual(len(replies), len(senders)) self.assertEqual(len(replies), len(senders))
for i, reply in enumerate(replies): for i, reply in enumerate(replies):
if not self.failure: if self.timeout is not None:
self.assertEqual(reply, {'bar': order[i]}) self.assertIsInstance(reply, messaging.MessagingTimeout)
else: elif self.failure:
self.assertIsInstance(reply, ZeroDivisionError) self.assertIsInstance(reply, ZeroDivisionError)
else:
self.assertEqual(reply, {'bar': order[i]})
TestSendReceive.generate_scenarios() TestSendReceive.generate_scenarios()