Merge "Fix rerun of task in subworkflow"
This commit is contained in:
commit
89dc5d9858
|
@ -157,7 +157,7 @@ class DefaultEngine(base.Engine, coordination.Service):
|
|||
# policy worked.
|
||||
|
||||
wf_ex_id = task_ex.workflow_execution_id
|
||||
wf_ex = self._lock_workflow_execution(wf_ex_id)
|
||||
wf_ex = wf_handler.lock_workflow_execution(wf_ex_id)
|
||||
|
||||
wf_trace.info(
|
||||
task_ex,
|
||||
|
@ -239,7 +239,7 @@ class DefaultEngine(base.Engine, coordination.Service):
|
|||
).get_clone()
|
||||
|
||||
wf_ex_id = action_ex.task_execution.workflow_execution_id
|
||||
wf_ex = self._lock_workflow_execution(wf_ex_id)
|
||||
wf_ex = wf_handler.lock_workflow_execution(wf_ex_id)
|
||||
|
||||
task_handler.on_action_complete(action_ex, result)
|
||||
|
||||
|
@ -255,7 +255,7 @@ class DefaultEngine(base.Engine, coordination.Service):
|
|||
with db_api.transaction():
|
||||
action_ex = db_api.get_action_execution(action_ex_id)
|
||||
task_ex = action_ex.task_execution
|
||||
wf_ex = self._lock_workflow_execution(
|
||||
wf_ex = wf_handler.lock_workflow_execution(
|
||||
task_ex.workflow_execution_id
|
||||
)
|
||||
self._on_task_state_change(task_ex, wf_ex)
|
||||
|
@ -278,7 +278,7 @@ class DefaultEngine(base.Engine, coordination.Service):
|
|||
@u.log_exec(LOG)
|
||||
def pause_workflow(self, execution_id):
|
||||
with db_api.transaction():
|
||||
wf_ex = self._lock_workflow_execution(execution_id)
|
||||
wf_ex = wf_handler.lock_workflow_execution(execution_id)
|
||||
|
||||
wf_handler.set_execution_state(wf_ex, states.PAUSED)
|
||||
|
||||
|
@ -287,7 +287,11 @@ class DefaultEngine(base.Engine, coordination.Service):
|
|||
def _continue_workflow(self, wf_ex, task_ex=None, reset=True, env=None):
|
||||
wf_ex = wf_service.update_workflow_execution_env(wf_ex, env)
|
||||
|
||||
wf_handler.set_execution_state(wf_ex, states.RUNNING)
|
||||
wf_handler.set_execution_state(
|
||||
wf_ex,
|
||||
states.RUNNING,
|
||||
set_upstream=True
|
||||
)
|
||||
|
||||
wf_ctrl = wf_base.WorkflowController.get_controller(wf_ex)
|
||||
|
||||
|
@ -327,7 +331,7 @@ class DefaultEngine(base.Engine, coordination.Service):
|
|||
def rerun_workflow(self, wf_ex_id, task_ex_id, reset=True, env=None):
|
||||
try:
|
||||
with db_api.transaction():
|
||||
wf_ex = self._lock_workflow_execution(wf_ex_id)
|
||||
wf_ex = wf_handler.lock_workflow_execution(wf_ex_id)
|
||||
|
||||
task_ex = db_api.get_task_execution(task_ex_id)
|
||||
|
||||
|
@ -350,7 +354,7 @@ class DefaultEngine(base.Engine, coordination.Service):
|
|||
def resume_workflow(self, wf_ex_id, env=None):
|
||||
try:
|
||||
with db_api.transaction():
|
||||
wf_ex = self._lock_workflow_execution(wf_ex_id)
|
||||
wf_ex = wf_handler.lock_workflow_execution(wf_ex_id)
|
||||
|
||||
if wf_ex.state != states.PAUSED:
|
||||
return wf_ex.get_clone()
|
||||
|
@ -367,7 +371,7 @@ class DefaultEngine(base.Engine, coordination.Service):
|
|||
@u.log_exec(LOG)
|
||||
def stop_workflow(self, execution_id, state, message=None):
|
||||
with db_api.transaction():
|
||||
wf_ex = self._lock_workflow_execution(execution_id)
|
||||
wf_ex = wf_handler.lock_workflow_execution(execution_id)
|
||||
|
||||
return self._stop_workflow(wf_ex, state, message)
|
||||
|
||||
|
@ -498,10 +502,3 @@ class DefaultEngine(base.Engine, coordination.Service):
|
|||
data_flow.add_workflow_variables_to_context(wf_ex, wf_spec)
|
||||
|
||||
return wf_ex
|
||||
|
||||
@staticmethod
|
||||
def _lock_workflow_execution(wf_exec_id):
|
||||
# Locks a workflow execution using the db_api.acquire_lock function.
|
||||
# The method expires all session objects and returns the up-to-date
|
||||
# workflow execution from the DB.
|
||||
return db_api.acquire_lock(db_models.WorkflowExecution, wf_exec_id)
|
||||
|
|
|
@ -79,9 +79,7 @@ def run_existing_task(task_ex_id, reset=True):
|
|||
action_ex.accepted = False
|
||||
|
||||
# Explicitly change task state to RUNNING.
|
||||
task_ex.state = states.RUNNING
|
||||
task_ex.state_info = None
|
||||
task_ex.processed = False
|
||||
set_task_state(task_ex, states.RUNNING, None, processed=False)
|
||||
|
||||
_run_existing_task(task_ex, task_spec, wf_spec)
|
||||
|
||||
|
@ -128,7 +126,7 @@ def run_new_task(wf_cmd):
|
|||
)
|
||||
|
||||
if task_ex:
|
||||
_set_task_state(task_ex, states.RUNNING)
|
||||
set_task_state(task_ex, states.RUNNING, None)
|
||||
task_ex.in_context = ctx
|
||||
else:
|
||||
task_ex = _create_task_execution(wf_ex, task_spec, ctx)
|
||||
|
@ -502,7 +500,7 @@ def _complete_task(task_ex, task_spec, state, state_info=None):
|
|||
if states.is_completed(task_ex.state):
|
||||
return []
|
||||
|
||||
_set_task_state(task_ex, state, state_info=state_info)
|
||||
set_task_state(task_ex, state, state_info)
|
||||
|
||||
try:
|
||||
data_flow.publish_variables(
|
||||
|
@ -510,13 +508,13 @@ def _complete_task(task_ex, task_spec, state, state_info=None):
|
|||
task_spec
|
||||
)
|
||||
except Exception as e:
|
||||
_set_task_state(task_ex, states.ERROR, state_info=str(e))
|
||||
set_task_state(task_ex, states.ERROR, str(e))
|
||||
|
||||
if not task_spec.get_keep_result():
|
||||
data_flow.destroy_task_result(task_ex)
|
||||
|
||||
|
||||
def _set_task_state(task_ex, state, state_info=None):
|
||||
def set_task_state(task_ex, state, state_info, processed=None):
|
||||
# TODO(rakhmerov): How do we log task result?
|
||||
wf_trace.info(
|
||||
task_ex.workflow_execution,
|
||||
|
@ -525,9 +523,10 @@ def _set_task_state(task_ex, state, state_info=None):
|
|||
)
|
||||
|
||||
task_ex.state = state
|
||||
task_ex.state_info = state_info
|
||||
|
||||
if state_info:
|
||||
task_ex.state_info = state_info
|
||||
if processed is not None:
|
||||
task_ex.processed = processed
|
||||
|
||||
|
||||
def is_task_completed(task_ex, task_spec):
|
||||
|
|
|
@ -13,7 +13,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
from mistral.db.v2 import api as db_api
|
||||
from mistral.db.v2.sqlalchemy import models as db_models
|
||||
from mistral.engine import rpc
|
||||
from mistral.engine import task_handler
|
||||
from mistral import exceptions as exc
|
||||
from mistral.services import scheduler
|
||||
from mistral.utils import wf_trace
|
||||
|
@ -85,7 +87,7 @@ def send_result_to_parent_workflow(wf_ex_id):
|
|||
)
|
||||
|
||||
|
||||
def set_execution_state(wf_ex, state, state_info=None):
|
||||
def set_execution_state(wf_ex, state, state_info=None, set_upstream=False):
|
||||
cur_state = wf_ex.state
|
||||
|
||||
if states.is_valid_transition(cur_state, state):
|
||||
|
@ -106,3 +108,32 @@ def set_execution_state(wf_ex, state, state_info=None):
|
|||
# Workflow result should be accepted by parent workflows (if any)
|
||||
# only if it completed successfully.
|
||||
wf_ex.accepted = wf_ex.state == states.SUCCESS
|
||||
|
||||
# If specified, then recursively set the state of the parent workflow
|
||||
# executions to the same state. Only changing state to RUNNING is
|
||||
# supported.
|
||||
if set_upstream and state == states.RUNNING and wf_ex.task_execution_id:
|
||||
task_ex = db_api.get_task_execution(wf_ex.task_execution_id)
|
||||
|
||||
parent_wf_ex = lock_workflow_execution(task_ex.workflow_execution_id)
|
||||
|
||||
set_execution_state(
|
||||
parent_wf_ex,
|
||||
state,
|
||||
state_info=state_info,
|
||||
set_upstream=set_upstream
|
||||
)
|
||||
|
||||
task_handler.set_task_state(
|
||||
task_ex,
|
||||
state,
|
||||
state_info=None,
|
||||
processed=False
|
||||
)
|
||||
|
||||
|
||||
def lock_workflow_execution(wf_ex_id):
|
||||
# Locks a workflow execution using the db_api.acquire_lock function.
|
||||
# The method expires all session objects and returns the up-to-date
|
||||
# workflow execution from the DB.
|
||||
return db_api.acquire_lock(db_models.WorkflowExecution, wf_ex_id)
|
||||
|
|
|
@ -133,7 +133,6 @@ workflows:
|
|||
action: std.echo output="Task 2"
|
||||
"""
|
||||
|
||||
|
||||
JOIN_WORKBOOK = """
|
||||
---
|
||||
version: '2.0'
|
||||
|
@ -155,6 +154,32 @@ workflows:
|
|||
join: all
|
||||
"""
|
||||
|
||||
SUBFLOW_WORKBOOK = """
|
||||
version: '2.0'
|
||||
name: wb1
|
||||
workflows:
|
||||
wf1:
|
||||
type: direct
|
||||
tasks:
|
||||
t1:
|
||||
action: std.echo output="Task 1"
|
||||
on-success:
|
||||
- t2
|
||||
t2:
|
||||
workflow: wf2
|
||||
on-success:
|
||||
- t3
|
||||
t3:
|
||||
action: std.echo output="Task 3"
|
||||
wf2:
|
||||
type: direct
|
||||
output:
|
||||
result: <% $.wf2_t1 %>
|
||||
tasks:
|
||||
wf2_t1:
|
||||
action: std.echo output="Task 2"
|
||||
"""
|
||||
|
||||
|
||||
class DirectWorkflowRerunTest(base.EngineTestCase):
|
||||
|
||||
|
@ -1018,3 +1043,208 @@ class DirectWorkflowRerunTest(base.EngineTestCase):
|
|||
)
|
||||
|
||||
self.assertEqual(1, len(task_2_action_exs))
|
||||
|
||||
@mock.patch.object(
|
||||
std_actions.EchoAction,
|
||||
'run',
|
||||
mock.MagicMock(
|
||||
side_effect=[
|
||||
'Task 1', # Mock task1 success for initial run.
|
||||
exc.ActionException(), # Mock task2 exception for initial run.
|
||||
'Task 2', # Mock task2 success for rerun.
|
||||
'Task 3' # Mock task3 success.
|
||||
]
|
||||
)
|
||||
)
|
||||
def test_rerun_subflow(self):
|
||||
wb_service.create_workbook_v2(SUBFLOW_WORKBOOK)
|
||||
|
||||
# Run workflow and fail task.
|
||||
wf_ex = self.engine.start_workflow('wb1.wf1', {})
|
||||
self._await(lambda: self.is_execution_error(wf_ex.id))
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
self.assertEqual(states.ERROR, wf_ex.state)
|
||||
self.assertIsNotNone(wf_ex.state_info)
|
||||
self.assertEqual(2, len(wf_ex.task_executions))
|
||||
|
||||
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
|
||||
task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2')
|
||||
|
||||
self.assertEqual(states.SUCCESS, task_1_ex.state)
|
||||
self.assertEqual(states.ERROR, task_2_ex.state)
|
||||
self.assertIsNotNone(task_2_ex.state_info)
|
||||
|
||||
# Resume workflow and re-run failed task.
|
||||
self.engine.rerun_workflow(wf_ex.id, task_2_ex.id)
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
self.assertEqual(states.RUNNING, wf_ex.state)
|
||||
self.assertIsNone(wf_ex.state_info)
|
||||
|
||||
# Wait for the workflow to succeed.
|
||||
self._await(lambda: self.is_execution_success(wf_ex.id))
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
self.assertEqual(states.SUCCESS, wf_ex.state)
|
||||
self.assertIsNone(wf_ex.state_info)
|
||||
self.assertEqual(3, len(wf_ex.task_executions))
|
||||
|
||||
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
|
||||
task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2')
|
||||
task_3_ex = self._assert_single_item(wf_ex.task_executions, name='t3')
|
||||
|
||||
# Check action executions of task 1.
|
||||
self.assertEqual(states.SUCCESS, task_1_ex.state)
|
||||
|
||||
task_1_action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_1_ex.id)
|
||||
|
||||
self.assertEqual(1, len(task_1_action_exs))
|
||||
self.assertEqual(states.SUCCESS, task_1_action_exs[0].state)
|
||||
|
||||
# Check action executions of task 2.
|
||||
self.assertEqual(states.SUCCESS, task_2_ex.state)
|
||||
self.assertIsNone(task_2_ex.state_info)
|
||||
|
||||
task_2_action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_2_ex.id)
|
||||
|
||||
self.assertEqual(2, len(task_2_action_exs))
|
||||
self.assertEqual(states.ERROR, task_2_action_exs[0].state)
|
||||
self.assertEqual(states.SUCCESS, task_2_action_exs[1].state)
|
||||
|
||||
# Check action executions of task 3.
|
||||
self.assertEqual(states.SUCCESS, task_3_ex.state)
|
||||
|
||||
task_3_action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_3_ex.id)
|
||||
|
||||
self.assertEqual(1, len(task_3_action_exs))
|
||||
self.assertEqual(states.SUCCESS, task_3_action_exs[0].state)
|
||||
|
||||
@mock.patch.object(
|
||||
std_actions.EchoAction,
|
||||
'run',
|
||||
mock.MagicMock(
|
||||
side_effect=[
|
||||
'Task 1', # Mock task1 success for initial run.
|
||||
exc.ActionException(), # Mock task2 exception for initial run.
|
||||
'Task 2', # Mock task2 success for rerun.
|
||||
'Task 3' # Mock task3 success.
|
||||
]
|
||||
)
|
||||
)
|
||||
def test_rerun_subflow_task(self):
|
||||
wb_service.create_workbook_v2(SUBFLOW_WORKBOOK)
|
||||
|
||||
# Run workflow and fail task.
|
||||
wf_ex = self.engine.start_workflow('wb1.wf1', {})
|
||||
self._await(lambda: self.is_execution_error(wf_ex.id))
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
self.assertEqual(states.ERROR, wf_ex.state)
|
||||
self.assertIsNotNone(wf_ex.state_info)
|
||||
self.assertEqual(2, len(wf_ex.task_executions))
|
||||
|
||||
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
|
||||
task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2')
|
||||
|
||||
self.assertEqual(states.SUCCESS, task_1_ex.state)
|
||||
self.assertEqual(states.ERROR, task_2_ex.state)
|
||||
self.assertIsNotNone(task_2_ex.state_info)
|
||||
|
||||
# Get subworkflow and related task
|
||||
sub_wf_exs = db_api.get_workflow_executions(
|
||||
task_execution_id=task_2_ex.id
|
||||
)
|
||||
|
||||
sub_wf_ex = sub_wf_exs[0]
|
||||
|
||||
self.assertEqual(states.ERROR, sub_wf_ex.state)
|
||||
self.assertIsNotNone(sub_wf_ex.state_info)
|
||||
self.assertEqual(1, len(sub_wf_ex.task_executions))
|
||||
|
||||
sub_wf_task_ex = self._assert_single_item(
|
||||
sub_wf_ex.task_executions,
|
||||
name='wf2_t1'
|
||||
)
|
||||
|
||||
self.assertEqual(states.ERROR, sub_wf_task_ex.state)
|
||||
self.assertIsNotNone(sub_wf_task_ex.state_info)
|
||||
|
||||
# Resume workflow and re-run failed subworkflow task.
|
||||
self.engine.rerun_workflow(sub_wf_ex.id, sub_wf_task_ex.id)
|
||||
sub_wf_ex = db_api.get_workflow_execution(sub_wf_ex.id)
|
||||
|
||||
self.assertEqual(states.RUNNING, sub_wf_ex.state)
|
||||
self.assertIsNone(sub_wf_ex.state_info)
|
||||
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
self.assertEqual(states.RUNNING, wf_ex.state)
|
||||
self.assertIsNone(wf_ex.state_info)
|
||||
|
||||
# Wait for the subworkflow to succeed.
|
||||
self._await(lambda: self.is_execution_success(sub_wf_ex.id))
|
||||
sub_wf_ex = db_api.get_workflow_execution(sub_wf_ex.id)
|
||||
|
||||
self.assertEqual(states.SUCCESS, sub_wf_ex.state)
|
||||
self.assertIsNone(sub_wf_ex.state_info)
|
||||
self.assertEqual(1, len(sub_wf_ex.task_executions))
|
||||
|
||||
sub_wf_task_ex = self._assert_single_item(
|
||||
sub_wf_ex.task_executions,
|
||||
name='wf2_t1'
|
||||
)
|
||||
|
||||
# Check action executions of subworkflow task.
|
||||
self.assertEqual(states.SUCCESS, sub_wf_task_ex.state)
|
||||
self.assertIsNone(sub_wf_task_ex.state_info)
|
||||
|
||||
sub_wf_task_ex_action_exs = db_api.get_action_executions(
|
||||
task_execution_id=sub_wf_task_ex.id)
|
||||
|
||||
self.assertEqual(2, len(sub_wf_task_ex_action_exs))
|
||||
self.assertEqual(states.ERROR, sub_wf_task_ex_action_exs[0].state)
|
||||
self.assertEqual(states.SUCCESS, sub_wf_task_ex_action_exs[1].state)
|
||||
|
||||
# Wait for the main workflow to succeed.
|
||||
self._await(lambda: self.is_execution_success(wf_ex.id))
|
||||
wf_ex = db_api.get_workflow_execution(wf_ex.id)
|
||||
|
||||
self.assertEqual(states.SUCCESS, wf_ex.state)
|
||||
self.assertIsNone(wf_ex.state_info)
|
||||
self.assertEqual(3, len(wf_ex.task_executions))
|
||||
|
||||
task_1_ex = self._assert_single_item(wf_ex.task_executions, name='t1')
|
||||
task_2_ex = self._assert_single_item(wf_ex.task_executions, name='t2')
|
||||
task_3_ex = self._assert_single_item(wf_ex.task_executions, name='t3')
|
||||
|
||||
# Check action executions of task 1.
|
||||
self.assertEqual(states.SUCCESS, task_1_ex.state)
|
||||
|
||||
task_1_action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_1_ex.id)
|
||||
|
||||
self.assertEqual(1, len(task_1_action_exs))
|
||||
self.assertEqual(states.SUCCESS, task_1_action_exs[0].state)
|
||||
|
||||
# Check action executions of task 2.
|
||||
self.assertEqual(states.SUCCESS, task_2_ex.state)
|
||||
self.assertIsNone(task_2_ex.state_info)
|
||||
|
||||
task_2_action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_2_ex.id)
|
||||
|
||||
self.assertEqual(1, len(task_2_action_exs))
|
||||
self.assertEqual(states.SUCCESS, task_1_action_exs[0].state)
|
||||
|
||||
# Check action executions of task 3.
|
||||
self.assertEqual(states.SUCCESS, task_3_ex.state)
|
||||
|
||||
task_3_action_exs = db_api.get_action_executions(
|
||||
task_execution_id=task_3_ex.id)
|
||||
|
||||
self.assertEqual(1, len(task_3_action_exs))
|
||||
self.assertEqual(states.SUCCESS, task_3_action_exs[0].state)
|
||||
|
|
Loading…
Reference in New Issue