diff --git a/src/actions.yaml b/src/actions.yaml index 60226fc..f331ccf 100644 --- a/src/actions.yaml +++ b/src/actions.yaml @@ -8,3 +8,60 @@ authorize-charm: - token refresh-secrets: description: Refresh secret_id's and re-issue retrieval tokens for secrets endpoints +get-csr: + description: Get intermediate CA csr + properties: + # Depending on the configuration of CA that will sign the CSRs it + # may be necessary to ensure these fields match the CA + country: + type: string + description: >- + The C (Country) values in the subject field of the CSR + province: + type: string + description: >- + The ST (Province) values in the subject field of the CSR. + organization: + type: string + description: >- + The O (Organization) values in the subject field of the CSR. + organizational-unit: + type: string + description: >- + The OU (OrganizationalUnit) values in the subject field of the CSR. +upload-signed-csr: + description: Upload a signed csr to vault + properties: + pem: + type: string + description: base64 encoded certificate + allow-subdomains: + type: boolean + default: True + description: >- + Specifies if clients can request certificates with + enforce-hostnames: + type: boolean + default: False + description: >- + Specifies if only valid host names are allowed + for CNs, DNS SANs, and the host part of email addresses. + allow-any-name: + type: boolean + default: True + description: >- + Specifies if clients can request any CN + max-ttl: + type: string + default: '87598h' + description: >- + Specifies the maximum Time To Live + root-ca: + type: string + description: >- + The certificate of the root CA which will be passed out to client on + the certificate relation along with the intermediate CA cert + required: + - pem +reissue-certificates: + description: Reissue certificates to all clients diff --git a/src/actions/actions.py b/src/actions/actions.py index 0425536..8c62d67 100755 --- a/src/actions/actions.py +++ b/src/actions/actions.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import os import sys @@ -26,7 +27,7 @@ basic.init_config_states() import charmhelpers.core.hookenv as hookenv import charm.vault as vault - +import charm.vault_pki as vault_pki import charms.reactive from charms.reactive.flags import set_flag @@ -50,11 +51,49 @@ def refresh_secrets(*args): set_flag('secrets.refresh') +def get_intermediate_csrs(*args): + if not hookenv.is_leader(): + hookenv.action_fail('Please run action on lead unit') + action_config = hookenv.action_get() or {} + csrs = vault_pki.get_csr( + ttl=action_config.get('ttl'), + country=action_config.get('country'), + province=action_config.get('province'), + organization=action_config.get('organization'), + organizational_unit=action_config.get('organizational-unit')) + hookenv.action_set({'output': csrs}) + + +def upload_signed_csr(*args): + if not hookenv.is_leader(): + hookenv.action_fail('Please run action on lead unit') + return + + action_config = hookenv.action_get() + root_ca = action_config.get('root-ca') + if root_ca: + hookenv.leader_set( + {'root-ca': base64.b64decode(root_ca).decode("utf-8")}) + vault_pki.upload_signed_csr( + base64.b64decode(action_config['pem']).decode("utf-8"), + allowed_domains=action_config.get('allowed-domains'), + allow_subdomains=action_config.get('allow-subdomains'), + enforce_hostnames=action_config.get('enforce-hostnames'), + allow_any_name=action_config.get('allow-any-name'), + max_ttl=action_config.get('max-ttl')) + + +def reissue_certificates(*args): + charms.reactive.set_flag('certificates.reissue.requested') + # Actions to function mapping, to allow for illegal python action names that # can map to a python function. ACTIONS = { "authorize-charm": authorize_charm_action, "refresh-secrets": refresh_secrets, + "get-csr": get_intermediate_csrs, + "upload-signed-csr": upload_signed_csr, + "reissue-certificates": reissue_certificates, } diff --git a/src/actions/get-csr b/src/actions/get-csr new file mode 120000 index 0000000..405a394 --- /dev/null +++ b/src/actions/get-csr @@ -0,0 +1 @@ +actions.py \ No newline at end of file diff --git a/src/actions/reissue-certificates b/src/actions/reissue-certificates new file mode 120000 index 0000000..405a394 --- /dev/null +++ b/src/actions/reissue-certificates @@ -0,0 +1 @@ +actions.py \ No newline at end of file diff --git a/src/actions/upload-signed-csr b/src/actions/upload-signed-csr new file mode 120000 index 0000000..405a394 --- /dev/null +++ b/src/actions/upload-signed-csr @@ -0,0 +1 @@ +actions.py \ No newline at end of file diff --git a/src/layer.yaml b/src/layer.yaml index 1ac370a..ffe3c45 100644 --- a/src/layer.yaml +++ b/src/layer.yaml @@ -8,6 +8,7 @@ includes: - interface:etcd - interface:hacluster - interface:vault-kv + - interface:tls-certificates options: basic: packages: diff --git a/src/lib/charm/vault.py b/src/lib/charm/vault.py index 10f179a..fe80fcd 100644 --- a/src/lib/charm/vault.py +++ b/src/lib/charm/vault.py @@ -55,6 +55,11 @@ path "sys/mounts/charm-*" { capabilities = ["create", "read", "update", "delete", "sudo"] } +# Allow charm- prefixes pki backends to be used +path "charm-pki-*" { + capabilities = ["create", "read", "update", "delete", "list", "sudo"] +} + # Allow discovery of secrets backends path "sys/mounts" { capabilities = ["read"] @@ -63,9 +68,6 @@ path "sys/mounts/" { capabilities = ["list"] }""" -VAULT_HEALTH_URL = '{vault_addr}/v1/sys/health' -VAULT_LOCALHOST_URL = "http://127.0.0.1:8220" - SECRET_BACKEND_HCL = """ path "{backend}/{hostname}/*" {{ capabilities = ["create", "read", "update", "delete", "list"] @@ -77,6 +79,17 @@ path "{backend}/*" {{ capabilities = ["create", "read", "update", "delete", "list"] }} """ +VAULT_LOCALHOST_URL = "http://127.0.0.1:8220" +VAULT_HEALTH_URL = '{vault_addr}/v1/sys/health' + + +class VaultNotReady(Exception): + """Exception raised for units in error state + """ + + def __init__(self, reason): + message = "Vault is not ready ({})".format(reason) + super(VaultNotReady, self).__init__(message) def binding_address(binding): @@ -100,6 +113,16 @@ get_cluster_url = functools.partial(get_vault_url, binding='cluster', port=8201) +def get_access_address(): + protocol = 'http' + addr = hookenv.config('dns-ha-access-record') + addr = addr or hookenv.config('vip') + addr = addr or binding_address('access') + if charms.reactive.is_state('vault.ssl.available'): + protocol = 'https' + return '{}://{}:{}'.format(protocol, addr, 8200) + + def enable_approle_auth(client): """Enable the approle auth method within vault @@ -164,6 +187,22 @@ def get_client(url=None): return hvac.Client(url=url or get_api_url()) +def get_local_client(): + """Provide a client for talking to the vault api + + :returns: vault client + :rtype: hvac.Client + """ + client = get_client(url=VAULT_LOCALHOST_URL) + app_role_id = get_local_charm_access_role_id() + if not app_role_id: + hookenv.log('Could not retrieve app_role_id', level=hookenv.DEBUG) + raise VaultNotReady("Cannot initialise local client") + client = hvac.Client(url=VAULT_LOCALHOST_URL) + client.auth_approle(app_role_id) + return client + + @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(10), reraise=True) @@ -319,3 +358,35 @@ def generate_role_secret_id(client, name, cidr): response = client.write('auth/approle/role/{}/secret-id'.format(name), wrap_ttl='1h', cidr_list=cidr) return response['wrap_info']['token'] + + +def is_backend_mounted(client, name): + """Check if the supplied backend is mounted + + :returns: Whether mount point is in use + :rtype: bool + """ + return '{}/'.format(name) in client.list_secret_backends() + + +def vault_ready_for_clients(): + """Check if vault is ready to recieve client requests""" + @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, max=10), + stop=tenacity.stop_after_attempt(10), + reraise=True) + def _check_vault_status(client): + if (not host.service_running('vault') or + not client.is_initialized() or + client.is_sealed()): + return False + return True + + # NOTE: use localhost listener as policy only allows 127.0.0.1 to + # administer the local vault instances via the charm + client = get_client(url=VAULT_LOCALHOST_URL) + + status_ok = _check_vault_status(client) + if status_ok: + return True + else: + return False diff --git a/src/lib/charm/vault_pki.py b/src/lib/charm/vault_pki.py new file mode 100644 index 0000000..e3d8c50 --- /dev/null +++ b/src/lib/charm/vault_pki.py @@ -0,0 +1,373 @@ +import datetime +import json +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.x509.extensions import ExtensionNotFound +from cryptography.x509.oid import NameOID, ExtensionOID + +import charmhelpers.contrib.network.ip as ch_ip +import charmhelpers.core.hookenv as hookenv + +from . import vault + +CHARM_PKI_MP = "charm-pki-local" +CHARM_PKI_ROLE = "local" + + +def configure_pki_backend(client, name, ttl=None): + """Ensure a pki backend is enabled + + :param client: Vault client + :type client: hvac.Client + :param name: Name of backend to enable + :type name: str + :param ttl: TTL + :type ttl: str + """ + if not vault.is_backend_mounted(client, name): + client.enable_secret_backend( + backend_type='pki', + description='Charm created PKI backend', + mount_point=name, + # Default ttl to 1 Year + config={'max-lease-ttl': ttl or '87600h'}) + + +def is_ca_ready(client, name, role): + """Check if CA is ready for use + + :returns: Whether CA is ready + :rtype: bool + """ + return client.read('{}/roles/{}'.format(name, role)) is not None + + +def get_chain(name=None): + """Check if CA is ready for use + + :returns: Whether CA is ready + :rtype: bool + """ + client = vault.get_local_client() + if not name: + name = CHARM_PKI_MP + return client.read('{}/cert/ca_chain'.format(name))['data']['certificate'] + + +def get_ca(): + """Check if CA is ready for use + + :returns: Whether CA is ready + :rtype: bool + """ + return hookenv.leader_get('root-ca') + + +def get_server_certificate(cn, ip_sans=None, alt_names=None): + """Create a certificate and key for the given cn inc sans if requested + + :param cn: Common name to use for certifcate + :type cn: string + :param ip_sans: List of IP address to create san records for + :type ip_sans: [str1,...] + :param alt_names: List of names to create san records for + :type alt_names: [str1,...] + :raises: vault.VaultNotReady + :returns: The newly created cert, issuing ca and key + :rtype: tuple + """ + client = vault.get_local_client() + configure_pki_backend(client, CHARM_PKI_MP) + if is_ca_ready(client, CHARM_PKI_MP, CHARM_PKI_ROLE): + config = { + 'common_name': cn} + if ip_sans: + config['ip_sans'] = ','.join(ip_sans) + if alt_names: + config['alt_names'] = ','.join(alt_names) + bundle = client.write( + '{}/issue/{}'.format(CHARM_PKI_MP, CHARM_PKI_ROLE), + **config)['data'] + else: + raise vault.VaultNotReady("CA not ready") + return bundle + + +def get_csr(ttl=None, country=None, province=None, + organization=None, organizational_unit=None): + """Generate a csr for the vault Intermediate Authority + + Depending on the configuration of the CA signing this CR some of the + fields embedded in the CSR may have to match the CA. + + :param ttl: TTL + :type ttl: string + :param country: The C (Country) values in the subject field of the CSR + :type country: string + :param province: The ST (Province) values in the subject field of the CSR. + :type province: string + :param organization: The O (Organization) values in the subject field of + the CSR + :type organization: string + :param organizational_unit: The OU (OrganizationalUnit) values in the + subject field of the CSR. + :type organizational_unit: string + :returns: Certificate signing request + :rtype: string + """ + client = vault.get_local_client() + if not vault.is_backend_mounted(client, CHARM_PKI_MP): + configure_pki_backend(client, CHARM_PKI_MP) + config = { + 'common_name': ("Vault Intermediate Certificate Authority " + "({})".format(CHARM_PKI_MP)), + # Year - 1 hour + 'ttl': ttl or '87599h', + 'country': country, + 'province': province, + 'ou': organizational_unit, + 'organization': organization} + config = {k: v for k, v in config.items() if v} + csr_info = client.write( + '{}/intermediate/generate/internal'.format(CHARM_PKI_MP), + **config) + return csr_info['data']['csr'] + + +def upload_signed_csr(pem, allowed_domains, allow_subdomains=True, + enforce_hostnames=False, allow_any_name=True, + max_ttl=None): + """Upload signed csr to intermediate pki + + :param pem: signed csr in pem format + :type pem: string + :param allow_subdomains: Specifies if clients can request certificates with + CNs that are subdomains of the CNs: + :type allow_subdomains: bool + :param enforce_hostnames: Specifies if only valid host names are allowed + for CNs, DNS SANs, and the host part of email + addresses. + :type enforce_hostnames: bool + :param allow_any_name: Specifies if clients can request any CN + :type allow_any_name: bool + :param max_ttl: Specifies the maximum Time To Live + :type max_ttl: str + """ + client = vault.get_local_client() + # Set the intermediate certificate authorities signing certificate to the + # signed certificate. + # (hvac module doesn't expose a method for this, hence the _post call) + client._post( + 'v1/{}/intermediate/set-signed'.format(CHARM_PKI_MP), + json={'certificate': pem}) + # Generated certificates can have the CRL location and the location of the + # issuing certificate encoded. + addr = vault.get_access_address() + client.write( + '{}/config/urls'.format(CHARM_PKI_MP), + issuing_certificates="{}/v1/{}/ca".format(addr, CHARM_PKI_MP), + crl_distribution_points="{}/v1/{}/crl".format(addr, CHARM_PKI_MP) + ) + # Configure a role which maps to a policy for accessing this pki + if not max_ttl: + max_ttl = '87598h' + client.write( + '{}/roles/{}'.format(CHARM_PKI_MP, CHARM_PKI_ROLE), + allowed_domains=allowed_domains, + allow_subdomains=allow_subdomains, + enforce_hostnames=enforce_hostnames, + allow_any_name=allow_any_name, + max_ttl=max_ttl) + + +def sort_sans(sans): + """Split SANS into IP sans and name SANS + + :param sans: List of SANS + :type sans: list + :returns: List of IP sans and list of Name SANS + :rtype: ([], []) + """ + ip_sans = {s for s in sans if ch_ip.is_ip(s)} + alt_names = set(sans).difference(ip_sans) + return sorted(list(ip_sans)), sorted(list(alt_names)) + + +def get_vault_units(): + """Return all vault units related to this one + + :returns: List of vault units + :rtype: [] + """ + peer_rid = hookenv.relation_ids('cluster')[0] + vault_units = [hookenv.local_unit()] + vault_units.extend(hookenv.related_units(relid=peer_rid)) + return vault_units + + +def get_matching_cert_from_relation(unit_name, cn, ip_sans, alt_names): + """Scan vault units relation data for a cert that matches + + Scan the relation data that each vault unit has sent to the clients + to find a cert that matchs the cn and sans. If one exists return it. + If mutliple are found then return the one with the lastest valid_to + date + + :param unit_name: Return the unit_name to look for serts for. + :type unit_name: string + :param cn: Common name to use for certifcate + :type cn: string + :param ip_sans: List of IP address to create san records for + :type ip_sans: [str1,...] + :param alt_names: List of names to create san records for + :type alt_names: [str1,...] + :returns: Cert and key if found + :rtype: {} + """ + vault_units = get_vault_units() + rid = hookenv.relation_id('certificates', unit_name) + match = [] + for vunit in vault_units: + sent_data = hookenv.relation_get(unit=vunit, rid=rid) + name = unit_name.replace('/', '_') + cert_name = '{}.server.cert'.format(name) + cert_key = '{}.server.key'.format(name) + candidate_cert = sent_data.get(cert_name) + if candidate_cert and cert_matches_request(candidate_cert, cn, + ip_sans, alt_names): + match.append({ + 'certificate': sent_data.get(cert_name), + 'private_key': sent_data.get(cert_key)}) + batch_request_raw = sent_data.get('processed_requests') + if batch_request_raw: + batch_request = json.loads(batch_request_raw) + for sent_cn in batch_request.keys(): + if sent_cn == cn: + candidate_cert = batch_request[cn]['cert'] + candidate_key = batch_request[cn]['key'] + if cert_matches_request(candidate_cert, cn, ip_sans, + alt_names): + match.append({ + 'certificate': candidate_cert, + 'private_key': candidate_key}) + return select_newest(match) + + +def cert_matches_request(cert_pem, cn, ip_sans, alt_names): + """Test if the cert matches the supplied attributes + + If the cn is duplicated in either the cert or the supplied alt_names + it is removed before performing the check. + + :param cert_pem: Certificate in pem format to check + :type cert_pem: string + :param cn: Common name to use for certifcate + :type cn: string + :param ip_sans: List of IP address to create san records for + :type ip_sans: [str1,...] + :param alt_names: List of names to create san records for + :type alt_names: [str1,...] + :returns: Whether cert matches criteria + :rtype: bool + """ + cert_data = certificate_information(cert_pem) + if cn == cert_data['cn']: + try: + cert_data['alt_names'].remove(cn) + except ValueError: + pass + try: + alt_names.remove(cn) + except ValueError: + pass + else: + return False + if sorted(cert_data['alt_names']) == sorted(alt_names) and \ + sorted(cert_data['ip_sans']) == sorted(ip_sans): + return True + else: + return False + + +def certificate_information(cert_pem): + """Extract cn, sans and expiration info from certificate + + :param cert_pem: Certificate in pem format to check + :type cert_pem: string + :returns: Certificate information in a dictionary + :rtype: {} + """ + cert = x509.load_pem_x509_certificate(cert_pem.encode(), default_backend()) + bundle = { + 'cn': cert.subject.get_attributes_for_oid( + NameOID.COMMON_NAME)[0].value, + 'not_valid_after': cert.not_valid_after} + try: + sans = cert.extensions.get_extension_for_oid( + ExtensionOID.SUBJECT_ALTERNATIVE_NAME) + alt_names = sans.value.get_values_for_type(x509.DNSName) + ip_sans = sans.value.get_values_for_type(x509.IPAddress) + ip_sans = [str(ip) for ip in ip_sans] + except ExtensionNotFound: + alt_names = ip_sans = [] + bundle['ip_sans'] = ip_sans + bundle['alt_names'] = alt_names + return bundle + + +def select_newest(certs): + """Iterate over the certificate bundle and return the one with the latest + not_valid_after date + + :returns: Certificate bundle + :rtype: {} + """ + latest = datetime.datetime.utcfromtimestamp(0) + candidate = None + for bundle in certs: + cert = x509.load_pem_x509_certificate( + bundle['certificate'].encode(), + default_backend()) + not_valid_after = cert.not_valid_after + if not_valid_after > latest: + latest = not_valid_after + candidate = bundle + return candidate + + +def process_cert_request(cn, sans, unit_name, reissue_requested): + """Return a certificate and key matching the requeest + + Return a certificate and key matching the request. This may be an existing + certificate and key if one exists and reissue_requested is False. + + :param cn: Common name to use for certifcate + :type cn: string + :param sans: List of SANS + :type sans: list + :param unit_name: Return the unit_name to look for serts for. + :type unit_name: string + :returns: Cert and key + :rtype: {} + """ + bundle = {} + ip_sans, alt_names = sort_sans(sans) + if not reissue_requested: + bundle = get_matching_cert_from_relation( + unit_name, + cn, + list(ip_sans), + list(alt_names)) + hookenv.log( + "Found existing cert for {}, reusing".format(cn), + level=hookenv.DEBUG) + if not bundle: + hookenv.log( + "Requesting new cert for {}".format(cn), + level=hookenv.DEBUG) + # Create the server certificate based on the info in request. + bundle = get_server_certificate( + cn, + ip_sans=ip_sans, + alt_names=alt_names) + return bundle diff --git a/src/metadata.yaml b/src/metadata.yaml index 9b8bfd0..2e7012d 100644 --- a/src/metadata.yaml +++ b/src/metadata.yaml @@ -34,6 +34,8 @@ provides: scope: container secrets: interface: vault-kv + certificates: + interface: tls-certificates peers: cluster: interface: vault-ha diff --git a/src/reactive/vault_handlers.py b/src/reactive/vault_handlers.py index 6c42431..e367eb8 100644 --- a/src/reactive/vault_handlers.py +++ b/src/reactive/vault_handlers.py @@ -61,6 +61,7 @@ from charms.reactive.flags import ( from charms.layer import snap import lib.charm.vault as vault +import lib.charm.vault_pki as vault_pki # See https://www.vaultproject.io/docs/configuration/storage/postgresql.html @@ -622,3 +623,69 @@ def _assess_status(): 'disabled' if mlock_disabled else 'enabled' ) ) + + +@when('leadership.is_leader') +@when_any('certificates.server.cert.requested', + 'certificates.reissue.requested') +def create_server_cert(): + if not vault.vault_ready_for_clients(): + log('Unable to process new secret backend requests,' + ' deferring until vault is fully configured', level=DEBUG) + return + reissue_requested = is_flag_set('certificates.reissue.requested') + tls = endpoint_from_flag('certificates.available') + server_requests = tls.get_server_requests() + for unit_name, request in server_requests.items(): + log( + 'Processing certificate requests from {}'.format(unit_name), + level=DEBUG) + # Process request for a single certificate + cn = request.get('common_name') + sans = request.get('sans') + if cn and sans: + log( + 'Processing single certificate requests for {}'.format(cn), + level=DEBUG) + try: + bundle = vault_pki.process_cert_request( + cn, + sans, + unit_name, + reissue_requested) + except vault.VaultNotReady: + # Cannot continue if vault is not ready + return + # Set the certificate and key for the unit on the relationship. + tls.set_server_cert( + unit_name, + bundle['certificate'], + bundle['private_key']) + # Process request for a batch of certificates + cert_requests = request.get('cert_requests') + if cert_requests: + log( + 'Processing batch of requests from {}'.format(unit_name), + level=DEBUG) + for cn, crequest in cert_requests.items(): + log('Processing requests for {}'.format(cn), level=DEBUG) + try: + bundle = vault_pki.process_cert_request( + cn, + crequest.get('sans'), + unit_name, + reissue_requested) + except vault.VaultNotReady: + # Cannot continue if vault is not ready + return + tls.add_server_cert( + unit_name, + cn, + bundle['certificate'], + bundle['private_key']) + tls.set_server_multicerts(unit_name) + tls.set_ca(vault_pki.get_ca()) + chain = vault_pki.get_chain() + if chain: + tls.set_chain(chain) + clear_flag('certificates.reissue.requested') diff --git a/src/tests/bundles/overlays/local-charm-overlay.yaml.j2 b/src/tests/bundles/overlays/local-charm-overlay.yaml.j2 deleted file mode 100644 index 2411ac2..0000000 --- a/src/tests/bundles/overlays/local-charm-overlay.yaml.j2 +++ /dev/null @@ -1,3 +0,0 @@ -applications: - vault: - charm: ../../../vault diff --git a/src/tests/bundles/xenial-mysql.yaml b/src/tests/bundles/xenial-mysql.yaml index 1a81223..0a85520 100644 --- a/src/tests/bundles/xenial-mysql.yaml +++ b/src/tests/bundles/xenial-mysql.yaml @@ -3,10 +3,20 @@ services: vault: num_units: 1 series: xenial - charm: ../../../vault + charm: vault mysql: charm: cs:mysql num_units: 1 + keystone: + charm: cs:~openstack-charmers-next/keystone + num_units: 1 + options: + admin-password: openstack + openstack-origin: cloud:xenial-queens relations: - - vault:shared-db - mysql:shared-db +- - keystone:shared-db + - mysql:shared-db +- - vault:certificates + - keystone:certificates diff --git a/src/wheelhouse.txt b/src/wheelhouse.txt index 97cb110..fde4600 100644 --- a/src/wheelhouse.txt +++ b/src/wheelhouse.txt @@ -1,3 +1,4 @@ +netifaces hvac tenacity pbr diff --git a/test-requirements.txt b/test-requirements.txt index 0cf669c..41a3cfe 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,4 +1,5 @@ # Unit test requirements +netifaces hvac flake8>=2.2.4,<=2.4.1 os-testr>=0.4.1 @@ -9,3 +10,4 @@ psycopg2 git+https://github.com/openstack/charms.openstack#egg=charms.openstack tenacity pbr +cryptography diff --git a/unit_tests/test_lib_charm_vault_pki.py b/unit_tests/test_lib_charm_vault_pki.py new file mode 100644 index 0000000..0701b2b --- /dev/null +++ b/unit_tests/test_lib_charm_vault_pki.py @@ -0,0 +1,549 @@ +import datetime +import json +import mock +from unittest.mock import patch +from cryptography.x509.extensions import ExtensionNotFound + +import lib.charm.vault_pki as vault_pki +import unit_tests.test_utils + + +class TestLibCharmVaultPKI(unit_tests.test_utils.CharmTestCase): + + def setUp(self): + super(TestLibCharmVaultPKI, self).setUp() + self.obj = vault_pki + self.patches = [] + self.patch_all() + + @patch.object(vault_pki.vault, 'is_backend_mounted') + def test_configure_pki_backend(self, is_backend_mounted): + client_mock = mock.MagicMock() + is_backend_mounted.return_value = False + vault_pki.configure_pki_backend( + client_mock, + 'my_backend', + ttl=42) + client_mock.enable_secret_backend.assert_called_once_with( + backend_type='pki', + config={'max-lease-ttl': 42}, + description='Charm created PKI backend', + mount_point='my_backend') + + @patch.object(vault_pki.vault, 'is_backend_mounted') + def test_configure_pki_backend_default_ttl(self, is_backend_mounted): + client_mock = mock.MagicMock() + is_backend_mounted.return_value = False + vault_pki.configure_pki_backend( + client_mock, + 'my_backend') + client_mock.enable_secret_backend.assert_called_once_with( + backend_type='pki', + config={'max-lease-ttl': '87600h'}, + description='Charm created PKI backend', + mount_point='my_backend') + + @patch.object(vault_pki.vault, 'is_backend_mounted') + def test_configure_pki_backend_noop(self, is_backend_mounted): + client_mock = mock.MagicMock() + is_backend_mounted.return_value = True + vault_pki.configure_pki_backend( + client_mock, + 'my_backend', + ttl=42) + self.assertFalse(client_mock.enable_secret_backend.called) + + def test_is_ca_ready(self): + client_mock = mock.MagicMock() + vault_pki.is_ca_ready(client_mock, 'my_backend', 'local') + client_mock.read.assert_called_once_with('my_backend/roles/local') + + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_chain(self, get_local_client): + client_mock = mock.MagicMock() + client_mock.read.return_value = { + 'data': { + 'certificate': 'somecert'}} + get_local_client.return_value = client_mock + self.assertEqual( + vault_pki.get_chain('my_backend'), + 'somecert') + client_mock.read.assert_called_once_with( + 'my_backend/cert/ca_chain') + + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_chain_default_pki(self, get_local_client): + client_mock = mock.MagicMock() + client_mock.read.return_value = { + 'data': { + 'certificate': 'somecert'}} + get_local_client.return_value = client_mock + self.assertEqual( + vault_pki.get_chain(), + 'somecert') + client_mock.read.assert_called_once_with( + 'charm-pki-local/cert/ca_chain') + + @patch.object(vault_pki.hookenv, 'leader_get') + def test_get_ca(self, leader_get): + leader_get.return_value = 'ROOTCA' + self.assertEqual(vault_pki.get_ca(), 'ROOTCA') + + @patch.object(vault_pki, 'is_ca_ready') + @patch.object(vault_pki, 'configure_pki_backend') + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_server_certificate(self, get_local_client, + configure_pki_backend, is_ca_ready): + client_mock = mock.MagicMock() + get_local_client.return_value = client_mock + is_ca_ready.return_value = True + vault_pki.get_server_certificate('bob.example.com') + client_mock.write.assert_called_once_with( + 'charm-pki-local/issue/local', + common_name='bob.example.com' + ) + + @patch.object(vault_pki, 'is_ca_ready') + @patch.object(vault_pki, 'configure_pki_backend') + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_server_certificate_sans(self, get_local_client, + configure_pki_backend, + is_ca_ready): + client_mock = mock.MagicMock() + get_local_client.return_value = client_mock + is_ca_ready.return_value = True + vault_pki.get_server_certificate( + 'bob.example.com', + ip_sans=['10.10.10.10', '192.197.45.23'], + alt_names=['localunit', 'public.bob.example.com']) + client_mock.write.assert_called_once_with( + 'charm-pki-local/issue/local', + alt_names='localunit,public.bob.example.com', + common_name='bob.example.com', + ip_sans='10.10.10.10,192.197.45.23' + ) + + @patch.object(vault_pki.vault, 'is_backend_mounted') + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_csr(self, get_local_client, is_backend_mounted): + is_backend_mounted.return_value = True + client_mock = mock.MagicMock() + get_local_client.return_value = client_mock + client_mock.write.return_value = { + 'data': { + 'csr': 'somecert'}} + self.assertEqual(vault_pki.get_csr(), 'somecert') + client_mock.write.assert_called_once_with( + 'charm-pki-local/intermediate/generate/internal', + common_name=('Vault Intermediate Certificate Authority' + ' (charm-pki-local)'), + ttl='87599h') + + @patch.object(vault_pki, 'configure_pki_backend') + @patch.object(vault_pki.vault, 'is_backend_mounted') + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_csr_config_backend(self, get_local_client, is_backend_mounted, + configure_pki_backend): + is_backend_mounted.return_value = False + client_mock = mock.MagicMock() + get_local_client.return_value = client_mock + client_mock.write.return_value = { + 'data': { + 'csr': 'somecert'}} + self.assertEqual(vault_pki.get_csr(), 'somecert') + client_mock.write.assert_called_once_with( + 'charm-pki-local/intermediate/generate/internal', + common_name=('Vault Intermediate Certificate Authority' + ' (charm-pki-local)'), + ttl='87599h') + configure_pki_backend.assert_called_once_with( + client_mock, + 'charm-pki-local') + + @patch.object(vault_pki.vault, 'is_backend_mounted') + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_csr_explicit(self, get_local_client, is_backend_mounted): + is_backend_mounted.return_value = False + client_mock = mock.MagicMock() + get_local_client.return_value = client_mock + client_mock.write.return_value = { + 'data': { + 'csr': 'somecert'}} + self.assertEqual( + vault_pki.get_csr( + ttl='2h', + country='GB', + province='Kent', + organizational_unit='My Department', + organization='My Company'), + 'somecert') + client_mock.write.assert_called_once_with( + 'charm-pki-local/intermediate/generate/internal', + common_name=('Vault Intermediate Certificate Authority ' + '(charm-pki-local)'), + country='GB', + organization='My Company', + ou='My Department', + province='Kent', + ttl='2h') + + @patch.object(vault_pki.vault, 'get_access_address') + @patch.object(vault_pki.vault, 'get_local_client') + def test_upload_signed_csr(self, get_local_client, get_access_address): + get_access_address.return_value = 'https://vault.local:8200' + client_mock = mock.MagicMock() + get_local_client.return_value = client_mock + local_url = 'https://vault.local:8200/v1/charm-pki-local' + write_calls = [ + mock.call( + 'charm-pki-local/config/urls', + issuing_certificates='{}/ca'.format(local_url), + crl_distribution_points='{}/crl'.format(local_url)), + mock.call( + 'charm-pki-local/roles/local', + allowed_domains='exmaple.com', + allow_subdomains=True, + enforce_hostnames=False, + allow_any_name=True, + max_ttl='87598h') + ] + vault_pki.upload_signed_csr('MYPEM', 'exmaple.com') + client_mock._post.assert_called_once_with( + 'v1/charm-pki-local/intermediate/set-signed', + json={'certificate': 'MYPEM'}) + client_mock.write.assert_has_calls(write_calls) + + @patch.object(vault_pki.vault, 'get_access_address') + @patch.object(vault_pki.vault, 'get_local_client') + def test_upload_signed_csr_explicit(self, get_local_client, + get_access_address): + client_mock = mock.MagicMock() + get_access_address.return_value = 'https://vault.local:8200' + get_local_client.return_value = client_mock + local_url = 'https://vault.local:8200/v1/charm-pki-local' + write_calls = [ + mock.call( + 'charm-pki-local/config/urls', + issuing_certificates='{}/ca'.format(local_url), + crl_distribution_points='{}/crl'.format(local_url)), + mock.call( + 'charm-pki-local/roles/local', + allowed_domains='exmaple.com', + allow_subdomains=False, + enforce_hostnames=True, + allow_any_name=False, + max_ttl='42h') + ] + vault_pki.upload_signed_csr( + 'MYPEM', + 'exmaple.com', + allow_subdomains=False, + enforce_hostnames=True, + allow_any_name=False, + max_ttl='42h') + client_mock._post.assert_called_once_with( + 'v1/charm-pki-local/intermediate/set-signed', + json={'certificate': 'MYPEM'}) + client_mock.write.assert_has_calls(write_calls) + + def test_sort_sans(self): + self.assertEqual( + vault_pki.sort_sans([ + '10.0.0.10', + '10.0.0.20', + '10.0.0.10', + 'admin.local', + 'admin.local', + 'public.local']), + (['10.0.0.10', '10.0.0.20'], ['admin.local', 'public.local'])) + + @patch.object(vault_pki.hookenv, 'related_units') + @patch.object(vault_pki.hookenv, 'relation_ids') + @patch.object(vault_pki.hookenv, 'local_unit') + def test_get_vault_units(self, local_unit, relation_ids, related_units): + local_unit.return_value = 'vault/3' + relation_ids.return_value = 'certificates:34' + related_units.return_value = ['vault/1', 'vault/5'] + self.assertEqual( + vault_pki.get_vault_units(), + ['vault/3', 'vault/1', 'vault/5']) + + def _get_matching_cert_from_relation(self, vault_relation, cert_match, + func_args, + expected_bundle, + expected_newest_calls): + self.patch_object(vault_pki.hookenv, 'relation_get') + self.patch_object(vault_pki.hookenv, 'relation_id') + self.patch_object(vault_pki, 'select_newest') + self.patch_object(vault_pki, 'cert_matches_request') + self.patch_object(vault_pki, 'get_vault_units') + self.relation_get.side_effect = lambda unit, rid: vault_relation[unit] + self.cert_matches_request.side_effect = \ + lambda w, x, y, z: cert_match[w] + self.get_vault_units.return_value = ['vault/3', 'vault/1', 'vault/5'] + self.relation_id.return_value = 'certificates:23' + self.select_newest.side_effect = lambda x: x[0] + rget_calls = [ + mock.call(unit='vault/3', rid='certificates:23'), + mock.call(unit='vault/1', rid='certificates:23'), + mock.call(unit='vault/5', rid='certificates:23')] + self.assertEqual( + vault_pki.get_matching_cert_from_relation(*func_args), + expected_bundle) + self.relation_get.assert_has_calls(rget_calls) + self.select_newest.assert_called_once_with(expected_newest_calls) + + def test_get_matching_cert_from_relation(self): + _rinfo = { + 'vault/1': { + 'keystone_0.server.cert': 'V1CERT', + 'keystone_0.server.key': 'V1KEY'}, + 'vault/3': {}, + 'vault/5': {}, + } + _cmatch = { + 'V1CERT': True + } + self._get_matching_cert_from_relation( + _rinfo, + _cmatch, + ('keystone/0', 'ks.bob.com', ['10.0.0.23'], ['junit1.maas.local']), + {'private_key': 'V1KEY', 'certificate': 'V1CERT'}, + [{'private_key': 'V1KEY', 'certificate': 'V1CERT'}]) + + def test_get_matching_cert_from_relation_batch_single(self): + _rinfo = { + 'vault/1': {}, + 'vault/3': { + 'processed_requests': json.dumps({ + 'ks.bob.com': { + 'cert': 'V3CERT', + 'key': 'V3KEY'}})}, + 'vault/5': {}, + } + _cmatch = { + 'V3CERT': True + } + self._get_matching_cert_from_relation( + _rinfo, + _cmatch, + ('keystone/0', 'ks.bob.com', ['10.0.0.23'], ['junit1.maas.local']), + {'private_key': 'V3KEY', 'certificate': 'V3CERT'}, + [{'private_key': 'V3KEY', 'certificate': 'V3CERT'}]) + + def test_get_matching_cert_from_relation_batch_multi_one_match(self): + _rinfo = { + 'vault/1': {}, + 'vault/3': { + 'processed_requests': json.dumps({ + 'ks.bob.com': { + 'cert': 'V3CERT', + 'key': 'V3KEY'}})}, + 'vault/5': { + 'processed_requests': json.dumps({ + 'glance.bob.com': { + 'cert': 'V5CERT', + 'key': 'V5KEY'}})}, + } + _cmatch = { + 'V3CERT': True + } + self._get_matching_cert_from_relation( + _rinfo, + _cmatch, + ('keystone/0', 'ks.bob.com', ['10.0.0.23'], ['junit1.maas.local']), + {'private_key': 'V3KEY', 'certificate': 'V3CERT'}, + [{'private_key': 'V3KEY', 'certificate': 'V3CERT'}]) + + def test_get_matching_cert_from_relation_batch_multi_two_match(self): + _rinfo = { + 'vault/1': {}, + 'vault/3': { + 'processed_requests': json.dumps({ + 'ks.bob.com': { + 'cert': 'V3CERT', + 'key': 'V3KEY'}})}, + 'vault/5': { + 'processed_requests': json.dumps({ + 'ks.bob.com': { + 'cert': 'V5CERT', + 'key': 'V5KEY'}})}, + } + _cmatch = { + 'V3CERT': True, + 'V5CERT': True + } + self._get_matching_cert_from_relation( + _rinfo, + _cmatch, + ('keystone/0', 'ks.bob.com', ['10.0.0.23'], ['junit1.maas.local']), + {'private_key': 'V3KEY', 'certificate': 'V3CERT'}, + [ + {'private_key': 'V3KEY', 'certificate': 'V3CERT'}, + {'private_key': 'V5KEY', 'certificate': 'V5CERT'}]) + + def test_get_matching_cert_from_relation_batch_multi_sans_mismatch(self): + _rinfo = { + 'vault/1': {}, + 'vault/3': { + 'processed_requests': json.dumps({ + 'ks.bob.com': { + 'cert': 'V3CERT', + 'key': 'V3KEY'}})}, + 'vault/5': { + 'processed_requests': json.dumps({ + 'ks.bob.com': { + 'cert': 'V5CERT', + 'key': 'V5KEY'}})}, + } + _cmatch = { + 'V3CERT': False, + 'V5CERT': True + } + self._get_matching_cert_from_relation( + _rinfo, + _cmatch, + ('keystone/0', 'ks.bob.com', ['10.0.0.23'], ['junit1.maas.local']), + {'private_key': 'V5KEY', 'certificate': 'V5CERT'}, + [{'private_key': 'V5KEY', 'certificate': 'V5CERT'}]) + + @patch.object(vault_pki, 'certificate_information') + def test_cert_matches_request(self, certificate_information): + certificate_information.return_value = { + 'cn': 'ks.bob.com', + 'ip_sans': ['10.0.0.10'], + 'alt_names': ['unit1.bob.com']} + self.assertTrue( + vault_pki.cert_matches_request( + 'pem', 'ks.bob.com', ['10.0.0.10'], ['unit1.bob.com'])) + + @patch.object(vault_pki, 'certificate_information') + def test_cert_matches_request_mismatch_cn(self, certificate_information): + certificate_information.return_value = { + 'cn': 'glance.bob.com', + 'ip_sans': ['10.0.0.10'], + 'alt_names': ['unit1.bob.com']} + self.assertFalse( + vault_pki.cert_matches_request( + 'pem', 'ks.bob.com', ['10.0.0.10'], ['unit1.bob.com'])) + + @patch.object(vault_pki, 'certificate_information') + def test_cert_matches_request_mismatch_ipsan(self, + certificate_information): + certificate_information.return_value = { + 'cn': 'glance.bob.com', + 'ip_sans': ['10.0.0.10', '10.0.0.20'], + 'alt_names': ['unit1.bob.com']} + self.assertFalse( + vault_pki.cert_matches_request( + 'pem', 'ks.bob.com', ['10.0.0.10'], ['unit1.bob.com'])) + + @patch.object(vault_pki, 'certificate_information') + def test_cert_matches_request_cn_in_san(self, certificate_information): + certificate_information.return_value = { + 'cn': 'ks.bob.com', + 'ip_sans': ['10.0.0.10'], + 'alt_names': ['ks.bob.com', 'unit1.bob.com']} + self.assertTrue( + vault_pki.cert_matches_request( + 'pem', 'ks.bob.com', ['10.0.0.10'], ['unit1.bob.com'])) + + @patch.object(vault_pki.x509, 'load_pem_x509_certificate') + def test_certificate_information(self, load_pem_x509_certificate): + x509_mock = mock.MagicMock(not_valid_after="10 Mar 1976") + x509_name_mock = mock.MagicMock(value='ks.bob.com') + x509_mock.subject.get_attributes_for_oid.return_value = [ + x509_name_mock] + x509_sans_mock = mock.MagicMock() + sans = [ + ['10.0.0.0.10'], + ['sans1.bob.com']] + x509_sans_mock.value.get_values_for_type = lambda x: sans.pop() + x509_mock.extensions.get_extension_for_oid.return_value = \ + x509_sans_mock + load_pem_x509_certificate.return_value = x509_mock + self.assertEqual( + vault_pki.certificate_information('pem'), + { + 'cn': 'ks.bob.com', + 'not_valid_after': '10 Mar 1976', + 'ip_sans': ['10.0.0.0.10'], + 'alt_names': ['sans1.bob.com']}) + + @patch.object(vault_pki.x509, 'load_pem_x509_certificate') + def test_certificate_information_no_sans(self, load_pem_x509_certificate): + x509_mock = mock.MagicMock(not_valid_after="10 Mar 1976") + x509_name_mock = mock.MagicMock(value='ks.bob.com') + x509_mock.subject.get_attributes_for_oid.return_value = [ + x509_name_mock] + x509_mock.extensions.get_extension_for_oid.side_effect = \ + ExtensionNotFound('msg', 'oid') + load_pem_x509_certificate.return_value = x509_mock + self.assertEqual( + vault_pki.certificate_information('pem'), + { + 'cn': 'ks.bob.com', + 'not_valid_after': '10 Mar 1976', + 'ip_sans': [], + 'alt_names': []}) + + @patch.object(vault_pki.x509, 'load_pem_x509_certificate') + def test_select_newest(self, load_pem_x509_certificate): + def _load_pem_x509(pem): + pem = pem.decode() + cmock1 = mock.MagicMock( + not_valid_after=datetime.datetime(2018, 5, 3)) + cmock2 = mock.MagicMock( + not_valid_after=datetime.datetime(2018, 5, 4)) + cmock3 = mock.MagicMock( + not_valid_after=datetime.datetime(2018, 5, 5)) + certs = { + 'cert1': cmock1, + 'cert2': cmock2, + 'cert3': cmock3} + return certs[pem] + load_pem_x509_certificate.side_effect = lambda x, y: _load_pem_x509(x) + certs = [ + {'certificate': 'cert1'}, + {'certificate': 'cert2'}, + {'certificate': 'cert3'}] + self.assertEqual( + vault_pki.select_newest(certs), + {'certificate': 'cert3'}) + + @patch.object(vault_pki, 'get_matching_cert_from_relation') + @patch.object(vault_pki, 'get_server_certificate') + def test_process_cert_request(self, get_server_certificate, + get_matching_cert_from_relation): + get_matching_cert_from_relation.return_value = 'cached_bundle' + self.assertEqual( + vault_pki.process_cert_request( + 'ks.bob.com', + ['10.0.0.10', 'sans1.bob.com'], + 'keystone_0', + False), + 'cached_bundle') + get_matching_cert_from_relation.assert_called_once_with( + 'keystone_0', + 'ks.bob.com', + ['10.0.0.10'], + ['sans1.bob.com']) + get_server_certificate.assert_not_called() + + @patch.object(vault_pki, 'get_matching_cert_from_relation') + @patch.object(vault_pki, 'get_server_certificate') + def test_process_cert_request_reissue(self, get_server_certificate, + get_matching_cert_from_relation): + get_server_certificate.return_value = 'new_bundle' + self.assertEqual( + vault_pki.process_cert_request( + 'ks.bob.com', + ['10.0.0.10', 'sans1.bob.com'], + 'keystone_0', + True), + 'new_bundle') + get_matching_cert_from_relation.assert_not_called() + get_server_certificate.assert_called_once_with( + 'ks.bob.com', + ip_sans=['10.0.0.10'], + alt_names=['sans1.bob.com']) diff --git a/unit_tests/test_reactive_vault_handlers.py b/unit_tests/test_reactive_vault_handlers.py index f2e9b18..bcb81af 100644 --- a/unit_tests/test_reactive_vault_handlers.py +++ b/unit_tests/test_reactive_vault_handlers.py @@ -70,6 +70,7 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): 'set_flag', 'clear_flag', 'is_container', + 'endpoint_from_flag', ] self.patch_all() self.is_container.return_value = False @@ -617,3 +618,123 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): mock_secrets.publish_ca.assert_called_once_with( vault_ca='test-ca' ) + + @mock.patch.object(handlers.vault_pki, 'get_ca') + @mock.patch.object(handlers.vault_pki, 'get_chain') + @mock.patch.object(handlers.vault_pki, 'process_cert_request') + @mock.patch.object(handlers, 'vault') + def test_create_server_cert(self, _vault, process_cert_request, + get_chain, get_ca): + tls_mock = mock.MagicMock() + tls_mock.get_server_requests.return_value = { + 'keystone_0': { + 'common_name': 'public.openstack.local', + 'sans': ['10.0.0.10', 'admin.public.openstack.local']} + } + _vault.vault_ready_for_clients.return_value = True + process_cert_request.return_value = { + 'certificate': 'CERT', + 'private_key': 'KEY'} + get_ca.return_value = 'CA' + get_chain.return_value = 'CHAIN' + self.endpoint_from_flag.return_value = tls_mock + self.is_flag_set.return_value = False + handlers.create_server_cert() + process_cert_request.assert_called_once_with( + 'public.openstack.local', + ['10.0.0.10', 'admin.public.openstack.local'], + 'keystone_0', + False) + tls_mock.set_server_cert.assert_called_once_with( + 'keystone_0', + 'CERT', + 'KEY') + tls_mock.set_ca.assert_called_once_with('CA') + tls_mock.set_chain.assert_called_once_with('CHAIN') + + @mock.patch.object(handlers.vault_pki, 'get_ca') + @mock.patch.object(handlers.vault_pki, 'get_chain') + @mock.patch.object(handlers.vault_pki, 'process_cert_request') + @mock.patch.object(handlers, 'vault') + def test_create_server_cert_batch(self, _vault, process_cert_request, + get_chain, get_ca): + + def _certs(cn, ip_sans, alt_names, reissue_requested=False): + data = { + 'admin.openstack.local': { + 'certificate': 'ADMINCERT', + 'private_key': 'ADMINKEY'}, + 'public.openstack.local': { + 'certificate': 'PUBLICCERT', + 'private_key': 'PUBLICKEY'}, + 'internal.openstack.local': { + 'certificate': 'INTCERT', + 'private_key': 'INTKEY'}} + return data[cn] + + tls_mock = mock.MagicMock() + tls_mock.get_server_requests.return_value = { + 'keystone_0': { + 'common_name': 'admin.openstack.local', + 'sans': ['10.0.0.10', 'flump.openstack.local'], + 'cert_requests': { + 'public.openstack.local': { + 'sans': ['10.10.0.10', 'unit_name.openstack.local']}, + 'internal.openstack.local': { + 'sans': ['10.20.0.10']}}}} + _vault.vault_ready_for_clients.return_value = True + process_cert_request.side_effect = _certs + get_ca.return_value = 'CA' + get_chain.return_value = 'CHAIN' + create_calls = [ + mock.call( + 'admin.openstack.local', + ['10.0.0.10', 'flump.openstack.local'], + 'keystone_0', + False), + mock.call( + 'public.openstack.local', + ['10.10.0.10', 'unit_name.openstack.local'], + 'keystone_0', + False), + mock.call( + 'internal.openstack.local', + ['10.20.0.10'], + 'keystone_0', + False)] + add_server_calls = [ + mock.call( + 'keystone_0', + 'public.openstack.local', + 'PUBLICCERT', + 'PUBLICKEY'), + mock.call( + 'keystone_0', + 'internal.openstack.local', + 'INTCERT', + 'INTKEY') + ] + self.endpoint_from_flag.return_value = tls_mock + self.is_flag_set.return_value = False + handlers.create_server_cert() + print(process_cert_request.call_args_list) + process_cert_request.assert_has_calls( + create_calls, + any_order=True) + tls_mock.set_server_cert.assert_called_once_with( + 'keystone_0', + 'ADMINCERT', + 'ADMINKEY') + tls_mock.add_server_cert.assert_has_calls( + add_server_calls, + any_order=True) + tls_mock.set_ca.assert_called_once_with('CA') + tls_mock.set_chain.assert_called_once_with('CHAIN') + + @mock.patch.object(handlers, 'vault') + def test_create_server_cert_vault_not_ready(self, _vault): + _vault.vault_ready_for_clients.return_value = False + tls_mock = mock.MagicMock() + self.endpoint_from_flag.return_value = tls_mock + handlers.create_server_cert() + self.assertFalse(tls_mock.get_server_requests.called) diff --git a/unit_tests/test_utils.py b/unit_tests/test_utils.py index 15fe317..5f268d8 100644 --- a/unit_tests/test_utils.py +++ b/unit_tests/test_utils.py @@ -1,8 +1,20 @@ +import mock import unittest class CharmTestCase(unittest.TestCase): + def setUp(self): + self._patches = {} + self._patches_start = {} + + def tearDown(self): + for k, v in self._patches.items(): + v.stop() + setattr(self, k, None) + self._patches = None + self._patches_start = None + def _patch(self, method): _m = unittest.mock.patch.object(self.obj, method) mock = _m.start() @@ -12,3 +24,31 @@ class CharmTestCase(unittest.TestCase): def patch_all(self): for method in self.patches: setattr(self, method, self._patch(method)) + + def patch_object(self, obj, attr, return_value=None, name=None, new=None, + **kwargs): + if name is None: + name = attr + if new is not None: + mocked = mock.patch.object(obj, attr, new=new, **kwargs) + else: + mocked = mock.patch.object(obj, attr, **kwargs) + self._patches[name] = mocked + started = mocked.start() + if new is None: + started.return_value = return_value + self._patches_start[name] = started + setattr(self, name, started) + + def patch(self, item, return_value=None, name=None, new=None, **kwargs): + if name is None: + raise RuntimeError("Must pass 'name' to .patch()") + if new is not None: + mocked = mock.patch(item, new=new, **kwargs) + else: + mocked = mock.patch(item, **kwargs) + self._patches[name] = mocked + started = mocked.start() + if new is None: + started.return_value = return_value + self._patches_start[name] = started