Integrate PyASN1 for certificate operations

Instead of relying on openssl code for certificate parsing, use the
ASN.1 representation directly. All previous features are supported. Not
all the extensions are full parsed yet, but the code doesn't require
them for now.

The code makes accessing and modifying the certificate structure simpler
and requires less error checking than the original version. The code
leaves few TODOs, but nothing that destroys previous behaviour.

It still uses the cryptography.io backend for loading keys and producing
signatures for the certificates.

Implements: blueprint direct-asn1
Change-Id: Ic555d3d056ca8da7016e2d8b434506cf214d06a1
This commit is contained in:
Stanisław Pitucha 2015-07-22 14:03:02 +10:00
parent 76c76e6ac5
commit ef390f5f54
19 changed files with 1340 additions and 1125 deletions

View File

@ -13,127 +13,108 @@
from __future__ import absolute_import
from cryptography.hazmat.backends.openssl import backend
import base64
import binascii
import io
from cryptography.hazmat import backends as cio_backends
from cryptography.hazmat.primitives.asymmetric import dsa
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import hashes
from pyasn1.codec.der import decoder
from pyasn1.codec.der import encoder
from pyasn1.type import univ as asn1_univ
from pyasn1_modules import pem
from pyasn1_modules import rfc2459 # X509v3
from anchor.X509 import errors
from anchor.X509 import message_digest
from anchor.X509 import extension
from anchor.X509 import name
from anchor.X509 import utils
SIGNING_ALGORITHMS = {
('RSA', 'MD5'): rfc2459.md5WithRSAEncryption,
('RSA', 'SHA1'): rfc2459.sha1WithRSAEncryption,
('RSA', 'SHA224'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.14'),
('RSA', 'SHA256'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.11'),
('RSA', 'SHA384'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.12'),
('RSA', 'SHA512'): asn1_univ.ObjectIdentifier('1.2.840.113549.1.1.13'),
('DSA', 'SHA1'): rfc2459.id_dsa_with_sha1,
('DSA', 'SHA224'): asn1_univ.ObjectIdentifier('2.16.840.1.101.3.4.3.1'),
('DSA', 'SHA256'): asn1_univ.ObjectIdentifier('2.16.840.1.101.3.4.3.2'),
}
class X509CertificateError(errors.X509Error):
"""Specific error for X509 certificate operations."""
def __init__(self, what):
super(X509CertificateError, self).__init__(what)
class X509Extension(object):
"""An X509 V3 Certificate extension."""
def __init__(self, ext):
self._lib = backend._lib
self._ffi = backend._ffi
self._ext = ext
def __str__(self):
return "%s %s" % (self.get_name(), self.get_value())
def get_name(self):
"""Get the extension name as a python string."""
ext_obj = self._lib.X509_EXTENSION_get_object(self._ext)
ext_nid = self._lib.OBJ_obj2nid(ext_obj)
ext_name_str = self._lib.OBJ_nid2sn(ext_nid)
return self._ffi.string(ext_name_str).decode('ascii')
def get_value(self):
"""Get the extension value as a python string."""
bio = self._lib.BIO_new(self._lib.BIO_s_mem())
bio = self._ffi.gc(bio, self._lib.BIO_free)
self._lib.X509V3_EXT_print(bio, self._ext, 0, 0)
size = 1024
data = self._ffi.new("char[]", size)
self._lib.BIO_gets(bio, data, size)
return self._ffi.string(data).decode('ascii')
pass
class X509Certificate(object):
"""X509 certificate class."""
def __init__(self):
self._lib = backend._lib
self._ffi = backend._ffi
certObj = self._lib.X509_new()
if certObj == self._ffi.NULL:
raise X509CertificateError("Could not create X509 certificate "
"object") # pragma: no cover
def __init__(self, certificate=None):
if certificate is None:
self._cert = rfc2459.Certificate()
self._cert['tbsCertificate'] = rfc2459.TBSCertificate()
else:
self._cert = certificate
self._certObj = certObj
@staticmethod
def from_open_file(f):
try:
der_content = pem.readPemFromFile(f)
certificate = decoder.decode(der_content,
asn1Spec=rfc2459.Certificate())[0]
return X509Certificate(certificate)
except Exception:
raise X509CertificateError("Could not read X509 certificate from "
"PEM data.")
def __del__(self):
if getattr(self, '_certObj', None):
self._lib.X509_free(self._certObj)
def from_buffer(self, data):
@staticmethod
def from_buffer(data):
"""Build this X509 object from a data buffer in memory.
:param data: A data buffer
"""
if type(data) != bytes:
data = data.encode('ascii')
bio = backend._bytes_to_bio(data)
return X509Certificate.from_open_file(io.StringIO(data))
# NOTE(tkelsey): some versions of OpenSSL dont re-use the cert object
# properly, so free it and use the new one
#
certObj = self._lib.PEM_read_bio_X509(bio[0],
self._ffi.NULL,
self._ffi.NULL,
self._ffi.NULL)
if certObj == self._ffi.NULL:
raise X509CertificateError("Could not read X509 certificate from "
"PEM data.")
self._lib.X509_free(self._certObj)
self._certObj = certObj
def from_file(self, path):
@staticmethod
def from_file(path):
"""Build this X509 certificate object from a data file on disk.
:param path: A data buffer
"""
data = None
with open(path, 'rb') as f:
data = f.read()
self.from_buffer(data)
with open(path, 'r') as f:
return X509Certificate.from_open_file(f)
def as_pem(self):
"""Serialise this X509 certificate object as PEM string."""
raw_bio = self._lib.BIO_new(self._lib.BIO_s_mem())
bio = self._ffi.gc(raw_bio, self._lib.BIO_free)
ret = self._lib.PEM_write_bio_X509(bio, self._certObj)
if ret == 0:
raise X509CertificateError("Could not write X509 certificate "
"as PEM data.") # pragma: no cover
buf = self._ffi.new("char**")
pem_len = self._lib.BIO_get_mem_data(bio, buf)
pem = self._ffi.string(buf[0], pem_len)
return pem
header = '-----BEGIN CERTIFICATE-----'
footer = '-----END CERTIFICATE-----'
der_cert = encoder.encode(self._cert)
b64_encoder = (base64.encodestring if str is bytes else
base64.encodebytes)
b64_cert = b64_encoder(der_cert).decode('ascii')
return "%s\n%s%s\n" % (header, b64_cert, footer)
def set_version(self, v):
"""Set the version of this X509 certificate object.
:param v: The version
"""
ret = self._lib.X509_set_version(self._certObj, v)
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"version.") # pragma: no cover
self._cert['tbsCertificate']['version'] = v
def get_version(self):
"""Get the version of this X509 certificate object."""
return self._lib.X509_get_version(self._certObj)
return self._cert['tbsCertificate']['version']
def get_validity(self):
if self._cert['tbsCertificate']['validity'] is None:
self._cert['tbsCertificate']['validity'] = None
return self._cert['tbsCertificate']['validity']
def set_not_before(self, t):
"""Set the 'not before' date field.
@ -141,15 +122,13 @@ class X509Certificate(object):
:param t: time in seconds since the epoch
"""
asn1_time = utils.timestamp_to_asn1_time(t)
ret = self._lib.X509_set_notBefore(self._certObj, asn1_time)
self._lib.ASN1_TIME_free(asn1_time)
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"not before time.") # pragma: no cover
validity = self.get_validity()
validity['notBefore'] = asn1_time
def get_not_before(self):
"""Get the 'not before' date field as seconds since the epoch."""
not_before = self._lib.X509_get_notBefore(self._certObj)
validity = self.get_validity()
not_before = validity['notBefore']
return utils.asn1_time_to_timestamp(not_before)
def set_not_after(self, t):
@ -158,37 +137,28 @@ class X509Certificate(object):
:param t: time in seconds since the epoch
"""
asn1_time = utils.timestamp_to_asn1_time(t)
ret = self._lib.X509_set_notAfter(self._certObj, asn1_time)
self._lib.ASN1_TIME_free(asn1_time)
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"not after time.") # pragma: no cover
validity = self.get_validity()
validity['notAfter'] = asn1_time
def get_not_after(self):
"""Get the 'not after' date field as seconds since the epoch."""
not_after = self._lib.X509_get_notAfter(self._certObj)
validity = self.get_validity()
not_after = validity['notAfter']
return utils.asn1_time_to_timestamp(not_after)
def set_pubkey(self, pkey):
"""Set the public key field.
:param pkey: The public key, an EVP_PKEY ssl type
:param pkey: The public key, rfc2459.SubjectPublicKeyInfo description
"""
ret = self._lib.X509_set_pubkey(self._certObj, pkey)
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"pubkey.") # pragma: no cover
self._cert['tbsCertificate']['subjectPublicKeyInfo'] = pkey
def get_subject(self):
"""Get the subject name field value.
:return: An X509Name object instance
"""
val = self._lib.X509_get_subject_name(self._certObj)
if val == self._ffi.NULL:
raise X509CertificateError("Could not get subject from X509 "
"certificate.") # pragma: no cover
val = self._cert['tbsCertificate']['subject'][0]
return name.X509Name(val)
def set_subject(self, subject):
@ -197,10 +167,9 @@ class X509Certificate(object):
:param subject: An X509Name object instance
"""
val = subject._name_obj
ret = self._lib.X509_set_subject_name(self._certObj, val)
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"subject.") # pragma: no cover
if self._cert['tbsCertificate']['subject'] is None:
self._cert['tbsCertificate']['subject'] = rfc2459.Name()
self._cert['tbsCertificate']['subject'][0] = val
def set_issuer(self, issuer):
"""Set the issuer name field value.
@ -208,20 +177,16 @@ class X509Certificate(object):
:param issuer: An X509Name object instance
"""
val = issuer._name_obj
ret = self._lib.X509_set_issuer_name(self._certObj, val)
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"issuer.") # pragma: no cover
if self._cert['tbsCertificate']['issuer'] is None:
self._cert['tbsCertificate']['issuer'] = rfc2459.Name()
self._cert['tbsCertificate']['issuer'][0] = val
def get_issuer(self):
"""Get the issuer name field value.
:return: An X509Name object instance
"""
val = self._lib.X509_get_issuer_name(self._certObj)
if val == self._ffi.NULL:
raise X509CertificateError("Could not get subject from X509 "
"certificate.") # pragma: no cover
val = self._cert['tbsCertificate']['issuer'][0]
return name.X509Name(val)
def set_serial_number(self, serial):
@ -232,14 +197,18 @@ class X509Certificate(object):
:param serial: The serial number, 32 bit integer
"""
asn1_int = self._lib.ASN1_INTEGER_new()
ret = self._lib.ASN1_INTEGER_set(asn1_int, serial)
if ret != 0:
ret = self._lib.X509_set_serialNumber(self._certObj, asn1_int)
self._lib.ASN1_INTEGER_free(asn1_int)
if ret == 0:
raise X509CertificateError("Could not set X509 certificate "
"serial number.") # pragma: no cover
self._cert['tbsCertificate']['serialNumber'] = serial
def _get_extensions(self):
if self._cert['tbsCertificate']['extensions'] is None:
# this actually initialises the extensions tag rather than
# assign None
self._cert['tbsCertificate']['extensions'] = None
return self._cert['tbsCertificate']['extensions']
def get_extensions(self):
extensions = self._get_extensions()
return [extension.construct_extension(e) for e in extensions]
def add_extension(self, ext, index):
"""Add an X509 V3 Certificate extension.
@ -247,10 +216,11 @@ class X509Certificate(object):
:param ext: An X509Extension instance
:param index: The index of the extension
"""
ret = self._lib.X509_add_ext(self._certObj, ext._ext, index)
if ret == 0:
raise X509CertificateError("Could not add X509 certificate "
"extension.") # pragma: no cover
if not isinstance(ext, extension.X509Extension):
raise errors.X509Error("ext needs to be a pyasn1 extension")
extensions = self._get_extensions()
extensions[index] = ext.as_asn1()
def sign(self, key, md='sha1'):
"""Sign the X509 certificate with a key using a message digest algorithm
@ -262,28 +232,44 @@ class X509Certificate(object):
- sha1
- sha256
"""
mda = getattr(self._lib, "EVP_%s" % md, None)
if mda is None:
msg = 'X509 signing error: Unknown algorithm {a}'.format(a=md)
raise X509CertificateError(msg)
ret = self._lib.X509_sign(self._certObj, key, mda())
if ret == 0:
raise X509CertificateError("X509 signing error: Could not sign "
" certificate.") # pragma: no cover
md = md.upper()
if isinstance(key, rsa.RSAPrivateKey):
encryption = 'RSA'
elif isinstance(key, dsa.DSAPrivateKey):
encryption = 'DSA'
else:
raise errors.X509Error("Unknown key type: %s" % (key.__class__,))
hash_class = utils.get_hash_class(md)
signature_type = SIGNING_ALGORITHMS.get((encryption, md))
if signature_type is None:
raise errors.X509Error(
"Unknown encryption/hash combination %s/%s" % (encryption, md))
algo_id = rfc2459.AlgorithmIdentifier()
algo_id['algorithm'] = signature_type
if encryption == 'RSA':
algo_id['parameters'] = encoder.encode(asn1_univ.Null())
elif encryption == 'DSA':
pass # parameters should be omitted, see RFC3279
self._cert['tbsCertificate']['signature'] = algo_id
to_sign = encoder.encode(self._cert['tbsCertificate'])
if encryption == 'RSA':
signer = key.signer(padding.PKCS1v15(), hash_class())
elif encryption == 'DSA':
signer = key.signer(hash_class())
signer.update(to_sign)
signature = signer.finalize()
self._cert['signatureValue'] = "'%s'B" % (
utils.bytes_to_bin(signature),)
self._cert['signatureAlgorithm'] = algo_id
def as_der(self):
"""Return this X509 certificate as DER encoded data."""
buf = None
num = self._lib.i2d_X509(self._certObj, self._ffi.NULL)
if num != 0:
buf = self._ffi.new("unsigned char[]", num + 1)
buf_ptr = self._ffi.new("unsigned char**")
buf_ptr[0] = buf
num = self._lib.i2d_X509(self._certObj, buf_ptr)
else:
raise X509CertificateError("Could not encode X509 certificate "
"as DER.") # pragma: no cover
return buf
return encoder.encode(self._cert)
def get_fingerprint(self, md='md5'):
"""Get the fingerprint of this X509 certificate.
@ -291,7 +277,11 @@ class X509Certificate(object):
:param md: The message digest algorthim used to compute the fingerprint
:return: The fingerprint encoded as a hex string
"""
der = self.as_der()
md = message_digest.MessageDigest(md)
md.update(der)
return md.final()
hash_class = utils.get_hash_class(md)
if hash_class is None:
raise errors.X509Error(
"Unknown hash %s" % (md,))
hasher = hashes.Hash(hash_class(),
backend=cio_backends.default_backend())
hasher.update(self.as_der())
return binascii.hexlify(hasher.finalize()).upper().decode('ascii')

318
anchor/X509/extension.py Normal file
View File

@ -0,0 +1,318 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
import functools
import netaddr
from pyasn1.codec.der import decoder
from pyasn1.codec.der import encoder
from pyasn1.type import constraint as asn1_constraint
from pyasn1.type import namedtype as asn1_namedtype
from pyasn1.type import univ as asn1_univ
from pyasn1_modules import rfc2459 # X509v3
from anchor.X509 import errors
from anchor.X509 import utils
EXTENSION_NAMES = {
rfc2459.id_ce_policyConstraints: 'policyConstraints',
rfc2459.id_ce_basicConstraints: 'basicConstraints',
rfc2459.id_ce_subjectDirectoryAttributes: 'subjectDirectoryAttributes',
rfc2459.id_ce_deltaCRLIndicator: 'deltaCRLIndicator',
rfc2459.id_ce_cRLDistributionPoints: 'cRLDistributionPoints',
rfc2459.id_ce_issuingDistributionPoint: 'issuingDistributionPoint',
rfc2459.id_ce_nameConstraints: 'nameConstraints',
rfc2459.id_ce_certificatePolicies: 'certificatePolicies',
rfc2459.id_ce_policyMappings: 'policyMappings',
rfc2459.id_ce_privateKeyUsagePeriod: 'privateKeyUsagePeriod',
rfc2459.id_ce_keyUsage: 'keyUsage',
rfc2459.id_ce_authorityKeyIdentifier: 'authorityKeyIdentifier',
rfc2459.id_ce_subjectKeyIdentifier: 'subjectKeyIdentifier',
rfc2459.id_ce_certificateIssuer: 'certificateIssuer',
rfc2459.id_ce_subjectAltName: 'subjectAltName',
rfc2459.id_ce_issuerAltName: 'issuerAltName',
}
LONG_KEY_USAGE_NAMES = {
"Digital Signature": "digitalSignature",
"Non Repudiation": "nonRepudiation",
"Key Encipherment": "keyEncipherment",
"Data Encipherment": "dataEncipherment",
"Key Agreement": "keyAgreement",
"Certificate Sign": "keyCertSign",
"CRL Sign": "cRLSign",
"Encipher Only": "encipherOnly",
"Decipher Only": "decipherOnly",
}
def uses_ext_value(f):
"""Wrapper allowing reading of extension value.
Because the value is normally saved in a (double) serialised way, it's
not easily accessible to the member methods. This is made easier by
unpacking the extension value into an extra argument.
"""
@functools.wraps(f)
def ext_value_filled(self, *args, **kwargs):
kwargs['ext_value'] = self._get_value()
return f(self, *args, **kwargs)
return ext_value_filled
def modifies_ext_value(f):
"""Wrapper allowing modification of extension value.
Because the value is normally saved in a (double) serialised way, it's
not easily accessible to the member methods. This is made easier by
unpacking the extension value into an extra argument.
New value needs to be returned from the method.
"""
@functools.wraps(f)
def ext_value_filled(self, *args, **kwargs):
value = self._get_value()
kwargs['ext_value'] = value
# since some elements like NamedValue are pure value types, there is
# no interface to modify them and new versions have to be returned
value = f(self, *args, **kwargs)
self._set_value(value)
return ext_value_filled
class BasicConstraints(asn1_univ.Sequence):
"""Custom BasicConstraint implementation until pyasn1_modules is fixes."""
componentType = asn1_namedtype.NamedTypes(
asn1_namedtype.DefaultedNamedType('cA', asn1_univ.Boolean(False)),
asn1_namedtype.OptionalNamedType(
'pathLenConstraint',
asn1_univ.Integer().subtype(
subtypeSpec=asn1_constraint.ValueRangeConstraint(0, 64)))
)
class X509Extension(object):
"""Abstraction for the pyasn1 Extension structures.
The object should normally be constructed using `construct_extension`,
which will choose the right extension type based on the id.
Each extension has an immutable oid and a spec of the internal value
representation.
Unknown extension types can be still represented by the
X509Extension object and copied/serialised without understanding the
value details. The value will not be displayed properly in the logs
in the case.
"""
_oid = None
spec = None
"""An X509 V3 Certificate extension."""
def __init__(self, ext=None):
if ext is None:
if self.spec is None:
raise errors.X509Error("cannot create generic extension")
self._ext = rfc2459.Extension()
self._ext['extnID'] = self._oid
self._set_value(self._get_default_value())
else:
if not isinstance(ext, rfc2459.Extension):
raise errors.X509Error("extension has incorrect type")
self._ext = ext
@classmethod
def _get_default_value(cls):
# if there are any non-optional fields, this needs to be defined in
# the class
return cls.spec()
def __str__(self):
return "%s: %s" % (self.get_name(), self.get_value_as_str())
def get_value_as_str(self):
return "<unknown>"
def get_oid(self):
return self._ext['extnID']
def get_name(self):
"""Get the extension name as a python string."""
oid = self.get_oid()
return EXTENSION_NAMES.get(oid, oid)
def get_critical(self):
return self._ext['critical']
def set_critical(self, critical):
self._ext['critical'] = critical
def _get_value(self):
value_der = decoder.decode(self._ext['extnValue'])[0]
return decoder.decode(value_der, asn1Spec=self.spec())[0]
def _set_value(self, value):
if not isinstance(value, self.spec):
raise errors.X509Error("extension value has incorrect type")
self._ext['extnValue'] = encoder.encode(rfc2459.univ.OctetString(
encoder.encode(value)))
def as_der(self):
return encoder.encode(self._ext)
def as_asn1(self):
return self._ext
class X509ExtensionBasicConstraints(X509Extension):
spec = BasicConstraints
_oid = rfc2459.id_ce_basicConstraints
@uses_ext_value
def get_ca(self, ext_value=None):
return bool(ext_value['cA'])
@modifies_ext_value
def set_ca(self, ca, ext_value=None):
ext_value['cA'] = ca
return ext_value
@uses_ext_value
def get_path_len_constraint(self, ext_value=None):
return ext_value['pathLenConstraint']
@modifies_ext_value
def set_path_len_constraint(self, length, ext_value=None):
ext_value['pathLenConstraint'] = length
return ext_value
def __str__(self):
return "basicConstraints: CA: %s, pathLen: %s" % (
str(self.get_ca()).upper(), self.get_path_len_constraint())
class X509ExtensionKeyUsage(X509Extension):
spec = rfc2459.KeyUsage
_oid = rfc2459.id_ce_keyUsage
fields = dict(spec.namedValues.namedValues)
inv_fields = dict((v, k) for k, v in spec.namedValues.namedValues)
@classmethod
def _get_default_value(cls):
# if there are any non-optional fields, this needs to be defined in
# the class
return cls.spec("''B")
@uses_ext_value
def get_usage(self, usage, ext_value=None):
usage = LONG_KEY_USAGE_NAMES.get(usage, usage)
pos = self.fields[usage]
if pos >= len(ext_value):
return False
return bool(ext_value[pos])
@uses_ext_value
def get_all_usages(self, ext_value=None):
return [self.inv_fields[i] for i, enabled in enumerate(ext_value)
if enabled]
@modifies_ext_value
def set_usage(self, usage, state, ext_value=None):
usage = LONG_KEY_USAGE_NAMES.get(usage, usage)
pos = self.fields[usage]
values = [x for x in ext_value]
if state:
while pos >= len(values):
values.append(0)
values[pos] = 1
else:
if pos < len(values):
values[pos] = 0
bits = ''.join(str(x) for x in values)
return self.spec("'%s'B" % bits)
def __str__(self):
return "keyUsage: " + ", ".join(self.get_all_usages())
class X509ExtensionSubjectAltName(X509Extension):
spec = rfc2459.SubjectAltName
_oid = rfc2459.id_ce_subjectAltName
@uses_ext_value
def get_dns_ids(self, ext_value=None):
dns_ids = []
for name in ext_value:
if name.getName() != 'dNSName':
continue
component = name.getComponent()
dns_id = component.asOctets().decode(component.encoding)
dns_ids.append(dns_id)
return dns_ids
@uses_ext_value
def get_ips(self, ext_value=None):
ips = []
for name in ext_value:
if name.getName() != 'iPAddress':
continue
ips.append(utils.asn1_to_netaddr(name.getComponent()))
return ips
@modifies_ext_value
def add_dns_id(self, dns_id, ext_value=None):
# TODO(stan) validate dns_id
new_pos = len(ext_value)
ext_value[new_pos] = None
ext_value[new_pos]['dNSName'] = dns_id
return ext_value
@modifies_ext_value
def add_ip(self, ip, ext_value=None):
if not isinstance(ip, netaddr.IPAddress):
raise errors.X509Error("not a real ip address provided")
new_pos = len(ext_value)
ext_value[new_pos] = None
ext_value[new_pos]['iPAddress'] = utils.netaddr_to_asn1(ip)
return ext_value
@uses_ext_value
def __str__(self, ext_value=None):
entries = ["DNS:%s" % (x,) for x in self.get_dns_ids()]
entries += ["IP:%s" % (x,) for x in self.get_ips()]
return "subjectAltName: " + ", ".join(entries)
EXTENSION_CLASSES = {
rfc2459.id_ce_basicConstraints: X509ExtensionBasicConstraints,
rfc2459.id_ce_keyUsage: X509ExtensionKeyUsage,
rfc2459.id_ce_subjectAltName: X509ExtensionSubjectAltName,
}
def construct_extension(ext):
"""Construct an extension object of the right type.
While X509Extension can provide basic access to the extension elements,
it cannot parse details of extensions. This function detects which type
should be used based on the extension id.
If the type is unknown, generic X509Extension is used instead.
"""
if not isinstance(ext, rfc2459.Extension):
raise errors.X509Error("extension has incorrect type")
ext_class = EXTENSION_CLASSES.get(ext['extnID'], X509Extension)
return ext_class(ext)

View File

@ -1,93 +0,0 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
from cryptography.hazmat.backends.openssl import backend
import binascii
class MessageDigestError(Exception):
def __init__(self, what):
super(MessageDigestError, self).__init__(what)
class MessageDigest(object):
"""Compute a message digest from input data."""
@staticmethod
def getValidAlgorithms():
"""Get a list of available valid hash algorithms."""
algs = [
"md5",
"ripemd160",
"sha224",
"sha256",
"sha384",
"sha512"
]
ret = []
for alg in algs:
if getattr(backend._lib, "EVP_%s" % alg, None) is not None:
ret.append(alg)
return ret
def __init__(self, algo):
self._lib = backend._lib
self._ffi = backend._ffi
md = getattr(self._lib, "EVP_%s" % algo, None)
if md is None:
msg = 'MessageDigest error: unknown algorithm {a}'.format(a=algo)
raise MessageDigestError(msg)
ret = 0
ctx = self._lib.EVP_MD_CTX_create()
if ctx != self._ffi.NULL:
self.ctx = ctx
self.mda = md()
ret = self._lib.EVP_DigestInit_ex(self.ctx,
self.mda,
self._ffi.NULL)
if ret == 0:
raise MessageDigestError(
"Could not setup message digest context.") # pragma: no cover
def __del__(self):
if getattr(self, 'ctx', None):
self._lib.EVP_MD_CTX_cleanup(self.ctx)
self._lib.EVP_MD_CTX_destroy(self.ctx)
def update(self, data):
"""Add more data to the digest."""
ret = self._lib.EVP_DigestUpdate(self.ctx, data, len(data))
if ret == 0:
raise MessageDigestError(
"Failed to update message digest data.") # pragma: no cover
def final(self):
"""get the final resulting digest value.
Note that you should not call update() with additional data after using
final.
"""
sz = self._lib.EVP_MD_size(self.mda)
data = self._ffi.new("char[]", sz)
ret = self._lib.EVP_DigestFinal_ex(self.ctx, data, self._ffi.NULL)
if ret == 0:
raise MessageDigestError(
"Failed to get message digest.") # pragma: no cover
digest = self._ffi.string(data)
return binascii.hexlify(digest).decode('ascii').upper()

View File

@ -13,21 +13,64 @@
from __future__ import absolute_import
from cryptography.hazmat.backends.openssl import backend
from pyasn1.codec.der import decoder
from pyasn1.codec.der import encoder
from pyasn1.type import error as asn1_error
from pyasn1.type import univ as asn1_univ
from pyasn1_modules import rfc2459
from anchor.X509 import errors
from anchor.X509 import utils
OID_commonName = rfc2459.id_at_commonName
OID_localityName = rfc2459.id_at_localityName
OID_stateOrProvinceName = rfc2459.id_at_stateOrProvinceName
OID_organizationName = rfc2459.id_at_organizationName
OID_organizationalUnitName = rfc2459.id_at_organizationalUnitName
OID_countryName = rfc2459.id_at_countryName
OID_pkcs9_emailAddress = rfc2459.emailAddress
OID_surname = rfc2459.id_at_sutname
OID_givenName = rfc2459.id_at_givenName
NID_countryName = backend._lib.NID_countryName
NID_stateOrProvinceName = backend._lib.NID_stateOrProvinceName
NID_localityName = backend._lib.NID_localityName
NID_organizationName = backend._lib.NID_organizationName
NID_organizationalUnitName = backend._lib.NID_organizationalUnitName
NID_commonName = backend._lib.NID_commonName
NID_pkcs9_emailAddress = backend._lib.NID_pkcs9_emailAddress
NID_surname = backend._lib.NID_surname
NID_givenName = backend._lib.NID_givenName
name_oids = {
rfc2459.id_at_name: rfc2459.X520name,
rfc2459.id_at_sutname: rfc2459.X520name,
rfc2459.id_at_givenName: rfc2459.X520name,
rfc2459.id_at_initials: rfc2459.X520name,
rfc2459.id_at_generationQualifier: rfc2459.X520name,
rfc2459.id_at_commonName: rfc2459.X520CommonName,
rfc2459.id_at_localityName: rfc2459.X520LocalityName,
rfc2459.id_at_stateOrProvinceName: rfc2459.X520StateOrProvinceName,
rfc2459.id_at_organizationName: rfc2459.X520OrganizationName,
rfc2459.id_at_organizationalUnitName: rfc2459.X520OrganizationalUnitName,
rfc2459.id_at_title: rfc2459.X520Title,
rfc2459.id_at_dnQualifier: rfc2459.X520dnQualifier,
rfc2459.id_at_countryName: rfc2459.X520countryName,
rfc2459.emailAddress: rfc2459.Pkcs9email,
}
code_names = {
rfc2459.id_at_commonName: "CN",
rfc2459.id_at_localityName: "L",
rfc2459.id_at_stateOrProvinceName: "ST",
rfc2459.id_at_organizationName: "O",
rfc2459.id_at_organizationalUnitName: "OU",
rfc2459.id_at_countryName: "C",
rfc2459.id_at_givenName: "GN",
rfc2459.id_at_sutname: "SN",
rfc2459.emailAddress: "emailAddress",
}
short_names = {
rfc2459.id_at_commonName: "commonName",
rfc2459.id_at_localityName: "localityName",
rfc2459.id_at_stateOrProvinceName: "stateOrProvinceName",
rfc2459.id_at_organizationName: "organizationName",
rfc2459.id_at_organizationalUnitName: "organizationalUnitName",
rfc2459.id_at_countryName: "countryName",
rfc2459.id_at_givenName: "givenName",
rfc2459.id_at_sutname: "surname",
rfc2459.emailAddress: "emailAddress",
}
class X509Name(object):
@ -35,104 +78,90 @@ class X509Name(object):
class Entry():
"""An X509 Name sub-entry object."""
def __init__(self, obj, parent):
self._parent = parent
self._lib = backend._lib
self._ffi = backend._ffi
self._entry = obj
def __init__(self, obj):
self._obj = obj
def __str__(self):
return "%s: %s" % (self.get_name(), self.get_value())
def get_oid(self):
return self._obj[0]['type']
def get_name(self):
"""Get the name of this entry.
:return: entry name as a python string
"""
asn1_obj = self._lib.X509_NAME_ENTRY_get_object(self._entry)
buf = self._ffi.new('char[]', 1024)
ret = self._lib.OBJ_obj2txt(buf, 1024, asn1_obj, 0)
if ret == 0:
raise errors.X509Error("Could not convert ASN1_OBJECT to "
"string.") # pragma: no cover
return self._ffi.string(buf).decode('ascii')
oid = self.get_oid()
return short_names.get(oid, str(oid))
def get_code(self):
"""Get the name of this entry.
:return: entry name as a python string
"""
oid = self.get_oid()
return code_names.get(oid, str(oid))
def get_value(self):
"""Get the value of this entry.
:return: entry value as a python string
"""
val = self._lib.X509_NAME_ENTRY_get_data(self._entry)
return utils.asn1_string_to_utf8(val)
value = self._obj[0]['value']
der = value.asOctets()
name_spec = name_oids[self.get_oid()]()
value = decoder.decode(der, asn1Spec=name_spec)[0]
if hasattr(value, 'getComponent'):
value = value.getComponent()
return value.asOctets().decode(value.encoding)
def __init__(self, name_obj=None):
self._lib = backend._lib
self._ffi = backend._ffi
if name_obj is not None:
self._name_obj = self._lib.X509_NAME_dup(name_obj)
if self._name_obj == self._ffi.NULL:
raise errors.X509Error("Failed to copy X509_NAME "
"object.") # pragma: no cover
if not isinstance(name_obj, rfc2459.RDNSequence):
raise TypeError("name is not an RDNSequence")
# TODO(stan): actual copy
self._name_obj = name_obj
else:
self._name_obj = self._lib.X509_NAME_new()
if self._name_obj == self._ffi.NULL:
raise errors.X509Error("Failed to create "
"X509_NAME object.") # pragma: no cover
def __del__(self):
self._lib.X509_NAME_free(self._name_obj)
self._name_obj = rfc2459.RDNSequence()
def __str__(self):
# NOTE(tkelsey): we need to pass in a max size, so why not 1024
val = self._lib.X509_NAME_oneline(self._name_obj, self._ffi.NULL, 1024)
if val == self._ffi.NULL:
raise errors.X509Error("Could not convert"
" X509_NAME to string.") # pragma: no cover
val = self._ffi.gc(val, self._lib.OPENSSL_free)
return self._ffi.string(val).decode('ascii')
return '/' + '/'.join("%s=%s" % (e.get_code(), e.get_value())
for e in self)
def __len__(self):
return self._lib.X509_NAME_entry_count(self._name_obj)
return len(self._name_obj)
def __getitem__(self, idx):
if not (0 <= idx < self.entry_count()):
raise IndexError("index out of range")
ent = self._lib.X509_NAME_get_entry(self._name_obj, idx)
return X509Name.Entry(ent, self)
return X509Name.Entry(self._name_obj[idx])
def __iter__(self):
for i in range(self.entry_count()):
for i in range(len(self)):
yield self[i]
def add_name_entry(self, nid, text):
"""Add a name entry by its NID name."""
ret = self._lib.X509_NAME_add_entry_by_NID(
self._name_obj, nid,
self._lib.MBSTRING_UTF8,
text.encode('utf8'), -1, -1, 0)
def add_name_entry(self, oid, text):
if not isinstance(oid, asn1_univ.ObjectIdentifier):
raise errors.X509Error("oid '%s' is not valid" % (oid,))
entry = rfc2459.RelativeDistinguishedName()
entry[0] = rfc2459.AttributeTypeAndValue()
entry[0]['type'] = oid
name_type = name_oids[oid]
try:
if name_type in (rfc2459.X520countryName, rfc2459.Pkcs9email):
val = name_type(text)
else:
val = name_type()
val['utf8String'] = text
except asn1_error.ValueConstraintError:
raise errors.X509Error("Name '%s' is not valid" % text)
entry[0]['value'] = rfc2459.AttributeValue(encoder.encode(val))
self._name_obj[len(self)] = entry
if ret != 1:
raise errors.X509Error("Failed to add name entry: '%s' '%s'" % (
nid, text))
def entry_count(self):
"""Get the number of entries in the name object."""
return self._lib.X509_NAME_entry_count(self._name_obj)
def get_entries_by_nid(self, nid):
def get_entries_by_oid(self, oid):
"""Get a name entry corresponding to an NID name.
:param nid: an NID for the new name entry
:return: An X509Name.Entry object
"""
out = []
idx = self._lib.X509_NAME_get_index_by_NID(self._name_obj, nid, -1)
while idx != -1:
val = self._lib.X509_NAME_get_entry(self._name_obj, idx)
if val != self._ffi.NULL:
out.append(X509Name.Entry(val, self))
idx = self._lib.X509_NAME_get_index_by_NID(self._name_obj,
nid, idx)
return out
return [entry for entry in self if entry.get_oid() == oid]

View File

@ -13,13 +13,23 @@
from __future__ import absolute_import
from cryptography.hazmat.backends.openssl import backend
import io
from pyasn1.codec.der import decoder
from pyasn1.codec.der import encoder
from pyasn1.type import univ as asn1_univ
from pyasn1_modules import pem
from pyasn1_modules import rfc2314 # PKCS#10 / CSR
from pyasn1_modules import rfc2459 # X509
from anchor.X509 import certificate
from anchor.X509 import errors
from anchor.X509 import extension
from anchor.X509 import name
OID_extensionRequest = asn1_univ.ObjectIdentifier('1.2.840.113549.1.9.14')
class X509CsrError(errors.X509Error):
def __init__(self, what):
super(X509CsrError, self).__init__(what)
@ -27,84 +37,108 @@ class X509CsrError(errors.X509Error):
class X509Csr(object):
"""An X509 Certificate Signing Request."""
def __init__(self):
self._lib = backend._lib
self._ffi = backend._ffi
csrObj = self._lib.X509_REQ_new()
if csrObj == self._ffi.NULL:
raise X509CsrError(
"Could not create X509 CSR Object.") # pragma: no cover
def __init__(self, csr=None):
if csr is None:
self._csr = rfc2314.CertificationRequest()
else:
self._csr = csr
self._csrObj = csrObj
@staticmethod
def from_open_file(f):
try:
der_content = pem.readPemFromFile(
f, startMarker='-----BEGIN CERTIFICATE REQUEST-----',
endMarker='-----END CERTIFICATE REQUEST-----')
csr = decoder.decode(der_content,
asn1Spec=rfc2314.CertificationRequest())[0]
return X509Csr(csr)
except Exception:
raise X509CsrError("Could not read X509 certificate from "
"PEM data.")
def __del__(self):
if getattr(self, '_csrObj', None):
self._lib.X509_REQ_free(self._csrObj)
def from_buffer(self, data, password=None):
@staticmethod
def from_buffer(data):
"""Create this CSR from a buffer
:param data: The data buffer
:param password: decryption password, if needed
"""
if type(data) != bytes:
data = data.encode('ascii')
bio = backend._bytes_to_bio(data)
ptr = self._ffi.new("X509_REQ **")
ptr[0] = self._csrObj
ret = self._lib.PEM_read_bio_X509_REQ(bio[0], ptr,
self._ffi.NULL,
self._ffi.NULL)
if ret == self._ffi.NULL:
raise X509CsrError("Could not read X509 CSR from PEM data.")
return X509Csr.from_open_file(io.StringIO(data))
def from_file(self, path, password=None):
@staticmethod
def from_file(path):
"""Create this CSR from a file on disk
:param path: Path to the file on disk
:param password: decryption password, if needed
"""
data = None
with open(path, 'rb') as f:
data = f.read()
self.from_buffer(data, password)
with open(path, 'r') as f:
return X509Csr.from_open_file(f)
def get_pubkey(self):
"""Get the public key from the CSR
:return: an OpenSSL EVP_PKEY object
"""
pkey = self._lib.X509_REQ_get_pubkey(self._csrObj)
if pkey == self._ffi.NULL:
raise X509CsrError(
"Could not get pubkey from X509 CSR.") # pragma: no cover
return self._csr['certificationRequestInfo']['subjectPublicKeyInfo']
return pkey
def get_request_info(self):
if self._csr['certificationRequestInfo'] is None:
self._csr['certificationRequestInfo'] = None
return self._csr['certificationRequestInfo']
def get_subject(self):
"""Get the subject name field from the CSR
:return: an X509Name object
"""
subs = self._lib.X509_REQ_get_subject_name(self._csrObj)
if subs == self._ffi.NULL:
raise X509CsrError(
"Could not get subject from X509 CSR.") # pragma: no cover
ri = self.get_request_info()
if ri['subject'] is None:
ri['subject'] = None
# setup first RDN sequence
ri['subject'][0] = None
return name.X509Name(subs)
subject = ri['subject'][0]
return name.X509Name(subject)
def get_extensions(self):
def get_attributes(self):
ri = self.get_request_info()
if ri['attributes'] is None:
ri['attributes'] = None
return ri['attributes']
def get_extensions(self, ext_type=None):
"""Get the list of all X509 V3 Extensions on this CSR
:return: a list of X509Extension objects
"""
# TODO(tkelsey): I assume the ext list copies data and this is safe
# TODO(tkelsey): Error checking needed here
ret = []
exts = self._lib.X509_REQ_get_extensions(self._csrObj)
num = self._lib.sk_X509_EXTENSION_num(exts)
for i in range(0, num):
ext = self._lib.sk_X509_EXTENSION_value(exts, i)
ret.append(certificate.X509Extension(ext))
self._lib.sk_X509_EXTENSION_free(exts)
return ret
ext_attrs = [a for a in self.get_attributes()
if a['type'] == OID_extensionRequest]
if len(ext_attrs) == 0:
return []
else:
exts_der = ext_attrs[0]['vals'][0].asOctets()
exts = decoder.decode(exts_der, asn1Spec=rfc2459.Extensions())[0]
return [extension.construct_extension(e) for e in exts
if ext_type is None or e['extnID'] == ext_type._oid]
def add_extension(self, ext):
if not isinstance(ext, extension.X509Extension):
raise errors.X509Error("ext is not an anchor X509Extension")
attributes = self.get_attributes()
ext_attrs = [a for a in attributes
if a['type'] == OID_extensionRequest]
if not ext_attrs:
new_attr_index = len(attributes)
attributes[new_attr_index] = None
ext_attr = attributes[new_attr_index]
ext_attr['type'] = OID_extensionRequest
ext_attr['vals'] = None
exts = rfc2459.Extensions()
else:
ext_attr = ext_attrs[0]
exts = decoder.decode(ext_attr['vals'][0].asOctets(),
asn1Spec=rfc2459.Extensions())[0]
new_ext_index = len(exts)
exts[new_ext_index] = ext._ext
ext_attr['vals'][0] = encoder.encode(exts)

View File

@ -15,38 +15,18 @@ from __future__ import absolute_import
import calendar
import datetime
import struct
from cryptography.hazmat.backends.openssl import backend
from cryptography.hazmat import backends
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
import netaddr
from pyasn1.type import useful as asn1_useful
from pyasn1_modules import rfc2459
from anchor.X509 import errors
def load_pem_private_key(key_data, passwd=None):
"""Load and return an OpenSSL EVP_PKEY public key object from a data buffer
:param key_data: The data buffer
:param passwd: Decryption password if neded (not used for now)
:return: an OpenSSL EVP_PKEY public key object
"""
# TODO(tkelsey): look at using backend.read_private_key
#
if type(key_data) != bytes:
key_data = key_data.encode('ascii')
lib = backend._lib
ffi = backend._ffi
data = backend._bytes_to_bio(key_data)
evp_pkey = lib.EVP_PKEY_new()
evp_pkey_ptr = ffi.new("EVP_PKEY**")
evp_pkey_ptr[0] = evp_pkey
evp_pkey = lib.PEM_read_bio_PrivateKey(data[0], evp_pkey_ptr,
ffi.NULL, ffi.NULL)
evp_pkey = ffi.gc(evp_pkey, lib.EVP_PKEY_free)
return evp_pkey
def create_timezone(minute_offset):
"""Create a new timezone with a specified offset.
@ -80,18 +60,17 @@ def asn1_time_to_timestamp(t):
:param t: ASN1_TIME to convert
"""
gen_time = backend._lib.ASN1_TIME_to_generalizedtime(t, backend._ffi.NULL)
if gen_time == backend._ffi.NULL:
raise errors.ASN1TimeError("time conversion failure")
try:
return asn1_generalizedtime_to_timestamp(gen_time)
finally:
backend._lib.ASN1_GENERALIZEDTIME_free(gen_time)
component = t.getComponent()
timestring = component.asOctets().decode(component.encoding)
if isinstance(component, asn1_useful.UTCTime):
if int(timestring[0]) >= 5:
timestring = "19" + timestring
else:
timestring = "20" + timestring
return asn1_timestring_to_timestamp(timestring)
def asn1_generalizedtime_to_timestamp(gt):
def asn1_timestring_to_timestamp(timestring):
"""Convert from ASN1_GENERALIZEDTIME to UTC-based timestamp.
:param gt: ASN1_GENERALIZEDTIME to convert
@ -99,11 +78,8 @@ def asn1_generalizedtime_to_timestamp(gt):
# ASN1_GENERALIZEDTIME is actually a string in known formats,
# so the conversion can be done in this code
string_time = backend._ffi.cast("ASN1_STRING*", gt)
res = asn1_string_to_utf8(string_time)
before_tz = res[:14]
tz_str = res[14:]
before_tz = timestring[:14]
tz_str = timestring[14:]
d = datetime.datetime.strptime(before_tz, "%Y%m%d%H%M%S")
if tz_str == 'Z':
# YYYYMMDDhhmmssZ
@ -126,25 +102,85 @@ def timestamp_to_asn1_time(t):
"""
d = datetime.datetime.utcfromtimestamp(t)
# use the ASN1_GENERALIZEDTIME format
time_str = d.strftime("%Y%m%d%H%M%SZ").encode('ascii')
asn1_time = backend._lib.ASN1_STRING_type_new(
backend._lib.V_ASN1_GENERALIZEDTIME)
backend._lib.ASN1_STRING_set(asn1_time, time_str, len(time_str))
asn1_gentime = backend._ffi.cast("ASN1_GENERALIZEDTIME*", asn1_time)
if backend._lib.ASN1_GENERALIZEDTIME_check(asn1_gentime) == 0:
raise errors.ASN1TimeError("timestamp not accepted by ASN1 check")
# ASN1_GENERALIZEDTIME is a form of ASN1_TIME, so a pointer cast is valid
return backend._ffi.cast("ASN1_TIME*", asn1_time)
asn1time = rfc2459.Time()
if d.year <= 2049:
time_str = d.strftime("%y%m%d%H%M%SZ").encode('ascii')
asn1time['utcTime'] = time_str
else:
time_str = d.strftime("%Y%m%d%H%M%SZ").encode('ascii')
asn1time['generalTime'] = time_str
return asn1time
def asn1_string_to_utf8(asn1_string):
buf = backend._ffi.new("unsigned char **")
res = backend._lib.ASN1_STRING_to_UTF8(buf, asn1_string)
if res < 0 or buf[0] == backend._ffi.NULL:
raise errors.ASN1StringError("cannot convert asn1 to python string")
buf = backend._ffi.gc(
buf, lambda buffer: backend._lib.OPENSSL_free(buffer[0])
)
return backend._ffi.buffer(buf[0], res)[:].decode('utf8')
# functions needed for converting the pyasn1 signature fields
def bin_to_bytes(bits):
"""Convert bit string to byte string."""
bits = ''.join(str(b) for b in bits)
bits = _pad_byte(bits)
octets = [bits[8*i:8*(i+1)] for i in range(len(bits)/8)]
bytes = [chr(int(x, 2)) for x in octets]
return "".join(bytes)
# ord good for py2 and py3
local_ord = ord if str is bytes else lambda x: x
def _pad_byte(bits):
"""Pad a string of bits with zeros to make its length a multiple of 8."""
r = len(bits) % 8
return ((8-r) % 8)*'0' + bits
def bytes_to_bin(bytes):
"""Convert byte string to bit string."""
return "".join([_pad_byte(_int_to_bin(local_ord(byte))) for byte in bytes])
def _int_to_bin(n):
if n == 0 or n == 1:
return str(n)
elif n % 2 == 0:
return _int_to_bin(n // 2) + "0"
else:
return _int_to_bin(n // 2) + "1"
def get_hash_class(md):
return getattr(hashes, md.upper(), None)
def get_private_key_from_bytes(data):
key = serialization.load_pem_private_key(
data, None, backend=backends.default_backend())
return key
def get_private_key_from_file(path):
with open(path, 'rb') as f:
return get_private_key_from_bytes(f.read())
def asn1_to_netaddr(octet_string):
"""Translate the ASN1 IP format to netaddr object."""
if not isinstance(octet_string, rfc2459.univ.OctetString):
raise TypeError("not an OctetString")
ip_bytes = octet_string.asOctets()
if len(ip_bytes) == 4:
ip_num = struct.unpack(">I", ip_bytes)[0]
return netaddr.IPAddress(ip_num, 4)
elif len(ip_bytes) == 16:
ip_num_front, ip_num_back = struct.unpack(">QQ", ip_bytes)
ip_num = ip_num_front << 64 | ip_num_back
return netaddr.IPAddress(ip_num, 6)
else:
raise TypeError("ip address is neither v4 nor v6")
def netaddr_to_asn1(ip):
"""Translate the netaddr object to ASN1 IP format."""
if not isinstance(ip, netaddr.IPAddress):
raise errors.X509Error("not a real ip address provided")
return bytes(ip.packed)

View File

@ -25,7 +25,7 @@ from anchor import jsonloader
from anchor import validators
from anchor.X509 import certificate
from anchor.X509 import signing_request
from anchor.X509 import utils as X509_utils
from anchor.X509 import utils
logger = logging.getLogger(__name__)
@ -54,8 +54,7 @@ def parse_csr(csr, encoding):
# load the CSR into the backend X509 library
try:
out_req = signing_request.X509Csr()
out_req.from_buffer(csr)
out_req = signing_request.X509Csr.from_buffer(csr)
return out_req
except Exception as e:
logger.exception("Exception while parsing the CSR: %s", e)
@ -132,17 +131,14 @@ def sign(csr):
:param csr: X509 certificate signing request
"""
try:
ca = certificate.X509Certificate()
ca.from_file(jsonloader.conf.ca["cert_path"])
ca = certificate.X509Certificate.from_file(
jsonloader.conf.ca["cert_path"])
except Exception as e:
logger.exception("Cannot load the signing CA: %s", e)
pecan.abort(500, "certificate signing error")
try:
key_data = None
with open(jsonloader.conf.ca["key_path"]) as f:
key_data = f.read()
key = X509_utils.load_pem_private_key(key_data)
key = utils.get_private_key_from_file(jsonloader.conf.ca['key_path'])
except Exception as e:
logger.exception("Cannot load the signing CA key: %s", e)
pecan.abort(500, "certificate signing error")
@ -182,7 +178,7 @@ def sign(csr):
cert_pem = new_cert.as_pem()
with open(path, "wb") as f:
with open(path, "w") as f:
f.write(cert_pem)
return cert_pem

View File

@ -17,6 +17,7 @@ import logging
import netaddr
from anchor.X509 import extension
from anchor.X509 import name as x509_name
@ -29,7 +30,7 @@ class ValidationError(Exception):
def csr_get_cn(csr):
name = csr.get_subject()
data = name.get_entries_by_nid(x509_name.NID_commonName)
data = name.get_entries_by_oid(x509_name.OID_commonName)
if len(data) > 0:
return data[0].get_value()
else:
@ -51,19 +52,14 @@ def check_domains(domain, allowed_domains):
def iter_alternative_names(csr, types, fail_other_types=True):
for ext in csr.get_extensions():
if ext.get_name() == "subjectAltName":
alternatives = [alt.strip() for alt in ext.get_value().split(',')]
for alternative in alternatives:
parts = alternative.split(':', 1)
if len(parts) != 2:
# it has at least one part, so parts[0] is valid
raise ValidationError("Alt name should have 2 parts, but "
"found: '%s'" % parts[0])
if parts[0] in types:
yield parts
elif fail_other_types:
raise ValidationError("Alt name '%s' has unexpected type "
"'%s'" % (parts[1], parts[0]))
if isinstance(ext, extension.X509ExtensionSubjectAltName):
# TODO(stan): fail on other types
if 'DNS' in types:
for dns_id in ext.get_dns_ids():
yield ('DNS', dns_id)
if 'IP Address' in types:
for ip in ext.get_ips():
yield ('IP Address', ip)
def check_networks(ip, allowed_networks):
@ -91,15 +87,15 @@ def common_name(csr, allowed_domains=[], allowed_networks=[], **kwargs):
alt_present = any(ext.get_name() == "subjectAltName"
for ext in csr.get_extensions())
CNs = csr.get_subject().get_entries_by_nid(x509_name.NID_commonName)
CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName)
if len(CNs) > 1:
raise ValidationError("Too many CNs in the request")
if not alt_present:
# rfc5280#section-4.2.1.6 says so
if len(CNs) == 0:
raise ValidationError("Alt subjects have to exist if the main"
" subject doesn't")
# rfc5280#section-4.2.1.6 says so
if len(CNs) == 0 and not alt_present:
raise ValidationError("Alt subjects have to exist if the main"
" subject doesn't")
if len(CNs) > 0:
cn = csr_get_cn(csr)
@ -122,7 +118,7 @@ def alternative_names(csr, allowed_domains=[], **kwargs):
the list of known suffixes, or network ranges.
"""
for name_type, name in iter_alternative_names(csr, ['DNS']):
for _, name in iter_alternative_names(csr, ['DNS']):
if not check_domains(name, allowed_domains):
raise ValidationError("Domain '%s' not allowed (doesn't"
" match known domains)"
@ -142,9 +138,8 @@ def alternative_names_ip(csr, allowed_domains=[], allowed_networks=[],
raise ValidationError("Domain '%s' not allowed (doesn't"
" match known domains)" % name)
if name_type == 'IP Address':
ip = netaddr.IPAddress(name)
if not check_networks(ip, allowed_networks):
raise ValidationError("Address '%s' not allowed (doesn't"
if not check_networks(name, allowed_networks):
raise ValidationError("IP '%s' not allowed (doesn't"
" match known networks)" % name)
@ -156,7 +151,7 @@ def blacklist_names(csr, domains=[], **kwargs):
"consider disabling the step or providing a list")
return
CNs = csr.get_subject().get_entries_by_nid(x509_name.NID_commonName)
CNs = csr.get_subject().get_entries_by_oid(x509_name.OID_commonName)
if len(CNs) > 0:
cn = csr_get_cn(csr)
if check_domains(cn, domains):
@ -198,45 +193,41 @@ def extensions(csr=None, allowed_extensions=[], **kwargs):
def key_usage(csr=None, allowed_usage=None, **kwargs):
"""Ensure only accepted key usages are specified."""
allowed = set(allowed_usage)
allowed = set(extension.LONG_KEY_USAGE_NAMES.get(x, x) for x in
allowed_usage)
denied = set()
for ext in (csr.get_extensions() or []):
if ext.get_name() == 'keyUsage':
usages = set(usage.strip() for usage in ext.get_value().split(','))
if usages & allowed != usages:
raise ValidationError("Found some not allowed key usages: %s"
% ', '.join(usages - allowed))
if isinstance(ext, extension.X509ExtensionKeyUsage):
usages = set(ext.get_all_usages())
denied = denied | (usages - allowed)
if denied:
raise ValidationError("Found some not allowed key usages: %s"
% ', '.join(denied))
def ca_status(csr=None, ca_requested=False, **kwargs):
"""Ensure the request has/hasn't got the CA flag."""
request_ca_flags = False
for ext in (csr.get_extensions() or []):
ext_name = ext.get_name()
if ext_name == 'basicConstraints':
options = [opt.strip() for opt in ext.get_value().split(",")]
for option in options:
parts = option.split(":")
if len(parts) != 2:
raise ValidationError("Invalid basic constraints flag")
if parts[0] == 'CA':
if parts[1] != str(ca_requested).upper():
raise ValidationError("Invalid CA status, 'CA:%s'"
" requested" % parts[1])
elif parts[0] == 'pathlen':
# errr.. it's ok, I guess
pass
else:
raise ValidationError("Invalid basic constraints option")
elif ext_name == 'keyUsage':
usages = set(usage.strip() for usage in ext.get_value().split(','))
has_cert_sign = ('Certificate Sign' in usages)
has_crl_sign = ('CRL Sign' in usages)
if ca_requested != has_cert_sign or ca_requested != has_crl_sign:
raise ValidationError("Key usage doesn't match requested CA"
" status (keyCertSign/cRLSign: %s/%s)"
% (has_cert_sign, has_crl_sign))
if isinstance(ext, extension.X509ExtensionBasicConstraints):
if ext.get_ca():
if not ca_requested:
raise ValidationError(
"CA status requested, but not allowed")
request_ca_flags = True
elif isinstance(ext, extension.X509ExtensionKeyUsage):
has_cert_sign = ext.get_usage('keyCertSign')
has_crl_sign = ext.get_usage('cRLSign')
if has_crl_sign or has_cert_sign:
if not ca_requested:
raise ValidationError(
"Key usage doesn't match requested CA status "
"(keyCertSign/cRLSign: %s/%s)"
% (has_cert_sign, has_crl_sign))
request_ca_flags = True
if ca_requested and not request_ca_flags:
raise ValidationError("CA flags required")
def source_cidrs(request=None, cidrs=None, **kwargs):

