diff --git a/vmware_nsx/common/exceptions.py b/vmware_nsx/common/exceptions.py index da87113e4b..9bb4e9037b 100644 --- a/vmware_nsx/common/exceptions.py +++ b/vmware_nsx/common/exceptions.py @@ -157,3 +157,7 @@ class NsxL2GWDeviceNotFound(n_exc.NotFound): class NsxL2GWInUse(n_exc.InUse): message = _("L2 Gateway '%(gateway_id)s' has been used") + + +class InvalidIPAddress(n_exc.InvalidInput): + message = _("'%(ip_address)s' must be a /32 CIDR based IPv4 address") diff --git a/vmware_nsx/common/utils.py b/vmware_nsx/common/utils.py index 078e47dc16..0a26991ad2 100644 --- a/vmware_nsx/common/utils.py +++ b/vmware_nsx/common/utils.py @@ -236,3 +236,24 @@ def get_name_and_uuid(name, uuid, tag=None, maxlen=80): return name[:maxlen] + '_' + tag + short_uuid else: return name[:maxlen] + short_uuid + + +def is_ipv4_ip_address(addr): + + def _valid_part(part): + try: + int_part = int(part) + if int_part < 0 or int_part > 255: + return False + return True + except ValueError: + return False + + parts = str(addr).split('.') + if len(parts) != 4: + return False + + for ip_part in parts: + if not _valid_part(ip_part): + return False + return True diff --git a/vmware_nsx/plugins/nsx_v3/plugin.py b/vmware_nsx/plugins/nsx_v3/plugin.py index a0d7d0318e..d95a302a15 100644 --- a/vmware_nsx/plugins/nsx_v3/plugin.py +++ b/vmware_nsx/plugins/nsx_v3/plugin.py @@ -657,6 +657,12 @@ class NsxV3Plugin(addr_pair_db.AllowedAddressPairsMixin, neutron_db['network_id'], result['id']) return result + def _validate_address_pairs(self, address_pairs): + for pair in address_pairs: + ip = pair.get('ip_address') + if not utils.is_ipv4_ip_address(ip): + raise nsx_exc.InvalidIPAddress(ip_address=ip) + def _create_port_preprocess_security( self, context, port, port_data, neutron_db): (port_security, has_ip) = self._determine_port_security_and_has_ip( @@ -665,13 +671,15 @@ class NsxV3Plugin(addr_pair_db.AllowedAddressPairsMixin, self._process_port_port_security_create( context, port_data, neutron_db) # allowed address pair checks - if attributes.is_attr_set(port_data.get(addr_pair.ADDRESS_PAIRS)): + address_pairs = port_data.get(addr_pair.ADDRESS_PAIRS) + if attributes.is_attr_set(address_pairs): if not port_security: raise addr_pair.AddressPairAndPortSecurityRequired() else: + self._validate_address_pairs(address_pairs) self._process_create_allowed_address_pairs( context, neutron_db, - port_data[addr_pair.ADDRESS_PAIRS]) + address_pairs) else: # remove ATTR_NOT_SPECIFIED port_data[addr_pair.ADDRESS_PAIRS] = [] @@ -787,6 +795,8 @@ class NsxV3Plugin(addr_pair_db.AllowedAddressPairsMixin, raise addr_pair.AddressPairAndPortSecurityRequired() if delete_addr_pairs or has_addr_pairs: + self._validate_address_pairs( + updated_port[addr_pair.ADDRESS_PAIRS]) # delete address pairs and read them in self._delete_allowed_address_pairs(context, id) self._process_create_allowed_address_pairs( @@ -929,6 +939,18 @@ class NsxV3Plugin(addr_pair_db.AllowedAddressPairsMixin, with context.session.begin(subtransactions=True): super(NsxV3Plugin, self).update_port( context, id, {'port': original_port}) + + # revert allowed address pairs + if port_security: + orig_pair = original_port.get(addr_pair.ADDRESS_PAIRS) + updated_pair = updated_port.get( + addr_pair.ADDRESS_PAIRS) + if orig_pair != updated_pair: + self._delete_allowed_address_pairs(context, id) + if orig_pair: + self._process_create_allowed_address_pairs( + context, original_port, orig_pair) + if sec_grp_updated: self.update_security_group_on_port( context, id, {'port': original_port}, updated_port, diff --git a/vmware_nsx/tests/unit/extensions/test_addresspairs.py b/vmware_nsx/tests/unit/extensions/test_addresspairs.py index e8e8b0d414..7c256f0f5c 100644 --- a/vmware_nsx/tests/unit/extensions/test_addresspairs.py +++ b/vmware_nsx/tests/unit/extensions/test_addresspairs.py @@ -49,6 +49,25 @@ class TestAllowedAddressPairsNSXv3(test_v3_plugin.NsxV3PluginTestCaseMixin, super(TestAllowedAddressPairsNSXv3, self).setUp( plugin=plugin, ext_mgr=ext_mgr, service_plugins=service_plugins) + def test_create_bad_address_pairs_with_cidr(self): + address_pairs = [{'mac_address': '00:00:00:00:00:01', + 'ip_address': '10.0.0.1/24'}] + self._create_port_with_address_pairs(address_pairs, 400) + + def test_update_add_bad_address_pairs_with_cidr(self): + with self.network() as net: + res = self._create_port(self.fmt, net['network']['id']) + port = self.deserialize(self.fmt, res) + address_pairs = [{'mac_address': '00:00:00:00:00:01', + 'ip_address': '10.0.0.1/24'}] + update_port = {'port': {addr_pair.ADDRESS_PAIRS: + address_pairs}} + req = self.new_update_request('ports', update_port, + port['port']['id']) + res = req.get_response(self.api) + self.assertEqual(res.status_int, 400) + self._delete('ports', port['port']['id']) + def test_create_port_security_false_allowed_address_pairs(self): self.skipTest('TBD')