Merge "Limit min<=max port check to TCP/UDP in secgroup rule"
This commit is contained in:
commit
44f4a9810f
@ -34,7 +34,10 @@ INTERFACE_KEY = '_interfaces'
|
||||
IPv4 = 'IPv4'
|
||||
IPv6 = 'IPv6'
|
||||
|
||||
ICMP_PROTOCOL = 1
|
||||
TCP_PROTOCOL = 6
|
||||
UDP_PROTOCOL = 17
|
||||
|
||||
DHCP_RESPONSE_PORT = 68
|
||||
|
||||
MIN_VLAN_TAG = 1
|
||||
|
@ -22,6 +22,7 @@ from sqlalchemy.orm import exc
|
||||
from sqlalchemy.orm import scoped_session
|
||||
|
||||
from neutron.api.v2 import attributes as attr
|
||||
from neutron.common import constants
|
||||
from neutron.db import db_base_plugin_v2
|
||||
from neutron.db import model_base
|
||||
from neutron.db import models_v2
|
||||
@ -29,6 +30,11 @@ from neutron.extensions import securitygroup as ext_sg
|
||||
from neutron.openstack.common import uuidutils
|
||||
|
||||
|
||||
IP_PROTOCOL_MAP = {'tcp': constants.TCP_PROTOCOL,
|
||||
'udp': constants.UDP_PROTOCOL,
|
||||
'icmp': constants.ICMP_PROTOCOL}
|
||||
|
||||
|
||||
class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant):
|
||||
"""Represents a v2 neutron security group."""
|
||||
|
||||
@ -284,6 +290,32 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
||||
return self.create_security_group_rule_bulk_native(context,
|
||||
bulk_rule)[0]
|
||||
|
||||
def _get_ip_proto_number(self, protocol):
|
||||
if protocol is None:
|
||||
return
|
||||
return IP_PROTOCOL_MAP.get(protocol, protocol)
|
||||
|
||||
def _validate_port_range(self, rule):
|
||||
"""Check that port_range is valid."""
|
||||
if (rule['port_range_min'] is None and
|
||||
rule['port_range_max'] is None):
|
||||
return
|
||||
if not rule['protocol']:
|
||||
raise ext_sg.SecurityGroupProtocolRequiredWithPorts()
|
||||
ip_proto = self._get_ip_proto_number(rule['protocol'])
|
||||
if ip_proto in [constants.TCP_PROTOCOL, constants.UDP_PROTOCOL]:
|
||||
if (rule['port_range_min'] is not None and
|
||||
rule['port_range_min'] <= rule['port_range_max']):
|
||||
pass
|
||||
else:
|
||||
raise ext_sg.SecurityGroupInvalidPortRange()
|
||||
elif ip_proto == constants.ICMP_PROTOCOL:
|
||||
for attr, field in [('port_range_min', 'type'),
|
||||
('port_range_max', 'code')]:
|
||||
if rule[attr] > 255:
|
||||
raise ext_sg.SecurityGroupInvalidIcmpValue(
|
||||
field=field, attr=attr, value=rule[attr])
|
||||
|
||||
def _validate_security_group_rules(self, context, security_group_rule):
|
||||
"""Check that rules being installed.
|
||||
|
||||
@ -297,16 +329,7 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
|
||||
rule = rules.get('security_group_rule')
|
||||
new_rules.add(rule['security_group_id'])
|
||||
|
||||
# Check that port_range's are valid
|
||||
if (rule['port_range_min'] is None and
|
||||
rule['port_range_max'] is None):
|
||||
pass
|
||||
elif (rule['port_range_min'] is not None and
|
||||
rule['port_range_min'] <= rule['port_range_max']):
|
||||
if not rule['protocol']:
|
||||
raise ext_sg.SecurityGroupProtocolRequiredWithPorts()
|
||||
else:
|
||||
raise ext_sg.SecurityGroupInvalidPortRange()
|
||||
self._validate_port_range(rule)
|
||||
|
||||
if rule['remote_ip_prefix'] and rule['remote_group_id']:
|
||||
raise ext_sg.SecurityGroupRemoteGroupAndRemoteIpPrefix()
|
||||
|
@ -39,6 +39,11 @@ class SecurityGroupInvalidPortValue(qexception.InvalidInput):
|
||||
message = _("Invalid value for port %(port)s")
|
||||
|
||||
|
||||
class SecurityGroupInvalidIcmpValue(qexception.InvalidInput):
|
||||
message = _("Invalid value for ICMP %(field)s (%(attr)s) "
|
||||
"%(value)s. It must be 0 to 255.")
|
||||
|
||||
|
||||
class SecurityGroupInUse(qexception.InUse):
|
||||
message = _("Security Group %(id)s in use.")
|
||||
|
||||
|
@ -633,6 +633,57 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
|
||||
for k, v, in keys:
|
||||
self.assertEqual(rule['security_group_rule'][k], v)
|
||||
|
||||
def test_create_security_group_rule_icmp_with_type_and_code(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
direction = "ingress"
|
||||
remote_ip_prefix = "10.0.0.0/24"
|
||||
protocol = 'icmp'
|
||||
# port_range_min (ICMP type) is greater than port_range_max
|
||||
# (ICMP code) in order to confirm min <= max port check is
|
||||
# not called for ICMP.
|
||||
port_range_min = 8
|
||||
port_range_max = 5
|
||||
keys = [('remote_ip_prefix', remote_ip_prefix),
|
||||
('security_group_id', security_group_id),
|
||||
('direction', direction),
|
||||
('protocol', protocol),
|
||||
('port_range_min', port_range_min),
|
||||
('port_range_max', port_range_max)]
|
||||
with self.security_group_rule(security_group_id, direction,
|
||||
protocol, port_range_min,
|
||||
port_range_max,
|
||||
remote_ip_prefix) as rule:
|
||||
for k, v, in keys:
|
||||
self.assertEqual(rule['security_group_rule'][k], v)
|
||||
|
||||
def test_create_security_group_rule_icmp_with_type_only(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
direction = "ingress"
|
||||
remote_ip_prefix = "10.0.0.0/24"
|
||||
protocol = 'icmp'
|
||||
# ICMP type
|
||||
port_range_min = 8
|
||||
# ICMP code
|
||||
port_range_max = None
|
||||
keys = [('remote_ip_prefix', remote_ip_prefix),
|
||||
('security_group_id', security_group_id),
|
||||
('direction', direction),
|
||||
('protocol', protocol),
|
||||
('port_range_min', port_range_min),
|
||||
('port_range_max', port_range_max)]
|
||||
with self.security_group_rule(security_group_id, direction,
|
||||
protocol, port_range_min,
|
||||
port_range_max,
|
||||
remote_ip_prefix) as rule:
|
||||
for k, v, in keys:
|
||||
self.assertEqual(rule['security_group_rule'][k], v)
|
||||
|
||||
def test_create_security_group_source_group_ip_and_ip_prefix(self):
|
||||
security_group_id = "4cd70774-cc67-4a87-9b39-7d1db38eb087"
|
||||
direction = "ingress"
|
||||
@ -757,8 +808,10 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
for protocol in ['tcp', 'udp', 6, 17]:
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress', 'tcp', '50', '22')
|
||||
sg['security_group']['id'],
|
||||
'ingress', protocol, '50', '22')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
@ -777,6 +830,58 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(res.status_int, 400)
|
||||
|
||||
def test_create_security_group_rule_port_range_min_only(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress', 'tcp', '22', None)
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(res.status_int, 400)
|
||||
|
||||
def test_create_security_group_rule_port_range_max_only(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress', 'tcp', None, '22')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(res.status_int, 400)
|
||||
|
||||
def test_create_security_group_rule_icmp_type_too_big(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress', 'icmp', '256', None)
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(res.status_int, 400)
|
||||
|
||||
def test_create_security_group_rule_icmp_code_too_big(self):
|
||||
name = 'webservers'
|
||||
description = 'my webservers'
|
||||
with self.security_group(name, description) as sg:
|
||||
security_group_id = sg['security_group']['id']
|
||||
with self.security_group_rule(security_group_id):
|
||||
rule = self._build_security_group_rule(
|
||||
sg['security_group']['id'], 'ingress', 'icmp', '8', '256')
|
||||
self._create_security_group_rule(self.fmt, rule)
|
||||
res = self._create_security_group_rule(self.fmt, rule)
|
||||
self.deserialize(self.fmt, res)
|
||||
self.assertEqual(res.status_int, 400)
|
||||
|
||||
def test_list_ports_security_group(self):
|
||||
with self.network() as n:
|
||||
with self.subnet(n):
|
||||
|
Loading…
Reference in New Issue
Block a user