Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #682 asynctimeout #686

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from asyncio import CancelledError

from transitions.extensions.factory import AsyncGraphMachine, HierarchicalAsyncGraphMachine
from transitions.extensions.states import add_state_features

try:
import asyncio
from transitions.extensions.asyncio import AsyncMachine, HierarchicalAsyncMachine, AsyncEventData, \
AsyncTransition
AsyncTransition, AsyncTimeout

except (ImportError, SyntaxError):
asyncio = None # type: ignore
Expand Down Expand Up @@ -342,9 +345,6 @@ async def run():
asyncio.run(run())

def test_async_timeout(self):
from transitions.extensions.states import add_state_features
from transitions.extensions.asyncio import AsyncTimeout

timeout_called = MagicMock()

@add_state_features(AsyncTimeout)
Expand Down Expand Up @@ -376,6 +376,72 @@ async def run():

asyncio.run(run())

def test_timeout_cancel(self):
error_mock = MagicMock()
timout_mock = MagicMock()
long_op_mock = MagicMock()

@add_state_features(AsyncTimeout)
class TimeoutMachine(self.machine_cls): # type: ignore
async def on_enter_B(self):
await asyncio.sleep(0.2)
long_op_mock() # should never be called

async def handle_timeout(self):
timout_mock()
await self.to_A()

machine = TimeoutMachine(states=["A", {"name": "B", "timeout": 0.1, "on_timeout": "handle_timeout"}],
initial="A", on_exception=error_mock)

async def run():
await machine.to_B()
assert timout_mock.called
assert error_mock.call_count == 1 # should only be one CancelledError
assert not long_op_mock.called
assert machine.is_A()
asyncio.run(run())

def test_queued_timeout_cancel(self):
error_mock = MagicMock()
timout_mock = MagicMock()
long_op_mock = MagicMock()

@add_state_features(AsyncTimeout)
class TimeoutMachine(self.machine_cls): # type: ignore
async def long_op(self, event_data):
await self.to_C()
await self.to_D()
await self.to_E()
await asyncio.sleep(1)
long_op_mock()

async def handle_timeout(self, event_data):
timout_mock()
raise TimeoutError()

async def handle_error(self, event_data):
if isinstance(event_data.error, CancelledError):
if error_mock.called:
raise RuntimeError()
error_mock()
raise event_data.error

machine = TimeoutMachine(states=["A", "C", "D", "E",
{"name": "B", "timeout": 0.1, "on_timeout": "handle_timeout",
"on_enter": "long_op"}],
initial="A", queued=True, send_event=True, on_exception="handle_error")

async def run():
await machine.to_B()
assert timout_mock.called
assert error_mock.called
assert not long_op_mock.called
assert machine.is_B()
with self.assertRaises(RuntimeError):
await machine.to_B()
asyncio.run(run())

def test_callback_order(self):
finished = []

Expand Down
41 changes: 34 additions & 7 deletions transitions/extensions/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asyncio
import contextvars
import inspect
import warnings
from collections import deque
from functools import partial, reduce
import copy
Expand Down Expand Up @@ -116,7 +117,7 @@ async def execute(self, event_data):

machine = event_data.machine
# cancel running tasks since the transition will happen
await machine.switch_model_context(event_data.model)
await machine.cancel_running_transitions(event_data.model, event_data.event.name)

await event_data.machine.callbacks(event_data.machine.before_state_change, event_data)
await event_data.machine.callbacks(self.before, event_data)
Expand Down Expand Up @@ -189,7 +190,8 @@ async def _trigger(self, event_data):
if self._is_valid_source(event_data.state):
await self._process(event_data)
except BaseException as err: # pylint: disable=broad-except; Exception will be handled elsewhere
_LOGGER.error("%sException was raised while processing the trigger: %s", self.machine.name, err)
_LOGGER.error("%sException was raised while processing the trigger '%s': %s",
self.machine.name, event_data.event.name, repr(err))
event_data.error = err
if self.machine.on_exception:
await self.machine.callbacks(self.machine.on_exception, event_data)
Expand Down Expand Up @@ -374,18 +376,24 @@ async def await_all(callables):
return await asyncio.gather(*[func() for func in callables])

