diff --git a/vmware_nsx/db/extended_security_group.py b/vmware_nsx/db/extended_security_group.py index dfd11ca57d..bb3b346e0b 100644 --- a/vmware_nsx/db/extended_security_group.py +++ b/vmware_nsx/db/extended_security_group.py @@ -168,8 +168,8 @@ class ExtendedSecurityGroupPropertiesMixin(object): NsxExtendedSecurityGroupProperties.security_group_id ).join(securitygroups_db.SecurityGroup).filter( securitygroups_db.SecurityGroup.tenant_id == tenant_id, - NsxExtendedSecurityGroupProperties.provider == sa.true()).scalar() - return [res] if res else [] + NsxExtendedSecurityGroupProperties.provider == sa.true()).all() + return [r[0] for r in res] def _validate_security_group_properties_create(self, context, security_group, default_sg): diff --git a/vmware_nsx/tests/unit/extensions/test_provider_security_groups.py b/vmware_nsx/tests/unit/extensions/test_provider_security_groups.py index f1c97915f6..73f019e5c1 100644 --- a/vmware_nsx/tests/unit/extensions/test_provider_security_groups.py +++ b/vmware_nsx/tests/unit/extensions/test_provider_security_groups.py @@ -168,12 +168,28 @@ class ProviderSecurityGroupExtTestCase( provider_secgroup = self._create_provider_security_group() with self.port(tenant_id=self._tenant_id) as p: # check that the provider security group is on port resource. + self.assertEqual(1, len(p['port']['provider_security_groups'])) self.assertEqual(provider_secgroup['security_group']['id'], p['port']['provider_security_groups'][0]) # confirm there is still a default security group. self.assertEqual(len(p['port']['security_groups']), 1) + def test_create_port_gets_multi_provider_sg(self): + # need to create provider security groups first. + provider_secgroup1 = self._create_provider_security_group() + provider_secgroup2 = self._create_provider_security_group() + with self.port(tenant_id=self._tenant_id) as p: + # check that the provider security group is on port resource. + self.assertEqual(2, len(p['port']['provider_security_groups'])) + self.assertIn(provider_secgroup1['security_group']['id'], + p['port']['provider_security_groups']) + self.assertIn(provider_secgroup2['security_group']['id'], + p['port']['provider_security_groups']) + + # confirm there is still a default security group. + self.assertEqual(len(p['port']['security_groups']), 1) + def test_create_port_with_no_provider_sg(self): self._create_provider_security_group() with self.port(tenant_id=self._tenant_id,