anchor/anchor/validators/custom.py

314 lines
12 KiB
Python

#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import logging
import netaddr
from pyasn1.type import univ as pyasn1_univ
from pyasn1_modules import rfc2437 # PKCS#1
from pyasn1_modules import rfc2459
from anchor.validators import errors as v_errors
from anchor.validators import utils
from anchor.X509 import extension
from anchor.X509 import name as x509_name
logger = logging.getLogger(__name__)
def common_name(csr, allowed_domains=[], allowed_networks=[], **kwargs):
"""Check the CN entry is a known domain.
Refuse requests for certificates if they contain multiple CN
entries, or the domain does not match the list of known suffixes.
"""
alt_present = any(ext.get_name() == "subjectAltName"
for ext in csr.get_extensions())
CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName)
if len(CNs) > 1:
raise v_errors.ValidationError("Too many CNs in the request")
# rfc2459#section-4.2.1.6 says so
if len(CNs) == 0 and not alt_present:
raise v_errors.ValidationError("Alt subjects have to exist if the main"
" subject doesn't")
if len(CNs) > 0:
cn = utils.csr_require_cn(csr)
try:
# is it an IP rather than domain?
ip = netaddr.IPAddress(cn)
if not (utils.check_networks(ip, allowed_networks)):
raise v_errors.ValidationError(
"Address '%s' not allowed (does not match known networks)"
% cn)
except netaddr.AddrFormatError:
if not (utils.check_domains(cn, allowed_domains)):
raise v_errors.ValidationError(
"Domain '%s' not allowed (does not match known domains)"
% cn)
def alternative_names(csr, allowed_domains=[], **kwargs):
"""Check known domain alternative names.
Refuse requests for certificates if the domain does not match
the list of known suffixes, or network ranges.
"""
for _, name in utils.iter_alternative_names(csr, ['DNS']):
if not utils.check_domains(name, allowed_domains):
raise v_errors.ValidationError("Domain '%s' not allowed (doesn't"
" match known domains)" % name)
def alternative_names_ip(csr, allowed_domains=[], allowed_networks=[],
**kwargs):
"""Check known domain and ip alternative names.
Refuse requests for certificates if the domain does not match
the list of known suffixes, or network ranges.
"""
for name_type, name in utils.iter_alternative_names(csr,
['DNS', 'IP Address']):
if name_type == 'DNS' and not utils.check_domains(name,
allowed_domains):
raise v_errors.ValidationError("Domain '%s' not allowed (doesn't"
" match known domains)" % name)
if name_type == 'IP Address':
if not utils.check_networks(name, allowed_networks):
raise v_errors.ValidationError("IP '%s' not allowed (doesn't"
" match known networks)" % name)
def blacklist_names(csr, domains=[], **kwargs):
"""Check for blacklisted names in CN and altNames."""
if not domains:
logger.warning("No domains were configured for the blacklist filter, "
"consider disabling the step or providing a list")
return
CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName)
if len(CNs) > 0:
cn = utils.csr_require_cn(csr)
if utils.check_domains(cn, domains):
raise v_errors.ValidationError("Domain '%s' not allowed "
"(CN blacklisted)" % cn)
for _, name in utils.iter_alternative_names(csr, ['DNS'],
fail_other_types=False):
if utils.check_domains(name, domains):
raise v_errors.ValidationError("Domain '%s' not allowed "
"(alt blacklisted)" % name)
def server_group(auth_result=None, csr=None, group_prefixes={}, **kwargs):
"""Check Team prefix.
Make sure that for server names containing a team prefix, the team is
verified against the groups the user is a member of.
"""
cn = utils.csr_require_cn(csr)
parts = cn.split('-')
if len(parts) == 1 or '.' in parts[0]:
return # no prefix
if parts[0] in group_prefixes:
if group_prefixes[parts[0]] not in auth_result.groups:
raise v_errors.ValidationError(
"Server prefix doesn't match user groups")
def extensions(csr=None, allowed_extensions=[], **kwargs):
"""Ensure only accepted extensions are used."""
exts = csr.get_extensions() or []
for ext in exts:
if (ext.get_name() not in allowed_extensions and
str(ext.get_oid()) not in allowed_extensions):
raise v_errors.ValidationError("Extension '%s' not allowed"
% ext.get_name())
def key_usage(csr=None, allowed_usage=None, **kwargs):
"""Ensure only accepted key usages are specified."""
allowed = set(extension.LONG_KEY_USAGE_NAMES.get(x, x) for x in
allowed_usage)
denied = set()
for ext in (csr.get_extensions() or []):
if isinstance(ext, extension.X509ExtensionKeyUsage):
usages = set(ext.get_all_usages())
denied = denied | (usages - allowed)
if denied:
raise v_errors.ValidationError("Found some prohibited key usages: %s"
% ', '.join(denied))
def ext_key_usage(csr=None, allowed_usage=None, **kwargs):
"""Ensure only accepted extended key usages are specified."""
# transform all possible names into oids we actually check
for i, usage in enumerate(allowed_usage):
if usage in extension.EXT_KEY_USAGE_NAMES_INV:
allowed_usage[i] = extension.EXT_KEY_USAGE_NAMES_INV[usage]
elif usage in extension.EXT_KEY_USAGE_SHORT_NAMES_INV:
allowed_usage[i] = extension.EXT_KEY_USAGE_SHORT_NAMES_INV[usage]
else:
try:
oid = pyasn1_univ.ObjectIdentifier(usage)
allowed_usage[i] = oid
except Exception:
raise v_errors.ValidationError("Unknown usage: %s" % (usage,))
allowed = set(allowed_usage)
denied = set()
for ext in csr.get_extensions(extension.X509ExtensionExtendedKeyUsage):
usages = set(ext.get_all_usages())
denied = denied | (usages - allowed)
if denied:
text_denied = [extension.EXT_KEY_USAGE_SHORT_NAMES.get(x)
for x in denied]
raise v_errors.ValidationError("Found some prohibited key usages: %s"
% ', '.join(text_denied))
def source_cidrs(request=None, cidrs=None, **kwargs):
"""Ensure that the request comes from a known source."""
for cidr in cidrs:
try:
r = netaddr.IPNetwork(cidr)
if request.client_addr in r:
return
except netaddr.AddrFormatError:
raise v_errors.ValidationError(
"Cidr '%s' does not describe a valid network" % cidr)
raise v_errors.ValidationError(
"No network matched the request source '%s'" %
request.client_addr)
def public_key(csr=None, allowed_keys=None, **kwargs):
"""Ensure the public key has the known type and size.
Configuration provides a dictionary of key types and minimum sizes.
"""
if allowed_keys is None or not isinstance(allowed_keys, dict):
raise v_errors.ValidationError("Allowed keys configuration missing")
algo = csr.get_public_key_algo()
algo_names = {
rfc2437.rsaEncryption: 'RSA',
rfc2459.id_dsa: 'DSA',
}
algo_name = algo_names.get(algo)
if algo_name is None:
raise v_errors.ValidationError("Unknown public key type")
min_size = allowed_keys.get(algo_name)
if min_size is None:
raise v_errors.ValidationError(
"Key type not allowed (%s)" % (algo_name,))
if min_size == 0:
# key size is not enforced
return
if csr.get_public_key_size() < min_size:
raise v_errors.ValidationError("Key size too small")
def _split_names_by_type(names):
"""Identify ips and network ranges in a list of strings."""
allowed_domains = []
allowed_ips = []
allowed_ranges = []
for name in names:
ip = utils.maybe_ip(name)
if ip:
allowed_ips.append(ip)
continue
net = utils.maybe_range(name)
if net:
allowed_ranges.append(net)
continue
allowed_domains.append(name)
return (allowed_domains, allowed_ips, allowed_ranges)
def whitelist_names(csr=None, names=[], allow_cn_id=False, allow_dns_id=False,
allow_ip_id=False, allow_wildcard=False, **kwargs):
"""Ensure names match the whitelist in the allowed name slots."""
allowed_domains, allowed_ips, allowed_ranges = _split_names_by_type(names)
for dns_id in csr.get_subject_dns_ids():
if not allow_dns_id:
raise v_errors.ValidationError("IP-ID not allowed")
valid = False
for allowed_domain in allowed_domains:
if utils.compare_name_pattern(dns_id, allowed_domain,
allow_wildcard):
valid = True
break
if not valid:
raise v_errors.ValidationError(
"Value `%s` not allowed in DNS-ID" % (dns_id,))
for ip_id in csr.get_subject_ip_ids():
if not allow_ip_id:
raise v_errors.ValidationError("IP-ID not allowed")
if ip_id in allowed_ips:
continue
for net in allowed_ranges:
if ip_id in net:
continue
raise v_errors.ValidationError(
"Value `%s` not allowed in IP-ID" % (ip_id,))
for cn_id in csr.get_subject_cn():
if not allow_cn_id:
raise v_errors.ValidationError("CN-ID not allowed")
ip = utils.maybe_ip(cn_id)
if ip:
# current CN is an ip address
if ip in allowed_ips:
continue
if any((ip in net) for net in allowed_ranges):
continue
raise v_errors.ValidationError(
"Value `%s` not allowed in CN-ID" % (cn_id,))
else:
# current CN is a domain
valid = False
for allowed_domain in allowed_domains:
if utils.compare_name_pattern(cn_id, allowed_domain,
allow_wildcard):
valid = True
break
if valid:
continue
raise v_errors.ValidationError(
"Value `%s` not allowed in CN-ID" % (cn_id,))
if csr.has_unknown_san_entries():
raise v_errors.ValidationError("Request contains unknown SAN entries")