diff --git a/mistral/db/v2/api.py b/mistral/db/v2/api.py index fb93f92b7..98738bd17 100644 --- a/mistral/db/v2/api.py +++ b/mistral/db/v2/api.py @@ -71,13 +71,13 @@ def acquire_lock(model, id): # Workbooks. -def get_workbook(name): - return IMPL.get_workbook(name) +def get_workbook(name, fields=()): + return IMPL.get_workbook(name, fields=fields) -def load_workbook(name): +def load_workbook(name, fields=()): """Unlike get_workbook this method is allowed to return None.""" - return IMPL.load_workbook(name) + return IMPL.load_workbook(name, fields=fields) def get_workbooks(limit=None, marker=None, sort_keys=None, @@ -114,17 +114,21 @@ def delete_workbooks(**kwargs): # Workflow definitions. -def get_workflow_definition(identifier, namespace=''): - return IMPL.get_workflow_definition(identifier, namespace=namespace) +def get_workflow_definition(identifier, namespace='', fields=()): + return IMPL.get_workflow_definition( + identifier, + namespace=namespace, + fields=fields + ) -def get_workflow_definition_by_id(id): - return IMPL.get_workflow_definition_by_id(id) +def get_workflow_definition_by_id(id, fields=()): + return IMPL.get_workflow_definition_by_id(id, fields=fields) -def load_workflow_definition(name, namespace=''): +def load_workflow_definition(name, namespace='', fields=()): """Unlike get_workflow_definition this method is allowed to return None.""" - return IMPL.load_workflow_definition(name, namespace) + return IMPL.load_workflow_definition(name, namespace, fields=fields) def get_workflow_definitions(limit=None, marker=None, sort_keys=None, @@ -161,17 +165,17 @@ def delete_workflow_definitions(**kwargs): # Action definitions. -def get_action_definition_by_id(id): - return IMPL.get_action_definition_by_id(id) +def get_action_definition_by_id(id, fields=()): + return IMPL.get_action_definition_by_id(id, fields=fields) -def get_action_definition(name): - return IMPL.get_action_definition(name) +def get_action_definition(name, fields=()): + return IMPL.get_action_definition(name, fields=fields) -def load_action_definition(name): +def load_action_definition(name, fields=()): """Unlike get_action_definition this method is allowed to return None.""" - return IMPL.load_action_definition(name) + return IMPL.load_action_definition(name, fields=fields) def get_action_definitions(limit=None, marker=None, sort_keys=None, @@ -207,13 +211,13 @@ def delete_action_definitions(**kwargs): # Action executions. -def get_action_execution(id): - return IMPL.get_action_execution(id) +def get_action_execution(id, fields=()): + return IMPL.get_action_execution(id, fields=fields) -def load_action_execution(name): +def load_action_execution(name, fields=()): """Unlike get_action_execution this method is allowed to return None.""" - return IMPL.load_action_execution(name) + return IMPL.load_action_execution(name, fields=fields) def get_action_executions(**kwargs): @@ -242,14 +246,13 @@ def delete_action_executions(**kwargs): # Workflow executions. -# TODO(rakhmerov): Add 'fields' parameter to all 'get' methods. def get_workflow_execution(id, fields=()): return IMPL.get_workflow_execution(id, fields=fields) -def load_workflow_execution(name): +def load_workflow_execution(name, fields=()): """Unlike get_workflow_execution this method is allowed to return None.""" - return IMPL.load_workflow_execution(name) + return IMPL.load_workflow_execution(name, fields=fields) def get_workflow_executions(limit=None, marker=None, sort_keys=None, @@ -289,13 +292,13 @@ def update_workflow_execution_state(**kwargs): # Tasks executions. -def get_task_execution(id): - return IMPL.get_task_execution(id) +def get_task_execution(id, fields=()): + return IMPL.get_task_execution(id, fields=fields) -def load_task_execution(id): +def load_task_execution(id, fields=()): """Unlike get_task_execution this method is allowed to return None.""" - return IMPL.load_task_execution(id) + return IMPL.load_task_execution(id, fields=fields) def get_task_executions(limit=None, marker=None, sort_keys=None, diff --git a/mistral/db/v2/sqlalchemy/api.py b/mistral/db/v2/sqlalchemy/api.py index a4e667129..4270d974b 100644 --- a/mistral/db/v2/sqlalchemy/api.py +++ b/mistral/db/v2/sqlalchemy/api.py @@ -25,8 +25,6 @@ from oslo_db.sqlalchemy import utils as db_utils from oslo_log import log as logging from oslo_utils import uuidutils # noqa import sqlalchemy as sa -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql.expression import Insert from mistral import context from mistral.db.sqlalchemy import base as b @@ -279,18 +277,10 @@ def _get_collection(model, insecure=False, limit=None, marker=None, return query.all() -def _get_db_object_by_name(model, name, filter_=None, order_by=None): +def _get_db_object_by_name(model, name, columns=()): + query = _secure_query(model, *columns) - query = _secure_query(model) - final_filter = model.name == name - - if filter_ is not None: - final_filter = sa.and_(final_filter, filter_) - - if order_by is not None: - query = query.order_by(order_by) - - return query.filter(final_filter).first() + return query.filter_by(name=name).first() def _get_db_object_by_id(model, id, insecure=False, columns=()): @@ -304,9 +294,13 @@ def _get_db_object_by_id(model, id, insecure=False, columns=()): def _get_db_object_by_name_and_namespace_or_id(model, identifier, - namespace=None, insecure=False): - - query = b.model_query(model) if insecure else _secure_query(model) + namespace=None, insecure=False, + columns=()): + query = ( + b.model_query(model, columns=columns) + if insecure + else _secure_query(model, *columns) + ) match_name = model.name == identifier @@ -325,30 +319,11 @@ def _get_db_object_by_name_and_namespace_or_id(model, identifier, return query.first() -@compiles(Insert) -def append_string(insert, compiler, **kw): - s = compiler.visit_insert(insert, **kw) - - if 'append_string' in insert.kwargs: - append = insert.kwargs['append_string'] - - if append: - s += " " + append - - if 'replace_string' in insert.kwargs: - replace = insert.kwargs['replace_string'] - - if isinstance(replace, tuple): - s = s.replace(replace[0], replace[1]) - - return s - - # Workbook definitions. @b.session_aware() -def get_workbook(name, session=None): - wb = _get_db_object_by_name(models.Workbook, name) +def get_workbook(name, fields=(), session=None): + wb = _get_db_object_by_name(models.Workbook, name, columns=fields) if not wb: raise exc.DBEntityNotFoundError( @@ -359,8 +334,8 @@ def get_workbook(name, session=None): @b.session_aware() -def load_workbook(name, session=None): - return _get_db_object_by_name(models.Workbook, name) +def load_workbook(name, fields=(), session=None): + return _get_db_object_by_name(models.Workbook, name, columns=fields) @b.session_aware() @@ -421,12 +396,14 @@ def delete_workbooks(session=None, **kwargs): # Workflow definitions. @b.session_aware() -def get_workflow_definition(identifier, namespace='', session=None): +def get_workflow_definition(identifier, namespace='', fields=(), session=None): """Gets workflow definition by name or uuid. :param identifier: Identifier could be in the format of plain string or uuid. :param namespace: The namespace the workflow is in. Optional. + :param fields: Fields that need to be loaded. For example, + (WorkflowDefinition.name,) :return: Workflow definition. """ ctx = context.ctx() @@ -435,7 +412,8 @@ def get_workflow_definition(identifier, namespace='', session=None): models.WorkflowDefinition, identifier, namespace=namespace, - insecure=ctx.is_admin + insecure=ctx.is_admin, + columns=fields ) if not wf_def: @@ -448,8 +426,12 @@ def get_workflow_definition(identifier, namespace='', session=None): @b.session_aware() -def get_workflow_definition_by_id(id, session=None): - wf_def = _get_db_object_by_id(models.WorkflowDefinition, id) +def get_workflow_definition_by_id(id, fields=(), session=None): + wf_def = _get_db_object_by_id( + models.WorkflowDefinition, + id, + columns=fields + ) if not wf_def: raise exc.DBEntityNotFoundError( @@ -460,20 +442,22 @@ def get_workflow_definition_by_id(id, session=None): @b.session_aware() -def load_workflow_definition(name, namespace='', session=None): +def load_workflow_definition(name, namespace='', fields=(), session=None): model = models.WorkflowDefinition - filter_ = model.namespace.in_([namespace, '']) + query = _secure_query(model, *fields) + + filter_ = sa.and_( + model.name == name, model.namespace.in_([namespace, '']) + ) # Give priority to objects not in the default namespace. order_by = model.namespace.desc() - return _get_db_object_by_name( - model, - name, - filter_, - order_by - ) + if order_by is not None: + query = query.order_by(order_by) + + return query.filter(filter_).first() @b.session_aware() @@ -528,6 +512,7 @@ def update_workflow_definition(identifier, values, namespace='', session=None): insecure=True, workflow_id=wf_def.id ) + for e_t in event_triggers: if e_t.project_id != wf_def.project_id: raise exc.NotAllowedException( @@ -590,8 +575,12 @@ def delete_workflow_definitions(session=None, **kwargs): # Action definitions. @b.session_aware() -def get_action_definition_by_id(id, session=None): - action_def = _get_db_object_by_id(models.ActionDefinition, id) +def get_action_definition_by_id(id, fields=(), session=None): + action_def = _get_db_object_by_id( + models.ActionDefinition, + id, + columns=fields + ) if not action_def: raise exc.DBEntityNotFoundError( @@ -602,10 +591,11 @@ def get_action_definition_by_id(id, session=None): @b.session_aware() -def get_action_definition(identifier, session=None): +def get_action_definition(identifier, fields=(), session=None): a_def = _get_db_object_by_name_and_namespace_or_id( models.ActionDefinition, - identifier + identifier, + columns=fields ) if not a_def: @@ -617,8 +607,12 @@ def get_action_definition(identifier, session=None): @b.session_aware() -def load_action_definition(name, session=None): - return _get_db_object_by_name(models.ActionDefinition, name) +def load_action_definition(name, fields=(), session=None): + return _get_db_object_by_name( + models.ActionDefinition, + name, + columns=fields + ) @b.session_aware() @@ -675,8 +669,8 @@ def delete_action_definitions(session=None, **kwargs): # Action executions. @b.session_aware() -def get_action_execution(id, session=None): - a_ex = _get_db_object_by_id(models.ActionExecution, id) +def get_action_execution(id, fields=(), session=None): + a_ex = _get_db_object_by_id(models.ActionExecution, id, columns=fields) if not a_ex: raise exc.DBEntityNotFoundError( @@ -687,8 +681,8 @@ def get_action_execution(id, session=None): @b.session_aware() -def load_action_execution(id, session=None): - return _get_db_object_by_id(models.ActionExecution, id) +def load_action_execution(id, fields=(), session=None): + return _get_db_object_by_id(models.ActionExecution, id, columns=fields) @b.session_aware() @@ -771,8 +765,8 @@ def get_workflow_execution(id, fields=(), session=None): @b.session_aware() -def load_workflow_execution(id, session=None): - return _get_db_object_by_id(models.WorkflowExecution, id) +def load_workflow_execution(id, fields=(), session=None): + return _get_db_object_by_id(models.WorkflowExecution, id, columns=fields) @b.session_aware() @@ -846,8 +840,8 @@ def update_workflow_execution_state(id, cur_state, state): # Tasks executions. @b.session_aware() -def get_task_execution(id, session=None): - task_ex = _get_db_object_by_id(models.TaskExecution, id) +def get_task_execution(id, fields=(), session=None): + task_ex = _get_db_object_by_id(models.TaskExecution, id, columns=fields) if not task_ex: raise exc.DBEntityNotFoundError( @@ -858,8 +852,8 @@ def get_task_execution(id, session=None): @b.session_aware() -def load_task_execution(id, session=None): - return _get_db_object_by_id(models.TaskExecution, id) +def load_task_execution(id, fields=(), session=None): + return _get_db_object_by_id(models.TaskExecution, id, columns=fields) @b.session_aware() diff --git a/mistral/tests/unit/db/v2/test_sqlalchemy_db_api.py b/mistral/tests/unit/db/v2/test_sqlalchemy_db_api.py index d0827f55b..960d9431c 100644 --- a/mistral/tests/unit/db/v2/test_sqlalchemy_db_api.py +++ b/mistral/tests/unit/db/v2/test_sqlalchemy_db_api.py @@ -84,6 +84,20 @@ class WorkbookTest(SQLAlchemyTest): self.assertIsNone(db_api.load_workbook("not-existing-wb")) + def test_get_workbook_with_fields(self): + with db_api.transaction(): + created = db_api.create_workbook(WORKBOOKS[0]) + + fetched = db_api.get_workbook( + created['name'], + fields=(db_models.Workbook.scope,) + ) + + self.assertNotEqual(created, fetched) + self.assertIsInstance(fetched, tuple) + self.assertEqual(1, len(fetched)) + self.assertEqual(created.scope, fetched[0]) + def test_create_workbook_duplicate_without_auth(self): cfg.CONF.set_default('auth_enable', False, group='pecan') db_api.create_workbook(WORKBOOKS[0]) @@ -459,6 +473,20 @@ class WorkflowDefinitionTest(SQLAlchemyTest): self.assertIsNone(db_api.load_workflow_definition("not-existing-wf")) + def test_get_workflow_definition_with_fields(self): + with db_api.transaction(): + created = db_api.create_workflow_definition(WF_DEFINITIONS[0]) + + fetched = db_api.get_workflow_definition( + created.name, + fields=(db_models.WorkflowDefinition.scope,) + ) + + self.assertNotEqual(created, fetched) + self.assertIsInstance(fetched, tuple) + self.assertEqual(1, len(fetched)) + self.assertEqual(created.scope, fetched[0]) + def test_get_workflow_definition_with_uuid(self): created = db_api.create_workflow_definition(WF_DEFINITIONS[0]) fetched = db_api.get_workflow_definition(created.id) @@ -1006,6 +1034,20 @@ class ActionDefinitionTest(SQLAlchemyTest): self.assertIsNone(db_api.load_action_definition("not-existing-id")) + def test_get_action_definition_with_fields(self): + with db_api.transaction(): + created = db_api.create_action_definition(ACTION_DEFINITIONS[0]) + + fetched = db_api.get_action_definition( + created.name, + fields=(db_models.ActionDefinition.scope,) + ) + + self.assertNotEqual(created, fetched) + self.assertIsInstance(fetched, tuple) + self.assertEqual(1, len(fetched)) + self.assertEqual(created.scope, fetched[0]) + def test_get_action_definition_with_uuid(self): created = db_api.create_action_definition(ACTION_DEFINITIONS[0]) fetched = db_api.get_action_definition(created.id) @@ -1343,6 +1385,20 @@ class ActionExecutionTest(SQLAlchemyTest): self.assertIsNone(db_api.load_action_execution("not-existing-id")) + def test_get_action_execution_with_fields(self): + with db_api.transaction(): + created = db_api.create_action_execution(ACTION_EXECS[0]) + + fetched = db_api.get_action_execution( + created.id, + fields=(db_models.ActionExecution.name,) + ) + + self.assertNotEqual(created, fetched) + self.assertIsInstance(fetched, tuple) + self.assertEqual(1, len(fetched)) + self.assertEqual(created.name, fetched[0]) + def test_update_action_execution(self): with db_api.transaction(): created = db_api.create_action_execution(ACTION_EXECS[0]) @@ -1504,7 +1560,7 @@ class WorkflowExecutionTest(SQLAlchemyTest): db_api.load_workflow_execution("not-existing-id") ) - def test_get_workflow_execution_with_columns(self): + def test_get_workflow_execution_with_fields(self): with db_api.transaction(): created = db_api.create_workflow_execution(WF_EXECS[0]) @@ -1516,7 +1572,7 @@ class WorkflowExecutionTest(SQLAlchemyTest): self.assertNotEqual(created, fetched) self.assertIsInstance(fetched, tuple) self.assertEqual(1, len(fetched)) - self.assertEqual(created.state, fetched.state) + self.assertEqual(created.state, fetched[0]) def test_update_workflow_execution(self): with db_api.transaction(): @@ -1900,7 +1956,6 @@ TASK_EXECS = [ class TaskExecutionTest(SQLAlchemyTest): - def test_create_and_get_and_load_task_execution(self): with db_api.transaction(): wf_ex = db_api.create_workflow_execution(WF_EXECS[0]) @@ -1922,6 +1977,25 @@ class TaskExecutionTest(SQLAlchemyTest): self.assertIsNone(db_api.load_task_execution("not-existing-id")) + def test_get_task_execution_with_fields(self): + with db_api.transaction(): + wf_ex = db_api.create_workflow_execution(WF_EXECS[0]) + + values = copy.deepcopy(TASK_EXECS[0]) + values.update({'workflow_execution_id': wf_ex.id}) + + created = db_api.create_task_execution(values) + + fetched = db_api.get_task_execution( + created.id, + fields=(db_models.TaskExecution.name,) + ) + + self.assertNotEqual(created, fetched) + self.assertIsInstance(fetched, tuple) + self.assertEqual(1, len(fetched)) + self.assertEqual(created.name, fetched[0]) + def test_action_executions(self): # Store one task with two invocations. with db_api.transaction():