# Copyright 2013 Red Hat, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. __all__ = ['AMQPDriverBase'] import logging import threading import uuid from six import moves from oslo import messaging from oslo.messaging._drivers import amqp as rpc_amqp from oslo.messaging._drivers import base from oslo.messaging._drivers import common as rpc_common LOG = logging.getLogger(__name__) class AMQPIncomingMessage(base.IncomingMessage): def __init__(self, listener, ctxt, message, msg_id, reply_q): super(AMQPIncomingMessage, self).__init__(listener, ctxt, message) self.msg_id = msg_id self.reply_q = reply_q def _send_reply(self, conn, reply=None, failure=None, ending=False, log_failure=True): if failure: failure = rpc_common.serialize_remote_exception(failure, log_failure) msg = {'result': reply, 'failure': failure} if ending: msg['ending'] = True rpc_amqp._add_unique_id(msg) # If a reply_q exists, add the msg_id to the reply and pass the # reply_q to direct_send() to use it as the response queue. # Otherwise use the msg_id for backward compatibility. if self.reply_q: msg['_msg_id'] = self.msg_id conn.direct_send(self.reply_q, rpc_common.serialize_msg(msg)) else: conn.direct_send(self.msg_id, rpc_common.serialize_msg(msg)) def reply(self, reply=None, failure=None, log_failure=True): with self.listener.driver._get_connection() as conn: self._send_reply(conn, reply, failure, log_failure=log_failure) self._send_reply(conn, ending=True) class AMQPListener(base.Listener): def __init__(self, driver, target, conn): super(AMQPListener, self).__init__(driver, target) self.conn = conn self.msg_id_cache = rpc_amqp._MsgIdCache() self.incoming = [] def __call__(self, message): # FIXME(markmc): logging isn't driver specific rpc_common._safe_log(LOG.debug, 'received %s', message) self.msg_id_cache.check_duplicate_message(message) ctxt = rpc_amqp.unpack_context(self.conf, message) self.incoming.append(AMQPIncomingMessage(self, ctxt.to_dict(), message, ctxt.msg_id, ctxt.reply_q)) def poll(self): while True: if self.incoming: return self.incoming.pop(0) self.conn.consume(limit=1) class ReplyWaiters(object): WAKE_UP = object() def __init__(self): self._queues = {} self._wrn_threshold = 10 def get(self, msg_id, timeout): try: return self._queues[msg_id].get(block=True, timeout=timeout) except moves.queue.Empty: raise messaging.MessagingTimeout('Timed out waiting for a reply ' 'to message ID %s' % msg_id) def check(self, msg_id): try: return self._queues[msg_id].get(block=False) except moves.queue.Empty: return None def put(self, msg_id, message_data): queue = self._queues.get(msg_id) if not queue: LOG.warn('No calling threads waiting for msg_id : %(msg_id)s' ', message : %(data)s', {'msg_id': msg_id, 'data': message_data}) LOG.warn('_queues: %s' % str(self._queues)) else: queue.put(message_data) def wake_all(self, except_id): msg_ids = [i for i in self._queues.keys() if i != except_id] for msg_id in msg_ids: self.put(msg_id, self.WAKE_UP) def add(self, msg_id, queue): self._queues[msg_id] = queue if len(self._queues) > self._wrn_threshold: LOG.warn('Number of call queues is greater than warning ' 'threshold: %d. There could be a leak.' % self._wrn_threshold) self._wrn_threshold *= 2 def remove(self, msg_id): del self._queues[msg_id] class ReplyWaiter(object): def __init__(self, conf, reply_q, conn, allowed_remote_exmods): self.conf = conf self.conn = conn self.reply_q = reply_q self.allowed_remote_exmods = allowed_remote_exmods self.conn_lock = threading.Lock() self.incoming = [] self.msg_id_cache = rpc_amqp._MsgIdCache() self.waiters = ReplyWaiters() conn.declare_direct_consumer(reply_q, self) def __call__(self, message): self.incoming.append(message) def listen(self, msg_id): queue = moves.queue.Queue() self.waiters.add(msg_id, queue) def unlisten(self, msg_id): self.waiters.remove(msg_id) def _process_reply(self, data): result = None ending = False self.msg_id_cache.check_duplicate_message(data) if data['failure']: failure = data['failure'] result = rpc_common.deserialize_remote_exception( failure, self.allowed_remote_exmods) elif data.get('ending', False): ending = True else: result = data['result'] return result, ending def _poll_connection(self, msg_id, timeout): while True: while self.incoming: message_data = self.incoming.pop(0) incoming_msg_id = message_data.pop('_msg_id', None) if incoming_msg_id == msg_id: return self._process_reply(message_data) self.waiters.put(incoming_msg_id, message_data) try: self.conn.consume(limit=1, timeout=timeout) except rpc_common.Timeout: raise messaging.MessagingTimeout('Timed out waiting for a ' 'reply to message ID %s' % msg_id) def _poll_queue(self, msg_id, timeout): message = self.waiters.get(msg_id, timeout) if message is self.waiters.WAKE_UP: return None, None, True # lock was released reply, ending = self._process_reply(message) return reply, ending, False def _check_queue(self, msg_id): while True: message = self.waiters.check(msg_id) if message is self.waiters.WAKE_UP: continue if message is None: return None, None, True # queue is empty reply, ending = self._process_reply(message) return reply, ending, False def wait(self, msg_id, timeout): # # NOTE(markmc): we're waiting for a reply for msg_id to come in for on # the reply_q, but there may be other threads also waiting for replies # to other msg_ids # # Only one thread can be consuming from the queue using this connection # and we don't want to hold open a connection per thread, so instead we # have the first thread take responsibility for passing replies not # intended for itself to the appropriate thread. # final_reply = None while True: if self.conn_lock.acquire(False): # Ok, we're the thread responsible for polling the connection try: # Check the queue to see if a previous lock-holding thread # queued up a reply already while True: reply, ending, empty = self._check_queue(msg_id) if empty: break if not ending: final_reply = reply else: return final_reply # Now actually poll the connection while True: reply, ending = self._poll_connection(msg_id, timeout) if not ending: final_reply = reply else: return final_reply finally: self.conn_lock.release() # We've got our reply, tell the other threads to wake up # so that one of them will take over the responsibility for # polling the connection self.waiters.wake_all(msg_id) else: # We're going to wait for the first thread to pass us our reply reply, ending, trylock = self._poll_queue(msg_id, timeout) if trylock: # The first thread got its reply, let's try and take over # the responsibility for polling continue if not ending: final_reply = reply else: return final_reply class AMQPDriverBase(base.BaseDriver): def __init__(self, conf, url, connection_pool, default_exchange=None, allowed_remote_exmods=[]): super(AMQPDriverBase, self).__init__(conf, url, default_exchange, allowed_remote_exmods) self._server_params = self._server_params_from_url(self._url) self._default_exchange = default_exchange # FIXME(markmc): temp hack if self._default_exchange: self.conf.set_override('control_exchange', self._default_exchange) self._connection_pool = connection_pool self._reply_q_lock = threading.Lock() self._reply_q = None self._reply_q_conn = None self._waiter = None def _server_params_from_url(self, url): sp = {} if url.virtual_host is not None: sp['virtual_host'] = url.virtual_host if url.hosts: # FIXME(markmc): support multiple hosts host = url.hosts[0] sp['hostname'] = host.hostname if host.port is not None: sp['port'] = host.port sp['username'] = host.username or '' sp['password'] = host.password or '' return sp def _get_connection(self, pooled=True): # FIXME(markmc): we don't yet have a connection pool for each # Transport instance, so we'll only use the pool with the # transport configuration from the config file server_params = self._server_params or None if server_params: pooled = False return rpc_amqp.ConnectionContext(self.conf, self._connection_pool, pooled=pooled, server_params=server_params) def _get_reply_q(self): with self._reply_q_lock: if self._reply_q is not None: return self._reply_q reply_q = 'reply_' + uuid.uuid4().hex conn = self._get_connection(pooled=False) self._waiter = ReplyWaiter(self.conf, reply_q, conn, self._allowed_remote_exmods) self._reply_q = reply_q self._reply_q_conn = conn return self._reply_q def _send(self, target, ctxt, message, wait_for_reply=None, timeout=None, envelope=True): # FIXME(markmc): remove this temporary hack class Context(object): def __init__(self, d): self.d = d def to_dict(self): return self.d context = Context(ctxt) msg = message if wait_for_reply: msg_id = uuid.uuid4().hex msg.update({'_msg_id': msg_id}) LOG.debug('MSG_ID is %s' % (msg_id)) msg.update({'_reply_q': self._get_reply_q()}) rpc_amqp._add_unique_id(msg) rpc_amqp.pack_context(msg, context) if envelope: msg = rpc_common.serialize_msg(msg) if wait_for_reply: self._waiter.listen(msg_id) try: with self._get_connection() as conn: if target.fanout: conn.fanout_send(target.topic, msg) else: topic = target.topic if target.server: topic = '%s.%s' % (target.topic, target.server) conn.topic_send(topic, msg, timeout=timeout) if wait_for_reply: result = self._waiter.wait(msg_id, timeout) if isinstance(result, Exception): raise result return result finally: if wait_for_reply: self._waiter.unlisten(msg_id) def send(self, target, ctxt, message, wait_for_reply=None, timeout=None): return self._send(target, ctxt, message, wait_for_reply, timeout) def send_notification(self, target, ctxt, message, version): return self._send(target, ctxt, message, envelope=(version == 2.0)) def listen(self, target): conn = self._get_connection(pooled=False) listener = AMQPListener(self, target, conn) conn.declare_topic_consumer(target.topic, listener) conn.declare_topic_consumer('%s.%s' % (target.topic, target.server), listener) conn.declare_fanout_consumer(target.topic, listener) return listener def cleanup(self): if self._connection_pool: self._connection_pool.empty() self._connection_pool = None