diff --git a/doc/source/user/features.rst b/doc/source/user/features.rst index 2e2b9c0f99..74add84209 100644 --- a/doc/source/user/features.rst +++ b/doc/source/user/features.rst @@ -269,3 +269,11 @@ There are 2 types of string currently supported: After placeholders are replaced, the real URLs are stored in the ``data_source_urls`` field of the job execution object. This is used later to find objects created by a particular job run. + +Keypair replacement +------------------- + +A cluster allows users to create a new keypair to access to the running cluster +when the cluster's keypair is deleted. But the name of new keypair should be +same as the deleted one, and the new keypair will be available for cluster +scaling. diff --git a/releasenotes/notes/keypair-replacement-0c0cc3db0551c112.yaml b/releasenotes/notes/keypair-replacement-0c0cc3db0551c112.yaml new file mode 100644 index 0000000000..8bd0bb9f47 --- /dev/null +++ b/releasenotes/notes/keypair-replacement-0c0cc3db0551c112.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Use a new keypair to access to the running cluster when the cluster's + keypair is deleted. diff --git a/sahara/service/api/v2/clusters.py b/sahara/service/api/v2/clusters.py index 9f632256c5..6e6db0264e 100644 --- a/sahara/service/api/v2/clusters.py +++ b/sahara/service/api/v2/clusters.py @@ -162,6 +162,10 @@ def terminate_cluster(id, force=False): def update_cluster(id, values): + if "update_keypair" in values: + if values["update_keypair"]: + api.OPS.update_keypair(id) + values.pop("update_keypair") if verification_base.update_verification_required(values): api.OPS.handle_verification(id, values) return conductor.cluster_get(context.ctx(), id) diff --git a/sahara/service/ops.py b/sahara/service/ops.py index 6c1271cd55..9280fd8485 100644 --- a/sahara/service/ops.py +++ b/sahara/service/ops.py @@ -26,12 +26,14 @@ from sahara import context from sahara import exceptions from sahara.i18n import _ from sahara.plugins import base as plugin_base +from sahara.plugins import utils as u from sahara.service.edp import job_manager from sahara.service.edp.utils import shares from sahara.service.health import verification_base as ver_base from sahara.service import ntp_service from sahara.service import trusts from sahara.utils import cluster as c_u +from sahara.utils.openstack import nova from sahara.utils import remote from sahara.utils import rpc as rpc_utils @@ -96,6 +98,9 @@ class RemoteOps(rpc_utils.RPCClient): def provision_cluster(self, cluster_id): self.cast('provision_cluster', cluster_id=cluster_id) + def update_keypair(self, cluster_id): + self.cast('update_keypair', cluster_id=cluster_id) + def provision_scaled_cluster(self, cluster_id, node_group_id_map, node_group_instance_map=None): self.cast('provision_scaled_cluster', cluster_id=cluster_id, @@ -146,6 +151,10 @@ class OpsServer(rpc_utils.RPCServer): def provision_cluster(self, cluster_id): _provision_cluster(cluster_id) + @request_context + def update_keypair(self, cluster_id): + _update_keypair(cluster_id) + @request_context def provision_scaled_cluster(self, cluster_id, node_group_id_map, node_group_instance_map=None): @@ -454,3 +463,16 @@ def _refresh_health_for_cluster(cluster_id): def _handle_verification(cluster_id, values): ver_base.handle_verification(cluster_id, values) + + +def _update_keypair(cluster_id): + ctx = context.ctx() + cluster = conductor.cluster_get(ctx, cluster_id) + keypair_name = cluster.user_keypair_id + key = nova.get_keypair(keypair_name) + nodes = u.get_instances(cluster) + for node in nodes: + with node.remote() as r: + r.execute_command( + "echo {keypair} >> ~/.ssh/authorized_keys". + format(keypair=key.public_key)) diff --git a/sahara/service/validations/clusters_schema.py b/sahara/service/validations/clusters_schema.py index 9a83f00b9d..77e0419c54 100644 --- a/sahara/service/validations/clusters_schema.py +++ b/sahara/service/validations/clusters_schema.py @@ -71,6 +71,9 @@ CLUSTER_UPDATE_SCHEMA = { "description": { "type": ["string", "null"] }, + "update_keypair": { + "type": ["boolean", "null"] + }, "name": { "type": "string", "minLength": 1, diff --git a/sahara/tests/unit/service/api/v2/test_clusters.py b/sahara/tests/unit/service/api/v2/test_clusters.py index 7496d6ab21..efe0050200 100644 --- a/sahara/tests/unit/service/api/v2/test_clusters.py +++ b/sahara/tests/unit/service/api/v2/test_clusters.py @@ -22,6 +22,7 @@ from sahara import conductor as cond from sahara import context from sahara import exceptions as exc from sahara.plugins import base as pl_base +from sahara.plugins import utils as u from sahara.service import api as service_api from sahara.service.api.v2 import clusters as api from sahara.tests.unit import base @@ -74,6 +75,21 @@ class FakeOps(object): conductor.node_group_update(context.ctx(), ng, {'count': count}) conductor.cluster_update(context.ctx(), id, {'status': 'Scaled'}) + def update_keypair(self, id): + self.calls_order.append('ops.update_keypair') + cluster = conductor.cluster_get(context.ctx(), id) + keypair_name = cluster.user_keypair_id + nova_p = mock.patch("sahara.utils.openstack.nova.client") + nova = nova_p.start() + key = nova.get_keypair(keypair_name) + nodes = u.get_instances(cluster) + for instance in nodes: + remote = mock.Mock() + remote.execute_command( + "echo {keypair} >> ~/.ssh/authorized_keys".format( + keypair=key.public_key)) + remote.reset_mock() + def terminate_cluster(self, id, force): self.calls_order.append('ops.terminate_cluster') @@ -249,3 +265,8 @@ class TestClusterApi(base.SaharaWithDbTestCase): updated_cluster = api.update_cluster( cluster.id, {'description': 'Cluster'}) self.assertEqual('Cluster', updated_cluster.description) + + def test_cluster_keypair_update(self): + with mock.patch('sahara.service.quotas.check_cluster'): + cluster = api.create_cluster(api_base.SAMPLE_CLUSTER) + api.update_cluster(cluster.id, {'update_keypair': True}) diff --git a/sahara/utils/openstack/nova.py b/sahara/utils/openstack/nova.py index f6dab1e195..412a7699f1 100644 --- a/sahara/utils/openstack/nova.py +++ b/sahara/utils/openstack/nova.py @@ -56,3 +56,8 @@ def get_flavor(**kwargs): def get_instance_info(instance): return base.execute_with_retries( client().servers.get, instance.instance_id) + + +def get_keypair(keypair_name): + return base.execute_with_retries( + client().keypairs.get, keypair_name)