Skip to content

Commit

Permalink
improve typing information in pyi files
Browse files Browse the repository at this point in the history
  • Loading branch information
aleneum committed May 13, 2024
1 parent 3f08612 commit cb913e0
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 76 deletions.
25 changes: 13 additions & 12 deletions transitions/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ from enum import Enum, EnumMeta

_LOGGER: Logger

Callback = Union[str, Callable]
CallbackFunc = Callable[..., Optional[bool]]
Callback = Union[str, CallbackFunc]
CallbackList = List[Callback]
CallbacksArg = Optional[Union[Callback, CallbackList]]
ModelState = Union[str, Enum, List]
ModelState = Union[str, Enum, List["ModelState"]]
ModelParameter = Union[Union[Literal['self'], Any], List[Union[Literal['self'], Any]]]

def listify(obj: Union[None, list, tuple, EnumMeta, Any]) -> Union[list, tuple, EnumMeta]: ...
def listify(obj: Union[None, List[Any], Tuple[Any], EnumMeta, Any]) -> Union[List[Any], Tuple[Any], EnumMeta]: ...

def _prep_ordered_arg(desired_length: int, arguments: CallbacksArg) -> CallbackList: ...

Expand Down Expand Up @@ -89,7 +90,7 @@ class Event:
transitions: DefaultDict[str, List[Transition]]
def __init__(self, name: str, machine: Machine) -> None: ...
def add_transition(self, transition: Transition) -> None: ...
def trigger(self, model: object, *args: List, **kwargs: Dict) -> bool: ...
def trigger(self, model: object, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...
def _trigger(self, event_data: EventData) -> bool: ...
def _process(self, event_data: EventData) -> bool: ...
def _is_valid_source(self, state: State) -> bool: ...
Expand All @@ -105,7 +106,7 @@ class Machine:
event_cls: Type[Event]
self_literal: Literal['self']
_queued: bool
_transition_queue: Deque[partial]
_transition_queue: Deque[CallbackFunc]
_before_state_change: CallbackList
_after_state_change: CallbackList
_prepare_event: CallbackList
Expand Down Expand Up @@ -178,9 +179,9 @@ class Machine:
on_enter: CallbacksArg = ..., on_exit: CallbacksArg = ...,
ignore_invalid_triggers: Optional[bool] = ..., **kwargs: Dict[str, Any]) -> None: ...
def _add_model_to_state(self, state: State, model: object) -> None: ...
def _checked_assignment(self, model: object, name: str, func: Callable) -> None: ...
def _checked_assignment(self, model: object, name: str, func: CallbackFunc) -> None: ...
def _add_trigger_to_model(self, trigger: str, model: object) -> None: ...
def _get_trigger(self, model: object, trigger_name: str, *args: List, **kwargs: Dict[str, Any]) -> bool: ...
def _get_trigger(self, model: object, trigger_name: str, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...
def get_triggers(self, *args: Union[str, Enum, State]) -> List[str]: ...
def add_transition(self, trigger: str,
source: Union[StateIdentifier, List[StateIdentifier]],
Expand All @@ -201,13 +202,13 @@ class Machine:
def get_transitions(self, trigger: str = ...,
source: StateIdentifier = ..., dest: StateIdentifier = ...) -> List[Transition]: ...
def remove_transition(self, trigger: str, source: str = ..., dest: str = ...) -> None: ...
def dispatch(self, trigger: str, *args: List, **kwargs: Dict[str, Any]) -> bool: ...
def callbacks(self, funcs: Iterable[Union[str, Callable]], event_data: EventData) -> None: ...
def callback(self, func: Union[str, Callable], event_data: EventData) -> None: ...
def dispatch(self, trigger: str, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...
def callbacks(self, funcs: Iterable[Callback], event_data: EventData) -> None: ...
def callback(self, func: Callback, event_data: EventData) -> None: ...
@staticmethod
def resolve_callable(func: Union[str, Callable], event_data: EventData) -> Callable: ...
def resolve_callable(func: Callback, event_data: EventData) -> CallbackFunc: ...
def _has_state(self, state: StateIdentifier, raise_error: bool = ...) -> bool: ...
def _process(self, trigger: partial) -> bool: ...
def _process(self, trigger: Callable[[], bool]) -> bool: ...
def _identify_callback(self, name: str) -> Tuple[Optional[str], Optional[str]]: ...
def __getattr__(self, name: str) -> Any: ...

Expand Down
40 changes: 21 additions & 19 deletions transitions/extensions/asyncio.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..core import Condition, Event, EventData, Machine, State, Transition, StateConfig, ModelParameter, TransitionConfig
from ..core import Callback, Condition, Event, EventData, Machine, State, Transition, StateConfig, ModelParameter, TransitionConfig
from .nesting import HierarchicalMachine, NestedEvent, NestedState, NestedTransition
from typing import Any, Optional, List, Type, Dict, Deque, Callable, Union, Iterable, DefaultDict, Literal, Sequence
from typing import Any, Awaitable, Optional, List, Type, Dict, Deque, Callable, Union, Iterable, DefaultDict, Literal, Sequence
from asyncio import Task
from functools import partial
from logging import Logger
Expand All @@ -11,6 +11,9 @@ from ..core import StateIdentifier, CallbacksArg, CallbackList

_LOGGER: Logger

AsyncCallbackFunc = Callable[..., Awaitable[Optional[bool]]]
AsyncCallback = Union[str, AsyncCallbackFunc]

class AsyncState(State):
async def enter(self, event_data: AsyncEventData) -> None: ... # type: ignore[override]
async def exit(self, event_data: AsyncEventData) -> None: ... # type: ignore[override]
Expand Down Expand Up @@ -42,7 +45,7 @@ class AsyncEvent(Event):
machine: AsyncMachine
transitions: DefaultDict[str, List[AsyncTransition]] # type: ignore

async def trigger(self, model: object, *args: List, **kwargs: Dict[str, Any]) -> bool: ... # type: ignore[override]
async def trigger(self, model: object, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ... # type: ignore[override]
async def _trigger(self, event_data: AsyncEventData) -> bool: ... # type: ignore[override]
async def _process(self, event_data: AsyncEventData) -> bool: ... # type: ignore[override]

Expand All @@ -56,12 +59,12 @@ class AsyncMachine(Machine):
state_cls: Type[NestedAsyncState]
transition_cls: Type[AsyncTransition]
event_cls: Type[AsyncEvent]
async_tasks: Dict[int, List[Task]]
async_tasks: Dict[int, List[Task[Any]]]
events: Dict[str, AsyncEvent] # type: ignore
queued: Union[bool, Literal["model"]]
protected_tasks: List[Task]
current_context: ContextVar
_transition_queue_dict: Dict[int, Deque[Callable]]
protected_tasks: List[Task[Any]]
current_context: ContextVar[Optional[Task[Any]]]
_transition_queue_dict: Dict[int, Deque[AsyncCallbackFunc]]
def __init__(self, model: Optional[ModelParameter] = ...,
states: Optional[Union[Sequence[StateConfig], Type[Enum]]] = ...,
initial: Optional[StateIdentifier] = ...,
Expand All @@ -75,44 +78,43 @@ class AsyncMachine(Machine):
**kwargs: Dict[str, Any]) -> None: ...
def add_model(self, model: Union[Union[Literal["self"], object], Sequence[Union[Literal["self"], object]]],
initial: Optional[StateIdentifier] = ...) -> None: ...
async def dispatch(self, trigger: str, *args: List, **kwargs: Dict[str, Any]) -> bool: ... # type: ignore[override]
async def callbacks(self, funcs: Iterable[Union[str, Callable]], event_data: AsyncEventData) -> None: ... # type: ignore[override]
async def callback(self, func: Union[str, Callable], event_data: AsyncEventData) -> None: ... # type: ignore[override]
async def dispatch(self, trigger: str, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ... # type: ignore[override]
async def callbacks(self, funcs: Iterable[Callback], event_data: AsyncEventData) -> None: ... # type: ignore[override]
async def callback(self, func: AsyncCallback, event_data: AsyncEventData) -> None: ... # type: ignore[override]
@staticmethod
async def await_all(callables: List[Callable]) -> List: ...
async def await_all(callables: List[AsyncCallbackFunc]) -> Awaitable[List[Any]]: ...
async def switch_model_context(self, model: object) -> None: ...
def get_state(self, state: Union[str, Enum]) -> AsyncState: ...
async def process_context(self, func: partial, model: object) -> bool: ...
async def process_context(self, func: Callable[[], Awaitable[None]], model: object) -> bool: ...
def remove_model(self, model: object) -> None: ...
def _process(self, trigger: partial) -> bool: ...
async def _process_async(self, trigger: partial, model: object) -> bool: ...
async def _process_async(self, trigger: Callable[[], Awaitable[None]], model: object) -> bool: ...


class HierarchicalAsyncMachine(HierarchicalMachine, AsyncMachine): # type: ignore
state_cls: Type[NestedAsyncState]
transition_cls: Type[NestedAsyncTransition]
event_cls: Type[NestedAsyncEvent] # type: ignore
async def trigger_event(self, model: object, trigger: str, # type: ignore[override]
*args: List, **kwargs: Dict[str, Any]) -> bool: ...
*args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...
async def _trigger_event(self, event_data: AsyncEventData, trigger: str) -> bool: ... # type: ignore[override]


class AsyncTimeout(AsyncState):
dynamic_methods: List[str]
timeout: float
_on_timeout: CallbacksArg
runner: Dict[int, Task]
def __init__(self, *args: List, **kwargs: Dict[str, Any]) -> None: ...
runner: Dict[int, Task[Any]]
def __init__(self, *args: List[Any], **kwargs: Dict[str, Any]) -> None: ...
async def enter(self, event_data: AsyncEventData) -> None: ... # type: ignore[override]
async def exit(self, event_data: AsyncEventData) -> None: ... # type: ignore[override]
def create_timer(self, event_data: AsyncEventData) -> Task: ...
def create_timer(self, event_data: AsyncEventData) -> Task[Any]: ...
async def _process_timeout(self, event_data: AsyncEventData) -> None: ...
@property
def on_timeout(self) -> CallbackList: ...
@on_timeout.setter
def on_timeout(self, value: CallbacksArg) -> None: ...

class _DictionaryMock(dict):
class _DictionaryMock(Dict[Any, Any]):
_value: Any
def __init__(self, item: Any) -> None: ...
def __setitem__(self, key: Any, item: Any) -> None: ...
Expand Down
2 changes: 1 addition & 1 deletion transitions/extensions/diagrams.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ GraphvizParameters = Dict[str, Union[str, Dict[str, Any]]]

class TransitionGraphSupport(Transition):
label: str
def __init__(self, *args: List, **kwargs: Dict[str, Any]) -> None: ...
def __init__(self, *args: List[Any], **kwargs: Dict[str, Any]) -> None: ...
def _change_state(self, event_data: EventData) -> None: ...


Expand Down
4 changes: 2 additions & 2 deletions transitions/extensions/diagrams_base.pyi
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import abc
from typing import Protocol, Optional, Union, List, Dict, IO, Tuple, Generator
from typing import BinaryIO, Protocol, Optional, Union, List, Dict, Tuple, Generator

from .diagrams import GraphMachine, HierarchicalGraphMachine
from ..core import ModelState


class GraphProtocol(Protocol):

def draw(self, filename: Optional[Union[str, IO]], format:Optional[str] = ...,
def draw(self, filename: Optional[Union[str, BinaryIO]], format:Optional[str] = ...,
prog: Optional[str] = ..., args:str = ...) -> Optional[str]: ...

class GraphModelProtocol(Protocol):
Expand Down
8 changes: 4 additions & 4 deletions transitions/extensions/diagrams_graphviz.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from ..core import State, ModelState
from .diagrams import GraphMachine
from .diagrams_base import BaseGraph
from logging import Logger
from typing import Type, Optional, Dict, List, Union, IO, DefaultDict, Any
from typing import BinaryIO, Type, Optional, Dict, List, Union, DefaultDict, Any
try:
from graphviz import Digraph
from graphviz.dot import SubgraphContext
Expand All @@ -16,7 +16,7 @@ except ImportError:
_LOGGER: Logger

class Graph(BaseGraph):
custom_styles: Dict[str, DefaultDict]
custom_styles: Dict[str, DefaultDict[str, Union[str, DefaultDict[str, str]]]]
def __init__(self, machine: Type[GraphMachine]) -> None: ...
def set_previous_transition(self, src: str, dst: str) -> None: ...
def set_node_style(self, state: ModelState, style: str) -> None: ...
Expand All @@ -28,12 +28,12 @@ class Graph(BaseGraph):
def generate(self) -> None: ...
def get_graph(self, title: Optional[str] = ..., # type: ignore[no-any-unimported]
roi_state: Optional[str] = ...) -> Digraph: ...
def draw(self, filename: Optional[Union[str, IO]], format:Optional[str] = ...,
def draw(self, filename: Optional[Union[str, BinaryIO]], format:Optional[str] = ...,
prog: Optional[str] = ..., args:str = ...) -> Optional[str]: ...

class NestedGraph(Graph):
_cluster_states: List[str]
def __init__(self, *args: List, **kwargs: Dict[str, Any]) -> None: ...
def __init__(self, *args: List[Any], **kwargs: Dict[str, Any]) -> None: ...
def set_previous_transition(self, src: str, dst: str) -> None: ...
def _add_nodes(self, states: List[Dict[str, str]], # type: ignore[no-any-unimported]
container: Union[Digraph, SubgraphContext]) -> None: ...
Expand Down
2 changes: 1 addition & 1 deletion transitions/extensions/diagrams_pygraphviz.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Graph(BaseGraph):

class NestedGraph(Graph):
seen_transitions: Any
def __init__(self, *args: List, **kwargs: Dict[str, Any]) -> None: ...
def __init__(self, *args: List[Any], **kwargs: Dict[str, Any]) -> None: ...
def _add_nodes(self, # type: ignore[override, no-any-unimported]
states: List[Dict[str, Union[str, List[Dict[str, str]]]]],
container: AGraph, prefix: str = ..., default_style: str = ...) -> None: ...
Expand Down
8 changes: 4 additions & 4 deletions transitions/extensions/factory.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from ..core import Machine, State
from ..core import CallbackFunc, Machine, State
from .diagrams import GraphMachine, NestedGraphTransition, HierarchicalGraphMachine
from .locking import LockedMachine
from .nesting import HierarchicalMachine, NestedEvent
from typing import Type, Dict, Tuple, Callable, Union
from typing import Any, Type, Dict, Tuple, Callable, Union

try:
from transitions.extensions.asyncio import AsyncMachine, AsyncTransition
Expand Down Expand Up @@ -39,13 +39,13 @@ class LockedHierarchicalMachine(LockedMachine, HierarchicalMachine): # type: ig

class LockedGraphMachine(GraphMachine, LockedMachine): # type: ignore
@staticmethod
def format_references(func: Callable) -> str: ...
def format_references(func: CallbackFunc) -> str: ...

class LockedHierarchicalGraphMachine(GraphMachine, LockedHierarchicalMachine): # type: ignore
transition_cls: Type[NestedGraphTransition]
event_cls: Type[NestedEvent]
@staticmethod
def format_references(func: Callable) -> str: ...
def format_references(func: CallbackFunc) -> str: ...

class AsyncGraphMachine(GraphMachine, AsyncMachine):
# AsyncTransition already considers graph models when necessary
Expand Down
23 changes: 12 additions & 11 deletions transitions/extensions/locking.pyi
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from contextlib import AbstractContextManager
from transitions.core import Event, Machine, ModelParameter, TransitionConfig, CallbacksArg, StateConfig
from typing import Any, Dict, ContextManager, Literal, Optional, Type, List, DefaultDict, Union, Callable, Sequence
from typing import Any, Dict, Literal, Optional, Type, List, DefaultDict, Union, Callable, Sequence
from types import TracebackType
from logging import Logger
from threading import Lock
from enum import Enum

from ..core import StateIdentifier, State

_LOGGER: Logger

from enum import Enum

LockContext = AbstractContextManager[None]

class PicklableLock(ContextManager):
class PicklableLock(LockContext):
lock: Lock
def __init__(self) -> None: ...
def __getstate__(self) -> Dict[str, Any]: ...
Expand All @@ -20,7 +21,7 @@ class PicklableLock(ContextManager):
def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType]) -> None: ...

class IdentManager(ContextManager):
class IdentManager(LockContext):
current: int
def __init__(self) -> None: ...
def __enter__(self) -> None: ...
Expand All @@ -29,14 +30,14 @@ class IdentManager(ContextManager):

class LockedEvent(Event):
machine: LockedMachine
def trigger(self, model: object, *args: List, **kwargs: Dict[str, Any]) -> bool: ...
def trigger(self, model: object, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...


class LockedMachine(Machine):
event_cls: Type[LockedEvent]
_ident: IdentManager
machine_context: List[ContextManager]
model_context_map: DefaultDict[int, List[ContextManager]]
machine_context: List[LockContext]
model_context_map: DefaultDict[int, List[LockContext]]
def __init__(self, model: Optional[ModelParameter] = ...,
states: Optional[Union[Sequence[StateConfig], Type[Enum]]] = ...,
initial: Optional[StateIdentifier] = ...,
Expand All @@ -47,17 +48,17 @@ class LockedMachine(Machine):
name: str = ..., queued: bool = ...,
prepare_event: CallbacksArg = ..., finalize_event: CallbacksArg = ...,
model_attribute: str = ..., on_exception: CallbacksArg = ...,
machine_context: Optional[Union[List[ContextManager], ContextManager]] = ...,
machine_context: Optional[Union[List[LockContext], LockContext]] = ...,
**kwargs: Dict[str, Any]) -> None: ...
def __getstate__(self) -> Dict[str, Any]: ...
def __setstate__(self, state: Dict[str, Any]) -> None: ...
def add_model(self, model: Union[Union[Literal['self'], object], List[Union[Literal['self'], object]]],
initial: Optional[StateIdentifier] = ...,
model_context: Optional[Union[ContextManager, List[ContextManager]]] = ...) -> None: ...
model_context: Optional[Union[LockContext, List[LockContext]]] = ...) -> None: ...
def remove_model(self, model: Union[Union[Literal['self'], object],
List[Union[Literal['self'], object]]]) -> None: ...
def __getattribute__(self, item: str) -> Any: ...
def __getattr__(self, item: str) -> Any: ...
def _add_model_to_state(self, state: State, model: object) -> None: ...
def _get_qualified_state_name(self, state: State) -> str: ...
def _locked_method(self, func: Callable, *args: List, **kwargs: Dict[str, Any]) -> Any: ...
def _locked_method(self, func: Callable[..., Any], *args: List[Any], **kwargs: Dict[str, Any]) -> Any: ...
10 changes: 5 additions & 5 deletions transitions/extensions/markup.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numbers

from ..core import Machine, StateIdentifier, CallbacksArg, StateConfig, Event, TransitionConfig, ModelParameter
from ..core import CallbackFunc, Machine, StateIdentifier, CallbacksArg, StateConfig, Event, TransitionConfig, ModelParameter
from .nesting import HierarchicalMachine
from typing import List, Dict, Union, Optional, Callable, Tuple, Any, Type, Sequence, TypedDict

Expand Down Expand Up @@ -46,7 +46,7 @@ class MarkupMachine(Machine):
on_enter: CallbacksArg = ..., on_exit: CallbacksArg = ...,
ignore_invalid_triggers: Optional[bool] = ..., **kwargs: Dict[str, Any]) -> None: ...
@staticmethod
def format_references(func: Callable) -> str: ...
def format_references(func: CallbackFunc) -> str: ...
def _convert_states_and_transitions(self, root: MarkupConfig) -> None: ...
def _convert_states(self, root: MarkupConfig) -> None: ...
def _convert_transitions(self, root: MarkupConfig) -> None: ...
Expand All @@ -61,7 +61,7 @@ class HierarchicalMarkupMachine(MarkupMachine, HierarchicalMachine): # type: ig
pass


def rep(func: Union[Callable, str, numbers.Number],
format_references: Optional[Callable] = ...) -> str: ...
def _convert(obj: object, attributes: List[str], format_references: Optional[Callable]) -> MarkupConfig: ...
def rep(func: Union[CallbackFunc, str, Enum],
format_references: Optional[Callable[[CallbackFunc], str]] = ...) -> str: ...
def _convert(obj: object, attributes: List[str], format_references: Optional[Callable[[CallbackFunc], str]]) -> MarkupConfig: ...

Loading

0 comments on commit cb913e0

Please sign in to comment.