diff --git a/mistral/engine/task_handler.py b/mistral/engine/task_handler.py index bf58c1c4f..b180f2088 100644 --- a/mistral/engine/task_handler.py +++ b/mistral/engine/task_handler.py @@ -409,32 +409,43 @@ def _refresh_task_state(task_ex_id): wf_ctrl = wf_base.get_controller(wf_ex, wf_spec) - log_state = wf_ctrl.get_logical_task_state(task_ex) + with db_api.named_lock(task_ex.id): + # NOTE: we have to use this lock to prevent two (or more) such + # methods from changing task state and starting its action or + # workflow. Checking task state outside of this section is a + # performance optimization because locking is pretty expensive. + db_api.refresh(task_ex) - state = log_state.state - state_info = log_state.state_info + if (states.is_completed(task_ex.state) + or task_ex.state == states.RUNNING): + return - # Update 'triggered_by' because it could have changed. - task_ex.runtime_context['triggered_by'] = log_state.triggered_by + log_state = wf_ctrl.get_logical_task_state(task_ex) - if state == states.RUNNING: - continue_task(task_ex) - elif state == states.ERROR: - complete_task(task_ex, state, state_info) - elif state == states.WAITING: - LOG.info( - "Task execution is still in WAITING state" - " [task_ex_id=%s, task_name=%s]", - task_ex_id, - task_ex.name - ) - else: - # Must never get here. - raise RuntimeError( - 'Unexpected logical task state [task_ex_id=%s, ' - 'task_name=%s, state=%s]' % - (task_ex_id, task_ex.name, state) - ) + state = log_state.state + state_info = log_state.state_info + + # Update 'triggered_by' because it could have changed. + task_ex.runtime_context['triggered_by'] = log_state.triggered_by + + if state == states.RUNNING: + continue_task(task_ex) + elif state == states.ERROR: + complete_task(task_ex, state, state_info) + elif state == states.WAITING: + LOG.info( + "Task execution is still in WAITING state" + " [task_ex_id=%s, task_name=%s]", + task_ex_id, + task_ex.name + ) + else: + # Must never get here. + raise RuntimeError( + 'Unexpected logical task state [task_ex_id=%s, ' + 'task_name=%s, state=%s]' % + (task_ex_id, task_ex.name, state) + ) def _schedule_refresh_task_state(task_ex_id, delay=0):