diff --git a/src/lib/charm/vault_pki.py b/src/lib/charm/vault_pki.py index 0403a8e..f0a4f99 100644 --- a/src/lib/charm/vault_pki.py +++ b/src/lib/charm/vault_pki.py @@ -1,3 +1,8 @@ +import json +import re +from subprocess import check_output, CalledProcessError +from tempfile import NamedTemporaryFile + import hvac import charmhelpers.contrib.network.ip as ch_ip @@ -337,3 +342,446 @@ def update_roles(**kwargs): local.update(**kwargs) del local['server_flag'] write_roles(client, **local) + + +def is_cert_from_vault(cert, name=None): + """Return True if the cert is issued by vault and not revoked. + + Looking at the cert, check to see if it was issued by Vault and not on the + revoked list. In order to do this, the cert must be in x509 format as + openssl is used to extract the ID of the cert. Then the certificate is + extracted from vault and the signatures compared. + + :param cert: the certificate in x509 form + :type cert: str + :param name: the mount point in value, default CHARM_PKI_MP + :type name: str + :returns: True if issued by vault, False if unknown. + :raises VaultDown: if vault is down. + :raises VaultNotReady: if vault is sealed. + :raises VaultError: for any other vault issue. + """ + # first get the ID from the client + serial = get_serial_number_from_cert(cert) + if serial is None: + return False + + try: + # now get a list of serial numbers from vault. + client = vault.get_local_client() + if not name: + name = CHARM_PKI_MP + vault_certs_response = client.secrets.pki.list_certificates( + mount_point=name) + vault_certs = [k.replace('-', '').upper() + for k in vault_certs_response['data']['keys']] + + if serial not in vault_certs: + hookenv.log("Certificate with serial {} not issed by vault." + .format(serial), level=hookenv.DEBUG) + return False + revoked_serials = get_revoked_serials_from_vault(name) + if serial in revoked_serials: + hookenv.log("Serial {} is revoked.".format(serial), + level=hookenv.DEBUG) + return False + return True + except ( + vault.hvac.exceptions.InvalidPath, + vault.hvac.exceptions.InternalServerError, + vault.hvac.exceptions.VaultDown, + vault.VaultNotReady, + ): + # vault is not available for some reason, return None, None as nothing + # else is particularly useful here. + return False + except Exception as e: + hookenv.log("General failure verifying cert: {}".format(str(e)), + level=hookenv.DEBUG) + return False + + +def get_serial_number_from_cert(cert, name=None): + """Extract the serial number from the cert, or return None. + + :param cert: the certificate in x509 form + :type cert: str + :returns: the cert serial number or None. + :rtype: str | None + """ + with NamedTemporaryFile() as f: + f.write(cert.encode()) + f.flush() + command = ["openssl", "x509", "-in", f.name, "-noout", "-serial"] + try: + # output in form of 'serial=xxxxx' + output = check_output(command).decode().strip() + serial = output.split("=")[1] + return serial + except CalledProcessError as e: + hookenv.log("Couldn't process certificate: reason: {}" + .format(str(e)), + level=hookenv.DEBUG) + except (TypeError, IndexError): + hookenv.log( + "Couldn't extract serial number from passed certificate", + level=hookenv.DEBUG) + return None + + +def get_revoked_serials_from_vault(name=None): + """Get a list of revoked serial numbers from vault. + + This fetches the CRL from vault; this is in PEM format. We ought to use + python cryptography.x509.load_pem_x509_crl(), but adding cryptography + requires converting the charm to binary, and seems a lot for one function. + + Thus, the format for no certificates revoked is: + + .. code-block:: text + + Certificate Revocation List (CRL): + Version 2 (0x1) + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN = Vault Intermediate Certificate Authority ... + Last Update: Jul 17 11:58:57 2023 GMT + Next Update: Jul 20 11:58:57 2023 GMT + No Revoked Certificates. + Signature Algorithm: sha256WithRSAEncryption + Signature Value: + ... + + And for two (and the pattern repeats): + + .. code-block:: text + + Certificate Revocation List (CRL): + Version 2 (0x1) + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN = Vault Intermediate Certificate Authority ... + Last Update: Jul 18 11:38:17 2023 GMT + Next Update: Jul 21 11:38:17 2023 GMT + Revoked Certificates: + Serial Number: 6EAE52225CB7AB452F37D4FBAC127DDF9542D3DC + Revocation Date: Jul 18 11:38:17 2023 GMT + Serial Number: 78FBEEE4E419C5A335113E4F1EF41F463534B698 + Revocation Date: Jul 18 11:33:36 2023 GMT + Signature Algorithm: sha256WithRSAEncryption + Signature Value: + + Thus we just need to grep the output for "Serial Number:" + + :param name: the mount point in value, default CHARM_PKI_MP + :type name: str + :returns: a list of serial numbers, uppercase, no hyphens + :rtype: List[str] + :raises VaultDown: if vault is down. + :raises VaultNotReady: if vault is sealed. + :raises VaultError: for any other vault issue. + :raises subprocess.CalledProcessError: if openssl command fails + """ + client = vault.get_local_client() + revoked_certs_response = client.secrets.pki.read_crl(mount_point=name) + with NamedTemporaryFile() as f: + f.write(revoked_certs_response.encode()) + f.flush() + command = ["openssl", "crl", "-in", f.name, "-noout", "-text"] + output = check_output(command).decode().strip() + pattern = re.compile(r"Serial Number: (\S+)$") + serials = [] + # for line in output.split("\n"): + for line in output.splitlines(): + match = pattern.match(line.strip()) + if match: + serials.append(match[1]) + return serials + + +class CertCache: + """A class to store the cert and key for a request. + + This class provides a mechanism to CRUD a cached pair of (cert, key) in + storage, which is as loosely coupled to leader storage as possible. + + As the key and cert is stored in leader settings, it's available across the + units and therefore, any unit can access the key and cert for any unit that + is related to the application. + + The actually storing of the key and cert is done in as flat a way as + possible in leader-settings. This is to minimise the size of the + get and store operations for units that might have many certificate + requests. The key and cert are stored as values to a key which is + constructed from the unit_name, publish_key, common_name and item. See + PUBLISH_KEY_FORMAT for details. + + Although, it has a dependency on the request (from tls_certificates), this + was deemed acceptable to keep the interface obvious and pleasing to use. + """ + PUBLISH_KEY_FORMAT = "pki:{unit_name}:{publish_key}:{common_name}:{item}" + PUBLISH_KEY_PREFIX = "pki:{unit_name}:" + TOP_LEVEL_PUBLISH_KEY = "top_level_publish_key" + + def __init__(self, request): + """Initialise a proxy for the the cert and key in leader-settings. + + :param request: the request from which the cert/cache is cached. + :type request: tls_certificates_common.CertificateRequest + """ + self._request = request + + def _cache_key_for(self, item): + """ + Return a cache key for the request by the item. + + :param item: the item to return a key for, either 'cert' or 'key' + :type item: str + :returns: the unique key for the unit, request, and item + :rtype: str + """ + assert item in ('cert', 'key'), "Error in argument passed" + if self._request._is_top_level_server_cert: + return self.PUBLISH_KEY_FORMAT.format( + unit_name=self._request.unit_name, + publish_key=self.TOP_LEVEL_PUBLISH_KEY, + common_name=self._request.common_name, + item=item) + else: + return self.PUBLISH_KEY_FORMAT.format( + unit_name=self._request.unit_name, + publish_key=self._request._publish_key, + common_name=self._request.common_name, + item=item) + + @staticmethod + def _fetch(key): + """Fetch from the storage using a store pre key and key. + + Note the _store() method dumps it as json so it is fetched as json. + + :param key: the key value to fetch from leader settings + :type key: str + :returns: the value from leader settings or "" + :rtype: str + """ + value = hookenv.leader_get(key) + if value: + return json.loads(value) + return "" + + @staticmethod + def _store(key, value): + """Store a value by key into the actual storage. + + :param key: the key value to set in leader settings + :type key: str + :param value: the value to store. + :type value: str + :raises: RuntimeError if not the leader + :raises: TypeError if value couldn't be converted. + """ + try: + hookenv.leader_set({key: json.dumps(value)}) + except TypeError: + raise + except Exception as e: + raise RuntimeError(str(e)) + + @staticmethod + def _clear(key): + """Explicitly clear a valye in the actual storage. + + :param key: the key value to clear. + :type key: str + :raises: RuntimeError if not the leader + :raises: TypeError if value couldn't be converted. + """ + try: + hookenv.leader_set({key: None}) + except Exception as e: + raise RuntimeError(str(e)) + + def clear(self): + self._clear(self._cache_key_for('key')) + self._clear(self._cache_key_for('cert')) + + @property + def key(self): + """Get the key.""" + return self._fetch(self._cache_key_for('key')) + + @key.setter + def key(self, key_value): + """Set the key value.""" + self._store(self._cache_key_for('key'), key_value) + + @property + def cert(self): + """The the cert.""" + return self._fetch(self._cache_key_for('cert')) + + @cert.setter + def cert(self, cert_value): + """Set the cert value.""" + self._store(self._cache_key_for('cert'), cert_value) + + @classmethod + def remove_all_for(cls, unit_name): + """Remove all the cached keys for a unit name. + + This is an awkward function, as the cache in leader settings is 'flat' + to ensure that the set payloads are as small as possible. + + This iterates through all the keys and if they match the prefix for the + unit_name, it clears them. + + :param unit_name: The unit_name to clear. + :type unit_name: str + """ + prefix = cls.PUBLISH_KEY_PREFIX.format(unit_name=unit_name) + leader_keys = (cls._fetch(None) or {}).keys() + for key in leader_keys: + if key.startswith(prefix): + cls._clear(key) + + +def find_cert_in_cache(request): + """Return certificate and key from cache that match the request. + + Returned certificate is validated against the current CA cert. If CA cert + is missing then the function returns (None, None). + + If the certificate can't be found in vault, then a warning is logged, but + the cert is still returned as it is in leader_settings; the leader may + decide to remove it at a later date. + + :param request: Request for certificate from "client" unit. + :type request: tls_certificates_common.CertificateRequest + :return: Certificate and private key from cache + :rtype: (str, str) | (None, None) + """ + request_pki_cache = CertCache(request) + cert = request_pki_cache.cert + key = request_pki_cache.key + if cert is None or key is None: + return None, None + + if not is_cert_from_vault(cert, name=CHARM_PKI_MP): + hookenv.log('Certificate from cache for "{}" (cn: "{}") was not found ' + 'in vault, but is in the cache. Using, but may not be ' + 'valid.'.format(request.unit_name, request.common_name), + level=hookenv.WARNING) + return cert, key + + +def update_cert_cache(request, cert, key): + """Store certificate and key in the cache. + + Stored values are associated with the request from "client" unit, + so it can be later retrieved when the request is handled again. + + :param request: Request for certificate from "client" unit. + :type request: tls_certificates_common.CertificateRequest + :param cert: Issued certificate for the "client" request (in PEM format) + :type cert: str + :param key: Issued private key from the "client" request (in PEM format) + :type key: str + :return: None + """ + request_pki_cache = CertCache(request) + hookenv.log('Saving certificate for "{}" ' + '(cn: "{}") into cache.'.format(request.unit_name, + request.common_name), + hookenv.DEBUG) + + request_pki_cache.key = key + request_pki_cache.cert = cert + + +def remove_unit_from_cache(unit_name): + """Clear certificates and keys related to the unit from the cache. + + :param unit_name: Name of the unit to be removed from the cache. + :type unit_name: str + :return: None + """ + hookenv.log('Removing certificates for unit "{}" from ' + 'cache.'.format(unit_name), hookenv.DEBUG) + CertCache.remove_all_for(unit_name) + + +def populate_cert_cache(tls_endpoint): + """Store previously issued certificates in the cache. + + This function is used when vault charm is upgraded from older version + that may not have a certificate cache to a version that has it. It + goes through all previously issued certificates and stores them in + cache. + + :param tls_endpoint: Endpoint of "certificates" relation + :type tls_endpoint: interface_tls_certificates.provides.TlsProvides + :return: None + """ + hookenv.log( + "Populating certificate cache with data from relations", hookenv.INFO + ) + + for request in tls_endpoint.all_requests: + try: + if request._is_top_level_server_cert: + relation_data = request._unit.relation.to_publish_raw + cert = relation_data[request._server_cert_key] + key = relation_data[request._server_key_key] + else: + relation_data = request._unit.relation.to_publish + cert = relation_data[request._publish_key][ + request.common_name + ]['cert'] + key = relation_data[request._publish_key][ + request.common_name + ]['key'] + except (KeyError, TypeError): + if request._is_top_level_server_cert: + cert_id = request._server_cert_key + else: + cert_id = request.common_name + hookenv.log( + 'Certificate "{}" (or associated key) issued for unit "{}" ' + 'not found in relation data.'.format( + cert_id, request._unit.unit_name + ), + hookenv.WARNING + ) + continue + + update_cert_cache(request, cert, key) + + +def set_global_client_cert(bundle): + """Set the global cert for all units in the app. + + :param bundle: the bundle returned from generate_certificates() + :type bundle: Dict[str, str] + :raises: RuntimeError if leader_set fails. + :raises: TypeError if the bundle can't be serialised. + """ + try: + hookenv.leader_set( + {'charm.vault.global-client-cert': json.dumps(bundle)}) + except TypeError: + raise + except Exception as e: + raise RuntimeError("Couldn't run leader_settings: {}".format(str(e))) + + +def get_global_client_cert(): + """Return the bundle returned from leader_settings. + + Will return an empty dictionary if key is not present. + + :returns: the bundle previously stored, or {} + :rtype: Dict[str, str] + """ + bundle = hookenv.leader_get('charm.vault.global-client-cert') + if bundle: + return json.loads(bundle) + return {} diff --git a/src/reactive/vault_handlers.py b/src/reactive/vault_handlers.py index c6b51d7..dd2c951 100644 --- a/src/reactive/vault_handlers.py +++ b/src/reactive/vault_handlers.py @@ -35,6 +35,7 @@ from charmhelpers.core.hookenv import ( log, network_get_primary_address, open_port, + remote_unit, status_set, unit_private_ip, ) @@ -293,6 +294,11 @@ def upgrade_charm(): remove_state('vault.nrpe.configured') remove_state('vault.ssl.configured') remove_state('vault.requested-lb') + # When upgrading from version of a charm that did not have a certificate + # cache, we need to populate the cache with already issued certificates. + # Otherwise the non-leader units would not be able to sync their + # certificate data via cache. + set_flag('needs-cert-cache-repopulation') @when_not("is-update-status-hook") @@ -977,7 +983,8 @@ def publish_ca_info(): 'certificates.available') @when_not('config.changed') def publish_global_client_cert(): - """ + """publish the global certificate. + This is for backwards compatibility with older tls-certificate clients only. Obviously, it's not good security / design to have clients sharing a certificate, but it seems that there are clients that depend on this @@ -988,23 +995,104 @@ def publish_global_client_cert(): log("Vault not authorized: Skipping publish_global_client_cert", "WARNING") return - cert_created = is_flag_set('charm.vault.global-client-cert.created') reissue_requested = is_flag_set('certificates.reissue.global.requested') tls = endpoint_from_flag('certificates.available') - if not cert_created or reissue_requested: + bundle = vault_pki.get_global_client_cert() + certificate_present = "certificate" in bundle and "private_key" in bundle + if not certificate_present or reissue_requested: ttl = config()['default-ttl'] max_ttl = config()['max-ttl'] bundle = vault_pki.generate_certificate('client', 'global-client', [], ttl, max_ttl) - unitdata.kv().set('charm.vault.global-client-cert', bundle) + vault_pki.set_global_client_cert(bundle) set_flag('charm.vault.global-client-cert.created') clear_flag('certificates.reissue.global.requested') - else: - bundle = unitdata.kv().get('charm.vault.global-client-cert') + tls.set_client_cert(bundle['certificate'], bundle['private_key']) +@when_not("is-update-status-hook") +@when('certificates.available') +@when('charm.vault.ca.ready') +@when('leadership.is_leader') +@when('needs-cert-cache-repopulation') +def repopulate_cert_cache(): + """Force repopulation of cert cache on the leader. + + Certain circumstances such as 'upgrade-charm' hook should force the leader + to populate the cert cache, so then non-leaders will follow in the + 'sync_cert_from_cache' method.""" + tls = endpoint_from_flag('certificates.available') + if tls: + vault_pki.populate_cert_cache(tls) + clear_flag('needs-cert-cache-repopulation') + + +@when_not("is-update-status-hook") +@when("certificates.available") +@when_not('leadership.is_leader') +def sync_cert_from_cache(): + """Sync cert and key data in the tls-certificate relation. + + Non-leader units should keep the relation data up-to-date according + to the data from PKI cache that's maintained by the leader. This ensures + that "client" units can use data from any of the related vault units to + receive valid keys and certificates. + """ + tls = endpoint_from_flag('certificates.available') + + # propagate the ca stored by the leader if it can be obtained. + ca = vault_pki.get_ca() + if ca: + tls.set_ca(ca) + + ca_chain = None + # propagate the chain if it can be obtained. + try: + # this might fail if we were restarted and need to be unsealed + ca_chain = vault_pki.get_chain() + except ( + vault.hvac.exceptions.InvalidPath, + vault.hvac.exceptions.InternalServerError, + vault.hvac.exceptions.VaultDown, + vault.VaultNotReady, + ) as e: + log("Couldn't get the chain from vault. Reason: {}".format(str(e))) + else: + tls.set_chain(ca_chain) + + # propagate global client cert from cache + bundle = vault_pki.get_global_client_cert() + if bundle.get('certificate') and bundle.get('private_key'): + tls.set_client_cert(bundle['certificate'], bundle['private_key']) + + # update certificate data in relations + cert_requests = tls.all_requests + for request in cert_requests: + cache_cert, cache_key = vault_pki.find_cert_in_cache(request) + if cache_cert and cache_key: + request.set_cert(cache_cert, cache_key) + + +@hook('certificates-relation-departed') +def cert_client_leaving(relation): + """Remove certs and keys of the departing unit from cache. + + Note: this uses the hook as the interface code doesn't provide a mechanism + to notify departing units. + """ + if is_flag_set('leadership.is_leader'): + # Due to certificates requests replacing "/" in the unit + # name with "_" (see: tls_certificates_common.CertificateRequest), + # we must emulate the same behavior when removing unit certs from + # cache. + departing_unit = remote_unit() + log("Removing certificates for {} from cache.".format(departing_unit)) + unit_name = departing_unit.replace('/', '_') + vault_pki.remove_unit_from_cache(unit_name) + + @when_not("is-update-status-hook") @when('leadership.is_leader', 'charm.vault.ca.ready', @@ -1032,6 +1120,7 @@ def create_certs(): processed_applications.append(request.application_name) else: cert_type = request.cert_type + try: ttl = config()['default-ttl'] max_ttl = config()['max-ttl'] @@ -1039,6 +1128,9 @@ def create_certs(): request.common_name, request.sans, ttl, max_ttl) request.set_cert(bundle['certificate'], bundle['private_key']) + vault_pki.update_cert_cache(request, + bundle["certificate"], + bundle["private_key"]) except vault.VaultInvalidRequest as e: log(str(e), level=ERROR) continue # TODO: report failure back to client diff --git a/src/tests/tests.yaml b/src/tests/tests.yaml index edfcf4e..1eed128 100644 --- a/src/tests/tests.yaml +++ b/src/tests/tests.yaml @@ -31,10 +31,7 @@ target_deploy_status: tests: - zaza.openstack.charm_tests.vault.tests.VaultTest -# This second run of the tests is to ensure that Vault can handle updating the -# root CA in Vault with a refreshed CSR and won't end up in a hook-error -# state. (LP: #1866150). -- zaza.openstack.charm_tests.vault.tests.VaultTest +- zaza.openstack.charm_tests.vault.tests.VaultCacheTest tests_options: force_deploy: diff --git a/unit_tests/test_lib_charm_vault_pki.py b/unit_tests/test_lib_charm_vault_pki.py index 5b06e22..a141476 100644 --- a/unit_tests/test_lib_charm_vault_pki.py +++ b/unit_tests/test_lib_charm_vault_pki.py @@ -1,5 +1,7 @@ +import collections +import json from unittest import mock -from unittest.mock import patch +from unittest.mock import call, patch, MagicMock import hvac @@ -515,3 +517,625 @@ class TestLibCharmVaultPKI(unit_tests.test_utils.CharmTestCase): server_flag=False, client_flag=True), ]) + + @patch.object(vault_pki, 'get_serial_number_from_cert') + def test_is_cert_from_vault_no_serial( + self, + mock_get_serial_number_from_cert, + ): + mock_get_serial_number_from_cert.return_value = None + self.assertFalse(vault_pki.is_cert_from_vault('the-cert')) + mock_get_serial_number_from_cert.assert_called_once_with('the-cert') + + @patch.object(vault_pki, 'get_serial_number_from_cert') + @patch.object(vault_pki.vault, 'get_local_client') + @patch.object(vault_pki.hookenv, 'log') + def test_is_cert_from_vault_not_from_vault( + self, + mock_log, + mock_get_local_client, + mock_get_serial_number_from_cert, + ): + mock_get_serial_number_from_cert.return_value = "1234567890" + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.list_certificates.return_value = { + "data": { + "keys": [] + } + } + + self.assertFalse( + vault_pki.is_cert_from_vault('the-cert', name='a-name')) + mock_get_serial_number_from_cert.assert_called_once_with('the-cert') + mock_client.secrets.pki.list_certificates.assert_called_once_with( + mount_point='a-name') + mock_log.assert_called_once_with( + "Certificate with serial 1234567890 not issed by vault.", + level=vault_pki.hookenv.DEBUG + ) + + @patch.object(vault_pki, 'get_serial_number_from_cert') + @patch.object(vault_pki.vault, 'get_local_client') + @patch.object(vault_pki, 'get_revoked_serials_from_vault') + @patch.object(vault_pki.hookenv, 'log') + def test_is_cert_from_vault_not_revoked_serial( + self, + mock_log, + mock_get_revoked_serials_from_vault, + mock_get_local_client, + mock_get_serial_number_from_cert, + ): + mock_get_serial_number_from_cert.return_value = "1234567890" + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.list_certificates.return_value = { + "data": { + "keys": ["1234567890"] + } + } + mock_get_revoked_serials_from_vault.return_value = [] + + self.assertTrue( + vault_pki.is_cert_from_vault('the-cert', name='a-name')) + + mock_get_revoked_serials_from_vault.assert_called_once_with('a-name') + mock_log.assert_not_called() + + @patch.object(vault_pki, 'get_serial_number_from_cert') + @patch.object(vault_pki.vault, 'get_local_client') + @patch.object(vault_pki, 'get_revoked_serials_from_vault') + @patch.object(vault_pki.hookenv, 'log') + def test_is_cert_from_vault_revoked_serial( + self, + mock_log, + mock_get_revoked_serials_from_vault, + mock_get_local_client, + mock_get_serial_number_from_cert, + ): + mock_get_serial_number_from_cert.return_value = "1234567890" + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.list_certificates.return_value = { + "data": { + "keys": ["12-34-56-78-90"] + } + } + mock_get_revoked_serials_from_vault.return_value = [ + "DEADBEEF", + "1234567890", + "notme", + ] + + self.assertFalse( + vault_pki.is_cert_from_vault('the-cert', name='a-name')) + + mock_log.assert_called_once_with( + "Serial 1234567890 is revoked.", level=vault_pki.hookenv.DEBUG) + + @patch.object(vault_pki, 'get_serial_number_from_cert') + @patch.object(vault_pki.vault, 'get_local_client') + @patch.object(vault_pki, 'get_revoked_serials_from_vault') + @patch.object(vault_pki.hookenv, 'log') + def test_is_cert_from_vault_raised_exceptions( + self, + mock_log, + mock_get_revoked_serials_from_vault, + mock_get_local_client, + mock_get_serial_number_from_cert, + ): + mock_get_serial_number_from_cert.return_value = "1234567890" + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.list_certificates.return_value = { + "data": { + "keys": ["12-34-56-78-90"] + } + } + mock_get_revoked_serials_from_vault.return_value = [ + "DEADBEEF", + "1234567890", + "notme", + ] + + def make_raiser(exc): + def _raiser(*args, **kwargs): + raise exc + return _raiser + + exceptions = [ + vault_pki.vault.hvac.exceptions.InvalidPath('wrong-path'), + vault_pki.vault.hvac.exceptions.InternalServerError('bang'), + vault_pki.vault.hvac.exceptions.VaultDown(), + vault_pki.vault.VaultNotReady("really-not-ready"), + ] + + for exception in exceptions: + mock_get_local_client.side_effect = make_raiser(exception) + self.assertFalse( + vault_pki.is_cert_from_vault('the-cert', name='a-name')) + mock_log.assert_not_called() + + class OtherException(Exception): + pass + + mock_get_local_client.side_effect = make_raiser( + OtherException("on noes")) + self.assertFalse( + vault_pki.is_cert_from_vault('the-cert', name='a-name')) + mock_log.assert_called_once_with( + "General failure verifying cert: on noes", + level=vault_pki.hookenv.DEBUG) + + @patch.object(vault_pki, 'check_output') + @patch.object(vault_pki, 'NamedTemporaryFile') + @patch.object(vault_pki.hookenv, 'log') + def test_get_serial_number_from_cert( + self, + mock_log, + mock_named_temporary_file, + mock_check_output + ): + mock_f = MagicMock() + mock_f.name = "filename" + mock_named_temporary_file.return_value.__enter__.return_value = mock_f + mock_check_output.return_value = b" serial=12345678 " + self.assertEqual(vault_pki.get_serial_number_from_cert( + "this is a cert"), "12345678") + mock_f.write.assert_called_once_with(b"this is a cert") + mock_f.flush.assert_called_once_with() + mock_check_output.assert_called_once_with( + ['openssl', 'x509', '-in', 'filename', '-noout', '-serial']) + + @patch.object(vault_pki, 'check_output') + @patch.object(vault_pki, 'NamedTemporaryFile') + @patch.object(vault_pki.hookenv, 'log') + def test_get_serial_number_from_cert_subprocess_error( + self, + mock_log, + mock_named_temporary_file, + mock_check_output + ): + mock_f = MagicMock() + mock_f.name = "filename" + mock_named_temporary_file.return_value.__enter__.return_value = mock_f + mock_check_output.return_value = b" serial=12345678 " + + def _raise(*args, **kwargs): + raise vault_pki.CalledProcessError(cmd="bang", returncode=1) + + mock_check_output.side_effect = _raise + + self.assertEqual(vault_pki.get_serial_number_from_cert( + "this is a cert"), None) + + mock_log.assert_called_once_with( + "Couldn't process certificate: reason: Command 'bang' returned " + "non-zero exit status 1.", + level=vault_pki.hookenv.DEBUG) + + @patch.object(vault_pki, 'check_output') + @patch.object(vault_pki, 'NamedTemporaryFile') + @patch.object(vault_pki.hookenv, 'log') + def test_get_serial_number_from_cert_other_error( + self, + mock_log, + mock_named_temporary_file, + mock_check_output + ): + mock_f = MagicMock() + mock_f.name = "filename" + mock_named_temporary_file.return_value.__enter__.return_value = mock_f + mock_check_output.return_value = b"thing" + + self.assertEqual(vault_pki.get_serial_number_from_cert( + "this is a cert"), None) + + mock_log.assert_called_once_with( + "Couldn't extract serial number from passed certificate", + level=vault_pki.hookenv.DEBUG) + + @patch.object(vault_pki, 'check_output') + @patch.object(vault_pki, 'NamedTemporaryFile') + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_revoked_serials_from_vault_no_serials( + self, + mock_get_local_client, + mock_named_temporary_file, + mock_check_output + ): + mock_f = MagicMock() + mock_f.name = "filename" + mock_named_temporary_file.return_value.__enter__.return_value = mock_f + mock_check_output.return_value = b"\n\n\n" + + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.read_crl.return_value = "the crl" + + self.assertEqual(vault_pki.get_revoked_serials_from_vault( + name=vault_pki.CHARM_PKI_MP), []) + + mock_check_output.assert_called_once_with( + ['openssl', 'crl', '-in', 'filename', '-noout', '-text']) + mock_f.write.assert_called_once_with(b"the crl") + mock_client.secrets.pki.read_crl.assert_called_once_with( + mount_point=vault_pki.CHARM_PKI_MP) + + @patch.object(vault_pki, 'check_output') + @patch.object(vault_pki, 'NamedTemporaryFile') + @patch.object(vault_pki.vault, 'get_local_client') + def test_get_revoked_serials_from_vault_some_serials( + self, + mock_get_local_client, + mock_named_temporary_file, + mock_check_output + ): + mock_f = MagicMock() + mock_f.name = "filename" + mock_named_temporary_file.return_value.__enter__.return_value = mock_f + mock_check_output.return_value = "\n".join([ + "Some interesting line", + " Serial Number: DEADBEEF", + "another interesting line.", + " and another.", + " Serial Number: 1234567890", + " and finally this one." + ]).encode() + + mock_client = MagicMock() + mock_get_local_client.return_value = mock_client + mock_client.secrets.pki.read_crl.return_value = "the crl" + + self.assertEqual(vault_pki.get_revoked_serials_from_vault( + name=vault_pki.CHARM_PKI_MP), ['DEADBEEF', '1234567890']) + + def test_certcache__init(self): + item = vault_pki.CertCache('a-request') + self.assertEqual(item._request, 'a-request') + + class ReadOnlyDict(collections.OrderedDict): + """The ReadOnly dictionary accessible via attributes.""" + + def __init__(self, data): + for k, v in data.items(): + super().__setitem__(k, v) + + def __getitem__(self, key): + return super().__getitem__(key) + + def __setattr__(self, *_): + raise TypeError("{} does not allow setting of attributes" + .format(self.__class__.__name__)) + + def __setitem__(self, *_): + raise TypeError("{} does not allow setting of items" + .format(self.__class__.__name__)) + + __getattr__ = __getitem__ + + def _default_request(self): + return self.ReadOnlyDict({ + 'unit_name': 'the-name', + '_is_top_level_server_cert': False, + '_publish_key': "subbed", + 'common_name': 'cn1' + }) + + def test_certcache__cache_key_for(self): + + request = self.ReadOnlyDict({ + 'unit_name': 'the-name', + '_is_top_level_server_cert': True, + '_publish_key': None, + 'common_name': 'cn1' + }) + self.assertEqual(vault_pki.CertCache(request)._cache_key_for('cert'), + "pki:the-name:top_level_publish_key:cn1:cert") + self.assertEqual(vault_pki.CertCache(request)._cache_key_for('key'), + "pki:the-name:top_level_publish_key:cn1:key") + + request = self._default_request() + self.assertEqual(vault_pki.CertCache(request)._cache_key_for('cert'), + "pki:the-name:subbed:cn1:cert") + self.assertEqual(vault_pki.CertCache(request)._cache_key_for('key'), + "pki:the-name:subbed:cn1:key") + + with self.assertRaises(AssertionError): + vault_pki.CertCache(request)._cache_key_for('thing') + + @patch.object(vault_pki.hookenv, 'leader_get') + def test_certcache__fetch(self, mock_leader_get): + mock_leader_get.return_value = None + request = self._default_request() + self.assertEqual(vault_pki.CertCache(request)._fetch("mine"), "") + mock_leader_get.assert_called_once_with('mine') + + mock_leader_get.reset_mock() + mock_leader_get.return_value = '"the-value"' + self.assertEqual(vault_pki.CertCache(request)._fetch("mine"), + 'the-value') + + @patch.object(vault_pki.hookenv, 'leader_set') + def test_certcache__store(self, mock_leader_set): + request = self._default_request() + vault_pki.CertCache(request)._store("mine", "a value") + mock_leader_set.assert_called_once_with({"mine": '"a value"'}) + + # type error + class A: + pass + + with self.assertRaises(TypeError): + vault_pki.CertCache(request)._store("mine", A()) + + # leader-set failure (subprocess call!) + def _raise(*args, **kwargs): + raise vault_pki.CalledProcessError(cmd="bang", returncode=1) + + mock_leader_set.side_effect = _raise + + with self.assertRaises(RuntimeError): + vault_pki.CertCache(request)._store("mine", "thing") + + @patch.object(vault_pki.hookenv, 'leader_set') + def test_certcache__clear(self, mock_leader_set): + request = self._default_request() + vault_pki.CertCache(request)._clear("mine") + mock_leader_set.assert_called_once_with({"mine": None}) + + # leader-set failure (subprocess call!) + def _raise(*args, **kwargs): + raise vault_pki.CalledProcessError(cmd="bang", returncode=1) + + mock_leader_set.side_effect = _raise + + with self.assertRaises(RuntimeError): + vault_pki.CertCache(request)._clear("mine") + + @patch.object(vault_pki.CertCache, '_clear') + def test_certcache_clear(self, mock__clear): + request = self._default_request() + vault_pki.CertCache(request).clear() + mock__clear.assert_has_calls([ + call('pki:the-name:subbed:cn1:key'), + call('pki:the-name:subbed:cn1:cert'), + ]) + + # leader-set failure (subprocess call!) + def _raise(*args, **kwargs): + raise RuntimeError("bang") + + mock__clear.side_effect = _raise + + with self.assertRaises(RuntimeError): + vault_pki.CertCache(request).clear() + + @patch.object(vault_pki.CertCache, '_store') + @patch.object(vault_pki.CertCache, '_fetch') + @patch.object(vault_pki.CertCache, '_cache_key_for') + def test_certcache__key_property( + self, + mock__cache_key_for, + mock__fetch, + mock__store, + ): + request = self._default_request() + mock__cache_key_for.return_value = "cache-key" + mock__fetch.return_value = "the-value" + + # read + self.assertEqual(vault_pki.CertCache(request).key, "the-value") + mock__cache_key_for.assert_called_once_with('key') + mock__fetch.assert_called_once_with('cache-key') + + # write + vault_pki.CertCache(request).key = 'new-value' + mock__store.assert_called_once_with('cache-key', 'new-value') + + @patch.object(vault_pki.CertCache, '_store') + @patch.object(vault_pki.CertCache, '_fetch') + @patch.object(vault_pki.CertCache, '_cache_key_for') + def test_certcache__cert_property( + self, + mock__cache_key_for, + mock__fetch, + mock__store, + ): + request = self._default_request() + mock__cache_key_for.return_value = "cache-key" + mock__fetch.return_value = "the-value" + + # read + self.assertEqual(vault_pki.CertCache(request).cert, "the-value") + mock__cache_key_for.assert_called_once_with('cert') + mock__fetch.assert_called_once_with('cache-key') + + # write + vault_pki.CertCache(request).cert = 'new-value' + mock__store.assert_called_once_with('cache-key', 'new-value') + + @patch.object(vault_pki.CertCache, '_clear') + @patch.object(vault_pki.CertCache, '_fetch') + def test_certcache__remove_all_for( + self, + mock__fetch, + mock__clear, + ): + mock__fetch.return_value = { + 'pki:the-name:subbed:cn1:key': "thing1", + 'pki:the-name:subbed:cn1:cert': "thing2", + 'pki:the-name2:subbed:cn1:key': "thing3", + 'pki:the-name2:subbed:cn1:cert': "thing4", + } + vault_pki.CertCache.remove_all_for('the-name') + mock__clear.assert_has_calls([ + call('pki:the-name:subbed:cn1:key'), + call('pki:the-name:subbed:cn1:cert'), + ]) + + @patch.object(vault_pki.hookenv, 'log') + @patch.object(vault_pki, 'is_cert_from_vault') + @patch.object(vault_pki, 'CertCache') + def test_find_cert_in_cache(self, + mock_cert_cache, + mock_is_cert_from_vault, + mock_log): + mock_cert_cache_object = MagicMock() + mock_cert_cache_object.cert = "a-cert" + mock_cert_cache_object.key = "a-key" + mock_cert_cache.return_value = mock_cert_cache_object + mock_is_cert_from_vault.return_value = True + request = MagicMock() + + cert, key = vault_pki.find_cert_in_cache(request) + self.assertEqual((cert, key), ("a-cert", "a-key")) + mock_cert_cache.assert_called_once_with(request) + + @patch.object(vault_pki.hookenv, 'log') + @patch.object(vault_pki, 'is_cert_from_vault') + @patch.object(vault_pki.CertCache, 'cert', new_callable=mock.PropertyMock) + @patch.object(vault_pki.CertCache, 'key', new_callable=mock.PropertyMock) + def test_find_cert_in_cache_not_found(self, + mock_key, mock_cert, + mock_is_cert_from_vault, + mock_log): + mock_cert.return_value = None + mock_key.return_value = "a-key" + mock_is_cert_from_vault.return_value = True + request = MagicMock() + + cert, key = vault_pki.find_cert_in_cache(request) + self.assertEqual((cert, key), (None, None)) + + mock_cert.return_value = "a-cert" + mock_key.return_value = None + cert, key = vault_pki.find_cert_in_cache(request) + self.assertEqual((cert, key), (None, None)) + + @patch.object(vault_pki.hookenv, 'log') + @patch.object(vault_pki, 'is_cert_from_vault') + @patch.object(vault_pki.CertCache, 'cert', new_callable=mock.PropertyMock) + @patch.object(vault_pki.CertCache, 'key', new_callable=mock.PropertyMock) + def test_find_cert_in_cache_not_in_vault(self, + mock_key, mock_cert, + mock_is_cert_from_vault, + mock_log): + mock_cert.return_value = "a-cert" + mock_key.return_value = "a-key" + mock_is_cert_from_vault.return_value = False + request = MagicMock() + + cert, key = vault_pki.find_cert_in_cache(request) + self.assertEqual((cert, key), ("a-cert", "a-key")) + mock_is_cert_from_vault.assert_called_once_with( + 'a-cert', name=vault_pki.CHARM_PKI_MP) + + @patch.object(vault_pki.CertCache, 'cert', new_callable=mock.PropertyMock) + @patch.object(vault_pki.CertCache, 'key', new_callable=mock.PropertyMock) + def test_update_cert_cache_top_level_cert(self, mock_key, mock_cert): + """Test storing top-level cert in cache.""" + cert_data = "cert data" + key_data = "key data" + cert_name = "server.cert" + key_name = "server.key" + client_name = "client_unit_0" + + # setup cert request + request = MagicMock() + request.unit_name = client_name + request.common_name = client_name + request._is_top_level_server_cert = True + request._server_cert_key = cert_name + request._server_key_key = key_name + + vault_pki.update_cert_cache(request, cert_data, key_data) + mock_cert.assert_called_once_with(cert_data) + mock_key.assert_called_once_with(key_data) + + @patch.object(vault_pki.CertCache, 'remove_all_for') + def test_remove_unit_from_cache(self, mock_remove_all_for): + """Test removing unit certificates from cache.""" + vault_pki.remove_unit_from_cache('client_0') + mock_remove_all_for.assert_called_once_with('client_0') + + @patch.object(vault_pki, 'update_cert_cache') + def test_populate_cert_cache(self, update_cert_cache): + # Define data for top level certificate and key + top_level_cert_name = "server.crt" + top_level_key_name = "server.key" + top_level_cert_data = "top level cert" + top_level_key_data = "top level key" + + # Define data for non-top level certificate + processed_request_cn = "juju_unit_service.crt" + processed_request_publish_key = "juju_unit_service.processed" + processed_cert_data = "processed cert" + processed_key_data = "processed key" + + # Mock request for top level certificate + top_level_request = MagicMock() + top_level_request._is_top_level_server_cert = True + top_level_request._server_cert_key = top_level_cert_name + top_level_request._server_key_key = top_level_key_name + top_level_request._unit.relation.to_publish_raw = { + top_level_cert_name: top_level_cert_data, + top_level_key_name: top_level_key_data, + } + + # Mock request for non-top level certificate + processed_request = MagicMock() + processed_request._is_top_level_server_cert = False + processed_request.common_name = processed_request_cn + processed_request._publish_key = processed_request_publish_key + processed_request._unit.relation.to_publish = { + processed_request_publish_key: {processed_request_cn: { + "cert": processed_cert_data, + "key": processed_key_data + }} + } + + tls_endpoint = MagicMock() + tls_endpoint.all_requests = [top_level_request, processed_request] + + vault_pki.populate_cert_cache(tls_endpoint) + + expected_update_calls = [ + call(top_level_request, top_level_cert_data, top_level_key_data), + call(processed_request, processed_cert_data, processed_key_data), + ] + update_cert_cache.assert_has_calls(expected_update_calls) + + @patch.object(vault_pki.hookenv, 'leader_set') + def test_set_global_client_cert(self, mock_leader_set): + bundle = { + 'key1': 'value1', + 'key2': 'value2', + } + vault_pki.set_global_client_cert(bundle) + mock_leader_set.assert_called_once_with( + {'charm.vault.global-client-cert': mock.ANY}) + v = mock_leader_set.call_args[0][0]['charm.vault.global-client-cert'] + self.assertEqual(json.loads(v), bundle) + + # Type error + class A: + pass + + with self.assertRaises(TypeError): + vault_pki.set_global_client_cert(A()) + + # leader-set error. + def _raise(*args, **kwargs): + raise vault_pki.CalledProcessError(cmd="bang", returncode=1) + + mock_leader_set.side_effect = _raise + with self.assertRaises(RuntimeError): + vault_pki.set_global_client_cert(bundle) + + @patch.object(vault_pki.hookenv, 'leader_get') + def test_get_global_client_cert(self, mock_leader_get): + mock_leader_get.return_value = '{"a":"a-value"}' + self.assertEqual(vault_pki.get_global_client_cert(), {'a': 'a-value'}) + mock_leader_get.return_value = None + self.assertEqual(vault_pki.get_global_client_cert(), {}) diff --git a/unit_tests/test_reactive_vault_handlers.py b/unit_tests/test_reactive_vault_handlers.py index 9714dff..2e4e6bc 100644 --- a/unit_tests/test_reactive_vault_handlers.py +++ b/unit_tests/test_reactive_vault_handlers.py @@ -2,6 +2,7 @@ from unittest import mock from unittest.mock import patch, call import charms.reactive +import hvac # Mock out reactive decorators prior to importing reactive.vault dec_mock = mock.MagicMock() @@ -245,6 +246,16 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): mock.call('vault.ssl.configured')] handlers.upgrade_charm() self.remove_state.assert_has_calls(calls) + self.set_flag.assert_called_once_with( + 'needs-cert-cache-repopulation') + + @mock.patch.object(handlers, 'vault_pki') + def test_repopulate_cert_cache(self, mock_vault_pki): + handlers.repopulate_cert_cache() + mock_vault_pki.populate_cert_cache.assert_called_once_with( + self.endpoint_from_flag.return_value) + self.clear_flag.assert_called_once_with( + 'needs-cert-cache-repopulation') def test_request_db(self): psql = mock.MagicMock() @@ -933,18 +944,22 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): @mock.patch.object(handlers, 'client_approle_authorized') @mock.patch.object(handlers, 'vault_pki') - def test_publish_global_client_cert_already_gend( + def test_publish_global_client_cert_already_sent( self, vault_pki, _client_approle_authorized): _client_approle_authorized.return_value = True tls = self.endpoint_from_flag.return_value - self.is_flag_set.side_effect = [True, False] - self.unitdata.kv().get.return_value = {'certificate': 'crt', - 'private_key': 'key'} + self.is_flag_set.return_value = False + vault_pki.get_global_client_cert.return_value = { + 'certificate': 'crt', + 'private_key': 'key' + } + vault_pki.generate_certificate.return_value = "bundle" + handlers.publish_global_client_cert() + assert not vault_pki.generate_certificate.called assert not self.set_flag.called - self.unitdata.kv().get.assert_called_with('charm.vault.' - 'global-client-cert') + vault_pki.set_global_client_cert.assert_not_called() tls.set_client_cert.assert_called_with('crt', 'key') @mock.patch.object(handlers, 'client_approle_authorized') @@ -957,49 +972,57 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): 'max-ttl': '3456h', } - tls = self.endpoint_from_flag.return_value + vault_pki.get_global_client_cert.return_value = { + 'certificate': 'stale_cert', + 'private_key': 'stale_key' + } - self.is_flag_set.side_effect = [True, True] + tls = self.endpoint_from_flag.return_value + # the flag for re-issue return true. + self.is_flag_set.return_value = True bundle = {'certificate': 'crt', 'private_key': 'key'} vault_pki.generate_certificate.return_value = bundle + handlers.publish_global_client_cert() + vault_pki.generate_certificate.assert_called_with('client', 'global-client', [], '3456h', '3456h') - self.unitdata.kv().set.assert_called_with('charm.vault.' - 'global-client-cert', - bundle) + # cluster_relation.set_global_client_cert.assert_called_with(bundle) + vault_pki.set_global_client_cert.assert_called_with(bundle) + self.is_flag_set.assert_called_once_with( + 'certificates.reissue.global.requested') self.set_flag.assert_called_with('charm.vault.' 'global-client-cert.created') tls.set_client_cert.assert_called_with('crt', 'key') @mock.patch.object(handlers, 'client_approle_authorized') @mock.patch.object(handlers, 'vault_pki') - def test_publish_global_client_certe( + def test_publish_global_client_cert( self, vault_pki, _client_approle_authorized): _client_approle_authorized.return_value = True self.config.return_value = { 'default-ttl': '3456h', 'max-ttl': '3456h', } - + vault_pki.generate_certificate.return_value = {} tls = self.endpoint_from_flag.return_value - self.is_flag_set.side_effect = [False, False] + self.is_flag_set.return_value = False bundle = {'certificate': 'crt', 'private_key': 'key'} vault_pki.generate_certificate.return_value = bundle + handlers.publish_global_client_cert() + vault_pki.generate_certificate.assert_called_with('client', 'global-client', [], '3456h', '3456h') - self.unitdata.kv().set.assert_called_with('charm.vault.' - 'global-client-cert', - bundle) + vault_pki.set_global_client_cert.assert_called_with(bundle) self.set_flag.assert_called_with('charm.vault.' 'global-client-cert.created') tls.set_client_cert.assert_called_with('crt', 'key') @@ -1027,6 +1050,10 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): handlers.vault.VaultInvalidRequest, {'certificate': 'crt2', 'private_key': 'key2'}, ] + expected_cache_update_calls = [ + call(tls.new_requests[0], "crt1", "key1"), + call(tls.new_requests[2], "crt2", "key2"), + ] handlers.create_certs() vault_pki.generate_certificate.assert_has_calls([ mock.call('cert_type1', 'common_name1', 'sans1', @@ -1043,6 +1070,205 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): tls.new_requests[2].set_cert.assert_has_calls([ mock.call('crt2', 'key2'), ]) + vault_pki.update_cert_cache.assert_has_calls( + expected_cache_update_calls + ) + + @mock.patch.object(handlers, 'vault_pki') + def test_create_certs_reissue(self, vault_pki): + """Test that certificates are not served from cache on reissue. + + Even when certificates are available from cache, they should not + be reused if reissue was requested. + """ + self.config.return_value = { + 'default-ttl': '3456h', + 'max-ttl': '3456h', + } + cert_cache = ( + ("common_name1_cert", "common_name1_key"), + ("common_name2_cert", "common_name2_key"), + ) + new_certs = ( + {"certificate": "cn1_new_cert", "private_key": "cn1_new_key"}, + {"certificate": "cn2_new_cert", "private_key": "cn2_new_key"}, + ) + vault_pki.find_cert_in_cache.side_effect = cert_cache + vault_pki.generate_certificate.side_effect = new_certs + + tls = self.endpoint_from_flag.return_value + self.is_flag_set.return_value = True + tls.all_requests = [mock.Mock(cert_type='cert_type1', + common_name='common_name1', + sans='sans1'), + mock.Mock(cert_type='cert_type2', + common_name='common_name2', + sans='sans2'), + ] + expected_cache_update_calls = ( + call(tls.all_requests[0], + new_certs[0]["certificate"], + new_certs[0]["private_key"]), + call(tls.all_requests[1], + new_certs[1]["certificate"], + new_certs[1]["private_key"]), + ) + + handlers.create_certs() + + vault_pki.generate_certificate.assert_has_calls([ + mock.call('cert_type1', 'common_name1', 'sans1', + '3456h', '3456h'), + mock.call('cert_type2', 'common_name2', 'sans2', + '3456h', '3456h') + ]) + + for index, request in enumerate(tls.new_requests): + request.set_cert.assert_called_once_with( + new_certs[index]["certificate"], + new_certs[index]["private_key"], + ) + vault_pki.update_cert_cache.assert_has_calls( + expected_cache_update_calls + ) + + @mock.patch.object(handlers, 'vault_pki') + @mock.patch.object(handlers, 'remote_unit') + def test_cert_client_leaving(self, remote_unit, vault_pki): + """Test that certificates are removed from cache on unit departure.""" + # This should be performed only on leader unit + self.is_flag_set.return_value = True + unit_name = "client/0" + cache_unit_id = "client_0" + remote_unit.return_value = unit_name + + handlers.cert_client_leaving(mock.MagicMock()) + + vault_pki.remove_unit_from_cache.assert_called_once_with(cache_unit_id) + + # non-leaders should not perform this action + vault_pki.remove_unit_from_cache.reset_mock() + self.is_flag_set.return_value = False + + handlers.cert_client_leaving(mock.MagicMock()) + + vault_pki.remove_unit_from_cache.assert_not_called() + + @mock.patch.object(handlers.vault_pki, 'get_global_client_cert') + @mock.patch.object(handlers.vault_pki, 'find_cert_in_cache') + @mock.patch.object(handlers.vault_pki, 'get_chain') + @mock.patch.object(handlers.vault_pki, 'get_ca') + def test_sync_cert_from_cache(self, + mock_get_ca, + mock_get_chain, + mock_find_cert_in_cache, + mock_get_global_client_cert): + """Test that non-leaders copy data from cache to relations.""" + global_client_bundle = { + "certificate": "Global client cert", + "private_key": "Global client key", + } + mock_get_global_client_cert.return_value = ( + global_client_bundle + ) + + mock_get_chain.return_value = None + + certs_in_cache = ( + ("cn1_cert", "cn1_key"), + ("cn2_cert", "cn2_key"), + ) + mock_find_cert_in_cache.side_effect = certs_in_cache + + self.is_flag_set.return_value = False + tls = self.endpoint_from_flag.return_value + self.is_flag_set.return_value = True + tls.all_requests = [mock.Mock(cert_type='cert_type1', + common_name='common_name1', + sans='sans1'), + mock.Mock(cert_type='cert_type2', + common_name='common_name2', + sans='sans2'), + ] + + handlers.sync_cert_from_cache() + + tls.set_client_cert.assert_called_once_with( + global_client_bundle["certificate"], + global_client_bundle["private_key"], + ) + + for index, request in enumerate(tls.all_requests): + request.set_cert.assert_called_once_with( + certs_in_cache[index][0], + certs_in_cache[index][1], + ) + + @mock.patch.object(handlers, 'vault_pki') + def test_sync_cert_from_cache_no_ca(self, vault_pki): + """Test that non-leaders copy data from cache to relations.""" + vault_pki.get_ca.return_value = None + + handlers.sync_cert_from_cache() + + vault_pki.get_ca.assert_called_once_with() + tls = self.endpoint_from_flag.return_value + tls.set_ca.assert_not_called() + + @mock.patch.object(handlers, 'vault_pki') + def test_sync_cert_from_cache_no_chain_err(self, vault_pki): + """Test that non-leaders copy data from cache to relations.""" + vault_pki.get_chain.side_effect = hvac.exceptions.InternalServerError + + handlers.sync_cert_from_cache() + + vault_pki.get_ca.assert_called_once_with() + tls = self.endpoint_from_flag.return_value + tls.set_ca.assert_called_once_with(vault_pki.get_ca.return_value) + vault_pki.get_chain.assert_called_once_with() + tls.set_chain.assert_not_called() + + @mock.patch.object(handlers, 'vault_pki') + @mock.patch.object(handlers, 'leader_get') + def test_sync_cert_from_cache_err(self, leader_get, vault_pki): + """Test that it gracefully fails if get_chain doesn't succeed.""" + global_client_bundle = { + "certificate": "Global client cert", + "private_key": "Global client key", + } + + certs_in_cache = ( + ("cn1_cert", "cn1_key"), + ("cn2_cert", "cn2_key"), + ) + vault_pki.get_global_client_cert.return_value = global_client_bundle + vault_pki.find_cert_in_cache.side_effect = certs_in_cache + vault_pki.get_chain.side_effect = hvac.exceptions.InvalidPath + + self.is_flag_set.return_value = False + tls = self.endpoint_from_flag.return_value + self.is_flag_set.return_value = True + tls.set_chain.assert_not_called() + tls.all_requests = [mock.Mock(cert_type='cert_type1', + common_name='common_name1', + sans='sans1'), + mock.Mock(cert_type='cert_type2', + common_name='common_name2', + sans='sans2'), + ] + + handlers.sync_cert_from_cache() + + tls.set_client_cert.assert_called_once_with( + global_client_bundle["certificate"], + global_client_bundle["private_key"], + ) + + for index, request in enumerate(tls.all_requests): + request.set_cert.assert_called_once_with( + certs_in_cache[index][0], + certs_in_cache[index][1], + ) @mock.patch.object(handlers, 'vault_pki') def test_tune_pki_backend(self, vault_pki):