diff --git a/heat/engine/worker.py b/heat/engine/worker.py index b1e821204..1a544a794 100644 --- a/heat/engine/worker.py +++ b/heat/engine/worker.py @@ -28,6 +28,7 @@ from heat.common.i18n import _LW from heat.common import messaging as rpc_messaging from heat.db import api as db_api from heat.engine import check_resource +from heat.engine import stack as parser from heat.engine import sync_point from heat.objects import stack as stack_objects from heat.rpc import api as rpc_api @@ -104,24 +105,17 @@ class WorkerService(service.Service): in_progress resources to complete normally; no worker is stopped abruptly. """ - old_trvsl = stack.current_traversal - updated = _update_current_traversal(stack) - if not updated: - LOG.warning(_LW("Failed to update stack %(name)s with new " - "traversal, aborting stack cancel"), - {'name': stack.name}) - return + _stop_traversal(stack) - reason = 'User cancelled stack %s ' % stack.action - updated = stack.state_set(stack.action, stack.FAILED, reason) - if not updated: - LOG.warning(_LW("Failed to update stack %(name)s status" - " to %(action)_%(state)"), - {'name': stack.name, 'action': stack.action, - 'state': stack.FAILED}) - return + db_child_stacks = stack_objects.Stack.get_all_by_root_owner_id( + stack.context, stack.id) - sync_point.delete_all(stack.context, stack.id, old_trvsl) + for db_child in db_child_stacks: + if db_child.status == parser.Stack.IN_PROGRESS: + child = parser.Stack.load(stack.context, + stack_id=db_child.id, + stack=db_child) + _stop_traversal(child) def stop_all_workers(self, stack): # stop the traversal @@ -184,6 +178,27 @@ class WorkerService(service.Service): _cancel_check_resource(stack_id, self.engine_id, self.thread_group_mgr) +def _stop_traversal(stack): + old_trvsl = stack.current_traversal + updated = _update_current_traversal(stack) + if not updated: + LOG.warning(_LW("Failed to update stack %(name)s with new " + "traversal, aborting stack cancel"), + {'name': stack.name}) + return + + reason = 'Stack %(action)s cancelled' % {'action': stack.action} + updated = stack.state_set(stack.action, stack.FAILED, reason) + if not updated: + LOG.warning(_LW("Failed to update stack %(name)s status" + " to %(action)_%(state)"), + {'name': stack.name, 'action': stack.action, + 'state': stack.FAILED}) + return + + sync_point.delete_all(stack.context, stack.id, old_trvsl) + + def _update_current_traversal(stack): previous_traversal = stack.current_traversal stack.current_traversal = uuidutils.generate_uuid() diff --git a/heat/tests/engine/test_engine_worker.py b/heat/tests/engine/test_engine_worker.py index f02c18ab4..175691694 100644 --- a/heat/tests/engine/test_engine_worker.py +++ b/heat/tests/engine/test_engine_worker.py @@ -17,6 +17,8 @@ import mock from heat.db import api as db_api from heat.engine import check_resource +from heat.engine import stack as parser +from heat.engine import template as templatem from heat.engine import worker from heat.objects import stack as stack_objects from heat.rpc import worker_client as wc @@ -178,6 +180,26 @@ class WorkerServiceTest(common.HeatTestCase): mock_ccr.assert_has_calls(calls, any_order=True) self.assertTrue(mock_wc.called) + @mock.patch.object(worker, '_stop_traversal') + def test_stop_traversal_stops_nested_stack(self, mock_st): + mock_tgm = mock.Mock() + ctx = utils.dummy_context() + tmpl = templatem.Template.create_empty_template() + stack1 = parser.Stack(ctx, 'stack1', tmpl, + current_traversal='123') + stack1.store() + stack2 = parser.Stack(ctx, 'stack2', tmpl, + owner_id=stack1.id, current_traversal='456') + stack2.store() + _worker = worker.WorkerService('host-1', 'topic-1', 'engine-001', + mock_tgm) + _worker.stop_traversal(stack1) + self.assertEqual(2, mock_st.call_count) + call1, call2 = mock_st.call_args_list + call_args1, call_args2 = call1[0][0], call2[0][0] + self.assertEqual('stack1', call_args1.name) + self.assertEqual('stack2', call_args2.name) + @mock.patch.object(worker, '_cancel_workers') @mock.patch.object(worker.WorkerService, 'stop_traversal') def test_stop_all_workers_when_stack_in_progress(self, mock_st, mock_cw):