diff --git a/neutron/db/portbindings_db.py b/neutron/db/portbindings_db.py index 678f743345..c7ab4d37e2 100644 --- a/neutron/db/portbindings_db.py +++ b/neutron/db/portbindings_db.py @@ -73,17 +73,17 @@ class PortBindingMixin(portbindings_base.PortBindingBaseMixin): del port[portbindings.PROFILE] host = port_data.get(portbindings.HOST_ID) host_set = attributes.is_attr_set(host) - if not host_set: - self._extend_port_dict_binding_host(port, None) - return with context.session.begin(subtransactions=True): bind_port = context.session.query( PortBindingPort).filter_by(port_id=port['id']).first() - if not bind_port: - context.session.add(PortBindingPort(port_id=port['id'], - host=host)) + if host_set: + if not bind_port: + context.session.add(PortBindingPort(port_id=port['id'], + host=host)) + else: + bind_port.host = host else: - bind_port.host = host + host = (bind_port and bind_port.host or None) self._extend_port_dict_binding_host(port, host) def get_port_host(self, context, port_id): diff --git a/neutron/tests/unit/_test_extension_portbindings.py b/neutron/tests/unit/_test_extension_portbindings.py index 85df2f7cac..f362f92ead 100644 --- a/neutron/tests/unit/_test_extension_portbindings.py +++ b/neutron/tests/unit/_test_extension_portbindings.py @@ -247,6 +247,24 @@ class PortBindingsHostTestCaseMixin(object): for port in ports: self.assertEqual('testhosttemp', port[portbindings.HOST_ID]) + def test_ports_vif_non_host_update(self): + host_arg = {portbindings.HOST_ID: self.hostname} + with self.port(name='name', arg_list=(portbindings.HOST_ID,), + **host_arg) as port: + data = {'port': {'admin_state_up': False}} + req = self.new_update_request('ports', data, port['port']['id']) + res = self.deserialize(self.fmt, req.get_response(self.api)) + self.assertEqual(port['port'][portbindings.HOST_ID], + res['port'][portbindings.HOST_ID]) + + def test_ports_vif_non_host_update_when_host_null(self): + with self.port() as port: + data = {'port': {'admin_state_up': False}} + req = self.new_update_request('ports', data, port['port']['id']) + res = self.deserialize(self.fmt, req.get_response(self.api)) + self.assertEqual(port['port'][portbindings.HOST_ID], + res['port'][portbindings.HOST_ID]) + def test_ports_vif_host_list(self): cfg.CONF.set_default('allow_overlapping_ips', True) host_arg = {portbindings.HOST_ID: self.hostname}