diff --git a/oslo/messaging/_drivers/amqp.py b/oslo/messaging/_drivers/amqp.py index 74b671baf..b7ec5945a 100644 --- a/oslo/messaging/_drivers/amqp.py +++ b/oslo/messaging/_drivers/amqp.py @@ -59,16 +59,17 @@ LOG = logging.getLogger(__name__) class ConnectionPool(pool.Pool): """Class that implements a Pool of Connections.""" - def __init__(self, conf, connection_cls): + def __init__(self, conf, url, connection_cls): self.connection_cls = connection_cls self.conf = conf + self.url = url super(ConnectionPool, self).__init__(self.conf.rpc_conn_pool_size) self.reply_proxy = None # TODO(comstud): Timeout connections not used in a while def create(self): LOG.debug(_('Pool creating new connection')) - return self.connection_cls(self.conf) + return self.connection_cls(self.conf, self.url) def empty(self): for item in self.iter_free(): @@ -82,18 +83,19 @@ class ConnectionPool(pool.Pool): # time code, it gets here via cleanup() and only appears in service.py # just before doing a sys.exit(), so cleanup() only happens once and # the leakage is not a problem. - self.connection_cls.pool = None + del self.connection_cls.pools[self.url] _pool_create_sem = threading.Lock() -def get_connection_pool(conf, connection_cls): +def get_connection_pool(conf, url, connection_cls): with _pool_create_sem: # Make sure only one thread tries to create the connection pool. - if not connection_cls.pool: - connection_cls.pool = ConnectionPool(conf, connection_cls) - return connection_cls.pool + if url not in connection_cls.pools: + connection_cls.pools[url] = ConnectionPool(conf, url, + connection_cls) + return connection_cls.pools[url] class ConnectionContext(rpc_common.Connection): @@ -108,17 +110,16 @@ class ConnectionContext(rpc_common.Connection): If possible the function makes sure to return a connection to the pool. """ - def __init__(self, conf, connection_pool, pooled=True, server_params=None): + def __init__(self, conf, url, connection_pool, pooled=True): """Create a new connection, or get one from the pool.""" self.connection = None self.conf = conf + self.url = url self.connection_pool = connection_pool if pooled: self.connection = connection_pool.get() else: - self.connection = connection_pool.connection_cls( - conf, - server_params=server_params) + self.connection = connection_pool.connection_cls(conf, url) self.pooled = pooled def __enter__(self): diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index 3a2303d39..fedbb7c4c 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -295,8 +295,6 @@ class AMQPDriverBase(base.BaseDriver): super(AMQPDriverBase, self).__init__(conf, url, default_exchange, allowed_remote_exmods) - self._server_params = self._server_params_from_url(self._url) - self._default_exchange = default_exchange # FIXME(markmc): temp hack @@ -310,35 +308,11 @@ class AMQPDriverBase(base.BaseDriver): self._reply_q_conn = None self._waiter = None - def _server_params_from_url(self, url): - sp = {} - - if url.virtual_host is not None: - sp['virtual_host'] = url.virtual_host - - if url.hosts: - # FIXME(markmc): support multiple hosts - host = url.hosts[0] - - sp['hostname'] = host.hostname - if host.port is not None: - sp['port'] = host.port - sp['username'] = host.username or '' - sp['password'] = host.password or '' - - return sp - def _get_connection(self, pooled=True): - # FIXME(markmc): we don't yet have a connection pool for each - # Transport instance, so we'll only use the pool with the - # transport configuration from the config file - server_params = self._server_params or None - if server_params: - pooled = False return rpc_amqp.ConnectionContext(self.conf, + self._url, self._connection_pool, - pooled=pooled, - server_params=server_params) + pooled=pooled) def _get_reply_q(self): with self._reply_q_lock: diff --git a/oslo/messaging/_drivers/impl_qpid.py b/oslo/messaging/_drivers/impl_qpid.py index 39f8136c1..0c169ae77 100644 --- a/oslo/messaging/_drivers/impl_qpid.py +++ b/oslo/messaging/_drivers/impl_qpid.py @@ -27,6 +27,7 @@ from oslo.messaging._drivers import amqpdriver from oslo.messaging._drivers import common as rpc_common from oslo.messaging.openstack.common import importutils from oslo.messaging.openstack.common import jsonutils +from oslo.messaging.openstack.common import network_utils # FIXME(markmc): remove this _ = lambda s: s @@ -449,9 +450,9 @@ class NotifyPublisher(Publisher): class Connection(object): """Connection object.""" - pool = None + pools = {} - def __init__(self, conf, server_params=None): + def __init__(self, conf, url): if not qpid_messaging: raise ImportError("Failed to import qpid.messaging") @@ -460,35 +461,44 @@ class Connection(object): self.consumers = {} self.conf = conf - if server_params and 'hostname' in server_params: - # NOTE(russellb) This enables support for cast_to_server. - server_params['qpid_hosts'] = [ - '%s:%d' % (server_params['hostname'], - server_params.get('port', 5672)) - ] + self.brokers_params = [] + if url.hosts: + for host in url.hosts: + params = { + 'username': host.username or '', + 'password': host.password or '', + } + if host.port is not None: + params['host'] = '%s:%d' % (host.hostname, host.port) + else: + params['host'] = host.hostname + self.brokers_params.append(params) + else: + # Old configuration format + for adr in self.conf.qpid_hosts: + hostname, port = network_utils.parse_host_port( + adr, default_port=5672) - params = { - 'qpid_hosts': self.conf.qpid_hosts[:], - 'username': self.conf.qpid_username, - 'password': self.conf.qpid_password, - } - params.update(server_params or {}) + params = { + 'host': '%s:%d' % (hostname, port), + 'username': self.conf.qpid_username, + 'password': self.conf.qpid_password, + } + self.brokers_params.append(params) - random.shuffle(params['qpid_hosts']) - self.brokers = itertools.cycle(params['qpid_hosts']) + random.shuffle(self.brokers_params) + self.brokers = itertools.cycle(self.brokers_params) - self.username = params['username'] - self.password = params['password'] self.reconnect() def connection_create(self, broker): # Create the connection - this does not open the connection - self.connection = qpid_messaging.Connection(broker) + self.connection = qpid_messaging.Connection(broker['host']) # Check if flags are set and if so set them for the connection # before we call open - self.connection.username = self.username - self.connection.password = self.password + self.connection.username = broker['username'] + self.connection.password = broker['password'] self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms # Reconnection is done by self.reconnect() @@ -520,14 +530,14 @@ class Connection(object): self.connection_create(broker) self.connection.open() except qpid_exceptions.MessagingError as e: - msg_dict = dict(e=e, delay=delay) - msg = _("Unable to connect to AMQP server: %(e)s. " - "Sleeping %(delay)s seconds") % msg_dict + msg_dict = dict(e=e, delay=delay, broker=broker['host']) + msg = _("Unable to connect to AMQP server on %(broker)s: " + "%(e)s. Sleeping %(delay)s seconds") % msg_dict LOG.error(msg) time.sleep(delay) delay = min(delay + 1, 5) else: - LOG.info(_('Connected to AMQP server on %s'), broker) + LOG.info(_('Connected to AMQP server on %s'), broker['host']) break self.session = self.connection.session() @@ -687,7 +697,7 @@ class QpidDriver(amqpdriver.AMQPDriverBase): conf.register_opts(qpid_opts) conf.register_opts(rpc_amqp.amqp_opts) - connection_pool = rpc_amqp.get_connection_pool(conf, Connection) + connection_pool = rpc_amqp.get_connection_pool(conf, url, Connection) super(QpidDriver, self).__init__(conf, url, connection_pool, diff --git a/oslo/messaging/_drivers/impl_rabbit.py b/oslo/messaging/_drivers/impl_rabbit.py index b2cf460b9..7c5d1b3fa 100644 --- a/oslo/messaging/_drivers/impl_rabbit.py +++ b/oslo/messaging/_drivers/impl_rabbit.py @@ -421,9 +421,9 @@ class NotifyPublisher(TopicPublisher): class Connection(object): """Connection object.""" - pool = None + pools = {} - def __init__(self, conf, server_params=None): + def __init__(self, conf, url): self.consumers = [] self.conf = conf self.max_retries = self.conf.rabbit_max_retries @@ -436,39 +436,54 @@ class Connection(object): self.interval_max = 30 self.memory_transport = False - if server_params is None: - server_params = {} - # Keys to translate from server_params to kombu params - server_params_to_kombu_params = {'username': 'userid'} - ssl_params = self._fetch_ssl_params() - params_list = [] - for adr in self.conf.rabbit_hosts: - hostname, port = network_utils.parse_host_port( - adr, default_port=self.conf.rabbit_port) - params = { - 'hostname': hostname, - 'port': port, - 'userid': self.conf.rabbit_userid, - 'password': self.conf.rabbit_password, - 'login_method': self.conf.rabbit_login_method, - 'virtual_host': self.conf.rabbit_virtual_host, - } + if url.virtual_host is not None: + virtual_host = url.virtual_host + else: + virtual_host = self.conf.rabbit_virtual_host - for sp_key, value in six.iteritems(server_params): - p_key = server_params_to_kombu_params.get(sp_key, sp_key) - params[p_key] = value + self.brokers_params = [] + if url.hosts: + for host in url.hosts: + params = { + 'hostname': host.hostname, + 'port': host.port or 5672, + 'userid': host.username or '', + 'password': host.password or '', + 'login_method': self.conf.rabbit_login_method, + 'virtual_host': virtual_host + } + if self.conf.fake_rabbit: + params['transport'] = 'memory' + if self.conf.rabbit_use_ssl: + params['ssl'] = ssl_params - if self.conf.fake_rabbit: - params['transport'] = 'memory' - if self.conf.rabbit_use_ssl: - params['ssl'] = ssl_params + self.brokers_params.append(params) + else: + # Old configuration format + for adr in self.conf.rabbit_hosts: + hostname, port = network_utils.parse_host_port( + adr, default_port=self.conf.rabbit_port) - params_list.append(params) + params = { + 'hostname': hostname, + 'port': port, + 'userid': self.conf.rabbit_userid, + 'password': self.conf.rabbit_password, + 'login_method': self.conf.rabbit_login_method, + 'virtual_host': virtual_host + } - random.shuffle(params_list) - self.params_list = itertools.cycle(params_list) + if self.conf.fake_rabbit: + params['transport'] = 'memory' + if self.conf.rabbit_use_ssl: + params['ssl'] = ssl_params + + self.brokers_params.append(params) + + random.shuffle(self.brokers_params) + self.brokers = itertools.cycle(self.brokers_params) self.memory_transport = self.conf.fake_rabbit @@ -519,14 +534,14 @@ class Connection(object): # Return the extended behavior or just have the default behavior return ssl_params or True - def _connect(self, params): + def _connect(self, broker): """Connect to rabbit. Re-establish any queues that may have been declared before if we are reconnecting. Exceptions should be handled by the caller. """ if self.connection: LOG.info(_("Reconnecting to AMQP server on " - "%(hostname)s:%(port)d") % params) + "%(hostname)s:%(port)d") % broker) try: # XXX(nic): when reconnecting to a RabbitMQ cluster # with mirrored queues in use, the attempt to release the @@ -545,7 +560,7 @@ class Connection(object): # Setting this in case the next statement fails, though # it shouldn't be doing any network operations, yet. self.connection = None - self.connection = kombu.connection.BrokerConnection(**params) + self.connection = kombu.connection.BrokerConnection(**broker) self.connection_errors = self.connection.connection_errors self.channel_errors = self.connection.channel_errors if self.memory_transport: @@ -561,7 +576,7 @@ class Connection(object): for consumer in self.consumers: consumer.reconnect(self.channel) LOG.info(_('Connected to AMQP server on %(hostname)s:%(port)d') % - params) + broker) def reconnect(self): """Handles reconnecting and re-establishing queues. @@ -574,10 +589,10 @@ class Connection(object): attempt = 0 while True: - params = six.next(self.params_list) + broker = six.next(self.brokers) attempt += 1 try: - self._connect(params) + self._connect(broker) return except IOError as e: pass @@ -596,7 +611,7 @@ class Connection(object): log_info = {} log_info['err_str'] = str(e) log_info['max_retries'] = self.max_retries - log_info.update(params) + log_info.update(broker) if self.max_retries and attempt == self.max_retries: msg = _('Unable to connect to AMQP server on ' @@ -775,7 +790,7 @@ class RabbitDriver(amqpdriver.AMQPDriverBase): conf.register_opts(rabbit_opts) conf.register_opts(rpc_amqp.amqp_opts) - connection_pool = rpc_amqp.get_connection_pool(conf, Connection) + connection_pool = rpc_amqp.get_connection_pool(conf, url, Connection) super(RabbitDriver, self).__init__(conf, url, connection_pool, diff --git a/tests/test_qpid.py b/tests/test_qpid.py index 419d9dd15..23145518a 100644 --- a/tests/test_qpid.py +++ b/tests/test_qpid.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +import operator import random import thread import threading @@ -102,6 +103,65 @@ class _QpidBaseTestCase(test_utils.BaseTestCase): self.con_send.close() +class TestQpidTransportURL(_QpidBaseTestCase): + + scenarios = [ + ('none', dict(url=None, + expected=[dict(host='localhost:5672', + username='', + password='')])), + ('empty', + dict(url='qpid:///', + expected=[dict(host='localhost:5672', + username='', + password='')])), + ('localhost', + dict(url='qpid://localhost/', + expected=[dict(host='localhost', + username='', + password='')])), + ('no_creds', + dict(url='qpid://host/', + expected=[dict(host='host', + username='', + password='')])), + ('no_port', + dict(url='qpid://user:password@host/', + expected=[dict(host='host', + username='user', + password='password')])), + ('full_url', + dict(url='qpid://user:password@host:10/', + expected=[dict(host='host:10', + username='user', + password='password')])), + ('full_two_url', + dict(url='qpid://user:password@host:10,' + 'user2:password2@host2:12/', + expected=[dict(host='host:10', + username='user', + password='password'), + dict(host='host2:12', + username='user2', + password='password2') + ] + )), + + ] + + @mock.patch.object(qpid_driver.Connection, 'reconnect') + def test_transport_url(self, *args): + transport = messaging.get_transport(self.conf, self.url) + self.addCleanup(transport.cleanup) + driver = transport._driver + + brokers_params = driver._get_connection().brokers_params + self.assertEqual(sorted(self.expected, + key=operator.itemgetter('host')), + sorted(brokers_params, + key=operator.itemgetter('host'))) + + class TestQpidInvalidTopologyVersion(_QpidBaseTestCase): """Unit test cases to test invalid qpid topology version.""" @@ -398,11 +458,12 @@ class TestQpidReconnectOrder(test_utils.BaseTestCase): brokers = ['host1', 'host2', 'host3', 'host4', 'host5'] brokers_count = len(brokers) - self.messaging_conf.conf.qpid_hosts = brokers + self.config(qpid_hosts=brokers) with mock.patch('qpid.messaging.Connection') as conn_mock: # starting from the first broker in the list - connection = qpid_driver.Connection(self.messaging_conf.conf) + url = messaging.TransportURL.parse(self.conf, None) + connection = qpid_driver.Connection(self.conf, url) # reconnect will advance to the next broker, one broker per # attempt, and then wrap to the start of the list once the end is @@ -412,7 +473,7 @@ class TestQpidReconnectOrder(test_utils.BaseTestCase): expected = [] for broker in brokers: - expected.extend([mock.call(broker), + expected.extend([mock.call("%s:5672" % broker), mock.call().open(), mock.call().session(), mock.call().opened(), @@ -601,6 +662,9 @@ class FakeQpidSession(object): key = slash_split[-1] return key.strip() + def close(self): + pass + _fake_session = FakeQpidSession() diff --git a/tests/test_rabbit.py b/tests/test_rabbit.py index f08d57e38..d42a0f507 100644 --- a/tests/test_rabbit.py +++ b/tests/test_rabbit.py @@ -13,6 +13,7 @@ # under the License. import datetime +import operator import sys import threading import uuid @@ -46,74 +47,88 @@ class TestRabbitDriverLoad(test_utils.BaseTestCase): class TestRabbitTransportURL(test_utils.BaseTestCase): scenarios = [ - ('none', dict(url=None, expected=None)), + ('none', dict(url=None, + expected=[dict(hostname='localhost', + port=5672, + userid='guest', + password='guest', + virtual_host='/')])), ('empty', dict(url='rabbit:///', - expected=dict(virtual_host=''))), + expected=[dict(hostname='localhost', + port=5672, + userid='guest', + password='guest', + virtual_host='')])), ('localhost', dict(url='rabbit://localhost/', - expected=dict(hostname='localhost', - username='', - password='', - virtual_host=''))), + expected=[dict(hostname='localhost', + port=5672, + userid='', + password='', + virtual_host='')])), ('virtual_host', dict(url='rabbit:///vhost', - expected=dict(virtual_host='vhost'))), + expected=[dict(hostname='localhost', + port=5672, + userid='guest', + password='guest', + virtual_host='vhost')])), ('no_creds', dict(url='rabbit://host/virtual_host', - expected=dict(hostname='host', - username='', - password='', - virtual_host='virtual_host'))), + expected=[dict(hostname='host', + port=5672, + userid='', + password='', + virtual_host='virtual_host')])), ('no_port', dict(url='rabbit://user:password@host/virtual_host', - expected=dict(hostname='host', - username='user', - password='password', - virtual_host='virtual_host'))), + expected=[dict(hostname='host', + port=5672, + userid='user', + password='password', + virtual_host='virtual_host')])), ('full_url', dict(url='rabbit://user:password@host:10/virtual_host', - expected=dict(hostname='host', - port=10, - username='user', - password='password', - virtual_host='virtual_host'))), + expected=[dict(hostname='host', + port=10, + userid='user', + password='password', + virtual_host='virtual_host')])), + ('full_two_url', + dict(url='rabbit://user:password@host:10,' + 'user2:password2@host2:12/virtual_host', + expected=[dict(hostname='host', + port=10, + userid='user', + password='password', + virtual_host='virtual_host'), + dict(hostname='host2', + port=12, + userid='user2', + password='password2', + virtual_host='virtual_host') + ] + )), + ] - def setUp(self): - super(TestRabbitTransportURL, self).setUp() - - self.messaging_conf.transport_driver = 'rabbit' + def test_transport_url(self): self.messaging_conf.in_memory = True - self._server_params = [] - cnx_init = rabbit_driver.Connection.__init__ + transport = messaging.get_transport(self.conf, self.url) + self.addCleanup(transport.cleanup) + driver = transport._driver - def record_params(cnx, conf, server_params=None): - self._server_params.append(server_params) - return cnx_init(cnx, conf, server_params) + brokers_params = driver._get_connection().brokers_params[:] + brokers_params = [dict((k, v) for k, v in broker.items() + if k not in ['transport', 'login_method']) + for broker in brokers_params] - def dummy_send(cnx, topic, msg, timeout=None): - pass - - self.stubs.Set(rabbit_driver.Connection, '__init__', record_params) - self.stubs.Set(rabbit_driver.Connection, 'topic_send', dummy_send) - - self._driver = messaging.get_transport(self.conf, self.url)._driver - self._target = messaging.Target(topic='testtopic') - - def test_transport_url_listen(self): - self._driver.listen(self._target) - self.assertEqual(self.expected, self._server_params[0]) - - def test_transport_url_listen_for_notification(self): - self._driver.listen_for_notifications( - [(messaging.Target(topic='topic'), 'info')]) - self.assertEqual(self.expected, self._server_params[0]) - - def test_transport_url_send(self): - self._driver.send(self._target, {}, {}) - self.assertEqual(self.expected, self._server_params[0]) + self.assertEqual(sorted(self.expected, + key=operator.itemgetter('hostname')), + sorted(brokers_params, + key=operator.itemgetter('hostname'))) class TestSendReceive(test_utils.BaseTestCase): @@ -619,8 +634,8 @@ class RpcKombuHATestCase(test_utils.BaseTestCase): brokers = ['host1', 'host2', 'host3', 'host4', 'host5'] brokers_count = len(brokers) - self.conf.rabbit_hosts = brokers - self.conf.rabbit_max_retries = 1 + self.config(rabbit_hosts=brokers, + rabbit_max_retries=1) hostname_sets = set() @@ -639,7 +654,8 @@ class RpcKombuHATestCase(test_utils.BaseTestCase): self.stubs.Set(rabbit_driver.Connection, '_connect', _connect) # starting from the first broker in the list - connection = rabbit_driver.Connection(self.conf) + url = messaging.TransportURL.parse(self.conf, None) + connection = rabbit_driver.Connection(self.conf, url) # now that we have connection object, revert to the real 'connect' # implementation