diff --git a/anchor/certificate_ops.py b/anchor/certificate_ops.py index 83baff0..ec65ed3 100644 --- a/anchor/certificate_ops.py +++ b/anchor/certificate_ops.py @@ -22,7 +22,7 @@ import pecan from webob import exc as http_status from anchor import jsonloader -from anchor import validators +from anchor.validators import errors from anchor.X509 import certificate from anchor.X509 import extension from anchor.X509 import signing_request @@ -85,7 +85,7 @@ def _run_validator(name, body, args): validator(**new_kwargs) logger.debug("_run_validator: success: <%s> ", name) return True # validator passed b/c no exceptions - except validators.ValidationError as e: + except errors.ValidationError as e: logger.error("_run_validator: FAILED: <%s> - %s", name, e) return False diff --git a/anchor/util.py b/anchor/util.py index eba8a51..14f2318 100644 --- a/anchor/util.py +++ b/anchor/util.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import hmac +import re def constant_time_compare(val1, val2): @@ -44,3 +45,32 @@ def _constant_time_compare(val1, val2): for x, y in zip(val1, val2): result |= ord(x) ^ ord(y) return result == 0 + + +# RFC1034 allows a simple " " too, but it's not allowed in certificates, so it +# will not match +RE_DOMAIN_LABEL = re.compile("^[a-z](?:[-a-z0-9]*[a-z0-9])?$", re.IGNORECASE) + + +def verify_domain(domain, allow_wildcards=False): + labels = domain.split('.') + if labels[-1] == "": + # single trailing . is ok, ignore + labels.pop(-1) + + for i, label in enumerate(labels): + if len(label) > 63: + raise ValueError( + "domain <%s> it too long (RFC5280/4.2.1.6)" % (domain,)) + + # check for wildcard labels, ignore partial-wildcard labels + if '*' == label and allow_wildcards: + if i != 0: + raise ValueError( + "domain <%s> has wildcard that's not in the " + "left-most label (RFC6125/6.4.3)" % (domain,)) + else: + if RE_DOMAIN_LABEL.match(label) is None: + raise ValueError( + "domain <%s> contains invalid characters " + "(RFC1034/3.5)" % (domain,)) diff --git a/anchor/validators/__init__.py b/anchor/validators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/anchor/validators.py b/anchor/validators/custom.py similarity index 56% rename from anchor/validators.py rename to anchor/validators/custom.py index 6a78142..16eb83c 100644 --- a/anchor/validators.py +++ b/anchor/validators/custom.py @@ -18,6 +18,8 @@ import logging import netaddr from pyasn1.type import univ as pyasn1_univ +from anchor.validators import errors as v_errors +from anchor.validators import utils from anchor.X509 import errors from anchor.X509 import extension from anchor.X509 import name as x509_name @@ -26,60 +28,6 @@ from anchor.X509 import name as x509_name logger = logging.getLogger(__name__) -class ValidationError(Exception): - pass - - -def csr_require_cn(csr): - cns = csr.get_subject_cn() - if not cns: - raise ValidationError("CSR is lacking a CN in the Subject") - if len(cns) > 1: - raise ValidationError("CSR has too many CN entries") - return cns[0] - - -def check_domains(domain, allowed_domains): - if allowed_domains: - if not any(domain.endswith(suffix) for suffix in allowed_domains): - # no domain matched - return False - else: - # no valid domains were provided, so we can't make any assertions - logger.warning("No domains were configured for validation. Anchor " - "will issue certificates for any domain, this is not a " - "recommended configuration for production environments") - return True - - -def iter_alternative_names(csr, types, fail_other_types=True): - for ext in csr.get_extensions(): - if isinstance(ext, extension.X509ExtensionSubjectAltName): - # TODO(stan): fail on other types - if 'DNS' in types: - for dns_id in ext.get_dns_ids(): - yield ('DNS', dns_id) - if 'IP Address' in types: - for ip in ext.get_ips(): - yield ('IP Address', ip) - - -def check_networks(ip, allowed_networks): - """Check the IP is within an allowed network.""" - if not isinstance(ip, netaddr.IPAddress): - raise TypeError("ip must be a netaddr ip address") - - if not allowed_networks: - # no valid networks were provided, so we can't make any assertions - logger.warning("No valid network IP ranges were given, skipping") - return True - - if any(ip in netaddr.IPNetwork(net) for net in allowed_networks): - return True - - return False - - def common_name(csr, allowed_domains=[], allowed_networks=[], **kwargs): """Check the CN entry is a known domain. @@ -92,25 +40,27 @@ def common_name(csr, allowed_domains=[], allowed_networks=[], **kwargs): CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName) if len(CNs) > 1: - raise ValidationError("Too many CNs in the request") + raise v_errors.ValidationError("Too many CNs in the request") # rfc5280#section-4.2.1.6 says so if len(CNs) == 0 and not alt_present: - raise ValidationError("Alt subjects have to exist if the main" - " subject doesn't") + raise v_errors.ValidationError("Alt subjects have to exist if the main" + " subject doesn't") if len(CNs) > 0: - cn = csr_require_cn(csr) + cn = utils.csr_require_cn(csr) try: # is it an IP rather than domain? ip = netaddr.IPAddress(cn) - if not (check_networks(ip, allowed_networks)): - raise ValidationError("Address '%s' not allowed (does not " - "match known networks)" % cn) + if not (utils.check_networks(ip, allowed_networks)): + raise v_errors.ValidationError( + "Address '%s' not allowed (does not match known networks)" + % cn) except netaddr.AddrFormatError: - if not (check_domains(cn, allowed_domains)): - raise ValidationError("Domain '%s' not allowed (does not " - "match known domains)" % cn) + if not (utils.check_domains(cn, allowed_domains)): + raise v_errors.ValidationError( + "Domain '%s' not allowed (does not match known domains)" + % cn) def alternative_names(csr, allowed_domains=[], **kwargs): @@ -120,11 +70,10 @@ def alternative_names(csr, allowed_domains=[], **kwargs): the list of known suffixes, or network ranges. """ - for _, name in iter_alternative_names(csr, ['DNS']): - if not check_domains(name, allowed_domains): - raise ValidationError("Domain '%s' not allowed (doesn't" - " match known domains)" - % name) + for _, name in utils.iter_alternative_names(csr, ['DNS']): + if not utils.check_domains(name, allowed_domains): + raise v_errors.ValidationError("Domain '%s' not allowed (doesn't" + " match known domains)" % name) def alternative_names_ip(csr, allowed_domains=[], allowed_networks=[], @@ -135,14 +84,16 @@ def alternative_names_ip(csr, allowed_domains=[], allowed_networks=[], the list of known suffixes, or network ranges. """ - for name_type, name in iter_alternative_names(csr, ['DNS', 'IP Address']): - if name_type == 'DNS' and not check_domains(name, allowed_domains): - raise ValidationError("Domain '%s' not allowed (doesn't" - " match known domains)" % name) + for name_type, name in utils.iter_alternative_names(csr, + ['DNS', 'IP Address']): + if name_type == 'DNS' and not utils.check_domains(name, + allowed_domains): + raise v_errors.ValidationError("Domain '%s' not allowed (doesn't" + " match known domains)" % name) if name_type == 'IP Address': - if not check_networks(name, allowed_networks): - raise ValidationError("IP '%s' not allowed (doesn't" - " match known networks)" % name) + if not utils.check_networks(name, allowed_networks): + raise v_errors.ValidationError("IP '%s' not allowed (doesn't" + " match known networks)" % name) def blacklist_names(csr, domains=[], **kwargs): @@ -155,16 +106,16 @@ def blacklist_names(csr, domains=[], **kwargs): CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName) if len(CNs) > 0: - cn = csr_require_cn(csr) - if check_domains(cn, domains): - raise ValidationError("Domain '%s' not allowed " - "(CN blacklisted)" % cn) + cn = utils.csr_require_cn(csr) + if utils.check_domains(cn, domains): + raise v_errors.ValidationError("Domain '%s' not allowed " + "(CN blacklisted)" % cn) - for _, name in iter_alternative_names(csr, ['DNS'], - fail_other_types=False): - if check_domains(name, domains): - raise ValidationError("Domain '%s' not allowed " - "(alt blacklisted)" % name) + for _, name in utils.iter_alternative_names(csr, ['DNS'], + fail_other_types=False): + if utils.check_domains(name, domains): + raise v_errors.ValidationError("Domain '%s' not allowed " + "(alt blacklisted)" % name) def server_group(auth_result=None, csr=None, group_prefixes={}, **kwargs): @@ -174,14 +125,15 @@ def server_group(auth_result=None, csr=None, group_prefixes={}, **kwargs): verified against the groups the user is a member of. """ - cn = csr_require_cn(csr) + cn = utils.csr_require_cn(csr) parts = cn.split('-') if len(parts) == 1 or '.' in parts[0]: return # no prefix if parts[0] in group_prefixes: if group_prefixes[parts[0]] not in auth_result.groups: - raise ValidationError("Server prefix doesn't match user groups") + raise v_errors.ValidationError( + "Server prefix doesn't match user groups") def extensions(csr=None, allowed_extensions=[], **kwargs): @@ -190,8 +142,8 @@ def extensions(csr=None, allowed_extensions=[], **kwargs): for ext in exts: if (ext.get_name() not in allowed_extensions and str(ext.get_oid()) not in allowed_extensions): - raise ValidationError("Extension '%s' not allowed" - % ext.get_name()) + raise v_errors.ValidationError("Extension '%s' not allowed" + % ext.get_name()) def key_usage(csr=None, allowed_usage=None, **kwargs): @@ -205,8 +157,8 @@ def key_usage(csr=None, allowed_usage=None, **kwargs): usages = set(ext.get_all_usages()) denied = denied | (usages - allowed) if denied: - raise ValidationError("Found some prohibited key usages: %s" - % ', '.join(denied)) + raise v_errors.ValidationError("Found some prohibited key usages: %s" + % ', '.join(denied)) def ext_key_usage(csr=None, allowed_usage=None, **kwargs): @@ -223,7 +175,7 @@ def ext_key_usage(csr=None, allowed_usage=None, **kwargs): oid = pyasn1_univ.ObjectIdentifier(usage) allowed_usage[i] = oid except Exception: - raise ValidationError("Unknown usage: %s" % (usage,)) + raise v_errors.ValidationError("Unknown usage: %s" % (usage,)) allowed = set(allowed_usage) denied = set() @@ -234,8 +186,8 @@ def ext_key_usage(csr=None, allowed_usage=None, **kwargs): if denied: text_denied = [extension.EXT_KEY_USAGE_SHORT_NAMES.get(x) for x in denied] - raise ValidationError("Found some prohibited key usages: %s" - % ', '.join(text_denied)) + raise v_errors.ValidationError("Found some prohibited key usages: %s" + % ', '.join(text_denied)) def ca_status(csr=None, ca_requested=False, **kwargs): @@ -245,7 +197,7 @@ def ca_status(csr=None, ca_requested=False, **kwargs): if isinstance(ext, extension.X509ExtensionBasicConstraints): if ext.get_ca(): if not ca_requested: - raise ValidationError( + raise v_errors.ValidationError( "CA status requested, but not allowed") request_ca_flags = True elif isinstance(ext, extension.X509ExtensionKeyUsage): @@ -253,13 +205,13 @@ def ca_status(csr=None, ca_requested=False, **kwargs): has_crl_sign = ext.get_usage('cRLSign') if has_crl_sign or has_cert_sign: if not ca_requested: - raise ValidationError( + raise v_errors.ValidationError( "Key usage doesn't match requested CA status " "(keyCertSign/cRLSign: %s/%s)" % (has_cert_sign, has_crl_sign)) request_ca_flags = True if ca_requested and not request_ca_flags: - raise ValidationError("CA flags required") + raise v_errors.ValidationError("CA flags required") def source_cidrs(request=None, cidrs=None, **kwargs): @@ -270,16 +222,17 @@ def source_cidrs(request=None, cidrs=None, **kwargs): if request.client_addr in r: return except netaddr.AddrFormatError: - raise ValidationError("Cidr '%s' does not describe a valid" - " network" % cidr) - raise ValidationError("No network matched the request source '%s'" % - request.client_addr) + raise v_errors.ValidationError( + "Cidr '%s' does not describe a valid network" % cidr) + raise v_errors.ValidationError( + "No network matched the request source '%s'" % + request.client_addr) def csr_signature(csr=None, **kwargs): """Ensure that the CSR has a valid self-signature.""" try: if not csr.verify(): - raise ValidationError("Signature on the CSR is not valid") + raise v_errors.ValidationError("Signature on the CSR is not valid") except errors.X509Error: - raise ValidationError("Signature on the CSR is not valid") + raise v_errors.ValidationError("Signature on the CSR is not valid") diff --git a/anchor/validators/errors.py b/anchor/validators/errors.py new file mode 100644 index 0000000..7918d24 --- /dev/null +++ b/anchor/validators/errors.py @@ -0,0 +1,16 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +class ValidationError(Exception): + pass diff --git a/anchor/validators_standards.py b/anchor/validators/standards.py similarity index 64% rename from anchor/validators_standards.py rename to anchor/validators/standards.py index 3c51326..726b767 100644 --- a/anchor/validators_standards.py +++ b/anchor/validators/standards.py @@ -23,9 +23,8 @@ All the rules are pulled into a single validator: ``standards_compliance``. from __future__ import absolute_import -import re - -from anchor import validators +from anchor import util +from anchor.validators import errors from anchor.X509 import extension @@ -46,7 +45,7 @@ def _no_extension_duplicates(csr): for ext in csr.get_extensions(): oid = ext.get_oid() if oid in seen_oids: - raise validators.ValidationError( + raise errors.ValidationError( "Duplicate extension with oid %s (RFC5280/4.2)" % oid) seen_oids.add(oid) @@ -56,52 +55,28 @@ def _critical_flags(csr): for ext in csr.get_extensions(): if isinstance(ext, extension.X509ExtensionSubjectAltName): if len(csr.get_subject()) == 0 and not ext.get_critical(): - raise validators.ValidationError( + raise errors.ValidationError( "SAN must be critical if subject is empty " "(RFC5280/4.1.2.6)") if isinstance(ext, extension.X509ExtensionBasicConstraints): if not ext.get_critical(): - raise validators.ValidationError( + raise errors.ValidationError( "Basic constraints has to be marked critical " "(RFC5280/4.1.2.9)") -# RFC1034 allows a simple " " too, but it's not allowed in certificates, so it -# will not match -RE_DOMAIN_LABEL = re.compile("^[a-z](?:[-a-z0-9]*[a-z0-9])?$", re.IGNORECASE) - - def _valid_domains(csr): """Format of the domin names See RFC5280 section 4.2.1.6 / RFC6125 / RFC1034 """ - def verify_domain(domain): - labels = domain.split('.') - if labels[-1] == "": - # single trailing . is ok, ignore - labels.pop(-1) - for i, label in enumerate(labels): - if len(label) > 63: - raise validators.ValidationError( - "SAN entry <%s> it too long (RFC5280/4.2.1.6)" % (domain,)) - - # check for wildcard labels, ignore partial-wildcard labels - if '*' == label: - if i != 0: - raise validators.ValidationError( - "SAN entry <%s> has wildcard that's not in the " - "left-most label (RFC6125/6.4.3)" % (domain,)) - else: - if RE_DOMAIN_LABEL.match(label) is None: - raise validators.ValidationError( - "SAN entry <%s> contains invalid characters " - "(RFC1034/3.5)" % (domain,)) - sans = csr.get_extensions(extension.X509ExtensionSubjectAltName) if not sans: return ext = sans[0] for domain in ext.get_dns_ids(): - verify_domain(domain) + try: + util.verify_domain(domain, allow_wildcards=True) + except ValueError as e: + raise errors.ValidationError(str(e)) diff --git a/anchor/validators/utils.py b/anchor/validators/utils.py new file mode 100644 index 0000000..188b434 --- /dev/null +++ b/anchor/validators/utils.py @@ -0,0 +1,74 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import + +import logging + +import netaddr + +from anchor.validators import errors +from anchor.X509 import extension + + +logger = logging.getLogger(__name__) + + +def csr_require_cn(csr): + cns = csr.get_subject_cn() + if not cns: + raise errors.ValidationError("CSR is lacking a CN in the Subject") + if len(cns) > 1: + raise errors.ValidationError("CSR has too many CN entries") + return cns[0] + + +def check_domains(domain, allowed_domains): + if allowed_domains: + if not any(domain.endswith(suffix) for suffix in allowed_domains): + # no domain matched + return False + else: + # no valid domains were provided, so we can't make any assertions + logger.warning("No domains were configured for validation. Anchor " + "will issue certificates for any domain, this is not a " + "recommended configuration for production environments") + return True + + +def iter_alternative_names(csr, types, fail_other_types=True): + for ext in csr.get_extensions(): + if isinstance(ext, extension.X509ExtensionSubjectAltName): + # TODO(stan): fail on other types + if 'DNS' in types: + for dns_id in ext.get_dns_ids(): + yield ('DNS', dns_id) + if 'IP Address' in types: + for ip in ext.get_ips(): + yield ('IP Address', ip) + + +def check_networks(ip, allowed_networks): + """Check the IP is within an allowed network.""" + if not isinstance(ip, netaddr.IPAddress): + raise TypeError("ip must be a netaddr ip address") + + if not allowed_networks: + # no valid networks were provided, so we can't make any assertions + logger.warning("No valid network IP ranges were given, skipping") + return True + + if any(ip in netaddr.IPNetwork(net) for net in allowed_networks): + return True + + return False diff --git a/setup.cfg b/setup.cfg index 3d05c2a..fe12837 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,20 +31,19 @@ anchor.signing_backends = anchor = anchor.certificate_ops:sign anchor.validators = - check_domains = anchor.validators:check_domains - iter_alternative_names = anchor.validators:iter_alternative_names - check_networks = anchor.validators:check_networks - common_name = anchor.validators:common_name - alternative_names = anchor.validators:alternative_names - alternative_names_ip = anchor.validators:alternative_names_ip - blacklist_names = anchor.validators:blacklist_names - server_group = anchor.validators:server_group - extensions = anchor.validators:extensions - key_usage = anchor.validators:key_usage - ext_key_usage = anchor.validators:ext_key_usage - ca_status = anchor.validators:ca_status - source_cidrs = anchor.validators:source_cidrs - standards_compliance = anchor.validators_standards:standards_compliance + check_domains = anchor.validators.custom:check_domains + iter_alternative_names = anchor.validators.custom:iter_alternative_names + check_networks = anchor.validators.custom:check_networks + common_name = anchor.validators.custom:common_name + alternative_names = anchor.validators.custom:alternative_names + alternative_names_ip = anchor.validators.custom:alternative_names_ip + blacklist_names = anchor.validators.custom:blacklist_names + server_group = anchor.validators.custom:server_group + extensions = anchor.validators.custom:extensions + key_usage = anchor.validators.custom:key_usage + ca_status = anchor.validators.custom:ca_status + source_cidrs = anchor.validators.custom:source_cidrs + standards_compliance = anchor.validators.standards:standards_compliance anchor.authentication = keystone = anchor.auth.keystone:login diff --git a/tests/validators/test_base_validation_functions.py b/tests/validators/test_base_validation_functions.py index 3a98464..880dc6f 100644 --- a/tests/validators/test_base_validation_functions.py +++ b/tests/validators/test_base_validation_functions.py @@ -19,7 +19,8 @@ import unittest import netaddr -from anchor import validators +from anchor.validators import errors +from anchor.validators import utils from anchor.X509 import signing_request @@ -86,32 +87,32 @@ class TestBaseValidators(unittest.TestCase): super(TestBaseValidators, self).tearDown() def test_csr_require_cn(self): - name = validators.csr_require_cn(self.csr) + name = utils.csr_require_cn(self.csr) self.assertEqual(name, "ossg.test.com") self.csr = signing_request.X509Csr.from_buffer( TestBaseValidators.csr_data_without_cn) - with self.assertRaises(validators.ValidationError): - validators.csr_require_cn(self.csr) + with self.assertRaises(errors.ValidationError): + utils.csr_require_cn(self.csr) def test_check_domains(self): test_domain = 'good.example.com' test_allowed = ['.example.com', '.example.net'] - self.assertTrue(validators.check_domains(test_domain, test_allowed)) - self.assertFalse(validators.check_domains('bad.example.org', - test_allowed)) + self.assertTrue(utils.check_domains(test_domain, test_allowed)) + self.assertFalse(utils.check_domains('bad.example.org', + test_allowed)) def test_check_networks(self): good_ip = netaddr.IPAddress('10.2.3.4') bad_ip = netaddr.IPAddress('88.2.3.4') test_allowed = ['10/8'] - self.assertTrue(validators.check_networks(good_ip, test_allowed)) - self.assertFalse(validators.check_networks(bad_ip, test_allowed)) + self.assertTrue(utils.check_networks(good_ip, test_allowed)) + self.assertFalse(utils.check_networks(bad_ip, test_allowed)) def test_check_networks_invalid(self): with self.assertRaises(TypeError): - validators.check_networks('1.2.3.4', ['10/8']) + utils.check_networks('1.2.3.4', ['10/8']) def test_check_networks_passthrough(self): good_ip = netaddr.IPAddress('10.2.3.4') - self.assertTrue(validators.check_networks(good_ip, [])) + self.assertTrue(utils.check_networks(good_ip, [])) diff --git a/tests/validators/test_callable_validators.py b/tests/validators/test_callable_validators.py index 8c2839b..bf7df95 100644 --- a/tests/validators/test_callable_validators.py +++ b/tests/validators/test_callable_validators.py @@ -21,7 +21,9 @@ import mock import netaddr from pyasn1_modules import rfc2459 -from anchor import validators +from anchor.validators import custom +from anchor.validators import errors +from anchor.validators import utils from anchor.X509 import extension as x509_ext from anchor.X509 import name as x509_name from anchor.X509 import signing_request as x509_csr @@ -50,16 +52,16 @@ class TestValidators(unittest.TestCase): def test_check_networks_good(self): allowed_networks = ['15/8', '74.125/16'] - self.assertTrue(validators.check_networks( + self.assertTrue(utils.check_networks( netaddr.IPAddress('74.125.224.64'), allowed_networks)) def test_check_networks_bad(self): allowed_networks = ['15/8', '74.125/16'] - self.assertFalse(validators.check_networks( + self.assertFalse(utils.check_networks( netaddr.IPAddress('12.2.2.2'), allowed_networks)) def test_check_domains_empty(self): - self.assertTrue(validators.check_domains( + self.assertTrue(utils.check_domains( 'example.com', [])) def test_common_name_with_two_CN(self): @@ -68,8 +70,8 @@ class TestValidators(unittest.TestCase): name.add_name_entry(x509_name.OID_commonName, "dummy_value") name.add_name_entry(x509_name.OID_commonName, "dummy_value") - with self.assertRaises(validators.ValidationError) as e: - validators.common_name( + with self.assertRaises(errors.ValidationError) as e: + custom.common_name( csr=csr, allowed_domains=[], allowed_networks=[]) @@ -78,8 +80,8 @@ class TestValidators(unittest.TestCase): def test_common_name_no_CN(self): csr = x509_csr.X509Csr() - with self.assertRaises(validators.ValidationError) as e: - validators.common_name( + with self.assertRaises(errors.ValidationError) as e: + custom.common_name( csr=csr, allowed_domains=[], allowed_networks=[]) @@ -93,7 +95,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.common_name( + custom.common_name( csr=csr, allowed_domains=['.example.com'], ) @@ -104,8 +106,8 @@ class TestValidators(unittest.TestCase): name = csr.get_subject() name.add_name_entry(x509_name.OID_commonName, 'bad.example.org') - with self.assertRaises(validators.ValidationError) as e: - validators.common_name( + with self.assertRaises(errors.ValidationError) as e: + custom.common_name( csr=csr, allowed_domains=['.example.com']) self.assertEqual("Domain 'bad.example.org' not allowed (does not " @@ -118,7 +120,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.common_name( + custom.common_name( csr=csr, allowed_domains=['.example.com'], allowed_networks=['10/8'] @@ -130,8 +132,8 @@ class TestValidators(unittest.TestCase): name = csr.get_subject() name.add_name_entry(x509_name.OID_commonName, '15.1.1.1') - with self.assertRaises(validators.ValidationError) as e: - validators.common_name( + with self.assertRaises(errors.ValidationError) as e: + custom.common_name( csr=csr, allowed_domains=['.example.com'], allowed_networks=['10/8']) @@ -146,7 +148,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.alternative_names( + custom.alternative_names( csr=csr, allowed_domains=['.example.com'], ) @@ -158,8 +160,8 @@ class TestValidators(unittest.TestCase): ext.add_dns_id('bad.example.org') csr.add_extension(ext) - with self.assertRaises(validators.ValidationError) as e: - validators.alternative_names( + with self.assertRaises(errors.ValidationError) as e: + custom.alternative_names( csr=csr, allowed_domains=['.example.com']) self.assertEqual("Domain 'bad.example.org' not allowed (doesn't " @@ -173,7 +175,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.alternative_names_ip( + custom.alternative_names_ip( csr=csr, allowed_domains=['.example.com'], allowed_networks=['10/8'] @@ -186,8 +188,8 @@ class TestValidators(unittest.TestCase): ext.add_ip(netaddr.IPAddress('10.1.1.1')) csr.add_extension(ext) - with self.assertRaises(validators.ValidationError) as e: - validators.alternative_names_ip( + with self.assertRaises(errors.ValidationError) as e: + custom.alternative_names_ip( csr=csr, allowed_domains=['.example.com'], allowed_networks=['99/8']) @@ -200,8 +202,8 @@ class TestValidators(unittest.TestCase): ext.add_dns_id('bad.example.org') csr.add_extension(ext) - with self.assertRaises(validators.ValidationError) as e: - validators.alternative_names_ip( + with self.assertRaises(errors.ValidationError) as e: + custom.alternative_names_ip( csr=csr, allowed_domains=['.example.com']) self.assertEqual("Domain 'bad.example.org' not allowed (doesn't " @@ -214,7 +216,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.server_group( + custom.server_group( auth_result=None, csr=csr, group_prefixes={} @@ -228,7 +230,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.server_group( + custom.server_group( auth_result=None, csr=csr, group_prefixes={} @@ -246,7 +248,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.server_group( + custom.server_group( auth_result=auth_result, csr=csr, group_prefixes={'nv': 'nova', 'sw': 'swift'} @@ -261,8 +263,8 @@ class TestValidators(unittest.TestCase): name = csr.get_subject() name.add_name_entry(x509_name.OID_commonName, "nv-master.example.com") - with self.assertRaises(validators.ValidationError) as e: - validators.server_group( + with self.assertRaises(errors.ValidationError) as e: + custom.server_group( auth_result=auth_result, csr=csr, group_prefixes={'nv': 'nova', 'sw': 'swift'}) @@ -275,8 +277,8 @@ class TestValidators(unittest.TestCase): ext.set_usage('keyCertSign', True) csr.add_extension(ext) - with self.assertRaises(validators.ValidationError) as e: - validators.extensions( + with self.assertRaises(errors.ValidationError) as e: + custom.extensions( csr=csr, allowed_extensions=['basicConstraints', 'nameConstraints']) self.assertEqual("Extension 'keyUsage' not allowed", str(e.exception)) @@ -289,7 +291,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.extensions( + custom.extensions( csr=csr, allowed_extensions=['basicConstraints', 'keyUsage'] ) @@ -303,7 +305,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.extensions( + custom.extensions( csr=csr, allowed_extensions=['basicConstraints', '2.5.29.15'] ) @@ -319,8 +321,8 @@ class TestValidators(unittest.TestCase): ext.set_usage('keyCertSign', True) csr.add_extension(ext) - with self.assertRaises(validators.ValidationError) as e: - validators.key_usage( + with self.assertRaises(errors.ValidationError) as e: + custom.key_usage( csr=csr, allowed_usage=allowed_usage) self.assertEqual("Found some prohibited key usages: " @@ -339,7 +341,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.key_usage( + custom.key_usage( csr=csr, allowed_usage=allowed_usage ) @@ -355,7 +357,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.ext_key_usage( + custom.ext_key_usage( csr=csr, allowed_usage=allowed_usage ) @@ -371,7 +373,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.ext_key_usage( + custom.ext_key_usage( csr=csr, allowed_usage=allowed_usage ) @@ -387,7 +389,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.ext_key_usage( + custom.ext_key_usage( csr=csr, allowed_usage=allowed_usage ) @@ -401,8 +403,8 @@ class TestValidators(unittest.TestCase): ext.set_usage(rfc2459.id_kp_clientAuth, True) csr.add_extension(ext) - with self.assertRaises(validators.ValidationError) as e: - validators.ext_key_usage( + with self.assertRaises(errors.ValidationError) as e: + custom.ext_key_usage( csr=csr, allowed_usage=allowed_usage) self.assertEqual("Found some prohibited key usages: " @@ -416,7 +418,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.ca_status( + custom.ca_status( csr=csr, ca_requested=True ) @@ -430,7 +432,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.ca_status( + custom.ca_status( csr=csr, ca_requested=False ) @@ -442,8 +444,8 @@ class TestValidators(unittest.TestCase): ext.set_ca(True) csr.add_extension(ext) - with self.assertRaises(validators.ValidationError) as e: - validators.ca_status( + with self.assertRaises(errors.ValidationError) as e: + custom.ca_status( csr=csr, ca_requested=False) self.assertEqual("CA status requested, but not allowed", @@ -455,8 +457,8 @@ class TestValidators(unittest.TestCase): ext.set_ca(False) csr.add_extension(ext) - with self.assertRaises(validators.ValidationError) as e: - validators.ca_status( + with self.assertRaises(errors.ValidationError) as e: + custom.ca_status( csr=csr, ca_requested=True) self.assertEqual("CA flags required", @@ -470,7 +472,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.ca_status( + custom.ca_status( csr=csr, ca_requested=False ) @@ -482,8 +484,8 @@ class TestValidators(unittest.TestCase): ext.set_usage('keyCertSign', True) csr.add_extension(ext) - with self.assertRaises(validators.ValidationError) as e: - validators.ca_status( + with self.assertRaises(errors.ValidationError) as e: + custom.ca_status( csr=csr, ca_requested=False) self.assertEqual("Key usage doesn't match requested CA status " @@ -497,7 +499,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.ca_status( + custom.ca_status( csr=csr, ca_requested=True ) @@ -509,8 +511,8 @@ class TestValidators(unittest.TestCase): ext.set_usage('cRLSign', True) csr.add_extension(ext) - with self.assertRaises(validators.ValidationError) as e: - validators.ca_status( + with self.assertRaises(errors.ValidationError) as e: + custom.ca_status( csr=csr, ca_requested=False) self.assertEqual("Key usage doesn't match requested CA status " @@ -524,7 +526,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.ca_status( + custom.ca_status( csr=csr, ca_requested=True ) @@ -534,7 +536,7 @@ class TestValidators(unittest.TestCase): request = mock.Mock(client_addr='127.0.0.1') self.assertEqual( None, - validators.source_cidrs( + custom.source_cidrs( request=request, cidrs=['127/8', '10/8'] ) @@ -542,8 +544,8 @@ class TestValidators(unittest.TestCase): def test_source_cidrs_out_of_range(self): request = mock.Mock(client_addr='99.0.0.1') - with self.assertRaises(validators.ValidationError) as e: - validators.source_cidrs( + with self.assertRaises(errors.ValidationError) as e: + custom.source_cidrs( request=request, cidrs=['127/8', '10/8']) self.assertEqual("No network matched the request source '99.0.0.1'", @@ -551,8 +553,8 @@ class TestValidators(unittest.TestCase): def test_source_cidrs_bad_cidr(self): request = mock.Mock(client_addr='127.0.0.1') - with self.assertRaises(validators.ValidationError) as e: - validators.source_cidrs( + with self.assertRaises(errors.ValidationError) as e: + custom.source_cidrs( request=request, cidrs=['bad']) self.assertEqual("Cidr 'bad' does not describe a valid network", @@ -566,7 +568,7 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.blacklist_names( + custom.blacklist_names( csr=csr, domains=['.example.org'], ) @@ -578,8 +580,8 @@ class TestValidators(unittest.TestCase): ext.add_dns_id('bad.example.com') csr.add_extension(ext) - with self.assertRaises(validators.ValidationError): - validators.blacklist_names( + with self.assertRaises(errors.ValidationError): + custom.blacklist_names( csr=csr, domains=['.example.com'], ) @@ -589,8 +591,8 @@ class TestValidators(unittest.TestCase): name = csr.get_subject() name.add_name_entry(x509_name.OID_commonName, "bad.example.com") - with self.assertRaises(validators.ValidationError): - validators.blacklist_names( + with self.assertRaises(errors.ValidationError): + custom.blacklist_names( csr=csr, domains=['.example.com'], ) @@ -602,8 +604,8 @@ class TestValidators(unittest.TestCase): ext.add_dns_id('good.example.com') csr.add_extension(ext) - with self.assertRaises(validators.ValidationError): - validators.blacklist_names( + with self.assertRaises(errors.ValidationError): + custom.blacklist_names( csr=csr, domains=['.example.org'], ) @@ -617,27 +619,27 @@ class TestValidators(unittest.TestCase): self.assertEqual( None, - validators.blacklist_names( + custom.blacklist_names( csr=csr, ) ) def test_csr_signature(self): csr = x509_csr.X509Csr.from_buffer(self.csr_data) - self.assertEqual(None, validators.csr_signature(csr=csr)) + self.assertEqual(None, custom.csr_signature(csr=csr)) def test_csr_signature_bad_sig(self): csr = x509_csr.X509Csr.from_buffer(self.csr_data) with mock.patch.object(x509_csr.X509Csr, '_get_signature', return_value=(b'A'*49)): - with self.assertRaisesRegexp(validators.ValidationError, + with self.assertRaisesRegexp(errors.ValidationError, "Signature on the CSR is not valid"): - validators.csr_signature(csr=csr) + custom.csr_signature(csr=csr) def test_csr_signature_bad_algo(self): csr = x509_csr.X509Csr.from_buffer(self.csr_data) with mock.patch.object(x509_csr.X509Csr, '_get_signing_algorithm', return_value=rfc2459.id_dsa_with_sha1): - with self.assertRaisesRegexp(validators.ValidationError, + with self.assertRaisesRegexp(errors.ValidationError, "Signature on the CSR is not valid"): - validators.csr_signature(csr=csr) + custom.csr_signature(csr=csr) diff --git a/tests/validators/test_standards_validator.py b/tests/validators/test_standards_validator.py index 0b0febe..186ba95 100644 --- a/tests/validators/test_standards_validator.py +++ b/tests/validators/test_standards_validator.py @@ -20,8 +20,8 @@ import unittest from pyasn1.codec.der import encoder from pyasn1_modules import rfc2459 -from anchor import validators -from anchor import validators_standards +from anchor.validators import errors +from anchor.validators import standards from anchor.X509 import extension from anchor.X509 import name from anchor.X509 import signing_request @@ -44,19 +44,19 @@ class TestStandardsValidator(unittest.TestCase): def test_passing(self): csr = signing_request.X509Csr.from_buffer(self.csr_data) - validators_standards.standards_compliance(csr=csr) + standards.standards_compliance(csr=csr) class TestExtensionDuplicates(unittest.TestCase): def test_no_extensions(self): csr = signing_request.X509Csr() - validators_standards._no_extension_duplicates(csr) + standards._no_extension_duplicates(csr) def test_no_duplicates(self): csr = signing_request.X509Csr() ext = extension.X509ExtensionSubjectAltName() csr.add_extension(ext) - validators_standards._no_extension_duplicates(csr) + standards._no_extension_duplicates(csr) def test_with_duplicates(self): csr = signing_request.X509Csr() @@ -71,8 +71,8 @@ class TestExtensionDuplicates(unittest.TestCase): attrs[0]['type'] = signing_request.OID_extensionRequest attrs[0]['vals'] = None attrs[0]['vals'][0] = encoder.encode(exts) - with self.assertRaises(validators.ValidationError): - validators_standards._no_extension_duplicates(csr) + with self.assertRaises(errors.ValidationError): + standards._no_extension_duplicates(csr) class TestExtensionCriticalFlags(unittest.TestCase): @@ -82,8 +82,8 @@ class TestExtensionCriticalFlags(unittest.TestCase): ext.set_critical(False) ext.add_dns_id('example.com') csr.add_extension(ext) - with self.assertRaises(validators.ValidationError): - validators_standards._critical_flags(csr) + with self.assertRaises(errors.ValidationError): + standards._critical_flags(csr) def test_no_subject_san_critical(self): csr = signing_request.X509Csr() @@ -91,7 +91,7 @@ class TestExtensionCriticalFlags(unittest.TestCase): ext.set_critical(True) ext.add_dns_id('example.com') csr.add_extension(ext) - validators_standards._critical_flags(csr) + standards._critical_flags(csr) def test_with_subject_san_not_critical(self): csr = signing_request.X509Csr() @@ -102,22 +102,22 @@ class TestExtensionCriticalFlags(unittest.TestCase): ext.set_critical(False) ext.add_dns_id('example.com') csr.add_extension(ext) - validators_standards._critical_flags(csr) + standards._critical_flags(csr) def test_basic_constraints_not_critical(self): csr = signing_request.X509Csr() ext = extension.X509ExtensionBasicConstraints() ext.set_critical(False) csr.add_extension(ext) - with self.assertRaises(validators.ValidationError): - validators_standards._critical_flags(csr) + with self.assertRaises(errors.ValidationError): + standards._critical_flags(csr) def test_basic_constraints_critical(self): csr = signing_request.X509Csr() ext = extension.X509ExtensionBasicConstraints() ext.set_critical(True) csr.add_extension(ext) - validators_standards._critical_flags(csr) + standards._critical_flags(csr) class TestValidDomains(unittest.TestCase): @@ -130,45 +130,45 @@ class TestValidDomains(unittest.TestCase): def test_all_valid(self): csr = self._create_csr_with_domain_san('a-123.example.com') - validators_standards._valid_domains(csr) + standards._valid_domains(csr) def test_all_valid_trailing_dot(self): csr = self._create_csr_with_domain_san('a-123.example.com.') - validators_standards._valid_domains(csr) + standards._valid_domains(csr) def test_too_long(self): csr = self._create_csr_with_domain_san( 'very-long-label-over-63-characters-' 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.example.com') - with self.assertRaises(validators.ValidationError): - validators_standards._valid_domains(csr) + with self.assertRaises(errors.ValidationError): + standards._valid_domains(csr) def test_beginning_hyphen(self): csr = self._create_csr_with_domain_san('-label.example.com.') - with self.assertRaises(validators.ValidationError): - validators_standards._valid_domains(csr) + with self.assertRaises(errors.ValidationError): + standards._valid_domains(csr) def test_trailing_hyphen(self): csr = self._create_csr_with_domain_san('label-.example.com.') - with self.assertRaises(validators.ValidationError): - validators_standards._valid_domains(csr) + with self.assertRaises(errors.ValidationError): + standards._valid_domains(csr) def test_san_space(self): # valid domain, but not in CSRs csr = self._create_csr_with_domain_san(' ') - with self.assertRaises(validators.ValidationError): - validators_standards._valid_domains(csr) + with self.assertRaises(errors.ValidationError): + standards._valid_domains(csr) def test_wildcard(self): csr = self._create_csr_with_domain_san('*.example.com') - validators_standards._valid_domains(csr) + standards._valid_domains(csr) def test_wildcard_middle(self): csr = self._create_csr_with_domain_san('foo.*.example.com') - with self.assertRaises(validators.ValidationError): - validators_standards._valid_domains(csr) + with self.assertRaises(errors.ValidationError): + standards._valid_domains(csr) def test_wildcard_partial(self): csr = self._create_csr_with_domain_san('foo*.example.com') - with self.assertRaises(validators.ValidationError): - validators_standards._valid_domains(csr) + with self.assertRaises(errors.ValidationError): + standards._valid_domains(csr)