Fix race-condition in rabbit reply processing

Concurrency. Sigh.

A sequence of events like this is possible:

  - We send a request from thread A
  - Thread B, who is waiting for a response gets scheduled
  - Thread B receives our response and queues it up
  - Thread B receives its own response and drops the connection lock
  - Thread A grabs the connection lock and wait for a response to arrive

The obvious solution is that when we grab the connection lock, we should
check whether a previous lock-holding thread had already received our
response and queued it up.

Change-Id: I88b0d55d5a40814a84d82ed4f42d5ba85d2ef7e0
This commit is contained in:
Mark McLoughlin 2013-08-26 11:10:57 +01:00
parent 361092a488
commit aebe53f242
2 changed files with 114 additions and 3 deletions

View File

@ -93,6 +93,8 @@ class AMQPListener(base.Listener):
class ReplyWaiters(object):
WAKE_UP = object()
def __init__(self):
self._queues = {}
self._wrn_threshhold = 10
@ -104,6 +106,12 @@ class ReplyWaiters(object):
raise messaging.MessagingTimeout('Timed out waiting for a reply '
'to message ID %s' % msg_id)
def check(self, msg_id):
try:
return self._queues[msg_id].get(block=False)
except Queue.Empty:
return None
def put(self, msg_id, message_data):
queue = self._queues.get(msg_id)
if not queue:
@ -117,7 +125,7 @@ class ReplyWaiters(object):
def wake_all(self, except_id):
msg_ids = [i for i in self._queues.keys() if i != except_id]
for msg_id in msg_ids:
self.put(msg_id, None)
self.put(msg_id, self.WAKE_UP)
def add(self, msg_id, queue):
self._queues[msg_id] = queue
@ -189,10 +197,20 @@ class ReplyWaiter(object):
% msg_id)
def _poll_queue(self, msg_id, timeout):
message = self.waiters.get(msg_id, timeout)
if message is self.waiters.WAKE_UP:
return None, None, True # lock was released
reply, ending = self._process_reply(message)
return reply, ending, False
def _check_queue(self, msg_id):
while True:
message = self.waiters.get(msg_id, timeout)
message = self.waiters.check(msg_id)
if message is self.waiters.WAKE_UP:
continue
if message is None:
return None, None, True # lock was released
return None, None, True # queue is empty
reply, ending = self._process_reply(message)
return reply, ending, False
@ -213,6 +231,18 @@ class ReplyWaiter(object):
if self.conn_lock.acquire(False):
# Ok, we're the thread responsible for polling the connection
try:
# Check the queue to see if a previous lock-holding thread
# queued up a reply already
while True:
reply, ending, empty = self._check_queue(msg_id)
if empty:
break
if not ending:
final_reply = reply
else:
return final_reply
# Now actually poll the connection
while True:
reply, ending = self._poll_connection(msg_id, timeout)
if not ending:

View File

@ -24,6 +24,7 @@ import kombu
import testscenarios
from oslo import messaging
from oslo.messaging._drivers import amqpdriver
from oslo.messaging._drivers import common as driver_common
from oslo.messaging._drivers import impl_rabbit as rabbit_driver
from oslo.messaging.openstack.common import jsonutils
@ -235,6 +236,86 @@ class TestSendReceive(test_utils.BaseTestCase):
TestSendReceive.generate_scenarios()
class TestRacyWaitForReply(test_utils.BaseTestCase):
def setUp(self):
super(TestRacyWaitForReply, self).setUp()
self.messaging_conf.transport_driver = 'rabbit'
self.messaging_conf.in_memory = True
def test_send_receive(self):
transport = messaging.get_transport(self.conf)
self.addCleanup(transport.cleanup)
driver = transport._driver
target = messaging.Target(topic='testtopic')
listener = driver.listen(target)
senders = []
replies = []
msgs = []
wait_conditions = []
orig_reply_waiter = amqpdriver.ReplyWaiter.wait
def reply_waiter(self, msg_id, timeout):
if wait_conditions:
with wait_conditions[0]:
wait_conditions.pop().wait()
return orig_reply_waiter(self, msg_id, timeout)
self.stubs.Set(amqpdriver.ReplyWaiter, 'wait', reply_waiter)
def send_and_wait_for_reply(i):
replies.append(driver.send(target,
{},
{'tx_id': i},
wait_for_reply=True,
timeout=None))
while len(senders) < 2:
t = threading.Thread(target=send_and_wait_for_reply,
args=(len(senders), ))
t.daemon = True
senders.append(t)
# Start the first guy, receive his message, but delay his polling
notify_condition = threading.Condition()
wait_conditions.append(notify_condition)
senders[0].start()
msgs.append(listener.poll())
self.assertEqual(msgs[-1].message, {'tx_id': 0})
# Start the second guy, receive his message
senders[1].start()
msgs.append(listener.poll())
self.assertEqual(msgs[-1].message, {'tx_id': 1})
# Reply to both in order, making the second thread queue
# the reply meant for the first thread
msgs[0].reply({'rx_id': 0})
msgs[1].reply({'rx_id': 1})
# Wait for the second thread to finish
senders[1].join()
# Let the first thread continue
with notify_condition:
notify_condition.notify()
# Wait for the first thread to finish
senders[0].join()
# Verify replies were received out of order
self.assertEqual(len(replies), len(senders))
self.assertEqual(replies[0], {'rx_id': 1})
self.assertEqual(replies[1], {'rx_id': 0})
def _declare_queue(target):
connection = kombu.connection.BrokerConnection(transport='memory')