Optimize the logic that check if 'join' task is allowed to start

* Moved DB related lookup functions from workflow/utils to
  a separate module lookup_utils
* Optimized data access pattern where we calculate 'join'
  states induced by upstream tasks, applied caching for some
  lookup operations
* Added caching in direct workflow specification for inbound
  and outbound task specs, if workflow size is large calculating
  them may be expensive
* Added an adaptive delay between calls that refresh 'join'
  task state based on a number of unfulfilled preconditions
  returned by workflow controller

Change-Id: I383fa52f2f05877df7522048020cc7ff280324a2
This commit is contained in:
Renat Akhmerov 2016-09-13 12:37:20 +03:00
parent a0f6c7ae3f
commit 9d06a61fe4
10 changed files with 293 additions and 122 deletions

View File

@ -73,7 +73,7 @@ def run_task(wf_cmd):
return
if task.is_waiting() and (task.is_created() or task.is_state_changed()):
_schedule_refresh_task_state(task.task_ex)
_schedule_refresh_task_state(task.task_ex, 1)
@profiler.trace('task-handler-on-action-complete')
@ -251,15 +251,27 @@ def _refresh_task_state(task_ex_id):
wf_spec
)
state, state_info = wf_ctrl.get_logical_task_state(task_ex)
state, state_info, cardinality = wf_ctrl.get_logical_task_state(
task_ex
)
if state == states.RUNNING:
continue_task(task_ex)
elif state == states.ERROR:
fail_task(task_ex, state_info)
elif state == states.WAITING:
# TODO(rakhmerov): Algorithm for increasing rescheduling delay.
_schedule_refresh_task_state(task_ex, 1)
# Let's assume that a task takes 0.01 sec in average to complete
# and based on this assumption calculate a time of the next check.
# The estimation is very rough, of course, but this delay will be
# decreasing as task preconditions will be completing which will
# give a decent asymptotic approximation.
# For example, if a 'join' task has 100 inbound incomplete tasks
# then the next 'refresh_task_state' call will happen in 10
# seconds. For 500 tasks it will be 50 seconds. The larger the
# workflow is, the more beneficial this mechanism will be.
delay = int(cardinality * 0.01)
_schedule_refresh_task_state(task_ex, max(1, delay))
else:
# Must never get here.
raise RuntimeError(

View File

@ -34,6 +34,7 @@ from mistral.workbook import parser as spec_parser
from mistral.workflow import base as wf_base
from mistral.workflow import commands
from mistral.workflow import data_flow
from mistral.workflow import lookup_utils
from mistral.workflow import states
from mistral.workflow import utils as wf_utils
@ -158,6 +159,11 @@ class Workflow(object):
assert self.wf_ex
# Since some lookup utils functions may use cache for completed tasks
# we need to clean caches to make sure that stale objects can't be
# retrieved.
lookup_utils.clean_caches()
wf_service.update_workflow_execution_env(self.wf_ex, env)
self.set_state(states.RUNNING, recursive=True)
@ -429,7 +435,7 @@ def _build_fail_info_message(wf_ctrl, wf_ex):
failed_tasks = sorted(
filter(
lambda t: not wf_ctrl.is_error_handled_for(t),
wf_utils.find_error_task_executions(wf_ex)
lookup_utils.find_error_task_executions(wf_ex.id)
),
key=lambda t: t.name
)
@ -468,7 +474,7 @@ def _build_fail_info_message(wf_ctrl, wf_ex):
def _build_cancel_info_message(wf_ctrl, wf_ex):
# Try to find where cancel is exactly.
cancelled_tasks = sorted(
wf_utils.find_cancelled_task_executions(wf_ex),
lookup_utils.find_cancelled_task_executions(wf_ex.id),
key=lambda t: t.name
)

View File

@ -36,6 +36,7 @@ from mistral.tests.unit import config as test_config
from mistral.utils import inspect_utils as i_utils
from mistral import version
from mistral.workbook import parser as spec_parser
from mistral.workflow import lookup_utils
RESOURCES_PATH = 'tests/resources/'
LOG = logging.getLogger(__name__)
@ -244,6 +245,8 @@ class DbTestCase(BaseTest):
action_manager.sync_db()
def _clean_db(self):
lookup_utils.clean_caches()
contexts = [
get_context(default=False),
get_context(default=True)

View File

@ -13,13 +13,12 @@
# limitations under the License.
from oslo_serialization import jsonutils
from stevedore import extension
import yaql
from mistral.db.v2 import api as db_api
from mistral import utils
from mistral.workflow import utils as wf_utils
from oslo_serialization import jsonutils
from stevedore import extension
ROOT_CONTEXT = None
@ -87,8 +86,6 @@ def task_(context, task_name):
# Importing data_flow in order to break cycle dependency between modules.
from mistral.workflow import data_flow
wf_ex = db_api.get_workflow_execution(context['__execution']['id'])
# This section may not exist in a context if it's calculated not in
# task scope.
cur_task = context['__task_execution']
@ -96,7 +93,10 @@ def task_(context, task_name):
if cur_task and cur_task['name'] == task_name:
task_ex = db_api.get_task_execution(cur_task['id'])
else:
task_execs = wf_utils.find_task_executions_by_name(wf_ex, task_name)
task_execs = db_api.get_task_executions(
workflow_execution_id=context['__execution']['id'],
name=task_name
)
# TODO(rakhmerov): Account for multiple executions (i.e. in case of
# cycles).

View File

@ -15,6 +15,7 @@
from oslo_utils import uuidutils
import six
import threading
from mistral import exceptions as exc
from mistral import utils
@ -150,6 +151,18 @@ class DirectWorkflowSpec(WorkflowSpec):
}
}
def __init__(self, data):
super(DirectWorkflowSpec, self).__init__(data)
# Init simple dictionary based caches for inbound and
# outbound task specifications. In fact, we don't need
# any special cache implementations here because these
# structures can't grow indefinitely.
self.inbound_tasks_cache_lock = threading.RLock()
self.inbound_tasks_cache = {}
self.outbound_tasks_cache_lock = threading.RLock()
self.outbound_tasks_cache = {}
def validate_semantics(self):
super(DirectWorkflowSpec, self).validate_semantics()
@ -211,17 +224,43 @@ class DirectWorkflowSpec(WorkflowSpec):
]
def find_inbound_task_specs(self, task_spec):
return [
task_name = task_spec.get_name()
with self.inbound_tasks_cache_lock:
specs = self.inbound_tasks_cache.get(task_name)
if specs is not None:
return specs
specs = [
t_s for t_s in self.get_tasks()
if self.transition_exists(t_s.get_name(), task_spec.get_name())
if self.transition_exists(t_s.get_name(), task_name)
]
with self.inbound_tasks_cache_lock:
self.inbound_tasks_cache[task_name] = specs
return specs
def find_outbound_task_specs(self, task_spec):
return [
task_name = task_spec.get_name()
with self.outbound_tasks_cache_lock:
specs = self.outbound_tasks_cache.get(task_name)
if specs is not None:
return specs
specs = [
t_s for t_s in self.get_tasks()
if self.transition_exists(task_spec.get_name(), t_s.get_name())
if self.transition_exists(task_name, t_s.get_name())
]
with self.outbound_tasks_cache_lock:
self.outbound_tasks_cache[task_name] = specs
return specs
def has_inbound_transitions(self, task_spec):
return len(self.find_inbound_task_specs(task_spec)) > 0

View File

@ -26,13 +26,14 @@ from mistral import utils as u
from mistral.workbook import parser as spec_parser
from mistral.workflow import commands
from mistral.workflow import data_flow
from mistral.workflow import lookup_utils
from mistral.workflow import states
from mistral.workflow import utils as wf_utils
LOG = logging.getLogger(__name__)
@profiler.trace('wf-controller-get-controller')
def get_controller(wf_ex, wf_spec=None):
"""Gets a workflow controller instance by given workflow execution object.
@ -130,8 +131,13 @@ class WorkflowController(object):
"""Determines a logical state of the given task.
:param task_ex: Task execution.
:return: Tuple (state, state_info) which the given task should have
according to workflow rules and current states of other tasks.
:return: Tuple (state, state_info, cardinality) where 'state' and
'state_info' are the corresponding values which the given
task should have according to workflow rules and current
states of other tasks. 'cardinality' gives the estimation on
the number of preconditions that are not yet met in case if
state is WAITING. This number can be used to estimate how
frequently we can refresh the state of this task.
"""
raise NotImplementedError
@ -159,7 +165,9 @@ class WorkflowController(object):
:return: True if there is one or more tasks in cancelled state.
"""
return len(wf_utils.find_cancelled_task_executions(self.wf_ex)) > 0
t_execs = lookup_utils.find_cancelled_task_executions(self.wf_ex.id)
return len(t_execs) > 0
@abc.abstractmethod
def evaluate_workflow_final_context(self):
@ -214,8 +222,8 @@ class WorkflowController(object):
return []
# Add all tasks in IDLE state.
idle_tasks = wf_utils.find_task_executions_with_state(
self.wf_ex,
idle_tasks = lookup_utils.find_task_executions_with_state(
self.wf_ex.id,
states.IDLE
)

View File

@ -13,6 +13,7 @@
# limitations under the License.
from oslo_log import log as logging
from osprofiler import profiler
from mistral import exceptions as exc
from mistral import expressions as expr
@ -20,8 +21,8 @@ from mistral import utils
from mistral.workflow import base
from mistral.workflow import commands
from mistral.workflow import data_flow
from mistral.workflow import lookup_utils
from mistral.workflow import states
from mistral.workflow import utils as wf_utils
LOG = logging.getLogger(__name__)
@ -46,8 +47,8 @@ class DirectWorkflowController(base.WorkflowController):
return list(
filter(
lambda t_e: self._is_upstream_task_execution(task_spec, t_e),
wf_utils.find_task_executions_by_specs(
self.wf_ex,
lookup_utils.find_task_executions_by_specs(
self.wf_ex.id,
self.wf_spec.find_inbound_task_specs(task_spec)
)
)
@ -60,7 +61,7 @@ class DirectWorkflowController(base.WorkflowController):
if not t_spec.get_join():
return not t_ex_candidate.processed
induced_state = self._get_induced_join_state(
induced_state, _ = self._get_induced_join_state(
self.wf_spec.get_tasks()[t_ex_candidate.name],
t_spec
)
@ -173,7 +174,7 @@ class DirectWorkflowController(base.WorkflowController):
# A simple 'non-join' task does not have any preconditions
# based on state of other tasks so its logical state always
# equals to its real state.
return task_ex.state, task_ex.state_info
return task_ex.state, task_ex.state_info, 0
return self._get_join_logical_state(task_spec)
@ -181,8 +182,7 @@ class DirectWorkflowController(base.WorkflowController):
return bool(self.wf_spec.get_on_error_clause(task_ex.name))
def all_errors_handled(self):
for t_ex in wf_utils.find_error_task_executions(self.wf_ex):
for t_ex in lookup_utils.find_error_task_executions(self.wf_ex.id):
tasks_on_error = self._find_next_tasks_for_clause(
self.wf_spec.get_on_error_clause(t_ex.name),
data_flow.evaluate_task_outbound_context(t_ex)
@ -197,7 +197,7 @@ class DirectWorkflowController(base.WorkflowController):
return list(
filter(
lambda t_ex: not self._has_outbound_tasks(t_ex),
wf_utils.find_successful_task_executions(self.wf_ex)
lookup_utils.find_successful_task_executions(self.wf_ex.id)
)
)
@ -270,64 +270,94 @@ class DirectWorkflowController(base.WorkflowController):
if not condition or expr.evaluate(condition, ctx)
]
@profiler.trace('direct-wf-controller-get-join-logical-state')
def _get_join_logical_state(self, task_spec):
"""Evaluates logical state of 'join' task.
:param task_spec: 'join' task specification.
:return: Tuple (state, state_info, spec_cardinality) where 'state' and
'state_info' describe the logical state of the given 'join'
task and 'spec_cardinality' gives the remaining number of
unfulfilled preconditions. If logical state is not WAITING then
'spec_cardinality' should always be 0.
"""
# TODO(rakhmerov): We need to use task_ex instead of task_spec
# in order to cover a use case when there's more than one instance
# of the same 'join' task in a workflow.
# TODO(rakhmerov): In some cases this method will be expensive because
# it uses a multistep recursive search. We need to optimize it moving
# forward (e.g. with Workflow Execution Graph).
join_expr = task_spec.get_join()
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs:
return states.RUNNING
return states.RUNNING, None, 0
# List of tuples (task_name, state).
# List of tuples (task_name, (state, depth)).
induced_states = [
(t_s.get_name(), self._get_induced_join_state(t_s, task_spec))
for t_s in in_task_specs
]
def count(state):
return len(list(filter(lambda s: s[1] == state, induced_states)))
cnt = 0
total_depth = 0
error_count = count(states.ERROR)
running_count = count(states.RUNNING)
for s in induced_states:
if s[1][0] == state:
cnt += 1
total_depth += s[1][1]
return cnt, total_depth
errors_tuples = count(states.ERROR)
runnings_tuple = count(states.RUNNING)
total_count = len(induced_states)
def _blocked_message():
return (
'Blocked by tasks: %s' %
[s[0] for s in induced_states if s[1] == states.WAITING]
[s[0] for s in induced_states if s[1][0] == states.WAITING]
)
def _failed_message():
return (
'Failed by tasks: %s' %
[s[0] for s in induced_states if s[1] == states.ERROR]
[s[0] for s in induced_states if s[1][0] == states.ERROR]
)
# If "join" is configured as a number or 'one'.
if isinstance(join_expr, int) or join_expr == 'one':
cardinality = 1 if join_expr == 'one' else join_expr
spec_cardinality = 1 if join_expr == 'one' else join_expr
if running_count >= cardinality:
return states.RUNNING, None
if runnings_tuple[0] >= spec_cardinality:
return states.RUNNING, None, 0
# E.g. 'join: 3' with inbound [ERROR, ERROR, RUNNING, WAITING]
# No chance to get 3 RUNNING states.
if error_count > (total_count - cardinality):
return states.ERROR, _failed_message()
if errors_tuples[0] > (total_count - spec_cardinality):
return states.ERROR, _failed_message(), 0
return states.WAITING, _blocked_message()
# Calculate how many tasks need to finish to trigger this 'join'.
cardinality = spec_cardinality - runnings_tuple[0]
return states.WAITING, _blocked_message(), cardinality
if join_expr == 'all':
if total_count == running_count:
return states.RUNNING, None
if total_count == runnings_tuple[0]:
return states.RUNNING, None, 0
if error_count > 0:
return states.ERROR, _failed_message()
if errors_tuples[0] > 0:
return states.ERROR, _failed_message(), 0
return states.WAITING, _blocked_message()
# Remaining cardinality is just a difference between all tasks and
# a number of those tasks that induce RUNNING state.
cardinality = total_count - runnings_tuple[1]
return states.WAITING, _blocked_message(), cardinality
raise RuntimeError('Unexpected join expression: %s' % join_expr)
@ -337,51 +367,54 @@ class DirectWorkflowController(base.WorkflowController):
def _get_induced_join_state(self, inbound_task_spec, join_task_spec):
join_task_name = join_task_spec.get_name()
in_task_ex = self._find_task_execution_by_spec(inbound_task_spec)
in_task_ex = self._find_task_execution_by_name(
inbound_task_spec.get_name()
)
if not in_task_ex:
if self._possible_route(inbound_task_spec):
return states.WAITING
possible, depth = self._possible_route(inbound_task_spec)
if possible:
return states.WAITING, depth
else:
return states.ERROR
return states.ERROR, depth
if not states.is_completed(in_task_ex.state):
return states.WAITING
return states.WAITING, 1
if join_task_name not in self._find_next_task_names(in_task_ex):
return states.ERROR
return states.ERROR, 1
return states.RUNNING
return states.RUNNING, 1
def _find_task_execution_by_spec(self, task_spec):
in_t_execs = wf_utils.find_task_executions_by_spec(
self.wf_ex,
task_spec
def _find_task_execution_by_name(self, t_name):
# Note: in case of 'join' completion check it's better to initialize
# the entire task_executions collection to avoid too many DB queries.
t_execs = lookup_utils.find_task_executions_by_name(
self.wf_ex.id,
t_name
)
# TODO(rakhmerov): Temporary hack. See the previous comment.
return in_t_execs[-1] if in_t_execs else None
return t_execs[-1] if t_execs else None
def _possible_route(self, task_spec):
# TODO(rakhmerov): In some cases this method will be expensive because
# it uses a multistep recursive search with DB queries.
# It will be optimized with Workflow Execution Graph moving forward.
def _possible_route(self, task_spec, depth=1):
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs:
return True
return True, depth
for t_s in in_task_specs:
t_ex = self._find_task_execution_by_spec(t_s)
t_ex = self._find_task_execution_by_name(t_s.get_name())
if not t_ex:
if self._possible_route(t_s):
return True
if self._possible_route(t_s, depth + 1):
return True, depth
else:
t_name = task_spec.get_name()
if (not states.is_completed(t_ex.state) or
t_name in self._find_next_task_names(t_ex)):
return True
return True, depth
return False
return False, depth

View File

@ -0,0 +1,109 @@
# Copyright 2015 - Mirantis, Inc.
# Copyright 2015 - StackStorm, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The intention of the module is providing various DB related lookup functions
for more convenient usage withing the workflow engine.
Some of the functions may provide caching capabilities.
WARNING: Oftentimes, persistent objects returned by the methods in this
module won't be attached to the current DB SQLAlchemy session because
they are returned from the cache and therefore they need to be used
carefully without trying to do any lazy loading etc.
These objects are also not suitable for re-attaching them to a session
in order to update their persistent DB state.
Mostly, they are useful for doing any kind of fast lookups with in order
to make some decision based on their state.
"""
import cachetools
import threading
from mistral.db.v2 import api as db_api
from mistral.workflow import states
_TASK_EXECUTIONS_CACHE_LOCK = threading.RLock()
_TASK_EXECUTIONS_CACHE = cachetools.LRUCache(maxsize=20000)
def find_task_executions_by_name(wf_ex_id, task_name):
"""Finds task executions by workflow execution id and task name.
:param wf_ex_id: Workflow execution id.
:param task_name: Task name.
:return: Task executions (possibly a cached value).
"""
cache_key = (wf_ex_id, task_name)
with _TASK_EXECUTIONS_CACHE_LOCK:
t_execs = _TASK_EXECUTIONS_CACHE.get(cache_key)
if t_execs:
return t_execs
t_execs = db_api.get_task_executions(
workflow_execution_id=wf_ex_id,
name=task_name
)
# We can cache only finished tasks because they won't change.
all_finished = (
t_execs and
all([states.is_completed(t_ex.state) for t_ex in t_execs])
)
if all_finished:
with _TASK_EXECUTIONS_CACHE_LOCK:
_TASK_EXECUTIONS_CACHE[cache_key] = t_execs
return t_execs
def find_task_executions_by_spec(wf_ex_id, task_spec):
return find_task_executions_by_name(wf_ex_id, task_spec.get_name())
def find_task_executions_by_specs(wf_ex_id, task_specs):
res = []
for t_s in task_specs:
res = res + find_task_executions_by_spec(wf_ex_id, t_s)
return res
def find_task_executions_with_state(wf_ex_id, state):
return db_api.get_task_executions(
workflow_execution_id=wf_ex_id,
state=state
)
def find_successful_task_executions(wf_ex_id):
return find_task_executions_with_state(wf_ex_id, states.SUCCESS)
def find_error_task_executions(wf_ex_id):
return find_task_executions_with_state(wf_ex_id, states.ERROR)
def find_cancelled_task_executions(wf_ex_id):
return find_task_executions_with_state(wf_ex_id, states.CANCELLED)
def clean_caches():
with _TASK_EXECUTIONS_CACHE_LOCK:
_TASK_EXECUTIONS_CACHE.clear()

View File

@ -19,8 +19,8 @@ from mistral import exceptions as exc
from mistral.workflow import base
from mistral.workflow import commands
from mistral.workflow import data_flow
from mistral.workflow import lookup_utils
from mistral.workflow import states
from mistral.workflow import utils as wf_utils
class ReverseWorkflowController(base.WorkflowController):
@ -92,13 +92,16 @@ class ReverseWorkflowController(base.WorkflowController):
return list(
filter(
lambda t_e: t_e.state == states.SUCCESS,
wf_utils.find_task_executions_by_specs(self.wf_ex, t_specs)
lookup_utils.find_task_executions_by_specs(
self.wf_ex.id,
t_specs
)
)
)
def evaluate_workflow_final_context(self):
task_execs = wf_utils.find_task_executions_by_spec(
self.wf_ex,
task_execs = lookup_utils.find_task_executions_by_spec(
self.wf_ex.id,
self._get_target_task_specification()
)
@ -110,13 +113,15 @@ class ReverseWorkflowController(base.WorkflowController):
def get_logical_task_state(self, task_ex):
# TODO(rakhmerov): Implement.
return task_ex.state, task_ex.state_info
return task_ex.state, task_ex.state_info, 0
def is_error_handled_for(self, task_ex):
return task_ex.state != states.ERROR
def all_errors_handled(self):
return len(wf_utils.find_error_task_executions(self.wf_ex)) == 0
task_execs = lookup_utils.find_error_task_executions(self.wf_ex.id)
return len(task_execs) == 0
def _find_task_specs_with_satisfied_dependencies(self):
"""Given a target task name finds tasks with no dependencies.
@ -139,7 +144,8 @@ class ReverseWorkflowController(base.WorkflowController):
]
def _is_satisfied_task(self, task_spec):
if wf_utils.find_task_executions_by_spec(self.wf_ex, task_spec):
if lookup_utils.find_task_executions_by_spec(
self.wf_ex.id, task_spec):
return False
if not self.wf_spec.get_task_requires(task_spec):

View File

@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from mistral.db.v2 import api as db_api
from mistral.utils import serializers
from mistral.workflow import states
class Result(object):
@ -72,46 +70,3 @@ class ResultSerializer(serializers.Serializer):
entity['error'],
entity.get('cancel', False)
)
def find_task_executions_by_name(wf_ex, task_name):
return db_api.get_task_executions(
workflow_execution_id=wf_ex.id,
name=task_name
)
def find_task_executions_by_spec(wf_ex, task_spec):
return find_task_executions_by_name(wf_ex, task_spec.get_name())
def find_task_executions_by_specs(wf_ex, task_specs):
res = []
for t_s in task_specs:
res = res + find_task_executions_by_spec(wf_ex, t_s)
return res
def find_task_executions_with_state(wf_ex, state):
return db_api.get_task_executions(
workflow_execution_id=wf_ex.id,
state=state
)
def find_running_task_executions(wf_ex):
return find_task_executions_with_state(wf_ex, states.RUNNING)
def find_successful_task_executions(wf_ex):
return find_task_executions_with_state(wf_ex, states.SUCCESS)
def find_error_task_executions(wf_ex):
return find_task_executions_with_state(wf_ex, states.ERROR)
def find_cancelled_task_executions(wf_ex):
return find_task_executions_with_state(wf_ex, states.CANCELLED)