diff --git a/anchor/validators.py b/anchor/validators.py index 679fc72..ea61950 100644 --- a/anchor/validators.py +++ b/anchor/validators.py @@ -14,7 +14,6 @@ from __future__ import absolute_import import logging -import socket import netaddr @@ -67,83 +66,53 @@ def iter_alternative_names(csr, types, fail_other_types=True): "'%s'" % (parts[1], parts[0])) -def check_networks(domain, allowed_networks): - """Check the domain resolves to an IP that is within an allowed network +def check_networks(ip, allowed_networks): + """Check the IP is within an allowed network.""" + if not isinstance(ip, netaddr.IPAddress): + raise TypeError("ip must be a netaddr ip address") - Resolve all of the IP addresses for 'domain' and ensure that - at least one of the IP addresses is listed in allowed_networks for the - deployment. - """ if not allowed_networks: # no valid networks were provided, so we can't make any assertions logger.warning("No valid network IP ranges were given, skipping") return True - try: - networks = socket.gethostbyname_ex(domain) - except socket.gaierror: - # the domain is not a valid ip address - return False - - for possible_network in networks[2]: - ip = netaddr.IPAddress(possible_network) - if any(ip in netaddr.IPNetwork(net) for net in allowed_networks): - return True + if any(ip in netaddr.IPNetwork(net) for net in allowed_networks): + return True return False -def check_networks_strict(domain, allowed_networks): - """Check the domain resolves to an IP that is within an allowed network - - Resolve all of the IP addresses for 'domain' and ensure that - at each of the IP addresses is listed in allowed_networks for the - deployment. This is the stricter form of check_networks. - """ - try: - networks = socket.gethostbyname_ex(domain)[2] - except socket.gaierror: - # the domain is not a valid ip address - return False - - for possible_network in networks: - ip = netaddr.IPAddress(possible_network) - if not any(ip in netaddr.IPNetwork(net) for net in allowed_networks): - return False - - return True - - def common_name(csr, allowed_domains=[], allowed_networks=[], **kwargs): """Check the CN entry is a known domain. Refuse requests for certificates if they contain multiple CN - entries, or the domain does not match the list of known suffixes - or network ranges. + entries, or the domain does not match the list of known suffixes. """ alt_present = any(ext.get_name() == "subjectAltName" for ext in csr.get_extensions()) CNs = csr.get_subject().get_entries_by_nid(x509_name.NID_commonName) - if alt_present: - if len(CNs) > 1: - raise ValidationError("Too many CNs in the request") - else: + if len(CNs) > 1: + raise ValidationError("Too many CNs in the request") + if not alt_present: # rfc5280#section-4.2.1.6 says so - if len(csr.get_subject()) == 0: + if len(CNs) == 0: raise ValidationError("Alt subjects have to exist if the main" " subject doesn't") if len(CNs) > 0: cn = csr_get_cn(csr) - if not (check_domains(cn, allowed_domains)): - raise ValidationError("Domain '%s' not allowed (does not match" - " known domains)" % cn) - - if not (check_networks(cn, allowed_networks)): - raise ValidationError("Network '%s' not allowed (does not match" - " known networks)" % cn) + try: + # is it an IP rather than domain? + ip = netaddr.IPAddress(cn) + if not (check_networks(ip, allowed_networks)): + raise ValidationError("Address '%s' not allowed (does not " + "match known networks)" % cn) + except netaddr.AddrFormatError: + if not (check_domains(cn, allowed_domains)): + raise ValidationError("Domain '%s' not allowed (does not " + "match known domains)" % cn) def alternative_names(csr, allowed_domains=[], **kwargs): @@ -156,7 +125,7 @@ def alternative_names(csr, allowed_domains=[], **kwargs): for name_type, name in iter_alternative_names(csr, ['DNS']): if not check_domains(name, allowed_domains): raise ValidationError("Domain '%s' not allowed (doesn't" - " match known domains or networks)" + " match known domains)" % name) @@ -169,11 +138,14 @@ def alternative_names_ip(csr, allowed_domains=[], allowed_networks=[], """ for name_type, name in iter_alternative_names(csr, ['DNS', 'IP Address']): - if not (check_domains(name, allowed_domains) or - check_networks(name, allowed_networks)): + if name_type == 'DNS' and not check_domains(name, allowed_domains): raise ValidationError("Domain '%s' not allowed (doesn't" - " match known domains or networks)" - % name) + " match known domains)" % name) + if name_type == 'IP Address': + ip = netaddr.IPAddress(name) + if not check_networks(ip, allowed_networks): + raise ValidationError("Address '%s' not allowed (doesn't" + " match known networks)" % name) def blacklist_names(csr, domains=[], **kwargs): diff --git a/tests/validators/test_base_validation_functions.py b/tests/validators/test_base_validation_functions.py index ac69fb1..433d233 100644 --- a/tests/validators/test_base_validation_functions.py +++ b/tests/validators/test_base_validation_functions.py @@ -14,11 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. -import socket import textwrap import unittest -import mock +import netaddr from anchor import validators from anchor.X509 import signing_request @@ -100,58 +99,17 @@ class TestBaseValidators(unittest.TestCase): self.assertTrue(validators.check_domains(test_domain, test_allowed)) self.assertFalse(validators.check_domains('gmail.com', test_allowed)) - @mock.patch('socket.gethostbyname_ex') - def test_check_networks_bad_domain(self, gethostbyname_ex): - gethostbyname_ex.side_effect = socket.gaierror() - bad_domain = 'bad!$domain' - allowed_networks = ['127/8', '10/8'] - self.assertFalse(validators.check_networks( - bad_domain, allowed_networks)) + def test_check_networks(self): + good_ip = netaddr.IPAddress('10.2.3.4') + bad_ip = netaddr.IPAddress('88.2.3.4') + test_allowed = ['10/8'] + self.assertTrue(validators.check_networks(good_ip, test_allowed)) + self.assertFalse(validators.check_networks(bad_ip, test_allowed)) - @mock.patch('socket.gethostbyname_ex') - def test_check_networks_both(self, gethostbyname_ex): - allowed_networks = ['15/8', '74.125/16'] - gethostbyname_ex.return_value = ( - 'example.com', - [], - [ - '74.125.224.64', - '74.125.224.67', - '74.125.224.68', - '74.125.224.70', - ] - ) - self.assertTrue(validators.check_networks( - 'example.com', allowed_networks)) - self.assertTrue(validators.check_networks_strict( - 'example.com', allowed_networks)) + def test_check_networks_invalid(self): + with self.assertRaises(TypeError): + validators.check_networks('1.2.3.4', ['10/8']) - gethostbyname_ex.return_value = ('example.com', [], ['12.2.2.2']) - self.assertFalse(validators.check_networks( - 'example.com', allowed_networks)) - - gethostbyname_ex.return_value = ( - 'example.com', - [], - [ - '15.8.2.2', - '15.8.2.1', - '16.1.1.1', - ] - ) - self.assertFalse(validators.check_networks_strict( - 'example.com', allowed_networks)) - - @mock.patch('socket.gethostbyname_ex') - def test_check_networks_exception(self, gethostbyname_ex): - gethostbyname_ex.side_effect = socket.gaierror() - self.assertFalse( - validators.check_networks('mock', ['mock']), - ) - - @mock.patch('socket.gethostbyname_ex') - def test_check_networks_strict_exception(self, gethostbyname_ex): - gethostbyname_ex.side_effect = socket.gaierror() - self.assertFalse( - validators.check_networks_strict('mock', ['mock']), - ) + def test_check_networks_passthrough(self): + good_ip = netaddr.IPAddress('10.2.3.4') + self.assertTrue(validators.check_networks(good_ip, [])) diff --git a/tests/validators/test_callable_validators.py b/tests/validators/test_callable_validators.py index 019e5f8..fdddcce 100644 --- a/tests/validators/test_callable_validators.py +++ b/tests/validators/test_callable_validators.py @@ -17,6 +17,7 @@ import unittest import mock +import netaddr from anchor import validators from anchor.X509 import name as x509_name @@ -29,46 +30,17 @@ class TestValidators(unittest.TestCase): def tearDown(self): super(TestValidators, self).tearDown() - @mock.patch('socket.gethostbyname_ex') - def test_check_networks_good(self, gethostbyname_ex): + def test_check_networks_good(self): allowed_networks = ['15/8', '74.125/16'] - gethostbyname_ex.return_value = ( - 'example.com', - [], - [ - '74.125.224.64', - '74.125.224.67', - '74.125.224.68', - '74.125.224.70', - ] - ) self.assertTrue(validators.check_networks( - 'example.com', allowed_networks)) - self.assertTrue(validators.check_networks_strict( - 'example.com', allowed_networks)) + netaddr.IPAddress('74.125.224.64'), allowed_networks)) - @mock.patch('socket.gethostbyname_ex') - def test_check_networks_bad(self, gethostbyname_ex): + def test_check_networks_bad(self): allowed_networks = ['15/8', '74.125/16'] - gethostbyname_ex.return_value = ('example.com', [], ['12.2.2.2']) self.assertFalse(validators.check_networks( - 'example.com', allowed_networks)) + netaddr.IPAddress('12.2.2.2'), allowed_networks)) - gethostbyname_ex.return_value = ( - 'example.com', - ['mock.mock'], - [ - '15.8.2.2', - '15.8.2.1', - '16.1.1.1', - ] - ) - self.assertFalse(validators.check_networks_strict( - 'example.com', allowed_networks)) - - @mock.patch('socket.gethostbyname_ex') - def test_check_domains_empty(self, gethostbyname_ex): - gethostbyname_ex.return_value = ('example.com', [], ['12.2.2.2']) + def test_check_domains_empty(self): self.assertTrue(validators.check_domains( 'example.com', [])) @@ -106,10 +78,7 @@ class TestValidators(unittest.TestCase): self.assertEqual("Alt subjects have to exist if the main subject" " doesn't", str(e.exception)) - @mock.patch('socket.gethostbyname_ex') - def test_common_name_good_CN(self, gethostbyname_ex): - gethostbyname_ex.return_value = ('master.test.com', [], ['10.0.0.1']) - + def test_common_name_good_CN(self): cn_mock = mock.MagicMock() cn_mock.get_value.return_value = 'master.test.com' @@ -125,14 +94,10 @@ class TestValidators(unittest.TestCase): validators.common_name( csr=csr_mock, allowed_domains=['.test.com'], - allowed_networks=['10/8'] ) ) - @mock.patch('socket.gethostbyname_ex') - def test_common_name_bad_CN(self, gethostbyname_ex): - gethostbyname_ex.return_value = ('master.test.com', [], ['10.0.0.1']) - + def test_common_name_bad_CN(self): name = x509_name.X509Name() name.add_name_entry(x509_name.NID_commonName, 'test.baddomain.com') @@ -142,34 +107,29 @@ class TestValidators(unittest.TestCase): with self.assertRaises(validators.ValidationError) as e: validators.common_name( csr=csr_mock, - allowed_domains=['.test.com'], - allowed_networks=['10/8']) + allowed_domains=['.test.com']) self.assertEqual("Domain 'test.baddomain.com' not allowed (does not " "match known domains)", str(e.exception)) - def test_common_name_good_ip_CN(self): - cn_mock = mock.MagicMock() - cn_mock.get_value.return_value = '10.0.0.1' + def test_common_name_ip_good(self): + name = x509_name.X509Name() + name.add_name_entry(x509_name.NID_commonName, '10.1.1.1') - csr_config = { - 'get_subject.return_value.__len__.return_value': 1, - 'get_subject.return_value.get_entries_by_nid.return_value': - [cn_mock], - } - csr_mock = mock.MagicMock(**csr_config) + csr_mock = mock.MagicMock() + csr_mock.get_subject.return_value = name self.assertEqual( None, validators.common_name( csr=csr_mock, - allowed_domains=[], + allowed_domains=['.test.com'], allowed_networks=['10/8'] ) ) - def test_common_name_bad_ip_CN(self): + def test_common_name_ip_bad(self): name = x509_name.X509Name() - name.add_name_entry(x509_name.NID_commonName, '12.0.0.1') + name.add_name_entry(x509_name.NID_commonName, '15.1.1.1') csr_mock = mock.MagicMock() csr_mock.get_subject.return_value = name @@ -177,15 +137,12 @@ class TestValidators(unittest.TestCase): with self.assertRaises(validators.ValidationError) as e: validators.common_name( csr=csr_mock, - allowed_domains=[], + allowed_domains=['.test.com'], allowed_networks=['10/8']) - self.assertEqual("Network '12.0.0.1' not allowed (does not match " - "known networks)", str(e.exception)) - - @mock.patch('socket.gethostbyname_ex') - def test_alternative_names_good_domain(self, gethostbyname_ex): - gethostbyname_ex.return_value = ('master.test.com', [], ['10.0.0.1']) + self.assertEqual("Address '15.1.1.1' not allowed (does not " + "match known networks)", str(e.exception)) + def test_alternative_names_good_domain(self): ext_mock = mock.MagicMock() ext_mock.get_value.return_value = 'DNS:master.test.com' ext_mock.get_name.return_value = 'subjectAltName' @@ -200,10 +157,7 @@ class TestValidators(unittest.TestCase): ) ) - @mock.patch('socket.gethostbyname_ex') - def test_alternative_names_bad_domain(self, gethostbyname_ex): - gethostbyname_ex.return_value = ('master.test.com', [], ['10.0.0.1']) - + def test_alternative_names_bad_domain(self): ext_mock = mock.MagicMock() ext_mock.get_value.return_value = 'DNS:test.baddomain.com' ext_mock.get_name.return_value = 'subjectAltName' @@ -216,7 +170,7 @@ class TestValidators(unittest.TestCase): csr=csr_mock, allowed_domains=['.test.com']) self.assertEqual("Domain 'test.baddomain.com' not allowed (doesn't " - "match known domains or networks)", str(e.exception)) + "match known domains)", str(e.exception)) def test_alternative_names_ext(self): ext_mock = mock.MagicMock() @@ -233,10 +187,7 @@ class TestValidators(unittest.TestCase): self.assertEqual("Alt name should have 2 parts, but found: 'BAD'", str(e.exception)) - @mock.patch('socket.gethostbyname_ex') - def test_alternative_names_ip_good(self, gethostbyname_ex): - gethostbyname_ex.return_value = ('master.test.com', [], ['10.0.0.1']) - + def test_alternative_names_ip_good(self): ext_mock = mock.MagicMock() ext_mock.get_value.return_value = 'IP Address:10.1.1.1' ext_mock.get_name.return_value = 'subjectAltName' @@ -253,9 +204,7 @@ class TestValidators(unittest.TestCase): ) ) - @mock.patch('socket.gethostbyname_ex') - def test_alternative_names_ip_bad(self, gethostbyname_ex): - gethostbyname_ex.return_value = ('master.test.com', [], ['10.0.0.1']) + def test_alternative_names_ip_bad(self): ext_mock = mock.MagicMock() ext_mock.get_value.return_value = 'IP Address:10.1.1.1' @@ -269,8 +218,23 @@ class TestValidators(unittest.TestCase): csr=csr_mock, allowed_domains=['.test.com'], allowed_networks=['99/8']) - self.assertEqual("Domain '10.1.1.1' not allowed (doesn't match known " - "domains or networks)", str(e.exception)) + self.assertEqual("Address '10.1.1.1' not allowed (doesn't match known " + "networks)", str(e.exception)) + + def test_alternative_names_ip_bad_domain(self): + ext_mock = mock.MagicMock() + ext_mock.get_value.return_value = 'DNS:test.baddomain.com' + ext_mock.get_name.return_value = 'subjectAltName' + + csr_mock = mock.MagicMock() + csr_mock.get_extensions.return_value = [ext_mock] + + with self.assertRaises(validators.ValidationError) as e: + validators.alternative_names_ip( + csr=csr_mock, + allowed_domains=['.test.com']) + self.assertEqual("Domain 'test.baddomain.com' not allowed (doesn't " + "match known domains)", str(e.exception)) def test_alternative_names_ip_ext(self): ext_mock = mock.MagicMock()