Refactoring of graph.py and usage of it in scheduler

Current patch addresses several problems -
1. A lot of forced updates on every tick of scheduler are leading
to increased cpu consumption of solar-worker
2. In order to represent solar dbmodel Task using networkx interface
a lot of Task properties are duplicated and are copied each time
when graph object is created

Solving 2nd problem allows us to move update logic to scheduler,
and this will guarantee that we will update no more than reported task
+ childs of that task on each scheduler tick.

Closes-Bug: 1560059
Change-Id: I3ee368ff03b7e24e783e4a367d51e9a84b28a4d9
This commit is contained in:
Dmitry Shulyak 2016-03-18 15:54:21 +02:00
parent 8f1ca9708a
commit 16072bce2d
22 changed files with 358 additions and 373 deletions

View File

@ -32,7 +32,7 @@ class DBLayerProxy(wrapt.ObjectProxy):
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, DBLayerProxy): if not isinstance(other, DBLayerProxy):
return self.__wrapped__ == other return self.__wrapped__ == other
return self.__wrapped__ == self.__wrapped__ return self.__wrapped__ == other.__wrapped__
def __repr__(self): def __repr__(self):
return "<P: %r>" % self.__wrapped__ return "<P: %r>" % self.__wrapped__

View File

@ -1057,7 +1057,7 @@ class Task(Model):
name = Field(basestring) name = Field(basestring)
status = Field(basestring) status = Field(basestring)
target = Field(basestring, default=str) target = Field(basestring, default=str)
task_type = Field(basestring) type = Field(basestring)
args = Field(list) args = Field(list)
errmsg = Field(basestring, default=str) errmsg = Field(basestring, default=str)
timelimit = Field(int, default=int) timelimit = Field(int, default=int)
@ -1070,11 +1070,23 @@ class Task(Model):
parents = ParentField(default=list) parents = ParentField(default=list)
childs = ChildField(default=list) childs = ChildField(default=list)
type_limit = Field(int, default=int)
@classmethod @classmethod
def new(cls, data): def new(cls, data):
key = '%s~%s' % (data['execution'], data['name']) key = '%s~%s' % (data['execution'], data['name'])
return Task.from_dict(key, data) return Task.from_dict(key, data)
def __hash__(self):
return hash(self.key)
def __eq__(self, other):
if isinstance(other, basestring):
return self.key == other
return self.key == other.key
def __repr__(self):
return 'Task(execution={} name={})'.format(self.execution, self.name)
""" """
system log system log

View File

