# vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2012 Nicira Networks, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. # # @author: Aaron Rosen, Nicira, Inc import sqlalchemy as sa from sqlalchemy import orm from sqlalchemy.orm import exc from sqlalchemy.orm import scoped_session from neutron.api.v2 import attributes as attr from neutron.db import db_base_plugin_v2 from neutron.db import model_base from neutron.db import models_v2 from neutron.extensions import securitygroup as ext_sg from neutron.openstack.common import uuidutils class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant): """Represents a v2 neutron security group.""" name = sa.Column(sa.String(255)) description = sa.Column(sa.String(255)) class SecurityGroupPortBinding(model_base.BASEV2): """Represents binding between neutron ports and security profiles.""" port_id = sa.Column(sa.String(36), sa.ForeignKey("ports.id", ondelete='CASCADE'), primary_key=True) security_group_id = sa.Column(sa.String(36), sa.ForeignKey("securitygroups.id"), primary_key=True) # Add a relationship to the Port model in order to instruct SQLAlchemy to # eagerly load security group bindings ports = orm.relationship( models_v2.Port, backref=orm.backref("security_groups", lazy='joined', cascade='delete')) class SecurityGroupRule(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant): """Represents a v2 neutron security group rule.""" security_group_id = sa.Column(sa.String(36), sa.ForeignKey("securitygroups.id", ondelete="CASCADE"), nullable=False) remote_group_id = sa.Column(sa.String(36), sa.ForeignKey("securitygroups.id", ondelete="CASCADE"), nullable=True) direction = sa.Column(sa.Enum('ingress', 'egress', name='securitygrouprules_direction')) ethertype = sa.Column(sa.String(40)) protocol = sa.Column(sa.String(40)) port_range_min = sa.Column(sa.Integer) port_range_max = sa.Column(sa.Integer) remote_ip_prefix = sa.Column(sa.String(255)) security_group = orm.relationship( SecurityGroup, backref=orm.backref('rules', cascade='all,delete'), primaryjoin="SecurityGroup.id==SecurityGroupRule.security_group_id") source_group = orm.relationship( SecurityGroup, backref=orm.backref('source_rules', cascade='all,delete'), primaryjoin="SecurityGroup.id==SecurityGroupRule.remote_group_id") class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): """Mixin class to add security group to db_plugin_base_v2.""" __native_bulk_support = True def create_security_group_bulk(self, context, security_group_rule): return self._create_bulk('security_group', context, security_group_rule) def create_security_group(self, context, security_group, default_sg=False): """Create security group. If default_sg is true that means we are a default security group for a given tenant if it does not exist. """ s = security_group['security_group'] tenant_id = self._get_tenant_id_for_create(context, s) if not default_sg: self._ensure_default_security_group(context, tenant_id) with context.session.begin(subtransactions=True): security_group_db = SecurityGroup(id=s.get('id') or ( uuidutils.generate_uuid()), description=s['description'], tenant_id=tenant_id, name=s['name']) context.session.add(security_group_db) for ethertype in ext_sg.sg_supported_ethertypes: if s.get('name') == 'default': # Allow intercommunication ingress_rule = SecurityGroupRule( id=uuidutils.generate_uuid(), tenant_id=tenant_id, security_group=security_group_db, direction='ingress', ethertype=ethertype, source_group=security_group_db) context.session.add(ingress_rule) egress_rule = SecurityGroupRule( id=uuidutils.generate_uuid(), tenant_id=tenant_id, security_group=security_group_db, direction='egress', ethertype=ethertype) context.session.add(egress_rule) return self._make_security_group_dict(security_group_db) def get_security_groups(self, context, filters=None, fields=None, sorts=None, limit=None, marker=None, page_reverse=False, default_sg=False): # If default_sg is True do not call _ensure_default_security_group() # so this can be done recursively. Context.tenant_id is checked # because all the unit tests do not explicitly set the context on # GETS. TODO(arosen) context handling can probably be improved here. if not default_sg and context.tenant_id: self._ensure_default_security_group(context, context.tenant_id) marker_obj = self._get_marker_obj(context, 'security_group', limit, marker) return self._get_collection(context, SecurityGroup, self._make_security_group_dict, filters=filters, fields=fields, sorts=sorts, limit=limit, marker_obj=marker_obj, page_reverse=page_reverse) def get_security_groups_count(self, context, filters=None): return self._get_collection_count(context, SecurityGroup, filters=filters) def get_security_group(self, context, id, fields=None, tenant_id=None): """Tenant id is given to handle the case when creating a security group rule on behalf of another use. """ if tenant_id: tmp_context_tenant_id = context.tenant_id context.tenant_id = tenant_id try: with context.session.begin(subtransactions=True): ret = self._make_security_group_dict(self._get_security_group( context, id), fields) ret['security_group_rules'] = self.get_security_group_rules( context, {'security_group_id': [id]}) finally: if tenant_id: context.tenant_id = tmp_context_tenant_id return ret def _get_security_group(self, context, id): try: query = self._model_query(context, SecurityGroup) sg = query.filter(SecurityGroup.id == id).one() except exc.NoResultFound: raise ext_sg.SecurityGroupNotFound(id=id) return sg def delete_security_group(self, context, id): filters = {'security_group_id': [id]} ports = self._get_port_security_group_bindings(context, filters) if ports: raise ext_sg.SecurityGroupInUse(id=id) # confirm security group exists sg = self._get_security_group(context, id) if sg['name'] == 'default' and not context.is_admin: raise ext_sg.SecurityGroupCannotRemoveDefault() with context.session.begin(subtransactions=True): context.session.delete(sg) def update_security_group(self, context, id, security_group): s = security_group['security_group'] with context.session.begin(subtransactions=True): sg = self._get_security_group(context, id) if sg['name'] == 'default' and 'name' in s: raise ext_sg.SecurityGroupCannotUpdateDefault() sg.update(s) return self._make_security_group_dict(sg) def _make_security_group_dict(self, security_group, fields=None): res = {'id': security_group['id'], 'name': security_group['name'], 'tenant_id': security_group['tenant_id'], 'description': security_group['description']} res['security_group_rules'] = [self._make_security_group_rule_dict(r) for r in security_group.rules] return self._fields(res, fields) def _make_security_group_binding_dict(self, security_group, fields=None): res = {'port_id': security_group['port_id'], 'security_group_id': security_group['security_group_id']} return self._fields(res, fields) def _create_port_security_group_binding(self, context, port_id, security_group_id): with context.session.begin(subtransactions=True): db = SecurityGroupPortBinding(port_id=port_id, security_group_id=security_group_id) context.session.add(db) def _get_port_security_group_bindings(self, context, filters=None, fields=None): return self._get_collection(context, SecurityGroupPortBinding, self._make_security_group_binding_dict, filters=filters, fields=fields) def _delete_port_security_group_bindings(self, context, port_id): query = self._model_query(context, SecurityGroupPortBinding) bindings = query.filter( SecurityGroupPortBinding.port_id == port_id) with context.session.begin(subtransactions=True): for binding in bindings: context.session.delete(binding) def create_security_group_rule_bulk(self, context, security_group_rule): return self._create_bulk('security_group_rule', context, security_group_rule) def create_security_group_rule_bulk_native(self, context, security_group_rule): r = security_group_rule['security_group_rules'] scoped_session(context.session) security_group_id = self._validate_security_group_rules( context, security_group_rule) with context.session.begin(subtransactions=True): if not self.get_security_group(context, security_group_id): raise ext_sg.SecurityGroupNotFound(id=security_group_id) self._check_for_duplicate_rules(context, r) ret = [] for rule_dict in r: rule = rule_dict['security_group_rule'] tenant_id = self._get_tenant_id_for_create(context, rule) db = SecurityGroupRule( id=uuidutils.generate_uuid(), tenant_id=tenant_id, security_group_id=rule['security_group_id'], direction=rule['direction'], remote_group_id=rule.get('remote_group_id'), ethertype=rule['ethertype'], protocol=rule['protocol'], port_range_min=rule['port_range_min'], port_range_max=rule['port_range_max'], remote_ip_prefix=rule.get('remote_ip_prefix')) context.session.add(db) ret.append(self._make_security_group_rule_dict(db)) return ret def create_security_group_rule(self, context, security_group_rule): bulk_rule = {'security_group_rules': [security_group_rule]} return self.create_security_group_rule_bulk_native(context, bulk_rule)[0] def _validate_security_group_rules(self, context, security_group_rule): """Check that rules being installed. Check that all rules belong to the same security group, remote_group_id/security_group_id belong to the same tenant, and rules are valid. """ new_rules = set() tenant_ids = set() for rules in security_group_rule['security_group_rules']: rule = rules.get('security_group_rule') new_rules.add(rule['security_group_id']) # Check that port_range's are valid if (rule['port_range_min'] is None and rule['port_range_max'] is None): pass elif (rule['port_range_min'] is not None and rule['port_range_min'] <= rule['port_range_max']): if not rule['protocol']: raise ext_sg.SecurityGroupProtocolRequiredWithPorts() else: raise ext_sg.SecurityGroupInvalidPortRange() if rule['remote_ip_prefix'] and rule['remote_group_id']: raise ext_sg.SecurityGroupRemoteGroupAndRemoteIpPrefix() if rule['tenant_id'] not in tenant_ids: tenant_ids.add(rule['tenant_id']) remote_group_id = rule.get('remote_group_id') # Check that remote_group_id exists for tenant if remote_group_id: self.get_security_group(context, remote_group_id, tenant_id=rule['tenant_id']) if len(new_rules) > 1: raise ext_sg.SecurityGroupNotSingleGroupRules() security_group_id = new_rules.pop() # Confirm single tenant and that the tenant has permission # to add rules to this security group. if len(tenant_ids) > 1: raise ext_sg.SecurityGroupRulesNotSingleTenant() for tenant_id in tenant_ids: self.get_security_group(context, security_group_id, tenant_id=tenant_id) return security_group_id def _make_security_group_rule_dict(self, security_group_rule, fields=None): res = {'id': security_group_rule['id'], 'tenant_id': security_group_rule['tenant_id'], 'security_group_id': security_group_rule['security_group_id'], 'ethertype': security_group_rule['ethertype'], 'direction': security_group_rule['direction'], 'protocol': security_group_rule['protocol'], 'port_range_min': security_group_rule['port_range_min'], 'port_range_max': security_group_rule['port_range_max'], 'remote_ip_prefix': security_group_rule['remote_ip_prefix'], 'remote_group_id': security_group_rule['remote_group_id']} return self._fields(res, fields) def _make_security_group_rule_filter_dict(self, security_group_rule): sgr = security_group_rule['security_group_rule'] res = {'tenant_id': [sgr['tenant_id']], 'security_group_id': [sgr['security_group_id']], 'direction': [sgr['direction']]} include_if_present = ['protocol', 'port_range_max', 'port_range_min', 'ethertype', 'remote_ip_prefix', 'remote_group_id'] for key in include_if_present: value = sgr.get(key) if value: res[key] = [value] return res def _check_for_duplicate_rules(self, context, security_group_rules): for i in security_group_rules: found_self = False for j in security_group_rules: if i['security_group_rule'] == j['security_group_rule']: if found_self: raise ext_sg.DuplicateSecurityGroupRuleInPost(rule=i) found_self = True # Check in database if rule exists filters = self._make_security_group_rule_filter_dict(i) rules = self.get_security_group_rules(context, filters) if rules: raise ext_sg.SecurityGroupRuleExists(id=str(rules[0]['id'])) def get_security_group_rules(self, context, filters=None, fields=None, sorts=None, limit=None, marker=None, page_reverse=False): marker_obj = self._get_marker_obj(context, 'security_group_rule', limit, marker) return self._get_collection(context, SecurityGroupRule, self._make_security_group_rule_dict, filters=filters, fields=fields, sorts=sorts, limit=limit, marker_obj=marker_obj, page_reverse=page_reverse) def get_security_group_rules_count(self, context, filters=None): return self._get_collection_count(context, SecurityGroupRule, filters=filters) def get_security_group_rule(self, context, id, fields=None): security_group_rule = self._get_security_group_rule(context, id) return self._make_security_group_rule_dict(security_group_rule, fields) def _get_security_group_rule(self, context, id): try: query = self._model_query(context, SecurityGroupRule) sgr = query.filter(SecurityGroupRule.id == id).one() except exc.NoResultFound: raise ext_sg.SecurityGroupRuleNotFound(id=id) return sgr def delete_security_group_rule(self, context, id): with context.session.begin(subtransactions=True): rule = self._get_security_group_rule(context, id) context.session.delete(rule) def _extend_port_dict_security_group(self, port_res, port_db): # If port_db is provided, security groups will be accessed via # sqlalchemy models. As they're loaded together with ports this # will not cause an extra query. security_group_ids = [sec_group_mapping['security_group_id'] for sec_group_mapping in port_db.security_groups] port_res[ext_sg.SECURITYGROUPS] = security_group_ids return port_res # Register dict extend functions for ports db_base_plugin_v2.NeutronDbPluginV2.register_dict_extend_funcs( attr.PORTS, [_extend_port_dict_security_group]) def _process_port_create_security_group(self, context, port, security_group_ids): if attr.is_attr_set(security_group_ids): for security_group_id in security_group_ids: self._create_port_security_group_binding(context, port['id'], security_group_id) # Convert to list as a set might be passed here and # this has to be serialized port[ext_sg.SECURITYGROUPS] = (security_group_ids and list(security_group_ids) or []) def _ensure_default_security_group(self, context, tenant_id): """Create a default security group if one doesn't exist. :returns: the default security group id. """ filters = {'name': ['default'], 'tenant_id': [tenant_id]} default_group = self.get_security_groups(context, filters, default_sg=True) if not default_group: security_group = {'security_group': {'name': 'default', 'tenant_id': tenant_id, 'description': 'default'}} ret = self.create_security_group(context, security_group, True) return ret['id'] else: return default_group[0]['id'] def _get_security_groups_on_port(self, context, port): """Check that all security groups on port belong to tenant. :returns: all security groups IDs on port belonging to tenant. """ p = port['port'] if not attr.is_attr_set(p.get(ext_sg.SECURITYGROUPS)): return if p.get('device_owner') and p['device_owner'].startswith('network:'): return valid_groups = self.get_security_groups(context, fields=['id']) valid_group_map = dict((g['id'], g['id']) for g in valid_groups) try: return set([valid_group_map[sg_id] for sg_id in p.get(ext_sg.SECURITYGROUPS, [])]) except KeyError as e: raise ext_sg.SecurityGroupNotFound(id=str(e)) def _ensure_default_security_group_on_port(self, context, port): # we don't apply security groups for dhcp, router if (port['port'].get('device_owner') and port['port']['device_owner'].startswith('network:')): return tenant_id = self._get_tenant_id_for_create(context, port['port']) default_sg = self._ensure_default_security_group(context, tenant_id) if attr.is_attr_set(port['port'].get(ext_sg.SECURITYGROUPS)): sgids = port['port'].get(ext_sg.SECURITYGROUPS) else: sgids = [default_sg] port['port'][ext_sg.SECURITYGROUPS] = sgids def _check_update_deletes_security_groups(self, port): """Return True if port has as a security group and it's value is either [] or not is_attr_set, otherwise return False """ if (ext_sg.SECURITYGROUPS in port['port'] and not (attr.is_attr_set(port['port'][ext_sg.SECURITYGROUPS]) and port['port'][ext_sg.SECURITYGROUPS] != [])): return True return False def _check_update_has_security_groups(self, port): """Return True if port has as a security group and False if the security_group field is is_attr_set or []. """ if (ext_sg.SECURITYGROUPS in port['port'] and (attr.is_attr_set(port['port'][ext_sg.SECURITYGROUPS]) and port['port'][ext_sg.SECURITYGROUPS] != [])): return True return False