From 9c110a4c94def6089ed7923dad89e763604eb794 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 12 Aug 2013 07:48:00 +0100 Subject: [PATCH] Add transport URL support to rabbit driver If a transport URL is supplied, transform it into the server_params format which was previously used for cast_to_server() etc. Change-Id: I453734a71748dc8d3ffc02ead7bfb92ffb0a6c7c --- oslo/messaging/_drivers/amqpdriver.py | 41 +++++++++++++++++- tests/test_rabbit.py | 60 +++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index 089601fb1..c2f6c22a3 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -24,6 +24,7 @@ from oslo import messaging from oslo.messaging._drivers import amqp as rpc_amqp from oslo.messaging._drivers import base from oslo.messaging._drivers import common as rpc_common +from oslo.messaging import _urls as urls LOG = logging.getLogger(__name__) @@ -247,6 +248,8 @@ class AMQPDriverBase(base.BaseDriver): super(AMQPDriverBase, self).__init__(conf, url, default_exchange, allowed_remote_exmods) + self._server_params = self._parse_url(self._url) + self._default_exchange = default_exchange # FIXME(markmc): temp hack @@ -260,10 +263,46 @@ class AMQPDriverBase(base.BaseDriver): self._reply_q_conn = None self._waiter = None + @staticmethod + def _parse_url(url): + if url is None: + return None + + parsed = urls.parse_url(url) + + # Make sure there's not a query string; that could identify + # requirements we can't comply with (e.g., ssl), so reject it if + # it's present + if parsed['parameters']: + raise messaging.InvalidTransportURL( + url, "Cannot comply with query string in transport URL") + + if not parsed['hosts']: + return None + + sp = { + 'virtual_host': parsed['virtual_host'], + } + + # FIXME(markmc): support multiple hosts + host = parsed['hosts'][0] + + if ':' in host['host']: + (sp['hostname'], sp['port']) = host['host'].split(':', 1) + sp['port'] = int(sp['port']) + else: + sp['hostname'] = host['host'] + + sp['username'] = host['username'] + sp['password'] = host['password'] + + return sp + def _get_connection(self, pooled=True): return rpc_amqp.ConnectionContext(self.conf, self._connection_pool, - pooled=pooled) + pooled=pooled, + server_params=self._server_params) def _get_reply_q(self): with self._reply_q_lock: diff --git a/tests/test_rabbit.py b/tests/test_rabbit.py index 06d21a0da..4b6b28d26 100644 --- a/tests/test_rabbit.py +++ b/tests/test_rabbit.py @@ -47,6 +47,66 @@ class TestRabbitDriverLoad(test_utils.BaseTestCase): self.assertIsInstance(transport._driver, rabbit_driver.RabbitDriver) +class TestRabbitTransportURL(test_utils.BaseTestCase): + + scenarios = [ + ('none', dict(url=None, expected=None)), + ('empty', dict(url='rabbit:///', expected=None)), + ('localhost', + dict(url='rabbit://localhost/', + expected=dict(hostname='localhost', + username='', + password='', + virtual_host=''))), + ('no_creds', + dict(url='rabbit://host/virtual_host', + expected=dict(hostname='host', + username='', + password='', + virtual_host='virtual_host'))), + ('no_port', + dict(url='rabbit://user:password@host/virtual_host', + expected=dict(hostname='host', + username='user', + password='password', + virtual_host='virtual_host'))), + ('full_url', + dict(url='rabbit://user:password@host:10/virtual_host', + expected=dict(hostname='host', + port=10, + username='user', + password='password', + virtual_host='virtual_host'))), + ] + + def setUp(self): + super(TestRabbitTransportURL, self).setUp() + self.conf.register_opts(msg_transport._transport_opts) + self.conf.register_opts(rabbit_driver.rabbit_opts) + self.config(rpc_backend='rabbit') + self.config(fake_rabbit=True) + + def test_transport_url(self): + cnx_init = rabbit_driver.Connection.__init__ + passed_params = [] + + def record_params(self, conf, server_params=None): + passed_params.append(server_params) + return cnx_init(self, conf, server_params) + + self.stubs.Set(rabbit_driver.Connection, '__init__', record_params) + + transport = messaging.get_transport(self.conf, self.url) + + driver = transport._driver + + target = messaging.Target(topic='testtopic') + + driver.send(target, {}, {}) + + self.assertEquals(passed_params[0], self.expected) + + class TestSendReceive(test_utils.BaseTestCase): _n_senders = [