diff --git a/sahara/service/ops.py b/sahara/service/ops.py index 9280fd8485..1118ec22d8 100644 --- a/sahara/service/ops.py +++ b/sahara/service/ops.py @@ -327,10 +327,27 @@ def _provision_cluster(cluster_id): _refresh_health_for_cluster(cluster_id) +def _specific_inst_to_delete(node_group, node_group_instance_map=None): + if node_group_instance_map: + if node_group.id in node_group_instance_map: + return True + return False + + @ops_error_handler( _("Scaling cluster failed for the following reason(s): {reason}")) def _provision_scaled_cluster(cluster_id, node_group_id_map, node_group_instance_map=None): + """Provision scaled cluster. + + :param cluster_id: Id of cluster to be scaled. + + :param node_group_id_map: Dictionary in the format + node_group_id: number of instances. + + :param node_group_instance_map: Specifies the instances to be removed in + each node group. + """ ctx, cluster, plugin = _prepare_provisioning(cluster_id) # Decommissioning surplus nodes with the plugin @@ -340,19 +357,25 @@ def _provision_scaled_cluster(cluster_id, node_group_id_map, try: instances_to_delete = [] for node_group in cluster.node_groups: + ng_inst_to_delete_count = 0 + # new_count is the new number of instance on the current node group new_count = node_group_id_map[node_group.id] if new_count < node_group.count: - if (node_group_instance_map and - node_group.id in node_group_instance_map): - for instance_ref in node_group_instance_map[ - node_group.id]: - instance = _get_instance_obj(node_group.instances, - instance_ref) - instances_to_delete.append(instance) + # Adding selected instances to delete to the list + if _specific_inst_to_delete(node_group, + node_group_instance_map): + for instance_ref in node_group_instance_map[node_group.id]: + instances_to_delete.append(_get_instance_obj( + node_group.instances, instance_ref)) + ng_inst_to_delete_count += 1 - while node_group.count - new_count > len(instances_to_delete): + # Adding random instances to the list when the number of + # specific instances does not equals the difference between the + # current count and the new count of instances. + while node_group.count - new_count > ng_inst_to_delete_count: instances_to_delete.append(_get_random_instance_from_ng( node_group.instances, instances_to_delete)) + ng_inst_to_delete_count += 1 if instances_to_delete: context.set_step_type(_("Plugin: decommission cluster")) diff --git a/sahara/tests/unit/service/api/v2/base.py b/sahara/tests/unit/service/api/v2/base.py index 642c63c7f3..b53cff5504 100644 --- a/sahara/tests/unit/service/api/v2/base.py +++ b/sahara/tests/unit/service/api/v2/base.py @@ -85,7 +85,7 @@ SCALE_DATA_SPECIFIC_INSTANCE = { }, { 'name': 'ng_2', - 'count': 2, + 'count': 1, 'instances': ['ng_2_0'] } ], diff --git a/sahara/tests/unit/service/api/v2/test_clusters.py b/sahara/tests/unit/service/api/v2/test_clusters.py index efe0050200..2e6d843877 100644 --- a/sahara/tests/unit/service/api/v2/test_clusters.py +++ b/sahara/tests/unit/service/api/v2/test_clusters.py @@ -244,6 +244,37 @@ class TestClusterApi(base.SaharaWithDbTestCase): 'ops.provision_scaled_cluster', 'ops.terminate_cluster'], self.calls_order) + @mock.patch('sahara.service.quotas.check_cluster', return_value=None) + @mock.patch('sahara.service.quotas.check_scaling', return_value=None) + def test_scale_cluster_specific_and_non_specific(self, check_scaling, + check_cluster): + cluster = api.create_cluster(api_base.SAMPLE_CLUSTER) + cluster = api.get_cluster(cluster.id) + api.scale_cluster(cluster.id, api_base.SCALE_DATA_SPECIFIC_INSTANCE) + result_cluster = api.get_cluster(cluster.id) + self.assertEqual('Scaled', result_cluster.status) + expected_count = { + 'ng_1': 3, + 'ng_2': 1, + 'ng_3': 1, + } + ng_count = 0 + for ng in result_cluster.node_groups: + self.assertEqual(expected_count[ng.name], ng.count) + ng_count += 1 + self.assertEqual(1, result_cluster.node_groups[1].count) + self.assertNotIn('ng_2_0', + self._get_instances_ids( + result_cluster.node_groups[1])) + self.assertEqual(3, ng_count) + api.terminate_cluster(result_cluster.id) + self.assertEqual( + ['get_open_ports', 'recommend_configs', 'validate', + 'ops.provision_cluster', 'get_open_ports', + 'recommend_configs', 'validate_scaling', + 'ops.provision_scaled_cluster', + 'ops.terminate_cluster'], self.calls_order) + def _get_instances_ids(self, node_group): instance_ids = [] for instance in node_group.instances: