Merge "Task skipping feature"

This commit is contained in:
Zuul 2023-02-14 13:22:17 +00:00 committed by Gerrit Code Review
commit 4289317a91
25 changed files with 552 additions and 52 deletions

View File

@ -350,3 +350,33 @@ create many workflows with the same name as long as they are in different
namespaces.
See more at :doc:`Workflow namespaces </user/wf_namespaces>`
Task skip
---------
Mistral has an ability to skip tasks in ERROR state.
The task moves from ERROR state to SKIPPED state, publish variables from
publish-on-skip section, and the workflow continues from tasks specified in
on-skip section.
To configure task's behavior on skip, fill the following attributes in
the task definition:
* *on-skip* - Optional. This parameter specifies which tasks should be started
after skipping this task.
* *publish-on-skip* - Optional. This parameter specifies which variables should
be published after skipping this task.
It is also possible to skip task which does not have predefined parameters
described above, in this case task will not publish anything and will continue
by *on-success* branch. It could be not safe for next tasks, because they
probably would not have some inputs, so think twice before skipping such tasks.
Task skip could be performed by following request::
PUT /v2/tasks
{
"id": "<task-id>",
"state": "SKIPPED"
}

View File

@ -557,7 +557,7 @@ class Task(resource.Resource):
finished_at = wtypes.text
# Add this param to make Mistral API work with WSME 0.8.0 or higher version
reset = wsme.wsattr(bool, mandatory=True)
reset = wsme.wsattr(bool)
env = types.jsontype

View File

@ -19,6 +19,7 @@ import json
from oslo_log import log as logging
from pecan import rest
from wsme import types as wtypes
from wsme import Unset
import wsmeext.pecan as wsme_pecan
from mistral.api import access_control as acl
@ -341,7 +342,7 @@ class TasksController(rest.RestController):
task_ex = db_api.get_task_execution(id)
task_spec = spec_parser.get_task_spec(task_ex.spec)
task_name = task.name or None
reset = task.reset
reset = task.reset or None
env = task.env or None
if task_name and task_name != task_ex.name:
@ -366,10 +367,10 @@ class TasksController(rest.RestController):
if wf_name and wf_name != wf_ex.name:
raise exc.WorkflowException('Workflow name does not match.')
if task.state != states.RUNNING:
if task.state != states.RUNNING and task.state != states.SKIPPED:
raise exc.WorkflowException(
'Invalid task state. '
'Only updating task to rerun is supported.'
'Only updating task to RUNNING or SKIPPED is supported.'
)
if task_ex.state != states.ERROR:
@ -378,14 +379,21 @@ class TasksController(rest.RestController):
' Only updating task to rerun is supported.'
)
if not task_spec.get_with_items() and not reset:
raise exc.WorkflowException(
'Only with-items task has the option to not reset.'
)
if task.state == states.RUNNING:
if task.reset is Unset:
raise exc.WorkflowException(
'Reset field is mandatory to rerun task.'
)
if not task_spec.get_with_items() and not reset:
raise exc.WorkflowException(
'Only with-items task has the option to not reset.'
)
rpc.get_engine_client().rerun_workflow(
task_ex.id,
reset=reset,
skip=(task.state == states.SKIPPED),
env=env
)

View File

@ -1197,6 +1197,7 @@ def _get_completed_task_executions_query(kwargs):
models.TaskExecution.state.in_(
[states.ERROR,
states.CANCELLED,
states.SKIPPED,
states.SUCCESS]
)
)

View File

@ -101,12 +101,14 @@ class Engine(object, metaclass=abc.ABCMeta):
raise NotImplementedError
@abc.abstractmethod
def rerun_workflow(self, task_ex_id, reset=True, env=None):
def rerun_workflow(self, task_ex_id, reset=True, skip=False, env=None):
"""Rerun workflow from the specified task.
:param task_ex_id: Task execution id.
:param reset: If True, reset task state including deleting its action
executions.
:param skip: If True, then skip failed task and continue workflow
execution.
:param env: Workflow environment.
:return: Workflow execution object.
"""

View File

@ -197,13 +197,19 @@ class DefaultEngine(base.Engine):
@db_utils.retry_on_db_error
@post_tx_queue.run
def rerun_workflow(self, task_ex_id, reset=True, env=None):
def rerun_workflow(self, task_ex_id, reset=True, skip=False, env=None):
with db_api.transaction():
task_ex = db_api.get_task_execution(task_ex_id)
wf_ex = task_ex.workflow_execution
wf_handler.rerun_workflow(wf_ex, task_ex, reset=reset, env=env)
wf_handler.rerun_workflow(
wf_ex,
task_ex,
reset=reset,
skip=skip,
env=env
)
return wf_ex.get_clone()

