diff --git a/oslo_messaging/_drivers/common.py b/oslo_messaging/_drivers/common.py index 7ac1e4dff..539998bc9 100644 --- a/oslo_messaging/_drivers/common.py +++ b/oslo_messaging/_drivers/common.py @@ -15,6 +15,7 @@ # License for the specific language governing permissions and limitations # under the License. +import collections import copy import logging import sys @@ -442,3 +443,67 @@ class ConnectionContext(Connection): return getattr(self.connection, key) else: raise InvalidRPCConnectionReuse() + + +class ConfigOptsProxy(collections.Mapping): + """Proxy for oslo_config.cfg.ConfigOpts. + + Values from the query part of the transport url (if they are both present + and valid) override corresponding values from the configuration. + """ + + def __init__(self, conf, url): + self._conf = conf + self._url = url + + def __getattr__(self, name): + value = getattr(self._conf, name) + if isinstance(value, self._conf.GroupAttr): + return self.GroupAttrProxy(self._conf, name, value, self._url) + return value + + def __getitem__(self, name): + return self.__getattr__(name) + + def __contains__(self, name): + return name in self._conf + + def __iter__(self): + return iter(self._conf) + + def __len__(self): + return len(self._conf) + + class GroupAttrProxy(collections.Mapping): + """Internal helper proxy for oslo_config.cfg.ConfigOpts.GroupAttr.""" + + _VOID_MARKER = object() + + def __init__(self, conf, group_name, group, url): + self._conf = conf + self._group_name = group_name + self._group = group + self._url = url + + def __getattr__(self, opt_name): + # Make sure that the group has this specific option + opt_value_conf = getattr(self._group, opt_name) + # If the option is also present in the url and has a valid + # (i.e. convertible) value type, then try to override it + opt_value_url = self._url.query.get(opt_name, self._VOID_MARKER) + if opt_value_url is self._VOID_MARKER: + return opt_value_conf + opt_info = self._conf._get_opt_info(opt_name, self._group_name) + return opt_info['opt'].type(opt_value_url) + + def __getitem__(self, opt_name): + return self.__getattr__(opt_name) + + def __contains__(self, opt_name): + return opt_name in self._group + + def __iter__(self): + return iter(self._group) + + def __len__(self): + return len(self._group) diff --git a/oslo_messaging/_drivers/pika_driver/pika_engine.py b/oslo_messaging/_drivers/pika_driver/pika_engine.py index 8ff4e9103..97b6792d2 100644 --- a/oslo_messaging/_drivers/pika_driver/pika_engine.py +++ b/oslo_messaging/_drivers/pika_driver/pika_engine.py @@ -20,6 +20,7 @@ from oslo_utils import eventletutils import pika_pool from stevedore import driver +from oslo_messaging._drivers import common as drv_cmn from oslo_messaging._drivers.pika_driver import pika_commons as pika_drv_cmns from oslo_messaging._drivers.pika_driver import pika_exceptions as pika_drv_exc @@ -47,6 +48,7 @@ class PikaEngine(object): def __init__(self, conf, url, default_exchange=None, allowed_remote_exmods=None): + conf = drv_cmn.ConfigOptsProxy(conf, url) self.conf = conf self.url = url diff --git a/oslo_messaging/tests/test_config_opts_proxy.py b/oslo_messaging/tests/test_config_opts_proxy.py new file mode 100644 index 000000000..6d51716e3 --- /dev/null +++ b/oslo_messaging/tests/test_config_opts_proxy.py @@ -0,0 +1,77 @@ +# Copyright 2016 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_config import cfg +from oslo_config import types + +from oslo_messaging._drivers import common as drv_cmn +from oslo_messaging.tests import utils as test_utils +from oslo_messaging import transport + + +class TestConfigOptsProxy(test_utils.BaseTestCase): + + def test_rabbit(self): + group = 'oslo_messaging_rabbit' + self.config(rabbit_retry_interval=1, + rabbit_qos_prefetch_count=0, + rabbit_max_retries=3, + kombu_reconnect_delay=5.0, + group=group) + dummy_opts = [cfg.ListOpt('list_str', item_type=types.String(), + default=[]), + cfg.ListOpt('list_int', item_type=types.Integer(), + default=[]), + cfg.DictOpt('dict', default={}), + cfg.BoolOpt('bool', default=False), + cfg.StrOpt('str', default='default')] + self.conf.register_opts(dummy_opts, group=group) + url = transport.TransportURL.parse( + self.conf, "rabbit:///" + "?rabbit_qos_prefetch_count=2" + "&unknown_opt=4" + "&kombu_reconnect_delay=invalid_value" + "&list_str=1&list_str=2&list_str=3" + "&list_int=1&list_int=2&list_int=3" + "&dict=x:1&dict=y:2&dict=z:3" + "&bool=True" + ) + conf = drv_cmn.ConfigOptsProxy(self.conf, url) + self.assertRaises(cfg.NoSuchOptError, + conf.__getattr__, + 'unknown_group') + self.assertTrue(isinstance(getattr(conf, group), + conf.GroupAttrProxy)) + self.assertEqual(conf.oslo_messaging_rabbit.rabbit_retry_interval, + 1) + self.assertEqual(conf.oslo_messaging_rabbit.rabbit_qos_prefetch_count, + 2) + self.assertEqual(conf.oslo_messaging_rabbit.rabbit_max_retries, + 3) + self.assertRaises(cfg.NoSuchOptError, + conf.oslo_messaging_rabbit.__getattr__, + 'unknown_opt') + self.assertRaises(ValueError, + conf.oslo_messaging_rabbit.__getattr__, + 'kombu_reconnect_delay') + self.assertEqual(conf.oslo_messaging_rabbit.list_str, + ['1', '2', '3']) + self.assertEqual(conf.oslo_messaging_rabbit.list_int, + [1, 2, 3]) + self.assertEqual(conf.oslo_messaging_rabbit.dict, + {'x': '1', 'y': '2', 'z': '3'}) + self.assertEqual(conf.oslo_messaging_rabbit.bool, + True) + self.assertEqual(conf.oslo_messaging_rabbit.str, + 'default') diff --git a/oslo_messaging/tests/test_transport.py b/oslo_messaging/tests/test_transport.py index b8561f116..89619cf72 100644 --- a/oslo_messaging/tests/test_transport.py +++ b/oslo_messaging/tests/test_transport.py @@ -331,20 +331,33 @@ class TestTransportMethodArgs(test_utils.BaseTestCase): class TestTransportUrlCustomisation(test_utils.BaseTestCase): def setUp(self): super(TestTransportUrlCustomisation, self).setUp() - self.url1 = transport.TransportURL.parse(self.conf, "fake://vhost1") - self.url2 = transport.TransportURL.parse(self.conf, "fake://vhost2") - self.url3 = transport.TransportURL.parse(self.conf, "fake://vhost1") + + def transport_url_parse(url): + return transport.TransportURL.parse(self.conf, url) + + self.url1 = transport_url_parse("fake://vhost1?x=1&y=2&z=3") + self.url2 = transport_url_parse("fake://vhost2?foo=bar") + self.url3 = transport_url_parse("fake://vhost1?l=1&l=2&l=3") + self.url4 = transport_url_parse("fake://vhost2?d=x:1&d=y:2&d=z:3") def test_hash(self): urls = {} urls[self.url1] = self.url1 urls[self.url2] = self.url2 urls[self.url3] = self.url3 + urls[self.url4] = self.url4 self.assertEqual(2, len(urls)) def test_eq(self): self.assertEqual(self.url1, self.url3) - self.assertNotEqual(self.url1, self.url2) + self.assertEqual(self.url2, self.url4) + self.assertNotEqual(self.url1, self.url4) + + def test_query(self): + self.assertEqual(self.url1.query, {'x': '1', 'y': '2', 'z': '3'}) + self.assertEqual(self.url2.query, {'foo': 'bar'}) + self.assertEqual(self.url3.query, {'l': '1,2,3'}) + self.assertEqual(self.url4.query, {'d': 'x:1,y:2,z:3'}) class TestTransportHostCustomisation(test_utils.BaseTestCase): diff --git a/oslo_messaging/transport.py b/oslo_messaging/transport.py index 2a3c9aaf7..703c8aabf 100644 --- a/oslo_messaging/transport.py +++ b/oslo_messaging/transport.py @@ -230,10 +230,12 @@ class TransportURL(object): Transport URLs take the form:: - transport://user:pass@host:port[,userN:passN@hostN:portN]/virtual_host + transport://user:pass@host:port[,userN:passN@hostN:portN]/virtual_host?query i.e. the scheme selects the transport driver, you may include multiple - hosts in netloc and the path part is a "virtual host" partition path. + hosts in netloc, the path part is a "virtual host" partition path and + the query part contains some driver-specific options which may override + corresponding values from a static configuration. :param conf: a ConfigOpts instance :type conf: oslo.config.cfg.ConfigOpts @@ -243,12 +245,14 @@ class TransportURL(object): :type virtual_host: str :param hosts: a list of TransportHost objects :type hosts: list - :param aliases: DEPRECATED: A map of transport alias to transport name + :param aliases: DEPRECATED: a map of transport alias to transport name :type aliases: dict + :param query: a dictionary of URL query parameters + :type query: dict """ def __init__(self, conf, transport=None, virtual_host=None, hosts=None, - aliases=None): + aliases=None, query=None): self.conf = conf self.conf.register_opts(_transport_opts) self._transport = transport @@ -261,6 +265,10 @@ class TransportURL(object): self.aliases = {} else: self.aliases = aliases + if query is None: + self.query = {} + else: + self.query = query self._deprecation_logged = False @@ -346,6 +354,9 @@ class TransportURL(object): if self.virtual_host: url += parse.quote(self.virtual_host) + if self.query: + url += '?' + parse.urlencode(self.query, doseq=True) + return url @classmethod @@ -354,7 +365,7 @@ class TransportURL(object): Assuming a URL takes the form of:: - transport://user:pass@host:port[,userN:passN@hostN:portN]/virtual_host + transport://user:pass@host:port[,userN:passN@hostN:portN]/virtual_host?query then parse the URL and return a TransportURL object. @@ -371,7 +382,7 @@ class TransportURL(object): {"host": "host2:port2"} ] - If the url is not provided conf.transport_url is parsed intead. + If the url is not provided conf.transport_url is parsed instead. :param conf: a ConfigOpts instance :type conf: oslo.config.cfg.ConfigOpts @@ -394,13 +405,12 @@ class TransportURL(object): if not url.scheme: raise InvalidTransportURL(url.geturl(), 'No scheme specified') - # Make sure there's not a query string; that could identify - # requirements we can't comply with (for example ssl), so reject it if - # it's present - if '?' in url.path or url.query: - raise InvalidTransportURL(url.geturl(), - "Cannot comply with query string in " - "transport URL") + transport = url.scheme + + query = {} + if url.query: + for key, values in six.iteritems(parse.parse_qs(url.query)): + query[key] = ','.join(values) virtual_host = None if url.path.startswith('/'): @@ -430,7 +440,7 @@ class TransportURL(object): if host_end < 0: # NOTE(Vek): Identical to what Python 2.7's # urlparse.urlparse() raises in this case - raise ValueError("Invalid IPv6 URL") + raise ValueError('Invalid IPv6 URL') port_text = hostname[host_end:] hostname = hostname[1:host_end] @@ -449,4 +459,4 @@ class TransportURL(object): username=username, password=password)) - return cls(conf, url.scheme, virtual_host, hosts, aliases) + return cls(conf, transport, virtual_host, hosts, aliases, query)