@ -106,7 +106,7 @@ class React(Event):
location_id = Resource.get(self.child).inputs[ location_id = Resource.get(self.child).inputs[
'location_id'] 'location_id']
except (DBLayerNotFound, DBLayerSolarException): except (DBLayerNotFound, DBLayerSolarException):
location_id = None location_id = ''
changes_graph.add_node( changes_graph.add_node(
self.child_node, status='PENDING', self.child_node, status='PENDING',
target=location_id, target=location_id,
@ -128,7 +128,7 @@ class StateChange(Event):
try: try:
location_id = Resource.get(self.parent).inputs['location_id'] location_id = Resource.get(self.parent).inputs['location_id']
except (DBLayerNotFound, DBLayerSolarException): except (DBLayerNotFound, DBLayerSolarException):
location_id = None location_id = ''
changes_graph.add_node( changes_graph.add_node(
self.parent_node, status='PENDING', self.parent_node, status='PENDING',
target=location_id, target=location_id,

View File

@ -18,11 +18,19 @@ from solar.orchestration.traversal import states
from solar.orchestration.traversal import VISITED from solar.orchestration.traversal import VISITED
def make_full_name(graph, name):
return '{}~{}'.format(graph.graph['uid'], name)
def get_tasks_from_names(graph, names):
return [t for t in graph.nodes() if t.name in names]
def get_dfs_postorder_subgraph(dg, nodes): def get_dfs_postorder_subgraph(dg, nodes):
result = set() result = set()
for node in nodes: for node in nodes:
result.update(nx.dfs_postorder_nodes(dg, source=node)) result.update(nx.dfs_postorder_nodes(dg, source=node))
return dg.subgraph(result) return {n for n in dg if n in result}
def end_at(dg, nodes): def end_at(dg, nodes):
@ -31,12 +39,12 @@ def end_at(dg, nodes):
dg - directed graph dg - directed graph
nodes - iterable with node names nodes - iterable with node names
""" """
return set(get_dfs_postorder_subgraph(dg.reverse(), nodes).nodes()) return get_dfs_postorder_subgraph(dg.reverse(copy=False), nodes)
def start_from(dg, start_nodes): def start_from(dg, start_nodes):
"""Ensures that all paths starting from specific *nodes* will be visited""" """Ensures that all paths starting from specific *nodes* will be visited"""
visited = {n for n in dg if dg.node[n].get('status') in VISITED} visited = {t for t in dg if t.status in VISITED}
# sorting nodes in topological order will guarantee that all predecessors # sorting nodes in topological order will guarantee that all predecessors
# of current node were already walked, when current going to be considered # of current node were already walked, when current going to be considered
@ -58,10 +66,10 @@ def validate(dg, start_nodes, end_nodes, err_msgs):
error_msgs = err_msgs[:] error_msgs = err_msgs[:]
not_in_the_graph_msg = 'Node {} is not present in graph {}' not_in_the_graph_msg = 'Node {} is not present in graph {}'
for n in start_nodes: for n in start_nodes:
if n not in dg: if make_full_name(dg, n) not in dg:
error_msgs.append(not_in_the_graph_msg.format(n, dg.graph['uid'])) error_msgs.append(not_in_the_graph_msg.format(n, dg.graph['uid']))
for n in end_nodes: for n in end_nodes:
if n not in dg: if make_full_name(dg, n) not in dg:
if start_nodes: if start_nodes:
error_msgs.append( error_msgs.append(
'No path from {} to {}'.format(start_nodes, n)) 'No path from {} to {}'.format(start_nodes, n))
@ -82,25 +90,22 @@ def filter(dg, start=None, end=None, tasks=(), skip_with=states.SKIPPED.name):
error_msgs = [] error_msgs = []
subpath = dg.nodes() subpath = dg.nodes()
if tasks: if tasks:
subpath = tasks subpath = get_tasks_from_names(dg, tasks)
else: else:
subgraph = dg subgraph = dg
if start: if start:
error_msgs = validate(subgraph, start, [], error_msgs) error_msgs = validate(subgraph, start, [], error_msgs)
if error_msgs: if error_msgs:
return error_msgs return error_msgs
subpath = start_from(subgraph, get_tasks_from_names(dg, start))
subpath = start_from(subgraph, start)
subgraph = dg.subgraph(subpath) subgraph = dg.subgraph(subpath)
if end: if end:
error_msgs = validate(subgraph, start, end, error_msgs) error_msgs = validate(subgraph, start, end, error_msgs)
if error_msgs: if error_msgs:
return error_msgs return error_msgs
subpath = end_at(subgraph, get_tasks_from_names(dg, end))
subpath = end_at(subgraph, end) for task in dg.nodes():
if task not in subpath:
for node in dg: task.status = skip_with
if node not in subpath:
dg.node[node]['status'] = skip_with
return None return None

View File

@ -28,43 +28,16 @@ from solar import utils
def save_graph(graph): def save_graph(graph):
# maybe it is possible to store part of information in AsyncResult backend
uid = graph.graph['uid']
# TODO(dshulyak) remove duplication of parameters
# in solar_models.Task and this object
for n in nx.topological_sort(graph): for n in nx.topological_sort(graph):
t = Task.new( values = {'name': n, 'execution': graph.graph['uid']}
{'name': n, values.update(graph.node[n])
'execution': uid, t = Task.new(values)
'status': graph.node[n].get('status', ''),
'target': graph.node[n].get('target', '') or '',
'task_type': graph.node[n].get('type', ''),
'args': graph.node[n].get('args', []),
'errmsg': graph.node[n].get('errmsg', '') or '',
'timelimit': graph.node[n].get('timelimit', 0),
'retry': graph.node[n].get('retry', 0),
'timeout': graph.node[n].get('timeout', 0),
'start_time': 0.0,
'end_time': 0.0})
graph.node[n]['task'] = t graph.node[n]['task'] = t
for pred in graph.predecessors(n): for pred in graph.predecessors(n):
pred_task = graph.node[pred]['task'] pred_task = graph.node[pred]['task']
t.parents.add(pred_task) t.parents.add(pred_task)
pred_task.save() pred_task.save()
t.save() t.save_lazy()
def update_graph(graph, force=False):
for n in graph:
task = graph.node[n]['task']
task.status = graph.node[n]['status']
task.errmsg = graph.node[n]['errmsg'] or ''
task.retry = graph.node[n].get('retry', 0)
task.timeout = graph.node[n].get('timeout', 0)
task.start_time = graph.node[n].get('start_time', 0.0)
task.end_time = graph.node[n].get('end_time', 0.0)
task.save(force=force)
def set_states(uid, tasks): def set_states(uid, tasks):
@ -72,31 +45,22 @@ def set_states(uid, tasks):
for t in tasks: for t in tasks:
if t not in plan.node: if t not in plan.node:
raise Exception("No task %s in plan %s", t, uid) raise Exception("No task %s in plan %s", t, uid)
plan.node[t]['task'].status = states.NOOP.name plan.node[t].status = states.NOOP.name
plan.node[t]['task'].save_lazy() plan.node[t].save_lazy()
ModelMeta.save_all_lazy()
def get_task_by_name(dg, task_name):
return next(t for t in dg.nodes() if t.name == task_name)
def get_graph(uid): def get_graph(uid):
dg = nx.MultiDiGraph() mdg = nx.MultiDiGraph()
dg.graph['uid'] = uid mdg.graph['uid'] = uid
dg.graph['name'] = uid.split(':')[0] mdg.graph['name'] = uid.split(':')[0]
tasks = map(Task.get, Task.execution.filter(uid)) mdg.add_nodes_from(Task.multi_get(Task.execution.filter(uid)))
for t in tasks: mdg.add_edges_from([(parent, task) for task in mdg.nodes()
dg.add_node( for parent in task.parents.all()])
t.name, status=t.status, return mdg
type=t.task_type, args=t.args,
target=t.target or None,
errmsg=t.errmsg or None,
task=t,
timelimit=t.timelimit,
retry=t.retry,
timeout=t.timeout,
start_time=t.start_time,
end_time=t.end_time)
for u in t.parents.all_names():
dg.add_edge(u, t.name)
return dg
def longest_path_time(graph): def longest_path_time(graph):
@ -106,8 +70,8 @@ def longest_path_time(graph):
start = float('inf') start = float('inf')
end = float('-inf') end = float('-inf')
for n in graph: for n in graph:
node_start = graph.node[n]['start_time'] node_start = n.start_time
node_end = graph.node[n]['end_time'] node_end = n.end_time
if int(node_start) == 0 or int(node_end) == 0: if int(node_start) == 0 or int(node_end) == 0:
continue continue
@ -122,8 +86,8 @@ def longest_path_time(graph):
def total_delta(graph): def total_delta(graph):
delta = 0.0 delta = 0.0
for n in graph: for n in graph:
node_start = graph.node[n]['start_time'] node_start = n.start_time
node_end = graph.node[n]['end_time'] node_end = n.end_time
if int(node_start) == 0 or int(node_end) == 0: if int(node_start) == 0 or int(node_end) == 0:
continue continue
delta += node_end - node_start delta += node_end - node_start
@ -153,11 +117,13 @@ def parse_plan(plan_path):
return dg return dg
def create_plan_from_graph(dg, save=True): def create_plan_from_graph(dg):
dg.graph['uid'] = "{0}:{1}".format(dg.graph['name'], str(uuid.uuid4())) dg.graph['uid'] = "{0}:{1}".format(dg.graph['name'], str(uuid.uuid4()))
if save: # FIXME change save_graph api to return new graph with Task objects
save_graph(dg) # included
return dg save_graph(dg)
ModelMeta.save_all_lazy()
return get_graph(dg.graph['uid'])
def show(uid): def show(uid):
@ -166,21 +132,19 @@ def show(uid):
tasks = [] tasks = []
result['uid'] = dg.graph['uid'] result['uid'] = dg.graph['uid']
result['name'] = dg.graph['name'] result['name'] = dg.graph['name']
for n in nx.topological_sort(dg): for task in nx.topological_sort(dg):
data = dg.node[n]
tasks.append( tasks.append(
{'uid': n, {'uid': task.name,
'parameters': data, 'parameters': task.to_dict(),
'before': dg.successors(n), 'before': dg.successors(task),
'after': dg.predecessors(n) 'after': dg.predecessors(task)
}) })
result['tasks'] = tasks result['tasks'] = tasks
return utils.yaml_dump(result) return utils.yaml_dump(result)
def create_plan(plan_path, save=True): def create_plan(plan_path):
dg = parse_plan(plan_path) return create_plan_from_graph(parse_plan(plan_path))
return create_plan_from_graph(dg, save=save)
def reset_by_uid(uid, state_list=None): def reset_by_uid(uid, state_list=None):
@ -190,11 +154,11 @@ def reset_by_uid(uid, state_list=None):
def reset(graph, state_list=None): def reset(graph, state_list=None):
for n in graph: for n in graph:
if state_list is None or graph.node[n]['status'] in state_list: if state_list is None or n.status in state_list:
graph.node[n]['status'] = states.PENDING.name n.status = states.PENDING.name
graph.node[n]['start_time'] = 0.0 n.start_time = 0.0
graph.node[n]['end_time'] = 0.0 n.end_time = 0.0
update_graph(graph) n.save_lazy()
def reset_filtered(uid): def reset_filtered(uid):
@ -212,14 +176,14 @@ def report_progress_graph(dg):
'total_delta': total_delta(dg), 'total_delta': total_delta(dg),
'tasks': tasks} 'tasks': tasks}
# FIXME just return topologically sorted list of tasks
for task in nx.topological_sort(dg): for task in nx.topological_sort(dg):
data = dg.node[task]
tasks.append([ tasks.append([
task, task.name,
data['status'], task.status,
data['errmsg'], task.errmsg,
data.get('start_time'), task.start_time,
data.get('end_time')]) task.end_time])
return report return report
@ -237,7 +201,7 @@ def wait_finish(uid, timeout):
dg = get_graph(uid) dg = get_graph(uid)
summary = Counter() summary = Counter()
summary.update({s.name: 0 for s in states}) summary.update({s.name: 0 for s in states})
summary.update([s['status'] for s in dg.node.values()]) summary.update([task.status for task in dg.nodes()])
yield summary yield summary
if summary[states.PENDING.name] + summary[states.INPROGRESS.name] == 0: if summary[states.PENDING.name] + summary[states.INPROGRESS.name] == 0:
return return

View File

@ -52,27 +52,25 @@ def type_based_rule(dg, inprogress, item):
condition should be specified like: condition should be specified like:
type_limit: 2 type_limit: 2
""" """
_type = dg.node[item].get('resource_type') if not item.type_limit:
if 'type_limit' not in dg.node[item]:
return True return True
if not _type: if not item.resource_type:
return True return True
type_count = 0 type_count = 0
for n in inprogress: for task in inprogress:
if dg.node[n].get('resource_type') == _type: if task.resource_type == item.resource_type:
type_count += 1 type_count += 1
return dg.node[item]['type_limit'] > type_count return item.type_limit > type_count
def target_based_rule(dg, inprogress, item, limit=1): def target_based_rule(dg, inprogress, item, limit=1):
target = dg.node[item].get('target') if not item.target:
if not target:
return True return True
target_count = 0 target_count = 0
for n in inprogress: for n in inprogress:
if dg.node[n].get('target') == target: if n.target == item.target:
target_count += 1 target_count += 1
return limit > target_count return limit > target_count

View File

@ -39,20 +39,12 @@ VISITED = (states.SUCCESS.name, states.NOOP.name)
BLOCKED = (states.INPROGRESS.name, states.SKIPPED.name, states.ERROR.name) BLOCKED = (states.INPROGRESS.name, states.SKIPPED.name, states.ERROR.name)
def traverse(dg): def find_visitable_tasks(dg):
"""Filter to find tasks that satisfy next conditions:
visited = set() - task is not in VISITED or BLOCKED state
for node in dg: - all predecessors of task can be considered visited
data = dg.node[node] """
if data['status'] in VISITED: visited = set([t for t in dg if t.status in VISITED])
visited.add(node) return [t for t in dg
rst = [] if (not (t in visited or t.status in BLOCKED)
for node in dg: and set(dg.predecessors(t)) <= visited)]
data = dg.node[node]
if node in visited or data['status'] in BLOCKED:
continue
if set(dg.predecessors(node)) <= visited:
rst.append(node)
return rst

View File

@ -22,6 +22,10 @@ def write_graph(plan):
:param plan: networkx Graph object :param plan: networkx Graph object
""" """
names_only = nx.MultiDiGraph()
names_only.add_nodes_from([n.name for n in plan.nodes()])
names_only.add_edges_from([(n.name, s.name) for n in plan.nodes()
for s in plan.successors(n)])
colors = { colors = {
'PENDING': 'cyan', 'PENDING': 'cyan',
'ERROR': 'red', 'ERROR': 'red',
@ -30,11 +34,11 @@ def write_graph(plan):
'SKIPPED': 'blue', 'SKIPPED': 'blue',
'NOOP': 'black'} 'NOOP': 'black'}
for n in plan: for n in plan.nodes():
color = colors[plan.node[n]['status']] names_only.node[n.name]['color'] = colors[n.status]
plan.node[n]['color'] = color
nx.nx_pydot.write_dot(plan, '{name}.dot'.format(name=plan.graph['name'])) nx.nx_pydot.write_dot(names_only,
'{name}.dot'.format(name=plan.graph['name']))
subprocess.call( subprocess.call(
'tred {name}.dot | dot -Tsvg -o {name}.svg'.format( 'tred {name}.dot | dot -Tsvg -o {name}.svg'.format(
name=plan.graph['name']), name=plan.graph['name']),

View File

@ -18,10 +18,11 @@ import time
from solar.core.log import log from solar.core.log import log
from solar.dblayer.locking import Lock from solar.dblayer.locking import Lock
from solar.dblayer.locking import Waiter from solar.dblayer.locking import Waiter
from solar.dblayer.model import ModelMeta
from solar.orchestration import graph from solar.orchestration import graph
from solar.orchestration import limits from solar.orchestration import limits
from solar.orchestration.traversal import find_visitable_tasks
from solar.orchestration.traversal import states from solar.orchestration.traversal import states
from solar.orchestration.traversal import traverse
from solar.orchestration.traversal import VISITED from solar.orchestration.traversal import VISITED
from solar.orchestration.workers import base from solar.orchestration.workers import base
from solar.utils import get_current_ident from solar.utils import get_current_ident
@ -34,13 +35,10 @@ class Scheduler(base.Worker):
super(Scheduler, self).__init__() super(Scheduler, self).__init__()
def _next(self, plan): def _next(self, plan):
tasks = traverse(plan) return list(limits.get_default_chain(
filtered_tasks = list(limits.get_default_chain(
plan, plan,
[t for t in plan [t for t in plan if t.status == states.INPROGRESS.name],
if plan.node[t]['status'] == states.INPROGRESS.name], find_visitable_tasks(plan)))
tasks))
return filtered_tasks
def next(self, ctxt, plan_uid): def next(self, ctxt, plan_uid):
with Lock( with Lock(
@ -51,15 +49,16 @@ class Scheduler(base.Worker):
): ):
log.debug('Received *next* event for %s', plan_uid) log.debug('Received *next* event for %s', plan_uid)
plan = graph.get_graph(plan_uid) plan = graph.get_graph(plan_uid)
# FIXME get_graph should raise DBNotFound if graph is not
# created
if len(plan) == 0: if len(plan) == 0:
raise ValueError('Plan {} is empty'.format(plan_uid)) raise ValueError('Plan {} is empty'.format(plan_uid))
rst = self._next(plan) tasks_to_schedule = self._next(plan)
for task_name in rst: for task in tasks_to_schedule:
self._do_scheduling(plan, task_name) self._do_scheduling(task)
graph.update_graph(plan) log.debug('Scheduled tasks %r', tasks_to_schedule)
log.debug('Scheduled tasks %r', rst) ModelMeta.save_all_lazy()
# process tasks with tasks client return tasks_to_schedule
return rst
def soft_stop(self, ctxt, plan_uid): def soft_stop(self, ctxt, plan_uid):
with Lock( with Lock(
@ -68,63 +67,56 @@ class Scheduler(base.Worker):
retries=20, retries=20,
waiter=Waiter(1) waiter=Waiter(1)
): ):
plan = graph.get_graph(plan_uid) for task in graph.get_graph(plan_uid):
for n in plan: if task.status in (
if plan.node[n]['status'] in (
states.PENDING.name, states.ERROR_RETRY.name): states.PENDING.name, states.ERROR_RETRY.name):
plan.node[n]['status'] = states.SKIPPED.name task.status = states.SKIPPED.name
graph.update_graph(plan) task.save_lazy()
def _do_update(self, plan, task_name, status, errmsg=''): def _do_update(self, task, status, errmsg=''):
"""For single update correct state and other relevant data.""" """For single update correct state and other relevant data."""
old_status = plan.node[task_name]['status'] if task.status in VISITED:
if old_status in VISITED:
log.debug( log.debug(
'Task %s already in visited status %s' 'Task %s already in visited status %s'
', skipping update to %s', ', skipping update to %s',
task_name, old_status, status) task.name, task.status, status)
return return
retries_count = plan.node[task_name]['retry']
if status == states.ERROR.name and retries_count > 0: if status == states.ERROR.name and task.retry > 0:
retries_count -= 1 task.retry -= 1
status = states.ERROR_RETRY.name status = states.ERROR_RETRY.name
log.debug('Retry task %s in plan %s, retries left %s', log.debug('Retry task %s in plan %s, retries left %s',
task_name, plan.graph['uid'], retries_count) task.name, task.execution, task.retry)
else: else:
plan.node[task_name]['end_time'] = time.time() task.end_time = time.time()
plan.node[task_name]['status'] = status task.status = status
plan.node[task_name]['errmsg'] = errmsg task.errmsg = errmsg
plan.node[task_name]['retry'] = retries_count task.save_lazy()
def _do_scheduling(self, plan, task_name): def _do_scheduling(self, task):
task_id = '{}:{}'.format(plan.graph['uid'], task_name) task.status = states.INPROGRESS.name
task_type = plan.node[task_name]['type'] task.start_time = time.time()
plan.node[task_name]['status'] = states.INPROGRESS.name
plan.node[task_name]['start_time'] = time.time()
plan.node[task_name]['end_time'] = 0.0
timelimit = plan.node[task_name].get('timelimit', 0)
timeout = plan.node[task_name].get('timeout', 0)
ctxt = { ctxt = {
'task_id': task_id, 'task_id': task.key,
'task_name': task_name, 'task_name': task.name,
'plan_uid': plan.graph['uid'], 'plan_uid': task.execution,
'timelimit': timelimit, 'timelimit': task.timelimit,
'timeout': timeout} 'timeout': task.timeout}
log.debug( log.debug(
'Timelimit for task %s - %s, timeout - %s', 'Timelimit for task %s - %s, timeout - %s',
task_id, timelimit, timeout) task, task.timelimit, task.timeout)
task.save_lazy()
self._tasks( self._tasks(
task_type, ctxt, task.type, ctxt,
*plan.node[task_name]['args']) *task.args)
if timeout: if task.timeout:
self._configure_timeout(ctxt, timeout) self._configure_timeout(ctxt, task.timeout)
def update_next(self, ctxt, status, errmsg): def update_next(self, ctxt, status, errmsg):
log.debug( log.debug(
'Received update for TASK %s - %s %s', 'Received update for TASK %s - %s %s',
ctxt['task_id'], status, errmsg) ctxt['task_id'], status, errmsg)
plan_uid, task_name = ctxt['task_id'].rsplit(':', 1) plan_uid, task_name = ctxt['task_id'].rsplit('~', 1)
with Lock( with Lock(
plan_uid, plan_uid,
str(get_current_ident()), str(get_current_ident()),
@ -132,13 +124,14 @@ class Scheduler(base.Worker):
waiter=Waiter(1) waiter=Waiter(1)
): ):
plan = graph.get_graph(plan_uid) plan = graph.get_graph(plan_uid)
self._do_update(plan, task_name, status, errmsg=errmsg) task = next(t for t in plan.nodes() if t.name == task_name)
rst = self._next(plan) self._do_update(task, status, errmsg=errmsg)
for task_name in rst: tasks_to_schedule = self._next(plan)
self._do_scheduling(plan, task_name) for task in tasks_to_schedule:
graph.update_graph(plan) self._do_scheduling(task)
log.debug('Scheduled tasks %r', rst) log.debug('Scheduled tasks %r', tasks_to_schedule)
return rst ModelMeta.save_all_lazy()
return tasks_to_schedule
def _configure_timeout(self, ctxt, timeout): def _configure_timeout(self, ctxt, timeout):
if not hasattr(self._executor, 'register_timeout'): if not hasattr(self._executor, 'register_timeout'):

View File

@ -21,10 +21,10 @@ from solar.system_log.operations import set_error
class SystemLog(base.Worker): class SystemLog(base.Worker):
def commit(self, ctxt, *args, **kwargs): def commit(self, ctxt, *args, **kwargs):
return move_to_commited(ctxt['task_id'].rsplit(':', 1)[-1]) return move_to_commited(ctxt['task_id'].rsplit('~', 1)[-1])
def error(self, ctxt, *args, **kwargs): def error(self, ctxt, *args, **kwargs):
return set_error(ctxt['task_id'].rsplit(':', 1)[-1]) return set_error(ctxt['task_id'].rsplit('~', 1)[-1])
def tasks_subscribe(tasks, clients): def tasks_subscribe(tasks, clients):

View File

@ -18,6 +18,7 @@ import tempfile
import time import time
import unittest import unittest
import networkx as nx
import yaml import yaml
from solar.core.resource import composer as cr from solar.core.resource import composer as cr
@ -57,3 +58,9 @@ class BaseResourceTest(unittest.TestCase):
def create_resource(self, name, src, args=None): def create_resource(self, name, src, args=None):
args = args or {} args = args or {}
return cr.create(name, src, inputs=args)[0] return cr.create(name, src, inputs=args)[0]
def compare_task_to_names(tasks, names):
if isinstance(tasks, nx.DiGraph):
tasks = tasks.nodes()
assert {t.name for t in tasks} == names

View File

@ -31,7 +31,7 @@ def test_simple_fixture(simple_plan, scheduler, tasks):
expected = [['echo_stuff'], ['just_fail'], []] expected = [['echo_stuff'], ['just_fail'], []]
def register(ctxt, rst, *args, **kwargs): def register(ctxt, rst, *args, **kwargs):
scheduling_results.append(rst) scheduling_results.append([t.name for t in rst])
worker.for_all.on_success(register) worker.for_all.on_success(register)
def _result_waiter(): def _result_waiter():
@ -47,7 +47,7 @@ def test_sequential_fixture(sequential_plan, scheduler, tasks):
expected = {('s1',), ('s2',), ('s3',), ()} expected = {('s1',), ('s2',), ('s3',), ()}
def register(ctxt, rst, *args, **kwargs): def register(ctxt, rst, *args, **kwargs):
scheduling_results.add(tuple(rst)) scheduling_results.add(tuple(t.name for t in rst))
worker.for_all.on_success(register) worker.for_all.on_success(register)
def _result_waiter(): def _result_waiter():

View File

@ -24,8 +24,10 @@ from solar.orchestration.traversal import states
@pytest.fixture @pytest.fixture
def simple_plan_retries(simple_plan): def simple_plan_retries(simple_plan):
simple_plan.node['just_fail']['retry'] = 1 fail_task = next(t for t in simple_plan.nodes()
graph.update_graph(simple_plan, force=True) if t.name == 'just_fail')
fail_task.retry = 1
fail_task.save()
return simple_plan return simple_plan

View File

@ -39,14 +39,17 @@ def test_timelimit_plan(timelimit_plan, scheduler, tasks):
waiter = gevent.spawn(wait_function, 3) waiter = gevent.spawn(wait_function, 3)
waiter.join(timeout=3) waiter.join(timeout=3)
finished_plan = graph.get_graph(timelimit_plan.graph['uid']) finished_plan = graph.get_graph(timelimit_plan.graph['uid'])
assert 'ExecutionTimeout' in finished_plan.node['t1']['errmsg'] t1 = graph.get_task_by_name(finished_plan, 't1')
assert finished_plan.node['t2']['status'] == states.PENDING.name t2 = graph.get_task_by_name(finished_plan, 't2')
assert 'ExecutionTimeout' in t1.errmsg
assert t2.status == states.PENDING.name
@pytest.fixture @pytest.fixture
def timeout_plan(simple_plan): def timeout_plan(simple_plan):
simple_plan.node['echo_stuff']['timeout'] = 1 echo_task = graph.get_task_by_name(simple_plan, 'echo_stuff')
graph.update_graph(simple_plan, force=True) echo_task.timeout = 1
echo_task.save()
return simple_plan return simple_plan
@ -65,5 +68,5 @@ def test_timeout_plan(timeout_plan, scheduler):
waiter = gevent.spawn(wait_function, 2) waiter = gevent.spawn(wait_function, 2)
waiter.get(block=True, timeout=2) waiter.get(block=True, timeout=2)
timeout_plan = graph.get_graph(timeout_plan.graph['uid']) timeout_plan = graph.get_graph(timeout_plan.graph['uid'])
assert (timeout_plan.node['echo_stuff']['status'] echo_task = graph.get_task_by_name(timeout_plan, 'echo_stuff')
== states.ERROR.name) assert echo_task.status == states.ERROR.name

View File

@ -8,7 +8,6 @@ tasks:
args: args:
- node3 - node3
- run - run
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: node3.run uid: node3.run
@ -20,7 +19,6 @@ tasks:
args: args:
- hosts_file3 - hosts_file3
- run - run
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: hosts_file3.run uid: hosts_file3.run
@ -32,7 +30,6 @@ tasks:
args: args:
- node2 - node2
- run - run
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: node2.run uid: node2.run
@ -44,7 +41,6 @@ tasks:
args: args:
- node1 - node1
- run - run
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: node1.run uid: node1.run
@ -56,7 +52,6 @@ tasks:
args: args:
- hosts_file2 - hosts_file2
- run - run
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: hosts_file2.run uid: hosts_file2.run
@ -68,7 +63,6 @@ tasks:
args: args:
- hosts_file1 - hosts_file1
- run - run
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: hosts_file1.run uid: hosts_file1.run
@ -82,7 +76,6 @@ tasks:
args: args:
- riak_service1 - riak_service1
- run - run
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: riak_service1.run uid: riak_service1.run
@ -96,7 +89,6 @@ tasks:
args: args:
- riak_service3 - riak_service3
- run - run
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: riak_service3.run uid: riak_service3.run
@ -108,7 +100,6 @@ tasks:
args: args:
- riak_service3 - riak_service3
- join - join
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: riak_service3.join uid: riak_service3.join
@ -122,7 +113,6 @@ tasks:
args: args:
- riak_service2 - riak_service2
- run - run
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: riak_service2.run uid: riak_service2.run
@ -134,7 +124,6 @@ tasks:
args: args:
- riak_service2 - riak_service2
- join - join
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: riak_service2.join uid: riak_service2.join
@ -146,9 +135,7 @@ tasks:
args: args:
- riak_service1 - riak_service1
- commit - commit
errmsg: null
status: PENDING status: PENDING
type: solar_resource type: solar_resource
uid: riak_service1.commit uid: riak_service1.commit
uid: system_log:565581a1-80a0-425d-bb5c-d1cc4f48ffda uid: system_log:565581a1-80a0-425d-bb5c-d1cc4f48ffda

View File

@ -26,8 +26,12 @@ def test_longest_path_time_returns_0_for_empty_graph():
def test_reset_resets_times(): def test_reset_resets_times():
g = nx.MultiDiGraph() g = nx.MultiDiGraph()
g.add_node('task1', task=mock.Mock(), status='status', errmsg='', task = mock.Mock(
start_time=1, end_time=4) name='task1',
status='status',
errmsg='',
start_time=1, end_time=4)
g.add_node(task)
graph.reset(g) graph.reset(g)
assert int(g.node['task1']['start_time']) == 0 for n in g.nodes():
assert int(g.node['task1']['start_time']) == 0 assert n.start_time == 0

View File

@ -26,7 +26,6 @@ def test_scheduler_next_fails_with_empty_plan():
def test_soft_stop(simple_plan): def test_soft_stop(simple_plan):
# graph.save_graph(simple_plan)
uid = simple_plan.graph['uid'] uid = simple_plan.graph['uid']
scheduler = Scheduler(None) scheduler = Scheduler(None)
@ -34,4 +33,4 @@ def test_soft_stop(simple_plan):
plan = graph.get_graph(uid) plan = graph.get_graph(uid)
for n in plan: for n in plan:
assert plan.node[n]['status'] == states.SKIPPED.name assert n.status == states.SKIPPED.name

View File

@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from mock import Mock
import networkx as nx import networkx as nx
from pytest import fixture from pytest import fixture
@ -21,32 +22,39 @@ from solar.orchestration.traversal import states
def test_simple_plan_plan_created_and_loaded(simple_plan): def test_simple_plan_plan_created_and_loaded(simple_plan):
plan = graph.get_plan(simple_plan.graph['uid']) plan = graph.get_plan(simple_plan.graph['uid'])
assert set(plan.nodes()) == {'just_fail', 'echo_stuff'} expected_names = {n.name for n in plan.nodes()}
assert expected_names == {'just_fail', 'echo_stuff'}
def test_reset_all_states(simple_plan): def test_reset_all_states(simple_plan):
for n in simple_plan: for n in simple_plan:
simple_plan.node[n]['status'] == states.ERROR.name n.status == states.ERROR.name
graph.reset(simple_plan) graph.reset(simple_plan)
for n in simple_plan: for n in simple_plan:
assert simple_plan.node[n]['status'] == states.PENDING.name assert n.status == states.PENDING.name
def test_reset_only_provided(simple_plan): def test_reset_only_provided(simple_plan):
simple_plan.node['just_fail']['status'] = states.ERROR.name for n in simple_plan.nodes():
simple_plan.node['echo_stuff']['status'] = states.SUCCESS.name if n.name == 'just_fail':
n.status = states.ERROR.name
elif n.name == 'echo_stuff':
n.status = states.SUCCESS.name
graph.reset(simple_plan, [states.ERROR.name]) graph.reset(simple_plan, [states.ERROR.name])
assert simple_plan.node['just_fail']['status'] == states.PENDING.name for n in simple_plan.nodes():
assert simple_plan.node['echo_stuff']['status'] == states.SUCCESS.name if n.name == 'just_fail':
assert n.status == states.PENDING.name
elif n.name == 'echo_stuff':
assert n.status == states.SUCCESS.name
def test_wait_finish(simple_plan): def test_wait_finish(simple_plan):
for n in simple_plan: for n in simple_plan:
simple_plan.node[n]['status'] = states.SUCCESS.name n.status = states.SUCCESS.name
graph.update_graph(simple_plan) n.save()
assert next(graph.wait_finish(simple_plan.graph['uid'], 10)) == { assert next(graph.wait_finish(simple_plan.graph['uid'], 10)) == {
'SKIPPED': 0, 'SKIPPED': 0,
'SUCCESS': 2, 'SUCCESS': 2,
@ -59,8 +67,10 @@ def test_wait_finish(simple_plan):
def test_several_updates(simple_plan): def test_several_updates(simple_plan):
simple_plan.node['just_fail']['status'] = states.ERROR.name just_fail_task = next(t for t in simple_plan.nodes()
graph.update_graph(simple_plan) if t.name == 'just_fail')
just_fail_task.status = states.ERROR.name
just_fail_task.save()
assert next(graph.wait_finish(simple_plan.graph['uid'], 10)) == { assert next(graph.wait_finish(simple_plan.graph['uid'], 10)) == {
'SKIPPED': 0, 'SKIPPED': 0,
@ -72,8 +82,10 @@ def test_several_updates(simple_plan):
'ERROR_RETRY': 0, 'ERROR_RETRY': 0,
} }
simple_plan.node['echo_stuff']['status'] = states.ERROR.name echo_task = next(t for t in simple_plan.nodes()
graph.update_graph(simple_plan) if t.name == 'echo_stuff')
echo_task.status = states.ERROR.name
echo_task.save()
assert next(graph.wait_finish(simple_plan.graph['uid'], 10)) == { assert next(graph.wait_finish(simple_plan.graph['uid'], 10)) == {
'SKIPPED': 0, 'SKIPPED': 0,
@ -89,18 +101,19 @@ def test_several_updates(simple_plan):
@fixture @fixture
def times(): def times():
rst = nx.DiGraph() rst = nx.DiGraph()
rst.add_node('t1', start_time=1.0, end_time=12.0, t1 = Mock(name='t1', start_time=1.0, end_time=12.0,
status='', errmsg='') status='', errmsg='')
rst.add_node('t2', start_time=1.0, end_time=3.0, t2 = Mock(name='t2', start_time=1.0, end_time=3.0,
status='', errmsg='') status='', errmsg='')
rst.add_node('t3', start_time=3.0, end_time=7.0, t3 = Mock(name='t3', start_time=3.0, end_time=7.0,
status='', errmsg='') status='', errmsg='')
rst.add_node('t4', start_time=7.0, end_time=13.0, t4 = Mock(name='t4', start_time=7.0, end_time=13.0,
status='', errmsg='') status='', errmsg='')
rst.add_node('t5', start_time=12.0, end_time=14.0, t5 = Mock(name='t5', start_time=12.0, end_time=14.0,
status='', errmsg='') status='', errmsg='')
rst.add_path(['t1', 't5']) rst.add_nodes_from([t1, t2, t3, t4, t5])
rst.add_path(['t2', 't3', 't4']) rst.add_path([t1, t5])
rst.add_path([t2, t3, t4])
return rst return rst

View File

@ -12,15 +12,13 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import os
import networkx as nx import networkx as nx
from pytest import fixture from pytest import fixture
from pytest import mark from pytest import mark
from solar.orchestration import filters from solar.orchestration import filters
from solar.orchestration import graph
from solar.orchestration.traversal import states from solar.orchestration.traversal import states
from solar.test.base import compare_task_to_names
@fixture @fixture
@ -39,80 +37,55 @@ def dg_ex1():
(['n4', 'n5'], {'n1', 'n2', 'n3', 'n4', 'n5'}), (['n4', 'n5'], {'n1', 'n2', 'n3', 'n4', 'n5'}),
]) ])
def test_end_at(dg_ex1, end_nodes, visited): def test_end_at(dg_ex1, end_nodes, visited):
assert set(filters.end_at(dg_ex1, end_nodes)) == visited assert filters.end_at(dg_ex1, end_nodes) == visited
@mark.parametrize("start_nodes,visited", [
(['n3'], {'n3'}), (['n1'], {'n1', 'n2', 'n4'}),
(['n1', 'n3'], {'n1', 'n2', 'n3', 'n4', 'n5'})
])
def test_start_from(dg_ex1, start_nodes, visited):
assert set(filters.start_from(dg_ex1, start_nodes)) == visited
@fixture
def dg_ex2():
dg = nx.DiGraph()
dg.add_nodes_from(['n1', 'n2', 'n3', 'n4', 'n5'])
dg.add_edges_from([('n1', 'n3'), ('n2', 'n3'), ('n3', 'n4'), ('n3', 'n5')])
return dg
@fixture
def riak_plan():
riak_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'orch_fixtures',
'riak.yaml')
return graph.create_plan(riak_path, save=False)
def test_riak_start_node1(riak_plan): def test_riak_start_node1(riak_plan):
assert filters.start_from(riak_plan, ['node1.run']) == { start_tasks = filters.get_tasks_from_names(riak_plan, ['node1.run'])
'node1.run', 'hosts_file1.run', 'riak_service1.run' compare_task_to_names(
} filters.start_from(riak_plan, start_tasks),
{'node1.run', 'hosts_file1.run', 'riak_service1.run'})
def test_riak_end_hosts_file1(riak_plan): def test_riak_end_hosts_file1(riak_plan):
assert filters.end_at(riak_plan, ['hosts_file1.run']) == { compare_task_to_names(filters.end_at(
'node1.run', 'hosts_file1.run' riak_plan,
} filters.get_tasks_from_names(riak_plan, ['hosts_file1.run'])),
{'node1.run', 'hosts_file1.run'})
def test_start_at_two_nodes(riak_plan): def test_start_at_two_nodes(riak_plan):
assert filters.start_from(riak_plan, ['node1.run', 'node2.run']) == \ compare_task_to_names(filters.start_from(
riak_plan,
filters.get_tasks_from_names(riak_plan, ['node1.run', 'node2.run'])),
{'hosts_file1.run', 'riak_service2.run', 'riak_service2.join', {'hosts_file1.run', 'riak_service2.run', 'riak_service2.join',
'hosts_file2.run', 'node2.run', 'riak_service1.run', 'node1.run'} 'hosts_file2.run', 'node2.run', 'riak_service1.run', 'node1.run'})
def test_initial_from_node1_traverse(riak_plan): def test_initial_from_node1_traverse(riak_plan):
filters.filter(riak_plan, start=['node1.run']) filters.filter(riak_plan, start=['node1.run'])
pending = {n compare_task_to_names(
for n in riak_plan {t for t in riak_plan if t.status == states.PENDING.name},
if riak_plan.node[ {'hosts_file1.run', 'riak_service1.run', 'node1.run'})
n]['status'] == states.PENDING.name}
assert pending == {'hosts_file1.run', 'riak_service1.run', 'node1.run'}
def test_second_from_node2_with_node1_walked(riak_plan): def test_second_from_node2_with_node1_walked(riak_plan):
success = {'hosts_file1.run', 'riak_service1.run', 'node1.run'} success = {'hosts_file1.run', 'riak_service1.run', 'node1.run'}
for n in success: for task in riak_plan.nodes():
riak_plan.node[n]['status'] = states.SUCCESS.name if task.name in success:
task.status = states.SUCCESS.name
filters.filter(riak_plan, start=['node2.run']) filters.filter(riak_plan, start=['node2.run'])
pending = {n compare_task_to_names(
for n in riak_plan {t for t in riak_plan if t.status == states.PENDING.name},
if riak_plan.node[ {'hosts_file2.run', 'riak_service2.run', 'node2.run',
n]['status'] == states.PENDING.name} 'riak_service2.join'})
assert pending == {'hosts_file2.run', 'riak_service2.run', 'node2.run',
'riak_service2.join'}
def test_end_joins(riak_plan): def test_end_joins(riak_plan):
filters.filter(riak_plan, filters.filter(riak_plan,
start=['node1.run', 'node2.run', 'node3.run'], start=['node1.run', 'node2.run', 'node3.run'],
end=['riak_service2.join', 'riak_service3.join']) end=['riak_service2.join', 'riak_service3.join'])
skipped = {n compare_task_to_names(
for n in riak_plan {n for n in riak_plan if n.status == states.SKIPPED.name},
if riak_plan.node[ {'riak_service1.commit'})
n]['status'] == states.SKIPPED.name}
assert skipped == {'riak_service1.commit'}

View File

@ -14,6 +14,7 @@
import os import os
from mock import Mock
import networkx as nx import networkx as nx
from pytest import fixture from pytest import fixture
@ -22,56 +23,58 @@ from solar.orchestration import limits
@fixture @fixture
def dg(): def t1():
ex = nx.DiGraph() return Mock(name='t1',
ex.add_node('t1',
status='PENDING', status='PENDING',
target='1', target='1',
resource_type='node', resource_type='node',
type_limit=2) type_limit=2)
ex.add_node('t2',
status='PENDING',
target='1',
resource_type='node',
type_limit=2)
ex.add_node('t3',
status='PENDING',
target='1',
resource_type='node',
type_limit=2)
return ex
def test_target_rule(dg):
assert limits.target_based_rule(dg, [], 't1') is True
assert limits.target_based_rule(dg, ['t1'], 't2') is False
def test_type_limit_rule(dg):
assert limits.type_based_rule(dg, ['t1'], 't2') is True
assert limits.type_based_rule(dg, ['t1', 't2'], 't3') is False
def test_items_rule(dg):
assert limits.items_rule(dg, ['1'] * 99, '2')
assert limits.items_rule(dg, ['1'] * 99, '2', limit=10) is False
@fixture @fixture
def target_dg(): def t2():
ex = nx.DiGraph() return Mock(name='t2',
ex.add_node('t1', status='PENDING', target='1') status='PENDING',
ex.add_node('t2', status='PENDING', target='1') target='1',
resource_type='node',
return ex type_limit=2)
def test_filtering_chain(target_dg): @fixture
def t3():
return Mock(name='t3',
status='PENDING',
target='1',
resource_type='node',
type_limit=2)
chain = limits.get_default_chain(target_dg, [], ['t1', 't2'])
assert list(chain) == ['t1'] @fixture
def dg(t1, t2, t3):
example = nx.DiGraph()
example.add_nodes_from((t1, t2, t3))
return example
def test_target_rule(dg, t1, t2):
assert limits.target_based_rule(dg, [], t1)
assert limits.target_based_rule(dg, [t1], t2) is False
def test_type_limit_rule(dg, t1, t2, t3):
assert limits.type_based_rule(dg, [t1], t2)
assert limits.type_based_rule(dg, [t1, t2], t3) is False
def test_items_rule(dg):
assert limits.items_rule(dg, [t1] * 99, t2)
assert limits.items_rule(dg, [t1] * 99, t2, limit=10) is False
def test_filtering_chain(dg, t1, t2):
chain = limits.get_default_chain(dg, [], [t1, t2])
assert list(chain) == [t1]
@fixture @fixture
@ -79,7 +82,7 @@ def seq_plan():
seq_path = os.path.join( seq_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'orch_fixtures', os.path.dirname(os.path.realpath(__file__)), 'orch_fixtures',
'sequential.yaml') 'sequential.yaml')
return graph.create_plan(seq_path, save=False) return graph.create_plan(seq_path)
def test_limits_sequential(seq_plan): def test_limits_sequential(seq_plan):

View File

@ -22,6 +22,7 @@ from solar.dblayer.solar_models import CommitedResource
from solar.dblayer.solar_models import Resource as DBResource from solar.dblayer.solar_models import Resource as DBResource
from solar.system_log import change from solar.system_log import change
from solar.system_log import operations from solar.system_log import operations
from solar.test.base import compare_task_to_names
def create_resource(name, tags=None): def create_resource(name, tags=None):
@ -260,12 +261,12 @@ def test_stage_and_process_partially():
a_graph = change.send_to_orchestration(a) a_graph = change.send_to_orchestration(a)
a_expected = set(['%s.restart' % n for n in range_a]) a_expected = set(['%s.restart' % n for n in range_a])
assert set(a_graph.nodes()) == a_expected compare_task_to_names(set(a_graph.nodes()), a_expected)
b_graph = change.send_to_orchestration(b) b_graph = change.send_to_orchestration(b)
b_expected = set(['%s.restart' % n for n in range_b]) b_expected = set(['%s.restart' % n for n in range_b])
assert set(b_graph.nodes()) == b_expected compare_task_to_names(set(b_graph.nodes()), b_expected)
both_graph = change.send_to_orchestration(both) both_graph = change.send_to_orchestration(both)
assert set(both_graph.nodes()) == a_expected | b_expected compare_task_to_names(set(both_graph.nodes()), a_expected | b_expected)
def test_childs_added_on_stage(): def test_childs_added_on_stage():

View File

@ -12,60 +12,85 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from mock import Mock
import networkx as nx import networkx as nx
from pytest import fixture from pytest import fixture
from solar.orchestration.traversal import traverse from solar.orchestration.traversal import find_visitable_tasks
@fixture @fixture
def tasks(): def task():
return [ number = {'count': 0}
{'id': 't1', 'status': 'PENDING'},
{'id': 't2', 'status': 'PENDING'}, def make_task():
{'id': 't3', 'status': 'PENDING'}, number['count'] += 1
{'id': 't4', 'status': 'PENDING'}, return Mock(name='t%s' % number, status='PENDING')
{'id': 't5', 'status': 'PENDING'}] return make_task
@fixture @fixture
def dg(tasks): def t1(task):
return task()
@fixture
def t2(task):
return task()
@fixture
def t3(task):
return task()
@fixture
def t4(task):
return task()
@fixture
def t5(task):
return task()
@fixture
def dg(t1, t2, t3, t4, t5):
ex = nx.DiGraph() ex = nx.DiGraph()
for t in tasks: ex.add_nodes_from((t1, t2, t3, t4, t5))
ex.add_node(t['id'], status=t['status'])
return ex return ex
def test_parallel(dg): def test_parallel(dg, t1, t2, t3, t4, t5):
dg.add_path(['t1', 't3', 't4', 't5']) dg.add_path([t1, t3, t4, t5])
dg.add_path(['t2', 't3']) dg.add_path([t2, t3])
assert set(traverse(dg)) == {'t1', 't2'} assert set(find_visitable_tasks(dg)) == {t1, t2}
def test_walked_only_when_all_predecessors_visited(dg): def test_walked_only_when_all_predecessors_visited(dg, t1, t2, t3, t4, t5):
dg.add_path(['t1', 't3', 't4', 't5']) dg.add_path([t1, t3, t4, t5])
dg.add_path(['t2', 't3']) dg.add_path([t2, t3])
dg.node['t1']['status'] = 'SUCCESS' t1.status = 'SUCCESS'
dg.node['t2']['status'] = 'INPROGRESS' t2.status = 'INPROGRESS'
assert set(traverse(dg)) == set() assert set(find_visitable_tasks(dg)) == set()
dg.node['t2']['status'] = 'SUCCESS' t2.status = 'SUCCESS'
assert set(traverse(dg)) == {'t3'} assert set(find_visitable_tasks(dg)) == {t3}
def test_nothing_will_be_walked_if_parent_is_skipped(dg): def test_nothing_will_be_walked_if_parent_is_skipped(dg, t1, t2, t3, t4, t5):
dg.add_path(['t1', 't2', 't3', 't4', 't5']) dg.add_path([t1, t2, t3, t4, t5])
dg.node['t1']['status'] = 'SKIPPED' t1.status = 'SKIPPED'
assert set(traverse(dg)) == set() assert set(find_visitable_tasks(dg)) == set()
def test_node_will_be_walked_if_parent_is_noop(dg): def test_node_will_be_walked_if_parent_is_noop(dg, t1, t2, t3, t4, t5):
dg.add_path(['t1', 't2', 't3', 't4', 't5']) dg.add_path([t1, t2, t3, t4, t5])
dg.node['t1']['status'] = 'NOOP' t1.status = 'NOOP'
assert set(traverse(dg)) == {'t2'} assert set(find_visitable_tasks(dg)) == {t2}