View File

@ -147,6 +147,8 @@ def _process_commands(wf_ex, cmds):
if isinstance(cmd, (commands.RunTask, commands.RunExistingTask)):
task_handler.run_task(cmd)
elif isinstance(cmd, commands.SkipTask):
task_handler.skip_task(cmd)
elif isinstance(cmd, commands.SetWorkflowState):
wf_handler.set_workflow_state(wf_ex, cmd.new_state, cmd.msg)
else:

View File

@ -246,12 +246,15 @@ class EngineServer(service_base.MistralService):
return self.engine.pause_workflow(wf_ex_id)
def rerun_workflow(self, rpc_ctx, task_ex_id, reset=True, env=None):
def rerun_workflow(self, rpc_ctx, task_ex_id, reset=True,
skip=False, env=None):
"""Receives calls over RPC to rerun workflows on engine.
:param rpc_ctx: RPC request context.
:param task_ex_id: Task execution id.
:param reset: If true, then purge action execution for the task.
:param skip: If True, then skip failed task and continue workflow
execution.
:param env: Environment variables to update.
:return: Workflow execution.
"""
@ -260,7 +263,7 @@ class EngineServer(service_base.MistralService):
task_ex_id
)
return self.engine.rerun_workflow(task_ex_id, reset, env)
return self.engine.rerun_workflow(task_ex_id, reset, skip, env)
def resume_workflow(self, rpc_ctx, wf_ex_id, env=None):
"""Receives calls over RPC to resume workflows on engine.

View File

@ -87,6 +87,18 @@ def run_task(wf_cmd):
_check_affected_tasks(task)
@profiler.trace('task-handler-skip-task', hide_args=True)
def skip_task(wf_cmd):
"""Skip workflow task.
:param wf_cmd: Workflow command.
"""
task = _build_task_from_command(wf_cmd)
task.complete(states.SKIPPED, "Task was skipped.", skip=True)
_check_affected_tasks(task)
return
def mark_task_running(task_ex, wf_spec):
task = build_task_from_execution(wf_spec, task_ex)
@ -367,6 +379,19 @@ def _build_task_from_command(cmd):
return task
if isinstance(cmd, wf_cmds.SkipTask):
task = _create_task(
cmd.wf_ex,
cmd.wf_spec,
spec_parser.get_task_spec(cmd.task_ex.spec),
cmd.ctx,
task_ex=cmd.task_ex,
unique_key=cmd.task_ex.unique_key,
triggered_by=cmd.triggered_by,
)
return task
raise exc.MistralError('Unsupported workflow command: %s' % cmd)

View File

@ -351,7 +351,7 @@ class Task(object, metaclass=abc.ABCMeta):
return True
@profiler.trace('task-complete')
def complete(self, state, state_info=None):
def complete(self, state, state_info=None, skip=False):
"""Complete task and set specified state.
Method sets specified task state and runs all necessary post
@ -365,7 +365,7 @@ class Task(object, metaclass=abc.ABCMeta):
assert self.task_ex
# Ignore if task already completed.
if self.is_completed():
if self.is_completed() and not states.is_skipped(state):
return
# If we were unable to change the task state it means that it was
@ -383,7 +383,8 @@ class Task(object, metaclass=abc.ABCMeta):
if hasattr(ex, 'output'):
ex.output = {}
self._after_task_complete()
if not states.is_skipped(state):
self._after_task_complete()
# Ignore DELAYED state.
if self.task_ex.state == states.RUNNING_DELAYED:

View File

