Merge "Fixing tasks API endpoint"
This commit is contained in:
commit
cbab3c1e31
|
@ -23,12 +23,10 @@ import wsmeext.pecan as wsme_pecan
|
|||
from mistral.api.controllers import resource
|
||||
from mistral.api.controllers.v2 import action_execution
|
||||
from mistral.db.v2 import api as db_api
|
||||
from mistral.engine1 import rpc
|
||||
from mistral import exceptions as exc
|
||||
from mistral.openstack.common import log as logging
|
||||
from mistral.utils import rest_utils
|
||||
from mistral.workflow import data_flow
|
||||
from mistral.workflow import states
|
||||
from mistral.workflow import utils as wf_utils
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
@ -60,8 +58,7 @@ class Task(resource.Resource):
|
|||
for key, val in d.items():
|
||||
if hasattr(e, key):
|
||||
# Nonetype check for dictionary must be explicit.
|
||||
if val is not None and (
|
||||
key == 'input' or key == 'result'):
|
||||
if val is not None and key == 'input':
|
||||
val = json.dumps(val)
|
||||
setattr(e, key, val)
|
||||
|
||||
|
@ -95,6 +92,25 @@ class Tasks(resource.Resource):
|
|||
return cls(tasks=[Task.sample()])
|
||||
|
||||
|
||||
def _get_task_resources_with_results(wf_ex_id=None):
|
||||
filters = {}
|
||||
|
||||
if wf_ex_id:
|
||||
filters['workflow_execution_id'] = wf_ex_id
|
||||
|
||||
tasks = []
|
||||
task_execs = db_api.get_task_executions(**filters)
|
||||
for task_ex in task_execs:
|
||||
task = Task.from_dict(task_ex.to_dict())
|
||||
task.result = json.dumps(
|
||||
data_flow.get_task_execution_result(task_ex)
|
||||
)
|
||||
|
||||
tasks += [task]
|
||||
|
||||
return Tasks(tasks=tasks)
|
||||
|
||||
|
||||
class TasksController(rest.RestController):
|
||||
action_executions = action_execution.TasksActionExecutionController()
|
||||
|
||||
|
@ -104,46 +120,19 @@ class TasksController(rest.RestController):
|
|||
"""Return the specified task."""
|
||||
LOG.info("Fetch task [id=%s]" % id)
|
||||
|
||||
db_model = db_api.get_task_execution(id)
|
||||
task_ex = db_api.get_task_execution(id)
|
||||
task = Task.from_dict(task_ex.to_dict())
|
||||
|
||||
return Task.from_dict(db_model.to_dict())
|
||||
task.result = json.dumps(data_flow.get_task_execution_result(task_ex))
|
||||
|
||||
@rest_utils.wrap_wsme_controller_exception
|
||||
@wsme_pecan.wsexpose(Task, wtypes.text, body=Task)
|
||||
def put(self, id, task):
|
||||
"""Update the specified task."""
|
||||
LOG.info("Update task [id=%s, task=%s]" % (id, task))
|
||||
|
||||
# Client must provide a valid json. It shouldn't necessarily be an
|
||||
# object but it should be json complaint so strings have to be escaped.
|
||||
result = None
|
||||
|
||||
if task.result:
|
||||
try:
|
||||
result = json.loads(task.result)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise exc.InvalidResultException(str(e))
|
||||
|
||||
if task.state == states.ERROR:
|
||||
task_result = wf_utils.Result(error=result)
|
||||
else:
|
||||
task_result = wf_utils.Result(data=result)
|
||||
|
||||
engine = rpc.get_engine_client()
|
||||
|
||||
values = engine.on_task_result(id, task_result)
|
||||
|
||||
return Task.from_dict(values)
|
||||
return task
|
||||
|
||||
@wsme_pecan.wsexpose(Tasks)
|
||||
def get_all(self):
|
||||
"""Return all tasks within the execution."""
|
||||
LOG.info("Fetch tasks")
|
||||
|
||||
tasks = [Task.from_dict(db_model.to_dict())
|
||||
for db_model in db_api.get_task_executions()]
|
||||
|
||||
return Tasks(tasks=tasks)
|
||||
return _get_task_resources_with_results()
|
||||
|
||||
|
||||
class ExecutionTasksController(rest.RestController):
|
||||
|
@ -152,12 +141,4 @@ class ExecutionTasksController(rest.RestController):
|
|||
"""Return all tasks within the workflow execution."""
|
||||
LOG.info("Fetch tasks")
|
||||
|
||||
task_execs = db_api.get_task_executions(
|
||||
workflow_execution_id=workflow_execution_id
|
||||
)
|
||||
|
||||
return Tasks(
|
||||
tasks=[
|
||||
Task.from_dict(db_model.to_dict()) for db_model in task_execs
|
||||
]
|
||||
)
|
||||
return _get_task_resources_with_results(workflow_execution_id)
|
||||
|
|
|
@ -16,17 +16,19 @@
|
|||
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import mock
|
||||
|
||||
from mistral.db.v2 import api as db_api
|
||||
from mistral.db.v2.sqlalchemy import models
|
||||
from mistral.engine1 import rpc
|
||||
from mistral import exceptions as exc
|
||||
from mistral.tests.unit.api import base
|
||||
from mistral.workflow import data_flow
|
||||
from mistral.workflow import states
|
||||
|
||||
# TODO(everyone): later we need additional tests verifying all the errors etc.
|
||||
|
||||
RESULT = {"some": "result"}
|
||||
task_ex = models.TaskExecution(
|
||||
id='123',
|
||||
name='task',
|
||||
|
@ -49,7 +51,8 @@ TASK = {
|
|||
'state': 'RUNNING',
|
||||
'workflow_execution_id': '123',
|
||||
'created_at': '1970-01-01 00:00:00',
|
||||
'updated_at': '1970-01-01 00:00:00'
|
||||
'updated_at': '1970-01-01 00:00:00',
|
||||
'result': json.dumps(RESULT)
|
||||
}
|
||||
|
||||
UPDATED_task_ex = copy.copy(task_ex)
|
||||
|
@ -63,7 +66,6 @@ ERROR_TASK = copy.copy(TASK)
|
|||
ERROR_TASK['state'] = 'ERROR'
|
||||
|
||||
BROKEN_TASK = copy.copy(TASK)
|
||||
BROKEN_TASK['result'] = 'string not escaped'
|
||||
|
||||
MOCK_TASK = mock.MagicMock(return_value=task_ex)
|
||||
MOCK_TASKS = mock.MagicMock(return_value=[task_ex])
|
||||
|
@ -71,13 +73,15 @@ MOCK_EMPTY = mock.MagicMock(return_value=[])
|
|||
MOCK_NOT_FOUND = mock.MagicMock(side_effect=exc.NotFoundException())
|
||||
|
||||
|
||||
@mock.patch.object(
|
||||
data_flow,
|
||||
'get_task_execution_result', mock.Mock(return_value=RESULT)
|
||||
)
|
||||
class TestTasksController(base.FunctionalTest):
|
||||
@mock.patch.object(db_api, 'get_task_execution', MOCK_TASK)
|
||||
def test_get(self):
|
||||
resp = self.app.get('/v2/tasks/123')
|
||||
|
||||
self.maxDiff = None
|
||||
|
||||
self.assertEqual(resp.status_int, 200)
|
||||
self.assertDictEqual(TASK, resp.json)
|
||||
|
||||
|
@ -87,13 +91,6 @@ class TestTasksController(base.FunctionalTest):
|
|||
|
||||
self.assertEqual(resp.status_int, 404)
|
||||
|
||||
@mock.patch.object(rpc.EngineClient, 'on_action_complete')
|
||||
def test_put_bad_result(self, f):
|
||||
resp = self.app.put_json('/v2/tasks/123', BROKEN_TASK,
|
||||
expect_errors=True)
|
||||
|
||||
self.assertEqual(resp.status_int, 400)
|
||||
|
||||
@mock.patch.object(db_api, 'get_task_executions', MOCK_TASKS)
|
||||
def test_get_all(self):
|
||||
resp = self.app.get('/v2/tasks')
|
||||
|
|
|
@ -57,9 +57,10 @@ def get_task_execution_result(task_ex):
|
|||
if hasattr(ex, 'output') and ex.accepted
|
||||
]
|
||||
|
||||
assert len(results) > 0
|
||||
|
||||
return results if len(results) > 1 else results[0]
|
||||
if results:
|
||||
return results if len(results) > 1 else results[0]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def publish_variables(task_ex, task_spec):
|
||||
|
|
Loading…
Reference in New Issue