diff --git a/automaton/machines.py b/automaton/machines.py index 1dac13c..469d84e 100644 --- a/automaton/machines.py +++ b/automaton/machines.py @@ -33,12 +33,18 @@ class State(object): :ivar name: The name of the state. :ivar is_terminal: Whether this state is terminal (or not). :ivar next_states: Dictionary of 'event' -> 'next state name' (or none). + :ivar on_enter: callback that will be called when the state is entered. + :ivar on_exit: callback that will be called when the state is exited. """ - def __init__(self, name, is_terminal=False, next_states=None): + def __init__(self, name, + is_terminal=False, next_states=None, + on_enter=None, on_exit=None): self.name = name self.is_terminal = bool(is_terminal) self.next_states = next_states + self.on_enter = on_enter + self.on_exit = on_exit def _convert_to_states(state_space): @@ -141,7 +147,10 @@ class FiniteMachine(object): state_space = list(_convert_to_states(state_space)) m = cls() for state in state_space: - m.add_state(state.name, terminal=state.is_terminal) + m.add_state(state.name, + terminal=state.is_terminal, + on_enter=state.on_enter, + on_exit=state.on_exit) for state in state_space: if state.next_states: for event, next_state in six.iteritems(state.next_states): diff --git a/automaton/tests/test_fsm.py b/automaton/tests/test_fsm.py index a58ba29..f78728a 100644 --- a/automaton/tests/test_fsm.py +++ b/automaton/tests/test_fsm.py @@ -14,6 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. +import collections import functools import random @@ -69,6 +70,40 @@ class FSMTest(testcase.TestCase): expected = [('down', 'jump', 'up'), ('up', 'fall', 'down')] self.assertEqual(expected, list(m)) + def test_build_transitions_with_callbacks(self): + entered = collections.defaultdict(list) + exitted = collections.defaultdict(list) + + def on_enter(state, event): + entered[state].append(event) + + def on_exit(state, event): + exitted[state].append(event) + + space = [ + machines.State('down', is_terminal=False, + next_states={'jump': 'up'}, + on_enter=on_enter, on_exit=on_exit), + machines.State('up', is_terminal=False, + next_states={'fall': 'down'}, + on_enter=on_enter, on_exit=on_exit), + ] + m = machines.FiniteMachine.build(space) + m.default_start_state = 'down' + expected = [('down', 'jump', 'up'), ('up', 'fall', 'down')] + self.assertEqual(expected, list(m)) + + m.initialize() + m.process_event('jump') + + self.assertEqual({'down': ['jump']}, dict(exitted)) + self.assertEqual({'up': ['jump']}, dict(entered)) + + m.process_event('fall') + + self.assertEqual({'down': ['jump'], 'up': ['fall']}, dict(exitted)) + self.assertEqual({'up': ['jump'], 'down': ['fall']}, dict(entered)) + def test_build_transitions_dct(self): space = [ {