Batch ports from security groups RPC handler

The security groups RPC handler calls get_port_from_device
individually for each device in a list it receives. Each
one of these results in a separate SQL query for the security
groups and port details. This becomes very inefficient as the
number of devices on a single node increases.

This patch adds logic to the RPC handler to see if the core
plugin has a method to lookup all of the device IDs at once.
If so, it uses that method, otherwise it continues as normal.

The ML2 plugin is modified to include the batch function, which
uses one SQL query regardless of the number of devices.

Closes-Bug: #1374556
Change-Id: I15d19c22e8c44577db190309b6636a3251a9c66a
This commit is contained in:
Kevin Benton 2014-09-26 09:40:44 -07:00
parent 9e1f4ae5d7
commit 6161ea4d53
5 changed files with 165 additions and 59 deletions

View File

@ -36,15 +36,11 @@ class SecurityGroupServerRpcCallback(n_rpc.RpcCallback):
return manager.NeutronManager.get_plugin() return manager.NeutronManager.get_plugin()
def _get_devices_info(self, devices): def _get_devices_info(self, devices):
devices_info = {} return dict(
for device in devices: (port['id'], port)
port = self.plugin.get_port_from_device(device) for port in self.plugin.get_ports_from_devices(devices)
if not port: if port and not port['device_owner'].startswith('network:')
continue )
if port['device_owner'].startswith('network:'):
continue
devices_info[port['id']] = port
return devices_info
def security_group_rules_for_devices(self, context, **kwargs): def security_group_rules_for_devices(self, context, **kwargs):
"""Callback method to return security group rules for each port. """Callback method to return security group rules for each port.

View File

@ -40,7 +40,7 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
def get_port_from_device(self, device): def get_port_from_device(self, device):
"""Get port dict from device name on an agent. """Get port dict from device name on an agent.
Subclass must provide this method. Subclass must provide this method or get_ports_from_devices.
:param device: device name which identifies a port on the agent side. :param device: device name which identifies a port on the agent side.
What is specified in "device" depends on a plugin agent implementation. What is specified in "device" depends on a plugin agent implementation.
@ -54,9 +54,18 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin):
- security_group_source_groups - security_group_source_groups
- fixed_ips - fixed_ips
""" """
raise NotImplementedError(_("%s must implement get_port_from_device.") raise NotImplementedError(_("%s must implement get_port_from_device "
"or get_ports_from_devices.")
% self.__class__.__name__) % self.__class__.__name__)
def get_ports_from_devices(self, devices):
"""Bulk method of get_port_from_device.
Subclasses may override this to provide better performance for DB
queries, backend calls, etc.
"""
return [self.get_port_from_device(device) for device in devices]
def create_security_group_rule(self, context, security_group_rule): def create_security_group_rule(self, context, security_group_rule):
bulk_rule = {'security_group_rules': [security_group_rule]} bulk_rule = {'security_group_rules': [security_group_rule]}
rule = self.create_security_group_rule_bulk_native(context, rule = self.create_security_group_rule_bulk_native(context,

View File

@ -13,6 +13,9 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import collections
from sqlalchemy import or_
from sqlalchemy.orm import exc from sqlalchemy.orm import exc
from oslo.db import exception as db_exc from oslo.db import exception as db_exc
@ -30,6 +33,9 @@ from neutron.plugins.ml2 import models
LOG = log.getLogger(__name__) LOG = log.getLogger(__name__)
# limit the number of port OR LIKE statements in one query
MAX_PORTS_PER_QUERY = 500
def _make_segment_dict(record): def _make_segment_dict(record):
"""Make a segment dictionary out of a DB record.""" """Make a segment dictionary out of a DB record."""
@ -209,32 +215,64 @@ def get_port_from_device_mac(device_mac):
return qry.first() return qry.first()
def get_port_and_sgs(port_id): def get_ports_and_sgs(port_ids):
"""Get port from database with security group info.""" """Get ports from database with security group info."""
LOG.debug(_("get_port_and_sgs() called for port_id %s"), port_id) # break large queries into smaller parts
if len(port_ids) > MAX_PORTS_PER_QUERY:
LOG.debug("Number of ports %(pcount)s exceeds the maximum per "
"query %(maxp)s. Partitioning queries.",
{'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY})
return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) +
get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:]))
LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids)
if not port_ids:
# if port_ids is empty, avoid querying to DB to ask it for nothing
return []
ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids)
return [make_port_dict_with_security_groups(port, sec_groups)
for port, sec_groups in ports_to_sg_ids.iteritems()]
def get_sg_ids_grouped_by_port(port_ids):
sg_ids_grouped_by_port = collections.defaultdict(list)
session = db_api.get_session() session = db_api.get_session()
sg_binding_port = sg_db.SecurityGroupPortBinding.port_id sg_binding_port = sg_db.SecurityGroupPortBinding.port_id
with session.begin(subtransactions=True): with session.begin(subtransactions=True):
# partial UUIDs must be individually matched with startswith.
# full UUIDs may be matched directly in an IN statement
partial_uuids = set(port_id for port_id in port_ids
if not uuidutils.is_uuid_like(port_id))
full_uuids = set(port_ids) - partial_uuids
or_criteria = [models_v2.Port.id.startswith(port_id)
for port_id in partial_uuids]
if full_uuids:
or_criteria.append(models_v2.Port.id.in_(full_uuids))
query = session.query(models_v2.Port, query = session.query(models_v2.Port,
sg_db.SecurityGroupPortBinding.security_group_id) sg_db.SecurityGroupPortBinding.security_group_id)
query = query.outerjoin(sg_db.SecurityGroupPortBinding, query = query.outerjoin(sg_db.SecurityGroupPortBinding,
models_v2.Port.id == sg_binding_port) models_v2.Port.id == sg_binding_port)
query = query.filter(models_v2.Port.id.startswith(port_id)) query = query.filter(or_(*or_criteria))
port_and_sgs = query.all()
if not port_and_sgs: for port, sg_id in query:
return if sg_id:
port = port_and_sgs[0][0] sg_ids_grouped_by_port[port].append(sg_id)
plugin = manager.NeutronManager.get_plugin() return sg_ids_grouped_by_port
port_dict = plugin._make_port_dict(port)
port_dict['security_groups'] = [
sg_id for port_, sg_id in port_and_sgs if sg_id] def make_port_dict_with_security_groups(port, sec_groups):
port_dict['security_group_rules'] = [] plugin = manager.NeutronManager.get_plugin()
port_dict['security_group_source_groups'] = [] port_dict = plugin._make_port_dict(port)
port_dict['fixed_ips'] = [ip['ip_address'] port_dict['security_groups'] = sec_groups
for ip in port['fixed_ips']] port_dict['security_group_rules'] = []
return port_dict port_dict['security_group_source_groups'] = []
port_dict['fixed_ips'] = [ip['ip_address']
for ip in port['fixed_ips']]
return port_dict
def get_port_binding_host(port_id): def get_port_binding_host(port_id):

