From 8e6fda10c29bbded70a8e0d67135bfad09a787ea Mon Sep 17 00:00:00 2001 From: Matt Dietz Date: Thu, 20 Nov 2014 22:24:31 +0000 Subject: [PATCH] Adds connection switching to the Nvp driver RM10638 Updates the NVP driver to switch controllers on any traceback from the underlying connection. It's still up to the calling code to retry, or not, as this simply iterates a connection index and reraises the exception from the connection. I've opted to opt any more handling in this particular patch, as it's already a large change. --- quark/drivers/nvp_driver.py | 380 ++++++++++++----------- quark/exceptions.py | 5 + quark/tests/test_nvp_driver.py | 127 +++++--- quark/tests/test_optimized_nvp_driver.py | 40 +-- 4 files changed, 307 insertions(+), 245 deletions(-) diff --git a/quark/drivers/nvp_driver.py b/quark/drivers/nvp_driver.py index 86be1c6..2af82d1 100644 --- a/quark/drivers/nvp_driver.py +++ b/quark/drivers/nvp_driver.py @@ -17,6 +17,8 @@ NVP client driver for Quark """ +import contextlib + import aiclib from neutron.extensions import securitygroup as sg_ext from neutron.openstack.common import log as logging @@ -89,6 +91,7 @@ class NVPDriver(base.BaseDriver): # NOTE(mdietz): What does default_tz actually mean? # We don't have one default. # NOTE(jkoelker): Transport Zone + # NOTE(mdietz): :-/ tz isn't the issue. default is default_tz = CONF.NVP.default_tz LOG.info("Loading NVP settings " + str(default_tz)) connections = CONF.NVP.controller_connection @@ -99,6 +102,7 @@ class NVPDriver(base.BaseDriver): 'max_rules_per_group': CONF.NVP.max_rules_per_group, 'max_rules_per_port': CONF.NVP.max_rules_per_port}) LOG.info("Loading NVP settings " + str(connections)) + for conn in connections: (ip, port, user, pw, req_timeout, http_timeout, retries, redirects) = conn.split(":") @@ -114,7 +118,11 @@ class NVPDriver(base.BaseDriver): default_tz=default_tz, backoff=backoff)) - def get_connection(self): + def _connection(self): + if len(self.nvp_connections) == 0: + raise exceptions.NoBackendConnectionsDefined( + msg="No NVP connections defined cannot continue") + conn = self.nvp_connections[self.conn_index] if "connection" not in conn: scheme = conn["port"] == "443" and "https" or "http" @@ -132,6 +140,22 @@ class NVPDriver(base.BaseDriver): backoff=backoff) return conn["connection"] + def _next_connection(self): + # TODO(anyone): Do we want to drop and create new connections at some + # point? What about recycling them after a certain + # number of usages or time, proactively? + conn_len = len(self.nvp_connections) + if conn_len: + self.conn_index = (self.conn_index + 1) % conn_len + + @contextlib.contextmanager + def get_connection(self): + try: + yield self._connection() + except Exception: + self._next_connection() + raise + def create_network(self, context, network_name, tags=None, network_id=None, **kwargs): return self._lswitch_create(context, network_name, tags, @@ -187,67 +211,67 @@ class NVPDriver(base.BaseDriver): security_groups = security_groups or [] tenant_id = context.tenant_id lswitch = self._create_or_choose_lswitch(context, network_id) - connection = self.get_connection() - port = connection.lswitch_port(lswitch) - port.admin_status_enabled(status) - nvp_group_ids = self._get_security_groups_for_port(context, - security_groups) - port.security_profiles(nvp_group_ids) - tags = [dict(tag=network_id, scope="neutron_net_id"), - dict(tag=port_id, scope="neutron_port_id"), - dict(tag=tenant_id, scope="os_tid"), - dict(tag=device_id, scope="vm_id")] - LOG.debug("Creating port on switch %s" % lswitch) - port.tags(tags) - res = port.create() - try: - """Catching odd NVP returns here will make it safe to assume that - NVP returned something correct.""" - res["lswitch"] = lswitch - except TypeError: - LOG.exception("Unexpected return from NVP: %s" % res) - raise - port = connection.lswitch_port(lswitch) - port.uuid = res["uuid"] - port.attachment_vif(port_id) - return res + with self.get_connection() as connection: + port = connection.lswitch_port(lswitch) + port.admin_status_enabled(status) + nvp_group_ids = self._get_security_groups_for_port(context, + security_groups) + port.security_profiles(nvp_group_ids) + tags = [dict(tag=network_id, scope="neutron_net_id"), + dict(tag=port_id, scope="neutron_port_id"), + dict(tag=tenant_id, scope="os_tid"), + dict(tag=device_id, scope="vm_id")] + LOG.debug("Creating port on switch %s" % lswitch) + port.tags(tags) + res = port.create() + try: + """Catching odd NVP returns here will make it safe to assume that + NVP returned something correct.""" + res["lswitch"] = lswitch + except TypeError: + LOG.exception("Unexpected return from NVP: %s" % res) + raise + port = connection.lswitch_port(lswitch) + port.uuid = res["uuid"] + port.attachment_vif(port_id) + return res def update_port(self, context, port_id, status=True, security_groups=None, **kwargs): security_groups = security_groups or [] - connection = self.get_connection() - lswitch_id = self._lswitch_from_port(context, port_id) - port = connection.lswitch_port(lswitch_id, port_id) - nvp_group_ids = self._get_security_groups_for_port(context, - security_groups) - if nvp_group_ids: - port.security_profiles(nvp_group_ids) - port.admin_status_enabled(status) - return port.update() + with self.get_connection() as connection: + lswitch_id = self._lswitch_from_port(context, port_id) + port = connection.lswitch_port(lswitch_id, port_id) + nvp_group_ids = self._get_security_groups_for_port(context, + security_groups) + if nvp_group_ids: + port.security_profiles(nvp_group_ids) + port.admin_status_enabled(status) + return port.update() def delete_port(self, context, port_id, **kwargs): - connection = self.get_connection() - lswitch_uuid = kwargs.get('lswitch_uuid', None) - try: - if not lswitch_uuid: - lswitch_uuid = self._lswitch_from_port(context, port_id) - LOG.debug("Deleting port %s from lswitch %s" - % (port_id, lswitch_uuid)) - connection.lswitch_port(lswitch_uuid, port_id).delete() - except aiclib.core.AICException as ae: - if ae.code == 404: - LOG.info("LSwitchPort/Port %s not found in NVP." - " Ignoring explicitly. Code: %s, Message: %s" - % (port_id, ae.code, ae.message)) - else: - LOG.info("AICException deleting LSwitchPort/Port %s in NVP." - " Ignoring explicitly. Code: %s, Message: %s" - % (port_id, ae.code, ae.message)) + with self.get_connection() as connection: + lswitch_uuid = kwargs.get('lswitch_uuid', None) + try: + if not lswitch_uuid: + lswitch_uuid = self._lswitch_from_port(context, port_id) + LOG.debug("Deleting port %s from lswitch %s" + % (port_id, lswitch_uuid)) + connection.lswitch_port(lswitch_uuid, port_id).delete() + except aiclib.core.AICException as ae: + if ae.code == 404: + LOG.info("LSwitchPort/Port %s not found in NVP." + " Ignoring explicitly. Code: %s, Message: %s" + % (port_id, ae.code, ae.message)) + else: + LOG.info("AICException deleting LSwitchPort/Port %s in " + "NVP. Ignoring explicitly. Code: %s, Message: %s" + % (port_id, ae.code, ae.message)) - except Exception as e: - LOG.info("Failed to delete LSwitchPort/Port %s in NVP." - " Ignoring explicitly. Message: %s" - % (port_id, e.args[0])) + except Exception as e: + LOG.info("Failed to delete LSwitchPort/Port %s in NVP." + " Ignoring explicitly. Message: %s" + % (port_id, e.args[0])) def _collect_lport_info(self, lport, get_status): info = { @@ -291,23 +315,23 @@ class NVPDriver(base.BaseDriver): return info def diag_port(self, context, port_id, get_status=False, **kwargs): - connection = self.get_connection() - lswitch_uuid = self._lswitch_from_port(context, port_id) - lswitch_port = connection.lswitch_port(lswitch_uuid, port_id) + with self.get_connection() as connection: + lswitch_uuid = self._lswitch_from_port(context, port_id) + lswitch_port = connection.lswitch_port(lswitch_uuid, port_id) - query = lswitch_port.query() - query.relations("LogicalPortAttachment") - results = query.results() - if results['result_count'] == 0: - return {'lport': "Logical port not found."} + query = lswitch_port.query() + query.relations("LogicalPortAttachment") + results = query.results() + if results['result_count'] == 0: + return {'lport': "Logical port not found."} - config = results['results'][0] - relations = config.pop('_relations') - config['attachment'] = relations['LogicalPortAttachment']['type'] - if get_status: - config['status'] = lswitch_port.status() - config['statistics'] = lswitch_port.statistics() - return {'lport': self._collect_lport_info(config, get_status)} + config = results['results'][0] + relations = config.pop('_relations') + config['attachment'] = relations['LogicalPortAttachment']['type'] + if get_status: + config['status'] = lswitch_port.status() + config['statistics'] = lswitch_port.statistics() + return {'lport': self._collect_lport_info(config, get_status)} def _get_network_details(self, context, network_id, switches): name, phys_net, phys_type, segment_id = None, None, None, None @@ -326,55 +350,55 @@ class NVPDriver(base.BaseDriver): def create_security_group(self, context, group_name, **group): tenant_id = context.tenant_id - connection = self.get_connection() - group_id = group.get('group_id') - profile = connection.securityprofile() - if group_name: - profile.display_name(group_name) - ingress_rules = group.get('port_ingress_rules', []) - egress_rules = group.get('port_egress_rules', []) + with self.get_connection() as connection: + group_id = group.get('group_id') + profile = connection.securityprofile() + if group_name: + profile.display_name(group_name) + ingress_rules = group.get('port_ingress_rules', []) + egress_rules = group.get('port_egress_rules', []) - if (len(ingress_rules) + len(egress_rules) > - self.limits['max_rules_per_group']): - raise exceptions.DriverLimitReached(limit="rules per group") + if (len(ingress_rules) + len(egress_rules) > + self.limits['max_rules_per_group']): + raise exceptions.DriverLimitReached(limit="rules per group") - if egress_rules: - profile.port_egress_rules(egress_rules) - if ingress_rules: - profile.port_ingress_rules(ingress_rules) - tags = [dict(tag=group_id, scope="neutron_group_id"), - dict(tag=tenant_id, scope="os_tid")] - LOG.debug("Creating security profile %s" % group_name) - profile.tags(tags) - return profile.create() + if egress_rules: + profile.port_egress_rules(egress_rules) + if ingress_rules: + profile.port_ingress_rules(ingress_rules) + tags = [dict(tag=group_id, scope="neutron_group_id"), + dict(tag=tenant_id, scope="os_tid")] + LOG.debug("Creating security profile %s" % group_name) + profile.tags(tags) + return profile.create() def delete_security_group(self, context, group_id, **kwargs): guuid = self._get_security_group_id(context, group_id) - connection = self.get_connection() - LOG.debug("Deleting security profile %s" % group_id) - connection.securityprofile(guuid).delete() + with self.get_connection() as connection: + LOG.debug("Deleting security profile %s" % group_id) + connection.securityprofile(guuid).delete() def update_security_group(self, context, group_id, **group): query = self._get_security_group(context, group_id) - connection = self.get_connection() - profile = connection.securityprofile(query.get('uuid')) + with self.get_connection() as connection: + profile = connection.securityprofile(query.get('uuid')) - ingress_rules = group.get('port_ingress_rules', - query.get('logical_port_ingress_rules')) - egress_rules = group.get('port_egress_rules', - query.get('logical_port_egress_rules')) + ingress_rules = group.get('port_ingress_rules', + query.get('logical_port_ingress_rules')) + egress_rules = group.get('port_egress_rules', + query.get('logical_port_egress_rules')) - if (len(ingress_rules) + len(egress_rules) > - self.limits['max_rules_per_group']): - raise exceptions.DriverLimitReached(limit="rules per group") + if (len(ingress_rules) + len(egress_rules) > + self.limits['max_rules_per_group']): + raise exceptions.DriverLimitReached(limit="rules per group") - if group.get('name', None): - profile.display_name(group['name']) - if group.get('port_ingress_rules', None) is not None: - profile.port_ingress_rules(ingress_rules) - if group.get('port_egress_rules', None) is not None: - profile.port_egress_rules(egress_rules) - return profile.update() + if group.get('name', None): + profile.display_name(group['name']) + if group.get('port_ingress_rules', None) is not None: + profile.port_ingress_rules(ingress_rules) + if group.get('port_egress_rules', None) is not None: + profile.port_egress_rules(egress_rules) + return profile.update() def _update_security_group_rules(self, context, group_id, rule, operation, checks): @@ -447,9 +471,9 @@ class NVPDriver(base.BaseDriver): return None def _lswitch_delete(self, context, lswitch_uuid): - connection = self.get_connection() - LOG.debug("Deleting lswitch %s" % lswitch_uuid) - connection.lswitch(lswitch_uuid).delete() + with self.get_connection() as connection: + LOG.debug("Deleting lswitch %s" % lswitch_uuid) + connection.lswitch(lswitch_uuid).delete() def _config_provider_attrs(self, connection, switch, phys_net, net_type, segment_id): @@ -491,64 +515,65 @@ class NVPDriver(base.BaseDriver): (context.tenant_id, network_name)) tenant_id = context.tenant_id - connection = self.get_connection() + with self.get_connection() as connection: + switch = connection.lswitch() + if network_name is None: + network_name = network_id + switch.display_name(network_name[:40]) + tags = tags or [] + tags.append({"tag": tenant_id, "scope": "os_tid"}) + if network_id: + tags.append({"tag": network_id, "scope": "neutron_net_id"}) + switch.tags(tags) + pnet = phys_net or CONF.NVP.default_tz + ptype = phys_type or CONF.NVP.default_tz_type + switch.transport_zone(pnet, ptype) + LOG.debug("Creating lswitch for network %s" % network_id) - switch = connection.lswitch() - if network_name is None: - network_name = network_id - switch.display_name(network_name[:40]) - tags = tags or [] - tags.append({"tag": tenant_id, "scope": "os_tid"}) - if network_id: - tags.append({"tag": network_id, "scope": "neutron_net_id"}) - switch.tags(tags) - pnet = phys_net or CONF.NVP.default_tz - ptype = phys_type or CONF.NVP.default_tz_type - switch.transport_zone(pnet, ptype) - LOG.debug("Creating lswitch for network %s" % network_id) - - # When connecting to public or snet, we need switches that are - # connected to their respective public/private transport zones - # using a "bridge" connector. Public uses no VLAN, whereas private - # uses VLAN 122 in netdev. Probably need this to be configurable - self._config_provider_attrs(connection, switch, phys_net, phys_type, - segment_id) - res = switch.create() - try: - uuid = res["uuid"] - return uuid - except TypeError: - LOG.exception("Unexpected return from NVP: %s" % res) - raise + # When connecting to public or snet, we need switches that are + # connected to their respective public/private transport zones + # using a "bridge" connector. Public uses no VLAN, whereas private + # uses VLAN 122 in netdev. Probably need this to be configurable + self._config_provider_attrs(connection, switch, phys_net, + phys_type, segment_id) + res = switch.create() + try: + uuid = res["uuid"] + return uuid + except TypeError: + LOG.exception("Unexpected return from NVP: %s" % res) + raise def _lswitches_for_network(self, context, network_id): - connection = self.get_connection() - query = connection.lswitch().query() - query.tagscopes(['os_tid', 'neutron_net_id']) - query.tags([context.tenant_id, network_id]) - return query + with self.get_connection() as connection: + query = connection.lswitch().query() + query.tagscopes(['os_tid', 'neutron_net_id']) + query.tags([context.tenant_id, network_id]) + return query def _lswitch_from_port(self, context, port_id): - connection = self.get_connection() - query = connection.lswitch_port("*").query() - query.relations("LogicalSwitchConfig") - query.uuid(port_id) - port = query.results() - if port['result_count'] > 1: - raise Exception("Could not identify lswitch for port %s" % port_id) - if port['result_count'] < 1: - raise Exception("No lswitch found for port %s" % port_id) - return port['results'][0]["_relations"]["LogicalSwitchConfig"]["uuid"] + with self.get_connection() as connection: + query = connection.lswitch_port("*").query() + query.relations("LogicalSwitchConfig") + query.uuid(port_id) + port = query.results() + if port['result_count'] > 1: + raise Exception("Could not identify lswitch for port %s" % + port_id) + if port['result_count'] < 1: + raise Exception("No lswitch found for port %s" % port_id) + cfg = port['results'][0]["_relations"]["LogicalSwitchConfig"] + return cfg["uuid"] def _get_security_group(self, context, group_id): - connection = self.get_connection() - query = connection.securityprofile().query() - query.tagscopes(['os_tid', 'neutron_group_id']) - query.tags([context.tenant_id, group_id]) - query = query.results() - if query['result_count'] != 1: - raise sg_ext.SecurityGroupNotFound(id=group_id) - return query['results'][0] + with self.get_connection() as connection: + query = connection.securityprofile().query() + query.tagscopes(['os_tid', 'neutron_group_id']) + query.tags([context.tenant_id, group_id]) + query = query.results() + if query['result_count'] != 1: + raise sg_ext.SecurityGroupNotFound(id=group_id) + return query['results'][0] def _get_security_group_id(self, context, group_id): return self._get_security_group(context, group_id)['uuid'] @@ -567,24 +592,25 @@ class NVPDriver(base.BaseDriver): if rule.get(key): rule_clone[key] = rule[key] - connection = self.get_connection() - secrule = connection.securityrule(ethertype, **rule_clone) + with self.get_connection() as connection: + secrule = connection.securityrule(ethertype, **rule_clone) - direction = rule.get('direction', '') - if direction not in ['ingress', 'egress']: - raise AttributeError( - "Direction not specified as 'ingress' or 'egress'.") - return (direction, secrule) + direction = rule.get('direction', '') + if direction not in ['ingress', 'egress']: + raise AttributeError( + "Direction not specified as 'ingress' or 'egress'.") + return (direction, secrule) def _check_rule_count_per_port(self, context, group_id): - connection = self.get_connection() - ports = connection.lswitch_port("*").query().security_profile_uuid( - '=', self._get_security_group_id( - context, group_id)).results().get('results', []) - groups = (port.get('security_profiles', []) for port in ports) - return max([self._check_rule_count_for_groups( - context, (connection.securityprofile(gp).read() for gp in group)) - for group in groups] or [0]) + with self.get_connection() as connection: + ports = connection.lswitch_port("*").query().security_profile_uuid( + '=', self._get_security_group_id( + context, group_id)).results().get('results', []) + groups = (port.get('security_profiles', []) for port in ports) + return max([self._check_rule_count_for_groups( + context, (connection.securityprofile(gp).read() + for gp in group)) + for group in groups] or [0]) def _check_rule_count_for_groups(self, context, groups): return sum(len(group['logical_port_ingress_rules']) + diff --git a/quark/exceptions.py b/quark/exceptions.py index d10c622..fc74863 100644 --- a/quark/exceptions.py +++ b/quark/exceptions.py @@ -123,3 +123,8 @@ class RedisConnectionFailure(exceptions.NeutronException): class RedisSlaveWritesForbidden(exceptions.NeutronException): message = _("No write actions can be applied to Slave redis nodes.") + + +class NoBackendConnectionsDefined(exceptions.NeutronException): + message = _("This driver cannot be used without a backend connection " + "definition. %(msg)") diff --git a/quark/tests/test_nvp_driver.py b/quark/tests/test_nvp_driver.py index 430af56..7b297b2 100644 --- a/quark/tests/test_nvp_driver.py +++ b/quark/tests/test_nvp_driver.py @@ -124,10 +124,10 @@ class TestNVPDriverCreateNetwork(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + ) as (conn,): connection = self._create_connection() - get_connection.return_value = connection + conn.return_value = connection yield connection def test_create_network(self): @@ -151,8 +151,8 @@ class TestNVPDriverProviderNetwork(TestNVPDriver): @contextlib.contextmanager def _stubs(self, tz): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + ) as (conn,): connection = self._create_connection() switch = self._create_lswitch(1, False) switch.transport_zone = mock.Mock() @@ -161,7 +161,7 @@ class TestNVPDriverProviderNetwork(TestNVPDriver): tz_query = mock.Mock() tz_query.query = mock.Mock(return_value=tz_results) connection.transportzone = mock.Mock(return_value=tz_query) - get_connection.return_value = connection + conn.return_value = connection yield connection, switch def test_config_provider_attrs_flat_net(self): @@ -286,11 +286,11 @@ class TestNVPDriverDeleteNetwork(TestNVPDriver): @contextlib.contextmanager def _stubs(self, network_exists=True): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), - ) as (get_connection, switch_list): + ) as (conn, switch_list): connection = self._create_connection() - get_connection.return_value = connection + conn.return_value = connection if network_exists: ret = {"results": [{"uuid": self.lswitch_uuid}]} else: @@ -318,12 +318,12 @@ class TestNVPDriverDeleteNetworkWithExceptions(TestNVPDriver): @contextlib.contextmanager def _stubs(self, network_exists=True, exception=None): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), mock.patch("%s._lswitch_delete" % self.d_pkg), - ) as (get_connection, switch_list, switch_delete): + ) as (conn, switch_list, switch_delete): connection = self._create_connection() - get_connection.return_value = connection + conn.return_value = connection if network_exists: ret = {"results": [{"uuid": self.lswitch_uuid}]} else: @@ -372,13 +372,14 @@ class TestNVPDriverCreatePort(TestNVPDriver): @contextlib.contextmanager def _stubs(self, has_lswitch=True, maxed_ports=False, net_details=None): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), mock.patch("%s._get_network_details" % self.d_pkg), - ) as (get_connection, get_switches, get_net_dets): + ) as (conn, next_conn, get_switches, get_net_dets): connection = self._create_connection(has_switches=has_lswitch, maxed_ports=maxed_ports) - get_connection.return_value = connection + conn.return_value = connection get_switches.return_value = connection.lswitch().query() get_net_dets.return_value = net_details yield connection @@ -517,11 +518,12 @@ class TestNVPDriverUpdatePort(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection() connection.securityprofile = self._create_security_profile() - get_connection.return_value = connection + conn.return_value = connection yield connection def test_update_port(self): @@ -550,10 +552,10 @@ class TestNVPDriverLswitchesForNetwork(TestNVPDriver): @contextlib.contextmanager def _stubs(self, single_switch=True): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + ) as (conn,): connection = self._create_connection(switch_count=1) - get_connection.return_value = connection + conn.return_value = connection yield connection def test_get_lswitches(self): @@ -606,10 +608,11 @@ class TestNVPDriverDeletePort(TestNVPDriver): @contextlib.contextmanager def _stubs(self, switch_count=1): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection(switch_count=switch_count) - get_connection.return_value = connection + conn.return_value = connection yield connection def test_delete_port(self): @@ -645,11 +648,11 @@ class TestNVPDriverDeletePortWithExceptions(TestNVPDriver): @contextlib.contextmanager def _stubs(self, switch_exception=None, delete_exception=None): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitch_from_port" % self.d_pkg), - ) as (get_connection, switch): + ) as (conn, switch): connection = self._create_connection() - get_connection.return_value = connection + conn.return_value = connection if switch_exception: switch.side_effect = switch_exception else: @@ -729,11 +732,12 @@ class TestNVPDriverCreateSecurityGroup(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection() connection.securityprofile = self._create_security_profile() - get_connection.return_value = connection + conn.return_value = connection yield connection def test_security_group_create(self): @@ -783,11 +787,12 @@ class TestNVPDriverDeleteSecurityGroup(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection() connection.securityprofile = self._create_security_profile() - get_connection.return_value = connection + conn.return_value = connection yield connection def test_security_group_delete(self): @@ -812,11 +817,12 @@ class TestNVPDriverUpdateSecurityGroup(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection() connection.securityprofile = self._create_security_profile() - get_connection.return_value = connection + conn.return_value = connection yield connection def test_security_group_update(self): @@ -872,14 +878,15 @@ class TestNVPDriverCreateSecurityGroupRule(TestNVPDriver): @contextlib.contextmanager def _stubs(self): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + mock.patch("%s._next_connection" % self.d_pkg), + ) as (conn, next_conn): connection = self._create_connection() connection.securityprofile = self._create_security_profile() connection.securityrule = self._create_security_rule() connection.lswitch_port().query.return_value = ( self._create_lport_query(1, [self.profile_id])) - get_connection.return_value = connection + conn.return_value = connection yield connection def test_security_rule_create(self): @@ -955,13 +962,13 @@ class TestNVPDriverDeleteSecurityGroupRule(TestNVPDriver): rulelist['logical_port_%s_rules' % rule.pop('direction')].append( rule) with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), - ) as (get_connection,): + mock.patch("%s._connection" % self.d_pkg), + ) as (conn,): connection = self._create_connection() connection.securityprofile = self._create_security_profile() connection.securityrule = self._create_security_rule() connection.securityprofile().read().update(rulelist) - get_connection.return_value = connection + conn.return_value = connection yield connection def test_delete_security_group(self): @@ -1023,16 +1030,40 @@ class TestNVPGetConnection(TestNVPDriver): http_timeout=10, retries=1, backoff=0)) - with mock.patch("aiclib.nvp.Connection") as (aiclib_conn): - yield aiclib_conn + with contextlib.nested( + mock.patch("aiclib.nvp.Connection"), + mock.patch("%s._next_connection" % self.d_pkg) + ) as (aiclib_conn, next_conn): + yield aiclib_conn, next_conn cfg.CONF.clear_override("controller_connection", "NVP") def test_get_connection(self): - with self._stubs(has_conn=False) as aiclib_conn: - self.driver.get_connection() + with self._stubs(has_conn=False) as (aiclib_conn, next_conn): + with self.driver.get_connection(): + pass self.assertTrue(aiclib_conn.called) + self.assertFalse(next_conn.called) def test_get_connection_connection_defined(self): - with self._stubs(has_conn=True) as aiclib_conn: - self.driver.get_connection() + with self._stubs(has_conn=True) as (aiclib_conn, next_conn): + with self.driver.get_connection(): + pass self.assertFalse(aiclib_conn.called) + self.assertFalse(next_conn.called) + + def test_get_connection_iterates(self): + with self._stubs(has_conn=True) as (aiclib_conn, next_conn): + try: + with self.driver.get_connection(): + raise Exception("Failure") + except Exception: + pass + self.assertFalse(aiclib_conn.called) + self.assertTrue(next_conn.called) + + +class TestNVPGetConnectionNoneDefined(TestNVPDriver): + def test_get_connection(self): + with self.assertRaises(q_exc.NoBackendConnectionsDefined): + with self.driver.get_connection(): + pass diff --git a/quark/tests/test_optimized_nvp_driver.py b/quark/tests/test_optimized_nvp_driver.py index 0a14d4c..f6ba105 100644 --- a/quark/tests/test_optimized_nvp_driver.py +++ b/quark/tests/test_optimized_nvp_driver.py @@ -52,13 +52,13 @@ class TestOptimizedNVPDriverDeleteNetwork(TestOptimizedNVPDriver): @contextlib.contextmanager def _stubs(self, switch_count=1): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitch_select_by_nvp_id" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), - ) as (get_connection, select_switch, get_switches): + ) as (conn, select_switch, get_switches): connection = self._create_connection() switch = self._create_lswitch_mock() - get_connection.return_value = connection + conn.return_value = connection select_switch.return_value = switch get_switches.return_value = [switch] * switch_count self.context.session.delete = mock.Mock(return_value=None) @@ -105,14 +105,14 @@ class TestOptimizedNVPDriverDeleteNetworkWithExceptions( @contextlib.contextmanager def _stubs(self, switch_count=1, error_code=500): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitch_select_by_nvp_id" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), mock.patch("%s._lswitch_delete" % self.d_pkg) - ) as (get_connection, select_switch, get_switches, delete_switch): + ) as (conn, select_switch, get_switches, delete_switch): connection = self._create_connection() switch = self._create_lswitch_mock() - get_connection.return_value = connection + conn.return_value = connection select_switch.return_value = switch get_switches.return_value = [switch] * switch_count delete_switch.side_effect = aiclib.core.AICException( @@ -151,17 +151,17 @@ class TestOptimizedNVPDriverDeletePortMultiSwitch(TestOptimizedNVPDriver): @contextlib.contextmanager def _stubs(self, port_count=2, exception=None): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lport_select_by_id" % self.d_pkg), mock.patch("%s._lswitch_select_by_nvp_id" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), mock.patch("%s._lport_delete" % self.d_pkg), - ) as (get_connection, select_port, select_switch, + ) as (conn, select_port, select_switch, two_switch, port_delete): connection = self._create_connection() port = self._create_lport_mock(port_count) switch = self._create_lswitch_mock() - get_connection.return_value = connection + conn.return_value = connection select_port.return_value = port select_switch.return_value = switch two_switch.return_value = [switch, switch] @@ -244,15 +244,15 @@ class TestOptimizedNVPDriverDeletePortSingleSwitch(TestOptimizedNVPDriver): @contextlib.contextmanager def _stubs(self, port_count=2): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lport_select_by_id" % self.d_pkg), mock.patch("%s._lswitch_select_by_nvp_id" % self.d_pkg), mock.patch("%s._lswitches_for_network" % self.d_pkg), - ) as (get_connection, select_port, select_switch, one_switch): + ) as (conn, select_port, select_switch, one_switch): connection = self._create_connection() port = self._create_lport_mock(port_count) switch = self._create_lswitch_mock() - get_connection.return_value = connection + conn.return_value = connection select_port.return_value = port select_switch.return_value = switch one_switch.return_value = [switch] @@ -274,16 +274,16 @@ class TestOptimizedNVPDriverCreatePort(TestOptimizedNVPDriver): @contextlib.contextmanager def _stubs(self, has_lswitch=True, maxed_ports=False): with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._lswitch_select_free" % self.d_pkg), mock.patch("%s._lswitch_select_first" % self.d_pkg), mock.patch("%s._lswitch_select_by_nvp_id" % self.d_pkg), mock.patch("%s._lswitch_create_optimized" % self.d_pkg), mock.patch("%s._get_network_details" % self.d_pkg) - ) as (get_connection, select_free, select_first, + ) as (conn, select_free, select_first, select_by_id, create_opt, get_net_dets): connection = self._create_connection() - get_connection.return_value = connection + conn.return_value = connection if has_lswitch: select_first.return_value = mock.Mock(nvp_id=self.lswitch_uuid) if not has_lswitch: @@ -427,7 +427,7 @@ class TestOptimizedNVPDriverUpdatePort(TestOptimizedNVPDriver): class TestCreateSecurityGroups(TestOptimizedNVPDriver): def test_create_security_group(self): - with mock.patch("%s.get_connection" % self.d_pkg): + with mock.patch("%s._connection" % self.d_pkg): self.driver.create_security_group(self.context, "newgroup") self.assertTrue(self.context.session.add.called) @@ -436,7 +436,7 @@ class TestDeleteSecurityGroups(TestOptimizedNVPDriver): def test_delete_security_group(self): mod_path = "quark.drivers.nvp_driver.NVPDriver" with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._query_security_group" % self.d_pkg), mock.patch("%s.delete_security_group" % mod_path)): @@ -452,10 +452,10 @@ class TestSecurityGroupRules(TestOptimizedNVPDriver): def _stubs(self, rules=None): rules = rules or [] with contextlib.nested( - mock.patch("%s.get_connection" % self.d_pkg), + mock.patch("%s._connection" % self.d_pkg), mock.patch("%s._query_security_group" % self.d_pkg), mock.patch("%s._check_rule_count_per_port" % self.d_pkg), - ) as (get_connection, query_sec_group, rule_count): + ) as (conn, query_sec_group, rule_count): query_sec_group.return_value = (quark.drivers.optimized_nvp_driver. SecurityProfile()) connection = self._create_connection() @@ -464,7 +464,7 @@ class TestSecurityGroupRules(TestOptimizedNVPDriver): connection.securityrule = self._create_security_rule() connection.lswitch_port().query.return_value = ( self._create_lport_query(1, [self.profile_id])) - get_connection.return_value = connection + conn.return_value = connection old_query = self.context.session.query sec_group = quark.db.models.SecurityGroup()