diff --git a/heat/engine/scheduler.py b/heat/engine/scheduler.py index ab767ca304..36a808ae30 100644 --- a/heat/engine/scheduler.py +++ b/heat/engine/scheduler.py @@ -236,6 +236,35 @@ class TaskRunner(object): while not self.step(): self._sleep(wait_time) + def as_task(self, timeout=None): + """Return a task that drives the TaskRunner.""" + resuming = self.started() + if not resuming: + self.start(timeout=timeout) + else: + if timeout is not None: + new_timeout = Timeout(self, timeout) + if self._timeout is None or new_timeout < self._timeout: + self._timeout = new_timeout + + done = self.step() if resuming else self.done() + while not done: + try: + yield + except GeneratorExit: + self.cancel() + raise + except: # noqa + self._done = True + try: + self._runner.throw(*sys.exc_info()) + except StopIteration: + return + else: + self._done = False + else: + done = self.step() + def cancel(self, grace_period=None): """Cancel the task and mark it as done.""" if self.done(): diff --git a/heat/tests/engine/test_scheduler.py b/heat/tests/engine/test_scheduler.py index 3dc65e1cc6..333f282fab 100644 --- a/heat/tests/engine/test_scheduler.py +++ b/heat/tests/engine/test_scheduler.py @@ -395,6 +395,94 @@ class TaskTest(common.HeatTestCase): scheduler.TaskRunner(task)() + def test_run_as_task(self): + task = DummyTask() + self.m.StubOutWithMock(task, 'do_step') + self.m.StubOutWithMock(scheduler.TaskRunner, '_sleep') + + task.do_step(1).AndReturn(None) + task.do_step(2).AndReturn(None) + task.do_step(3).AndReturn(None) + + self.m.ReplayAll() + + tr = scheduler.TaskRunner(task) + rt = tr.as_task() + for step in rt: + pass + self.assertTrue(tr.done()) + + def test_run_as_task_started(self): + task = DummyTask() + self.m.StubOutWithMock(task, 'do_step') + self.m.StubOutWithMock(scheduler.TaskRunner, '_sleep') + + task.do_step(1).AndReturn(None) + task.do_step(2).AndReturn(None) + task.do_step(3).AndReturn(None) + + self.m.ReplayAll() + + tr = scheduler.TaskRunner(task) + tr.start() + for step in tr.as_task(): + pass + self.assertTrue(tr.done()) + + def test_run_as_task_cancel(self): + task = DummyTask() + self.m.StubOutWithMock(task, 'do_step') + self.m.StubOutWithMock(scheduler.TaskRunner, '_sleep') + + task.do_step(1).AndReturn(None) + + self.m.ReplayAll() + + tr = scheduler.TaskRunner(task) + rt = tr.as_task() + next(rt) + rt.close() + + self.assertTrue(tr.done()) + + def test_run_as_task_exception(self): + class TestException(Exception): + pass + + task = DummyTask() + self.m.StubOutWithMock(task, 'do_step') + self.m.StubOutWithMock(scheduler.TaskRunner, '_sleep') + + task.do_step(1).AndReturn(None) + + self.m.ReplayAll() + + tr = scheduler.TaskRunner(task) + rt = tr.as_task() + next(rt) + self.assertRaises(TestException, rt.throw, TestException) + + self.assertTrue(tr.done()) + + def test_run_as_task_swallow_exception(self): + class TestException(Exception): + pass + + def task(): + try: + yield + except TestException: + yield + + tr = scheduler.TaskRunner(task) + rt = tr.as_task() + next(rt) + rt.throw(TestException) + + self.assertFalse(tr.done()) + self.assertRaises(StopIteration, next, rt) + self.assertTrue(tr.done()) + def test_run_delays(self): task = DummyTask(delays=itertools.repeat(2)) self.m.StubOutWithMock(task, 'do_step') @@ -688,6 +776,73 @@ class TaskTest(common.HeatTestCase): self.assertFalse(runner) self.assertTrue(runner.step()) + def test_as_task_timeout(self): + st = timeutils.wallclock() + + def task(): + while True: + yield + + self.m.StubOutWithMock(timeutils, 'wallclock') + timeutils.wallclock().AndReturn(st) + timeutils.wallclock().AndReturn(st + 0.5) + timeutils.wallclock().AndReturn(st + 1.5) + + self.m.ReplayAll() + + runner = scheduler.TaskRunner(task) + + rt = runner.as_task(timeout=1) + next(rt) + self.assertTrue(runner) + self.assertRaises(scheduler.Timeout, next, rt) + + def test_as_task_timeout_shorter(self): + st = timeutils.wallclock() + + def task(): + while True: + yield + + self.m.StubOutWithMock(timeutils, 'wallclock') + timeutils.wallclock().AndReturn(st) + timeutils.wallclock().AndReturn(st + 0.5) + timeutils.wallclock().AndReturn(st + 0.7) + timeutils.wallclock().AndReturn(st + 1.6) + timeutils.wallclock().AndReturn(st + 2.6) + + self.m.ReplayAll() + + runner = scheduler.TaskRunner(task) + runner.start(timeout=10) + self.assertTrue(runner) + + rt = runner.as_task(timeout=1) + next(rt) + self.assertRaises(scheduler.Timeout, next, rt) + + def test_as_task_timeout_longer(self): + st = timeutils.wallclock() + + def task(): + while True: + yield + + self.m.StubOutWithMock(timeutils, 'wallclock') + timeutils.wallclock().AndReturn(st) + timeutils.wallclock().AndReturn(st + 0.5) + timeutils.wallclock().AndReturn(st + 0.6) + timeutils.wallclock().AndReturn(st + 1.5) + + self.m.ReplayAll() + + runner = scheduler.TaskRunner(task) + runner.start(timeout=1) + self.assertTrue(runner) + + rt = runner.as_task(timeout=10) + self.assertRaises(scheduler.Timeout, next, rt) + def test_cancel_not_started(self): task = DummyTask(1)