Optimize the logic that check if 'join' task is allowed to start

* Moved DB related lookup functions from workflow/utils to
  a separate module lookup_utils
* Optimized data access pattern where we calculate 'join'
  states induced by upstream tasks, applied caching for some
  lookup operations
* Added caching in direct workflow specification for inbound
  and outbound task specs, if workflow size is large calculating
  them may be expensive
* Added an adaptive delay between calls that refresh 'join'
  task state based on a number of unfulfilled preconditions
  returned by workflow controller

Change-Id: I383fa52f2f05877df7522048020cc7ff280324a2
This commit is contained in:
Renat Akhmerov 2016-09-13 12:37:20 +03:00
parent a0f6c7ae3f
commit 9d06a61fe4
10 changed files with 293 additions and 122 deletions

View File

@ -73,7 +73,7 @@ def run_task(wf_cmd):
return return
if task.is_waiting() and (task.is_created() or task.is_state_changed()): if task.is_waiting() and (task.is_created() or task.is_state_changed()):
_schedule_refresh_task_state(task.task_ex) _schedule_refresh_task_state(task.task_ex, 1)
@profiler.trace('task-handler-on-action-complete') @profiler.trace('task-handler-on-action-complete')
@ -251,15 +251,27 @@ def _refresh_task_state(task_ex_id):
wf_spec wf_spec
) )
state, state_info = wf_ctrl.get_logical_task_state(task_ex) state, state_info, cardinality = wf_ctrl.get_logical_task_state(
task_ex
)
if state == states.RUNNING: if state == states.RUNNING:
continue_task(task_ex) continue_task(task_ex)
elif state == states.ERROR: elif state == states.ERROR:
fail_task(task_ex, state_info) fail_task(task_ex, state_info)
elif state == states.WAITING: elif state == states.WAITING:
# TODO(rakhmerov): Algorithm for increasing rescheduling delay. # Let's assume that a task takes 0.01 sec in average to complete
_schedule_refresh_task_state(task_ex, 1) # and based on this assumption calculate a time of the next check.
# The estimation is very rough, of course, but this delay will be
# decreasing as task preconditions will be completing which will
# give a decent asymptotic approximation.
# For example, if a 'join' task has 100 inbound incomplete tasks
# then the next 'refresh_task_state' call will happen in 10
# seconds. For 500 tasks it will be 50 seconds. The larger the
# workflow is, the more beneficial this mechanism will be.
delay = int(cardinality * 0.01)
_schedule_refresh_task_state(task_ex, max(1, delay))
else: else:
# Must never get here. # Must never get here.
raise RuntimeError( raise RuntimeError(

View File

@ -34,6 +34,7 @@ from mistral.workbook import parser as spec_parser
from mistral.workflow import base as wf_base from mistral.workflow import base as wf_base
from mistral.workflow import commands from mistral.workflow import commands
from mistral.workflow import data_flow from mistral.workflow import data_flow
from mistral.workflow import lookup_utils
from mistral.workflow import states from mistral.workflow import states
from mistral.workflow import utils as wf_utils from mistral.workflow import utils as wf_utils
@ -158,6 +159,11 @@ class Workflow(object):
assert self.wf_ex assert self.wf_ex
# Since some lookup utils functions may use cache for completed tasks
# we need to clean caches to make sure that stale objects can't be
# retrieved.
lookup_utils.clean_caches()
wf_service.update_workflow_execution_env(self.wf_ex, env) wf_service.update_workflow_execution_env(self.wf_ex, env)
self.set_state(states.RUNNING, recursive=True) self.set_state(states.RUNNING, recursive=True)
@ -429,7 +435,7 @@ def _build_fail_info_message(wf_ctrl, wf_ex):
failed_tasks = sorted( failed_tasks = sorted(
filter( filter(
lambda t: not wf_ctrl.is_error_handled_for(t), lambda t: not wf_ctrl.is_error_handled_for(t),
wf_utils.find_error_task_executions(wf_ex) lookup_utils.find_error_task_executions(wf_ex.id)
), ),
key=lambda t: t.name key=lambda t: t.name
) )
@ -468,7 +474,7 @@ def _build_fail_info_message(wf_ctrl, wf_ex):
def _build_cancel_info_message(wf_ctrl, wf_ex): def _build_cancel_info_message(wf_ctrl, wf_ex):
# Try to find where cancel is exactly. # Try to find where cancel is exactly.
cancelled_tasks = sorted( cancelled_tasks = sorted(
wf_utils.find_cancelled_task_executions(wf_ex), lookup_utils.find_cancelled_task_executions(wf_ex.id),
key=lambda t: t.name key=lambda t: t.name
) )

View File

@ -36,6 +36,7 @@ from mistral.tests.unit import config as test_config
from mistral.utils import inspect_utils as i_utils from mistral.utils import inspect_utils as i_utils
from mistral import version from mistral import version
from mistral.workbook import parser as spec_parser from mistral.workbook import parser as spec_parser
from mistral.workflow import lookup_utils
RESOURCES_PATH = 'tests/resources/' RESOURCES_PATH = 'tests/resources/'
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -244,6 +245,8 @@ class DbTestCase(BaseTest):
action_manager.sync_db() action_manager.sync_db()
def _clean_db(self): def _clean_db(self):
lookup_utils.clean_caches()
contexts = [ contexts = [
get_context(default=False), get_context(default=False),
get_context(default=True) get_context(default=True)

View File

@ -13,13 +13,12 @@
# limitations under the License. # limitations under the License.
from oslo_serialization import jsonutils
from stevedore import extension
import yaql import yaql
from mistral.db.v2 import api as db_api from mistral.db.v2 import api as db_api
from mistral import utils from mistral import utils
from mistral.workflow import utils as wf_utils
from oslo_serialization import jsonutils
from stevedore import extension
ROOT_CONTEXT = None ROOT_CONTEXT = None
@ -87,8 +86,6 @@ def task_(context, task_name):
# Importing data_flow in order to break cycle dependency between modules. # Importing data_flow in order to break cycle dependency between modules.
from mistral.workflow import data_flow from mistral.workflow import data_flow
wf_ex = db_api.get_workflow_execution(context['__execution']['id'])
# This section may not exist in a context if it's calculated not in # This section may not exist in a context if it's calculated not in
# task scope. # task scope.
cur_task = context['__task_execution'] cur_task = context['__task_execution']
@ -96,7 +93,10 @@ def task_(context, task_name):
if cur_task and cur_task['name'] == task_name: if cur_task and cur_task['name'] == task_name:
task_ex = db_api.get_task_execution(cur_task['id']) task_ex = db_api.get_task_execution(cur_task['id'])
else: else:
task_execs = wf_utils.find_task_executions_by_name(wf_ex, task_name) task_execs = db_api.get_task_executions(
workflow_execution_id=context['__execution']['id'],
name=task_name
)
# TODO(rakhmerov): Account for multiple executions (i.e. in case of # TODO(rakhmerov): Account for multiple executions (i.e. in case of
# cycles). # cycles).

View File

@ -15,6 +15,7 @@
from oslo_utils import uuidutils from oslo_utils import uuidutils
import six import six
import threading
from mistral import exceptions as exc from mistral import exceptions as exc
from mistral import utils from mistral import utils
@ -150,6 +151,18 @@ class DirectWorkflowSpec(WorkflowSpec):
} }
} }
def __init__(self, data):
super(DirectWorkflowSpec, self).__init__(data)
# Init simple dictionary based caches for inbound and
# outbound task specifications. In fact, we don't need
# any special cache implementations here because these
# structures can't grow indefinitely.
self.inbound_tasks_cache_lock = threading.RLock()
self.inbound_tasks_cache = {}
self.outbound_tasks_cache_lock = threading.RLock()
self.outbound_tasks_cache = {}
def validate_semantics(self): def validate_semantics(self):
super(DirectWorkflowSpec, self).validate_semantics() super(DirectWorkflowSpec, self).validate_semantics()
@ -211,17 +224,43 @@ class DirectWorkflowSpec(WorkflowSpec):
] ]
def find_inbound_task_specs(self, task_spec): def find_inbound_task_specs(self, task_spec):
return [ task_name = task_spec.get_name()
with self.inbound_tasks_cache_lock:
specs = self.inbound_tasks_cache.get(task_name)
if specs is not None:
return specs
specs = [
t_s for t_s in self.get_tasks() t_s for t_s in self.get_tasks()
if self.transition_exists(t_s.get_name(), task_spec.get_name()) if self.transition_exists(t_s.get_name(), task_name)
] ]
with self.inbound_tasks_cache_lock:
self.inbound_tasks_cache[task_name] = specs
return specs
def find_outbound_task_specs(self, task_spec): def find_outbound_task_specs(self, task_spec):
return [ task_name = task_spec.get_name()
with self.outbound_tasks_cache_lock:
specs = self.outbound_tasks_cache.get(task_name)
if specs is not None:
return specs
specs = [
t_s for t_s in self.get_tasks() t_s for t_s in self.get_tasks()
if self.transition_exists(task_spec.get_name(), t_s.get_name()) if self.transition_exists(task_name, t_s.get_name())
] ]
with self.outbound_tasks_cache_lock:
self.outbound_tasks_cache[task_name] = specs
return specs
def has_inbound_transitions(self, task_spec): def has_inbound_transitions(self, task_spec):
return len(self.find_inbound_task_specs(task_spec)) > 0 return len(self.find_inbound_task_specs(task_spec)) > 0

