Implement wait_for_reply timeout in rabbit driver
Note - the tests use timeout=0.01 because timeout=0 doesn't seem to be working for some reason. Change-Id: I814a3decdad5ddce0a1a2301ba2d59fa928b53a7
This commit is contained in:
parent
cb2623f46e
commit
7c305150ff
@ -20,6 +20,7 @@ import Queue
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from oslo import messaging
|
||||
from oslo.messaging._drivers import amqp as rpc_amqp
|
||||
from oslo.messaging._drivers import base
|
||||
from oslo.messaging._drivers import common as rpc_common
|
||||
@ -97,8 +98,6 @@ class AMQPListener(base.Listener):
|
||||
while True:
|
||||
if self.incoming:
|
||||
return self.incoming.pop(0)
|
||||
|
||||
# FIXME(markmc): timeout?
|
||||
self.conn.consume(limit=1)
|
||||
|
||||
|
||||
@ -108,8 +107,12 @@ class ReplyWaiters(object):
|
||||
self._queues = {}
|
||||
self._wrn_threshhold = 10
|
||||
|
||||
def get(self, msg_id):
|
||||
return self._queues[msg_id].get()
|
||||
def get(self, msg_id, timeout):
|
||||
try:
|
||||
return self._queues[msg_id].get(block=True, timeout=timeout)
|
||||
except Queue.Empty:
|
||||
raise messaging.MessagingTimeout('Timed out waiting for a reply '
|
||||
'to message ID %s' % msg_id)
|
||||
|
||||
def put(self, msg_id, message_data):
|
||||
queue = self._queues.get(msg_id)
|
||||
@ -176,7 +179,7 @@ class ReplyWaiter(object):
|
||||
result = data['result']
|
||||
return result, ending
|
||||
|
||||
def _poll_connection(self, msg_id):
|
||||
def _poll_connection(self, msg_id, timeout):
|
||||
while True:
|
||||
while self.incoming:
|
||||
message_data = self.incoming.pop(0)
|
||||
@ -187,20 +190,23 @@ class ReplyWaiter(object):
|
||||
|
||||
self.waiters.put(incoming_msg_id, message_data)
|
||||
|
||||
# FIXME(markmc): timeout?
|
||||
self.conn.consume(limit=1)
|
||||
try:
|
||||
self.conn.consume(limit=1, timeout=timeout)
|
||||
except rpc_common.Timeout:
|
||||
raise messaging.MessagingTimeout('Timed out waiting for a '
|
||||
'reply to message ID %s'
|
||||
% msg_id)
|
||||
|
||||
def _poll_queue(self, msg_id):
|
||||
def _poll_queue(self, msg_id, timeout):
|
||||
while True:
|
||||
# FIXME(markmc): timeout?
|
||||
message = self.waiters.get(msg_id)
|
||||
message = self.waiters.get(msg_id, timeout)
|
||||
if message is None:
|
||||
return None, None, True # lock was released
|
||||
|
||||
reply, ending = self._process_reply(message)
|
||||
return reply, ending, False
|
||||
|
||||
def wait(self, msg_id):
|
||||
def wait(self, msg_id, timeout):
|
||||
# NOTE(markmc): multiple threads may call this
|
||||
# First thread calls consume, when it gets its reply
|
||||
# it wakes up other threads and they call consume
|
||||
@ -211,7 +217,7 @@ class ReplyWaiter(object):
|
||||
if self.conn_lock.acquire(False):
|
||||
try:
|
||||
while True:
|
||||
reply, ending = self._poll_connection(msg_id)
|
||||
reply, ending = self._poll_connection(msg_id, timeout)
|
||||
if reply:
|
||||
final_reply = reply
|
||||
elif ending:
|
||||
@ -220,7 +226,7 @@ class ReplyWaiter(object):
|
||||
self.conn_lock.release()
|
||||
self.waiters.wake_all(msg_id)
|
||||
else:
|
||||
reply, ending, trylock = self._poll_queue(msg_id)
|
||||
reply, ending, trylock = self._poll_queue(msg_id, timeout)
|
||||
if trylock:
|
||||
continue
|
||||
if reply:
|
||||
@ -308,8 +314,7 @@ class AMQPDriverBase(base.BaseDriver):
|
||||
conn.topic_send(topic, msg, timeout=timeout)
|
||||
|
||||
if wait_for_reply:
|
||||
# FIXME(markmc): timeout?
|
||||
result = self._waiter.wait(msg_id)
|
||||
result = self._waiter.wait(msg_id, timeout)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
@ -600,9 +600,9 @@ class Connection(object):
|
||||
"""Send a notify message on a topic."""
|
||||
self.publisher_send(NotifyPublisher, topic, msg)
|
||||
|
||||
def consume(self, limit=None):
|
||||
def consume(self, limit=None, timeout=None):
|
||||
"""Consume from all queues/consumers."""
|
||||
it = self.iterconsume(limit=limit)
|
||||
it = self.iterconsume(limit=limit, timeout=timeout)
|
||||
while True:
|
||||
try:
|
||||
it.next()
|
||||
|
@ -743,9 +743,9 @@ class Connection(object):
|
||||
"""Send a notify message on a topic."""
|
||||
self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs)
|
||||
|
||||
def consume(self, limit=None):
|
||||
def consume(self, limit=None, timeout=None):
|
||||
"""Consume from all queues/consumers."""
|
||||
it = self.iterconsume(limit=limit)
|
||||
it = self.iterconsume(limit=limit, timeout=timeout)
|
||||
while True:
|
||||
try:
|
||||
it.next()
|
||||
|
@ -63,11 +63,17 @@ class TestSendReceive(test_utils.BaseTestCase):
|
||||
('failure', dict(failure=True)),
|
||||
]
|
||||
|
||||
_timeout = [
|
||||
('no_timeout', dict(timeout=None)),
|
||||
('timeout', dict(timeout=0.01)), # FIXME(markmc): timeout=0 is broken?
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def generate_scenarios(cls):
|
||||
cls.scenarios = testscenarios.multiply_scenarios(cls._n_senders,
|
||||
cls._context,
|
||||
cls._failure)
|
||||
cls._failure,
|
||||
cls._timeout)
|
||||
|
||||
def setUp(self):
|
||||
super(TestSendReceive, self).setUp()
|
||||
@ -95,11 +101,13 @@ class TestSendReceive(test_utils.BaseTestCase):
|
||||
replies.append(driver.send(target,
|
||||
self.ctxt,
|
||||
{'foo': i},
|
||||
wait_for_reply=True))
|
||||
wait_for_reply=True,
|
||||
timeout=self.timeout))
|
||||
self.assertFalse(self.failure)
|
||||
except ZeroDivisionError as e:
|
||||
self.assertIsNone(self.timeout)
|
||||
except (ZeroDivisionError, messaging.MessagingTimeout) as e:
|
||||
replies.append(e)
|
||||
self.assertTrue(self.failure)
|
||||
self.assertTrue(self.failure or self.timeout is not None)
|
||||
|
||||
while len(senders) < self.n_senders:
|
||||
senders.append(threading.Thread(target=send_and_wait_for_reply,
|
||||
@ -120,22 +128,25 @@ class TestSendReceive(test_utils.BaseTestCase):
|
||||
order[-1], order[-2] = order[-2], order[-1]
|
||||
|
||||
for i in order:
|
||||
if self.failure:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except Exception:
|
||||
failure = sys.exc_info()
|
||||
msgs[i].reply(failure=failure)
|
||||
else:
|
||||
msgs[i].reply({'bar': msgs[i].message['foo']})
|
||||
if self.timeout is None:
|
||||
if self.failure:
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except Exception:
|
||||
failure = sys.exc_info()
|
||||
msgs[i].reply(failure=failure)
|
||||
else:
|
||||
msgs[i].reply({'bar': msgs[i].message['foo']})
|
||||
senders[i].join()
|
||||
|
||||
self.assertEqual(len(replies), len(senders))
|
||||
for i, reply in enumerate(replies):
|
||||
if not self.failure:
|
||||
self.assertEqual(reply, {'bar': order[i]})
|
||||
else:
|
||||
if self.timeout is not None:
|
||||
self.assertIsInstance(reply, messaging.MessagingTimeout)
|
||||
elif self.failure:
|
||||
self.assertIsInstance(reply, ZeroDivisionError)
|
||||
else:
|
||||
self.assertEqual(reply, {'bar': order[i]})
|
||||
|
||||
|
||||
TestSendReceive.generate_scenarios()
|
||||
|
Loading…
x
Reference in New Issue
Block a user