diff --git a/tests/test_async.py b/tests/test_async.py index 8957e2b5..4166475b 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -498,6 +498,21 @@ async def run(): asyncio.run(run()) + def test_may_transition_internal(self): + states = ['A', 'B', 'C'] + d = DummyModel() + _ = self.machine_cls(model=d, states=states, transitions=[["go", "A", "B"], ["wait", "B", None]], + initial='A', auto_transitions=False) + + async def run(): + assert await d.may_go() + assert not await d.may_wait() + await d.go() + assert not await d.may_go() + assert await d.may_wait() + + asyncio.run(run()) + @skipIf(asyncio is None or (pgv is None and gv is None), "AsyncGraphMachine requires asyncio and (py)gaphviz") class TestAsyncGraphMachine(TestAsync): diff --git a/tests/test_nesting.py b/tests/test_nesting.py index a982aba0..2a29af0a 100644 --- a/tests/test_nesting.py +++ b/tests/test_nesting.py @@ -910,6 +910,7 @@ def test_machine_may_transitions(self): transitions = [ {'trigger': 'walk', 'source': 'A', 'dest': 'B'}, {'trigger': 'run', 'source': 'B', 'dest': 'C'}, + {'trigger': 'wait', 'source': 'B', 'dest': None}, {'trigger': 'run_fast', 'source': 'C', 'dest': 'C{0}1'.format(self.separator)}, {'trigger': 'sprint', 'source': 'C', 'dest': 'D'} ] @@ -920,11 +921,13 @@ def test_machine_may_transitions(self): assert not m.may_run() assert not m.may_run_fast() assert not m.may_sprint() + assert not m.may_wait() m.walk() assert not m.may_walk() assert m.may_run() assert not m.may_run_fast() + assert m.may_wait() m.run() assert m.may_run_fast() diff --git a/transitions/core.py b/transitions/core.py index 1c6c5786..a1968526 100644 --- a/transitions/core.py +++ b/transitions/core.py @@ -884,7 +884,7 @@ def _can_trigger(self, model, trigger, *args, **kwargs): continue for transition in self.events[trigger_name].transitions[state]: try: - _ = transition.source if transition.dest is None else self.get_state(transition.dest) + _ = self.get_state(transition.dest) if transition.dest is not None else transition.source except ValueError: continue diff --git a/transitions/extensions/asyncio.py b/transitions/extensions/asyncio.py index 7785743a..7f46db28 100644 --- a/transitions/extensions/asyncio.py +++ b/transitions/extensions/asyncio.py @@ -433,7 +433,7 @@ async def _can_trigger(self, model, trigger, *args, **kwargs): continue for transition in self.events[trigger_name].transitions[state]: try: - _ = self.get_state(transition.dest) + _ = self.get_state(transition.dest) if transition.dest is not None else transition.source except ValueError: continue await self.callbacks(self.prepare_event, evt) @@ -559,7 +559,7 @@ async def _can_trigger_nested(self, model, trigger, path, *args, **kwargs): state_name = self.state_cls.separator.join(source_path) for transition in self.events[trigger].transitions.get(state_name, []): try: - _ = self.get_state(transition.dest) + _ = self.get_state(transition.dest) if transition.dest is not None else transition.source except ValueError: continue await self.callbacks(self.prepare_event, evt) diff --git a/transitions/extensions/nesting.py b/transitions/extensions/nesting.py index 9c759453..3935bfae 100644 --- a/transitions/extensions/nesting.py +++ b/transitions/extensions/nesting.py @@ -686,7 +686,7 @@ def _can_trigger_nested(self, model, trigger, path, *args, **kwargs): state_name = self.state_cls.separator.join(source_path) for transition in self.events[trigger].transitions.get(state_name, []): try: - _ = self.get_state(transition.dest) + _ = self.get_state(transition.dest) if transition.dest is not None else transition.source except ValueError: continue self.callbacks(self.prepare_event, evt)