View File

@ -26,13 +26,14 @@ from mistral import utils as u
from mistral.workbook import parser as spec_parser from mistral.workbook import parser as spec_parser
from mistral.workflow import commands from mistral.workflow import commands
from mistral.workflow import data_flow from mistral.workflow import data_flow
from mistral.workflow import lookup_utils
from mistral.workflow import states from mistral.workflow import states
from mistral.workflow import utils as wf_utils
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@profiler.trace('wf-controller-get-controller')
def get_controller(wf_ex, wf_spec=None): def get_controller(wf_ex, wf_spec=None):
"""Gets a workflow controller instance by given workflow execution object. """Gets a workflow controller instance by given workflow execution object.
@ -130,8 +131,13 @@ class WorkflowController(object):
"""Determines a logical state of the given task. """Determines a logical state of the given task.
:param task_ex: Task execution. :param task_ex: Task execution.
:return: Tuple (state, state_info) which the given task should have :return: Tuple (state, state_info, cardinality) where 'state' and
according to workflow rules and current states of other tasks. 'state_info' are the corresponding values which the given
task should have according to workflow rules and current
states of other tasks. 'cardinality' gives the estimation on
the number of preconditions that are not yet met in case if
state is WAITING. This number can be used to estimate how
frequently we can refresh the state of this task.
""" """
raise NotImplementedError raise NotImplementedError
@ -159,7 +165,9 @@ class WorkflowController(object):
:return: True if there is one or more tasks in cancelled state. :return: True if there is one or more tasks in cancelled state.
""" """
return len(wf_utils.find_cancelled_task_executions(self.wf_ex)) > 0 t_execs = lookup_utils.find_cancelled_task_executions(self.wf_ex.id)
return len(t_execs) > 0
@abc.abstractmethod @abc.abstractmethod
def evaluate_workflow_final_context(self): def evaluate_workflow_final_context(self):
@ -214,8 +222,8 @@ class WorkflowController(object):
return [] return []
# Add all tasks in IDLE state. # Add all tasks in IDLE state.
idle_tasks = wf_utils.find_task_executions_with_state( idle_tasks = lookup_utils.find_task_executions_with_state(
self.wf_ex, self.wf_ex.id,
states.IDLE states.IDLE
) )

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from oslo_log import log as logging from oslo_log import log as logging
from osprofiler import profiler
from mistral import exceptions as exc from mistral import exceptions as exc
from mistral import expressions as expr from mistral import expressions as expr
@ -20,8 +21,8 @@ from mistral import utils
from mistral.workflow import base from mistral.workflow import base
from mistral.workflow import commands from mistral.workflow import commands
from mistral.workflow import data_flow from mistral.workflow import data_flow
from mistral.workflow import lookup_utils
from mistral.workflow import states from mistral.workflow import states
from mistral.workflow import utils as wf_utils
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -46,8 +47,8 @@ class DirectWorkflowController(base.WorkflowController):
return list( return list(
filter( filter(
lambda t_e: self._is_upstream_task_execution(task_spec, t_e), lambda t_e: self._is_upstream_task_execution(task_spec, t_e),
wf_utils.find_task_executions_by_specs( lookup_utils.find_task_executions_by_specs(
self.wf_ex, self.wf_ex.id,
self.wf_spec.find_inbound_task_specs(task_spec) self.wf_spec.find_inbound_task_specs(task_spec)
) )
) )
@ -60,7 +61,7 @@ class DirectWorkflowController(base.WorkflowController):
if not t_spec.get_join(): if not t_spec.get_join():
return not t_ex_candidate.processed return not t_ex_candidate.processed
induced_state = self._get_induced_join_state( induced_state, _ = self._get_induced_join_state(
self.wf_spec.get_tasks()[t_ex_candidate.name], self.wf_spec.get_tasks()[t_ex_candidate.name],
t_spec t_spec
) )
@ -173,7 +174,7 @@ class DirectWorkflowController(base.WorkflowController):
# A simple 'non-join' task does not have any preconditions # A simple 'non-join' task does not have any preconditions
# based on state of other tasks so its logical state always # based on state of other tasks so its logical state always
# equals to its real state. # equals to its real state.
return task_ex.state, task_ex.state_info return task_ex.state, task_ex.state_info, 0
return self._get_join_logical_state(task_spec) return self._get_join_logical_state(task_spec)
@ -181,8 +182,7 @@ class DirectWorkflowController(base.WorkflowController):
return bool(self.wf_spec.get_on_error_clause(task_ex.name)) return bool(self.wf_spec.get_on_error_clause(task_ex.name))
def all_errors_handled(self): def all_errors_handled(self):
for t_ex in wf_utils.find_error_task_executions(self.wf_ex): for t_ex in lookup_utils.find_error_task_executions(self.wf_ex.id):
tasks_on_error = self._find_next_tasks_for_clause( tasks_on_error = self._find_next_tasks_for_clause(
self.wf_spec.get_on_error_clause(t_ex.name), self.wf_spec.get_on_error_clause(t_ex.name),
data_flow.evaluate_task_outbound_context(t_ex) data_flow.evaluate_task_outbound_context(t_ex)
@ -197,7 +197,7 @@ class DirectWorkflowController(base.WorkflowController):
return list( return list(
filter( filter(
lambda t_ex: not self._has_outbound_tasks(t_ex), lambda t_ex: not self._has_outbound_tasks(t_ex),
wf_utils.find_successful_task_executions(self.wf_ex) lookup_utils.find_successful_task_executions(self.wf_ex.id)
) )
) )
@ -270,64 +270,94 @@ class DirectWorkflowController(base.WorkflowController):
if not condition or expr.evaluate(condition, ctx) if not condition or expr.evaluate(condition, ctx)
] ]
@profiler.trace('direct-wf-controller-get-join-logical-state')
def _get_join_logical_state(self, task_spec): def _get_join_logical_state(self, task_spec):
"""Evaluates logical state of 'join' task.
:param task_spec: 'join' task specification.
:return: Tuple (state, state_info, spec_cardinality) where 'state' and
'state_info' describe the logical state of the given 'join'
task and 'spec_cardinality' gives the remaining number of
unfulfilled preconditions. If logical state is not WAITING then
'spec_cardinality' should always be 0.
"""
# TODO(rakhmerov): We need to use task_ex instead of task_spec # TODO(rakhmerov): We need to use task_ex instead of task_spec
# in order to cover a use case when there's more than one instance # in order to cover a use case when there's more than one instance
# of the same 'join' task in a workflow. # of the same 'join' task in a workflow.
# TODO(rakhmerov): In some cases this method will be expensive because
# it uses a multistep recursive search. We need to optimize it moving
# forward (e.g. with Workflow Execution Graph).
join_expr = task_spec.get_join() join_expr = task_spec.get_join()
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec) in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs: if not in_task_specs:
return states.RUNNING return states.RUNNING, None, 0
# List of tuples (task_name, state). # List of tuples (task_name, (state, depth)).
induced_states = [ induced_states = [
(t_s.get_name(), self._get_induced_join_state(t_s, task_spec)) (t_s.get_name(), self._get_induced_join_state(t_s, task_spec))
for t_s in in_task_specs for t_s in in_task_specs
] ]
def count(state): def count(state):
return len(list(filter(lambda s: s[1] == state, induced_states))) cnt = 0
total_depth = 0
error_count = count(states.ERROR) for s in induced_states:
running_count = count(states.RUNNING) if s[1][0] == state:
cnt += 1
total_depth += s[1][1]
return cnt, total_depth
errors_tuples = count(states.ERROR)
runnings_tuple = count(states.RUNNING)
total_count = len(induced_states) total_count = len(induced_states)
def _blocked_message(): def _blocked_message():
return ( return (
'Blocked by tasks: %s' % 'Blocked by tasks: %s' %
[s[0] for s in induced_states if s[1] == states.WAITING] [s[0] for s in induced_states if s[1][0] == states.WAITING]
) )
def _failed_message(): def _failed_message():
return ( return (
'Failed by tasks: %s' % 'Failed by tasks: %s' %
[s[0] for s in induced_states if s[1] == states.ERROR] [s[0] for s in induced_states if s[1][0] == states.ERROR]
) )
# If "join" is configured as a number or 'one'. # If "join" is configured as a number or 'one'.
if isinstance(join_expr, int) or join_expr == 'one': if isinstance(join_expr, int) or join_expr == 'one':
cardinality = 1 if join_expr == 'one' else join_expr spec_cardinality = 1 if join_expr == 'one' else join_expr
if running_count >= cardinality: if runnings_tuple[0] >= spec_cardinality:
return states.RUNNING, None return states.RUNNING, None, 0
# E.g. 'join: 3' with inbound [ERROR, ERROR, RUNNING, WAITING] # E.g. 'join: 3' with inbound [ERROR, ERROR, RUNNING, WAITING]
# No chance to get 3 RUNNING states. # No chance to get 3 RUNNING states.
if error_count > (total_count - cardinality): if errors_tuples[0] > (total_count - spec_cardinality):
return states.ERROR, _failed_message() return states.ERROR, _failed_message(), 0
return states.WAITING, _blocked_message() # Calculate how many tasks need to finish to trigger this 'join'.
cardinality = spec_cardinality - runnings_tuple[0]
return states.WAITING, _blocked_message(), cardinality
if join_expr == 'all': if join_expr == 'all':
if total_count == running_count: if total_count == runnings_tuple[0]:
return states.RUNNING, None return states.RUNNING, None, 0
if error_count > 0: if errors_tuples[0] > 0:
return states.ERROR, _failed_message() return states.ERROR, _failed_message(), 0
return states.WAITING, _blocked_message() # Remaining cardinality is just a difference between all tasks and
# a number of those tasks that induce RUNNING state.
cardinality = total_count - runnings_tuple[1]
return states.WAITING, _blocked_message(), cardinality
raise RuntimeError('Unexpected join expression: %s' % join_expr) raise RuntimeError('Unexpected join expression: %s' % join_expr)
@ -337,51 +367,54 @@ class DirectWorkflowController(base.WorkflowController):
def _get_induced_join_state(self, inbound_task_spec, join_task_spec): def _get_induced_join_state(self, inbound_task_spec, join_task_spec):
join_task_name = join_task_spec.get_name() join_task_name = join_task_spec.get_name()
in_task_ex = self._find_task_execution_by_spec(inbound_task_spec) in_task_ex = self._find_task_execution_by_name(
inbound_task_spec.get_name()
)
if not in_task_ex: if not in_task_ex:
if self._possible_route(inbound_task_spec): possible, depth = self._possible_route(inbound_task_spec)
return states.WAITING
if possible:
return states.WAITING, depth
else: else:
return states.ERROR return states.ERROR, depth
if not states.is_completed(in_task_ex.state): if not states.is_completed(in_task_ex.state):
return states.WAITING return states.WAITING, 1
if join_task_name not in self._find_next_task_names(in_task_ex): if join_task_name not in self._find_next_task_names(in_task_ex):
return states.ERROR return states.ERROR, 1
return states.RUNNING return states.RUNNING, 1
def _find_task_execution_by_spec(self, task_spec): def _find_task_execution_by_name(self, t_name):
in_t_execs = wf_utils.find_task_executions_by_spec( # Note: in case of 'join' completion check it's better to initialize
self.wf_ex, # the entire task_executions collection to avoid too many DB queries.
task_spec t_execs = lookup_utils.find_task_executions_by_name(
self.wf_ex.id,
t_name
) )
# TODO(rakhmerov): Temporary hack. See the previous comment. # TODO(rakhmerov): Temporary hack. See the previous comment.
return in_t_execs[-1] if in_t_execs else None return t_execs[-1] if t_execs else None
def _possible_route(self, task_spec): def _possible_route(self, task_spec, depth=1):
# TODO(rakhmerov): In some cases this method will be expensive because
# it uses a multistep recursive search with DB queries.
# It will be optimized with Workflow Execution Graph moving forward.
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec) in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs: if not in_task_specs:
return True return True, depth
for t_s in in_task_specs: for t_s in in_task_specs:
t_ex = self._find_task_execution_by_spec(t_s) t_ex = self._find_task_execution_by_name(t_s.get_name())
if not t_ex: if not t_ex:
if self._possible_route(t_s): if self._possible_route(t_s, depth + 1):
return True return True, depth
else: else:
t_name = task_spec.get_name() t_name = task_spec.get_name()
if (not states.is_completed(t_ex.state) or if (not states.is_completed(t_ex.state) or
t_name in self._find_next_task_names(t_ex)): t_name in self._find_next_task_names(t_ex)):
return True return True, depth
return False return False, depth