View File

@ -1176,12 +1176,18 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
port_host = db.get_port_binding_host(port_id) port_host = db.get_port_binding_host(port_id)
return (port_host == host) return (port_host == host)
def get_port_from_device(self, device): def get_ports_from_devices(self, devices):
port_id = self._device_to_port_id(device) port_ids_to_devices = dict((self._device_to_port_id(device), device)
port = db.get_port_and_sgs(port_id) for device in devices)
if port: port_ids = port_ids_to_devices.keys()
port['device'] = device ports = db.get_ports_and_sgs(port_ids)
return port for port in ports:
# map back to original requested id
port_id = next((port_id for port_id in port_ids
if port['id'].startswith(port_id)), None)
port['device'] = port_ids_to_devices.get(port_id)
return ports
def _device_to_port_id(self, device): def _device_to_port_id(self, device):
# REVISIT(rkukura): Consider calling into MechanismDrivers to # REVISIT(rkukura): Consider calling into MechanismDrivers to

View File

@ -14,11 +14,15 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import contextlib
import math
import mock import mock
from neutron.api.v2 import attributes from neutron.api.v2 import attributes
from neutron.common import constants as const
from neutron.extensions import securitygroup as ext_sg from neutron.extensions import securitygroup as ext_sg
from neutron import manager from neutron import manager
from neutron.tests.unit import test_api_v2
from neutron.tests.unit import test_extension_security_group as test_sg from neutron.tests.unit import test_extension_security_group as test_sg
from neutron.tests.unit import test_security_groups_rpc as test_sg_rpc from neutron.tests.unit import test_security_groups_rpc as test_sg_rpc
@ -55,38 +59,91 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase,
plugin = manager.NeutronManager.get_plugin() plugin = manager.NeutronManager.get_plugin()
plugin.start_rpc_listeners() plugin.start_rpc_listeners()
def test_security_group_get_port_from_device(self): def _make_port_with_new_sec_group(self, net_id):
sg = self._make_security_group(self.fmt, 'name', 'desc')
port = self._make_port(
self.fmt, net_id, security_groups=[sg['security_group']['id']])
return port['port']
def test_security_group_get_ports_from_devices(self):
with self.network() as n: with self.network() as n:
with self.subnet(n): with self.subnet(n):
with self.security_group() as sg: port1 = self._make_port_with_new_sec_group(n['network']['id'])
security_group_id = sg['security_group']['id'] port2 = self._make_port_with_new_sec_group(n['network']['id'])
res = self._create_port(self.fmt, n['network']['id']) plugin = manager.NeutronManager.get_plugin()
port = self.deserialize(self.fmt, res) # should match full ID and starting chars
fixed_ips = port['port']['fixed_ips'] ports = plugin.get_ports_from_devices(
data = {'port': {'fixed_ips': fixed_ips, [port1['id'], port2['id'][0:8]])
'name': port['port']['name'], self.assertEqual(2, len(ports))
ext_sg.SECURITYGROUPS: for port_dict in ports:
[security_group_id]}} p = port1 if port1['id'] == port_dict['id'] else port2
self.assertEqual(p['id'], port_dict['id'])
req = self.new_update_request('ports', data, self.assertEqual(p['security_groups'],
port['port']['id'])
res = self.deserialize(self.fmt,
req.get_response(self.api))
port_id = res['port']['id']
plugin = manager.NeutronManager.get_plugin()
port_dict = plugin.get_port_from_device(port_id)
self.assertEqual(port_id, port_dict['id'])
self.assertEqual([security_group_id],
port_dict[ext_sg.SECURITYGROUPS]) port_dict[ext_sg.SECURITYGROUPS])
self.assertEqual([], port_dict['security_group_rules']) self.assertEqual([], port_dict['security_group_rules'])
self.assertEqual([fixed_ips[0]['ip_address']], self.assertEqual([p['fixed_ips'][0]['ip_address']],
port_dict['fixed_ips']) port_dict['fixed_ips'])
self._delete('ports', port_id) self._delete('ports', p['id'])
def test_security_group_get_port_from_device_with_no_port(self): def test_security_group_get_ports_from_devices_with_bad_id(self):
plugin = manager.NeutronManager.get_plugin() plugin = manager.NeutronManager.get_plugin()
port_dict = plugin.get_port_from_device('bad_device_id') ports = plugin.get_ports_from_devices(['bad_device_id'])
self.assertIsNone(port_dict) self.assertFalse(ports)
def test_security_group_no_db_calls_with_no_ports(self):
plugin = manager.NeutronManager.get_plugin()
with mock.patch(
'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port'
) as get_mock:
self.assertFalse(plugin.get_ports_from_devices([]))
self.assertFalse(get_mock.called)
def test_large_port_count_broken_into_parts(self):
plugin = manager.NeutronManager.get_plugin()
max_ports_per_query = 5
ports_to_query = 73
for max_ports_per_query in (1, 2, 5, 7, 9, 31):
with contextlib.nested(
mock.patch('neutron.plugins.ml2.db.MAX_PORTS_PER_QUERY',
new=max_ports_per_query),
mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port',
return_value={}),
) as (max_mock, get_mock):
plugin.get_ports_from_devices(
['%s%s' % (const.TAP_DEVICE_PREFIX, i)
for i in range(ports_to_query)])
all_call_args = map(lambda x: x[1][0], get_mock.mock_calls)
last_call_args = all_call_args.pop()
# all but last should be getting MAX_PORTS_PER_QUERY ports
self.assertTrue(
all(map(lambda x: len(x) == max_ports_per_query,
all_call_args))
)
remaining = ports_to_query % max_ports_per_query
if remaining:
self.assertEqual(remaining, len(last_call_args))
# should be broken into ceil(total/MAX_PORTS_PER_QUERY) calls
self.assertEqual(
math.ceil(ports_to_query / float(max_ports_per_query)),
get_mock.call_count
)
def test_full_uuids_skip_port_id_lookup(self):
plugin = manager.NeutronManager.get_plugin()
# when full UUIDs are provided, the _or statement should only
# have one matching 'IN' critiera for all of the IDs
with contextlib.nested(
mock.patch('neutron.plugins.ml2.db.or_'),
mock.patch('neutron.plugins.ml2.db.db_api.get_session')
) as (or_mock, sess_mock):
fmock = sess_mock.query.return_value.outerjoin.return_value.filter
# return no ports to exit the method early since we are mocking
# the query
fmock.return_value.all.return_value = []
plugin.get_ports_from_devices([test_api_v2._uuid(),
test_api_v2._uuid()])
# the or_ function should only have one argument
or_mock.assert_called_once_with(mock.ANY)
class TestMl2SGServerRpcCallBack( class TestMl2SGServerRpcCallBack(