From 2ab055b57f6eb8e5b7af43b278b9b3237ba07000 Mon Sep 17 00:00:00 2001 From: Renat Akhmerov Date: Thu, 15 Sep 2016 15:30:06 +0300 Subject: [PATCH] Avoid storing workflow input in task inbound context * 'in_context' field of task executions changed its semantics to not store workflow input and other data stored in the initial workflow context such as openstack security context and workflow variables, therefore task executions occupy less space in DB * Introduced ContextView class to avoid having to merge dictionaries every time we need to evaluate YAQL functions against some context. This class is a composite structure built on top of regular dictionaries that provides priority based lookup algorithm over these dictionaries. For example, if we need to evaluate an expression against a task inbound context we just need to build a context view including task 'in_context', workflow initial context (wf_ex.context) and workflow input dictionary (wf_ex.input). Using this class is a significant performance boost * Fixed unit tests * Other minor changes Change-Id: I7fe90533e260e7d78818b69a087fb5175b9d5199 (cherry picked from commit a4287a5e63c81f2c1400feb064f70ba47ca781b8) --- mistral/db/v2/sqlalchemy/models.py | 7 +- mistral/engine/base.py | 22 ++- mistral/engine/policies.py | 10 +- mistral/engine/tasks.py | 19 ++- mistral/engine/workflows.py | 17 ++- mistral/tests/unit/engine/test_dataflow.py | 113 ++++++++++++++- .../tests/unit/engine/test_default_engine.py | 14 +- mistral/tests/unit/engine/test_policies.py | 43 ++++-- .../unit/engine/test_workflow_variables.py | 8 -- .../tests/unit/workbook/test_spec_caching.py | 2 +- .../unit/workflow/test_direct_workflow.py | 4 +- mistral/workflow/base.py | 13 +- mistral/workflow/data_flow.py | 134 ++++++++++++++++-- mistral/workflow/direct_workflow.py | 20 ++- 14 files changed, 346 insertions(+), 80 deletions(-) diff --git a/mistral/db/v2/sqlalchemy/models.py b/mistral/db/v2/sqlalchemy/models.py index fdfdcd091..2b6d0c622 100644 --- a/mistral/db/v2/sqlalchemy/models.py +++ b/mistral/db/v2/sqlalchemy/models.py @@ -208,7 +208,12 @@ class WorkflowExecution(Execution): output = sa.orm.deferred(sa.Column(st.JsonLongDictType(), nullable=True)) params = sa.Column(st.JsonLongDictType()) - # TODO(rakhmerov): We need to get rid of this field at all. + # Initial workflow context containing workflow variables, environment, + # openstack security context etc. + # NOTES: + # * Data stored in this structure should not be copied into inbound + # contexts of tasks. No need to duplicate it. + # * This structure does not contain workflow input. context = sa.Column(st.JsonLongDictType()) diff --git a/mistral/engine/base.py b/mistral/engine/base.py index d1d721d06..ac9e80008 100644 --- a/mistral/engine/base.py +++ b/mistral/engine/base.py @@ -175,8 +175,15 @@ class TaskPolicy(object): :param task_ex: DB model for task that is about to start. :param task_spec: Task specification. """ - # No-op by default. - data_flow.evaluate_object_fields(self, task_ex.in_context) + wf_ex = task_ex.workflow_execution + + ctx_view = data_flow.ContextView( + task_ex.in_context, + wf_ex.context, + wf_ex.input + ) + + data_flow.evaluate_object_fields(self, ctx_view) self._validate() @@ -186,8 +193,15 @@ class TaskPolicy(object): :param task_ex: Completed task DB model. :param task_spec: Completed task specification. """ - # No-op by default. - data_flow.evaluate_object_fields(self, task_ex.in_context) + wf_ex = task_ex.workflow_execution + + ctx_view = data_flow.ContextView( + task_ex.in_context, + wf_ex.context, + wf_ex.input + ) + + data_flow.evaluate_object_fields(self, ctx_view) self._validate() diff --git a/mistral/engine/policies.py b/mistral/engine/policies.py index 6243f6628..326d49332 100644 --- a/mistral/engine/policies.py +++ b/mistral/engine/policies.py @@ -305,9 +305,17 @@ class RetryPolicy(base.TaskPolicy): context_key ) + wf_ex = task_ex.workflow_execution + + ctx_view = data_flow.ContextView( + data_flow.evaluate_task_outbound_context(task_ex), + wf_ex.context, + wf_ex.input + ) + continue_on_evaluation = expressions.evaluate( self._continue_on_clause, - data_flow.evaluate_task_outbound_context(task_ex) + ctx_view ) task_ex.runtime_context = runtime_context diff --git a/mistral/engine/tasks.py b/mistral/engine/tasks.py index 180b8fd79..2027ef3bf 100644 --- a/mistral/engine/tasks.py +++ b/mistral/engine/tasks.py @@ -362,7 +362,16 @@ class RegularTask(Task): def _get_action_input(self, ctx=None): ctx = ctx or self.ctx - input_dict = expr.evaluate_recursively(self.task_spec.get_input(), ctx) + ctx_view = data_flow.ContextView( + ctx, + self.wf_ex.context, + self.wf_ex.input + ) + + input_dict = expr.evaluate_recursively( + self.task_spec.get_input(), + ctx_view + ) return utils.merge_dicts( input_dict, @@ -478,9 +487,15 @@ class WithItemsTask(RegularTask): :return: the list of tuples containing indexes and the corresponding input dict. """ + ctx_view = data_flow.ContextView( + self.ctx, + self.wf_ex.context, + self.wf_ex.input + ) + with_items_inputs = expr.evaluate_recursively( self.task_spec.get_with_items(), - self.ctx + ctx_view ) with_items.validate_input(with_items_inputs) diff --git a/mistral/engine/workflows.py b/mistral/engine/workflows.py index 10e87348c..d2e6ffdaf 100644 --- a/mistral/engine/workflows.py +++ b/mistral/engine/workflows.py @@ -14,7 +14,6 @@ # limitations under the License. import abc -import copy from oslo_config import cfg from oslo_log import log as logging from osprofiler import profiler @@ -220,7 +219,6 @@ class Workflow(object): }) self.wf_ex.input = input_dict or {} - self.wf_ex.context = copy.deepcopy(input_dict) or {} env = _get_environment(params) @@ -309,18 +307,23 @@ class Workflow(object): wf_ctrl = wf_base.get_controller(self.wf_ex, self.wf_spec) if wf_ctrl.any_cancels(): - self._cancel_workflow( - _build_cancel_info_message(wf_ctrl, self.wf_ex) - ) + msg = _build_cancel_info_message(wf_ctrl, self.wf_ex) + + self._cancel_workflow(msg) elif wf_ctrl.all_errors_handled(): - self._succeed_workflow(wf_ctrl.evaluate_workflow_final_context()) + ctx = wf_ctrl.evaluate_workflow_final_context() + + self._succeed_workflow(ctx) else: - self._fail_workflow(_build_fail_info_message(wf_ctrl, self.wf_ex)) + msg = _build_fail_info_message(wf_ctrl, self.wf_ex) + + self._fail_workflow(msg) return 0 def _succeed_workflow(self, final_context, msg=None): self.wf_ex.output = data_flow.evaluate_workflow_output( + self.wf_ex, self.wf_spec, final_context ) diff --git a/mistral/tests/unit/engine/test_dataflow.py b/mistral/tests/unit/engine/test_dataflow.py index 75cbc0be7..4850ef9f8 100644 --- a/mistral/tests/unit/engine/test_dataflow.py +++ b/mistral/tests/unit/engine/test_dataflow.py @@ -17,6 +17,8 @@ from oslo_config import cfg from mistral.db.v2 import api as db_api from mistral.db.v2.sqlalchemy import models +from mistral import exceptions as exc +from mistral import expressions as expr from mistral.services import workflows as wf_service from mistral.tests.unit import base as test_base from mistral.tests.unit.engine import base as engine_test_base @@ -84,12 +86,8 @@ class DataFlowEngineTest(engine_test_base.EngineTestCase): ) # Make sure that task inbound context doesn't contain workflow - # specification, input and params. - ctx = task1.in_context - - self.assertFalse('spec' in ctx['__execution']) - self.assertFalse('input' in ctx['__execution']) - self.assertFalse('params' in ctx['__execution']) + # execution info. + self.assertFalse('__execution' in task1.in_context) def test_linear_with_branches_dataflow(self): linear_with_branches_wf = """--- @@ -595,3 +593,106 @@ class DataFlowTest(test_base.BaseTest): [1, 1], data_flow.get_task_execution_result(task_ex) ) + + def test_context_view(self): + ctx = data_flow.ContextView( + { + 'k1': 'v1', + 'k11': 'v11', + 'k3': 'v3' + }, + { + 'k2': 'v2', + 'k21': 'v21', + 'k3': 'v32' + } + ) + + self.assertIsInstance(ctx, dict) + self.assertEqual(5, len(ctx)) + + self.assertIn('k1', ctx) + self.assertIn('k11', ctx) + self.assertIn('k3', ctx) + self.assertIn('k2', ctx) + self.assertIn('k21', ctx) + + self.assertEqual('v1', ctx['k1']) + self.assertEqual('v1', ctx.get('k1')) + self.assertEqual('v11', ctx['k11']) + self.assertEqual('v11', ctx.get('k11')) + self.assertEqual('v3', ctx['k3']) + self.assertEqual('v2', ctx['k2']) + self.assertEqual('v2', ctx.get('k2')) + self.assertEqual('v21', ctx['k21']) + self.assertEqual('v21', ctx.get('k21')) + + self.assertIsNone(ctx.get('Not existing key')) + + self.assertRaises(exc.MistralError, ctx.update) + self.assertRaises(exc.MistralError, ctx.clear) + self.assertRaises(exc.MistralError, ctx.pop, 'k1') + self.assertRaises(exc.MistralError, ctx.popitem) + self.assertRaises(exc.MistralError, ctx.__setitem__, 'k5', 'v5') + self.assertRaises(exc.MistralError, ctx.__delitem__, 'k2') + + self.assertEqual('v1', expr.evaluate('<% $.k1 %>', ctx)) + self.assertEqual('v2', expr.evaluate('<% $.k2 %>', ctx)) + self.assertEqual('v3', expr.evaluate('<% $.k3 %>', ctx)) + + # Now change the order of dictionaries and make sure to have + # a different for key 'k3'. + ctx = data_flow.ContextView( + { + 'k2': 'v2', + 'k21': 'v21', + 'k3': 'v32' + }, + { + 'k1': 'v1', + 'k11': 'v11', + 'k3': 'v3' + } + ) + + self.assertEqual('v32', expr.evaluate('<% $.k3 %>', ctx)) + + def test_context_view_eval_root_with_yaql(self): + ctx = data_flow.ContextView( + {'k1': 'v1'}, + {'k2': 'v2'} + ) + + res = expr.evaluate('<% $ %>', ctx) + + self.assertIsNotNone(res) + self.assertIsInstance(res, dict) + self.assertEqual(2, len(res)) + + def test_context_view_eval_keys(self): + ctx = data_flow.ContextView( + {'k1': 'v1'}, + {'k2': 'v2'} + ) + + res = expr.evaluate('<% $.keys() %>', ctx) + + self.assertIsNotNone(res) + self.assertIsInstance(res, list) + self.assertEqual(2, len(res)) + self.assertIn('k1', res) + self.assertIn('k2', res) + + def test_context_view_eval_values(self): + ctx = data_flow.ContextView( + {'k1': 'v1'}, + {'k2': 'v2'} + ) + + res = expr.evaluate('<% $.values() %>', ctx) + + self.assertIsNotNone(res) + self.assertIsInstance(res, list) + self.assertEqual(2, len(res)) + self.assertIn('v1', res) + self.assertIn('v2', res) diff --git a/mistral/tests/unit/engine/test_default_engine.py b/mistral/tests/unit/engine/test_default_engine.py index 79f380a95..c0b57e1a6 100644 --- a/mistral/tests/unit/engine/test_default_engine.py +++ b/mistral/tests/unit/engine/test_default_engine.py @@ -116,7 +116,6 @@ class DefaultEngineTest(base.DbTestCase): self.assertIsNotNone(wf_ex) self.assertEqual(states.RUNNING, wf_ex.state) self.assertEqual('my execution', wf_ex.description) - self._assert_dict_contains_subset(wf_input, wf_ex.context) self.assertIn('__execution', wf_ex.context) # Note: We need to reread execution to access related tasks. @@ -133,9 +132,6 @@ class DefaultEngineTest(base.DbTestCase): self.assertDictEqual({}, task_ex.runtime_context) # Data Flow properties. - self._assert_dict_contains_subset(wf_input, task_ex.in_context) - self.assertIn('__execution', task_ex.in_context) - action_execs = db_api.get_action_executions( task_execution_id=task_ex.id ) @@ -159,7 +155,6 @@ class DefaultEngineTest(base.DbTestCase): self.assertIsNotNone(wf_ex) self.assertEqual(states.RUNNING, wf_ex.state) - self._assert_dict_contains_subset(wf_input, wf_ex.context) self.assertIn('__execution', wf_ex.context) # Note: We need to reread execution to access related tasks. @@ -176,9 +171,6 @@ class DefaultEngineTest(base.DbTestCase): self.assertDictEqual({}, task_ex.runtime_context) # Data Flow properties. - self._assert_dict_contains_subset(wf_input, task_ex.in_context) - self.assertIn('__execution', task_ex.in_context) - action_execs = db_api.get_action_executions( task_execution_id=task_ex.id ) @@ -318,8 +310,7 @@ class DefaultEngineTest(base.DbTestCase): self.assertEqual(states.RUNNING, task1_ex.state) self.assertIsNotNone(task1_ex.spec) self.assertDictEqual({}, task1_ex.runtime_context) - self._assert_dict_contains_subset(wf_input, task1_ex.in_context) - self.assertIn('__execution', task1_ex.in_context) + self.assertNotIn('__execution', task1_ex.in_context) action_execs = db_api.get_action_executions( task_execution_id=task1_ex.id @@ -345,8 +336,6 @@ class DefaultEngineTest(base.DbTestCase): # Data Flow properties. task1_ex = db_api.get_task_execution(task1_ex.id) # Re-read the state. - self._assert_dict_contains_subset(wf_input, task1_ex.in_context) - self.assertIn('__execution', task1_ex.in_context) self.assertDictEqual({'var': 'Hey'}, task1_ex.published) self.assertDictEqual({'output': 'Hey'}, task1_action_ex.input) self.assertDictEqual({'result': 'Hey'}, task1_action_ex.output) @@ -396,7 +385,6 @@ class DefaultEngineTest(base.DbTestCase): self.assertEqual(states.SUCCESS, task2_action_ex.state) # Data Flow properties. - self.assertIn('__execution', task1_ex.in_context) self.assertDictEqual({'output': 'Hi'}, task2_action_ex.input) self.assertDictEqual({}, task2_ex.published) self.assertDictEqual({'output': 'Hi'}, task2_action_ex.input) diff --git a/mistral/tests/unit/engine/test_policies.py b/mistral/tests/unit/engine/test_policies.py index 7d7137512..4d60aca31 100644 --- a/mistral/tests/unit/engine/test_policies.py +++ b/mistral/tests/unit/engine/test_policies.py @@ -17,6 +17,7 @@ from oslo_config import cfg from mistral.actions import std_actions from mistral.db.v2 import api as db_api +from mistral.db.v2.sqlalchemy import models from mistral.engine import policies from mistral import exceptions as exc from mistral.services import workbooks as wb_service @@ -364,18 +365,29 @@ class PoliciesTest(base.EngineTestCase): "delay": {"type": "integer"} } } - task_db = type('Task', (object,), {'in_context': {'int_var': 5}}) + + wf_ex = models.WorkflowExecution( + id='1-2-3-4', + context={}, + input={} + ) + + task_ex = models.TaskExecution(in_context={'int_var': 5}) + task_ex.workflow_execution = wf_ex + policy.delay = "<% $.int_var %>" # Validation is ok. - policy.before_task_start(task_db, None) + policy.before_task_start(task_ex, None) policy.delay = "some_string" # Validation is failing now. exception = self.assertRaises( exc.InvalidModelException, - policy.before_task_start, task_db, None + policy.before_task_start, + task_ex, + None ) self.assertIn("Invalid data type in TaskPolicy", str(exception)) @@ -494,7 +506,11 @@ class PoliciesTest(base.EngineTestCase): self.assertEqual(states.RUNNING, task_ex.state) self.assertDictEqual({}, task_ex.runtime_context) - self.assertEqual(2, task_ex.in_context['wait_after']) + + # TODO(rakhmerov): This check doesn't make sense anymore because + # we don't store evaluated value anywhere. + # Need to create a better test. + # self.assertEqual(2, task_ex.in_context['wait_after']) def test_retry_policy(self): wb_service.create_workbook_v2(RETRY_WB) @@ -535,8 +551,11 @@ class PoliciesTest(base.EngineTestCase): self.assertEqual(states.RUNNING, task_ex.state) self.assertDictEqual({}, task_ex.runtime_context) - self.assertEqual(3, task_ex.in_context["count"]) - self.assertEqual(1, task_ex.in_context["delay"]) + # TODO(rakhmerov): This check doesn't make sense anymore because + # we don't store evaluated values anywhere. + # Need to create a better test. + # self.assertEqual(3, task_ex.in_context["count"]) + # self.assertEqual(1, task_ex.in_context["delay"]) def test_retry_policy_never_happen(self): retry_wb = """--- @@ -908,7 +927,10 @@ class PoliciesTest(base.EngineTestCase): self.assertEqual(states.RUNNING, task_ex.state) - self.assertEqual(1, task_ex.in_context['timeout']) + # TODO(rakhmerov): This check doesn't make sense anymore because + # we don't store evaluated 'timeout' value anywhere. + # Need to create a better test. + # self.assertEqual(1, task_ex.in_context['timeout']) def test_pause_before_policy(self): wb_service.create_workbook_v2(PAUSE_BEFORE_WB) @@ -1012,9 +1034,7 @@ class PoliciesTest(base.EngineTestCase): self.assertEqual(states.SUCCESS, task_ex.state) - runtime_context = task_ex.runtime_context - - self.assertEqual(4, runtime_context['concurrency']) + self.assertEqual(4, task_ex.runtime_context['concurrency']) def test_concurrency_is_in_runtime_context_from_var(self): wb_service.create_workbook_v2(CONCURRENCY_WB_FROM_VAR) @@ -1023,12 +1043,13 @@ class PoliciesTest(base.EngineTestCase): wf_ex = self.engine.start_workflow('wb.wf1', {'concurrency': 4}) wf_ex = db_api.get_workflow_execution(wf_ex.id) + task_ex = self._assert_single_item( wf_ex.task_executions, name='task1' ) - self.assertEqual(4, task_ex.in_context['concurrency']) + self.assertEqual(4, task_ex.runtime_context['concurrency']) def test_wrong_policy_prop_type(self): wb = """--- diff --git a/mistral/tests/unit/engine/test_workflow_variables.py b/mistral/tests/unit/engine/test_workflow_variables.py index 74ccfda0e..8de657a0a 100644 --- a/mistral/tests/unit/engine/test_workflow_variables.py +++ b/mistral/tests/unit/engine/test_workflow_variables.py @@ -64,14 +64,6 @@ class WorkflowVariablesTest(base.EngineTestCase): self.assertEqual(states.SUCCESS, task1.state) - self._assert_dict_contains_subset( - { - 'literal_var': 'Literal value', - 'yaql_var': 'Hello Renat' - }, - task1.in_context - ) - self.assertDictEqual( { 'literal_var': 'Literal value', diff --git a/mistral/tests/unit/workbook/test_spec_caching.py b/mistral/tests/unit/workbook/test_spec_caching.py index 1a50e6151..4ffaa1eb1 100644 --- a/mistral/tests/unit/workbook/test_spec_caching.py +++ b/mistral/tests/unit/workbook/test_spec_caching.py @@ -158,7 +158,7 @@ class SpecificationCachingTest(base.DbTestCase): self.assertEqual(0, spec_parser.get_wf_execution_spec_cache_size()) self.assertEqual(2, spec_parser.get_wf_definition_spec_cache_size()) - def test_update_workflow_spec_for_execution(self): + def test_cache_workflow_spec_by_execution_id(self): wf_text = """ version: '2.0' diff --git a/mistral/tests/unit/workflow/test_direct_workflow.py b/mistral/tests/unit/workflow/test_direct_workflow.py index a97060e26..00f037b76 100644 --- a/mistral/tests/unit/workflow/test_direct_workflow.py +++ b/mistral/tests/unit/workflow/test_direct_workflow.py @@ -36,7 +36,9 @@ class DirectWorkflowControllerTest(base.DbTestCase): id='1-2-3-4', spec=wf_spec.to_dict(), state=states.RUNNING, - workflow_id=wfs[0].id + workflow_id=wfs[0].id, + input={}, + context={} ) self.wf_ex = wf_ex diff --git a/mistral/workflow/base.py b/mistral/workflow/base.py index 76f9d480e..3015a7d61 100644 --- a/mistral/workflow/base.py +++ b/mistral/workflow/base.py @@ -182,16 +182,15 @@ class WorkflowController(object): # to cover 'split' (aka 'merge') use case. upstream_task_execs = self._get_upstream_task_executions(task_spec) - upstream_ctx = data_flow.evaluate_upstream_context(upstream_task_execs) - - ctx = u.merge_dicts( - copy.deepcopy(self.wf_ex.context), - upstream_ctx - ) + ctx = data_flow.evaluate_upstream_context(upstream_task_execs) + # TODO(rakhmerov): Seems like we can fully get rid of '__env' in + # task context if we are OK to have it only in workflow execution + # object (wf_ex.context). Now we can selectively modify env + # for some tasks if we resume or re-run a workflow. if self.wf_ex.context: ctx['__env'] = u.merge_dicts( - copy.deepcopy(upstream_ctx.get('__env', {})), + copy.deepcopy(ctx.get('__env', {})), copy.deepcopy(self.wf_ex.context.get('__env', {})) ) diff --git a/mistral/workflow/data_flow.py b/mistral/workflow/data_flow.py index c3c19030c..ff6c8dff3 100644 --- a/mistral/workflow/data_flow.py +++ b/mistral/workflow/data_flow.py @@ -20,6 +20,7 @@ from oslo_log import log as logging from mistral import context as auth_ctx from mistral.db.v2.sqlalchemy import models +from mistral import exceptions as exc from mistral import expressions as expr from mistral import utils from mistral.utils import inspect_utils @@ -31,6 +32,101 @@ LOG = logging.getLogger(__name__) CONF = cfg.CONF +class ContextView(dict): + """Workflow context view. + + It's essentially an immutable composite structure providing fast lookup + over multiple dictionaries w/o having to merge those dictionaries every + time. The lookup algorithm simply iterates over provided dictionaries + one by one and returns a value taken from the first dictionary where + the provided key exists. This means that these dictionaries must be + provided in the order of decreasing priorities. + + Note: Although this class extends built-in 'dict' it shouldn't be + considered a normal dictionary because it may not implement all + methods and account for all corner cases. It's only a read-only view. + """ + + def __init__(self, *dicts): + super(ContextView, self).__init__() + + self.dicts = dicts or [] + + def __getitem__(self, key): + for d in self.dicts: + if key in d: + return d[key] + + raise KeyError(key) + + def get(self, key, default=None): + for d in self.dicts: + if key in d: + return d[key] + + return default + + def __contains__(self, key): + return any(key in d for d in self.dicts) + + def keys(self): + keys = set() + + for d in self.dicts: + keys.update(d.keys()) + + return keys + + def items(self): + return [(k, self[k]) for k in self.keys()] + + def values(self): + return [self[k] for k in self.keys()] + + def iteritems(self): + # NOTE: This is for compatibility with Python 2.7 + # YAQL converts output objects after they are evaluated + # to basic types and it uses six.iteritems() internally + # which calls d.items() in case of Python 2.7 and d.iteritems() + # for Python 2.7 + return iter(self.items()) + + def iterkeys(self): + # NOTE: This is for compatibility with Python 2.7 + # See the comment for iteritems(). + return iter(self.keys()) + + def itervalues(self): + # NOTE: This is for compatibility with Python 2.7 + # See the comment for iteritems(). + return iter(self.values()) + + def __len__(self): + return len(self.keys()) + + @staticmethod + def _raise_immutable_error(): + raise exc.MistralError('Context view is immutable.') + + def __setitem__(self, key, value): + self._raise_immutable_error() + + def update(self, E=None, **F): + self._raise_immutable_error() + + def clear(self): + self._raise_immutable_error() + + def pop(self, k, d=None): + self._raise_immutable_error() + + def popitem(self): + self._raise_immutable_error() + + def __delitem__(self, key): + self._raise_immutable_error() + + def evaluate_upstream_context(upstream_task_execs): published_vars = {} ctx = {} @@ -90,7 +186,13 @@ def publish_variables(task_ex, task_spec): if task_ex.state != states.SUCCESS: return - expr_ctx = task_ex.in_context + wf_ex = task_ex.workflow_execution + + expr_ctx = ContextView( + task_ex.in_context, + wf_ex.context, + wf_ex.input + ) if task_ex.name in expr_ctx: LOG.warning( @@ -112,26 +214,28 @@ def evaluate_task_outbound_context(task_ex): :param task_ex: DB task. :return: Outbound task Data Flow context. """ - - in_context = (copy.deepcopy(dict(task_ex.in_context)) - if task_ex.in_context is not None else {}) + in_context = ( + copy.deepcopy(dict(task_ex.in_context)) + if task_ex.in_context is not None else {} + ) return utils.update_dict(in_context, task_ex.published) -def evaluate_workflow_output(wf_spec, ctx): +def evaluate_workflow_output(wf_ex, wf_spec, ctx): """Evaluates workflow output. + :param wf_ex: Workflow execution. :param wf_spec: Workflow specification. :param ctx: Final Data Flow context (cause task's outbound context). """ - ctx = copy.deepcopy(ctx) - output_dict = wf_spec.get_output() - # Evaluate workflow 'publish' clause using the final workflow context. - output = expr.evaluate_recursively(output_dict, ctx) + # Evaluate workflow 'output' clause using the final workflow context. + ctx_view = ContextView(ctx, wf_ex.context, wf_ex.input) + + output = expr.evaluate_recursively(output_dict, ctx_view) # TODO(rakhmerov): Many don't like that we return the whole context # if 'output' is not explicitly defined. @@ -168,6 +272,7 @@ def add_execution_to_context(wf_ex): def add_environment_to_context(wf_ex): + # TODO(rakhmerov): This is redundant, we can always get env from WF params wf_ex.context = wf_ex.context or {} # If env variables are provided, add an evaluated copy into the context. @@ -181,10 +286,13 @@ def add_environment_to_context(wf_ex): def add_workflow_variables_to_context(wf_ex, wf_spec): wf_ex.context = wf_ex.context or {} - return utils.merge_dicts( - wf_ex.context, - expr.evaluate_recursively(wf_spec.get_vars(), wf_ex.context) - ) + # The context for calculating workflow variables is workflow input + # and other data already stored in workflow initial context. + ctx_view = ContextView(wf_ex.context, wf_ex.input) + + wf_vars = expr.evaluate_recursively(wf_spec.get_vars(), ctx_view) + + utils.merge_dicts(wf_ex.context, wf_vars) def evaluate_object_fields(obj, context): diff --git a/mistral/workflow/direct_workflow.py b/mistral/workflow/direct_workflow.py index 17f294d92..6587ccbf4 100644 --- a/mistral/workflow/direct_workflow.py +++ b/mistral/workflow/direct_workflow.py @@ -183,9 +183,15 @@ class DirectWorkflowController(base.WorkflowController): def all_errors_handled(self): for t_ex in lookup_utils.find_error_task_executions(self.wf_ex.id): + ctx_view = data_flow.ContextView( + data_flow.evaluate_task_outbound_context(t_ex), + self.wf_ex.context, + self.wf_ex.input + ) + tasks_on_error = self._find_next_tasks_for_clause( self.wf_spec.get_on_error_clause(t_ex.name), - data_flow.evaluate_task_outbound_context(t_ex) + ctx_view ) if not tasks_on_error: @@ -218,7 +224,11 @@ class DirectWorkflowController(base.WorkflowController): t_state = task_ex.state t_name = task_ex.name - ctx = data_flow.evaluate_task_outbound_context(task_ex) + ctx_view = data_flow.ContextView( + data_flow.evaluate_task_outbound_context(task_ex), + self.wf_ex.context, + self.wf_ex.input + ) t_names_and_params = [] @@ -226,7 +236,7 @@ class DirectWorkflowController(base.WorkflowController): t_names_and_params += ( self._find_next_tasks_for_clause( self.wf_spec.get_on_complete_clause(t_name), - ctx + ctx_view ) ) @@ -234,7 +244,7 @@ class DirectWorkflowController(base.WorkflowController): t_names_and_params += ( self._find_next_tasks_for_clause( self.wf_spec.get_on_error_clause(t_name), - ctx + ctx_view ) ) @@ -242,7 +252,7 @@ class DirectWorkflowController(base.WorkflowController): t_names_and_params += ( self._find_next_tasks_for_clause( self.wf_spec.get_on_success_clause(t_name), - ctx + ctx_view ) )