View File

@ -2,6 +2,8 @@
# of appearance. Changing the order has an impact on the overall integration
# process, which may cause wedges in the gate later.
cryptography>=0.9.1 # Apache-2.0
pyasn1
pyasn1_modules
pecan>=0.8.0
Paste
netaddr>=0.7.12

View File

@ -0,0 +1,144 @@
# -*- coding:utf-8 -*-
#
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest
import netaddr
from pyasn1_modules import rfc2459 # X509v3
from anchor.X509 import errors
from anchor.X509 import extension
class TestExtensionBase(unittest.TestCase):
def test_no_spec(self):
with self.assertRaises(errors.X509Error):
extension.X509Extension()
def test_invalid_asn(self):
with self.assertRaises(errors.X509Error):
extension.X509Extension("foobar")
def test_unknown_extension_str(self):
asn1 = rfc2459.Extension()
asn1['extnID'] = rfc2459.univ.ObjectIdentifier('1.2.3.4')
asn1['critical'] = False
asn1['extnValue'] = "foobar"
ext = extension.X509Extension(asn1)
self.assertEqual("1.2.3.4: <unknown>", str(ext))
def test_construct(self):
asn1 = rfc2459.Extension()
asn1['extnID'] = rfc2459.univ.ObjectIdentifier('1.2.3.4')
asn1['critical'] = False
asn1['extnValue'] = "foobar"
ext = extension.construct_extension(asn1)
self.assertIsInstance(ext, extension.X509Extension)
def test_construct_invalid_type(self):
with self.assertRaises(errors.X509Error):
extension.construct_extension("foobar")
def test_critical(self):
asn1 = rfc2459.Extension()
asn1['extnID'] = rfc2459.univ.ObjectIdentifier('1.2.3.4')
asn1['critical'] = False
asn1['extnValue'] = "foobar"
ext = extension.construct_extension(asn1)
self.assertFalse(ext.get_critical())
ext.set_critical(True)
self.assertTrue(ext.get_critical())
class TestBasicConstraints(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionBasicConstraints()
def test_str(self):
self.assertEqual(str(self.ext),
"basicConstraints: CA: FALSE, pathLen: None")
def test_ca(self):
self.ext.set_ca(True)
self.assertTrue(self.ext.get_ca())
self.ext.set_ca(False)
self.assertFalse(self.ext.get_ca())
def test_pathlen(self):
self.ext.set_path_len_constraint(1)
self.assertEqual(1, self.ext.get_path_len_constraint())
class TestKeyUsage(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionKeyUsage()
def test_usage_set(self):
self.ext.set_usage('digitalSignature', True)
self.ext.set_usage('keyAgreement', False)
self.assertTrue(self.ext.get_usage('digitalSignature'))
self.assertFalse(self.ext.get_usage('keyAgreement'))
def test_usage_reset(self):
self.ext.set_usage('digitalSignature', True)
self.ext.set_usage('digitalSignature', False)
self.assertFalse(self.ext.get_usage('digitalSignature'))
def test_usage_unset(self):
self.assertFalse(self.ext.get_usage('keyAgreement'))
def test_get_all_usage(self):
self.ext.set_usage('digitalSignature', True)
self.ext.set_usage('keyAgreement', False)
self.ext.set_usage('keyEncipherment', True)
self.assertEqual(set(['digitalSignature', 'keyEncipherment']),
set(self.ext.get_all_usages()))
def test_str(self):
self.ext.set_usage('digitalSignature', True)
self.assertEqual("keyUsage: digitalSignature", str(self.ext))
class TestSubjectAltName(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionSubjectAltName()
self.domain = 'some.domain'
self.ip = netaddr.IPAddress('1.2.3.4')
self.ip6 = netaddr.IPAddress('::1')
def test_dns_ids(self):
self.ext.add_dns_id(self.domain)
self.ext.add_ip(self.ip)
self.assertEqual([self.domain], self.ext.get_dns_ids())
def test_ips(self):
self.ext.add_dns_id(self.domain)
self.ext.add_ip(self.ip)
self.assertEqual([self.ip], self.ext.get_ips())
def test_ipv6(self):
self.ext.add_ip(self.ip6)
self.assertEqual([self.ip6], self.ext.get_ips())
def test_add_ip_invalid(self):
with self.assertRaises(errors.X509Error):
self.ext.add_ip("abcdef")
def test_str(self):
self.ext.add_dns_id(self.domain)
self.ext.add_ip(self.ip)
self.assertEqual("subjectAltName: DNS:some.domain, IP:1.2.3.4",
str(self.ext))

View File

@ -1,92 +0,0 @@
# -*- coding:utf-8 -*-
#
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest
from anchor.X509 import message_digest
class TestMessageDigest(unittest.TestCase):
data = b"this is test data to test with"
def setUp(self):
super(TestMessageDigest, self).setUp()
def tearDown(self):
super(TestMessageDigest, self).tearDown()
def test_bad_algo(self):
self.assertRaises(message_digest.MessageDigestError,
message_digest.MessageDigest,
'BAD')
def test_md5(self):
v = "B2F81E9F287884AF6A8B3E8EFB96C711"
md = message_digest.MessageDigest("md5")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_ripmed160(self):
v = "BA5CCC4574D676266D821269CA77BFFD7FD9FCB0"
md = message_digest.MessageDigest("ripemd160")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_sha224(self):
v = "675170C12E88D549DB0F608AD6857103D7B792F29FACFCC53173F178"
md = message_digest.MessageDigest("sha224")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_sha256(self):
v = "91F672E796E84BECC6F051A47D7392BD789AEA7D55090588F212CF041C862678"
md = message_digest.MessageDigest("sha256")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_sha384(self):
v = ("9667AF42DF2E6B81EE679757BB207A3F9BB7CED49CF838FF3ED8237C9B15291B"
"15")
md = message_digest.MessageDigest("sha384")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_sha512(self):
v = ("283B3ECD8AE687226C3EA46B59F65E5CA50A11735C9C14BED11F0CCB515707B5"
"1031145ED8AE4B35B24B91F26E70AC0ACAC37B5BEE933B28834FE6447D1298CB"
)
md = message_digest.MessageDigest("sha512")
md.update(TestMessageDigest.data)
ret = md.final()
self.assertEqual(ret, v)
def test_algorithms(self):
algs = [
"md5",
"ripemd160",
"sha224",
"sha256",
"sha384",
"sha512"
]
valid = message_digest.MessageDigest.getValidAlgorithms()
for alg in algs:
self.assertTrue(alg in valid)

View File

@ -14,70 +14,21 @@
# License for the specific language governing permissions and limitations
# under the License.
import datetime
import unittest
import mock
from anchor.X509 import errors
from anchor.X509 import utils
from cryptography.hazmat.backends.openssl import backend
class TestASN1String(unittest.TestCase):
# missing in cryptography.io
V_ASN1_UTF8STRING = 12
def test_utf8_string(self):
orig = u"test \u2603 snowman"
encoded = orig.encode('utf-8')
asn1string = backend._lib.ASN1_STRING_type_new(self.V_ASN1_UTF8STRING)
backend._lib.ASN1_STRING_set(asn1string, encoded, len(encoded))
res = utils.asn1_string_to_utf8(asn1string)
self.assertEqual(res, orig)
def test_invalid_string(self):
encoded = b"\xff"
asn1string = backend._lib.ASN1_STRING_type_new(self.V_ASN1_UTF8STRING)
backend._lib.ASN1_STRING_set(asn1string, encoded, len(encoded))
self.assertRaises(errors.ASN1StringError, utils.asn1_string_to_utf8,
asn1string)
class TestASN1Time(unittest.TestCase):
def test_conversion_failure(self):
with mock.patch.object(backend._lib, "ASN1_TIME_to_generalizedtime",
return_value=backend._ffi.NULL):
t = utils.timestamp_to_asn1_time(0)
self.assertRaises(errors.ASN1TimeError,
utils.asn1_time_to_timestamp, t)
def test_round_check(self):
t = 0
asn1_time = utils.timestamp_to_asn1_time(t)
res = utils.asn1_time_to_timestamp(asn1_time)
self.assertEqual(t, res)
def test_generalizedtime_check_failure(self):
with mock.patch.object(backend._lib, "ASN1_GENERALIZEDTIME_check",
return_value=0):
self.assertRaises(errors.ASN1TimeError,
utils.timestamp_to_asn1_time, 0)
class TestTimezone(unittest.TestCase):
def test_utcoffset(self):
tz = utils.create_timezone(1234)
offset = tz.utcoffset(datetime.datetime.now())
self.assertEqual(datetime.timedelta(minutes=1234), offset)
def test_dst(self):
tz = utils.create_timezone(1234)
offset = tz.dst(datetime.datetime.now())
self.assertEqual(datetime.timedelta(0), offset)
def test_name(self):
tz = utils.create_timezone(1234)
name = tz.tzname(datetime.datetime.now())
self.assertIsNone(name)
def test_repr(self):
tz = utils.create_timezone(1234)
self.assertEqual("Timezone +2034", repr(tz))
def test_post_2050(self):
"""Test date post 2050, which causes different encoding."""
t = 2600000000
asn1_time = utils.timestamp_to_asn1_time(t)
res = utils.asn1_time_to_timestamp(asn1_time)
self.assertEqual(t, res)

View File

@ -18,25 +18,18 @@ import unittest
import mock
import sys
import io
import textwrap
from anchor.X509 import certificate
from anchor.X509 import errors as x509_errors
from anchor.X509 import extension
from anchor.X509 import name as x509_name
# find the class representing an open file; it depends on the python version
# it's used later for mocking
if sys.version_info[0] < 3:
file_class = file # noqa
else:
import _io
file_class = _io.TextIOWrapper
from anchor.X509 import utils
class TestX509Cert(unittest.TestCase):
cert_data = textwrap.dedent("""
cert_data = textwrap.dedent(u"""
-----BEGIN CERTIFICATE-----
MIICKjCCAZOgAwIBAgIIfeW6dwGe6wMwDQYJKoZIhvcNAQEFBQAwUjELMAkGA1UE
BhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxFjAUBgNVBAoTDUhlcnAgRGVycCBw
@ -52,17 +45,70 @@ class TestX509Cert(unittest.TestCase):
gTLni27WuVJFVBNoTU1JfoxBSm/RBLdTj92g9N5g
-----END CERTIFICATE-----""")
key_dsa_data = textwrap.dedent("""
-----BEGIN DSA PARAMETERS-----
MIICLAKCAQEA59W1OsK9Tv7DRbxzibGVpBAL2Oz8JhbV3ii7WAat+UfTBLAnfdva
7UE8odu1l8p41N/8H/tDWgPh6tOgdX0YT9HDsILymQxzUEscliFZKmYg7YdSH3Zd
6DglOT7CqYxX0r9gK/BOh8ESe3gqKncnThHnO8Eu9wP8HNcrN00EOqP+fJpbS0lu
iifD9JdFY5YpCsLDIvpPbM0NCDuANPo10N3qqC8BuNiu0VfZpRSBcqzU1kwABT5n
y7+8RMh5Xaa7xnhGctJ9s9n+QfWcF/vbgiDOBttb3d8r8Pqvoou8v7Q38Q6zILhf
hajevqjGqZwodbvbHGfFbWapgBjpBIr4zwIhAOq6uryEHQglirWCGFJLQlkzxghy
ctHBRXGuKYb+ltRTAoIBAHRUFxzd1vhjKQ5atIdG0AiXUNm7/uboe21EJDLf4lkE
7UHDZfwsHXxQHfozzIsp7gHcw7F6AVCgiNRi9vBYOemPswevoWiVKqLTVt1wMogD
EJI6VAQEbBmSrtvyuClCkEAlIY6daX9EV9KqbnetS4/xv4WFQ9FPE47VyQ50vvxK
JSyNZnJ1lN6FUD9R5YYfwERgND8EYJBD10UBKIvtORICTJUfaDAweTWhaVcXUID7
VGNGPauOdVQzWsWTrQn/f/hbXCB/KXgv1l92D6rEoT2j2YrqIv/qD/ZxPwhBfLdr
W241Cb+LT05LVCokRbWUdjfuO8SdSBAIvT9P6umG/uQ=
-----END DSA PARAMETERS-----
-----BEGIN DSA PRIVATE KEY-----
MIIDVwIBAAKCAQEA59W1OsK9Tv7DRbxzibGVpBAL2Oz8JhbV3ii7WAat+UfTBLAn
fdva7UE8odu1l8p41N/8H/tDWgPh6tOgdX0YT9HDsILymQxzUEscliFZKmYg7YdS
H3Zd6DglOT7CqYxX0r9gK/BOh8ESe3gqKncnThHnO8Eu9wP8HNcrN00EOqP+fJpb
S0luiifD9JdFY5YpCsLDIvpPbM0NCDuANPo10N3qqC8BuNiu0VfZpRSBcqzU1kwA
BT5ny7+8RMh5Xaa7xnhGctJ9s9n+QfWcF/vbgiDOBttb3d8r8Pqvoou8v7Q38Q6z
ILhfhajevqjGqZwodbvbHGfFbWapgBjpBIr4zwIhAOq6uryEHQglirWCGFJLQlkz
xghyctHBRXGuKYb+ltRTAoIBAHRUFxzd1vhjKQ5atIdG0AiXUNm7/uboe21EJDLf
4lkE7UHDZfwsHXxQHfozzIsp7gHcw7F6AVCgiNRi9vBYOemPswevoWiVKqLTVt1w
MogDEJI6VAQEbBmSrtvyuClCkEAlIY6daX9EV9KqbnetS4/xv4WFQ9FPE47VyQ50
vvxKJSyNZnJ1lN6FUD9R5YYfwERgND8EYJBD10UBKIvtORICTJUfaDAweTWhaVcX
UID7VGNGPauOdVQzWsWTrQn/f/hbXCB/KXgv1l92D6rEoT2j2YrqIv/qD/ZxPwhB
fLdrW241Cb+LT05LVCokRbWUdjfuO8SdSBAIvT9P6umG/uQCggEBAKrZAppbnKf1
pzSvE3gTaloitAJG+79BML5h1n67EWuv0i+Fq4eUAVJ23R8GR1HrYw6utZoYbu8u
k8eHrArMfTfbFaLwK/Nv33Hfm3aTTXnY6auLNkpbiZXuCQjWBFhb6F+B42V9/JJ8
RJ1UV6Y2ajjjMvpeh0cPlARw5UpKBgQ933DhefCWyFBPsPToFvd3uPO+GUN6VpNY
iR7G0AH3/LSVJRuz5/QCp86uLIoU3fBEf1KGYJrkVKlc9DtcNmDXgpP0d3fK+4Jw
bGvi5AD1sQOWryNujyS/d2K/PAagsD0M6XJFgkEV592OSlygbYtuo3t4AtAy8F0f
VHNXq2l01FMCIQCrkk1749eQg4W6j7HfLFvjbDcuIFTw98IKyEZuZ93cdA==
-----END DSA PRIVATE KEY-----""").encode('ascii')
key_rsa_data = textwrap.dedent("""
-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCeeqg1Qeccv8hqj1BP9KEJX5QsFCxR62M8plPb5t4sLo8UYfZd
6kFLcOP8xzwwvx/eFY6Sux52enQ197o8aMwyP77hMhZqtd8NCgLJMVlUbRhwLti0
SkHFPic0wAg+esfXa6yhd5TxC+bti7MgV/ljA80XQxHH8xOjdOoGN0DHfQIDAQAB
AoGBAJ2ozJpe+7qgGJPaCz3f0izvBwtq7kR49fqqRZbo8HHnx7OxWVVI7LhOkKEy
2/Bq0xsvOu1CdiXL4LynvIDIiQqLaeINzG48Rbk+0HadbXblt3nDkIWdYII6zHKI
W9ewX4KpHEPbrlEO9BjAlAcYsDIvFIMYpQhtQ+0R/gmZ99WJAkEAz5C2a6FIcMbE
o3aTc9ECq99zY7lxh+6aLpUdIeeHyb/QzfGDBdlbpBAkA6EcxSqp0aqH4xIQnYHa
3P5ZCShqSwJBAMN1sb76xq94xkg2cxShPFPAE6xKRFyKqLgsBYVtulOdfOtOnjh9
1SK2XQQfBRIRdG4Q/gDoCP8XQHpJcWMk+FcCQDnuJqulaOVo5GrG5mJ1nCxCAh98
G06X7lo/7dCPoRtSuMExvaK9RlFk29hTeAcjYCAPWzupyA9dtarmJg1jRT8CQCKf
gYnb8D/6+9yk0IPR/9ayCooVacCeyz48hgnZowzWs98WwQ4utAd/GED3obVOpDov
Bl9wus889i3zPoOac+cCQCZHredQcJGd4dlthbVtP2NhuPXz33JuETGR9pXtsDUZ
uX/nSq1oo9kUh/dPOz6aP5Ues1YVe3LExmExPBQfwIE=
-----END RSA PRIVATE KEY-----""").encode('ascii')
def setUp(self):
super(TestX509Cert, self).setUp()
self.cert = certificate.X509Certificate()
self.cert.from_buffer(TestX509Cert.cert_data)
self.cert = certificate.X509Certificate.from_buffer(
TestX509Cert.cert_data)
def tearDown(self):
pass
def test_bad_data_throws(self):
bad_data = (
"some bad data is "
u"some bad data is "
"EHRlc3RAYW5jaG9yLnRlc3QwTDANBgkqhkiG9w0BAQEFAAM7ADA4AjEA6m")
cert = certificate.X509Certificate()
@ -72,119 +118,121 @@ class TestX509Cert(unittest.TestCase):
def test_get_subject_countryName(self):
name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_countryName)
entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "UK")
def test_get_subject_stateOrProvinceName(self):
name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_stateOrProvinceName)
entries = name.get_entries_by_oid(x509_name.OID_stateOrProvinceName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "stateOrProvinceName")
self.assertEqual(entries[0].get_value(), "Narnia")
def test_get_subject_localityName(self):
name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_localityName)
entries = name.get_entries_by_oid(x509_name.OID_localityName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "localityName")
self.assertEqual(entries[0].get_value(), "Funkytown")
def test_get_subject_organizationName(self):
name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_organizationName)
entries = name.get_entries_by_oid(x509_name.OID_organizationName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationName")
self.assertEqual(entries[0].get_value(), "Anchor Testing")
def test_get_subject_organizationUnitName(self):
name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_organizationalUnitName)
entries = name.get_entries_by_oid(x509_name.OID_organizationalUnitName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationalUnitName")
self.assertEqual(entries[0].get_value(), "testing")
def test_get_subject_commonName(self):
name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_commonName)
entries = name.get_entries_by_oid(x509_name.OID_commonName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "commonName")
self.assertEqual(entries[0].get_value(), "anchor.test")
def test_get_subject_emailAddress(self):
name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_pkcs9_emailAddress)
entries = name.get_entries_by_oid(x509_name.OID_pkcs9_emailAddress)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "emailAddress")
self.assertEqual(entries[0].get_value(), "test@anchor.test")
def test_get_issuer_countryName(self):
name = self.cert.get_issuer()
entries = name.get_entries_by_nid(x509_name.NID_countryName)
entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "AU")
def test_get_issuer_stateOrProvinceName(self):
name = self.cert.get_issuer()
entries = name.get_entries_by_nid(x509_name.NID_stateOrProvinceName)
entries = name.get_entries_by_oid(x509_name.OID_stateOrProvinceName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "stateOrProvinceName")
self.assertEqual(entries[0].get_value(), "Some-State")
def test_get_issuer_organizationName(self):
name = self.cert.get_issuer()
entries = name.get_entries_by_nid(x509_name.NID_organizationName)
entries = name.get_entries_by_oid(x509_name.OID_organizationName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationName")
self.assertEqual(entries[0].get_value(), "Herp Derp plc")
def test_get_issuer_commonName(self):
name = self.cert.get_issuer()
entries = name.get_entries_by_nid(x509_name.NID_commonName)
entries = name.get_entries_by_oid(x509_name.OID_commonName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "commonName")
self.assertEqual(entries[0].get_value(), "herp.derp.plc")
def test_set_subject(self):
name = x509_name.X509Name()
name.add_name_entry(x509_name.NID_countryName, 'UK')
name.add_name_entry(x509_name.OID_countryName, 'UK')
self.cert.set_subject(name)
name = self.cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_countryName)
entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "UK")
def test_set_issuer(self):
name = x509_name.X509Name()
name.add_name_entry(x509_name.NID_countryName, 'UK')
name.add_name_entry(x509_name.OID_countryName, 'UK')
self.cert.set_issuer(name)
name = self.cert.get_issuer()
entries = name.get_entries_by_nid(x509_name.NID_countryName)
entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "UK")
def test_read_from_file(self):
open_name = 'anchor.X509.certificate.open'
f = io.StringIO(TestX509Cert.cert_data)
with mock.patch(open_name, create=True) as mock_open:
mock_open.return_value = mock.MagicMock(spec=file_class)
m_file = mock_open.return_value.__enter__.return_value
m_file.read.return_value = TestX509Cert.cert_data
mock_open.return_value = f
cert = certificate.X509Certificate()
cert.from_file("some_path")
cert = certificate.X509Certificate.from_file("some_path")
name = cert.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_countryName)
entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(entries[0].get_value(), "UK")
def test_get_fingerprint(self):
fp = self.cert.get_fingerprint()
self.assertEqual(fp, "56D61AC583BDDD4B44EEB479EF6C998F")
self.assertEqual(fp, "634A8CD10C81F1CD7A7E140921B4D9CA")
def test_get_fingerprint_invalid_hash(self):
with self.assertRaises(x509_errors.X509Error):
self.cert.get_fingerprint('no_such_hash')
def test_sign_bad_md(self):
self.assertRaises(x509_errors.X509Error,
@ -194,7 +242,7 @@ class TestX509Cert(unittest.TestCase):
def test_sign_bad_key(self):
self.assertRaises(x509_errors.X509Error,
self.cert.sign,
self.cert._ffi.NULL)
None)
def test_get_version(self):
v = self.cert.get_version()
@ -222,3 +270,40 @@ class TestX509Cert(unittest.TestCase):
self.cert.set_not_after(0) # seconds since epoch
val = self.cert.get_not_after()
self.assertEqual(0, val)
def test_get_extensions(self):
exts = self.cert.get_extensions()
self.assertEqual(2, len(exts))
def test_add_extensions(self):
bc = extension.X509ExtensionBasicConstraints()
self.cert.add_extension(bc, 2)
exts = self.cert.get_extensions()
self.assertEqual(3, len(exts))
def test_add_extensions_invalid(self):
with self.assertRaises(x509_errors.X509Error):
self.cert.add_extension("abcdef", 2)
def test_sign_rsa_sha1(self):
key = utils.get_private_key_from_bytes(self.key_rsa_data)
self.cert.sign(key, 'sha1')
self.assertEqual(self.cert.get_fingerprint(),
"BA1B5C97D68EAE738FD10657E6F0B143")
def test_sign_dsa_sha1(self):
key = utils.get_private_key_from_bytes(self.key_dsa_data)
self.cert.sign(key, 'sha1')
# TODO(stan): add verification; DSA signatures are not
# deterministic which means right now we can only make sure it
# doesn't raise exceptions
def test_sign_unknown_key(self):
key = object()
with self.assertRaises(x509_errors.X509Error):
self.cert.sign(key, 'sha1')
def test_sign_unknown_hash(self):
key = utils.get_private_key_from_bytes(self.key_rsa_data)
with self.assertRaises(x509_errors.X509Error):
self.cert.sign(key, 'no_such_hash')

View File

@ -14,29 +14,21 @@
# License for the specific language governing permissions and limitations
# under the License.
import sys
import io
import textwrap
import unittest
from cryptography.hazmat.backends.openssl import backend
import mock
from pyasn1_modules import rfc2459
from anchor.X509 import errors as x509_errors
from anchor.X509 import extension
from anchor.X509 import name as x509_name
from anchor.X509 import signing_request
# find the class representing an open file; it depends on the python version
# it's used later for mocking
if sys.version_info[0] < 3:
file_class = file # noqa
else:
import _io
file_class = _io.TextIOWrapper
class TestX509Csr(unittest.TestCase):
csr_data = textwrap.dedent("""
csr_data = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST-----
MIIBWTCCARMCAQAwgZQxCzAJBgNVBAYTAlVLMQ8wDQYDVQQIEwZOYXJuaWExEjAQ
BgNVBAcTCUZ1bmt5dG93bjEXMBUGA1UEChMOQW5jaG9yIFRlc3RpbmcxEDAOBgNV
@ -50,41 +42,53 @@ class TestX509Csr(unittest.TestCase):
def setUp(self):
super(TestX509Csr, self).setUp()
self.csr = signing_request.X509Csr()
self.csr.from_buffer(TestX509Csr.csr_data)
self.csr = signing_request.X509Csr.from_buffer(TestX509Csr.csr_data)
def tearDown(self):
pass
def test_get_pubkey_bits(self):
# some OpenSSL gumph to test a reasonable attribute of the pubkey
def test_get_pubkey(self):
pubkey = self.csr.get_pubkey()
size = backend._lib.EVP_PKEY_bits(pubkey)
self.assertEqual(size, 384)
self.assertEqual(pubkey['algorithm']['algorithm'],
rfc2459.rsaEncryption)
def test_get_extensions(self):
exts = self.csr.get_extensions()
self.assertEqual(len(exts), 2)
self.assertEqual(str(exts[0]), "basicConstraints CA:FALSE")
self.assertEqual(str(exts[1]), ("keyUsage Digital Signature, Non "
"Repudiation, Key Encipherment"))
self.assertFalse(exts[0].get_ca())
self.assertIsNone(exts[0].get_path_len_constraint())
self.assertTrue(exts[1].get_usage('digitalSignature'))
self.assertTrue(exts[1].get_usage('nonRepudiation'))
self.assertTrue(exts[1].get_usage('keyEncipherment'))
self.assertFalse(exts[1].get_usage('cRLSign'))
def test_add_extension(self):
csr = signing_request.X509Csr()
bc = extension.X509ExtensionBasicConstraints()
csr.add_extension(bc)
self.assertEqual(1, len(csr.get_extensions()))
csr.add_extension(bc)
self.assertEqual(2, len(csr.get_extensions()))
def test_add_extension_invalid_type(self):
csr = signing_request.X509Csr()
with self.assertRaises(x509_errors.X509Error):
csr.add_extension(1234)
def test_read_from_file(self):
open_name = 'anchor.X509.signing_request.open'
f = io.StringIO(TestX509Csr.csr_data)
with mock.patch(open_name, create=True) as mock_open:
mock_open.return_value = mock.MagicMock(spec=file_class)
m_file = mock_open.return_value.__enter__.return_value
m_file.read.return_value = TestX509Csr.csr_data
csr = signing_request.X509Csr()
csr.from_file("some_path")
mock_open.return_value = f
csr = signing_request.X509Csr.from_file("some_path")
name = csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_countryName)
entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(entries[0].get_value(), "UK")
def test_bad_data_throws(self):
bad_data = (
"some bad data is "
u"some bad data is "
"EHRlc3RAYW5jaG9yLnRlc3QwTDANBgkqhkiG9w0BAQEFAAM7ADA4AjEA6m")
csr = signing_request.X509Csr()
@ -94,49 +98,49 @@ class TestX509Csr(unittest.TestCase):
def test_get_subject_countryName(self):
name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_countryName)
entries = name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "UK")
def test_get_subject_stateOrProvinceName(self):
name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_stateOrProvinceName)
entries = name.get_entries_by_oid(x509_name.OID_stateOrProvinceName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "stateOrProvinceName")
self.assertEqual(entries[0].get_value(), "Narnia")
def test_get_subject_localityName(self):
name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_localityName)
entries = name.get_entries_by_oid(x509_name.OID_localityName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "localityName")
self.assertEqual(entries[0].get_value(), "Funkytown")
def test_get_subject_organizationName(self):
name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_organizationName)
entries = name.get_entries_by_oid(x509_name.OID_organizationName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationName")
self.assertEqual(entries[0].get_value(), "Anchor Testing")
def test_get_subject_organizationUnitName(self):
name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_organizationalUnitName)
entries = name.get_entries_by_oid(x509_name.OID_organizationalUnitName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationalUnitName")
self.assertEqual(entries[0].get_value(), "testing")
def test_get_subject_commonName(self):
name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_commonName)
entries = name.get_entries_by_oid(x509_name.OID_commonName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "commonName")
self.assertEqual(entries[0].get_value(), "anchor.test")
def test_get_subject_emailAddress(self):
name = self.csr.get_subject()
entries = name.get_entries_by_nid(x509_name.NID_pkcs9_emailAddress)
entries = name.get_entries_by_oid(x509_name.OID_pkcs9_emailAddress)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "emailAddress")
self.assertEqual(entries[0].get_value(), "test@anchor.test")

