diff --git a/adjutant/api/v1/tasks.py b/adjutant/api/v1/tasks.py index 5fa22e5..15875a9 100644 --- a/adjutant/api/v1/tasks.py +++ b/adjutant/api/v1/tasks.py @@ -179,16 +179,19 @@ class TaskView(APIViewWithLogger): task.save() # Instantiate actions with serializers + action_instances = [] for i, action in enumerate(action_serializer_list): data = action['serializer'].validated_data # construct the action class - action_instance = action['action']( + action_instances.append(action['action']( data=data, task=task, order=i - ) + )) + # We run pre_approve on the actions once we've setup all of them. + for action_instance in action_instances: try: action_instance.pre_approve() except Exception as e: diff --git a/adjutant/api/v1/tests/test_api_taskview.py b/adjutant/api/v1/tests/test_api_taskview.py index f50628d..a767265 100644 --- a/adjutant/api/v1/tests/test_api_taskview.py +++ b/adjutant/api/v1/tests/test_api_taskview.py @@ -15,11 +15,13 @@ import mock from django.test.utils import override_settings +from django.conf import settings from django.core import mail from rest_framework import status from adjutant.api.models import Task, Token +from adjutant.api.v1.tasks import CreateProject from adjutant.common.tests.fake_clients import ( FakeManager, setup_identity_cache) from adjutant.common.tests import fake_clients @@ -1364,3 +1366,35 @@ class TaskViewTests(AdjutantAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.json(), {'errors': ['actions invalid']}) self.assertEqual(len(mail.outbox), 0) + + @mock.patch('adjutant.common.tests.fake_clients.FakeManager.find_project') + def test_all_actions_setup(self, mocked_find): + """ + Ensures that all actions have been setup before pre_approve is + run on any actions, even if we have a pre_approve failure. + + Deals with: bug/1745053 + """ + + setup_identity_cache() + + mocked_find.side_effect = KeyError() + + url = "/v1/actions/CreateProject" + data = {'project_name': "test_project", 'email': "test@example.com"} + response = self.client.post(url, data, format='json') + self.assertEqual( + response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + + new_task = Task.objects.all()[0] + + class_conf = settings.TASK_SETTINGS.get( + CreateProject.task_type, settings.DEFAULT_TASK_SETTINGS) + expected_action_names = ( + class_conf.get('default_actions', []) or + CreateProject.default_actions[:]) + expected_action_names += class_conf.get('additional_actions', []) + + actions = new_task.actions + observed_action_names = [a.action_name for a in actions] + self.assertEqual(observed_action_names, expected_action_names)