diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index cfbd16108..8c361627c 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -93,6 +93,8 @@ class AMQPListener(base.Listener): class ReplyWaiters(object): + WAKE_UP = object() + def __init__(self): self._queues = {} self._wrn_threshhold = 10 @@ -104,6 +106,12 @@ class ReplyWaiters(object): raise messaging.MessagingTimeout('Timed out waiting for a reply ' 'to message ID %s' % msg_id) + def check(self, msg_id): + try: + return self._queues[msg_id].get(block=False) + except Queue.Empty: + return None + def put(self, msg_id, message_data): queue = self._queues.get(msg_id) if not queue: @@ -117,7 +125,7 @@ class ReplyWaiters(object): def wake_all(self, except_id): msg_ids = [i for i in self._queues.keys() if i != except_id] for msg_id in msg_ids: - self.put(msg_id, None) + self.put(msg_id, self.WAKE_UP) def add(self, msg_id, queue): self._queues[msg_id] = queue @@ -189,10 +197,20 @@ class ReplyWaiter(object): % msg_id) def _poll_queue(self, msg_id, timeout): + message = self.waiters.get(msg_id, timeout) + if message is self.waiters.WAKE_UP: + return None, None, True # lock was released + + reply, ending = self._process_reply(message) + return reply, ending, False + + def _check_queue(self, msg_id): while True: - message = self.waiters.get(msg_id, timeout) + message = self.waiters.check(msg_id) + if message is self.waiters.WAKE_UP: + continue if message is None: - return None, None, True # lock was released + return None, None, True # queue is empty reply, ending = self._process_reply(message) return reply, ending, False @@ -213,6 +231,18 @@ class ReplyWaiter(object): if self.conn_lock.acquire(False): # Ok, we're the thread responsible for polling the connection try: + # Check the queue to see if a previous lock-holding thread + # queued up a reply already + while True: + reply, ending, empty = self._check_queue(msg_id) + if empty: + break + if not ending: + final_reply = reply + else: + return final_reply + + # Now actually poll the connection while True: reply, ending = self._poll_connection(msg_id, timeout) if not ending: diff --git a/tests/test_rabbit.py b/tests/test_rabbit.py index 142252ca9..3ee243ef2 100644 --- a/tests/test_rabbit.py +++ b/tests/test_rabbit.py @@ -24,6 +24,7 @@ import kombu import testscenarios from oslo import messaging +from oslo.messaging._drivers import amqpdriver from oslo.messaging._drivers import common as driver_common from oslo.messaging._drivers import impl_rabbit as rabbit_driver from oslo.messaging.openstack.common import jsonutils @@ -235,6 +236,86 @@ class TestSendReceive(test_utils.BaseTestCase): TestSendReceive.generate_scenarios() +class TestRacyWaitForReply(test_utils.BaseTestCase): + + def setUp(self): + super(TestRacyWaitForReply, self).setUp() + self.messaging_conf.transport_driver = 'rabbit' + self.messaging_conf.in_memory = True + + def test_send_receive(self): + transport = messaging.get_transport(self.conf) + self.addCleanup(transport.cleanup) + + driver = transport._driver + + target = messaging.Target(topic='testtopic') + + listener = driver.listen(target) + + senders = [] + replies = [] + msgs = [] + + wait_conditions = [] + orig_reply_waiter = amqpdriver.ReplyWaiter.wait + + def reply_waiter(self, msg_id, timeout): + if wait_conditions: + with wait_conditions[0]: + wait_conditions.pop().wait() + return orig_reply_waiter(self, msg_id, timeout) + + self.stubs.Set(amqpdriver.ReplyWaiter, 'wait', reply_waiter) + + def send_and_wait_for_reply(i): + replies.append(driver.send(target, + {}, + {'tx_id': i}, + wait_for_reply=True, + timeout=None)) + + while len(senders) < 2: + t = threading.Thread(target=send_and_wait_for_reply, + args=(len(senders), )) + t.daemon = True + senders.append(t) + + # Start the first guy, receive his message, but delay his polling + notify_condition = threading.Condition() + wait_conditions.append(notify_condition) + senders[0].start() + + msgs.append(listener.poll()) + self.assertEqual(msgs[-1].message, {'tx_id': 0}) + + # Start the second guy, receive his message + senders[1].start() + + msgs.append(listener.poll()) + self.assertEqual(msgs[-1].message, {'tx_id': 1}) + + # Reply to both in order, making the second thread queue + # the reply meant for the first thread + msgs[0].reply({'rx_id': 0}) + msgs[1].reply({'rx_id': 1}) + + # Wait for the second thread to finish + senders[1].join() + + # Let the first thread continue + with notify_condition: + notify_condition.notify() + + # Wait for the first thread to finish + senders[0].join() + + # Verify replies were received out of order + self.assertEqual(len(replies), len(senders)) + self.assertEqual(replies[0], {'rx_id': 1}) + self.assertEqual(replies[1], {'rx_id': 0}) + + def _declare_queue(target): connection = kombu.connection.BrokerConnection(transport='memory')