diff --git a/mistral/services/workflows.py b/mistral/services/workflows.py index 383322375..083796352 100644 --- a/mistral/services/workflows.py +++ b/mistral/services/workflows.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import yaml from mistral.db.v2 import api as db_api from mistral import exceptions as exc @@ -89,11 +90,16 @@ def create_workflows(definition, scope='private', is_system=False, def _append_all_workflows(definition, is_system, scope, namespace, wf_list_spec, db_wfs): - for wf_spec in wf_list_spec.get_workflows(): + wfs = wf_list_spec.get_workflows() + wfs_yaml = yaml.load(definition) + for wf_spec in wfs: db_wfs.append( _create_workflow( wf_spec, - definition, + _cut_wf_definition_from_all( + wfs_yaml, + wf_spec.get_name() + ), scope, namespace, is_system @@ -120,10 +126,14 @@ def update_workflows(definition, scope='private', identifier=None, db_wfs = [] with db_api.transaction(): - for wf_spec in wf_list_spec.get_workflows(): + wfs_yaml = yaml.load(definition) + for wf_spec in wfs: db_wfs.append(_update_workflow( wf_spec, - definition, + _cut_wf_definition_from_all( + wfs_yaml, + wf_spec.get_name() + ), scope, namespace=namespace, identifier=identifier @@ -176,3 +186,10 @@ def _update_workflow(wf_spec, definition, scope, identifier=None, identifier if identifier else values['name'], values ) + + +def _cut_wf_definition_from_all(wfs_yaml, wf_name): + return yaml.dump({ + 'version': wfs_yaml['version'], + wf_name: wfs_yaml[wf_name] + }) diff --git a/mistral/tests/unit/api/v2/test_workflows.py b/mistral/tests/unit/api/v2/test_workflows.py index dc744c833..3b250d1ab 100644 --- a/mistral/tests/unit/api/v2/test_workflows.py +++ b/mistral/tests/unit/api/v2/test_workflows.py @@ -18,6 +18,7 @@ import datetime import mock import sqlalchemy as sa +import yaml from mistral.db.v2 import api as db_api from mistral.db.v2.sqlalchemy import models @@ -179,6 +180,60 @@ wf2: action: std.echo output="Mistral" """ +WFS_YAML = yaml.load(WFS_DEFINITION) +FIRST_WF_DEF = yaml.dump({ + 'version': '2.0', + 'wf1': WFS_YAML['wf1'] +}) +SECOND_WF_DEF = yaml.dump({ + 'version': '2.0', + 'wf2': WFS_YAML['wf2'] +}) + +FIRST_WF_DICT = { + 'name': 'wf1', + 'tasks': { + 'task1': { + 'action': 'std.echo output="Hello"', + 'name': 'task1', + 'type': 'direct', + 'version': '2.0' + } + }, + 'version': '2.0' +} +FIRST_WF = { + 'name': 'wf1', + 'tags': [], + 'definition': FIRST_WF_DEF, + 'spec': FIRST_WF_DICT, + 'scope': 'private', + 'namespace': '', + 'is_system': False +} + +SECOND_WF_DICT = { + 'name': 'wf2', + 'tasks': { + 'task1': { + 'action': 'std.echo output="Mistral"', + 'name': 'task1', + 'type': 'direct', + 'version': '2.0' + } + }, + 'version': '2.0' +} +SECOND_WF = { + 'name': 'wf2', + 'tags': [], + 'definition': SECOND_WF_DEF, + 'spec': SECOND_WF_DICT, + 'scope': 'private', + 'namespace': '', + 'is_system': False +} + MOCK_WF = mock.MagicMock(return_value=WF_DB) MOCK_WF_SYSTEM = mock.MagicMock(return_value=WF_DB_SYSTEM) MOCK_WF_WITH_INPUT = mock.MagicMock(return_value=WF_DB_WITH_INPUT) @@ -343,6 +398,18 @@ class TestWorkflowsController(base.APITest): self.assertEqual(400, resp.status_int) self.assertIn("Invalid DSL", resp.body.decode()) + @mock.patch.object(db_api, "update_workflow_definition") + def test_put_multiple(self, mock_mtd): + self.app.put( + '/v2/workflows', + WFS_DEFINITION, + headers={'Content-Type': 'text/plain'} + ) + + self.assertEqual(2, mock_mtd.call_count) + mock_mtd.assert_any_call('wf1', FIRST_WF) + mock_mtd.assert_any_call('wf2', SECOND_WF) + def test_put_more_workflows_with_uuid(self): resp = self.app.put( '/v2/workflows/123e4567-e89b-12d3-a456-426655440000', @@ -414,6 +481,18 @@ class TestWorkflowsController(base.APITest): self.assertEqual(409, resp.status_int) + @mock.patch.object(db_api, "create_workflow_definition") + def test_post_multiple(self, mock_mtd): + self.app.post( + '/v2/workflows', + WFS_DEFINITION, + headers={'Content-Type': 'text/plain'} + ) + + self.assertEqual(2, mock_mtd.call_count) + mock_mtd.assert_any_call(FIRST_WF) + mock_mtd.assert_any_call(SECOND_WF) + def test_post_invalid(self): resp = self.app.post( '/v2/workflows',