diff --git a/src/config.yaml b/src/config.yaml index 5176008..c585659 100644 --- a/src/config.yaml +++ b/src/config.yaml @@ -47,3 +47,8 @@ options: default: description: | Virtual IP to use api traffic + channel: + type: string + default: stable + description: >- + The snap channel to install from. \ No newline at end of file diff --git a/src/layer.yaml b/src/layer.yaml index c9aace3..daf6a37 100644 --- a/src/layer.yaml +++ b/src/layer.yaml @@ -11,9 +11,6 @@ options: packages: - python3-psycopg2 - libffi-dev - snap: - vault: - channel: stable resources: vault: type: file diff --git a/src/reactive/vault.py b/src/reactive/vault.py index e004a83..eb01e34 100644 --- a/src/reactive/vault.py +++ b/src/reactive/vault.py @@ -52,9 +52,13 @@ from charms.reactive.relations import ( ) from charms.reactive.flags import ( - is_flag_set + is_flag_set, + set_flag, + clear_flag, ) +from charms.layer import snap + # See https://www.vaultproject.io/docs/configuration/storage/postgresql.html VAULT_TABLE_DDL = """ @@ -121,6 +125,44 @@ def save_etcd_client_credentials(etcd, key, cert, ca): write_file(ca, credentials['client_ca'], perms=0o600) +def validate_snap_channel(channel): + """Validate a provided snap channel + + Any prefix is ignored ('0.10' in '0.10/stable' for example). + + :param: channel: string of the snap channel to validate + :returns: boolean: whether provided channel is valid + """ + channel_suffix = channel.split('/')[-1] + if channel_suffix not in ('stable', 'candidate', 'beta', 'edge'): + return False + return True + + +@when_not('snap.installed.vault') +def snap_install(): + channel = config('channel') or 'stable' + if validate_snap_channel(channel): + clear_flag('snap.channel.invalid') + snap.install('vault', channel=channel) + else: + set_flag('snap.channel.invalid') + + +@when('config.changed.channel') +@when('snap.installed.vault') +def snap_refresh(): + channel = config('channel') or 'stable' + if validate_snap_channel(channel): + clear_flag('snap.channel.invalid') + snap.refresh('vault', channel=channel) + if can_restart(): + log("Restarting vault", level=DEBUG) + service_restart('vault') + else: + set_flag('snap.channel.invalid') + + def configure_vault(context): context['disable_mlock'] = config()['disable-mlock'] context['ssl_available'] = is_state('vault.ssl.available') @@ -460,6 +502,12 @@ def cluster_connected(hacluster): def _assess_status(): """Assess status of relations and services for local unit""" + if is_flag_set('snap.channel.invalid'): + status_set('blocked', + 'Invalid snap channel ' + 'configured: {}'.format(config('channel'))) + return + health = None if service_running('vault'): health = get_vault_health() diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py index 367d6ee..ade4a53 100644 --- a/unit_tests/__init__.py +++ b/unit_tests/__init__.py @@ -1,6 +1,11 @@ # unit tests +import mock import sys sys.path.append('src') sys.path.append('src/lib') + +global snap +snap = mock.MagicMock() +sys.modules['charms.layer'] = snap diff --git a/unit_tests/test_vault.py b/unit_tests/test_vault.py index 705aa47..04105ba 100644 --- a/unit_tests/test_vault.py +++ b/unit_tests/test_vault.py @@ -65,6 +65,10 @@ class TestHandlers(unittest.TestCase): 'application_version_set', 'local_unit', 'network_get_primary_address', + 'snap', + 'is_flag_set', + 'set_flag', + 'clear_flag', ] self.patch_all() @@ -345,6 +349,7 @@ class TestHandlers(unittest.TestCase): @patch.object(handlers, 'get_vault_health') def test_assess_status(self, get_vault_health, _assess_interface_groups): + self.is_flag_set.return_value = False get_vault_health.return_value = self._health_response _assess_interface_groups.return_value = [] self.config.return_value = False @@ -366,10 +371,20 @@ class TestHandlers(unittest.TestCase): incomplete_interfaces=mock.ANY), ]) + def test_assess_status_invalid_channel(self): + self.is_flag_set.return_value = True + self.config.return_value = 'foorbar' + handlers._assess_status() + self.status_set.assert_called_with( + 'blocked', 'Invalid snap channel configured: foorbar') + self.is_flag_set.assert_called_with('snap.channel.invalid') + self.config.assert_called_with('channel') + @patch.object(handlers, '_assess_interface_groups') @patch.object(handlers, 'get_vault_health') def test_assess_status_not_running(self, get_vault_health, _assess_interface_groups): + self.is_flag_set.return_value = False get_vault_health.return_value = self._health_response self.service_running.return_value = False handlers._assess_status() @@ -381,6 +396,7 @@ class TestHandlers(unittest.TestCase): @patch.object(handlers, 'get_vault_health') def test_assess_status_vault_init(self, get_vault_health, _assess_interface_groups): + self.is_flag_set.return_value = False get_vault_health.return_value = self._health_response_needs_init _assess_interface_groups.return_value = [] self.service_running.return_value = True @@ -392,6 +408,7 @@ class TestHandlers(unittest.TestCase): @patch.object(handlers, 'get_vault_health') def test_assess_status_vault_sealed(self, get_vault_health, _assess_interface_groups): + self.is_flag_set.return_value = False get_vault_health.return_value = self._health_response_sealed _assess_interface_groups.return_value = [] self.service_running.return_value = True @@ -399,15 +416,14 @@ class TestHandlers(unittest.TestCase): self.status_set.assert_called_with( 'blocked', 'Unit is sealed') - @patch.object(handlers, 'is_flag_set') - def test_assess_interface_groups(self, is_flag_set): + def test_assess_interface_groups(self): flags = { 'db.master.available': True, 'db.connected': True, 'etcd.connected': True, 'baz.connected': True, } - is_flag_set.side_effect = lambda flag: flags.get(flag, False) + self.is_flag_set.side_effect = lambda flag: flags.get(flag, False) missing_interfaces = [] incomplete_interfaces = [] @@ -425,3 +441,60 @@ class TestHandlers(unittest.TestCase): self.assertEqual(incomplete_interfaces, ["'etcd' incomplete", "'baz' incomplete"]) + + def test_snap_install(self): + self.config.return_value = None + handlers.snap_install() + self.snap.install.assert_called_with('vault', channel='stable') + self.config.assert_called_with('channel') + self.clear_flag.assert_called_with('snap.channel.invalid') + + def test_snap_install_channel_set(self): + self.config.return_value = 'edge' + handlers.snap_install() + self.snap.install.assert_called_with('vault', channel='edge') + self.config.assert_called_with('channel') + self.clear_flag.assert_called_with('snap.channel.invalid') + + def test_snap_install_invalid_channel(self): + self.config.return_value = 'foorbar' + handlers.snap_install() + self.snap.install.assert_not_called() + self.config.assert_called_with('channel') + self.set_flag.assert_called_with('snap.channel.invalid') + + @patch.object(handlers, 'can_restart') + def test_snap_refresh_restartable(self, can_restart): + self.config.return_value = 'edge' + can_restart.return_value = True + handlers.snap_refresh() + self.snap.refresh.assert_called_with('vault', channel='edge') + self.config.assert_called_with('channel') + self.service_restart.assert_called_with('vault') + self.clear_flag.assert_called_with('snap.channel.invalid') + + @patch.object(handlers, 'can_restart') + def test_snap_refresh_not_restartable(self, can_restart): + self.config.return_value = 'edge' + can_restart.return_value = False + handlers.snap_refresh() + self.snap.refresh.assert_called_with('vault', channel='edge') + self.config.assert_called_with('channel') + self.service_restart.assert_not_called() + self.clear_flag.assert_called_with('snap.channel.invalid') + + def test_snap_refresh_invalid_channel(self): + self.config.return_value = 'foorbar' + handlers.snap_refresh() + self.snap.refresh.assert_not_called() + self.config.assert_called_with('channel') + self.set_flag.assert_called_with('snap.channel.invalid') + + def test_validate_snap_channel(self): + self.assertTrue(handlers.validate_snap_channel('stable')) + self.assertTrue(handlers.validate_snap_channel('0.10/stable')) + self.assertTrue(handlers.validate_snap_channel('edge')) + self.assertTrue(handlers.validate_snap_channel('beta')) + self.assertTrue(handlers.validate_snap_channel('candidate')) + self.assertFalse(handlers.validate_snap_channel('foobar')) + self.assertFalse(handlers.validate_snap_channel('0.10/foobar'))