Move validators to separate modules

Separate the utils function so they can be used from other places
without circular dependencies.

Change-Id: I57b1a28926e67077c3d2207cdefabdb57692941a
This commit is contained in:
Stanisław Pitucha 2015-09-17 18:23:44 +10:00 committed by Tim Kelsey
parent 580d6edcce
commit cb86576afa
11 changed files with 313 additions and 263 deletions

View File

@ -22,7 +22,7 @@ import pecan
from webob import exc as http_status
from anchor import jsonloader
from anchor import validators
from anchor.validators import errors
from anchor.X509 import certificate
from anchor.X509 import extension
from anchor.X509 import signing_request
@ -85,7 +85,7 @@ def _run_validator(name, body, args):
validator(**new_kwargs)
logger.debug("_run_validator: success: <%s> ", name)
return True # validator passed b/c no exceptions
except validators.ValidationError as e:
except errors.ValidationError as e:
logger.error("_run_validator: FAILED: <%s> - %s", name, e)
return False

View File

@ -14,6 +14,7 @@
from __future__ import absolute_import
import hmac
import re
def constant_time_compare(val1, val2):
@ -44,3 +45,32 @@ def _constant_time_compare(val1, val2):
for x, y in zip(val1, val2):
result |= ord(x) ^ ord(y)
return result == 0
# RFC1034 allows a simple " " too, but it's not allowed in certificates, so it
# will not match
RE_DOMAIN_LABEL = re.compile("^[a-z](?:[-a-z0-9]*[a-z0-9])?$", re.IGNORECASE)
def verify_domain(domain, allow_wildcards=False):
labels = domain.split('.')
if labels[-1] == "":
# single trailing . is ok, ignore
labels.pop(-1)
for i, label in enumerate(labels):
if len(label) > 63:
raise ValueError(
"domain <%s> it too long (RFC5280/4.2.1.6)" % (domain,))
# check for wildcard labels, ignore partial-wildcard labels
if '*' == label and allow_wildcards:
if i != 0:
raise ValueError(
"domain <%s> has wildcard that's not in the "
"left-most label (RFC6125/6.4.3)" % (domain,))
else:
if RE_DOMAIN_LABEL.match(label) is None:
raise ValueError(
"domain <%s> contains invalid characters "
"(RFC1034/3.5)" % (domain,))

View File

View File