View File

@ -24,18 +24,18 @@ class TestX509Name(unittest.TestCase):
def setUp(self):
super(TestX509Name, self).setUp()
self.name = x509_name.X509Name()
self.name.add_name_entry(x509_name.NID_countryName,
self.name.add_name_entry(x509_name.OID_countryName,
"UK") # must be 2 chars
self.name.add_name_entry(x509_name.NID_stateOrProvinceName, "test_ST")
self.name.add_name_entry(x509_name.NID_localityName, "test_L")
self.name.add_name_entry(x509_name.NID_organizationName, "test_O")
self.name.add_name_entry(x509_name.NID_organizationalUnitName,
self.name.add_name_entry(x509_name.OID_stateOrProvinceName, "test_ST")
self.name.add_name_entry(x509_name.OID_localityName, "test_L")
self.name.add_name_entry(x509_name.OID_organizationName, "test_O")
self.name.add_name_entry(x509_name.OID_organizationalUnitName,
"test_OU")
self.name.add_name_entry(x509_name.NID_commonName, "test_CN")
self.name.add_name_entry(x509_name.NID_pkcs9_emailAddress,
self.name.add_name_entry(x509_name.OID_commonName, "test_CN")
self.name.add_name_entry(x509_name.OID_pkcs9_emailAddress,
"test_Email")
self.name.add_name_entry(x509_name.NID_surname, "test_SN")
self.name.add_name_entry(x509_name.NID_givenName, "test_GN")
self.name.add_name_entry(x509_name.OID_surname, "test_SN")
self.name.add_name_entry(x509_name.OID_givenName, "test_GN")
def tearDown(self):
pass
@ -48,7 +48,7 @@ class TestX509Name(unittest.TestCase):
def test_set_bad_c_throws(self):
self.assertRaises(x509_errors.X509Error,
self.name.add_name_entry,
x509_name.NID_countryName, "BAD_WRONG")
x509_name.OID_countryName, "BAD_WRONG")
def test_name_to_string(self):
val = str(self.name)
@ -57,53 +57,53 @@ class TestX509Name(unittest.TestCase):
"SN=test_SN/GN=test_GN"))
def test_get_countryName(self):
entries = self.name.get_entries_by_nid(x509_name.NID_countryName)
entries = self.name.get_entries_by_oid(x509_name.OID_countryName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "countryName")
self.assertEqual(entries[0].get_value(), "UK")
def test_get_stateOrProvinceName(self):
entries = self.name.get_entries_by_nid(
x509_name.NID_stateOrProvinceName)
entries = self.name.get_entries_by_oid(
x509_name.OID_stateOrProvinceName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "stateOrProvinceName")
self.assertEqual(entries[0].get_value(), "test_ST")
def test_get_subject_localityName(self):
entries = self.name.get_entries_by_nid(x509_name.NID_localityName)
entries = self.name.get_entries_by_oid(x509_name.OID_localityName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "localityName")
self.assertEqual(entries[0].get_value(), "test_L")
def test_get_organizationName(self):
entries = self.name.get_entries_by_nid(x509_name.NID_organizationName)
entries = self.name.get_entries_by_oid(x509_name.OID_organizationName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationName")
self.assertEqual(entries[0].get_value(), "test_O")
def test_get_organizationUnitName(self):
entries = self.name.get_entries_by_nid(
x509_name.NID_organizationalUnitName)
entries = self.name.get_entries_by_oid(
x509_name.OID_organizationalUnitName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "organizationalUnitName")
self.assertEqual(entries[0].get_value(), "test_OU")
def test_get_commonName(self):
entries = self.name.get_entries_by_nid(x509_name.NID_commonName)
entries = self.name.get_entries_by_oid(x509_name.OID_commonName)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "commonName")
self.assertEqual(entries[0].get_value(), "test_CN")
def test_get_emailAddress(self):
entries = self.name.get_entries_by_nid(
x509_name.NID_pkcs9_emailAddress)
entries = self.name.get_entries_by_oid(
x509_name.OID_pkcs9_emailAddress)
self.assertEqual(len(entries), 1)
self.assertEqual(entries[0].get_name(), "emailAddress")
self.assertEqual(entries[0].get_value(), "test_Email")
def test_entry_to_string(self):
entries = self.name.get_entries_by_nid(
x509_name.NID_pkcs9_emailAddress)
entries = self.name.get_entries_by_oid(
x509_name.OID_pkcs9_emailAddress)
self.assertEqual(len(entries), 1)
self.assertEqual(str(entries[0]), "emailAddress: test_Email")

