From 7c305150ffba2d31083eba42ee1bf6e3244befe7 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Thu, 1 Aug 2013 21:45:32 +0100 Subject: [PATCH] Implement wait_for_reply timeout in rabbit driver Note - the tests use timeout=0.01 because timeout=0 doesn't seem to be working for some reason. Change-Id: I814a3decdad5ddce0a1a2301ba2d59fa928b53a7 --- oslo/messaging/_drivers/amqpdriver.py | 35 ++++++++++++---------- oslo/messaging/_drivers/impl_qpid.py | 4 +-- oslo/messaging/_drivers/impl_rabbit.py | 4 +-- tests/test_rabbit.py | 41 ++++++++++++++++---------- 4 files changed, 50 insertions(+), 34 deletions(-) diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index dd1247fbe..2f1b33ef7 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -20,6 +20,7 @@ import Queue import threading import uuid +from oslo import messaging from oslo.messaging._drivers import amqp as rpc_amqp from oslo.messaging._drivers import base from oslo.messaging._drivers import common as rpc_common @@ -97,8 +98,6 @@ class AMQPListener(base.Listener): while True: if self.incoming: return self.incoming.pop(0) - - # FIXME(markmc): timeout? self.conn.consume(limit=1) @@ -108,8 +107,12 @@ class ReplyWaiters(object): self._queues = {} self._wrn_threshhold = 10 - def get(self, msg_id): - return self._queues[msg_id].get() + def get(self, msg_id, timeout): + try: + return self._queues[msg_id].get(block=True, timeout=timeout) + except Queue.Empty: + raise messaging.MessagingTimeout('Timed out waiting for a reply ' + 'to message ID %s' % msg_id) def put(self, msg_id, message_data): queue = self._queues.get(msg_id) @@ -176,7 +179,7 @@ class ReplyWaiter(object): result = data['result'] return result, ending - def _poll_connection(self, msg_id): + def _poll_connection(self, msg_id, timeout): while True: while self.incoming: message_data = self.incoming.pop(0) @@ -187,20 +190,23 @@ class ReplyWaiter(object): self.waiters.put(incoming_msg_id, message_data) - # FIXME(markmc): timeout? - self.conn.consume(limit=1) + try: + self.conn.consume(limit=1, timeout=timeout) + except rpc_common.Timeout: + raise messaging.MessagingTimeout('Timed out waiting for a ' + 'reply to message ID %s' + % msg_id) - def _poll_queue(self, msg_id): + def _poll_queue(self, msg_id, timeout): while True: - # FIXME(markmc): timeout? - message = self.waiters.get(msg_id) + message = self.waiters.get(msg_id, timeout) if message is None: return None, None, True # lock was released reply, ending = self._process_reply(message) return reply, ending, False - def wait(self, msg_id): + def wait(self, msg_id, timeout): # NOTE(markmc): multiple threads may call this # First thread calls consume, when it gets its reply # it wakes up other threads and they call consume @@ -211,7 +217,7 @@ class ReplyWaiter(object): if self.conn_lock.acquire(False): try: while True: - reply, ending = self._poll_connection(msg_id) + reply, ending = self._poll_connection(msg_id, timeout) if reply: final_reply = reply elif ending: @@ -220,7 +226,7 @@ class ReplyWaiter(object): self.conn_lock.release() self.waiters.wake_all(msg_id) else: - reply, ending, trylock = self._poll_queue(msg_id) + reply, ending, trylock = self._poll_queue(msg_id, timeout) if trylock: continue if reply: @@ -308,8 +314,7 @@ class AMQPDriverBase(base.BaseDriver): conn.topic_send(topic, msg, timeout=timeout) if wait_for_reply: - # FIXME(markmc): timeout? - result = self._waiter.wait(msg_id) + result = self._waiter.wait(msg_id, timeout) if isinstance(result, Exception): raise result return result diff --git a/oslo/messaging/_drivers/impl_qpid.py b/oslo/messaging/_drivers/impl_qpid.py index aacb130c9..cdb77924a 100644 --- a/oslo/messaging/_drivers/impl_qpid.py +++ b/oslo/messaging/_drivers/impl_qpid.py @@ -600,9 +600,9 @@ class Connection(object): """Send a notify message on a topic.""" self.publisher_send(NotifyPublisher, topic, msg) - def consume(self, limit=None): + def consume(self, limit=None, timeout=None): """Consume from all queues/consumers.""" - it = self.iterconsume(limit=limit) + it = self.iterconsume(limit=limit, timeout=timeout) while True: try: it.next() diff --git a/oslo/messaging/_drivers/impl_rabbit.py b/oslo/messaging/_drivers/impl_rabbit.py index 943bac0b2..d997d78ec 100644 --- a/oslo/messaging/_drivers/impl_rabbit.py +++ b/oslo/messaging/_drivers/impl_rabbit.py @@ -743,9 +743,9 @@ class Connection(object): """Send a notify message on a topic.""" self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs) - def consume(self, limit=None): + def consume(self, limit=None, timeout=None): """Consume from all queues/consumers.""" - it = self.iterconsume(limit=limit) + it = self.iterconsume(limit=limit, timeout=timeout) while True: try: it.next() diff --git a/tests/test_rabbit.py b/tests/test_rabbit.py index 0116f2182..4f845b8be 100644 --- a/tests/test_rabbit.py +++ b/tests/test_rabbit.py @@ -63,11 +63,17 @@ class TestSendReceive(test_utils.BaseTestCase): ('failure', dict(failure=True)), ] + _timeout = [ + ('no_timeout', dict(timeout=None)), + ('timeout', dict(timeout=0.01)), # FIXME(markmc): timeout=0 is broken? + ] + @classmethod def generate_scenarios(cls): cls.scenarios = testscenarios.multiply_scenarios(cls._n_senders, cls._context, - cls._failure) + cls._failure, + cls._timeout) def setUp(self): super(TestSendReceive, self).setUp() @@ -95,11 +101,13 @@ class TestSendReceive(test_utils.BaseTestCase): replies.append(driver.send(target, self.ctxt, {'foo': i}, - wait_for_reply=True)) + wait_for_reply=True, + timeout=self.timeout)) self.assertFalse(self.failure) - except ZeroDivisionError as e: + self.assertIsNone(self.timeout) + except (ZeroDivisionError, messaging.MessagingTimeout) as e: replies.append(e) - self.assertTrue(self.failure) + self.assertTrue(self.failure or self.timeout is not None) while len(senders) < self.n_senders: senders.append(threading.Thread(target=send_and_wait_for_reply, @@ -120,22 +128,25 @@ class TestSendReceive(test_utils.BaseTestCase): order[-1], order[-2] = order[-2], order[-1] for i in order: - if self.failure: - try: - raise ZeroDivisionError - except Exception: - failure = sys.exc_info() - msgs[i].reply(failure=failure) - else: - msgs[i].reply({'bar': msgs[i].message['foo']}) + if self.timeout is None: + if self.failure: + try: + raise ZeroDivisionError + except Exception: + failure = sys.exc_info() + msgs[i].reply(failure=failure) + else: + msgs[i].reply({'bar': msgs[i].message['foo']}) senders[i].join() self.assertEqual(len(replies), len(senders)) for i, reply in enumerate(replies): - if not self.failure: - self.assertEqual(reply, {'bar': order[i]}) - else: + if self.timeout is not None: + self.assertIsInstance(reply, messaging.MessagingTimeout) + elif self.failure: self.assertIsInstance(reply, ZeroDivisionError) + else: + self.assertEqual(reply, {'bar': order[i]}) TestSendReceive.generate_scenarios()