async def switch_model_context(self, model):
warnings.warn("Please replace 'AsyncMachine.switch_model_context' with "
"'AsyncMachine.cancel_running_transitions'.", category=DeprecationWarning)
await self.cancel_running_transitions(model)

async def cancel_running_transitions(self, model, msg=None):
"""
This method is called by an `AsyncTransition` when all conditional tests have passed
and the transition will happen. This requires already running tasks to be cancelled.
Args:
model (object): The currently processed model
msg (str): Optional message to pass to a running task's cancel request
"""
for running_task in self.async_tasks.get(id(model), []):
if self.current_context.get() == running_task or running_task in self.protected_tasks:
continue
if running_task.done() is False:
_LOGGER.debug("Cancel running tasks...")
running_task.cancel()
running_task.cancel(msg)

async def process_context(self, func, model):
"""
Expand All @@ -399,7 +407,7 @@ async def process_context(self, func, model):
bool: returns the success state of the triggered event
"""
if self.current_context.get() is None:
self.current_context.set(asyncio.current_task())
token = self.current_context.set(asyncio.current_task())
if id(model) in self.async_tasks:
self.async_tasks[id(model)].append(asyncio.current_task())
else:
Expand All @@ -410,6 +418,7 @@ async def process_context(self, func, model):
res = False
finally:
self.async_tasks[id(model)].remove(asyncio.current_task())
self.current_context.reset(token)
if len(self.async_tasks[id(model)]) == 0:
del self.async_tasks[id(model)]
else:
Expand Down Expand Up @@ -677,12 +686,30 @@ async def _timeout():
await asyncio.shield(self._process_timeout(event_data))
except asyncio.CancelledError:
pass

return asyncio.ensure_future(_timeout())
return asyncio.create_task(_timeout())

async def _process_timeout(self, event_data):
_LOGGER.debug("%sTimeout state %s. Processing callbacks...", event_data.machine.name, self.name)
await event_data.machine.callbacks(self.on_timeout, event_data)
event_data = AsyncEventData(event_data.state, AsyncEvent("timeout", event_data.machine),
event_data.machine, event_data.model, args=tuple(), kwargs={})
token = AsyncMachine.current_context.set(None)
try:
await event_data.machine.callbacks(self.on_timeout, event_data)
except BaseException as err:
_LOGGER.warning("%sException raised while processing timeout!",
event_data.machine.name)
event_data.error = err
try:
if event_data.machine.on_exception:
await event_data.machine.callbacks(event_data.machine.on_exception, event_data)
else:
raise
except BaseException as err2:
_LOGGER.error("%sHandling timeout exception '%s' caused another exception: %s. "
"Cancel running transitions...", event_data.machine.name, repr(err), repr(err2))
await event_data.machine.cancel_running_transitions(event_data.model, "timeout")
finally:
AsyncMachine.current_context.reset(token)
_LOGGER.info("%sTimeout state %s processed.", event_data.machine.name, self.name)

@property
Expand Down
1 change: 1 addition & 0 deletions transitions/extensions/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class AsyncMachine(Machine):
async def callback(self, func: AsyncCallback, event_data: AsyncEventData) -> None: ... # type: ignore[override]
@staticmethod
async def await_all(callables: List[AsyncCallbackFunc]) -> List[Optional[bool]]: ...
async def cancel_running_transitions(self, model: object, msg: Optional[str] = ...) -> None: ...
async def switch_model_context(self, model: object) -> None: ...
def get_state(self, state: Union[str, Enum]) -> AsyncState: ...
async def process_context(self, func: Callable[[], Awaitable[None]], model: object) -> bool: ...
Expand Down