View File

@ -29,7 +29,7 @@ class CertificateOpsTests(unittest.TestCase):
def setUp(self):
# This is a CSR with CN=anchor-test.example.com
self.expected_cn = "anchor-test.example.com"
self.csr = textwrap.dedent("""
self.csr = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST-----
MIIEsDCCApgCAQAwazELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWEx
FjAUBgNVBAcTDU1vdW50YWluIFZpZXcxDTALBgNVBAoTBEFjbWUxIDAeBgNVBAMT
@ -67,16 +67,16 @@ class CertificateOpsTests(unittest.TestCase):
"""Test basic success path for parse_csr."""
result = certificate_ops.parse_csr(self.csr, 'pem')
subject = result.get_subject()
actual_cn = subject.get_entries_by_nid(
x509_name.NID_commonName)[0].get_value()
actual_cn = subject.get_entries_by_oid(
x509_name.OID_commonName)[0].get_value()
self.assertEqual(actual_cn, self.expected_cn)
def test_parse_csr_success2(self):
"""Test basic success path for parse_csr."""
result = certificate_ops.parse_csr(self.csr, 'PEM')
subject = result.get_subject()
actual_cn = subject.get_entries_by_nid(
x509_name.NID_commonName)[0].get_value()
actual_cn = subject.get_entries_by_oid(
x509_name.OID_commonName)[0].get_value()
self.assertEqual(actual_cn, self.expected_cn)
def test_parse_csr_fail1(self):

View File

@ -58,7 +58,7 @@ class TestFunctional(unittest.TestCase):
}
"""
csr_good = textwrap.dedent("""
csr_good = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST-----
MIIEDzCCAncCAQAwcjELMAkGA1UEBhMCR0IxEzARBgNVBAgTCkNhbGlmb3JuaWEx
FjAUBgNVBAcTDVNhbiBGcmFuY3NpY28xDTALBgNVBAoTBE9TU0cxDTALBgNVBAsT
@ -84,7 +84,7 @@ class TestFunctional(unittest.TestCase):
tR7XqQGqJKca/vRTfJ+zIAxMEeH1N9Lx7YBO6VdVja+yG1E=
-----END CERTIFICATE REQUEST-----""")
csr_bad = textwrap.dedent("""
csr_bad = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST-----
MIIBWTCCARMCAQAwgZQxCzAJBgNVBAYTAlVLMQ8wDQYDVQQIEwZOYXJuaWExEjAQ
BgNVBAcTCUZ1bmt5dG93bjEXMBUGA1UEChMOQW5jaG9yIFRlc3RpbmcxEDAOBgNV
@ -149,8 +149,7 @@ class TestFunctional(unittest.TestCase):
resp = self.app.post('/sign', data, expect_errors=False)
self.assertEqual(200, resp.status_int)
cert = X509_cert.X509Certificate()
cert.from_buffer(resp.text)
cert = X509_cert.X509Certificate.from_buffer(resp.text)
# make sure the cert is what we asked for
self.assertEqual(("/C=GB/ST=California/L=San Francsico/O=OSSG"

View File

@ -24,7 +24,7 @@ from anchor.X509 import signing_request
class TestBaseValidators(unittest.TestCase):
csr_data_with_cn = textwrap.dedent("""
csr_data_with_cn = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST-----
MIIDBTCCAe0CAQAwgb8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlh
MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMSEwHwYDVQQKExhPcGVuU3RhY2sgU2Vj
@ -51,7 +51,7 @@ class TestBaseValidators(unittest.TestCase):
CN=ossg.test.com/emailAddress=openstack-security@lists.openstack.org
"""
csr_data_without_cn = textwrap.dedent("""
csr_data_without_cn = textwrap.dedent(u"""
-----BEGIN CERTIFICATE REQUEST-----
MIIC7TCCAdUCAQAwgacxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlh
MRYwFAYDVQQHDA1TYW4gRnJhbmNpc2NvMSEwHwYDVQQKDBhPcGVuU3RhY2sgU2Vj
@ -79,8 +79,8 @@ class TestBaseValidators(unittest.TestCase):
def setUp(self):
super(TestBaseValidators, self).setUp()
self.csr = signing_request.X509Csr()
self.csr.from_buffer(TestBaseValidators.csr_data_with_cn)
self.csr = signing_request.X509Csr.from_buffer(
TestBaseValidators.csr_data_with_cn)
def tearDown(self):
super(TestBaseValidators, self).tearDown()
@ -89,7 +89,8 @@ class TestBaseValidators(unittest.TestCase):
name = validators.csr_get_cn(self.csr)
self.assertEqual(name, "ossg.test.com")
self.csr.from_buffer(TestBaseValidators.csr_data_without_cn)
self.csr = signing_request.X509Csr.from_buffer(
TestBaseValidators.csr_data_without_cn)
with self.assertRaises(validators.ValidationError):
validators.csr_get_cn(self.csr)

View File

@ -20,7 +20,9 @@ import mock
import netaddr
from anchor import validators
from anchor.X509 import extension as x509_ext
from anchor.X509 import name as x509_name
from anchor.X509 import signing_request as x509_csr
class TestValidators(unittest.TestCase):
@ -45,262 +47,174 @@ class TestValidators(unittest.TestCase):
'example.com', []))
def test_common_name_with_two_CN(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = "subjectAltName"
csr_config = {
'get_extensions.return_value': [ext_mock],
'get_subject.return_value.get_entries_by_nid.return_value':
['dummy_value', 'dummy_value'],
}
csr_mock = mock.MagicMock(**csr_config)
csr = x509_csr.X509Csr()
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "dummy_value")
name.add_name_entry(x509_name.OID_commonName, "dummy_value")
with self.assertRaises(validators.ValidationError) as e:
validators.common_name(
csr=csr_mock,
csr=csr,
allowed_domains=[],
allowed_networks=[])
self.assertEqual("Too many CNs in the request", str(e.exception))
def test_common_name_no_CN(self):
csr_config = {
'get_subject.return_value.__len__.return_value': 0,
'get_subject.return_value.get_entries_by_nid.return_value':
[]
}
csr_mock = mock.MagicMock(**csr_config)
csr = x509_csr.X509Csr()
with self.assertRaises(validators.ValidationError) as e:
validators.common_name(
csr=csr_mock,
csr=csr,
allowed_domains=[],
allowed_networks=[])
self.assertEqual("Alt subjects have to exist if the main subject"
" doesn't", str(e.exception))
def test_common_name_good_CN(self):
cn_mock = mock.MagicMock()
cn_mock.get_value.return_value = 'master.test.com'
csr_config = {
'get_subject.return_value.__len__.return_value': 1,
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
csr = x509_csr.X509Csr()
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "master.test.com")
self.assertEqual(
None,
validators.common_name(
csr=csr_mock,
csr=csr,
allowed_domains=['.test.com'],
)
)
def test_common_name_bad_CN(self):
name = x509_name.X509Name()
name.add_name_entry(x509_name.NID_commonName, 'test.baddomain.com')
csr_mock = mock.MagicMock()
csr_mock.get_subject.return_value = name
csr = x509_csr.X509Csr()
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, 'test.baddomain.com')
with self.assertRaises(validators.ValidationError) as e:
validators.common_name(
csr=csr_mock,
csr=csr,
allowed_domains=['.test.com'])
self.assertEqual("Domain 'test.baddomain.com' not allowed (does not "
"match known domains)", str(e.exception))
def test_common_name_ip_good(self):
name = x509_name.X509Name()
name.add_name_entry(x509_name.NID_commonName, '10.1.1.1')
csr_mock = mock.MagicMock()
csr_mock.get_subject.return_value = name
csr = x509_csr.X509Csr()
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, '10.1.1.1')
self.assertEqual(
None,
validators.common_name(
csr=csr_mock,
csr=csr,
allowed_domains=['.test.com'],
allowed_networks=['10/8']
)
)
def test_common_name_ip_bad(self):
name = x509_name.X509Name()
name.add_name_entry(x509_name.NID_commonName, '15.1.1.1')
csr_mock = mock.MagicMock()
csr_mock.get_subject.return_value = name
csr = x509_csr.X509Csr()
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, '15.1.1.1')
with self.assertRaises(validators.ValidationError) as e:
validators.common_name(
csr=csr_mock,
csr=csr,
allowed_domains=['.test.com'],
allowed_networks=['10/8'])
self.assertEqual("Address '15.1.1.1' not allowed (does not "
"match known networks)", str(e.exception))
def test_alternative_names_good_domain(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'DNS:master.test.com'
ext_mock.get_name.return_value = 'subjectAltName'
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionSubjectAltName()
ext.add_dns_id('master.test.com')
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual(
None,
validators.alternative_names(
csr=csr_mock,
csr=csr,
allowed_domains=['.test.com'],
)
)
def test_alternative_names_bad_domain(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'DNS:test.baddomain.com'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionSubjectAltName()
ext.add_dns_id('test.baddomain.com')
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names(
csr=csr_mock,
csr=csr,
allowed_domains=['.test.com'])
self.assertEqual("Domain 'test.baddomain.com' not allowed (doesn't "
"match known domains)", str(e.exception))
def test_alternative_names_ext(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'BAD,10.1.1.1'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names(
csr=csr_mock,
allowed_domains=['.test.com'])
self.assertEqual("Alt name should have 2 parts, but found: 'BAD'",
str(e.exception))
def test_alternative_names_ip_good(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'IP Address:10.1.1.1'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionSubjectAltName()
ext.add_ip(netaddr.IPAddress('10.1.1.1'))
csr.add_extension(ext)
self.assertEqual(
None,
validators.alternative_names_ip(
csr=csr_mock,
csr=csr,
allowed_domains=['.test.com'],
allowed_networks=['10/8']
)
)
def test_alternative_names_ip_bad(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'IP Address:10.1.1.1'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionSubjectAltName()
ext.add_ip(netaddr.IPAddress('10.1.1.1'))
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names_ip(
csr=csr_mock,
csr=csr,
allowed_domains=['.test.com'],
allowed_networks=['99/8'])
self.assertEqual("Address '10.1.1.1' not allowed (doesn't match known "
self.assertEqual("IP '10.1.1.1' not allowed (doesn't match known "
"networks)", str(e.exception))
def test_alternative_names_ip_bad_domain(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'DNS:test.baddomain.com'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionSubjectAltName()
ext.add_dns_id('test.baddomain.com')
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names_ip(
csr=csr_mock,
csr=csr,
allowed_domains=['.test.com'])
self.assertEqual("Domain 'test.baddomain.com' not allowed (doesn't "
"match known domains)", str(e.exception))
def test_alternative_names_ip_ext(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'BAD,10.1.1.1'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names_ip(
csr=csr_mock,
allowed_domains=['.test.com'])
self.assertEqual("Alt name should have 2 parts, but found: 'BAD'",
str(e.exception))
def test_alternative_names_ip_bad_ext(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'BAD:VALUE'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
validators.alternative_names_ip(
csr=csr_mock,
allowed_domains=['.test.com'],
allowed_networks=['99/8'])
self.assertEqual("Alt name 'VALUE' has unexpected type 'BAD'",
str(e.exception))
def test_server_group_no_prefix1(self):
cn_mock = mock.MagicMock()
cn_mock.get_value.return_value = 'master.test.com'
csr_config = {
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
csr = x509_csr.X509Csr()
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "master.test.com")
self.assertEqual(
None,
validators.server_group(
auth_result=None,
csr=csr_mock,
csr=csr,
group_prefixes={}
)
)
def test_server_group_no_prefix2(self):
cn_mock = mock.MagicMock()
cn_mock.get_value.return_value = 'nv_master.test.com'
csr_config = {
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
csr = x509_csr.X509Csr()
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "nv_master.test.com")
self.assertEqual(
None,
validators.server_group(
auth_result=None,
csr=csr_mock,
csr=csr,
group_prefixes={}
)
)
@ -310,20 +224,15 @@ class TestValidators(unittest.TestCase):
auth_result = mock.Mock()
auth_result.groups = ['nova']
cn_mock = mock.MagicMock()
cn_mock.get_value.return_value = 'nv_master.test.com'
csr_config = {
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
csr = x509_csr.X509Csr()
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "nv_master.test.com")
self.assertEqual(
None,
validators.server_group(
auth_result=auth_result,
csr=csr_mock,
csr=csr,
group_prefixes={'nv': 'nova', 'sw': 'swift'}
)
)
@ -332,50 +241,41 @@ class TestValidators(unittest.TestCase):
auth_result = mock.Mock()
auth_result.groups = ['glance']
cn_mock = mock.MagicMock()
cn_mock.get_value.return_value = 'nv-master.test.com'
csr_config = {
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
csr = x509_csr.X509Csr()
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "nv-master.test.com")
with self.assertRaises(validators.ValidationError) as e:
validators.server_group(
auth_result=auth_result,
csr=csr_mock,
csr=csr,
group_prefixes={'nv': 'nova', 'sw': 'swift'})
self.assertEqual("Server prefix doesn't match user groups",
str(e.exception))
def test_extensions_bad(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'BAD'
ext_mock.get_value.return_value = 'BAD'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionKeyUsage()
ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.extensions(
csr=csr_mock,
allowed_extensions=['GOOD-1', 'GOOD-2'])
self.assertEqual("Extension 'BAD' not allowed", str(e.exception))
csr=csr,
allowed_extensions=['basicConstraints', 'nameConstraints'])
self.assertEqual("Extension 'keyUsage' not allowed", str(e.exception))
def test_extensions_good(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'GOOD-1'
ext_mock.get_value.return_value = 'GOOD-1'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionKeyUsage()
ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
self.assertEqual(
None,
validators.extensions(
csr=csr_mock,
allowed_extensions=['GOOD-1', 'GOOD-2']
csr=csr,
allowed_extensions=['basicConstraints', 'keyUsage']
)
)
@ -384,204 +284,158 @@ class TestValidators(unittest.TestCase):
'Non Repudiation',
'Key Encipherment']
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'keyUsage'
ext_mock.get_value.return_value = 'Domination'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionKeyUsage()
ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.key_usage(
csr=csr_mock,
csr=csr,
allowed_usage=allowed_usage)
self.assertEqual("Found some not allowed key usages: "
"Domination", str(e.exception))
"keyCertSign", str(e.exception))
def test_key_usage_good(self):
allowed_usage = ['Digital Signature',
'Non Repudiation',
'Key Encipherment']
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'keyUsage'
ext_mock.get_value.return_value = 'Key Encipherment, Digital Signature'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionKeyUsage()
ext.set_usage('keyEncipherment', True)
ext.set_usage('digitalSignature', True)
csr.add_extension(ext)
self.assertEqual(
None,
validators.key_usage(
csr=csr_mock,
csr=csr,
allowed_usage=allowed_usage
)
)
def test_ca_status_good1(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'basicConstraints'
ext_mock.get_value.return_value = 'CA:TRUE'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionBasicConstraints()
ext.set_ca(True)
csr.add_extension(ext)
self.assertEqual(
None,
validators.ca_status(
csr=csr_mock,
csr=csr,
ca_requested=True
)
)
def test_ca_status_good2(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'basicConstraints'
ext_mock.get_value.return_value = 'CA:FALSE'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionBasicConstraints()
ext.set_ca(False)
csr.add_extension(ext)
self.assertEqual(
None,
validators.ca_status(
csr=csr_mock,
csr=csr,
ca_requested=False
)
)
def test_ca_status_bad(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'basicConstraints'
ext_mock.get_value.return_value = 'CA:FALSE'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
def test_ca_status_forbidden(self):
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionBasicConstraints()
ext.set_ca(True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
csr=csr_mock,
ca_requested=True)
self.assertEqual("Invalid CA status, 'CA:FALSE' requested",
csr=csr,
ca_requested=False)
self.assertEqual("CA status requested, but not allowed",
str(e.exception))
def test_ca_status_bad_format1(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'basicConstraints'
ext_mock.get_value.return_value = 'CA~FALSE'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
def test_ca_status_bad(self):
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionBasicConstraints()
ext.set_ca(False)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
csr=csr_mock,
ca_requested=False)
self.assertEqual("Invalid basic constraints flag", str(e.exception))
def test_ca_status_bad_format2(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'basicConstraints'
ext_mock.get_value.return_value = 'CA:FALSE:DERP'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
csr=csr_mock,
ca_requested=False)
self.assertEqual("Invalid basic constraints flag", str(e.exception))
csr=csr,
ca_requested=True)
self.assertEqual("CA flags required",
str(e.exception))
def test_ca_status_pathlen(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'basicConstraints'
ext_mock.get_value.return_value = 'pathlen:somthing'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionBasicConstraints()
ext.set_path_len_constraint(1)
csr.add_extension(ext)
self.assertEqual(
None,
validators.ca_status(
csr=csr_mock,
csr=csr,
ca_requested=False
)
)
def test_ca_status_bad_value(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'basicConstraints'
ext_mock.get_value.return_value = 'BAD:VALUE'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
csr=csr_mock,
ca_requested=False)
self.assertEqual("Invalid basic constraints option", str(e.exception))
def test_ca_status_key_usage_bad1(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'keyUsage'
ext_mock.get_value.return_value = 'Certificate Sign'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionKeyUsage()
ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
csr=csr_mock,
csr=csr,
ca_requested=False)
self.assertEqual("Key usage doesn't match requested CA status "
"(keyCertSign/cRLSign: True/False)", str(e.exception))
def test_ca_status_key_usage_good1(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'keyUsage'
ext_mock.get_value.return_value = 'Certificate Sign'
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionKeyUsage()
ext.set_usage('keyCertSign', True)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
self.assertEqual(
None,
validators.ca_status(
csr=csr_mock,
ca_requested=True)
self.assertEqual("Key usage doesn't match requested CA status "
"(keyCertSign/cRLSign: True/False)", str(e.exception))
csr=csr,
ca_requested=True
)
)
def test_ca_status_key_usage_bad2(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'keyUsage'
ext_mock.get_value.return_value = 'CRL Sign'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionKeyUsage()
ext.set_usage('cRLSign', True)
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError) as e:
validators.ca_status(
csr=csr_mock,
csr=csr,
ca_requested=False)
self.assertEqual("Key usage doesn't match requested CA status "
"(keyCertSign/cRLSign: False/True)", str(e.exception))
def test_ca_status_key_usage_good2(self):
ext_mock = mock.MagicMock()
ext_mock.get_name.return_value = 'keyUsage'
ext_mock.get_value.return_value = 'CRL Sign'
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionKeyUsage()
ext.set_usage('cRLSign', True)
csr.add_extension(ext)
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
with self.assertRaises(validators.ValidationError) as e:
self.assertEqual(
None,
validators.ca_status(
csr=csr_mock,
ca_requested=True)
self.assertEqual("Key usage doesn't match requested CA status "
"(keyCertSign/cRLSign: False/True)", str(e.exception))
csr=csr,
ca_requested=True
)
)
def test_source_cidrs_good(self):
request = mock.Mock(client_addr='127.0.0.1')
@ -612,99 +466,65 @@ class TestValidators(unittest.TestCase):
str(e.exception))
def test_blacklist_names_good(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'DNS:blah.good'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionSubjectAltName()
ext.add_dns_id('blah.good')
csr.add_extension(ext)
self.assertEqual(
None,
validators.blacklist_names(
csr=csr_mock,
csr=csr,
domains=['.bad'],
)
)
def test_blacklist_names_bad(self):
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'DNS:blah.bad'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionSubjectAltName()
ext.add_dns_id('blah.bad')
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError):
validators.blacklist_names(
csr=csr_mock,
csr=csr,
domains=['.bad'],
)
def test_blacklist_names_bad_cn(self):
cn_mock = mock.MagicMock()
cn_mock.get_value.return_value = 'blah.bad'
csr_config = {
'get_subject.return_value.get_entries_by_nid.return_value':
[cn_mock],
}
csr_mock = mock.MagicMock(**csr_config)
csr = x509_csr.X509Csr()
name = csr.get_subject()
name.add_name_entry(x509_name.OID_commonName, "blah.bad")
with self.assertRaises(validators.ValidationError):
validators.blacklist_names(
csr=csr_mock,
csr=csr,
domains=['.bad'],
)
def test_blacklist_names_mix(self):
ext1_mock = mock.MagicMock()
ext1_mock.get_value.return_value = 'DNS:blah.good'
ext1_mock.get_name.return_value = 'subjectAltName'
ext2_mock = mock.MagicMock()
ext2_mock.get_value.return_value = 'DNS:blah.bad'
ext2_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext1_mock, ext2_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionSubjectAltName()
ext.add_dns_id('blah.bad')
ext.add_dns_id('blah.good')
csr.add_extension(ext)
with self.assertRaises(validators.ValidationError):
validators.blacklist_names(
csr=csr_mock,
csr=csr,
domains=['.bad'],
)
def test_blacklist_names_ignore_unknown(self):
# only validate the DNS type - other types may look like domains
# by accident
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'RANDOM_TYPE:random.bad'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
self.assertEqual(
None,
validators.blacklist_names(
csr=csr_mock,
domains=['.bad'],
)
)
def test_blacklist_names_empty_list(self):
# empty blacklist should pass everything through
ext_mock = mock.MagicMock()
ext_mock.get_value.return_value = 'DNS:some.name'
ext_mock.get_name.return_value = 'subjectAltName'
csr_mock = mock.MagicMock()
csr_mock.get_extensions.return_value = [ext_mock]
csr = x509_csr.X509Csr()
ext = x509_ext.X509ExtensionSubjectAltName()
ext.add_dns_id('blah.good')
csr.add_extension(ext)
self.assertEqual(
None,
validators.blacklist_names(
csr=csr_mock,
csr=csr,
)
)