diff --git a/neutron/plugins/openvswitch/ovs_neutron_plugin.py b/neutron/plugins/openvswitch/ovs_neutron_plugin.py index 005fb308b7..3b86238c74 100644 --- a/neutron/plugins/openvswitch/ovs_neutron_plugin.py +++ b/neutron/plugins/openvswitch/ovs_neutron_plugin.py @@ -44,6 +44,7 @@ from neutron.extensions import allowedaddresspairs as addr_pair from neutron.extensions import extra_dhcp_opt as edo_ext from neutron.extensions import portbindings from neutron.extensions import providernet as provider +from neutron.extensions import securitygroup as ext_sg from neutron import manager from neutron.openstack.common import importutils from neutron.openstack.common import log as logging @@ -603,8 +604,9 @@ class OVSNeutronPluginV2(db_base_plugin_v2.NeutronDbPluginV2, need_port_update_notify |= self._update_extra_dhcp_opts_on_port( context, id, port, updated_port) - need_port_update_notify |= self.is_security_group_member_updated( + secgrp_member_updated = self.is_security_group_member_updated( context, original_port, updated_port) + need_port_update_notify |= secgrp_member_updated if original_port['admin_state_up'] != updated_port['admin_state_up']: need_port_update_notify = True @@ -615,6 +617,14 @@ class OVSNeutronPluginV2(db_base_plugin_v2.NeutronDbPluginV2, binding.network_type, binding.segmentation_id, binding.physical_network) + + if secgrp_member_updated: + old_set = set(original_port.get(ext_sg.SECURITYGROUPS)) + new_set = set(updated_port.get(ext_sg.SECURITYGROUPS)) + self.notifier.security_groups_member_updated( + context, + old_set ^ new_set) + return updated_port def delete_port(self, context, id, l3_port_check=True): diff --git a/neutron/tests/unit/openvswitch/test_openvswitch_plugin.py b/neutron/tests/unit/openvswitch/test_openvswitch_plugin.py index 234a8feb82..af1c1d0425 100644 --- a/neutron/tests/unit/openvswitch/test_openvswitch_plugin.py +++ b/neutron/tests/unit/openvswitch/test_openvswitch_plugin.py @@ -15,12 +15,17 @@ from oslo.config import cfg +from neutron import context from neutron.extensions import portbindings +from neutron.extensions import securitygroup as ext_sg +from neutron.plugins.openvswitch import ovs_neutron_plugin from neutron.tests.unit import _test_extension_portbindings as test_bindings from neutron.tests.unit import test_db_plugin as test_plugin from neutron.tests.unit import test_extension_allowedaddresspairs as test_pair from neutron.tests.unit import test_security_groups_rpc as test_sg_rpc +import mock + class OpenvswitchPluginV2TestCase(test_plugin.NeutronDbPluginV2TestCase): @@ -86,3 +91,69 @@ class TestOpenvswitchPortBindingHost( class TestOpenvswitchAllowedAddressPairs(OpenvswitchPluginV2TestCase, test_pair.TestAllowedAddressPairs): pass + + +class TestOpenvswitchUpdatePort(OpenvswitchPluginV2TestCase, + ovs_neutron_plugin.OVSNeutronPluginV2): + + def test_update_port_add_remove_security_group(self): + get_port_func = ( + 'neutron.db.db_base_plugin_v2.' + 'NeutronDbPluginV2.get_port' + ) + with mock.patch(get_port_func) as mock_get_port: + mock_get_port.return_value = { + ext_sg.SECURITYGROUPS: ["sg1", "sg2"], + "admin_state_up": True, + "fixed_ips": "fake_ip", + "network_id": "fake_id"} + + update_port_func = ( + 'neutron.db.db_base_plugin_v2.' + 'NeutronDbPluginV2.update_port' + ) + with mock.patch(update_port_func) as mock_update_port: + mock_update_port.return_value = { + ext_sg.SECURITYGROUPS: ["sg2", "sg3"], + "admin_state_up": True, + "fixed_ips": "fake_ip", + "network_id": "fake_id"} + + fake_func = ( + 'neutron.plugins.openvswitch.' + 'ovs_db_v2.get_network_binding' + ) + with mock.patch(fake_func) as mock_func: + class MockBinding: + network_type = "fake" + segmentation_id = "fake" + physical_network = "fake" + + mock_func.return_value = MockBinding() + + ctx = context.Context('', 'somebody') + self.update_port(ctx, "id", { + "port": { + ext_sg.SECURITYGROUPS: [ + "sg2", "sg3"]}}) + + sgmu = self.notifier.security_groups_member_updated + sgmu.assert_called_with(ctx, set(['sg1', 'sg3'])) + + def setUp(self): + super(TestOpenvswitchUpdatePort, self).setUp() + self.update_security_group_on_port = mock.MagicMock(return_value=True) + self._process_portbindings_create_and_update = mock.MagicMock( + return_value=True) + self._update_extra_dhcp_opts_on_port = mock.MagicMock( + return_value=True) + self.update_address_pairs_on_port = mock.MagicMock( + return_value=True) + + class MockNotifier: + def __init__(self): + self.port_update = mock.MagicMock(return_value=True) + self.security_groups_member_updated = mock.MagicMock( + return_value=True) + + self.notifier = MockNotifier()