diff --git a/neutron/common/constants.py b/neutron/common/constants.py index a6a5e99e42..a644022d4d 100644 --- a/neutron/common/constants.py +++ b/neutron/common/constants.py @@ -34,7 +34,10 @@ INTERFACE_KEY = '_interfaces' IPv4 = 'IPv4' IPv6 = 'IPv6' +ICMP_PROTOCOL = 1 +TCP_PROTOCOL = 6 UDP_PROTOCOL = 17 + DHCP_RESPONSE_PORT = 68 MIN_VLAN_TAG = 1 diff --git a/neutron/db/securitygroups_db.py b/neutron/db/securitygroups_db.py index 4201d42a38..198c231a8d 100644 --- a/neutron/db/securitygroups_db.py +++ b/neutron/db/securitygroups_db.py @@ -22,6 +22,7 @@ from sqlalchemy.orm import exc from sqlalchemy.orm import scoped_session from neutron.api.v2 import attributes as attr +from neutron.common import constants from neutron.db import db_base_plugin_v2 from neutron.db import model_base from neutron.db import models_v2 @@ -29,6 +30,11 @@ from neutron.extensions import securitygroup as ext_sg from neutron.openstack.common import uuidutils +IP_PROTOCOL_MAP = {'tcp': constants.TCP_PROTOCOL, + 'udp': constants.UDP_PROTOCOL, + 'icmp': constants.ICMP_PROTOCOL} + + class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant): """Represents a v2 neutron security group.""" @@ -284,6 +290,32 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): return self.create_security_group_rule_bulk_native(context, bulk_rule)[0] + def _get_ip_proto_number(self, protocol): + if protocol is None: + return + return IP_PROTOCOL_MAP.get(protocol, protocol) + + def _validate_port_range(self, rule): + """Check that port_range is valid.""" + if (rule['port_range_min'] is None and + rule['port_range_max'] is None): + return + if not rule['protocol']: + raise ext_sg.SecurityGroupProtocolRequiredWithPorts() + ip_proto = self._get_ip_proto_number(rule['protocol']) + if ip_proto in [constants.TCP_PROTOCOL, constants.UDP_PROTOCOL]: + if (rule['port_range_min'] is not None and + rule['port_range_min'] <= rule['port_range_max']): + pass + else: + raise ext_sg.SecurityGroupInvalidPortRange() + elif ip_proto == constants.ICMP_PROTOCOL: + for attr, field in [('port_range_min', 'type'), + ('port_range_max', 'code')]: + if rule[attr] > 255: + raise ext_sg.SecurityGroupInvalidIcmpValue( + field=field, attr=attr, value=rule[attr]) + def _validate_security_group_rules(self, context, security_group_rule): """Check that rules being installed. @@ -297,16 +329,7 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase): rule = rules.get('security_group_rule') new_rules.add(rule['security_group_id']) - # Check that port_range's are valid - if (rule['port_range_min'] is None and - rule['port_range_max'] is None): - pass - elif (rule['port_range_min'] is not None and - rule['port_range_min'] <= rule['port_range_max']): - if not rule['protocol']: - raise ext_sg.SecurityGroupProtocolRequiredWithPorts() - else: - raise ext_sg.SecurityGroupInvalidPortRange() + self._validate_port_range(rule) if rule['remote_ip_prefix'] and rule['remote_group_id']: raise ext_sg.SecurityGroupRemoteGroupAndRemoteIpPrefix() diff --git a/neutron/extensions/securitygroup.py b/neutron/extensions/securitygroup.py index 9fd4c95788..ebc1f780bc 100644 --- a/neutron/extensions/securitygroup.py +++ b/neutron/extensions/securitygroup.py @@ -39,6 +39,11 @@ class SecurityGroupInvalidPortValue(qexception.InvalidInput): message = _("Invalid value for port %(port)s") +class SecurityGroupInvalidIcmpValue(qexception.InvalidInput): + message = _("Invalid value for ICMP %(field)s (%(attr)s) " + "%(value)s. It must be 0 to 255.") + + class SecurityGroupInUse(qexception.InUse): message = _("Security Group %(id)s in use.") diff --git a/neutron/tests/unit/test_extension_security_group.py b/neutron/tests/unit/test_extension_security_group.py index a1df601cba..a0d3979638 100644 --- a/neutron/tests/unit/test_extension_security_group.py +++ b/neutron/tests/unit/test_extension_security_group.py @@ -633,6 +633,57 @@ class TestSecurityGroups(SecurityGroupDBTestCase): for k, v, in keys: self.assertEqual(rule['security_group_rule'][k], v) + def test_create_security_group_rule_icmp_with_type_and_code(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + direction = "ingress" + remote_ip_prefix = "10.0.0.0/24" + protocol = 'icmp' + # port_range_min (ICMP type) is greater than port_range_max + # (ICMP code) in order to confirm min <= max port check is + # not called for ICMP. + port_range_min = 8 + port_range_max = 5 + keys = [('remote_ip_prefix', remote_ip_prefix), + ('security_group_id', security_group_id), + ('direction', direction), + ('protocol', protocol), + ('port_range_min', port_range_min), + ('port_range_max', port_range_max)] + with self.security_group_rule(security_group_id, direction, + protocol, port_range_min, + port_range_max, + remote_ip_prefix) as rule: + for k, v, in keys: + self.assertEqual(rule['security_group_rule'][k], v) + + def test_create_security_group_rule_icmp_with_type_only(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + direction = "ingress" + remote_ip_prefix = "10.0.0.0/24" + protocol = 'icmp' + # ICMP type + port_range_min = 8 + # ICMP code + port_range_max = None + keys = [('remote_ip_prefix', remote_ip_prefix), + ('security_group_id', security_group_id), + ('direction', direction), + ('protocol', protocol), + ('port_range_min', port_range_min), + ('port_range_max', port_range_max)] + with self.security_group_rule(security_group_id, direction, + protocol, port_range_min, + port_range_max, + remote_ip_prefix) as rule: + for k, v, in keys: + self.assertEqual(rule['security_group_rule'][k], v) + def test_create_security_group_source_group_ip_and_ip_prefix(self): security_group_id = "4cd70774-cc67-4a87-9b39-7d1db38eb087" direction = "ingress" @@ -757,12 +808,14 @@ class TestSecurityGroups(SecurityGroupDBTestCase): with self.security_group(name, description) as sg: security_group_id = sg['security_group']['id'] with self.security_group_rule(security_group_id): - rule = self._build_security_group_rule( - sg['security_group']['id'], 'ingress', 'tcp', '50', '22') - self._create_security_group_rule(self.fmt, rule) - res = self._create_security_group_rule(self.fmt, rule) - self.deserialize(self.fmt, res) - self.assertEqual(res.status_int, 400) + for protocol in ['tcp', 'udp', 6, 17]: + rule = self._build_security_group_rule( + sg['security_group']['id'], + 'ingress', protocol, '50', '22') + self._create_security_group_rule(self.fmt, rule) + res = self._create_security_group_rule(self.fmt, rule) + self.deserialize(self.fmt, res) + self.assertEqual(res.status_int, 400) def test_create_security_group_rule_ports_but_no_protocol(self): name = 'webservers' @@ -777,6 +830,58 @@ class TestSecurityGroups(SecurityGroupDBTestCase): self.deserialize(self.fmt, res) self.assertEqual(res.status_int, 400) + def test_create_security_group_rule_port_range_min_only(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + with self.security_group_rule(security_group_id): + rule = self._build_security_group_rule( + sg['security_group']['id'], 'ingress', 'tcp', '22', None) + self._create_security_group_rule(self.fmt, rule) + res = self._create_security_group_rule(self.fmt, rule) + self.deserialize(self.fmt, res) + self.assertEqual(res.status_int, 400) + + def test_create_security_group_rule_port_range_max_only(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + with self.security_group_rule(security_group_id): + rule = self._build_security_group_rule( + sg['security_group']['id'], 'ingress', 'tcp', None, '22') + self._create_security_group_rule(self.fmt, rule) + res = self._create_security_group_rule(self.fmt, rule) + self.deserialize(self.fmt, res) + self.assertEqual(res.status_int, 400) + + def test_create_security_group_rule_icmp_type_too_big(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + with self.security_group_rule(security_group_id): + rule = self._build_security_group_rule( + sg['security_group']['id'], 'ingress', 'icmp', '256', None) + self._create_security_group_rule(self.fmt, rule) + res = self._create_security_group_rule(self.fmt, rule) + self.deserialize(self.fmt, res) + self.assertEqual(res.status_int, 400) + + def test_create_security_group_rule_icmp_code_too_big(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + with self.security_group_rule(security_group_id): + rule = self._build_security_group_rule( + sg['security_group']['id'], 'ingress', 'icmp', '8', '256') + self._create_security_group_rule(self.fmt, rule) + res = self._create_security_group_rule(self.fmt, rule) + self.deserialize(self.fmt, res) + self.assertEqual(res.status_int, 400) + def test_list_ports_security_group(self): with self.network() as n: with self.subnet(n):