From c5dd15ef8bf864c0e05acd0a7426a6bdad7dc930 Mon Sep 17 00:00:00 2001 From: Aaron Rosen Date: Thu, 24 Jan 2013 15:45:04 -0800 Subject: [PATCH] Make protocol and ethertype case insensitive for security groups Fixes bug 1104495 Change-Id: I0d93f5e849ebe0be72fff8c1d82f5825540df338 --- quantum/extensions/securitygroup.py | 25 ++++++++-- .../unit/test_extension_security_group.py | 49 +++++++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/quantum/extensions/securitygroup.py b/quantum/extensions/securitygroup.py index 488890abe5..c48e7f890e 100644 --- a/quantum/extensions/securitygroup.py +++ b/quantum/extensions/securitygroup.py @@ -54,9 +54,9 @@ class SecurityGroupDefaultAlreadyExists(qexception.InUse): message = _("Default security group already exists.") -class SecurityGroupRuleInvalidProtocol(qexception.InUse): - message = _("Security group rule protocol %(protocol)s not supported " - "only protocol values %(values)s supported.") +class SecurityGroupRuleInvalidProtocol(qexception.InvalidInput): + message = _("Security group rule protocol %(protocol)s not supported. " + "Only protocol values %(values)s supported.") class SecurityGroupRulesNotSingleTenant(qexception.InvalidInput): @@ -114,6 +114,23 @@ class SecurityGroupInvalidExternalID(qexception.InvalidInput): message = _("external_id wrong type %(data)s") +def convert_protocol_to_case_insensitive(value): + if value is None: + return value + try: + return value.lower() + except AttributeError: + raise SecurityGroupRuleInvalidProtocol( + protocol=value, values=sg_supported_protocols) + + +def convert_ethertype_to_case_insensitive(value): + if isinstance(value, basestring): + for ethertype in sg_supported_ethertypes: + if ethertype.lower() == value.lower(): + return ethertype + + def convert_validate_port_value(port): if port is None: return port @@ -199,6 +216,7 @@ RESOURCE_ATTRIBUTE_MAP = { 'validate': {'type:values': ['ingress', 'egress']}}, 'protocol': {'allow_post': True, 'allow_put': False, 'is_visible': True, 'default': None, + 'convert_to': convert_protocol_to_case_insensitive, 'validate': {'type:values': sg_supported_protocols}}, 'port_range_min': {'allow_post': True, 'allow_put': False, 'convert_to': convert_validate_port_value, @@ -208,6 +226,7 @@ RESOURCE_ATTRIBUTE_MAP = { 'default': None, 'is_visible': True}, 'ethertype': {'allow_post': True, 'allow_put': False, 'is_visible': True, 'default': 'IPv4', + 'convert_to': convert_ethertype_to_case_insensitive, 'validate': {'type:values': sg_supported_ethertypes}}, 'source_ip_prefix': {'allow_post': True, 'allow_put': False, 'default': None, 'is_visible': True}, diff --git a/quantum/tests/unit/test_extension_security_group.py b/quantum/tests/unit/test_extension_security_group.py index 879bee8c2a..f24c01b456 100644 --- a/quantum/tests/unit/test_extension_security_group.py +++ b/quantum/tests/unit/test_extension_security_group.py @@ -287,6 +287,55 @@ class TestSecurityGroups(SecurityGroupDBTestCase): else: self.assertEquals(len(group['security_group_rules']), 0) + def test_create_security_group_rule_ethertype_invalid_as_number(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + ethertype = 2 + rule = self._build_security_group_rule( + security_group_id, 'ingress', 'tcp', '22', '22', None, None, + ethertype=ethertype) + res = self._create_security_group_rule('json', rule) + self.deserialize('json', res) + self.assertEqual(res.status_int, 400) + + def test_create_security_group_rule_protocol_invalid_as_number(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + protocol = 2 + rule = self._build_security_group_rule( + security_group_id, 'ingress', protocol, '22', '22', + None, None) + res = self._create_security_group_rule('json', rule) + self.deserialize('json', res) + self.assertEqual(res.status_int, 400) + + def test_create_security_group_rule_case_insensitive(self): + name = 'webservers' + description = 'my webservers' + with self.security_group(name, description) as sg: + security_group_id = sg['security_group']['id'] + direction = "ingress" + source_ip_prefix = "10.0.0.0/24" + protocol = 'TCP' + port_range_min = 22 + port_range_max = 22 + ethertype = 'ipV4' + with self.security_group_rule(security_group_id, direction, + protocol, port_range_min, + port_range_max, + source_ip_prefix, + ethertype=ethertype) as rule: + + # the lower case value will be return + self.assertEquals(rule['security_group_rule']['protocol'], + protocol.lower()) + self.assertEquals(rule['security_group_rule']['ethertype'], + 'IPv4') + def test_get_security_group(self): name = 'webservers' description = 'my webservers'