@ -206,7 +206,7 @@ def pause_workflow(wf_ex, msg=None):
wf.pause(msg=msg)
def rerun_workflow(wf_ex, task_ex, reset=True, env=None):
def rerun_workflow(wf_ex, task_ex, reset=True, skip=False, env=None):
if wf_ex.state == states.PAUSED:
return wf_ex.get_clone()
@ -217,7 +217,7 @@ def rerun_workflow(wf_ex, task_ex, reset=True, env=None):
task = task_handler.build_task_from_execution(wf.wf_spec, task_ex)
wf.rerun(task, reset=reset, env=env)
wf.rerun(task, reset=reset, skip=skip, env=env)
_schedule_check_and_fix_integrity(
wf_ex,

View File

@ -238,13 +238,15 @@ class Workflow(object, metaclass=abc.ABCMeta):
self.wf_spec.__class__.__name__
)
def rerun(self, task, reset=True, env=None):
def rerun(self, task, reset=True, skip=False, env=None):
"""Rerun workflow from the given task.
:param task: An engine task associated with the task the workflow
needs to rerun from.
:param reset: If True, reset task state including deleting its action
executions.
:param skip: If True, then skip failed task and continue workflow
execution.
:param env: Environment.
"""
@ -257,7 +259,10 @@ class Workflow(object, metaclass=abc.ABCMeta):
wf_ctrl = wf_base.get_controller(self.wf_ex)
# Calculate commands to process next.
cmds = wf_ctrl.rerun_tasks([task.task_ex], reset=reset)
if skip:
cmds = wf_ctrl.skip_tasks([task.task_ex])
else:
cmds = wf_ctrl.rerun_tasks([task.task_ex], reset=reset)
if cmds:
task.cleanup_runtime_context()

View File

@ -40,6 +40,7 @@ class TaskDefaultsSpec(base.BaseSpec):
"on-complete": on_clause.OnClauseSpec.get_schema(),
"on-success": on_clause.OnClauseSpec.get_schema(),
"on-error": on_clause.OnClauseSpec.get_schema(),
"on-skip": on_clause.OnClauseSpec.get_schema(),
"safe-rerun": types.EXPRESSION_OR_BOOLEAN,
"requires": {
"oneOf": [types.NONEMPTY_STRING, types.UNIQUE_STRING_LIST]
@ -71,6 +72,7 @@ class TaskDefaultsSpec(base.BaseSpec):
self._on_complete = self._spec_property('on-complete', on_spec_cls)
self._on_success = self._spec_property('on-success', on_spec_cls)
self._on_error = self._spec_property('on-error', on_spec_cls)
self._on_skip = self._spec_property('on-skip', on_spec_cls)
self._safe_rerun = data.get('safe-rerun')
@ -88,6 +90,7 @@ class TaskDefaultsSpec(base.BaseSpec):
self._validate_transitions(self._on_complete)
self._validate_transitions(self._on_success)
self._validate_transitions(self._on_error)
self._validate_transitions(self._on_skip)
def _validate_transitions(self, on_clause_spec):
val = on_clause_spec.get_next() if on_clause_spec else []
@ -110,6 +113,9 @@ class TaskDefaultsSpec(base.BaseSpec):
def get_on_error(self):
return self._on_error
def get_on_skip(self):
return self._on_skip
def get_safe_rerun(self):
return self._safe_rerun

View File

@ -75,6 +75,7 @@ class TaskSpec(base.BaseSpec):
},
"publish": types.NONEMPTY_DICT,
"publish-on-error": types.NONEMPTY_DICT,
"publish-on-skip": types.NONEMPTY_DICT,
"retry": retry_policy.RetrySpec.get_schema(),
"wait-before": types.EXPRESSION_OR_POSITIVE_INTEGER,
"wait-after": types.EXPRESSION_OR_POSITIVE_INTEGER,
@ -121,6 +122,7 @@ class TaskSpec(base.BaseSpec):
self._with_items = self._get_with_items_as_dict()
self._publish = data.get('publish', {})
self._publish_on_error = data.get('publish-on-error', {})
self._publish_on_skip = data.get('publish-on-skip', {})
self._policies = self._group_spec(
policies.PoliciesSpec,
'retry',
@ -153,6 +155,7 @@ class TaskSpec(base.BaseSpec):
self.validate_expr(self._data.get('input', {}))
self.validate_expr(self._data.get('publish', {}))
self.validate_expr(self._data.get('publish-on-error', {}))
self.validate_expr(self._data.get('publish-on-skip', {}))
self.validate_expr(self._data.get('keep-result', {}))
self.validate_expr(self._data.get('safe-rerun', {}))
@ -260,6 +263,11 @@ class TaskSpec(base.BaseSpec):
{'branch': self._publish_on_error},
validate=self._validate
)
elif state == states.SKIPPED and self._publish_on_skip:
spec = publish.PublishSpec(
{'branch': self._publish_on_skip},
validate=self._validate
)
return spec
def get_keep_result(self):
@ -288,7 +296,8 @@ class DirectWorkflowTaskSpec(TaskSpec):
},
"on-complete": on_clause.OnClauseSpec.get_schema(),
"on-success": on_clause.OnClauseSpec.get_schema(),
"on-error": on_clause.OnClauseSpec.get_schema()
"on-error": on_clause.OnClauseSpec.get_schema(),
"on-skip": on_clause.OnClauseSpec.get_schema()
}
}
@ -307,12 +316,14 @@ class DirectWorkflowTaskSpec(TaskSpec):
self._on_complete = self._spec_property('on-complete', on_spec_cls)
self._on_success = self._spec_property('on-success', on_spec_cls)
self._on_error = self._spec_property('on-error', on_spec_cls)
self._on_skip = self._spec_property('on-skip', on_spec_cls)
def validate_semantics(self):
# Validate YAQL expressions.
self._validate_transitions(self._on_complete)
self._validate_transitions(self._on_success)
self._validate_transitions(self._on_error)
self._validate_transitions(self._on_skip)
if self._join:
join_task_name = self.get_name()
@ -345,6 +356,8 @@ class DirectWorkflowTaskSpec(TaskSpec):
on_clause = self._on_success
elif state == states.ERROR:
on_clause = self._on_error
elif state == states.SKIPPED:
on_clause = self._on_skip
if on_clause and on_clause.get_publish():
if spec:
@ -366,6 +379,9 @@ class DirectWorkflowTaskSpec(TaskSpec):
def get_on_error(self):
return self._on_error
def get_on_skip(self):
return self._on_skip
class ReverseWorkflowTaskSpec(TaskSpec):
_polymorphic_value = 'reverse'