View File

@ -0,0 +1,109 @@
# Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The intention of the module is providing various DB related lookup functions
for more convenient usage withing the workflow engine.
Some of the functions may provide caching capabilities.
WARNING: Oftentimes, persistent objects returned by the methods in this
module won't be attached to the current DB SQLAlchemy session because
they are returned from the cache and therefore they need to be used
carefully without trying to do any lazy loading etc.
These objects are also not suitable for re-attaching them to a session
in order to update their persistent DB state.
Mostly, they are useful for doing any kind of fast lookups with in order
to make some decision based on their state.
"""
import cachetools
import threading
from mistral.db.v2 import api as db_api
from mistral.workflow import states
_TASK_EXECUTIONS_CACHE_LOCK = threading.RLock()
_TASK_EXECUTIONS_CACHE = cachetools.LRUCache(maxsize=20000)
def find_task_executions_by_name(wf_ex_id, task_name):
"""Finds task executions by workflow execution id and task name.
:param wf_ex_id: Workflow execution id.
:param task_name: Task name.
:return: Task executions (possibly a cached value).
"""
cache_key = (wf_ex_id, task_name)
with _TASK_EXECUTIONS_CACHE_LOCK:
t_execs = _TASK_EXECUTIONS_CACHE.get(cache_key)
if t_execs:
return t_execs
t_execs = db_api.get_task_executions(
workflow_execution_id=wf_ex_id,
name=task_name
)
# We can cache only finished tasks because they won't change.
all_finished = (
t_execs and
all([states.is_completed(t_ex.state) for t_ex in t_execs])
)
if all_finished:
with _TASK_EXECUTIONS_CACHE_LOCK:
_TASK_EXECUTIONS_CACHE[cache_key] = t_execs
return t_execs
def find_task_executions_by_spec(wf_ex_id, task_spec):
return find_task_executions_by_name(wf_ex_id, task_spec.get_name())
def find_task_executions_by_specs(wf_ex_id, task_specs):
res = []
for t_s in task_specs:
res = res + find_task_executions_by_spec(wf_ex_id, t_s)
return res
def find_task_executions_with_state(wf_ex_id, state):
return db_api.get_task_executions(
workflow_execution_id=wf_ex_id,
state=state
)
def find_successful_task_executions(wf_ex_id):
return find_task_executions_with_state(wf_ex_id, states.SUCCESS)
def find_error_task_executions(wf_ex_id):
return find_task_executions_with_state(wf_ex_id, states.ERROR)
def find_cancelled_task_executions(wf_ex_id):
return find_task_executions_with_state(wf_ex_id, states.CANCELLED)
def clean_caches():
with _TASK_EXECUTIONS_CACHE_LOCK:
_TASK_EXECUTIONS_CACHE.clear()

View File

@ -19,8 +19,8 @@ from mistral import exceptions as exc
from mistral.workflow import base from mistral.workflow import base
from mistral.workflow import commands from mistral.workflow import commands
from mistral.workflow import data_flow from mistral.workflow import data_flow
from mistral.workflow import lookup_utils
from mistral.workflow import states from mistral.workflow import states
from mistral.workflow import utils as wf_utils
class ReverseWorkflowController(base.WorkflowController): class ReverseWorkflowController(base.WorkflowController):
@ -92,13 +92,16 @@ class ReverseWorkflowController(base.WorkflowController):
return list( return list(
filter( filter(
lambda t_e: t_e.state == states.SUCCESS, lambda t_e: t_e.state == states.SUCCESS,
wf_utils.find_task_executions_by_specs(self.wf_ex, t_specs) lookup_utils.find_task_executions_by_specs(
self.wf_ex.id,
t_specs
)
) )
) )
def evaluate_workflow_final_context(self): def evaluate_workflow_final_context(self):
task_execs = wf_utils.find_task_executions_by_spec( task_execs = lookup_utils.find_task_executions_by_spec(
self.wf_ex, self.wf_ex.id,
self._get_target_task_specification() self._get_target_task_specification()
) )
@ -110,13 +113,15 @@ class ReverseWorkflowController(base.WorkflowController):
def get_logical_task_state(self, task_ex): def get_logical_task_state(self, task_ex):
# TODO(rakhmerov): Implement. # TODO(rakhmerov): Implement.
return task_ex.state, task_ex.state_info return task_ex.state, task_ex.state_info, 0
def is_error_handled_for(self, task_ex): def is_error_handled_for(self, task_ex):
return task_ex.state != states.ERROR return task_ex.state != states.ERROR
def all_errors_handled(self): def all_errors_handled(self):
return len(wf_utils.find_error_task_executions(self.wf_ex)) == 0 task_execs = lookup_utils.find_error_task_executions(self.wf_ex.id)
return len(task_execs) == 0
def _find_task_specs_with_satisfied_dependencies(self): def _find_task_specs_with_satisfied_dependencies(self):
"""Given a target task name finds tasks with no dependencies. """Given a target task name finds tasks with no dependencies.
@ -139,7 +144,8 @@ class ReverseWorkflowController(base.WorkflowController):
] ]
def _is_satisfied_task(self, task_spec): def _is_satisfied_task(self, task_spec):
if wf_utils.find_task_executions_by_spec(self.wf_ex, task_spec): if lookup_utils.find_task_executions_by_spec(
self.wf_ex.id, task_spec):
return False return False
if not self.wf_spec.get_task_requires(task_spec): if not self.wf_spec.get_task_requires(task_spec):

View File

@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mistral.db.v2 import api as db_api
from mistral.utils import serializers from mistral.utils import serializers
from mistral.workflow import states
class Result(object): class Result(object):
@ -72,46 +70,3 @@ class ResultSerializer(serializers.Serializer):
entity['error'], entity['error'],
entity.get('cancel', False) entity.get('cancel', False)
) )
def find_task_executions_by_name(wf_ex, task_name):
return db_api.get_task_executions(
workflow_execution_id=wf_ex.id,
name=task_name
)
def find_task_executions_by_spec(wf_ex, task_spec):
return find_task_executions_by_name(wf_ex, task_spec.get_name())
def find_task_executions_by_specs(wf_ex, task_specs):
res = []
for t_s in task_specs:
res = res + find_task_executions_by_spec(wf_ex, t_s)
return res
def find_task_executions_with_state(wf_ex, state):
return db_api.get_task_executions(
workflow_execution_id=wf_ex.id,
state=state
)
def find_running_task_executions(wf_ex):
return find_task_executions_with_state(wf_ex, states.RUNNING)
def find_successful_task_executions(wf_ex):
return find_task_executions_with_state(wf_ex, states.SUCCESS)
def find_error_task_executions(wf_ex):
return find_task_executions_with_state(wf_ex, states.ERROR)
def find_cancelled_task_executions(wf_ex):
return find_task_executions_with_state(wf_ex, states.CANCELLED)