@ -18,6 +18,8 @@ import logging
import netaddr
from pyasn1.type import univ as pyasn1_univ
from anchor.validators import errors as v_errors
from anchor.validators import utils
from anchor.X509 import errors
from anchor.X509 import extension
from anchor.X509 import name as x509_name
@ -26,60 +28,6 @@ from anchor.X509 import name as x509_name
logger = logging.getLogger(__name__)
class ValidationError(Exception):
pass
def csr_require_cn(csr):
cns = csr.get_subject_cn()
if not cns:
raise ValidationError("CSR is lacking a CN in the Subject")
if len(cns) > 1:
raise ValidationError("CSR has too many CN entries")
return cns[0]
def check_domains(domain, allowed_domains):
if allowed_domains:
if not any(domain.endswith(suffix) for suffix in allowed_domains):
# no domain matched
return False
else:
# no valid domains were provided, so we can't make any assertions
logger.warning("No domains were configured for validation. Anchor "
"will issue certificates for any domain, this is not a "
"recommended configuration for production environments")
return True
def iter_alternative_names(csr, types, fail_other_types=True):
for ext in csr.get_extensions():
if isinstance(ext, extension.X509ExtensionSubjectAltName):
# TODO(stan): fail on other types
if 'DNS' in types:
for dns_id in ext.get_dns_ids():
yield ('DNS', dns_id)
if 'IP Address' in types:
for ip in ext.get_ips():
yield ('IP Address', ip)
def check_networks(ip, allowed_networks):
"""Check the IP is within an allowed network."""
if not isinstance(ip, netaddr.IPAddress):
raise TypeError("ip must be a netaddr ip address")
if not allowed_networks:
# no valid networks were provided, so we can't make any assertions
logger.warning("No valid network IP ranges were given, skipping")
return True
if any(ip in netaddr.IPNetwork(net) for net in allowed_networks):
return True
return False
def common_name(csr, allowed_domains=[], allowed_networks=[], **kwargs):
"""Check the CN entry is a known domain.
@ -92,25 +40,27 @@ def common_name(csr, allowed_domains=[], allowed_networks=[], **kwargs):
CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName)
if len(CNs) > 1:
raise ValidationError("Too many CNs in the request")
raise v_errors.ValidationError("Too many CNs in the request")
# rfc5280#section-4.2.1.6 says so
if len(CNs) == 0 and not alt_present:
raise ValidationError("Alt subjects have to exist if the main"
" subject doesn't")
raise v_errors.ValidationError("Alt subjects have to exist if the main"
" subject doesn't")
if len(CNs) > 0:
cn = csr_require_cn(csr)
cn = utils.csr_require_cn(csr)
try:
# is it an IP rather than domain?
ip = netaddr.IPAddress(cn)
if not (check_networks(ip, allowed_networks)):
raise ValidationError("Address '%s' not allowed (does not "
"match known networks)" % cn)
if not (utils.check_networks(ip, allowed_networks)):
raise v_errors.ValidationError(
"Address '%s' not allowed (does not match known networks)"
% cn)
except netaddr.AddrFormatError:
if not (check_domains(cn, allowed_domains)):
raise ValidationError("Domain '%s' not allowed (does not "
"match known domains)" % cn)
if not (utils.check_domains(cn, allowed_domains)):
raise v_errors.ValidationError(
"Domain '%s' not allowed (does not match known domains)"
% cn)
def alternative_names(csr, allowed_domains=[], **kwargs):
@ -120,11 +70,10 @@ def alternative_names(csr, allowed_domains=[], **kwargs):
the list of known suffixes, or network ranges.
"""
for _, name in iter_alternative_names(csr, ['DNS']):
if not check_domains(name, allowed_domains):
raise ValidationError("Domain '%s' not allowed (doesn't"
" match known domains)"
% name)
for _, name in utils.iter_alternative_names(csr, ['DNS']):
if not utils.check_domains(name, allowed_domains):
raise v_errors.ValidationError("Domain '%s' not allowed (doesn't"
" match known domains)" % name)
def alternative_names_ip(csr, allowed_domains=[], allowed_networks=[],
@ -135,14 +84,16 @@ def alternative_names_ip(csr, allowed_domains=[], allowed_networks=[],
the list of known suffixes, or network ranges.
"""
for name_type, name in iter_alternative_names(csr, ['DNS', 'IP Address']):
if name_type == 'DNS' and not check_domains(name, allowed_domains):
raise ValidationError("Domain '%s' not allowed (doesn't"
" match known domains)" % name)
for name_type, name in utils.iter_alternative_names(csr,
['DNS', 'IP Address']):
if name_type == 'DNS' and not utils.check_domains(name,
allowed_domains):
raise v_errors.ValidationError("Domain '%s' not allowed (doesn't"
" match known domains)" % name)
if name_type == 'IP Address':
if not check_networks(name, allowed_networks):
raise ValidationError("IP '%s' not allowed (doesn't"
" match known networks)" % name)
if not utils.check_networks(name, allowed_networks):
raise v_errors.ValidationError("IP '%s' not allowed (doesn't"
" match known networks)" % name)
def blacklist_names(csr, domains=[], **kwargs):
@ -155,16 +106,16 @@ def blacklist_names(csr, domains=[], **kwargs):
CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName)
if len(CNs) > 0:
cn = csr_require_cn(csr)
if check_domains(cn, domains):
raise ValidationError("Domain '%s' not allowed "
"(CN blacklisted)" % cn)
cn = utils.csr_require_cn(csr)
if utils.check_domains(cn, domains):
raise v_errors.ValidationError("Domain '%s' not allowed "
"(CN blacklisted)" % cn)
for _, name in iter_alternative_names(csr, ['DNS'],
fail_other_types=False):
if check_domains(name, domains):
raise ValidationError("Domain '%s' not allowed "
"(alt blacklisted)" % name)
for _, name in utils.iter_alternative_names(csr, ['DNS'],
fail_other_types=False):
if utils.check_domains(name, domains):
raise v_errors.ValidationError("Domain '%s' not allowed "
"(alt blacklisted)" % name)
def server_group(auth_result=None, csr=None, group_prefixes={}, **kwargs):
@ -174,14 +125,15 @@ def server_group(auth_result=None, csr=None, group_prefixes={}, **kwargs):
verified against the groups the user is a member of.
"""
cn = csr_require_cn(csr)
cn = utils.csr_require_cn(csr)
parts = cn.split('-')
if len(parts) == 1 or '.' in parts[0]:
return # no prefix
if parts[0] in group_prefixes:
if group_prefixes[parts[0]] not in auth_result.groups:
raise ValidationError("Server prefix doesn't match user groups")
raise v_errors.ValidationError(
"Server prefix doesn't match user groups")
def extensions(csr=None, allowed_extensions=[], **kwargs):
@ -190,8 +142,8 @@ def extensions(csr=None, allowed_extensions=[], **kwargs):
for ext in exts:
if (ext.get_name() not in allowed_extensions and
str(ext.get_oid()) not in allowed_extensions):
raise ValidationError("Extension '%s' not allowed"
% ext.get_name())
raise v_errors.ValidationError("Extension '%s' not allowed"
% ext.get_name())
def key_usage(csr=None, allowed_usage=None, **kwargs):
@ -205,8 +157,8 @@ def key_usage(csr=None, allowed_usage=None, **kwargs):
usages = set(ext.get_all_usages())
denied = denied | (usages - allowed)
if denied:
raise ValidationError("Found some prohibited key usages: %s"
% ', '.join(denied))
raise v_errors.ValidationError("Found some prohibited key usages: %s"
% ', '.join(denied))
def ext_key_usage(csr=None, allowed_usage=None, **kwargs):
@ -223,7 +175,7 @@ def ext_key_usage(csr=None, allowed_usage=None, **kwargs):
oid = pyasn1_univ.ObjectIdentifier(usage)
allowed_usage[i] = oid
except Exception:
raise ValidationError("Unknown usage: %s" % (usage,))
raise v_errors.ValidationError("Unknown usage: %s" % (usage,))
allowed = set(allowed_usage)
denied = set()
@ -234,8 +186,8 @@ def ext_key_usage(csr=None, allowed_usage=None, **kwargs):
if denied:
text_denied = [extension.EXT_KEY_USAGE_SHORT_NAMES.get(x)
for x in denied]
raise ValidationError("Found some prohibited key usages: %s"
% ', '.join(text_denied))
raise v_errors.ValidationError("Found some prohibited key usages: %s"
% ', '.join(text_denied))
def ca_status(csr=None, ca_requested=False, **kwargs):
@ -245,7 +197,7 @@ def ca_status(csr=None, ca_requested=False, **kwargs):
if isinstance(ext, extension.X509ExtensionBasicConstraints):
if ext.get_ca():
if not ca_requested:
raise ValidationError(
raise v_errors.ValidationError(
"CA status requested, but not allowed")
request_ca_flags = True
elif isinstance(ext, extension.X509ExtensionKeyUsage):
@ -253,13 +205,13 @@ def ca_status(csr=None, ca_requested=False, **kwargs):
has_crl_sign = ext.get_usage('cRLSign')
if has_crl_sign or has_cert_sign:
if not ca_requested:
raise ValidationError(
raise v_errors.ValidationError(
"Key usage doesn't match requested CA status "
"(keyCertSign/cRLSign: %s/%s)"
% (has_cert_sign, has_crl_sign))
request_ca_flags = True
if ca_requested and not request_ca_flags:
raise ValidationError("CA flags required")
raise v_errors.ValidationError("CA flags required")
def source_cidrs(request=None, cidrs=None, **kwargs):
@ -270,16 +222,17 @@ def source_cidrs(request=None, cidrs=None, **kwargs):
if request.client_addr in r:
return
except netaddr.AddrFormatError:
raise ValidationError("Cidr '%s' does not describe a valid"
" network" % cidr)
raise ValidationError("No network matched the request source '%s'" %
request.client_addr)
raise v_errors.ValidationError(
"Cidr '%s' does not describe a valid network" % cidr)
raise v_errors.ValidationError(
"No network matched the request source '%s'" %
request.client_addr)
def csr_signature(csr=None, **kwargs):
"""Ensure that the CSR has a valid self-signature."""
try:
if not csr.verify():
raise ValidationError("Signature on the CSR is not valid")
raise v_errors.ValidationError("Signature on the CSR is not valid")
except errors.X509Error:
raise ValidationError("Signature on the CSR is not valid")
raise v_errors.ValidationError("Signature on the CSR is not valid")

View File

@ -0,0 +1,16 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
class ValidationError(Exception):
pass

View File

@ -23,9 +23,8 @@ All the rules are pulled into a single validator: ``standards_compliance``.
from __future__ import absolute_import
import re
from anchor import validators
from anchor import util
from anchor.validators import errors
from anchor.X509 import extension
@ -46,7 +45,7 @@ def _no_extension_duplicates(csr):
for ext in csr.get_extensions():
oid = ext.get_oid()
if oid in seen_oids:
raise validators.ValidationError(
raise errors.ValidationError(
"Duplicate extension with oid %s (RFC5280/4.2)" % oid)
seen_oids.add(oid)
@ -56,52 +55,28 @@ def _critical_flags(csr):
for ext in csr.get_extensions():
if isinstance(ext, extension.X509ExtensionSubjectAltName):
if len(csr.get_subject()) == 0 and not ext.get_critical():
raise validators.ValidationError(
raise errors.ValidationError(
"SAN must be critical if subject is empty "
"(RFC5280/4.1.2.6)")
if isinstance(ext, extension.X509ExtensionBasicConstraints):
if not ext.get_critical():
raise validators.ValidationError(
raise errors.ValidationError(
"Basic constraints has to be marked critical "
"(RFC5280/4.1.2.9)")
# RFC1034 allows a simple " " too, but it's not allowed in certificates, so it
# will not match
RE_DOMAIN_LABEL = re.compile("^[a-z](?:[-a-z0-9]*[a-z0-9])?$", re.IGNORECASE)
def _valid_domains(csr):
"""Format of the domin names
See RFC5280 section 4.2.1.6 / RFC6125 / RFC1034
"""
def verify_domain(domain):
labels = domain.split('.')
if labels[-1] == "":
# single trailing . is ok, ignore
labels.pop(-1)
for i, label in enumerate(labels):
if len(label) > 63:
raise validators.ValidationError(
"SAN entry <%s> it too long (RFC5280/4.2.1.6)" % (domain,))
# check for wildcard labels, ignore partial-wildcard labels
if '*' == label:
if i != 0:
raise validators.ValidationError(
"SAN entry <%s> has wildcard that's not in the "
"left-most label (RFC6125/6.4.3)" % (domain,))
else:
if RE_DOMAIN_LABEL.match(label) is None:
raise validators.ValidationError(
"SAN entry <%s> contains invalid characters "
"(RFC1034/3.5)" % (domain,))
sans = csr.get_extensions(extension.X509ExtensionSubjectAltName)
if not sans:
return
ext = sans[0]
for domain in ext.get_dns_ids():
verify_domain(domain)
try:
util.verify_domain(domain, allow_wildcards=True)
except ValueError as e:
raise errors.ValidationError(str(e))

View File

@ -0,0 +1,74 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import logging
import netaddr
from anchor.validators import errors
from anchor.X509 import extension
logger = logging.getLogger(__name__)
def csr_require_cn(csr):
cns = csr.get_subject_cn()
if not cns:
raise errors.ValidationError("CSR is lacking a CN in the Subject")
if len(cns) > 1:
raise errors.ValidationError("CSR has too many CN entries")
return cns[0]
def check_domains(domain, allowed_domains):
if allowed_domains:
if not any(domain.endswith(suffix) for suffix in allowed_domains):
# no domain matched
return False
else:
# no valid domains were provided, so we can't make any assertions
logger.warning("No domains were configured for validation. Anchor "
"will issue certificates for any domain, this is not a "
"recommended configuration for production environments")
return True
def iter_alternative_names(csr, types, fail_other_types=True):
for ext in csr.get_extensions():
if isinstance(ext, extension.X509ExtensionSubjectAltName):
# TODO(stan): fail on other types
if 'DNS' in types:
for dns_id in ext.get_dns_ids():
yield ('DNS', dns_id)
if 'IP Address' in types:
for ip in ext.get_ips():
yield ('IP Address', ip)
def check_networks(ip, allowed_networks):
"""Check the IP is within an allowed network."""
if not isinstance(ip, netaddr.IPAddress):
raise TypeError("ip must be a netaddr ip address")
if not allowed_networks:
# no valid networks were provided, so we can't make any assertions
logger.warning("No valid network IP ranges were given, skipping")
return True
if any(ip in netaddr.IPNetwork(net) for net in allowed_networks):
return True
return False

View File

@ -31,20 +31,19 @@ anchor.signing_backends =
anchor = anchor.certificate_ops:sign
anchor.validators =
check_domains = anchor.validators:check_domains
iter_alternative_names = anchor.validators:iter_alternative_names
check_networks = anchor.validators:check_networks
common_name = anchor.validators:common_name
alternative_names = anchor.validators:alternative_names
alternative_names_ip = anchor.validators:alternative_names_ip
blacklist_names = anchor.validators:blacklist_names
server_group = anchor.validators:server_group
extensions = anchor.validators:extensions
key_usage = anchor.validators:key_usage
ext_key_usage = anchor.validators:ext_key_usage
ca_status = anchor.validators:ca_status
source_cidrs = anchor.validators:source_cidrs
standards_compliance = anchor.validators_standards:standards_compliance
check_domains = anchor.validators.custom:check_domains
iter_alternative_names = anchor.validators.custom:iter_alternative_names
check_networks = anchor.validators.custom:check_networks
common_name = anchor.validators.custom:common_name
alternative_names = anchor.validators.custom:alternative_names
alternative_names_ip = anchor.validators.custom:alternative_names_ip
blacklist_names = anchor.validators.custom:blacklist_names
server_group = anchor.validators.custom:server_group
extensions = anchor.validators.custom:extensions
key_usage = anchor.validators.custom:key_usage
ca_status = anchor.validators.custom:ca_status
source_cidrs = anchor.validators.custom:source_cidrs
standards_compliance = anchor.validators.standards:standards_compliance
anchor.authentication =
keystone = anchor.auth.keystone:login

View File

@ -19,7 +19,8 @@ import unittest
import netaddr
from anchor import validators
from anchor.validators import errors
from anchor.validators import utils
from anchor.X509 import signing_request
@ -86,32 +87,32 @@ class TestBaseValidators(unittest.TestCase):
super(TestBaseValidators, self).tearDown()
def test_csr_require_cn(self):
name = validators.csr_require_cn(self.csr)
name = utils.csr_require_cn(self.csr)
self.assertEqual(name, "ossg.test.com")
self.csr = signing_request.X509Csr.from_buffer(
TestBaseValidators.csr_data_without_cn)
with self.assertRaises(validators.ValidationError):
validators.csr_require_cn(self.csr)
with self.assertRaises(errors.ValidationError):
utils.csr_require_cn(self.csr)
def test_check_domains(self):
test_domain = 'good.example.com'
test_allowed = ['.example.com', '.example.net']
self.assertTrue(validators.check_domains(test_domain, test_allowed))
self.assertFalse(validators.check_domains('bad.example.org',
test_allowed))
self.assertTrue(utils.check_domains(test_domain, test_allowed))
self.assertFalse(utils.check_domains('bad.example.org',
test_allowed))
def test_check_networks(self):
good_ip = netaddr.IPAddress('10.2.3.4')
bad_ip = netaddr.IPAddress('88.2.3.4')
test_allowed = ['10/8']
self.assertTrue(validators.check_networks(good_ip, test_allowed))
self.assertFalse(validators.check_networks(bad_ip, test_allowed))
self.assertTrue(utils.check_networks(good_ip, test_allowed))
self.assertFalse(utils.check_networks(bad_ip, test_allowed))
def test_check_networks_invalid(self):
with self.assertRaises(TypeError):
validators.check_networks('1.2.3.4', ['10/8'])
utils.check_networks('1.2.3.4', ['10/8'])
def test_check_networks_passthrough(self):
good_ip = netaddr.IPAddress('10.2.3.4')
self.assertTrue(validators.check_networks(good_ip, []))
self.assertTrue(utils.check_networks(good_ip, []))

View File

@ -21,7 +21,9 @@ import mock
import netaddr
from pyasn1_modules import rfc2459
from anchor import validators
from anchor.validators import custom
from anchor.validators import errors
from anchor.validators import utils
from anchor.X509 import extension as x509_ext
from anchor.X509 import name as x509_name
from anchor.X509 import signing_request as x509_csr
@ -50,16 +52,16 @@ class TestValidators(unittest.TestCase):
def test_check_networks_good(self):
allowed_networks = ['15/8', '74.125/16']
self.assertTrue(validators.check_networks(
self.assertTrue(utils.check_networks(
netaddr.IPAddress('74.125.224.64'), allowed_networks))
def test_check_networks_bad(self):
allowed_networks = ['15/8', '74.125/16']
self.assertFalse(validators.check_networks(
self.assertFalse(utils.check_networks(
netaddr.IPAddress('12.2.2.2'), allowed_networks))
def test_check_domains_empty(self):
self.assertTrue(validators.check_domains(
self.assertTrue(utils.check_domains(
'example.com', []))
def test_common_name_with_two_CN(self):
@ -68,8 +70,8 @@ class TestValidators(unittest.TestCase):
name.add_name_entry(x509_name.OID_commonName, "dummy_value")
name.add_name_entry(x509_name.OID_commonName, "dummy_value")
with self.assertRaises(validators.ValidationError) as e:
validators.common_name(
with self.assertRaises(errors.ValidationError) as e:
custom.common_name(
csr=csr,
allowed_domains=[],
allowed_networks=[])
@ -78,8 +80,8 @@ class TestValidators(unittest.TestCase):
def test_common_name_no_CN(self):
csr = x509_csr.X509Csr()
with self.assertRaises(validators.ValidationError) as e:
validators.common_name(
with self.assertRaises(errors.ValidationError) as e:
custom.common_name(
csr=csr,
allowed_domains=[],
allowed_networks=[])
@ -93,7 +95,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.common_name(
custom.common_name(
csr=csr,
allowed_domains=['.example.com'],
)
@ -104,8 +106,8 @@ class TestValidators(unittest.TestCase):
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, 'bad.example.org')
with self.assertRaises(validators.ValidationError) as e:
validators.common_name(
with self.assertRaises(errors.ValidationError) as e:
custom.common_name(
csr=csr,
allowed_domains=['.example.com'])
self.assertEqual("Domain 'bad.example.org' not allowed (does not "
@ -118,7 +120,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.common_name(
custom.common_name(
csr=csr,
allowed_domains=['.example.com'],
allowed_networks=['10/8']
@ -130,8 +132,8 @@ class TestValidators(unittest.TestCase):
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, '15.1.1.1')
with self.assertRaises(validators.ValidationError) as e:
validators.common_name(
with self.assertRaises(errors.ValidationError) as e:
custom.common_name(
csr=csr,
allowed_domains=['.example.com'],
allowed_networks=['10/8'])
@ -146,7 +148,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.alternative_names(
custom.alternative_names(
csr=csr,
allowed_domains=['.example.com'],
)
@ -158,8 +160,8 @@ class TestValidators(unittest.TestCase):
ext.add_dns_id('bad.example.org')
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names(
with self.assertRaises(errors.ValidationError) as e:
custom.alternative_names(
csr=csr,
allowed_domains=['.example.com'])
self.assertEqual("Domain 'bad.example.org' not allowed (doesn't "
@ -173,7 +175,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.alternative_names_ip(
custom.alternative_names_ip(
csr=csr,
allowed_domains=['.example.com'],
allowed_networks=['10/8']
@ -186,8 +188,8 @@ class TestValidators(unittest.TestCase):
ext.add_ip(netaddr.IPAddress('10.1.1.1'))
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names_ip(
with self.assertRaises(errors.ValidationError) as e:
custom.alternative_names_ip(
csr=csr,
allowed_domains=['.example.com'],
allowed_networks=['99/8'])
@ -200,8 +202,8 @@ class TestValidators(unittest.TestCase):
ext.add_dns_id('bad.example.org')
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names_ip(
with self.assertRaises(errors.ValidationError) as e:
custom.alternative_names_ip(
csr=csr,
allowed_domains=['.example.com'])
self.assertEqual("Domain 'bad.example.org' not allowed (doesn't "
@ -214,7 +216,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.server_group(
custom.server_group(
auth_result=None,
csr=csr,
group_prefixes={}
@ -228,7 +230,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.server_group(
custom.server_group(
auth_result=None,
csr=csr,
group_prefixes={}
@ -246,7 +248,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.server_group(
custom.server_group(
auth_result=auth_result,
csr=csr,
group_prefixes={'nv': 'nova', 'sw': 'swift'}
@ -261,8 +263,8 @@ class TestValidators(unittest.TestCase):
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "nv-master.example.com")
with self.assertRaises(validators.ValidationError) as e:
validators.server_group(
with self.assertRaises(errors.ValidationError) as e:
custom.server_group(
auth_result=auth_result,
csr=csr,
group_prefixes={'nv': 'nova', 'sw': 'swift'})
@ -275,8 +277,8 @@ class TestValidators(unittest.TestCase):
ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.extensions(
with self.assertRaises(errors.ValidationError) as e:
custom.extensions(
csr=csr,
allowed_extensions=['basicConstraints', 'nameConstraints'])
self.assertEqual("Extension 'keyUsage' not allowed", str(e.exception))
@ -289,7 +291,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.extensions(
custom.extensions(
csr=csr,
allowed_extensions=['basicConstraints', 'keyUsage']
)
@ -303,7 +305,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.extensions(
custom.extensions(
csr=csr,
allowed_extensions=['basicConstraints', '2.5.29.15']
)
@ -319,8 +321,8 @@ class TestValidators(unittest.TestCase):
ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.key_usage(
with self.assertRaises(errors.ValidationError) as e:
custom.key_usage(
csr=csr,
allowed_usage=allowed_usage)
self.assertEqual("Found some prohibited key usages: "
@ -339,7 +341,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.key_usage(
custom.key_usage(
csr=csr,
allowed_usage=allowed_usage
)
@ -355,7 +357,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.ext_key_usage(
custom.ext_key_usage(
csr=csr,
allowed_usage=allowed_usage
)
@ -371,7 +373,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.ext_key_usage(
custom.ext_key_usage(
csr=csr,
allowed_usage=allowed_usage
)
@ -387,7 +389,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.ext_key_usage(
custom.ext_key_usage(
csr=csr,
allowed_usage=allowed_usage
)
@ -401,8 +403,8 @@ class TestValidators(unittest.TestCase):
ext.set_usage(rfc2459.id_kp_clientAuth, True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.ext_key_usage(
with self.assertRaises(errors.ValidationError) as e:
custom.ext_key_usage(
csr=csr,
allowed_usage=allowed_usage)
self.assertEqual("Found some prohibited key usages: "
@ -416,7 +418,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.ca_status(
custom.ca_status(
csr=csr,
ca_requested=True
)
@ -430,7 +432,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.ca_status(
custom.ca_status(
csr=csr,
ca_requested=False
)
@ -442,8 +444,8 @@ class TestValidators(unittest.TestCase):
ext.set_ca(True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
with self.assertRaises(errors.ValidationError) as e:
custom.ca_status(
csr=csr,
ca_requested=False)
self.assertEqual("CA status requested, but not allowed",
@ -455,8 +457,8 @@ class TestValidators(unittest.TestCase):
ext.set_ca(False)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
with self.assertRaises(errors.ValidationError) as e:
custom.ca_status(
csr=csr,
ca_requested=True)
self.assertEqual("CA flags required",
@ -470,7 +472,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.ca_status(
custom.ca_status(
csr=csr,
ca_requested=False
)
@ -482,8 +484,8 @@ class TestValidators(unittest.TestCase):
ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
with self.assertRaises(errors.ValidationError) as e:
custom.ca_status(
csr=csr,
ca_requested=False)
self.assertEqual("Key usage doesn't match requested CA status "
@ -497,7 +499,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.ca_status(
custom.ca_status(
csr=csr,
ca_requested=True
)
@ -509,8 +511,8 @@ class TestValidators(unittest.TestCase):
ext.set_usage('cRLSign', True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
with self.assertRaises(errors.ValidationError) as e:
custom.ca_status(
csr=csr,
ca_requested=False)
self.assertEqual("Key usage doesn't match requested CA status "
@ -524,7 +526,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.ca_status(
custom.ca_status(
csr=csr,
ca_requested=True
)
@ -534,7 +536,7 @@ class TestValidators(unittest.TestCase):
request = mock.Mock(client_addr='127.0.0.1')
self.assertEqual(
None,
validators.source_cidrs(
custom.source_cidrs(
request=request,
cidrs=['127/8', '10/8']
)
@ -542,8 +544,8 @@ class TestValidators(unittest.TestCase):
def test_source_cidrs_out_of_range(self):
request = mock.Mock(client_addr='99.0.0.1')
with self.assertRaises(validators.ValidationError) as e:
validators.source_cidrs(
with self.assertRaises(errors.ValidationError) as e:
custom.source_cidrs(
request=request,
cidrs=['127/8', '10/8'])
self.assertEqual("No network matched the request source '99.0.0.1'",
@ -551,8 +553,8 @@ class TestValidators(unittest.TestCase):
def test_source_cidrs_bad_cidr(self):
request = mock.Mock(client_addr='127.0.0.1')
with self.assertRaises(validators.ValidationError) as e:
validators.source_cidrs(
with self.assertRaises(errors.ValidationError) as e:
custom.source_cidrs(
request=request,
cidrs=['bad'])
self.assertEqual("Cidr 'bad' does not describe a valid network",
@ -566,7 +568,7 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.blacklist_names(
custom.blacklist_names(
csr=csr,
domains=['.example.org'],
)
@ -578,8 +580,8 @@ class TestValidators(unittest.TestCase):
ext.add_dns_id('bad.example.com')
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError):
validators.blacklist_names(
with self.assertRaises(errors.ValidationError):
custom.blacklist_names(
csr=csr,
domains=['.example.com'],
)
@ -589,8 +591,8 @@ class TestValidators(unittest.TestCase):
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "bad.example.com")
with self.assertRaises(validators.ValidationError):
validators.blacklist_names(
with self.assertRaises(errors.ValidationError):
custom.blacklist_names(
csr=csr,
domains=['.example.com'],
)
@ -602,8 +604,8 @@ class TestValidators(unittest.TestCase):
ext.add_dns_id('good.example.com')
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError):
validators.blacklist_names(
with self.assertRaises(errors.ValidationError):
custom.blacklist_names(
csr=csr,
domains=['.example.org'],
)
@ -617,27 +619,27 @@ class TestValidators(unittest.TestCase):
self.assertEqual(
None,
validators.blacklist_names(
custom.blacklist_names(
csr=csr,
)
)
def test_csr_signature(self):
csr = x509_csr.X509Csr.from_buffer(self.csr_data)
self.assertEqual(None, validators.csr_signature(csr=csr))
self.assertEqual(None, custom.csr_signature(csr=csr))
def test_csr_signature_bad_sig(self):
csr = x509_csr.X509Csr.from_buffer(self.csr_data)
with mock.patch.object(x509_csr.X509Csr, '_get_signature',
return_value=(b'A'*49)):
with self.assertRaisesRegexp(validators.ValidationError,
with self.assertRaisesRegexp(errors.ValidationError,
"Signature on the CSR is not valid"):
validators.csr_signature(csr=csr)
custom.csr_signature(csr=csr)
def test_csr_signature_bad_algo(self):
csr = x509_csr.X509Csr.from_buffer(self.csr_data)
with mock.patch.object(x509_csr.X509Csr, '_get_signing_algorithm',
return_value=rfc2459.id_dsa_with_sha1):
with self.assertRaisesRegexp(validators.ValidationError,
with self.assertRaisesRegexp(errors.ValidationError,
"Signature on the CSR is not valid"):
validators.csr_signature(csr=csr)
custom.csr_signature(csr=csr)

View File

@ -20,8 +20,8 @@ import unittest
from pyasn1.codec.der import encoder
from pyasn1_modules import rfc2459
from anchor import validators
from anchor import validators_standards
from anchor.validators import errors
from anchor.validators import standards
from anchor.X509 import extension
from anchor.X509 import name
from anchor.X509 import signing_request
@ -44,19 +44,19 @@ class TestStandardsValidator(unittest.TestCase):
def test_passing(self):
csr = signing_request.X509Csr.from_buffer(self.csr_data)
validators_standards.standards_compliance(csr=csr)
standards.standards_compliance(csr=csr)
class TestExtensionDuplicates(unittest.TestCase):
def test_no_extensions(self):
csr = signing_request.X509Csr()
validators_standards._no_extension_duplicates(csr)
standards._no_extension_duplicates(csr)
def test_no_duplicates(self):
csr = signing_request.X509Csr()
ext = extension.X509ExtensionSubjectAltName()
csr.add_extension(ext)
validators_standards._no_extension_duplicates(csr)
standards._no_extension_duplicates(csr)
def test_with_duplicates(self):
csr = signing_request.X509Csr()
@ -71,8 +71,8 @@ class TestExtensionDuplicates(unittest.TestCase):
attrs[0]['type'] = signing_request.OID_extensionRequest
attrs[0]['vals'] = None
attrs[0]['vals'][0] = encoder.encode(exts)
with self.assertRaises(validators.ValidationError):
validators_standards._no_extension_duplicates(csr)
with self.assertRaises(errors.ValidationError):
standards._no_extension_duplicates(csr)
class TestExtensionCriticalFlags(unittest.TestCase):
@ -82,8 +82,8 @@ class TestExtensionCriticalFlags(unittest.TestCase):
ext.set_critical(False)
ext.add_dns_id('example.com')
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError):
validators_standards._critical_flags(csr)
with self.assertRaises(errors.ValidationError):
standards._critical_flags(csr)
def test_no_subject_san_critical(self):
csr = signing_request.X509Csr()
@ -91,7 +91,7 @@ class TestExtensionCriticalFlags(unittest.TestCase):
ext.set_critical(True)
ext.add_dns_id('example.com')
csr.add_extension(ext)
validators_standards._critical_flags(csr)
standards._critical_flags(csr)
def test_with_subject_san_not_critical(self):
csr = signing_request.X509Csr()
@ -102,22 +102,22 @@ class TestExtensionCriticalFlags(unittest.TestCase):
ext.set_critical(False)
ext.add_dns_id('example.com')
csr.add_extension(ext)
validators_standards._critical_flags(csr)
standards._critical_flags(csr)
def test_basic_constraints_not_critical(self):
csr = signing_request.X509Csr()
ext = extension.X509ExtensionBasicConstraints()
ext.set_critical(False)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError):
validators_standards._critical_flags(csr)
with self.assertRaises(errors.ValidationError):
standards._critical_flags(csr)
def test_basic_constraints_critical(self):
csr = signing_request.X509Csr()
ext = extension.X509ExtensionBasicConstraints()
ext.set_critical(True)
csr.add_extension(ext)
validators_standards._critical_flags(csr)
standards._critical_flags(csr)
class TestValidDomains(unittest.TestCase):
@ -130,45 +130,45 @@ class TestValidDomains(unittest.TestCase):
def test_all_valid(self):
csr = self._create_csr_with_domain_san('a-123.example.com')
validators_standards._valid_domains(csr)
standards._valid_domains(csr)
def test_all_valid_trailing_dot(self):
csr = self._create_csr_with_domain_san('a-123.example.com.')
validators_standards._valid_domains(csr)
standards._valid_domains(csr)
def test_too_long(self):
csr = self._create_csr_with_domain_san(
'very-long-label-over-63-characters-'
'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.example.com')
with self.assertRaises(validators.ValidationError):
validators_standards._valid_domains(csr)
with self.assertRaises(errors.ValidationError):
standards._valid_domains(csr)
def test_beginning_hyphen(self):
csr = self._create_csr_with_domain_san('-label.example.com.')
with self.assertRaises(validators.ValidationError):
validators_standards._valid_domains(csr)
with self.assertRaises(errors.ValidationError):
standards._valid_domains(csr)
def test_trailing_hyphen(self):
csr = self._create_csr_with_domain_san('label-.example.com.')
with self.assertRaises(validators.ValidationError):
validators_standards._valid_domains(csr)
with self.assertRaises(errors.ValidationError):
standards._valid_domains(csr)
def test_san_space(self):
# valid domain, but not in CSRs
csr = self._create_csr_with_domain_san(' ')
with self.assertRaises(validators.ValidationError):
validators_standards._valid_domains(csr)
with self.assertRaises(errors.ValidationError):
standards._valid_domains(csr)
def test_wildcard(self):
csr = self._create_csr_with_domain_san('*.example.com')
validators_standards._valid_domains(csr)
standards._valid_domains(csr)
def test_wildcard_middle(self):
csr = self._create_csr_with_domain_san('foo.*.example.com')
with self.assertRaises(validators.ValidationError):
validators_standards._valid_domains(csr)
with self.assertRaises(errors.ValidationError):
standards._valid_domains(csr)
def test_wildcard_partial(self):
csr = self._create_csr_with_domain_san('foo*.example.com')
with self.assertRaises(validators.ValidationError):
validators_standards._valid_domains(csr)
with self.assertRaises(errors.ValidationError):
standards._valid_domains(csr)