Merge "Fix 500 error on invalid security-group-rule creation for NVP"

This commit is contained in:
Jenkins 2013-09-01 04:46:29 +00:00 committed by Gerrit Code Review
commit 6c11b959d9
5 changed files with 55 additions and 8 deletions

View File

@ -38,6 +38,7 @@ IPv6 = 'IPv6'
ICMP_PROTOCOL = 1 ICMP_PROTOCOL = 1
TCP_PROTOCOL = 6 TCP_PROTOCOL = 6
UDP_PROTOCOL = 17 UDP_PROTOCOL = 17
ICMPv6_PROTOCOL = 58
DHCP_RESPONSE_PORT = 68 DHCP_RESPONSE_PORT = 68

View File

@ -32,7 +32,8 @@ from neutron.openstack.common import uuidutils
IP_PROTOCOL_MAP = {'tcp': constants.TCP_PROTOCOL, IP_PROTOCOL_MAP = {'tcp': constants.TCP_PROTOCOL,
'udp': constants.UDP_PROTOCOL, 'udp': constants.UDP_PROTOCOL,
'icmp': constants.ICMP_PROTOCOL} 'icmp': constants.ICMP_PROTOCOL,
'icmpv6': constants.ICMPv6_PROTOCOL}
class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant): class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant):

View File

@ -2251,6 +2251,21 @@ class NvpPluginV2(db_base_plugin_v2.NeutronDbPluginV2,
return super(NvpPluginV2, self).delete_security_group( return super(NvpPluginV2, self).delete_security_group(
context, security_group_id) context, security_group_id)
def _validate_security_group_rules(self, context, rules):
for rule in rules['security_group_rules']:
r = rule.get('security_group_rule')
port_based_proto = (self._get_ip_proto_number(r['protocol'])
in securitygroups_db.IP_PROTOCOL_MAP.values())
if (not port_based_proto and
(r['port_range_min'] is not None or
r['port_range_max'] is not None)):
msg = (_("Port values not valid for "
"protocol: %s") % r['protocol'])
raise q_exc.BadRequest(resource='security_group_rule',
msg=msg)
return super(NvpPluginV2, self)._validate_security_group_rules(context,
rules)
def create_security_group_rule(self, context, security_group_rule): def create_security_group_rule(self, context, security_group_rule):
"""Create a single security group rule.""" """Create a single security group rule."""
bulk_rule = {'security_group_rules': [security_group_rule]} bulk_rule = {'security_group_rules': [security_group_rule]}

View File

@ -370,6 +370,21 @@ class TestNiciraSecurityGroup(ext_sg.TestSecurityGroups,
# Assert Neutron name is not truncated # Assert Neutron name is not truncated
self.assertEqual(sg['security_group']['name'], name) self.assertEqual(sg['security_group']['name'], name)
def test_create_security_group_rule_bad_input(self):
name = 'foo security group'
description = 'foo description'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
protocol = 200
min_range = 32
max_range = 4343
rule = self._build_security_group_rule(
security_group_id, 'ingress', protocol,
min_range, max_range)
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400)
class TestNiciraL3ExtensionManager(object): class TestNiciraL3ExtensionManager(object):

View File

@ -71,20 +71,24 @@ class SecurityGroupsTestCase(test_db_plugin.NeutronDbPluginV2TestCase):
context.Context('', kwargs['tenant_id'])) context.Context('', kwargs['tenant_id']))
return security_group_req.get_response(self.ext_api) return security_group_req.get_response(self.ext_api)
def _build_security_group_rule(self, security_group_id, direction, def _build_security_group_rule(self, security_group_id, direction, proto,
protocol, port_range_min, port_range_max, port_range_min=None, port_range_max=None,
remote_ip_prefix=None, remote_group_id=None, remote_ip_prefix=None, remote_group_id=None,
tenant_id='test_tenant', tenant_id='test_tenant',
ethertype='IPv4'): ethertype='IPv4'):
data = {'security_group_rule': {'security_group_id': security_group_id, data = {'security_group_rule': {'security_group_id': security_group_id,
'direction': direction, 'direction': direction,
'protocol': protocol, 'protocol': proto,
'ethertype': ethertype, 'ethertype': ethertype,
'port_range_min': port_range_min,
'port_range_max': port_range_max,
'tenant_id': tenant_id, 'tenant_id': tenant_id,
'ethertype': ethertype}} 'ethertype': ethertype}}
if port_range_min:
data['security_group_rule']['port_range_min'] = port_range_min
if port_range_max:
data['security_group_rule']['port_range_max'] = port_range_max
if remote_ip_prefix: if remote_ip_prefix:
data['security_group_rule']['remote_ip_prefix'] = remote_ip_prefix data['security_group_rule']['remote_ip_prefix'] = remote_ip_prefix
@ -408,6 +412,18 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
self.deserialize(self.fmt, res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_security_group_rule_tcp_protocol_as_number(self):
name = 'webservers'
description = 'my webservers'
with self.security_group(name, description) as sg:
security_group_id = sg['security_group']['id']
protocol = 6 # TCP
rule = self._build_security_group_rule(
security_group_id, 'ingress', protocol, '22', '22')
res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 201)
def test_create_security_group_rule_protocol_as_number(self): def test_create_security_group_rule_protocol_as_number(self):
name = 'webservers' name = 'webservers'
description = 'my webservers' description = 'my webservers'
@ -415,8 +431,7 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
security_group_id = sg['security_group']['id'] security_group_id = sg['security_group']['id']
protocol = 2 protocol = 2
rule = self._build_security_group_rule( rule = self._build_security_group_rule(
security_group_id, 'ingress', protocol, '22', '22', security_group_id, 'ingress', protocol)
None, None)
res = self._create_security_group_rule(self.fmt, rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize(self.fmt, res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 201) self.assertEqual(res.status_int, 201)