diff --git a/src/reactive/vault.py b/src/reactive/vault.py index ba7697f..a1c5867 100644 --- a/src/reactive/vault.py +++ b/src/reactive/vault.py @@ -96,6 +96,14 @@ def ssl_available(config): return True +def save_etcd_client_credentials(etcd, key, cert, ca): + """Save etcd TLS key, cert and ca to disk""" + credentials = etcd.get_client_credentials() + write_file(key, credentials['client_key'], perms=0o600) + write_file(cert, credentials['client_cert'], perms=0o600) + write_file(ca, credentials['client_ca'], perms=0o600) + + def configure_vault(context): context['disable_mlock'] = config()['disable-mlock'] context['ssl_available'] = is_state('vault.ssl.available') @@ -109,10 +117,10 @@ def configure_vault(context): context['etcd_tls_ca_file'] = '/var/snap/vault/common/etcd-ca.pem' context['etcd_tls_cert_file'] = '/var/snap/vault/common/etcd-cert.pem' context['etcd_tls_key_file'] = '/var/snap/vault/common/etcd.key' - etcd.save_client_credentials( - context['etcd_tls_key_file'], - context['etcd_tls_cert_file'], - context['etcd_tls_ca_file']) + save_etcd_client_credentials(etcd, + key=context['etcd_tls_key_file'], + cert=context['etcd_tls_cert_file'], + ca=context['etcd_tls_ca_file']) context['vault_api_url'] = get_api_url() log("Etcd detected, setting vault_api_url to {}".format( context['vault_api_url'])) diff --git a/unit_tests/test_vault.py b/unit_tests/test_vault.py index c18458c..e33fd7b 100644 --- a/unit_tests/test_vault.py +++ b/unit_tests/test_vault.py @@ -186,9 +186,30 @@ class TestHandlers(unittest.TestCase): handlers.database_not_ready() self.remove_state.assert_called_once_with('vault.schema.created') + @patch.object(handlers, 'write_file') + def test_save_etcd_client_credentials(self, write_file): + etcd_mock = mock.MagicMock() + etcd_mock.get_client_credentials.return_value = { + 'client_cert': 'test-cert', + 'client_key': 'test-key', + 'client_ca': 'test-ca', + } + handlers.save_etcd_client_credentials(etcd_mock, + key='key', + cert='cert', + ca='ca') + etcd_mock.get_client_credentials.assert_called_once_with() + write_file.assert_has_calls([ + mock.call('key', 'test-key', perms=0o600), + mock.call('cert', 'test-cert', perms=0o600), + mock.call('ca', 'test-ca', perms=0o600), + ]) + + @patch.object(handlers, 'save_etcd_client_credentials') @patch.object(handlers, 'can_restart') @patch.object(handlers, 'get_api_url') - def test_configure_vault_etcd(self, get_api_url, can_restart): + def test_configure_vault_etcd(self, get_api_url, can_restart, + save_etcd_client_credentials): can_restart.return_value = True get_api_url.return_value = 'http://this-unit' self.config.return_value = {'disable-mlock': False} @@ -218,6 +239,12 @@ class TestHandlers(unittest.TestCase): perms=0o644) ] self.render.assert_has_calls(render_calls) + save_etcd_client_credentials.assert_called_with( + etcd_mock, + key=expected_context['etcd_tls_key_file'], + cert=expected_context['etcd_tls_cert_file'], + ca=expected_context['etcd_tls_ca_file'], + ) @patch.object(handlers.hvac, 'Client') @patch.object(handlers, 'get_api_url')