diff --git a/heat/engine/scheduler.py b/heat/engine/scheduler.py index 3b896811fa..927a81fe84 100644 --- a/heat/engine/scheduler.py +++ b/heat/engine/scheduler.py @@ -363,9 +363,12 @@ class DependencyTaskGroup(object): dependency tree is passed as an argument. If an error_wait_time is specified, tasks that are already running at - the time of an error will continue to run for up to the specified - time before being cancelled. Once all remaining tasks are complete or - have been cancelled, the original exception is raised. + the time of an error will continue to run for up to the specified time + before being cancelled. Once all remaining tasks are complete or have + been cancelled, the original exception is raised. If error_wait_time is + a callable function it will be called for each task, passing the + dependency key as an argument, to determine the error_wait_time for + that particular task. If aggregate_exceptions is True, then execution of parallel operations will not be cancelled in the event of an error (operations downstream @@ -426,9 +429,16 @@ class DependencyTaskGroup(object): del raised_exceptions def cancel_all(self, grace_period=None): - for r in six.itervalues(self._runners): + if callable(grace_period): + get_grace_period = grace_period + else: + def get_grace_period(key): + return grace_period + + for k, r in six.iteritems(self._runners): + gp = get_grace_period(k) try: - r.cancel(grace_period=grace_period) + r.cancel(grace_period=gp) except Exception as ex: LOG.debug('Exception cancelling task: %s' % six.text_type(ex)) diff --git a/heat/tests/engine/test_scheduler.py b/heat/tests/engine/test_scheduler.py index 4196e2c69b..3dc65e1cc6 100644 --- a/heat/tests/engine/test_scheduler.py +++ b/heat/tests/engine/test_scheduler.py @@ -349,6 +349,28 @@ class DependencyTaskGroupTest(common.HeatTestCase): exc = self.assertRaises(type(e1), run_tasks_with_exceptions) self.assertEqual(e1, exc) + def test_exception_grace_period_per_task(self): + e1 = Exception('e1') + + def get_wait_time(key): + if key == 'B': + return 5 + else: + return None + + def run_tasks_with_exceptions(): + self.error_wait_time = get_wait_time + tasks = (('A', None), ('B', None), ('C', 'A')) + with self._dep_test(*tasks) as dummy: + dummy.do_step(1, 'A').InAnyOrder('1') + dummy.do_step(1, 'B').InAnyOrder('1') + dummy.do_step(2, 'A').InAnyOrder('2').AndRaise(e1) + dummy.do_step(2, 'B').InAnyOrder('2') + dummy.do_step(3, 'B') + + exc = self.assertRaises(type(e1), run_tasks_with_exceptions) + self.assertEqual(e1, exc) + class TaskTest(common.HeatTestCase):