From 38e00f460d6248627b49a53916fad76769a9a6b3 Mon Sep 17 00:00:00 2001 From: Alex Kavanagh Date: Fri, 14 Apr 2023 18:03:42 +0000 Subject: [PATCH] Revert "Implement cert cache for vault units (v3)" This reverts commit 04a237660b0e1aaa8d35f7c110c8f4fa2c38621d. Reason for revert: The bug in [1] caused all the yoga tests to fail in integration testing. Testing with a version of the charm without this commit allowed tests to complete. Thus reverting this until a more complete solution can be found to the original bug(s) [2..4] [1] https://bugs.launchpad.net/charm-keystone/+bug/2015103 [2] LP #1940549 [3] LP #1983269 [4] LP #1845961 Change-Id: I8a794fbb30e921e5322e9023b891d5e17e0e6e8b --- osci.yaml | 2 +- src/layer.yaml | 1 - src/lib/charm/vault_pki.py | 201 ------------ src/reactive/vault_handlers.py | 110 +------ unit_tests/test_lib_charm_vault_pki.py | 342 +-------------------- unit_tests/test_reactive_vault_handlers.py | 282 +---------------- 6 files changed, 22 insertions(+), 916 deletions(-) diff --git a/osci.yaml b/osci.yaml index 162ccb7..45c085c 100644 --- a/osci.yaml +++ b/osci.yaml @@ -16,7 +16,7 @@ needs_charm_build: true charm_build_name: vault build_type: charmcraft - charmcraft_channel: 2.1/stable + charmcraft_channel: 2.0/stable - job: name: jammy-mysql8 diff --git a/src/layer.yaml b/src/layer.yaml index b459cca..035e028 100644 --- a/src/layer.yaml +++ b/src/layer.yaml @@ -9,7 +9,6 @@ includes: - interface:hacluster - interface:vault-kv - interface:tls-certificates - - interface:vault-ha options: basic: use_venv: True diff --git a/src/lib/charm/vault_pki.py b/src/lib/charm/vault_pki.py index 49c343a..4a939ab 100644 --- a/src/lib/charm/vault_pki.py +++ b/src/lib/charm/vault_pki.py @@ -1,11 +1,7 @@ import hvac -from subprocess import check_output, CalledProcessError -from tempfile import NamedTemporaryFile - import charmhelpers.contrib.network.ip as ch_ip import charmhelpers.core.hookenv as hookenv -from charms.reactive.relations import endpoint_from_name from . import vault @@ -13,9 +9,6 @@ CHARM_PKI_MP = "charm-pki-local" CHARM_PKI_ROLE = "local" CHARM_PKI_ROLE_CLIENT = "local-client" -PKI_CACHE_KEY = "pki" -TOP_LEVEL_CERT_KEY = "top_level" - def configure_pki_backend(client, name, ttl=None, max_ttl=None): """Ensure a pki backend is enabled @@ -377,197 +370,3 @@ def update_roles(**kwargs): local.update(**kwargs) del local['server_flag'] write_roles(client, **local) - - -def verify_cert(ca_cert, untrusted_cert): - """Verify that the 'untrusted_cert' is signed by the 'ca_cert'. - - :param ca_cert: CA certificate that should sign the untrusted cert. - :param untrusted_cert: Certificate that is verified by the CA cert. - :return: True if CA cert can verify the untrusted cert - :rtype: bool - """ - with NamedTemporaryFile() as ca_file, NamedTemporaryFile() as cert_file: - ca_file.write(ca_cert.encode("UTF-8")) - ca_file.flush() - - cert_file.write(untrusted_cert.encode("UTF-8")) - cert_file.flush() - - try: - verify_cmd = ['openssl', 'verify', '-CAfile', - ca_file.name, cert_file.name] - check_output(verify_cmd) - except CalledProcessError as exc: - hookenv.log( - "Certificate verification failed: {}".format(exc.output), - hookenv.WARNING - ) - return False - else: - return True - - -def get_pki_cache(unit_name): - """Fetch and parse PKI from the leader storage. - - Returned dictionary contains certificates and keys issued by the vault - leader unit as a response to requests from other charms. The structure - loosely matches the format in which the certificates are shared via data - in the `tls-certificates` relation. - See `tls_certificates_common.CertificateRequest.set_cert()` for more info - on the structure. - - :return: Dictionary containing certs and keys generated by the leader unit - :rtype: dict - """ - unit_pki_cache_key = "{}_{}".format(PKI_CACHE_KEY, unit_name) - cluster = endpoint_from_name('cluster') - return cluster.get_unit_pki(unit_pki_cache_key) - - -def find_cert_in_cache(request): - """Return certificate and key from cache that match the request. - - Returned certificate is validated against the current CA cert. If CA cert - is missing, or certificate fails validation or it's simply not found, - returned value is None, None - - :param request: Request for certificate from "client" unit. - :type request: tls_certificates_common.CertificateRequest - :return: Certificate and private key from cache - :rtype: Union[(str, str), (None, None)] - """ - try: - ca_chain = get_chain() - except (hvac.exceptions.VaultDown, hvac.exceptions.InvalidPath): - # Fetching CA chain may fail - ca_chain = None - - ca_cert = ca_chain or get_ca() - if not ca_cert: - hookenv.log('CA cert not found. Skipping certificate cache lookup.', - hookenv.DEBUG) - return None, None - - unit_data = get_pki_cache(request.unit_name) - - try: - if request._is_top_level_server_cert: - cert = unit_data[TOP_LEVEL_CERT_KEY][request._server_cert_key] - key = unit_data[TOP_LEVEL_CERT_KEY][request._server_key_key] - else: - cert = unit_data[request._publish_key][request.common_name]['cert'] - key = unit_data[request._publish_key][request.common_name]['key'] - except (KeyError, TypeError): - hookenv.log('Certificate for "{}" (cn: "{}") not found in ' - 'cache.'.format(request.unit_name, request.common_name), - hookenv.DEBUG) - return None, None - - if verify_cert(ca_cert, cert): - return cert, key - else: - hookenv.log('Certificate from cache for "{}" (cn: "{}") is no longer' - 'valid and wont be reused.'.format(request.unit_name, - request.common_name)) - return None, None - - -def update_cert_cache(request, cert, key): - """Store certificate and key in the cache. - - Stored values are associated with the request from "client" unit, - so it can be later retrieved when the request is handled again. - - :param request: Request for certificate from "client" unit. - :type request: tls_certificates_common.CertificateRequest - :param cert: Issued certificate for the "client" request (in PEM format) - :type cert: str - :param key: Issued private key from the "client" request (in PEM format) - :type key: str - :return: None - """ - unit_cache = get_pki_cache(request.unit_name) - - if request._is_top_level_server_cert: - unit_cache[TOP_LEVEL_CERT_KEY] = { - request._server_cert_key: cert, - request._server_key_key: key, - } - else: - structured_certs = unit_cache.get(request._publish_key, {}) - structured_certs[request.common_name] = { - 'cert': cert, - 'key': key, - } - unit_cache[request._publish_key] = structured_certs - - hookenv.log('Saving certificate for "{}" ' - '(cn: "{}") into cache.'.format(request.unit_name, - request.common_name), - hookenv.DEBUG) - unit_pki_cache_key = "{}_{}".format(PKI_CACHE_KEY, request.unit_name) - cluster = endpoint_from_name('cluster') - cluster.set_unit_pki(unit_pki_cache_key, unit_cache) - - -def remove_unit_from_cache(unit_name): - """Clear certificates and keys related to the unit from the cache. - - :param unit_name: Name of the unit to be removed from the cache. - :type unit_name: str - :return: None - """ - hookenv.log('Removing certificates for unit "{}" from ' - 'cache.'.format(unit_name), hookenv.DEBUG) - unit_pki_cache_key = "{}_{}".format(PKI_CACHE_KEY, unit_name) - cluster = endpoint_from_name('cluster') - cluster.set_unit_pki(unit_pki_cache_key, None) - - -def populate_cert_cache(tls_endpoint): - """Store previously issued certificates in the cache. - - This function is used when vault charm is upgraded from older version - that may not have a certificate cache to a version that has it. It - goes through all previously issued certificates and stores them in - cache. - - :param tls_endpoint: Endpoint of "certificates" relation - :type tls_endpoint: interface_tls_certificates.provides.TlsProvides - :return: None - """ - hookenv.log( - "Populating certificate cache with data from relations", hookenv.INFO - ) - - for request in tls_endpoint.all_requests: - try: - if request._is_top_level_server_cert: - relation_data = request._unit.relation.to_publish_raw - cert = relation_data[request._server_cert_key] - key = relation_data[request._server_key_key] - else: - relation_data = request._unit.relation.to_publish - cert = relation_data[request._publish_key][ - request.common_name - ]['cert'] - key = relation_data[request._publish_key][ - request.common_name - ]['key'] - except (KeyError, TypeError): - if request._is_top_level_server_cert: - cert_id = request._server_cert_key - else: - cert_id = request.common_name - hookenv.log( - 'Certificate "{}" (or associated key) issued for unit "{}" ' - 'not found in relation data.'.format( - cert_id, request._unit.unit_name - ), - hookenv.WARNING - ) - continue - - update_cert_cache(request, cert, key) diff --git a/src/reactive/vault_handlers.py b/src/reactive/vault_handlers.py index 1406237..13b3a84 100644 --- a/src/reactive/vault_handlers.py +++ b/src/reactive/vault_handlers.py @@ -38,7 +38,6 @@ from charmhelpers.core.hookenv import ( log, network_get_primary_address, open_port, - remote_unit, status_set, unit_private_ip, ) @@ -343,11 +342,6 @@ def upgrade_charm(): remove_state('vault.nrpe.configured') remove_state('vault.ssl.configured') remove_state('vault.requested-lb') - # mkalcok: When upgrading from version of a charm that did not have a - # certificate cache, we need to populate the cache with already issued - # certificates. Otherwise the non-leader units would not be able to sync - # their certificate data via cache. - set_flag('needs-cert-cache-repopulation') @when_not("is-update-status-hook") @@ -1081,102 +1075,21 @@ def publish_global_client_cert(): log("Vault not authorized: Skipping publish_global_client_cert", "WARNING") return + cert_created = is_flag_set('charm.vault.global-client-cert.created') reissue_requested = is_flag_set('certificates.reissue.global.requested') tls = endpoint_from_flag('certificates.available') - cluster = endpoint_from_name('cluster') - bundle = cluster.get_global_client_cert() - certificate_present = "certificate" in bundle and "private_key" in bundle - if not certificate_present or reissue_requested: + if not cert_created or reissue_requested: ttl = config()['default-ttl'] max_ttl = config()['max-ttl'] bundle = vault_pki.generate_certificate('client', 'global-client', [], ttl, max_ttl) - cluster.set_global_client_cert(bundle) + unitdata.kv().set('charm.vault.global-client-cert', bundle) set_flag('charm.vault.global-client-cert.created') clear_flag('certificates.reissue.global.requested') - - tls.set_client_cert(bundle['certificate'], bundle['private_key']) - - -@when('certificates.available') -@when('charm.vault.ca.ready') -@when('leadership.is_leader') -@when('needs-cert-cache-repopulation') -def repopulate_cert_cache(): - """Force repopulation of cert cache on the leader. - - Certain circumstances such as 'upgrade-charm' hook should force the leader - to populate the cert cache, so then non-leaders will follow in the - 'sync_cert_from_cache' method.""" - tls = endpoint_from_flag('certificates.available') - if tls: - vault_pki.populate_cert_cache(tls) - clear_flag('needs-cert-cache-repopulation') - - -@when("certificates.available") -@when_not('leadership.is_leader') -def sync_cert_from_cache(): - """Sync cert and key data in the tls-certificate relation. - - Non-leader units should keep the relation data up-to-date according - to the data from PKI cache that's maintained by the leader. This ensures - that "client" units can use data from any of the related vault units to - receive valid keys and certificates. - """ - tls = endpoint_from_flag('certificates.available') - cert_requests = tls.all_requests - ca = vault_pki.get_ca() - if not ca: - # Don't bother syncing now if we are in a state - # which we don't have a CA, defer syncing to later - return - # propagate CA cert - tls.set_ca(ca) - try: - # this might fail if we were restarted and need to be unsealed - chain = vault_pki.get_chain() - except ( - vault.hvac.exceptions.VaultDown, - vault.hvac.exceptions.InvalidPath, - ): - pass - except vault.VaultNotReady: - # With Vault not being ready, there's no sense in continuing - return - except vault.hvac.exceptions.InternalServerError: - # We either cannot communicate with Vault or - # parse a CA/Chain in this state, defer syncing to later - return else: - tls.set_chain(chain) - - # propagate global client cert from cache - cluster = endpoint_from_name('cluster') - bundle = cluster.get_global_client_cert() - if bundle.get('certificate') and bundle.get('private_key'): - tls.set_client_cert(bundle['certificate'], bundle['private_key']) - - # update certificate data in relations - for request in cert_requests: - cache_cert, cache_key = vault_pki.find_cert_in_cache(request) - if cache_cert and cache_key: - request.set_cert(cache_cert, cache_key) - - -@hook('certificates-relation-departed') -def cert_client_leaving(relation): - """Remove certs and keys of the departing unit from cache.""" - if is_flag_set('leadership.is_leader'): - # mkalcok: Due to certificates requests replacing "/" in the unit - # name with "_" (see: tls_certificates_common.CertificateRequest), - # we must emulate the same behavior when removing unit certs from - # cache. - departing_unit = remote_unit() - log("Removing certificates for {} from cache.".format(departing_unit)) - unit_name = departing_unit.replace('/', '_') - vault_pki.remove_unit_from_cache(unit_name) + bundle = unitdata.kv().get('charm.vault.global-client-cert') + tls.set_client_cert(bundle['certificate'], bundle['private_key']) @when_not("is-update-status-hook") @@ -1206,16 +1119,6 @@ def create_certs(): processed_applications.append(request.application_name) else: cert_type = request.cert_type - - cache_cert, cache_key = vault_pki.find_cert_in_cache(request) - if not reissue_requested and cache_cert and cache_key: - # If valid certificates are in cache, and re-issue was not - # requested, reuse them. - log("Reusing certificate for unit '{}' and CN '{}' from " - "cache.".format(request.unit_name, request.common_name)) - request.set_cert(cache_cert, cache_key) - continue - try: ttl = config()['default-ttl'] max_ttl = config()['max-ttl'] @@ -1223,9 +1126,6 @@ def create_certs(): request.common_name, request.sans, ttl, max_ttl) request.set_cert(bundle['certificate'], bundle['private_key']) - vault_pki.update_cert_cache(request, - bundle["certificate"], - bundle["private_key"]) except vault.VaultInvalidRequest as e: log(str(e), level=ERROR) continue # TODO: report failure back to client diff --git a/unit_tests/test_lib_charm_vault_pki.py b/unit_tests/test_lib_charm_vault_pki.py index 24a6970..624d431 100644 --- a/unit_tests/test_lib_charm_vault_pki.py +++ b/unit_tests/test_lib_charm_vault_pki.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import call, patch, MagicMock +from unittest.mock import patch import hvac @@ -12,9 +12,7 @@ class TestLibCharmVaultPKI(unit_tests.test_utils.CharmTestCase): def setUp(self): super(TestLibCharmVaultPKI, self).setUp() self.obj = vault_pki - self.patches = [ - 'endpoint_from_name', - ] + self.patches = [] self.patch_all() @patch.object(vault_pki.vault, 'is_backend_mounted') @@ -461,339 +459,3 @@ class TestLibCharmVaultPKI(unit_tests.test_utils.CharmTestCase): client_flag=True) ), ]) - - def test_get_pki_cache(self): - """Test retrieving PKI from cache.""" - expected_pki = { - vault_pki.TOP_LEVEL_CERT_KEY: { - "client_unit_0.server.cert": "cert_data", - "client_unit_0.server.key": "key_data", - } - } - cluster_relation = MagicMock() - self.endpoint_from_name.return_value = cluster_relation - cluster_relation.get_unit_pki.return_value = expected_pki - - pki = vault_pki.get_pki_cache('client_unit_0') - cluster_relation.get_unit_pki.assert_called_once_with( - 'pki_client_unit_0') - self.assertEqual(pki, expected_pki) - - # test retrieval if the PKI is not set - cluster_relation.get_unit_pki.return_value = {} - cluster_relation.get_unit_pki.reset_mock() - - pki = vault_pki.get_pki_cache('client_unit_0') - cluster_relation.get_unit_pki.assert_called_once_with( - 'pki_client_unit_0') - self.assertEqual(pki, {}) - - @patch.object(vault_pki, 'get_pki_cache') - @patch.object(vault_pki, 'get_chain') - @patch.object(vault_pki, 'get_ca') - def test_find_cert_in_cache_no_ca(self, get_ca, get_chain, get_pki_cache): - """Test getting cert from cache when CA is missing.""" - get_ca.return_value = None - get_chain.return_value = None - - cert, key = vault_pki.find_cert_in_cache(MagicMock()) - - # assert that CA cert or chain was retrieved - get_ca.assert_called_once_with() - get_chain.assert_called_once_with() - # assert that function does not proceed due to the missing CA - get_pki_cache.assert_not_called() - - self.assertIsNone(cert) - self.assertIsNone(key) - - @patch.object(vault_pki, 'verify_cert') - @patch.object(vault_pki, 'get_pki_cache') - @patch.object(vault_pki, 'get_chain') - @patch.object(vault_pki, 'get_ca') - def test_find_cert_in_cache_missing(self, get_ca, get_chain, - get_pki_cache, verify_cache): - """Test use case when searched certificate is not in cache.""" - request = MagicMock() - request.unit_name = "client_unit_0" - request._is_top_level_server_cert = True - - get_ca.return_value = MagicMock() - get_pki_cache.return_value = {} - - cert, key = vault_pki.find_cert_in_cache(request) - - # assert that verification of cert is not attempted when - # cert is not found - verify_cache.assert_not_called() - - self.assertIsNone(cert) - self.assertIsNone(key) - - # Same scenario, but with non-top-level certificate - request._is_top_level_server_cert = False - - cert, key = vault_pki.find_cert_in_cache(request) - - verify_cache.assert_not_called() - self.assertIsNone(cert) - self.assertIsNone(key) - - @patch.object(vault_pki, 'get_pki_cache') - @patch.object(vault_pki, 'get_chain') - @patch.object(vault_pki, 'get_ca') - def test_find_cert_in_cache_err(self, get_ca, get_chain, get_pki_cache): - """Test getting cert from cache when CA is missing.""" - get_ca.return_value = None - get_chain.side_effect = hvac.exceptions.InvalidPath - - cert, key = vault_pki.find_cert_in_cache(MagicMock()) - - # assert that CA cert or chain was retrieved - get_ca.assert_called_once_with() - get_chain.assert_called_once_with() - # assert that function does not proceed due to the missing CA - get_pki_cache.assert_not_called() - - self.assertIsNone(cert) - self.assertIsNone(key) - - @patch.object(vault_pki, 'verify_cert') - @patch.object(vault_pki, 'get_pki_cache') - @patch.object(vault_pki, 'get_chain') - @patch.object(vault_pki, 'get_ca') - def test_find_cert_in_cache_top_level(self, get_ca, get_chain, - get_pki_cache, verify_cache): - """Test fetching top level cert from cache. - - Additional test scenario: Test that nothing is returned if cert fails - CA verification. - """ - ca_cert = "CA cert data" - expected_cert = "cert data" - expected_key = "key data" - cert_name = "server.cert" - key_name = "server.key" - client_name = "client_unit_0" - - # setup cert request - request = MagicMock() - request.unit_name = client_name - request._is_top_level_server_cert = True - request._server_cert_key = cert_name - request._server_key_key = key_name - - # PKI cache content - pki = { - vault_pki.TOP_LEVEL_CERT_KEY: { - cert_name: expected_cert, - key_name: expected_key - } - } - - get_ca.return_value = ca_cert - get_chain.return_value = ca_cert - get_pki_cache.return_value = pki - verify_cache.return_value = True - - cert, key = vault_pki.find_cert_in_cache(request) - - verify_cache.assert_called_once_with(ca_cert, expected_cert) - self.assertEqual(cert, expected_cert) - self.assertEqual(key, expected_key) - - # Additional test: Nothing should be returned if cert failed - # CA verification. - verify_cache.reset_mock() - verify_cache.return_value = False - - cert, key = vault_pki.find_cert_in_cache(request) - - verify_cache.assert_called_once_with(ca_cert, expected_cert) - self.assertIsNone(cert) - self.assertIsNone(key) - - @patch.object(vault_pki, 'verify_cert') - @patch.object(vault_pki, 'get_pki_cache') - @patch.object(vault_pki, 'get_chain') - @patch.object(vault_pki, 'get_ca') - def test_find_cert_in_cache_not_top_level(self, get_ca, get_chain, - get_pki_cache, verify_cache): - """Test fetching non-top level cert from cache. - - Additional test scenario: Test that nothing is returned if cert fails - CA verification. - """ - ca_cert = "CA cert data" - expected_cert = "cert data" - expected_key = "key data" - client_name = "client_unit_0" - publish_key = client_name + ".processed_client_requests" - common_name = "client.0" - - # setup cert request - request = MagicMock() - request.unit_name = client_name - request._is_top_level_server_cert = False - request._publish_key = publish_key - request.common_name = common_name - - # PKI cache content - pki = { - publish_key: { - common_name: { - "cert": expected_cert, - "key": expected_key, - } - } - } - - get_ca.return_value = ca_cert - get_chain.return_value = ca_cert - get_pki_cache.return_value = pki - verify_cache.return_value = True - - cert, key = vault_pki.find_cert_in_cache(request) - - verify_cache.assert_called_once_with(ca_cert, expected_cert) - self.assertEqual(cert, expected_cert) - self.assertEqual(key, expected_key) - - # Additional test: Nothing should be returned if cert failed - # CA verification. - verify_cache.reset_mock() - verify_cache.return_value = False - - cert, key = vault_pki.find_cert_in_cache(request) - - verify_cache.assert_called_once_with(ca_cert, expected_cert) - self.assertIsNone(cert) - self.assertIsNone(key) - - @patch.object(vault_pki, 'get_pki_cache') - def test_update_cert_cache_top_level_cert(self, get_pki_cache): - """Test storing top-level cert in cache.""" - cert_data = "cert data" - key_data = "key data" - cert_name = "server.cert" - key_name = "server.key" - client_name = "client_unit_0" - - # setup cert request - request = MagicMock() - request.unit_name = client_name - request.common_name = client_name - request._is_top_level_server_cert = True - request._server_cert_key = cert_name - request._server_key_key = key_name - - cluster_relation = MagicMock() - self.endpoint_from_name.return_value = cluster_relation - - # PKI structure - initial_pki = {} - expected_pki = { - vault_pki.TOP_LEVEL_CERT_KEY: { - cert_name: cert_data, - key_name: key_data - } - } - - get_pki_cache.return_value = initial_pki - - vault_pki.update_cert_cache(request, cert_data, key_data) - key = "{}_{}".format(vault_pki.PKI_CACHE_KEY, client_name) - cluster_relation.set_unit_pki.assert_called_once_with( - key, expected_pki) - - @patch.object(vault_pki, 'get_pki_cache') - def test_update_cert_cache_non_top_level_cert(self, get_pki_cache): - """Test storing non-top-level cert in cache.""" - cert_data = "cert data" - key_data = "key data" - client_name = "client_unit_0" - publish_key = client_name + ".processed_client_requests" - common_name = "client.0" - - cluster_relation = MagicMock() - self.endpoint_from_name.return_value = cluster_relation - - # setup cert request - request = MagicMock() - request.unit_name = client_name - request._is_top_level_server_cert = False - request._publish_key = publish_key - request.common_name = common_name - - # PKI structure - initial_pki = {} - expected_pki = { - publish_key: { - common_name: { - "cert": cert_data, - "key": key_data, - } - } - } - - get_pki_cache.return_value = initial_pki - - vault_pki.update_cert_cache(request, cert_data, key_data) - key = "{}_{}".format(vault_pki.PKI_CACHE_KEY, client_name) - cluster_relation.set_unit_pki.assert_called_once_with( - key, expected_pki) - - def test_remove_unit_from_cache(self): - """Test removing unit certificates from cache.""" - cluster_relation = MagicMock() - self.endpoint_from_name.return_value = cluster_relation - vault_pki.remove_unit_from_cache('client_0') - key = "{}_{}".format(vault_pki.PKI_CACHE_KEY, 'client_0') - cluster_relation.set_unit_pki.assert_called_once_with(key, None) - - @patch.object(vault_pki, 'update_cert_cache') - def test_populate_cert_cache(self, update_cert_cache): - # Define data for top level certificate and key - top_level_cert_name = "server.crt" - top_level_key_name = "server.key" - top_level_cert_data = "top level cert" - top_level_key_data = "top level key" - - # Define data for non-top level certificate - processed_request_cn = "juju_unit_service.crt" - processed_request_publish_key = "juju_unit_service.processed" - processed_cert_data = "processed cert" - processed_key_data = "processed key" - - # Mock request for top level certificate - top_level_request = MagicMock() - top_level_request._is_top_level_server_cert = True - top_level_request._server_cert_key = top_level_cert_name - top_level_request._server_key_key = top_level_key_name - top_level_request._unit.relation.to_publish_raw = { - top_level_cert_name: top_level_cert_data, - top_level_key_name: top_level_key_data, - } - - # Mock request for non-top level certificate - processed_request = MagicMock() - processed_request._is_top_level_server_cert = False - processed_request.common_name = processed_request_cn - processed_request._publish_key = processed_request_publish_key - processed_request._unit.relation.to_publish = { - processed_request_publish_key: {processed_request_cn: { - "cert": processed_cert_data, - "key": processed_key_data - }} - } - - tls_endpoint = MagicMock() - tls_endpoint.all_requests = [top_level_request, processed_request] - - vault_pki.populate_cert_cache(tls_endpoint) - - expected_update_calls = [ - call(top_level_request, top_level_cert_data, top_level_key_data), - call(processed_request, processed_cert_data, processed_key_data), - ] - update_cert_cache.assert_has_calls(expected_update_calls) diff --git a/unit_tests/test_reactive_vault_handlers.py b/unit_tests/test_reactive_vault_handlers.py index d07dc7b..63490f7 100644 --- a/unit_tests/test_reactive_vault_handlers.py +++ b/unit_tests/test_reactive_vault_handlers.py @@ -2,7 +2,6 @@ from unittest import mock from unittest.mock import patch, call import charms.reactive -import hvac # Mock out reactive decorators prior to importing reactive.vault dec_mock = mock.MagicMock() @@ -248,16 +247,6 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): mock.call('vault.ssl.configured')] handlers.upgrade_charm() self.remove_state.assert_has_calls(calls) - self.set_flag.assert_called_once_with( - 'needs-cert-cache-repopulation') - - @mock.patch.object(handlers, 'vault_pki') - def test_repopulate_cert_cache(self, mock_vault_pki): - handlers.repopulate_cert_cache() - mock_vault_pki.populate_cert_cache.assert_called_once_with( - self.endpoint_from_flag.return_value) - self.clear_flag.assert_called_once_with( - 'needs-cert-cache-repopulation') def test_request_db(self): psql = mock.MagicMock() @@ -990,18 +979,14 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): self, vault_pki, _client_approle_authorized): _client_approle_authorized.return_value = True tls = self.endpoint_from_flag.return_value - self.is_flag_set.return_value = False - cluster_relation = mock.MagicMock() - self.endpoint_from_name.return_value = cluster_relation - - cluster_relation.get_global_client_cert.return_value = { - 'certificate': 'crt', - 'private_key': 'key' - } + self.is_flag_set.side_effect = [True, False] + self.unitdata.kv().get.return_value = {'certificate': 'crt', + 'private_key': 'key'} handlers.publish_global_client_cert() assert not vault_pki.generate_certificate.called assert not self.set_flag.called - cluster_relation.get_global_client_cert.assert_called_with() + self.unitdata.kv().get.assert_called_with('charm.vault.' + 'global-client-cert') tls.set_client_cert.assert_called_with('crt', 'key') @mock.patch.object(handlers, 'client_approle_authorized') @@ -1014,16 +999,9 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): 'max-ttl': '3456h', } - cluster_relation = mock.MagicMock() - self.endpoint_from_name.return_value = cluster_relation - - cluster_relation.get_global_client_cert.return_value = { - 'certificate': 'stale_cert', - 'private_key': 'stale_key' - } - tls = self.endpoint_from_flag.return_value - self.is_flag_set.return_value = True + + self.is_flag_set.side_effect = [True, True] bundle = {'certificate': 'crt', 'private_key': 'key'} vault_pki.generate_certificate.return_value = bundle @@ -1033,7 +1011,9 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): [], '3456h', '3456h') - cluster_relation.set_global_client_cert.assert_called_with(bundle) + self.unitdata.kv().set.assert_called_with('charm.vault.' + 'global-client-cert', + bundle) self.set_flag.assert_called_with('charm.vault.' 'global-client-cert.created') tls.set_client_cert.assert_called_with('crt', 'key') @@ -1048,13 +1028,8 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): 'max-ttl': '3456h', } - cluster_relation = mock.MagicMock() - self.endpoint_from_name.return_value = cluster_relation - - cluster_relation.get_global_client_cert.return_value = {} - tls = self.endpoint_from_flag.return_value - self.is_flag_set.return_value = False + self.is_flag_set.side_effect = [False, False] bundle = {'certificate': 'crt', 'private_key': 'key'} vault_pki.generate_certificate.return_value = bundle @@ -1064,15 +1039,15 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): [], '3456h', '3456h') - cluster_relation.set_global_client_cert.assert_called_with( - bundle) + self.unitdata.kv().set.assert_called_with('charm.vault.' + 'global-client-cert', + bundle) self.set_flag.assert_called_with('charm.vault.' 'global-client-cert.created') tls.set_client_cert.assert_called_with('crt', 'key') @mock.patch.object(handlers, 'vault_pki') def test_create_certs(self, vault_pki): - vault_pki.find_cert_in_cache.return_value = (None, None) self.config.return_value = { 'default-ttl': '3456h', 'max-ttl': '3456h', @@ -1089,18 +1064,12 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): mock.Mock(cert_type='cert_type2', common_name='common_name2', sans='sans2')] - expected_cache_calls = [call(request) for request in tls.new_requests] vault_pki.generate_certificate.side_effect = [ {'certificate': 'crt1', 'private_key': 'key1'}, handlers.vault.VaultInvalidRequest, {'certificate': 'crt2', 'private_key': 'key2'}, ] - expected_cache_update_calls = [ - call(tls.new_requests[0], "crt1", "key1"), - call(tls.new_requests[2], "crt2", "key2"), - ] handlers.create_certs() - vault_pki.find_cert_in_cache.assert_has_calls(expected_cache_calls) vault_pki.generate_certificate.assert_has_calls([ mock.call('cert_type1', 'common_name1', 'sans1', '3456h', '3456h'), @@ -1116,229 +1085,6 @@ class TestHandlers(unit_tests.test_utils.CharmTestCase): tls.new_requests[2].set_cert.assert_has_calls([ mock.call('crt2', 'key2'), ]) - vault_pki.update_cert_cache.assert_has_calls( - expected_cache_update_calls - ) - - @mock.patch.object(handlers, 'vault_pki') - def test_create_certs_from_cache(self, vault_pki): - """Serve certificates from cache if they are available.""" - cert_cache = ( - ("common_name1_cert", "common_name1_key"), - ("common_name2_cert", "common_name2_key"), - ) - vault_pki.find_cert_in_cache.side_effect = cert_cache - tls = self.endpoint_from_flag.return_value - self.is_flag_set.return_value = False - tls.new_requests = [mock.Mock(cert_type='cert_type1', - common_name='common_name1', - sans='sans1'), - mock.Mock(cert_type='cert_type2', - common_name='common_name2', - sans='sans2'), - ] - - handlers.create_certs() - - vault_pki.generate_certificate.assert_not_called() - for index, request in enumerate(tls.new_requests): - request.set_cert.assert_called_once_with(*cert_cache[index]) - - @mock.patch.object(handlers, 'vault_pki') - def test_create_certs_reissue(self, vault_pki): - """Test that certificates are not served from cache on reissue. - - Even when certificates are available from cache, they should not - be reused if reissue was requested. - """ - self.config.return_value = { - 'default-ttl': '3456h', - 'max-ttl': '3456h', - } - cert_cache = ( - ("common_name1_cert", "common_name1_key"), - ("common_name2_cert", "common_name2_key"), - ) - new_certs = ( - {"certificate": "cn1_new_cert", "private_key": "cn1_new_key"}, - {"certificate": "cn2_new_cert", "private_key": "cn2_new_key"}, - ) - vault_pki.find_cert_in_cache.side_effect = cert_cache - vault_pki.generate_certificate.side_effect = new_certs - - tls = self.endpoint_from_flag.return_value - self.is_flag_set.return_value = True - tls.all_requests = [mock.Mock(cert_type='cert_type1', - common_name='common_name1', - sans='sans1'), - mock.Mock(cert_type='cert_type2', - common_name='common_name2', - sans='sans2'), - ] - expected_cache_update_calls = ( - call(tls.all_requests[0], - new_certs[0]["certificate"], - new_certs[0]["private_key"]), - call(tls.all_requests[1], - new_certs[1]["certificate"], - new_certs[1]["private_key"]), - ) - - handlers.create_certs() - - vault_pki.generate_certificate.assert_has_calls([ - mock.call('cert_type1', 'common_name1', 'sans1', - '3456h', '3456h'), - mock.call('cert_type2', 'common_name2', 'sans2', - '3456h', '3456h') - ]) - - for index, request in enumerate(tls.new_requests): - request.set_cert.assert_called_once_with( - new_certs[index]["certificate"], - new_certs[index]["private_key"], - ) - vault_pki.update_cert_cache.assert_has_calls( - expected_cache_update_calls - ) - - @mock.patch.object(handlers, 'vault_pki') - @mock.patch.object(handlers, 'remote_unit') - def test_cert_client_leaving(self, remote_unit, vault_pki): - """Test that certificates are removed from cache on unit departure.""" - # This should be performed only on leader unit - self.is_flag_set.return_value = True - unit_name = "client/0" - cache_unit_id = "client_0" - remote_unit.return_value = unit_name - - handlers.cert_client_leaving(mock.MagicMock()) - - vault_pki.remove_unit_from_cache.assert_called_once_with(cache_unit_id) - - # non-leaders should not perform this action - vault_pki.remove_unit_from_cache.reset_mock() - self.is_flag_set.return_value = False - - handlers.cert_client_leaving(mock.MagicMock()) - - vault_pki.remove_unit_from_cache.assert_not_called() - - @mock.patch.object(handlers, 'vault_pki') - def test_sync_cert_from_cache(self, vault_pki): - """Test that non-leaders copy data from cache to relations.""" - global_client_bundle = { - "certificate": "Global client cert", - "private_key": "Global client key", - } - cluster_relation = mock.MagicMock() - self.endpoint_from_name.return_value = cluster_relation - - cluster_relation.get_global_client_cert.return_value = ( - global_client_bundle - ) - - certs_in_cache = ( - ("cn1_cert", "cn1_key"), - ("cn2_cert", "cn2_key"), - ) - vault_pki.find_cert_in_cache.side_effect = certs_in_cache - - self.is_flag_set.return_value = False - tls = self.endpoint_from_flag.return_value - self.is_flag_set.return_value = True - tls.all_requests = [mock.Mock(cert_type='cert_type1', - common_name='common_name1', - sans='sans1'), - mock.Mock(cert_type='cert_type2', - common_name='common_name2', - sans='sans2'), - ] - - handlers.sync_cert_from_cache() - - tls.set_client_cert.assert_called_once_with( - global_client_bundle["certificate"], - global_client_bundle["private_key"], - ) - - for index, request in enumerate(tls.all_requests): - request.set_cert.assert_called_once_with( - certs_in_cache[index][0], - certs_in_cache[index][1], - ) - - @mock.patch.object(handlers, 'vault_pki') - def test_sync_cert_from_cache_no_ca(self, vault_pki): - """Test that non-leaders copy data from cache to relations.""" - vault_pki.get_ca.return_value = None - - handlers.sync_cert_from_cache() - - vault_pki.get_ca.assert_called_once_with() - tls = self.endpoint_from_flag.return_value - tls.set_ca.assert_not_called() - - @mock.patch.object(handlers, 'vault_pki') - def test_sync_cert_from_cache_no_chain_err(self, vault_pki): - """Test that non-leaders copy data from cache to relations.""" - vault_pki.get_chain.side_effect = hvac.exceptions.InternalServerError - - handlers.sync_cert_from_cache() - - vault_pki.get_ca.assert_called_once_with() - tls = self.endpoint_from_flag.return_value - tls.set_ca.assert_called_once_with(vault_pki.get_ca.return_value) - vault_pki.get_chain.assert_called_once_with() - tls.set_chain.assert_not_called() - - @mock.patch.object(handlers, 'vault_pki') - @mock.patch.object(handlers, 'leader_get') - def test_sync_cert_from_cache_err(self, leader_get, vault_pki): - """Test that it gracefully fails if get_chain doesn't succeed.""" - global_client_bundle = { - "certificate": "Global client cert", - "private_key": "Global client key", - } - - cluster_relation = mock.MagicMock() - self.endpoint_from_name.return_value = cluster_relation - - cluster_relation.get_global_client_cert.return_value = ( - global_client_bundle - ) - - certs_in_cache = ( - ("cn1_cert", "cn1_key"), - ("cn2_cert", "cn2_key"), - ) - vault_pki.find_cert_in_cache.side_effect = certs_in_cache - vault_pki.get_chain.side_effect = hvac.exceptions.InvalidPath - - self.is_flag_set.return_value = False - tls = self.endpoint_from_flag.return_value - self.is_flag_set.return_value = True - tls.set_chain.assert_not_called() - tls.all_requests = [mock.Mock(cert_type='cert_type1', - common_name='common_name1', - sans='sans1'), - mock.Mock(cert_type='cert_type2', - common_name='common_name2', - sans='sans2'), - ] - - handlers.sync_cert_from_cache() - - tls.set_client_cert.assert_called_once_with( - global_client_bundle["certificate"], - global_client_bundle["private_key"], - ) - - for index, request in enumerate(tls.all_requests): - request.set_cert.assert_called_once_with( - certs_in_cache[index][0], - certs_in_cache[index][1], - ) @mock.patch.object(handlers, 'vault_pki') def test_tune_pki_backend(self, vault_pki):