View File

@ -287,6 +287,9 @@ class DirectWorkflowSpec(WorkflowSpec):
for tup in self.get_on_complete_clause(task_name):
t_names.add(tup[0])
for tup in self.get_on_skip_clause(task_name):
t_names.add(tup[0])
return t_names
def transition_exists(self, from_task_name, to_task_name):
@ -313,6 +316,25 @@ class DirectWorkflowSpec(WorkflowSpec):
return result
def get_on_skip_clause(self, t_name):
result = []
on_clause = self.get_tasks()[t_name].get_on_skip()
if on_clause:
result = on_clause.get_next()
if not result:
t_defaults = self.get_task_defaults()
if t_defaults and t_defaults.get_on_skip():
result = self._remove_task_from_clause(
t_defaults.get_on_skip().get_next(),
t_name
)
return result
def get_on_success_clause(self, t_name):
result = []

View File

@ -40,6 +40,7 @@ TASK_CANCELLED = 'TASK_CANCELLED'
TASK_PAUSED = 'TASK_PAUSED'
TASK_RESUMED = 'TASK_RESUMED'
TASK_RERUN = 'TASK_RERUN'
TASK_SKIPPED = 'TASK_SKIPPED'
TASKS = [
TASK_LAUNCHED,
@ -48,7 +49,8 @@ TASKS = [
TASK_CANCELLED,
TASK_PAUSED,
TASK_RESUMED,
TASK_RERUN
TASK_RERUN,
TASK_SKIPPED
]
EVENTS = WORKFLOWS + TASKS
@ -66,7 +68,8 @@ _TASK_EVENT_MAP = {
states.SUCCESS: {'ANY': TASK_SUCCEEDED},
states.ERROR: {'ANY': TASK_FAILED},
states.CANCELLED: {'ANY': TASK_CANCELLED},
states.PAUSED: {'ANY': TASK_PAUSED}
states.PAUSED: {'ANY': TASK_PAUSED},
states.SKIPPED: {'ANY': TASK_SKIPPED}
}
# Describes what state transition matches to what event.

View File

@ -284,7 +284,7 @@ class EngineClient(eng.Engine):
)
@base.wrap_messaging_exception
def rerun_workflow(self, task_ex_id, reset=True, env=None):
def rerun_workflow(self, task_ex_id, reset=True, skip=False, env=None):
"""Rerun the workflow.
This method reruns workflow with the given execution id
@ -293,6 +293,8 @@ class EngineClient(eng.Engine):
:param task_ex_id: Task execution id.
:param reset: If true, then reset task execution state and purge
action execution for the task.
:param skip: If True, then skip failed task and continue workflow
execution.
:param env: Environment variables to update.
:return: Workflow execution.
"""
@ -307,6 +309,7 @@ class EngineClient(eng.Engine):
'rerun_workflow',
task_ex_id=task_ex_id,
reset=reset,
skip=skip,
env=env
)

View File

