Merge "Filter port-list based on security-group"

This commit is contained in:
Zuul 2018-03-28 11:11:19 +00:00 committed by Gerrit Code Review
commit 456ac69e49
5 changed files with 101 additions and 2 deletions

View File

@ -178,6 +178,21 @@ class NsxPluginBase(db_base_plugin_v2.NeutronDbPluginV2,
device_id=device_id, device_id=device_id,
device_owner=device_owner,).all() device_owner=device_owner,).all()
def _update_filters_with_sec_group(self, context, filters=None):
if filters is not None:
security_groups = filters.pop("security_groups", None)
if security_groups:
bindings = (
super(NsxPluginBase, self)
._get_port_security_group_bindings(context,
filters={'security_group_id': security_groups}))
if 'id' in filters:
filters['id'] = [entry['port_id'] for
entry in bindings
if entry['port_id'] in filters['id']]
else:
filters['id'] = [entry['port_id'] for entry in bindings]
def _find_router_subnets(self, context, router_id): def _find_router_subnets(self, context, router_id):
"""Retrieve subnets attached to the specified router.""" """Retrieve subnets attached to the specified router."""
ports = self._get_port_by_device_id(context, router_id, ports = self._get_port_by_device_id(context, router_id,

View File

@ -208,7 +208,8 @@ class NsxVPluginV2(addr_pair_db.AllowedAddressPairsMixin,
"flavors", "flavors",
"dhcp-mtu", "dhcp-mtu",
"mac-learning", "mac-learning",
"housekeeper"] "housekeeper",
"port-security-groups-filtering"]
__native_bulk_support = True __native_bulk_support = True
__native_pagination_support = True __native_pagination_support = True
@ -2449,6 +2450,19 @@ class NsxVPluginV2(addr_pair_db.AllowedAddressPairsMixin,
self._extend_get_port_dict_qos(context, port) self._extend_get_port_dict_qos(context, port)
return db_utils.resource_fields(port, fields) return db_utils.resource_fields(port, fields)
def get_ports(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None,
page_reverse=False):
filters = filters or {}
self._update_filters_with_sec_group(context, filters)
with db_api.context_manager.reader.using(context):
ports = (
super(NsxVPluginV2, self).get_ports(
context, filters, fields, sorts,
limit, marker, page_reverse))
return (ports if not fields else
[db_utils.resource_fields(port, fields) for port in ports])
def delete_port(self, context, id, l3_port_check=True, def delete_port(self, context, id, l3_port_check=True,
nw_gw_port_check=True, force_delete_dhcp=False, nw_gw_port_check=True, force_delete_dhcp=False,
allow_delete_internal=False): allow_delete_internal=False):

View File

@ -207,7 +207,8 @@ class NsxV3Plugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
"router_availability_zone", "router_availability_zone",
"subnet_allocation", "subnet_allocation",
"security-group-logging", "security-group-logging",
"provider-security-group"] "provider-security-group",
"port-security-groups-filtering"]
@resource_registry.tracked_resources( @resource_registry.tracked_resources(
network=models_v2.Network, network=models_v2.Network,
@ -3331,6 +3332,7 @@ class NsxV3Plugin(agentschedulers_db.AZDhcpAgentSchedulerDbMixin,
sorts=None, limit=None, marker=None, sorts=None, limit=None, marker=None,
page_reverse=False): page_reverse=False):
filters = filters or {} filters = filters or {}
self._update_filters_with_sec_group(context, filters)
with db_api.context_manager.reader.using(context): with db_api.context_manager.reader.using(context):
ports = ( ports = (
super(NsxV3Plugin, self).get_ports( super(NsxV3Plugin, self).get_ports(

View File

@ -19,6 +19,7 @@ import copy
from eventlet import greenthread from eventlet import greenthread
import mock import mock
import netaddr import netaddr
from neutron.db import securitygroups_db as sg_db
from neutron.extensions import address_scope from neutron.extensions import address_scope
from neutron.extensions import l3 from neutron.extensions import l3
from neutron.extensions import securitygroup as secgrp from neutron.extensions import securitygroup as secgrp
@ -1069,6 +1070,39 @@ class TestPortsV2(NsxVPluginV2TestCase,
[('admin_state_up', 'asc'), [('admin_state_up', 'asc'),
('mac_address', 'desc')]) ('mac_address', 'desc')])
def test_list_ports_filtered_by_security_groups(self):
ctx = context.get_admin_context()
with self.port() as port1, self.port() as port2:
query_params = "security_groups=%s" % (
port1['port']['security_groups'][0])
ports_data = self._list('ports', query_params=query_params)
self.assertEqual(set([port1['port']['id'], port2['port']['id']]),
set([port['id'] for port in ports_data['ports']]))
query_params = "security_groups=%s&id=%s" % (
port1['port']['security_groups'][0],
port1['port']['id'])
ports_data = self._list('ports', query_params=query_params)
self.assertEqual(port1['port']['id'], ports_data['ports'][0]['id'])
self.assertEqual(1, len(ports_data['ports']))
temp_sg = {'security_group': {'tenant_id': 'some_tenant',
'name': '', 'description': 's'}}
sg_dbMixin = sg_db.SecurityGroupDbMixin()
sg = sg_dbMixin.create_security_group(ctx, temp_sg)
sg_dbMixin._delete_port_security_group_bindings(
ctx, port2['port']['id'])
sg_dbMixin._create_port_security_group_binding(
ctx, port2['port']['id'], sg['id'])
port2['port']['security_groups'][0] = sg['id']
query_params = "security_groups=%s" % (
port1['port']['security_groups'][0])
ports_data = self._list('ports', query_params=query_params)
self.assertEqual(port1['port']['id'], ports_data['ports'][0]['id'])
self.assertEqual(1, len(ports_data['ports']))
query_params = "security_groups=%s" % (
(port2['port']['security_groups'][0]))
ports_data = self._list('ports', query_params=query_params)
self.assertEqual(port2['port']['id'], ports_data['ports'][0]['id'])
def test_update_port_delete_ip(self): def test_update_port_delete_ip(self):
# This test case overrides the default because the nsx plugin # This test case overrides the default because the nsx plugin
# implements port_security/security groups and it is not allowed # implements port_security/security groups and it is not allowed

View File

@ -16,6 +16,7 @@
import mock import mock
import netaddr import netaddr
from neutron.db import models_v2 from neutron.db import models_v2
from neutron.db import securitygroups_db as sg_db
from neutron.extensions import address_scope from neutron.extensions import address_scope
from neutron.extensions import l3 from neutron.extensions import l3
from neutron.extensions import securitygroup as secgrp from neutron.extensions import securitygroup as secgrp
@ -912,6 +913,39 @@ class TestPortsV2(test_plugin.TestPortsV2, NsxV3PluginTestCaseMixin,
self._get_ports_with_fields(tenid, 'mac_address', 4) self._get_ports_with_fields(tenid, 'mac_address', 4)
self._get_ports_with_fields(tenid, 'network_id', 4) self._get_ports_with_fields(tenid, 'network_id', 4)
def test_list_ports_filtered_by_security_groups(self):
ctx = context.get_admin_context()
with self.port() as port1, self.port() as port2:
query_params = "security_groups=%s" % (
port1['port']['security_groups'][0])
ports_data = self._list('ports', query_params=query_params)
self.assertEqual(set([port1['port']['id'], port2['port']['id']]),
set([port['id'] for port in ports_data['ports']]))
query_params = "security_groups=%s&id=%s" % (
port1['port']['security_groups'][0],
port1['port']['id'])
ports_data = self._list('ports', query_params=query_params)
self.assertEqual(port1['port']['id'], ports_data['ports'][0]['id'])
self.assertEqual(1, len(ports_data['ports']))
temp_sg = {'security_group': {'tenant_id': 'some_tenant',
'name': '', 'description': 's'}}
sg_dbMixin = sg_db.SecurityGroupDbMixin()
sg = sg_dbMixin.create_security_group(ctx, temp_sg)
sg_dbMixin._delete_port_security_group_bindings(
ctx, port2['port']['id'])
sg_dbMixin._create_port_security_group_binding(
ctx, port2['port']['id'], sg['id'])
port2['port']['security_groups'][0] = sg['id']
query_params = "security_groups=%s" % (
port1['port']['security_groups'][0])
ports_data = self._list('ports', query_params=query_params)
self.assertEqual(port1['port']['id'], ports_data['ports'][0]['id'])
self.assertEqual(1, len(ports_data['ports']))
query_params = "security_groups=%s" % (
(port2['port']['security_groups'][0]))
ports_data = self._list('ports', query_params=query_params)
self.assertEqual(port2['port']['id'], ports_data['ports'][0]['id'])
def test_port_failure_rollback_dhcp_exception(self): def test_port_failure_rollback_dhcp_exception(self):
cfg.CONF.set_override('native_dhcp_metadata', True, 'nsx_v3') cfg.CONF.set_override('native_dhcp_metadata', True, 'nsx_v3')
self.plugin = directory.get_plugin() self.plugin = directory.get_plugin()