From a3ff39641b0d9d4ebd58ef10a0560de5c2bcf9a7 Mon Sep 17 00:00:00 2001 From: Martin Kalcok Date: Fri, 11 Feb 2022 15:13:41 +0100 Subject: [PATCH] Implement cert cache for vault units (v4) This cache is used to store certificates and keys issued by the leader unit. Non-leader units read these certificates and keep data in their "tls-certificates" relations up to date. This ensures that charm units that receive certs from vault can read from relation data of any vault unit and receive correct data. This patch is mostly the same as I18aa6c9193379ea454851b6f60a8f331ef88a980 but improved to avoid LP#1896542 by removing the section where a certificate can be reused from cache during create_certs. Co-Authored-By: Rodrigo Barbieri Co-Authored-By: Alex Kavanagh func-test-pr: https://github.com/openstack-charmers/zaza-openstack-tests/pull/1153 Closes-Bug: #1940549 Closes-Bug: #1983269 Closes-Bug: #1845961 Related-Bug: #1896542 Change-Id: I0cca13d2042d61ffc6a7c13eccb0ec8c292020c9 (cherry picked from commit 1a1953b0ef23f724e9295505b100eca22ef9a6cd) (cherry picked from commit 56ca825332964a58961f6df3a1ca52df394f2d2c) --- src/lib/charm/vault_pki.py | 448 +++++++++++++++ src/reactive/vault_handlers.py | 104 +++- src/tests/tests.yaml | 5 +- unit_tests/test_lib_charm_vault_pki.py | 626 ++++++++++++++++++++- unit_tests/test_reactive_vault_handlers.py | 260 ++++++++- 5 files changed, 1415 insertions(+), 28 deletions(-) diff --git a/src/lib/charm/vault_pki.py b/src/lib/charm/vault_pki.py index 0403a8e..f0a4f99 100644 --- a/src/lib/charm/vault_pki.py +++ b/src/lib/charm/vault_pki.py @@ -1,3 +1,8 @@ +import json +import re +from subprocess import check_output, CalledProcessError +from tempfile import NamedTemporaryFile + import hvac import charmhelpers.contrib.network.ip as ch_ip @@ -337,3 +342,446 @@ def update_roles(**kwargs): local.update(**kwargs) del local['server_flag'] write_roles(client, **local) + + +def is_cert_from_vault(cert, name=None): + """Return True if the cert is issued by vault and not revoked. + + Looking at the cert, check to see if it was issued by Vault and not on the + revoked list. In order to do this, the cert must be in x509 format as + openssl is used to extract the ID of the cert. Then the certificate is + extracted from vault and the signatures compared. + + :param cert: the certificate in x509 form + :type cert: str + :param name: the mount point in value, default CHARM_PKI_MP + :type name: str + :returns: True if issued by vault, False if unknown. + :raises VaultDown: if vault is down. + :raises VaultNotReady: if vault is sealed. + :raises VaultError: for any other vault issue. + """ + # first get the ID from the client + serial = get_serial_number_from_cert(cert) + if serial is None: + return False + + try: + # now get a list of serial numbers from vault. + client = vault.get_local_client() + if not name: + name = CHARM_PKI_MP + vault_certs_response = client.secrets.pki.list_certificates( + mount_point=name) + vault_certs = [k.replace('-', '').upper() + for k in vault_certs_response['data']['keys']] + + if serial not in vault_certs: + hookenv.log("Certificate with serial {} not issed by vault." + .format(serial), level=hookenv.DEBUG) + return False + revoked_serials = get_revoked_serials_from_vault(name) + if serial in revoked_serials: + hookenv.log("Serial {} is revoked.".format(serial), + level=hookenv.DEBUG) + return False + return True + except ( + vault.hvac.exceptions.InvalidPath, + vault.hvac.exceptions.InternalServerError, + vault.hvac.exceptions.VaultDown, + vault.VaultNotReady, + ): + # vault is not available for some reason, return None, None as nothing + # else is particularly useful here. + return False + except Exception as e: + hookenv.log("General failure verifying cert: {}".format(str(e)), + level=hookenv.DEBUG) + return False + + +def get_serial_number_from_cert(cert, name=None): + """Extract the serial number from the cert, or return None. + + :param cert: the certificate in x509 form + :type cert: str + :returns: the cert serial number or None. + :rtype: str | None + """ + with NamedTemporaryFile() as f: + f.write(cert.encode()) + f.flush() + command = ["openssl", "x509", "-in", f.name, "-noout", "-serial"] + try: + # output in form of 'serial=xxxxx' + output = check_output(command).decode().strip() + serial = output.split("=")[1] + return serial + except CalledProcessError as e: + hookenv.log("Couldn't process certificate: reason: {}" + .format(str(e)), + level=hookenv.DEBUG) + except (TypeError, IndexError): + hookenv.log( + "Couldn't extract serial number from passed certificate", + level=hookenv.DEBUG) + return None + + +def get_revoked_serials_from_vault(name=None): + """Get a list of revoked serial numbers from vault. + + This fetches the CRL from vault; this is in PEM format. We ought to use + python cryptography.x509.load_pem_x509_crl(), but adding cryptography + requires converting the charm to binary, and seems a lot for one function. + + Thus, the format for no certificates revoked is: + + .. code-block:: text + + Certificate Revocation List (CRL): + Version 2 (0x1) + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN = Vault Intermediate Certificate Authority ... + Last Update: Jul 17 11:58:57 2023 GMT + Next Update: Jul 20 11:58:57 2023 GMT + No Revoked Certificates. + Signature Algorithm: sha256WithRSAEncryption + Signature Value: + ... + + And for two (and the pattern repeats): + + .. code-block:: text + + Certificate Revocation List (CRL): + Version 2 (0x1) + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN = Vault Intermediate Certificate Authority ... + Last Update: Jul 18 11:38:17 2023 GMT + Next Update: Jul 21 11:38:17 2023 GMT + Revoked Certificates: + Serial Number: 6EAE52225CB7AB452F37D4FBAC127DDF9542D3DC + Revocation Date: Jul 18 11:38:17 2023 GMT + Serial Number: 78FBEEE4E419C5A335113E4F1EF41F463534B698 + Revocation Date: Jul 18 11:33:36 2023 GMT + Signature Algorithm: sha256WithRSAEncryption + Signature Value: + + Thus we just need to grep the output for "Serial Number:" + + :param name: the mount point in value, default CHARM_PKI_MP + :type name: str + :returns: a list of serial numbers, uppercase, no hyphens + :rtype: List[str] + :raises VaultDown: if vault is down. + :raises VaultNotReady: if vault is sealed. + :raises VaultError: for any other vault issue. + :raises subprocess.CalledProcessError: if openssl command fails + """ + client = vault.get_local_client() + revoked_certs_response = client.secrets.pki.read_crl(mount_point=name) + with NamedTemporaryFile() as f: + f.write(revoked_certs_response.encode()) + f.flush() + command = ["openssl", "crl", "-in", f.name, "-noout", "-text"] + output = check_output(command).decode().strip() + pattern = re.compile(r"Serial Number: (\S+)$") + serials = [] + # for line in output.split("\n"): + for line in output.splitlines(): + match = pattern.match(line.strip()) + if match: + serials.append(match[1]) + return serials + + +class CertCache: + """A class to store the cert and key for a request. + + This class provides a mechanism to CRUD a cached pair of (cert, key) in + storage, which is as loosely coupled to leader storage as possible. + + As the key and cert is stored in leader settings, it's available across the + units and therefore, any unit can access the key and cert for any unit that + is related to the application. + + The actually storing of the key and cert is done in as flat a way as + possible in leader-settings. This is to minimise the size of the + get and store operations for units that might have many certificate + requests. The key and cert are stored as values to a key which is + constructed from the unit_name, publish_key, common_name and item. See + PUBLISH_KEY_FORMAT for details. + + Although, it has a dependency on the request (from tls_certificates), this + was deemed acceptable to keep the interface obvious and pleasing to use. + """ + PUBLISH_KEY_FORMAT = "pki:{unit_name}:{publish_key}:{common_name}:{item}" + PUBLISH_KEY_PREFIX = "pki:{unit_name}:" + TOP_LEVEL_PUBLISH_KEY = "top_level_publish_key" + + def __init__(self, request): + """Initialise a proxy for the the cert and key in leader-settings. + + :param request: the request from which the cert/cache is cached. + :type request: tls_certificates_common.CertificateRequest + """ + self._request = request + + def _cache_key_for(self, item): + """ + Return a cache key for the request by the item. + + :param item: the item to return a key for, either 'cert' or 'key' + :type item: str + :returns: the unique key for the unit, request, and item + :rtype: str + """ + assert item in ('cert', 'key'), "Error in argument passed" + if self._request._is_top_level_server_cert: + return self.PUBLISH_KEY_FORMAT.format( + unit_name=self._request.unit_name, + publish_key=self.TOP_LEVEL_PUBLISH_KEY, + common_name=self._request.common_name, + item=item) + else: + return self.PUBLISH_KEY_FORMAT.format( + unit_name=self._request.unit_name, + publish_key=self._request._publish_key, + common_name=self._request.common_name, + item=item) + + @staticmethod + def _fetch(key): + """Fetch from the storage using a store pre key and key. + + Note the _store() method dumps it as json so it is fetched as json. + + :param key: the key value to fetch from leader settings + :type key: str + :returns: the value from leader settings or "" + :rtype: str + """ + value = hookenv.leader_get(key) + if value: + return json.loads(value) + return "" + + @staticmethod + def _store(key, value): + """Store a value by key into the actual storage. + + :param key: the key value to set in leader settings + :type key: str + :param value: the value to store. + :type value: str + :raises: RuntimeError if not the leader + :raises: TypeError if value couldn't be converted. + """ + try: + hookenv.leader_set({key: json.dumps(value)}) + except TypeError: + raise + except Exception as e: + raise RuntimeError(str(e)) + + @staticmethod + def _clear(key): + """Explicitly clear a valye in the actual storage. + + :param key: the key value to clear. + :type key: str + :raises: RuntimeError if not the leader + :raises: TypeError if value couldn't be converted. + """ + try: + hookenv.leader_set({key: None}) + except Exception as e: + raise RuntimeError(str(e)) + + def clear(self): + self._clear(self._cache_key_for('key')) + self._clear(self._cache_key_for('cert')) + + @property + def key(self): + """Get the key.""" + return self._fetch(self._cache_key_for('key')) + + @key.setter + def key(self, key_value): + """Set the key value.""" + self._store(self._cache_key_for('key'), key_value) + + @property + def cert(self): + """The the cert.""" + return self._fetch(self._cache_key_for('cert')) + + @cert.setter + def cert(self, cert_value): + """Set the cert value.""" + self._store(self._cache_key_for('cert'), cert_value) + + @classmethod + def remove_all_for(cls, unit_name): + """Remove all the cached keys for a unit name. + + This is an awkward function, as the cache in leader settings is 'flat' + to ensure that the set payloads are as small as possible. + + This iterates through all the keys and if they match the prefix for the + unit_name, it clears them. + + :param unit_name: The unit_name to clear. + :type unit_name: str + """ + prefix = cls.PUBLISH_KEY_PREFIX.format(unit_name=unit_name) + leader_keys = (cls._fetch(None) or {}).keys() + for key in leader_keys: + if key.startswith(prefix): + cls._clear(key) + + +def find_cert_in_cache(request): + """Return certificate and key from cache that match the request. + + Returned certificate is validated against the current CA cert. If CA cert + is missing then the function returns (None, None). + + If the certificate can't be found in vault, then a warning is logged, but + the cert is still returned as it is in leader_settings; the leader may + decide to remove it at a later date. + + :param request: Request for certificate from "client" unit. + :type request: tls_certificates_common.CertificateRequest + :return: Certificate and private key from cache + :rtype: (str, str) | (None, None) + """ + request_pki_cache = CertCache(request) + cert = request_pki_cache.cert + key = request_pki_cache.key + if cert is None or key is None: + return None, None + + if not is_cert_from_vault(cert, name=CHARM_PKI_MP): + hookenv.log('Certificate from cache for "{}" (cn: "{}") was not found ' + 'in vault, but is in the cache. Using, but may not be ' + 'valid.'.format(request.unit_name, request.common_name), + level=hookenv.WARNING) + return cert, key + + +def update_cert_cache(request, cert, key): + """Store certificate and key in the cache. + + Stored values are associated with the request from "client" unit, + so it can be later retrieved when the request is handled again. + + :param request: Request for certificate from "client" unit. + :type request: tls_certificates_common.CertificateRequest + :param cert: Issued certificate for the "client" request (in PEM format) + :type cert: str + :param key: Issued private key from the "client" request (in PEM format) + :type key: str + :return: None + """ + request_pki_cache = CertCache(request) + hookenv.log('Saving certificate for "{}" ' + '(cn: "{}") into cache.'.format(request.unit_name, + request.common_name), + hookenv.DEBUG) + + request_pki_cache.key = key + request_pki_cache.cert = cert + + +def remove_unit_from_cache(unit_name): + """Clear certificates and keys related to the unit from the cache. + + :param unit_name: Name of the unit to be removed from the cache. + :type unit_name: str + :return: None + """ + hookenv.log('Removing certificates for unit "{}" from ' + 'cache.'.format(unit_name), hookenv.DEBUG) + CertCache.remove_all_for(unit_name) + + +def populate_cert_cache(tls_endpoint): + """Store previously issued certificates in the cache. + + This function is used when vault charm is upgraded from older version + that may not have a certificate cache to a version that has it. It + goes through all previously issued certificates and stores them in + cache. + + :param tls_endpoint: Endpoint of "certificates" relation + :type tls_endpoint: interface_tls_certificates.provides.TlsProvides + :return: None + """ + hookenv.log( + "Populating certificate cache with data from relations", hookenv.INFO + ) + + for request in tls_endpoint.all_requests: + try: + if request._is_top_level_server_cert: + relation_data = request._unit.relation.to_publish_raw + cert = relation_data[request._server_cert_key] + key = relation_data[request._server_key_key] + else: + relation_data = request._unit.relation.to_publish + cert = relation_data[request._publish_key][ + request.common_name + ]['cert'] + key = relation_data[request._publish_key][ + request.common_name + ]['key'] + except (KeyError, TypeError): + if request._is_top_level_server_cert: + cert_id = request._server_cert_key + else: + cert_id = request.common_name + hookenv.log( + 'Certificate "{}" (or associated key) issued for unit "{}" ' + 'not found in relation data.'.format( + cert_id, request._unit.unit_name + ), + hookenv.WARNING + ) + continue + + update_cert_cache(request, cert, key) + + +def set_global_client_cert(bundle): + """Set the global cert for all units in the app. + + :param bundle: the bundle returned from generate_certificates() + :type bundle: Dict[str, str] + :raises: RuntimeError if leader_set fails. + :raises: TypeError if the bundle can't be serialised. + """ + try: + hookenv.leader_set( + {'charm.vault.global-client-cert': json.dumps(bundle)}) + except TypeError: + raise + except Exception as e: + raise RuntimeError("Couldn't run leader_settings: {}".format(str(e))) + + +def get_global_client_cert(): + """Return the bundle returned from leader_settings. + + Will return an empty dictionary if key is not present. + + :returns: the bundle previously stored, or {} + :rtype: Dict[str, str] + """ + bundle = hookenv.leader_get('charm.vault.global-client-cert') + if bundle: + return json.loads(bundle) + return {} diff --git a/src/reactive/vault_handlers.py b/src/reactive/vault_handlers.py index 0d1e78b..150a412 100644 --- a/src/reactive/vault_handlers.py +++ b/src/reactive/vault_handlers.py @@ -35,6 +35,7 @@ from charmhelpers.core.hookenv import ( log, network_get_primary_address, open_port, + remote_unit, status_set, unit_private_ip, ) @@ -292,6 +293,11 @@ def upgrade_charm(): remove_state('vault.nrpe.configured') remove_state('vault.ssl.configured') remove_state('vault.requested-lb') + # When upgrading from version of a charm that did not have a certificate + # cache, we need to populate the cache with already issued certificates. + # Otherwise the non-leader units would not be able to sync their + # certificate data via cache. + set_flag('needs-cert-cache-repopulation') @when_not("is-update-status-hook") @@ -976,7 +982,8 @@ def publish_ca_info(): 'certificates.available') @when_not('config.changed') def publish_global_client_cert(): - """ + """publish the global certificate. + This is for backwards compatibility with older tls-certificate clients only. Obviously, it's not good security / design to have clients sharing a certificate, but it seems that there are clients that depend on this @@ -987,23 +994,104 @@ def publish_global_client_cert(): log("Vault not authorized: Skipping publish_global_client_cert", "WARNING") return - cert_created = is_flag_set('charm.vault.global-client-cert.created') reissue_requested = is_flag_set('certificates.reissue.global.requested') tls = endpoint_from_flag('certificates.available') - if not cert_created or reissue_requested: + bundle = vault_pki.get_global_client_cert() + certificate_present = "certificate" in bundle and "private_key" in bundle + if not certificate_present or reissue_requested: ttl = config()['default-ttl'] max_ttl = config()['max-ttl'] bundle = vault_pki.generate_certificate('client', 'global-client', [], ttl, max_ttl) - unitdata.kv().set('charm.vault.global-client-cert', bundle) + vault_pki.set_global_client_cert(bundle) set_flag('charm.vault.global-client-cert.created') clear_flag('certificates.reissue.global.requested') - else: - bundle = unitdata.kv().get('charm.vault.global-client-cert') + tls.set_client_cert(bundle['certificate'], bundle['private_key']) +@when_not("is-update-status-hook") +@when('certificates.available') +@when('charm.vault.ca.ready') +@when('leadership.is_leader') +@when('needs-cert-cache-repopulation') +def repopulate_cert_cache(): + """Force repopulation of cert cache on the leader. + + Certain circumstances such as 'upgrade-charm' hook should force the leader + to populate the cert cache, so then non-leaders will follow in the + 'sync_cert_from_cache' method.""" + tls = endpoint_from_flag('certificates.available') + if tls: + vault_pki.populate_cert_cache(tls) + clear_flag('needs-cert-cache-repopulation') + + +@when_not("is-update-status-hook") +@when("certificates.available") +@when_not('leadership.is_leader') +def sync_cert_from_cache(): + """Sync cert and key data in the tls-certificate relation. + + Non-leader units should keep the relation data up-to-date according + to the data from PKI cache that's maintained by the leader. This ensures + that "client" units can use data from any of the related vault units to + receive valid keys and certificates. + """ + tls = endpoint_from_flag('certificates.available') + + # propagate the ca stored by the leader if it can be obtained. + ca = vault_pki.get_ca() + if ca: + tls.set_ca(ca) + + ca_chain = None + # propagate the chain if it can be obtained. + try: + # this might fail if we were restarted and need to be unsealed + ca_chain = vault_pki.get_chain() + except ( + vault.hvac.exceptions.InvalidPath, + vault.hvac.exceptions.InternalServerError, + vault.hvac.exceptions.VaultDown, + vault.VaultNotReady, + ) as e: + log("Couldn't get the chain from vault. Reason: {}".format(str(e))) + else: + tls.set_chain(ca_chain) + + # propagate global client cert from cache + bundle = vault_pki.get_global_client_cert() + if bundle.get('certificate') and bundle.get('private_key'): + tls.set_client_cert(bundle['certificate'], bundle['private_key']) + + # update certificate data in relations + cert_requests = tls.all_requests + for request in cert_requests: + cache_cert, cache_key = vault_pki.find_cert_in_cache(request) + if cache_cert and cache_key: + request.set_cert(cache_cert, cache_key) + + +@hook('certificates-relation-departed') +def cert_client_leaving(relation): + """Remove certs and keys of the departing unit from cache. + + Note: this uses the hook as the interface code doesn't provide a mechanism + to notify departing units. + """ + if is_flag_set('leadership.is_leader'): + # Due to certificates requests replacing "/" in the unit + # name with "_" (see: tls_certificates_common.CertificateRequest), + # we must emulate the same behavior when removing unit certs from + # cache. + departing_unit = remote_unit() + log("Removing certificates for {} from cache.".format(departing_unit)) + unit_name = departing_unit.replace('/', '_') + vault_pki.remove_unit_from_cache(unit_name) + + @when_not("is-update-status-hook") @when('leadership.is_leader', 'charm.vault.ca.ready', @@ -1031,6 +1119,7 @@ def create_certs(): processed_applications.append(request.application_name) else: cert_type = request.cert_type + try: ttl = config()['default-ttl'] max_ttl = config()['max-ttl'] @@ -1038,6 +1127,9 @@ def create_certs(): request.common_name, request.sans, ttl, max_ttl) request.set_cert(bundle['certificate'], bundle['private_key']) + vault_pki.update_cert_cache(request, + bundle["certificate"], + bundle["private_key"]) except vault.VaultInvalidRequest as e: log(str(e), level=ERROR) continue # TODO: report failure back to client diff --git a/src/tests/tests.yaml b/src/tests/tests.yaml index edfcf4e..1eed128 100644 --- a/src/tests/tests.yaml +++ b/src/tests/tests.yaml @@ -31,10 +31,7 @@ target_deploy_status: tests: - zaza.openstack.charm_tests.vault.tests.VaultTest -# This second run of the tests is to ensure that Vault can handle updating the -# root CA in Vault with a refreshed CSR and won't end up in a hook-error -# state. (LP: #1866150). -- zaza.openstack.charm_tests.vault.tests.VaultTest +- zaza.openstack.charm_tests.vault.tests.VaultCacheTest tests_options: force_deploy: diff --git a/unit_tests/test_lib_charm_vault_pki.py b/unit_tests/test_lib_charm_vault_pki.py index 5b06e22..a141476 100644 --- a/unit_tests/test_lib_charm_vault_pki.py +++ b/unit_tests/test_lib_charm_vault_pki.py @@ -1,5 +1,7 @@ +import collections +import json from unittest import mock -from unittest.mock import patch +from unittest.mock import call, patch, MagicMock import hvac @@ -515,3 +517,625 @@ class TestLibCharmVaultPKI(unit_tests.test_utils.CharmTestCase): server_flag=False, client_flag=True), ]) + + @patch.object(vault_pki, 'get_serial_number_from_cert') + def test_is_cert_from_vault_no_serial( + self, + mock_get_serial_number_from_cert, + ): + mock_get_serial_number_from_cert.return_value = None + self.assertFalse(vault_pki.is_cert_from_vault('the-cert')) + mock_get_serial_number_from_cert.assert_called_once_with('the-cert') + + @patch.object(vault_pki, 'get_serial_number_from_cert') + @patch.object(vault_pki.vault, 'get_local_client') + @patch.object(vault_pki.hookenv, 'log') + def test_is_cert_from_vault_not_from_vault( + self, + mock_log, + mock_get_local_client, + mock_get_serial_number_from_cert, + ): + mock_get_serial_number_from_cert.return_value = "1234567890" + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.list_certificates.return_value = { + "data": { + "keys": [] + } + } + + self.assertFalse( + vault_pki.is_cert_from_vault('the-cert', name='a-name')) + mock_get_serial_number_from_cert.assert_called_once_with('the-cert') + mock_client.secrets.pki.list_certificates.assert_called_once_with( + mount_point='a-name') + mock_log.assert_called_once_with( + "Certificate with serial 1234567890 not issed by vault.", + level=vault_pki.hookenv.DEBUG + ) + + @patch.object(vault_pki, 'get_serial_number_from_cert') + @patch.object(vault_pki.vault, 'get_local_client') + @patch.object(vault_pki, 'get_revoked_serials_from_vault') + @patch.object(vault_pki.hookenv, 'log') + def test_is_cert_from_vault_not_revoked_serial( + self, + mock_log, + mock_get_revoked_serials_from_vault, + mock_get_local_client, + mock_get_serial_number_from_cert, + ): + mock_get_serial_number_from_cert.return_value = "1234567890" + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.list_certificates.return_value = { + "data": { + "keys": ["1234567890"] + } + } + mock_get_revoked_serials_from_vault.return_value = [] + + self.assertTrue( + vault_pki.is_cert_from_vault('the-cert', name='a-name')) + + mock_get_revoked_serials_from_vault.assert_called_once_with('a-name') + mock_log.assert_not_called() + + @patch.object(vault_pki, 'get_serial_number_from_cert') + @patch.object(vault_pki.vault, 'get_local_client') + @patch.object(vault_pki, 'get_revoked_serials_from_vault') + @patch.object(vault_pki.hookenv, 'log') + def test_is_cert_from_vault_revoked_serial( + self, + mock_log, + mock_get_revoked_serials_from_vault, + mock_get_local_client, + mock_get_serial_number_from_cert, + ): + mock_get_serial_number_from_cert.return_value = "1234567890" + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.list_certificates.return_value = { + "data": { + "keys": ["12-34-56-78-90"] + } + } + mock_get_revoked_serials_from_vault.return_value = [ + "DEADBEEF", + "1234567890", + "notme", + ] + + self.assertFalse( + vault_pki.is_cert_from_vault('the-cert', name='a-name')) + + mock_log.assert_called_once_with( + "Serial 1234567890 is revoked.", level=vault_pki.hookenv.DEBUG) + + @patch.object(vault_pki, 'get_serial_number_from_cert') + @patch.object(vault_pki.vault, 'get_local_client') + @patch.object(vault_pki, 'get_revoked_serials_from_vault') + @patch.object(vault_pki.hookenv, 'log') + def test_is_cert_from_vault_raised_exceptions( + self, + mock_log, + mock_get_revoked_serials_from_vault, + mock_get_local_client, + mock_get_serial_number_from_cert, + ): + mock_get_serial_number_from_cert.return_value = "1234567890" + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.list_certificates.return_value = { + "data": { + "keys": ["12-34-56-78-90"] + } + } + mock_get_revoked_serials_from_vault.return_value = [ + "DEADBEEF", + "1234567890", + "notme", + ] + + def make_raiser(exc): + def _raiser(*args, **kwargs): + raise exc + return _raiser + + exceptions = [ + vault_pki.vault.hvac.exceptions.InvalidPath('wrong-path'), + vault_pki.vault.hvac.exceptions.InternalServerError('bang'), + vault_pki.vault.hvac.exceptions.VaultDown(), + vault_pki.vault.VaultNotReady("really-not-ready"), + ] + + for exception in exceptions: + mock_get_local_client.side_effect = make_raiser(exception) + self.assertFalse( + vault_pki.is_cert_from_vault('the-cert', name='a-name')) + mock_log.assert_not_called() + + class OtherException(Exception): + pass + + mock_get_local_client.side_effect = make_raiser( + OtherException("on noes")) + self.assertFalse( + vault_pki.is_cert_from_vault('the-cert', name='a-name')) + mock_log.assert_called_once_with( + "General failure verifying cert: on noes", + level=vault_pki.hookenv.DEBUG) + + @patch.object(vault_pki, 'check_output') + @patch.object(vault_pki, 'NamedTemporaryFile') + @patch.object(vault_pki.hookenv, 'log') + def test_get_serial_number_from_cert( + self, + mock_log, + mock_named_temporary_file, + mock_check_output + ): + mock_f = MagicMock() + mock_f.name = "filename" + mock_named_temporary_file.return_value.__enter__.return_value = mock_f + mock_check_output.return_value = b" serial=12345678 " + self.assertEqual(vault_pki.get_serial_number_from_cert( + "this is a cert"), "12345678") + mock_f.write.assert_called_once_with(b"this is a cert") + mock_f.flush.assert_called_once_with() + mock_check_output.assert_called_once_with( + ['openssl', 'x509', '-in', 'filename', '-noout', '-serial']) + + @patch.object(vault_pki, 'check_output') + @patch.object(vault_pki, 'NamedTemporaryFile') + @patch.object(vault_pki.hookenv, 'log') + def test_get_serial_number_from_cert_subprocess_error( + self, + mock_log, + mock_named_temporary_file, + mock_check_output + ): + mock_f = MagicMock() + mock_f.name = "filename" + mock_named_temporary_file.return_value.__enter__.return_value = mock_f + mock_check_output.return_value = b" serial=12345678 " + + def _raise(*args, **kwargs): + raise vault_pki.CalledProcessError(cmd="bang", returncode=1) + + mock_check_output.side_effect = _raise + + self.assertEqual(vault_pki.get_serial_number_from_cert( + "this is a cert"), None) + + mock_log.assert_called_once_with( + "Couldn't process certificate: reason: Command 'bang' returned " + "non-zero exit status 1.", + level=vault_pki.hookenv.DEBUG) + + @patch.object(vault_pki, 'check_output') + @patch.object(vault_pki, 'NamedTemporaryFile') + @patch.object(vault_pki.hookenv, 'log') + def test_get_serial_number_from_cert_other_error( + self, + mock_log, + mock_named_temporary_file, + mock_check_output + ): + mock_f = MagicMock() + mock_f.name = "filename" + mock_named_temporary_file.return_value.__enter__.return_value = mock_f + mock_check_output.return_value = b"thing" + + self.assertEqual(vault_pki.get_serial_number_from_cert( + "this is a cert"), None) + + mock_log.assert_called_once_with( + "Couldn't extract serial number from passed certificate", + level=vault_pki.hookenv.DEBUG) + + @patch.object(vault_pki, 'check_output') + @patch.object(vault_pki, 'NamedTemporaryFile') + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_revoked_serials_from_vault_no_serials( + self, + mock_get_local_client, + mock_named_temporary_file, + mock_check_output + ): + mock_f = MagicMock() + mock_f.name = "filename" + mock_named_temporary_file.return_value.__enter__.return_value = mock_f + mock_check_output.return_value = b"\n\n\n" + + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.read_crl.return_value = "the crl" + + self.assertEqual(vault_pki.get_revoked_serials_from_vault( + name=vault_pki.CHARM_PKI_MP), []) + + mock_check_output.assert_called_once_with( + ['openssl', 'crl', '-in', 'filename', '-noout', '-text']) + mock_f.write.assert_called_once_with(b"the crl") + mock_client.secrets.pki.read_crl.assert_called_once_with( + mount_point=vault_pki.CHARM_PKI_MP) + + @patch.object(vault_pki, 'check_output') + @patch.object(vault_pki, 'NamedTemporaryFile') + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_revoked_serials_from_vault_some_serials( + self, + mock_get_local_client, + mock_named_temporary_file, + mock_check_output + ): + mock_f = MagicMock() + mock_f.name = "filename" + mock_named_temporary_file.return_value.__enter__.return_value = mock_f + mock_check_output.return_value = "\n".join([ + "Some interesting line", + " Serial Number: DEADBEEF", + "another interesting line.", + " and another.", + " Serial Number: 1234567890", + " and finally this one." + ]).encode() + + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.read_crl.return_value = "the crl" + + self.assertEqual(vault_pki.get_revoked_serials_from_vault( + name=vault_pki.CHARM_PKI_MP), ['DEADBEEF', '1234567890']) + + def test_certcache__init(self): + item = vault_pki.CertCache('a-request') + self.assertEqual(item._request, 'a-request') + + class ReadOnlyDict(collections.OrderedDict): + """The ReadOnly dictionary accessible via attributes.""" + + def __init__(self, data): + for k, v in data.items(): + super().__setitem__(k, v) + + def __getitem__(self, key): + return super().__getitem__(key) + + def __setattr__(self, *_): + raise TypeError("{} does not allow setting of attributes" + .format(self.__class__.__name__)) + + def __setitem__(self, *_): + raise TypeError("{} does not allow setting of items" + .format(self.__class__.__name__)) + + __getattr__ = __getitem__ + + def _default_request(self): + return self.ReadOnlyDict({ + 'unit_name': 'the-name', + '_is_top_level_server_cert': False, + '_publish_key': "subbed", + 'common_name': 'cn1' + }) + + def test_certcache__cache_key_for(self): + + request = self.ReadOnlyDict({ + 'unit_name': 'the-name', + '_is_top_level_server_cert': True, + '_publish_key': None, + 'common_name': 'cn1' + }) + self.assertEqual(vault_pki.CertCache(request)._cache_key_for('cert'), + "pki:the-name:top_level_publish_key:cn1:cert") + self.assertEqual(vault_pki.CertCache(request)._cache_key_for('key'), + "pki:the-name:top_level_publish_key:cn1:key") + + request = self._default_request() + self.assertEqual(vault_pki.CertCache(request)._cache_key_for('cert'), + "pki:the-name:subbed:cn1:cert") + self.assertEqual(vault_pki.CertCache(request)._cache_key_for('key'), + "pki:the-name:subbed:cn1:key") + + with self.assertRaises(AssertionError): + vault_pki.CertCache(request)._cache_key_for('thing') + + @patch.object(vault_pki.hookenv, 'leader_get') + def test_certcache__fetch(self, mock_leader_get): + mock_leader_get.return_value = None + request = self._default_request() + self.assertEqual(vault_pki.CertCache(request)._fetch("mine"), "") + mock_leader_get.assert_called_once_with('mine') + + mock_leader_get.reset_mock() + mock_leader_get.return_value = '"the-value"' + self.assertEqual(vault_pki.CertCache(request)._fetch("mine"), + 'the-value') + + @patch.object(vault_pki.hookenv, 'leader_set') + def test_certcache__store(self, mock_leader_set): + request = self._default_request() + vault_pki.CertCache(request)._store("mine", "a value") + mock_leader_set.assert_called_once_with({"mine": '"a value"'}) + + # type error + class A: + pass + + with self.assertRaises(TypeError): + vault_pki.CertCache(request)._store("mine", A()) + + # leader-set failure (subprocess call!) + def _raise(*args, **kwargs): + raise vault_pki.CalledProcessError(cmd="bang", returncode=1) + + mock_leader_set.side_effect = _raise + + with self.assertRaises(RuntimeError): + vault_pki.CertCache(request)._store("mine", "thing") + + @patch.object(vault_pki.hookenv, 'leader_set') + def test_certcache__clear(self, mock_leader_set): + request = self._default_request() + vault_pki.CertCache(request)._clear("mine") + mock_leader_set.assert_called_once_with({"mine": None}) + + # leader-set failure (subprocess call!) + def _raise(*args, **kwargs): + raise vault_pki.CalledProcessError(cmd="bang", returncode=1) + + mock_leader_set.side_effect = _raise + + with self.assertRaises(RuntimeError): + vault_pki.CertCache(request)._clear("mine") + + @patch.object(vault_pki.CertCache, '_clear') + def test_certcache_clear(self, mock__clear): + request = self._default_request() + vault_pki.CertCache(request).clear() + mock__clear.assert_has_calls([ + call('pki:the-name:subbed:cn1:key'), + call('pki:the-name:subbed:cn1:cert'), + ]) + + # leader-set failure (subprocess call!) + def _raise(*args, **kwargs): + raise RuntimeError("bang") + + mock__clear.side_effect = _raise + + with self.assertRaises(RuntimeError): + vault_pki.CertCache(request).clear() + + @patch.object(vault_pki.CertCache, '_store') + @patch.object(vault_pki.CertCache, '_fetch') + @patch.object(vault_pki.CertCache, '_cache_key_for') + def test_certcache__key_property( + self, + mock__cache_key_for, + mock__fetch, + mock__store, + ): + request = self._default_request() + mock__cache_key_for.return_value = "cache-key" + mock__fetch.return_value = "the-value" + + # read + self.assertEqual(vault_pki.CertCache(request).key, "the-value") + mock__cache_key_for.assert_called_once_with('key') + mock__fetch.assert_called_once_with('cache-key') + + # write + vault_pki.CertCache(request).key = 'new-value' + mock__store.assert_called_once_with('cache-key', 'new-value') + + @patch.object(vault_pki.CertCache, '_store') + @patch.object(vault_pki.CertCache, '_fetch') + @patch.object(vault_pki.CertCache, '_cache_key_for') + def test_certcache__cert_property( + self, + mock__cache_key_for, + mock__fetch, + mock__store, + ): + request = self._default_request() + mock__cache_key_for.return_value = "cache-key" + mock__fetch.return_value = "the-value" + + # read + self.assertEqual(vault_pki.CertCache(request).cert, "the-value") + mock__cache_key_for.assert_called_once_with('cert') + mock__fetch.assert_called_once_with('cache-key') + + # write + vault_pki.CertCache(request).cert = 'new-value' + mock__store.assert_called_once_with('cache-key', 'new-value') + + @patch.object(vault_pki.CertCache, '_clear') + @patch.object(vault_pki.CertCache, '_fetch') + def test_certcache__remove_all_for( + self, + mock__fetch, + mock__clear, + ): + mock__fetch.return_value = { + 'pki:the-name:subbed:cn1:key': "thing1", + 'pki:the-name:subbed:cn1:cert': "thing2", + 'pki:the-name2:subbed:cn1:key': "thing3", + 'pki:the-name2:subbed:cn1:cert': "thing4", + } + vault_pki.CertCache.remove_all_for('the-name') + mock__clear.assert_has_calls([ + call('pki:the-name:subbed:cn1:key'), + call('pki:the-name:subbed:cn1:cert'), + ]) + + @patch.object(vault_pki.hookenv, 'log') + @patch.object(vault_pki, 'is_cert_from_vault') + @patch.object(vault_pki, 'CertCache') + def test_find_cert_in_cache(self, + mock_cert_cache, + mock_is_cert_from_vault, + mock_log): + mock_cert_cache_object = MagicMock() + mock_cert_cache_object.cert = "a-cert" + mock_cert_cache_object.key = "a-key" + mock_cert_cache.return_value = mock_cert_cache_object + mock_is_cert_from_vault.return_value = True + request = MagicMock() + + cert, key = vault_pki.find_cert_in_cache(request) + self.assertEqual((cert, key), ("a-cert", "a-key")) + mock_cert_cache.assert_called_once_with(request) + + @patch.object(vault_pki.hookenv, 'log') + @patch.object(vault_pki, 'is_cert_from_vault') + @patch.object(vault_pki.CertCache, 'cert', new_callable=mock.PropertyMock) + @patch.object(vault_pki.CertCache, 'key', new_callable=mock.PropertyMock) + def test_find_cert_in_cache_not_found(self, + mock_key, mock_cert, + mock_is_cert_from_vault, + mock_log): + mock_cert.return_value = None + mock_key.return_value = "a-key" + mock_is_cert_from_vault.return_value = True + request = MagicMock() + + cert, key = vault_pki.find_cert_in_cache(request) + self.assertEqual((cert, key), (None, None)) + + mock_cert.return_value = "a-cert" + mock_key.return_value = None + cert, key = vault_pki.find_cert_in_cache(request) + self.assertEqual((cert, key), (None, None)) + + @patch.object(vault_pki.hookenv, 'log') + @patch.object(vault_pki, 'is_cert_from_vault') + @patch.object(vault_pki.CertCache, 'cert', new_callable=mock.PropertyMock) + @patch.object(vault_pki.CertCache, 'key', new_callable=mock.PropertyMock) + def test_find_cert_in_cache_not_in_vault(self, + mock_key, mock_cert, + mock_is_cert_from_vault, + mock_log): + mock_cert.return_value = "a-cert" + mock_key.return_value = "a-key" + mock_is_cert_from_vault.return_value = False + request = MagicMock() + + cert, key = vault_pki.find_cert_in_cache(request) + self.assertEqual((cert, key), ("a-cert", "a-key")) + mock_is_cert_from_vault.assert_called_once_with( + 'a-cert', name=vault_pki.CHARM_PKI_MP) + + @patch.object(vault_pki.CertCache, 'cert', new_callable=mock.PropertyMock) + @patch.object(vault_pki.CertCache, 'key', new_callable=mock.PropertyMock) + def test_update_cert_cache_top_level_cert(self, mock_key, mock_cert): + """Test storing top-level cert in cache.""" + cert_data = "cert data" + key_data = "key data" + cert_name = "server.cert" + key_name = "server.key" + client_name = "client_unit_0" + + # setup cert request + request = MagicMock() + request.unit_name = client_name + request.common_name = client_name + request._is_top_level_server_cert = True + request._server_cert_key = cert_name + request._server_key_key = key_name + + vault_pki.update_cert_cache(request, cert_data, key_data) + mock_cert.assert_called_once_with(cert_data) + mock_key.assert_called_once_with(key_data) + + @patch.object(vault_pki.CertCache, 'remove_all_for') + def test_remove_unit_from_cache(self, mock_remove_all_for): + """Test removing unit certificates from cache.""" + vault_pki.remove_unit_from_cache('client_0') + mock_remove_all_for.assert_called_once_with('client_0') + + @patch.object(vault_pki, 'update_cert_cache') + def test_populate_cert_cache(self, update_cert_cache): + # Define data for top level certificate and key + top_level_cert_name = "server.crt" + top_level_key_name = "server.key" + top_level_cert_data = "top level cert" + top_level_key_data = "top level key" + + # Define data for non-top level certificate + processed_request_cn = "juju_unit_service.crt" + processed_request_publish_key = "juju_unit_service.processed" + processed_cert_data = "processed cert" + processed_key_data = "processed key" + + # Mock request for top level certificate + top_level_request = MagicMock() + top_level_request._is_top_level_server_cert = True + top_level_request._server_cert_key = top_level_cert_name + top_level_request._server_key_key = top_level_key_name + top_level_request._unit.relation.to_publish_raw = { + top_level_cert_name: top_level_cert_data, + top_level_key_name: top_level_key_data, + } + + # Mock request for non-top level certificate + processed_request = MagicMock() + processed_request._is_top_level_server_cert = False + processed_request.common_name = processed_request_cn + processed_request._publish_key = processed_request_publish_key + processed_request._unit.relation.to_publish = { + processed_request_publish_key: {processed_request_cn: { + "cert": processed_cert_data, + "key": processed_key_data + }} + } + + tls_endpoint = MagicMock() + tls_endpoint.all_requests = [top_level_request, processed_request] + + vault_pki.populate_cert_cache(tls_endpoint) + + expected_update_calls = [ + call(top_level_request, top_level_cert_data, top_level_key_data), + call(processed_request, processed_cert_data, processed_key_data), + ] + update_cert_cache.assert_has_calls(expected_update_calls) + + @patch.object(vault_pki.hookenv, 'leader_set') + def test_set_global_client_cert(self, mock_leader_set): + bundle = { + 'key1': 'value1', + 'key2': 'value2', + } + vault_pki.set_global_client_cert(bundle) + mock_leader_set.assert_called_once_with( + {'charm.vault.global-client-cert': mock.ANY}) + v = mock_leader_set.call_args[0][0]['charm.vault.global-client-cert'] + self.assertEqual(json.loads(v), bundle) + + # Type error + class A: + pass + + with self.assertRaises(TypeError): + vault_pki.set_global_client_cert(A()) + + # leader-set error. + def _raise(*args, **kwargs): + raise vault_pki.CalledProcessError(cmd="bang", returncode=1) + + mock_leader_set.side_effect = _raise + with self.assertRaises(RuntimeError): + vault_pki.set_global_client_cert(bundle) + + @patch.object(vault_pki.hookenv, 'leader_get') + def test_get_global_client_cert(self, mock_leader_get): + mock_leader_get.return_value = '{"a":"a-value"}' + self.assertEqual(vault_pki.get_global_client_cert(), {'a': 'a-value'}) + mock_leader_get.return_value = None + self.assertEqual(vault_pki.get_global_client_cert(), {}) diff --git a/unit_tests/test_reactive_vault_handlers.py b/unit_tests/test_reactive_vault_handlers.py index 1945b08..ea6ecae 100644 --- a/unit_tests/test_reactive_vault_handlers.py +++ b/unit_tests/test_reactive_vault_handlers.py @@ -2,6 +2,7 @@ from unittest import mock from unittest.mock import patch, call import charms.reactive +import hvac # Mock out reactive decorators prior to importing reactive.vault dec_mock = mock.MagicMock() @@ -245,6 +246,16 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): mock.call('vault.ssl.configured')] handlers.upgrade_charm() self.remove_state.assert_has_calls(calls) + self.set_flag.assert_called_once_with( + 'needs-cert-cache-repopulation') + + @mock.patch.object(handlers, 'vault_pki') + def test_repopulate_cert_cache(self, mock_vault_pki): + handlers.repopulate_cert_cache() + mock_vault_pki.populate_cert_cache.assert_called_once_with( + self.endpoint_from_flag.return_value) + self.clear_flag.assert_called_once_with( + 'needs-cert-cache-repopulation') def test_request_db(self): psql = mock.MagicMock() @@ -936,18 +947,22 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): @mock.patch.object(handlers, 'client_approle_authorized') @mock.patch.object(handlers, 'vault_pki') - def test_publish_global_client_cert_already_gend( + def test_publish_global_client_cert_already_sent( self, vault_pki, _client_approle_authorized): _client_approle_authorized.return_value = True tls = self.endpoint_from_flag.return_value - self.is_flag_set.side_effect = [True, False] - self.unitdata.kv().get.return_value = {'certificate': 'crt', - 'private_key': 'key'} + self.is_flag_set.return_value = False + vault_pki.get_global_client_cert.return_value = { + 'certificate': 'crt', + 'private_key': 'key' + } + vault_pki.generate_certificate.return_value = "bundle" + handlers.publish_global_client_cert() + assert not vault_pki.generate_certificate.called assert not self.set_flag.called - self.unitdata.kv().get.assert_called_with('charm.vault.' - 'global-client-cert') + vault_pki.set_global_client_cert.assert_not_called() tls.set_client_cert.assert_called_with('crt', 'key') @mock.patch.object(handlers, 'client_approle_authorized') @@ -960,49 +975,57 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): 'max-ttl': '3456h', } - tls = self.endpoint_from_flag.return_value + vault_pki.get_global_client_cert.return_value = { + 'certificate': 'stale_cert', + 'private_key': 'stale_key' + } - self.is_flag_set.side_effect = [True, True] + tls = self.endpoint_from_flag.return_value + # the flag for re-issue return true. + self.is_flag_set.return_value = True bundle = {'certificate': 'crt', 'private_key': 'key'} vault_pki.generate_certificate.return_value = bundle + handlers.publish_global_client_cert() + vault_pki.generate_certificate.assert_called_with('client', 'global-client', [], '3456h', '3456h') - self.unitdata.kv().set.assert_called_with('charm.vault.' - 'global-client-cert', - bundle) + # cluster_relation.set_global_client_cert.assert_called_with(bundle) + vault_pki.set_global_client_cert.assert_called_with(bundle) + self.is_flag_set.assert_called_once_with( + 'certificates.reissue.global.requested') self.set_flag.assert_called_with('charm.vault.' 'global-client-cert.created') tls.set_client_cert.assert_called_with('crt', 'key') @mock.patch.object(handlers, 'client_approle_authorized') @mock.patch.object(handlers, 'vault_pki') - def test_publish_global_client_certe( + def test_publish_global_client_cert( self, vault_pki, _client_approle_authorized): _client_approle_authorized.return_value = True self.config.return_value = { 'default-ttl': '3456h', 'max-ttl': '3456h', } - + vault_pki.generate_certificate.return_value = {} tls = self.endpoint_from_flag.return_value - self.is_flag_set.side_effect = [False, False] + self.is_flag_set.return_value = False bundle = {'certificate': 'crt', 'private_key': 'key'} vault_pki.generate_certificate.return_value = bundle + handlers.publish_global_client_cert() + vault_pki.generate_certificate.assert_called_with('client', 'global-client', [], '3456h', '3456h') - self.unitdata.kv().set.assert_called_with('charm.vault.' - 'global-client-cert', - bundle) + vault_pki.set_global_client_cert.assert_called_with(bundle) self.set_flag.assert_called_with('charm.vault.' 'global-client-cert.created') tls.set_client_cert.assert_called_with('crt', 'key') @@ -1030,6 +1053,10 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): handlers.vault.VaultInvalidRequest, {'certificate': 'crt2', 'private_key': 'key2'}, ] + expected_cache_update_calls = [ + call(tls.new_requests[0], "crt1", "key1"), + call(tls.new_requests[2], "crt2", "key2"), + ] handlers.create_certs() vault_pki.generate_certificate.assert_has_calls([ mock.call('cert_type1', 'common_name1', 'sans1', @@ -1046,6 +1073,205 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): tls.new_requests[2].set_cert.assert_has_calls([ mock.call('crt2', 'key2'), ]) + vault_pki.update_cert_cache.assert_has_calls( + expected_cache_update_calls + ) + + @mock.patch.object(handlers, 'vault_pki') + def test_create_certs_reissue(self, vault_pki): + """Test that certificates are not served from cache on reissue. + + Even when certificates are available from cache, they should not + be reused if reissue was requested. + """ + self.config.return_value = { + 'default-ttl': '3456h', + 'max-ttl': '3456h', + } + cert_cache = ( + ("common_name1_cert", "common_name1_key"), + ("common_name2_cert", "common_name2_key"), + ) + new_certs = ( + {"certificate": "cn1_new_cert", "private_key": "cn1_new_key"}, + {"certificate": "cn2_new_cert", "private_key": "cn2_new_key"}, + ) + vault_pki.find_cert_in_cache.side_effect = cert_cache + vault_pki.generate_certificate.side_effect = new_certs + + tls = self.endpoint_from_flag.return_value + self.is_flag_set.return_value = True + tls.all_requests = [mock.Mock(cert_type='cert_type1', + common_name='common_name1', + sans='sans1'), + mock.Mock(cert_type='cert_type2', + common_name='common_name2', + sans='sans2'), + ] + expected_cache_update_calls = ( + call(tls.all_requests[0], + new_certs[0]["certificate"], + new_certs[0]["private_key"]), + call(tls.all_requests[1], + new_certs[1]["certificate"], + new_certs[1]["private_key"]), + ) + + handlers.create_certs() + + vault_pki.generate_certificate.assert_has_calls([ + mock.call('cert_type1', 'common_name1', 'sans1', + '3456h', '3456h'), + mock.call('cert_type2', 'common_name2', 'sans2', + '3456h', '3456h') + ]) + + for index, request in enumerate(tls.new_requests): + request.set_cert.assert_called_once_with( + new_certs[index]["certificate"], + new_certs[index]["private_key"], + ) + vault_pki.update_cert_cache.assert_has_calls( + expected_cache_update_calls + ) + + @mock.patch.object(handlers, 'vault_pki') + @mock.patch.object(handlers, 'remote_unit') + def test_cert_client_leaving(self, remote_unit, vault_pki): + """Test that certificates are removed from cache on unit departure.""" + # This should be performed only on leader unit + self.is_flag_set.return_value = True + unit_name = "client/0" + cache_unit_id = "client_0" + remote_unit.return_value = unit_name + + handlers.cert_client_leaving(mock.MagicMock()) + + vault_pki.remove_unit_from_cache.assert_called_once_with(cache_unit_id) + + # non-leaders should not perform this action + vault_pki.remove_unit_from_cache.reset_mock() + self.is_flag_set.return_value = False + + handlers.cert_client_leaving(mock.MagicMock()) + + vault_pki.remove_unit_from_cache.assert_not_called() + + @mock.patch.object(handlers.vault_pki, 'get_global_client_cert') + @mock.patch.object(handlers.vault_pki, 'find_cert_in_cache') + @mock.patch.object(handlers.vault_pki, 'get_chain') + @mock.patch.object(handlers.vault_pki, 'get_ca') + def test_sync_cert_from_cache(self, + mock_get_ca, + mock_get_chain, + mock_find_cert_in_cache, + mock_get_global_client_cert): + """Test that non-leaders copy data from cache to relations.""" + global_client_bundle = { + "certificate": "Global client cert", + "private_key": "Global client key", + } + mock_get_global_client_cert.return_value = ( + global_client_bundle + ) + + mock_get_chain.return_value = None + + certs_in_cache = ( + ("cn1_cert", "cn1_key"), + ("cn2_cert", "cn2_key"), + ) + mock_find_cert_in_cache.side_effect = certs_in_cache + + self.is_flag_set.return_value = False + tls = self.endpoint_from_flag.return_value + self.is_flag_set.return_value = True + tls.all_requests = [mock.Mock(cert_type='cert_type1', + common_name='common_name1', + sans='sans1'), + mock.Mock(cert_type='cert_type2', + common_name='common_name2', + sans='sans2'), + ] + + handlers.sync_cert_from_cache() + + tls.set_client_cert.assert_called_once_with( + global_client_bundle["certificate"], + global_client_bundle["private_key"], + ) + + for index, request in enumerate(tls.all_requests): + request.set_cert.assert_called_once_with( + certs_in_cache[index][0], + certs_in_cache[index][1], + ) + + @mock.patch.object(handlers, 'vault_pki') + def test_sync_cert_from_cache_no_ca(self, vault_pki): + """Test that non-leaders copy data from cache to relations.""" + vault_pki.get_ca.return_value = None + + handlers.sync_cert_from_cache() + + vault_pki.get_ca.assert_called_once_with() + tls = self.endpoint_from_flag.return_value + tls.set_ca.assert_not_called() + + @mock.patch.object(handlers, 'vault_pki') + def test_sync_cert_from_cache_no_chain_err(self, vault_pki): + """Test that non-leaders copy data from cache to relations.""" + vault_pki.get_chain.side_effect = hvac.exceptions.InternalServerError + + handlers.sync_cert_from_cache() + + vault_pki.get_ca.assert_called_once_with() + tls = self.endpoint_from_flag.return_value + tls.set_ca.assert_called_once_with(vault_pki.get_ca.return_value) + vault_pki.get_chain.assert_called_once_with() + tls.set_chain.assert_not_called() + + @mock.patch.object(handlers, 'vault_pki') + @mock.patch.object(handlers, 'leader_get') + def test_sync_cert_from_cache_err(self, leader_get, vault_pki): + """Test that it gracefully fails if get_chain doesn't succeed.""" + global_client_bundle = { + "certificate": "Global client cert", + "private_key": "Global client key", + } + + certs_in_cache = ( + ("cn1_cert", "cn1_key"), + ("cn2_cert", "cn2_key"), + ) + vault_pki.get_global_client_cert.return_value = global_client_bundle + vault_pki.find_cert_in_cache.side_effect = certs_in_cache + vault_pki.get_chain.side_effect = hvac.exceptions.InvalidPath + + self.is_flag_set.return_value = False + tls = self.endpoint_from_flag.return_value + self.is_flag_set.return_value = True + tls.set_chain.assert_not_called() + tls.all_requests = [mock.Mock(cert_type='cert_type1', + common_name='common_name1', + sans='sans1'), + mock.Mock(cert_type='cert_type2', + common_name='common_name2', + sans='sans2'), + ] + + handlers.sync_cert_from_cache() + + tls.set_client_cert.assert_called_once_with( + global_client_bundle["certificate"], + global_client_bundle["private_key"], + ) + + for index, request in enumerate(tls.all_requests): + request.set_cert.assert_called_once_with( + certs_in_cache[index][0], + certs_in_cache[index][1], + ) @mock.patch.object(handlers, 'vault_pki') def test_tune_pki_backend(self, vault_pki):