Implement wait_for_reply timeout in rabbit driver

Note - the tests use timeout=0.01 because timeout=0 doesn't seem to be
working for some reason.

Change-Id: I814a3decdad5ddce0a1a2301ba2d59fa928b53a7
This commit is contained in:
Mark McLoughlin 2013-08-01 21:45:32 +01:00
parent cb2623f46e
commit 7c305150ff
4 changed files with 50 additions and 34 deletions

View File

@ -20,6 +20,7 @@ import Queue
import threading
import uuid
from oslo import messaging
from oslo.messaging._drivers import amqp as rpc_amqp
from oslo.messaging._drivers import base
from oslo.messaging._drivers import common as rpc_common
@ -97,8 +98,6 @@ class AMQPListener(base.Listener):
while True:
if self.incoming:
return self.incoming.pop(0)
# FIXME(markmc): timeout?
self.conn.consume(limit=1)
@ -108,8 +107,12 @@ class ReplyWaiters(object):
self._queues = {}
self._wrn_threshhold = 10
def get(self, msg_id):
return self._queues[msg_id].get()
def get(self, msg_id, timeout):
try:
return self._queues[msg_id].get(block=True, timeout=timeout)
except Queue.Empty:
raise messaging.MessagingTimeout('Timed out waiting for a reply '
'to message ID %s' % msg_id)
def put(self, msg_id, message_data):
queue = self._queues.get(msg_id)
@ -176,7 +179,7 @@ class ReplyWaiter(object):
result = data['result']
return result, ending
def _poll_connection(self, msg_id):
def _poll_connection(self, msg_id, timeout):
while True:
while self.incoming:
message_data = self.incoming.pop(0)
@ -187,20 +190,23 @@ class ReplyWaiter(object):
self.waiters.put(incoming_msg_id, message_data)
# FIXME(markmc): timeout?
self.conn.consume(limit=1)
try:
self.conn.consume(limit=1, timeout=timeout)
except rpc_common.Timeout:
raise messaging.MessagingTimeout('Timed out waiting for a '
'reply to message ID %s'
% msg_id)
def _poll_queue(self, msg_id):
def _poll_queue(self, msg_id, timeout):
while True:
# FIXME(markmc): timeout?
message = self.waiters.get(msg_id)
message = self.waiters.get(msg_id, timeout)
if message is None:
return None, None, True # lock was released
reply, ending = self._process_reply(message)
return reply, ending, False
def wait(self, msg_id):
def wait(self, msg_id, timeout):
# NOTE(markmc): multiple threads may call this
# First thread calls consume, when it gets its reply
# it wakes up other threads and they call consume
@ -211,7 +217,7 @@ class ReplyWaiter(object):
if self.conn_lock.acquire(False):
try:
while True:
reply, ending = self._poll_connection(msg_id)
reply, ending = self._poll_connection(msg_id, timeout)
if reply:
final_reply = reply
elif ending:
@ -220,7 +226,7 @@ class ReplyWaiter(object):
self.conn_lock.release()
self.waiters.wake_all(msg_id)
else:
reply, ending, trylock = self._poll_queue(msg_id)
reply, ending, trylock = self._poll_queue(msg_id, timeout)
if trylock:
continue
if reply:
@ -308,8 +314,7 @@ class AMQPDriverBase(base.BaseDriver):
conn.topic_send(topic, msg, timeout=timeout)
if wait_for_reply:
# FIXME(markmc): timeout?
result = self._waiter.wait(msg_id)
result = self._waiter.wait(msg_id, timeout)
if isinstance(result, Exception):
raise result
return result

View File

@ -600,9 +600,9 @@ class Connection(object):
"""Send a notify message on a topic."""
self.publisher_send(NotifyPublisher, topic, msg)
def consume(self, limit=None):
def consume(self, limit=None, timeout=None):
"""Consume from all queues/consumers."""
it = self.iterconsume(limit=limit)
it = self.iterconsume(limit=limit, timeout=timeout)
while True:
try:
it.next()

View File

@ -743,9 +743,9 @@ class Connection(object):
"""Send a notify message on a topic."""
self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs)
def consume(self, limit=None):
def consume(self, limit=None, timeout=None):
"""Consume from all queues/consumers."""
it = self.iterconsume(limit=limit)
it = self.iterconsume(limit=limit, timeout=timeout)
while True:
try:
it.next()

View File

@ -63,11 +63,17 @@ class TestSendReceive(test_utils.BaseTestCase):
('failure', dict(failure=True)),
]
_timeout = [
('no_timeout', dict(timeout=None)),
('timeout', dict(timeout=0.01)), # FIXME(markmc): timeout=0 is broken?
]
@classmethod
def generate_scenarios(cls):
cls.scenarios = testscenarios.multiply_scenarios(cls._n_senders,
cls._context,
cls._failure)
cls._failure,
cls._timeout)
def setUp(self):
super(TestSendReceive, self).setUp()
@ -95,11 +101,13 @@ class TestSendReceive(test_utils.BaseTestCase):
replies.append(driver.send(target,
self.ctxt,
{'foo': i},
wait_for_reply=True))
wait_for_reply=True,
timeout=self.timeout))
self.assertFalse(self.failure)
except ZeroDivisionError as e:
self.assertIsNone(self.timeout)
except (ZeroDivisionError, messaging.MessagingTimeout) as e:
replies.append(e)
self.assertTrue(self.failure)
self.assertTrue(self.failure or self.timeout is not None)
while len(senders) < self.n_senders:
senders.append(threading.Thread(target=send_and_wait_for_reply,
@ -120,22 +128,25 @@ class TestSendReceive(test_utils.BaseTestCase):
order[-1], order[-2] = order[-2], order[-1]
for i in order:
if self.failure:
try:
raise ZeroDivisionError
except Exception:
failure = sys.exc_info()
msgs[i].reply(failure=failure)
else:
msgs[i].reply({'bar': msgs[i].message['foo']})
if self.timeout is None:
if self.failure:
try:
raise ZeroDivisionError
except Exception:
failure = sys.exc_info()
msgs[i].reply(failure=failure)
else:
msgs[i].reply({'bar': msgs[i].message['foo']})
senders[i].join()
self.assertEqual(len(replies), len(senders))
for i, reply in enumerate(replies):
if not self.failure:
self.assertEqual(reply, {'bar': order[i]})
else:
if self.timeout is not None:
self.assertIsInstance(reply, messaging.MessagingTimeout)
elif self.failure:
self.assertIsInstance(reply, ZeroDivisionError)
else:
self.assertEqual(reply, {'bar': order[i]})
TestSendReceive.generate_scenarios()