@ -138,6 +138,11 @@ ERROR_ITEMS_TASK_EX['state'] = 'ERROR'
ERROR_TASK = copy.deepcopy(TASK)
ERROR_TASK['state'] = 'ERROR'
SKIPPED_TASK_EX = copy.deepcopy(TASK_EX)
SKIPPED_TASK_EX['state'] = 'SKIPPED'
SKIPPED_TASK = copy.deepcopy(TASK)
SKIPPED_TASK['state'] = 'SKIPPED'
BROKEN_TASK = copy.deepcopy(TASK)
RERUN_TASK = {
@ -145,6 +150,11 @@ RERUN_TASK = {
'state': 'RUNNING'
}
SKIP_TASK = {
'id': '123',
'state': 'SKIPPED'
}
MOCK_WF_EX = mock.MagicMock(return_value=WF_EX)
TASK_EX.workflow_execution = WF_EX
MOCK_TASK = mock.MagicMock(return_value=TASK_EX)
@ -249,7 +259,7 @@ class TestTasksController(base.APITest):
mock.MagicMock(side_effect=[ERROR_TASK_EX, TASK_EX])
)
@mock.patch.object(rpc.EngineClient, 'rerun_workflow', MOCK_WF_EX)
def test_put(self):
def test_put_rerun(self):
params = copy.deepcopy(RERUN_TASK)
params['reset'] = True
@ -261,6 +271,7 @@ class TestTasksController(base.APITest):
rpc.EngineClient.rerun_workflow.assert_called_with(
TASK_EX.id,
reset=params['reset'],
skip=False,
env=None
)
@ -271,7 +282,29 @@ class TestTasksController(base.APITest):
mock.MagicMock(side_effect=[ERROR_TASK_EX, TASK_EX])
)
@mock.patch.object(rpc.EngineClient, 'rerun_workflow', MOCK_WF_EX)
def test_put_missing_reset(self):
def test_put_skip(self):
params = copy.deepcopy(SKIP_TASK)
resp = self.app.put_json('/v2/tasks/123', params=params)
self.assertEqual(200, resp.status_int)
self.assertDictEqual(TASK, resp.json)
rpc.EngineClient.rerun_workflow.assert_called_with(
TASK_EX.id,
reset=None,
skip=True,
env=None
)
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(
db_api,
'get_task_execution',
mock.MagicMock(side_effect=[ERROR_TASK_EX, TASK_EX])
)
@mock.patch.object(rpc.EngineClient, 'rerun_workflow', MOCK_WF_EX)
def test_put_missing_reset_rerun(self):
params = copy.deepcopy(RERUN_TASK)
resp = self.app.put_json(
@ -281,7 +314,10 @@ class TestTasksController(base.APITest):
self.assertEqual(400, resp.status_int)
self.assertIn('faultstring', resp.json)
self.assertIn('Mandatory field missing', resp.json['faultstring'])
self.assertIn(
'Reset field is mandatory to rerun task',
resp.json['faultstring']
)
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(
@ -290,7 +326,7 @@ class TestTasksController(base.APITest):
mock.MagicMock(side_effect=[ERROR_ITEMS_TASK_EX, WITH_ITEMS_TASK_EX])
)
@mock.patch.object(rpc.EngineClient, 'rerun_workflow', MOCK_WF_EX)
def test_put_with_items(self):
def test_put_with_items_rerun(self):
params = copy.deepcopy(RERUN_TASK)
params['reset'] = False
@ -319,12 +355,13 @@ class TestTasksController(base.APITest):
rpc.EngineClient.rerun_workflow.assert_called_with(
TASK_EX.id,
reset=params['reset'],
skip=False,
env=json.loads(params['env'])
)
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_task_execution', MOCK_TASK)
def test_put_current_task_not_in_error(self):
def test_put_current_task_not_in_error_rerun(self):
params = copy.deepcopy(RERUN_TASK)
params['reset'] = True
@ -338,10 +375,25 @@ class TestTasksController(base.APITest):
self.assertIn('faultstring', resp.json)
self.assertIn('execution must be in ERROR', resp.json['faultstring'])
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_task_execution', MOCK_TASK)
def test_put_current_task_not_in_error_skip(self):
params = copy.deepcopy(SKIP_TASK)
resp = self.app.put_json(
'/v2/tasks/123',
params=params,
expect_errors=True
)
self.assertEqual(400, resp.status_int)
self.assertIn('faultstring', resp.json)
self.assertIn('execution must be in ERROR', resp.json['faultstring'])
@mock.patch.object(rpc.EngineClient, 'rerun_workflow', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_task_execution', MOCK_ERROR_TASK)
def test_put_current_task_in_error(self):
def test_put_current_task_in_error_rerun(self):
params = copy.deepcopy(RERUN_TASK)
params['reset'] = True
params['env'] = '{"k1": "def"}'
@ -352,7 +404,7 @@ class TestTasksController(base.APITest):
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_task_execution', MOCK_ERROR_TASK)
def test_put_invalid_state(self):
def test_put_invalid_state_rerun(self):
params = copy.deepcopy(RERUN_TASK)
params['state'] = states.IDLE
params['reset'] = True
@ -369,7 +421,7 @@ class TestTasksController(base.APITest):
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_task_execution', MOCK_ERROR_TASK)
def test_put_invalid_reset(self):
def test_put_invalid_reset_rerun(self):
params = copy.deepcopy(RERUN_TASK)
params['reset'] = False
@ -383,23 +435,24 @@ class TestTasksController(base.APITest):
self.assertIn('faultstring', resp.json)
self.assertIn('Only with-items task', resp.json['faultstring'])
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_task_execution', MOCK_ERROR_TASK)
def test_put_valid_state(self):
params = copy.deepcopy(RERUN_TASK)
params['state'] = states.RUNNING
params['reset'] = True
@mock.patch.object(rpc.EngineClient, 'rerun_workflow', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_task_execution', MOCK_ERROR_TASK)
def test_put_valid_state_rerun(self):
params = copy.deepcopy(RERUN_TASK)
params['state'] = states.RUNNING
params['reset'] = True
resp = self.app.put_json(
'/v2/tasks/123',
params=params
)
resp = self.app.put_json(
'/v2/tasks/123',
params=params
)
self.assertEqual(200, resp.status_int)
self.assertEqual(200, resp.status_int)
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_task_execution', MOCK_ERROR_TASK)
def test_put_mismatch_task_name(self):
def test_put_mismatch_task_name_rerun(self):
params = copy.deepcopy(RERUN_TASK)
params['name'] = 'abc'
params['reset'] = True
@ -417,7 +470,7 @@ class TestTasksController(base.APITest):
@mock.patch.object(rpc.EngineClient, 'rerun_workflow', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_task_execution', MOCK_ERROR_TASK)
def test_put_match_task_name(self):
def test_put_match_task_name_rerun(self):
params = copy.deepcopy(RERUN_TASK)
params['name'] = 'task'
params['reset'] = True
@ -432,7 +485,7 @@ class TestTasksController(base.APITest):
@mock.patch.object(db_api, 'get_workflow_execution', MOCK_WF_EX)
@mock.patch.object(db_api, 'get_task_execution', MOCK_ERROR_TASK)
def test_put_mismatch_workflow_name(self):
def test_put_mismatch_workflow_name_rerun(self):
params = copy.deepcopy(RERUN_TASK)
params['workflow_name'] = 'xyz'
params['reset'] = True

View File

@ -0,0 +1,242 @@
# Copyright 2022 - NetCracker Technology Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from mistral.db.v2 import api as db_api
from mistral.services import workflows as wf_service
from mistral.tests.unit.engine import base
from mistral.workflow import states
class TaskSkipTest(base.EngineTestCase):
def test_basic_task_skip(self):
workflow = """
version: '2.0'
wf:
tasks:
t1:
action: std.fail
on-skip: t2
on-success: t3
t2:
action: std.noop
t3:
action: std.noop
"""
wf_service.create_workflows(workflow)
wf_ex = self.engine.start_workflow('wf')
self.await_workflow_error(wf_ex.id)
# Check that on-skip branch was not executed
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(1, len(task_execs))
t1_ex = self._assert_single_item(
task_execs,
name='t1',
state=states.ERROR
)
# Skip t1 and wait for wf to complete
self.engine.rerun_workflow(t1_ex.id, skip=True)
self.await_workflow_success(wf_ex.id)
# Check that on-skip branch was executed
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(2, len(task_execs))
self._assert_single_item(task_execs, name='t1', state=states.SKIPPED)
self._assert_single_item(task_execs, name='t2', state=states.SUCCESS)
def test_task_skip_on_workflow_tail(self):
workflow = """
version: '2.0'
wf:
tasks:
t0:
action: std.noop
on-success: t1
t1:
action: std.fail
"""
wf_service.create_workflows(workflow)
wf_ex = self.engine.start_workflow('wf')
self.await_workflow_error(wf_ex.id)
# Check that on-skip branch was not executed
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(2, len(task_execs))
t1_ex = self._assert_single_item(
task_execs,
name='t1',
state=states.ERROR
)
# Skip t1 and wait for wf to complete
self.engine.rerun_workflow(t1_ex.id, skip=True)
self.await_workflow_success(wf_ex.id)
# Check that on-skip branch was executed
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(2, len(task_execs))
self._assert_single_item(task_execs, name='t0', state=states.SUCCESS)
self._assert_single_item(task_execs, name='t1', state=states.SKIPPED)
def test_skip_subworkflow(self):
workflow = """
version: '2.0'
wf:
tasks:
t0:
action: std.noop
on-success: t1
t1:
workflow: subwf
subwf:
tasks:
t0:
action: std.fail
"""
wf_service.create_workflows(workflow)
wf_ex = self.engine.start_workflow('wf')
self.await_workflow_error(wf_ex.id)
# Check that on-skip branch was not executed
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(2, len(task_execs))
t1_ex = self._assert_single_item(
task_execs,
name='t1',
state=states.ERROR
)
# Skip t1 and wait for wf to complete
self.engine.rerun_workflow(t1_ex.id, skip=True)
self.await_workflow_success(wf_ex.id)
# Check that on-skip branch was executed
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(2, len(task_execs))
self._assert_single_item(task_execs, name='t0', state=states.SUCCESS)
self._assert_single_item(task_execs, name='t1', state=states.SKIPPED)
def test_publish_on_skip(self):
workflow = """
version: '2.0'
wf:
tasks:
t0:
action: std.noop
on-success: t1
t1:
action: std.fail
publish:
success: 1
publish-on-error:
error: 1
publish-on-skip:
skip: 1
"""
wf_service.create_workflows(workflow)
wf_ex = self.engine.start_workflow('wf')
self.await_workflow_error(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(states.ERROR, wf_ex.state)
self.assertEqual(2, len(task_execs))
t1_ex = self._assert_single_item(
task_execs,
name='t1',
state=states.ERROR
)
publish_before_skip = {"error": 1}
self.assertDictEqual(publish_before_skip, t1_ex.published)
# Skip t1 and wait for wf to complete
self.engine.rerun_workflow(t1_ex.id, skip=True)
self.await_workflow_success(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
t1_ex = self._assert_single_item(
task_execs,
name='t1',
state=states.SKIPPED
)
publish_after_skip = {"skip": 1}
self.assertDictEqual(publish_after_skip, t1_ex.published)
self.assertDictEqual(publish_after_skip, wf_ex.output)
def test_retry_dont_conflict_with_skip(self):
workflow = """
version: '2.0'
wf:
tasks:
t1:
action: std.fail
on-skip: skip
retry:
count: 2
delay: 0
skip:
action: std.noop
"""
wf_service.create_workflows(workflow)
wf_ex = self.engine.start_workflow('wf')
self.await_workflow_error(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(1, len(task_execs))
t1_ex = self._assert_single_item(
task_execs,
name='t1',
state=states.ERROR
)
self.engine.rerun_workflow(t1_ex.id, skip=True)
self.await_workflow_success(wf_ex.id)

View File

@ -57,12 +57,14 @@ class StatesModuleTest(base.BaseTest):
self.assertFalse(s.is_valid_transition(s.SUCCESS, s.RUNNING))
self.assertFalse(s.is_valid_transition(s.SUCCESS, s.ERROR))
self.assertFalse(s.is_valid_transition(s.SUCCESS, s.PAUSED))
self.assertFalse(s.is_valid_transition(s.SUCCESS, s.SKIPPED))
self.assertFalse(s.is_valid_transition(s.SUCCESS, s.RUNNING_DELAYED))
self.assertFalse(s.is_valid_transition(s.SUCCESS, s.IDLE))
# From ERROR
self.assertTrue(s.is_valid_transition(s.ERROR, s.ERROR))
self.assertTrue(s.is_valid_transition(s.ERROR, s.RUNNING))
self.assertTrue(s.is_valid_transition(s.ERROR, s.SKIPPED))
self.assertFalse(s.is_valid_transition(s.ERROR, s.PAUSED))
self.assertFalse(s.is_valid_transition(s.ERROR, s.RUNNING_DELAYED))
self.assertFalse(s.is_valid_transition(s.ERROR, s.SUCCESS))

View File

@ -158,6 +158,24 @@ class WorkflowController(object):
return cmds
def skip_tasks(self, task_execs):
"""Gets commands to skip existing task executions.
:param task_execs: List of task executions.
:return: List of workflow commands.
"""
if self._is_paused_or_completed():
return []
cmds = [
commands.SkipTask(self.wf_ex, self.wf_spec, t_e)
for t_e in task_execs
]
LOG.debug("Commands to skip workflow tasks: %s", cmds)
return cmds
@abc.abstractmethod
def get_logical_task_state(self, task_ex):
"""Determines a logical state of the given task.

View File

@ -132,6 +132,33 @@ class RunExistingTask(WorkflowCommand):
return d
class SkipTask(WorkflowCommand):
"""Command to skip an existing workflow task."""
def __init__(self, wf_ex, wf_spec, task_ex, triggered_by=None,
handles_error=False):
super(SkipTask, self).__init__(
wf_ex,
wf_spec,
spec_parser.get_task_spec(task_ex.spec),
task_ex.in_context,
triggered_by=triggered_by,
handles_error=handles_error
)
self.task_ex = task_ex
self.unique_key = task_ex.unique_key
def to_dict(self):
d = super(SkipTask, self).to_dict()
d['cmd_name'] = 'skip_task'
d['task_ex_id'] = self.task_ex.id
d['unique_key'] = self.unique_key
return d
class SetWorkflowState(WorkflowCommand):
"""Instruction to change a workflow state."""

View File

@ -181,7 +181,7 @@ def get_task_execution_result(task_ex):
def publish_variables(task_ex, task_spec):
if task_ex.state not in [states.SUCCESS, states.ERROR]:
if task_ex.state not in [states.SUCCESS, states.ERROR, states.SKIPPED]:
return
wf_ex = task_ex.workflow_execution

View File

@ -298,14 +298,27 @@ class DirectWorkflowController(base.WorkflowController):
result.append((name, params, 'on-error'))
if t_s == states.SUCCESS:
skip_is_empty = False
if t_s == states.SKIPPED:
for name, cond, params in self.wf_spec.get_on_skip_clause(t_n):
if not cond or expr.evaluate(cond, ctx_view):
params = expr.evaluate_recursively(params, ctx_view)
result.append((name, params, 'on-skip'))
# We should go to 'on-success' branch in case of
# skipping task with no 'on-skip' specified.
if len(result) == 0:
skip_is_empty = True
if t_s == states.SUCCESS or skip_is_empty:
for name, cond, params in self.wf_spec.get_on_success_clause(t_n):
if not cond or expr.evaluate(cond, ctx_view):
params = expr.evaluate_recursively(params, ctx_view)
result.append((name, params, 'on-success'))
if states.is_completed(t_s) and not states.is_cancelled(t_s):
if states.is_completed(t_s) \
and not states.is_cancelled_or_skipped(t_s):
for name, cond, params in self.wf_spec.get_on_complete_clause(t_n):
if not cond or expr.evaluate(cond, ctx_view):
params = expr.evaluate_recursively(params, ctx_view)

View File

@ -44,6 +44,9 @@ CANCELLED = 'CANCELLED'
ERROR = 'ERROR'
"""Task, action or workflow has finished with an error."""
SKIPPED = 'SKIPPED'
"""Task has been skipped."""
_ALL = [
IDLE,
WAITING,
@ -52,7 +55,8 @@ _ALL = [
PAUSED,
SUCCESS,
CANCELLED,
ERROR
ERROR,
SKIPPED
]
_VALID_TRANSITIONS = {
@ -63,7 +67,7 @@ _VALID_TRANSITIONS = {
PAUSED: [RUNNING, ERROR, CANCELLED],
SUCCESS: [],
CANCELLED: [RUNNING],
ERROR: [RUNNING]
ERROR: [RUNNING, SKIPPED]
}
TERMINAL_STATES = {SUCCESS, ERROR, CANCELLED}
@ -78,13 +82,17 @@ def is_invalid(state):
def is_completed(state):
return state in [SUCCESS, ERROR, CANCELLED]
return state in [SUCCESS, ERROR, CANCELLED, SKIPPED]
def is_cancelled(state):
return state == CANCELLED
def is_skipped(state):
return state == SKIPPED
def is_running(state):
return state in [RUNNING, RUNNING_DELAYED]
@ -105,6 +113,10 @@ def is_paused_or_completed(state):
return is_paused(state) or is_completed(state)
def is_cancelled_or_skipped(state):
return is_cancelled(state) or is_skipped(state)
def is_paused_or_idle(state):
return is_paused(state) or is_idle(state)