diff --git a/src/snakia/core/rx/__init__.py b/src/snakia/core/rx/__init__.py index b4bf207..3c6bf2d 100644 --- a/src/snakia/core/rx/__init__.py +++ b/src/snakia/core/rx/__init__.py @@ -1,7 +1,7 @@ from .async_bindable import AsyncBindable from .base_bindable import BaseBindable, BindableSubscriber, ValueChanged from .bindable import Bindable -from .chains import chain +from .chains import async_chain, chain from .combines import async_combine, combine from .concats import concat from .conds import cond @@ -16,6 +16,7 @@ __all__ = [ "BaseBindable", "BindableSubscriber", "ValueChanged", + "async_chain", "async_combine", "async_merge", "chain", diff --git a/src/snakia/core/rx/chains.py b/src/snakia/core/rx/chains.py index 0672e96..5181d4d 100644 --- a/src/snakia/core/rx/chains.py +++ b/src/snakia/core/rx/chains.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, ParamSpec, TypeVar, overload +from typing import Any, Awaitable, Callable, ParamSpec, TypeVar, overload P = ParamSpec("P") @@ -14,7 +14,9 @@ def chain(func1: Callable[P, A], /) -> Callable[P, A]: ... @overload -def chain(func1: Callable[P, A], func2: Callable[[A], B], /) -> Callable[P, B]: ... +def chain( + func1: Callable[P, A], func2: Callable[[A], B], / +) -> Callable[P, B]: ... @overload def chain( func1: Callable[P, A], func2: Callable[[A], B], func3: Callable[[B], C], / @@ -46,7 +48,9 @@ def chain( ) -> Callable[P, Any]: ... -def chain(func1: Callable[P, Any], /, *funcs: Callable[[Any], Any]) -> Callable[P, Any]: +def chain( + func1: Callable[P, Any], /, *funcs: Callable[[Any], Any] +) -> Callable[P, Any]: def inner(*args: P.args, **kwargs: P.kwargs) -> Any: v = func1(*args, **kwargs) @@ -55,3 +59,66 @@ def chain(func1: Callable[P, Any], /, *funcs: Callable[[Any], Any]) -> Callable[ return v return inner + + +@overload +def async_chain( + func1: Callable[P, Awaitable[A]], / +) -> Callable[P, Awaitable[A]]: ... + + +@overload +def async_chain( + func1: Callable[P, Awaitable[A]], func2: Callable[[A], Awaitable[B]], / +) -> Callable[P, Awaitable[B]]: ... + + +@overload +def async_chain( + func1: Callable[P, Awaitable[A]], + func2: Callable[[A], Awaitable[B]], + func3: Callable[[B], Awaitable[C]], + /, +) -> Callable[P, Awaitable[C]]: ... + + +@overload +def async_chain( + func1: Callable[P, Awaitable[A]], + func2: Callable[[A], Awaitable[B]], + func3: Callable[[B], Awaitable[C]], + func4: Callable[[C], Awaitable[D]], + /, +) -> Callable[P, Awaitable[D]]: ... + + +@overload +def async_chain( + func1: Callable[P, Awaitable[A]], + func2: Callable[[A], Awaitable[B]], + func3: Callable[[B], Awaitable[C]], + func4: Callable[[C], Awaitable[D]], + func5: Callable[[D], Awaitable[E]], + /, +) -> Callable[P, Awaitable[E]]: ... + + +@overload +def async_chain( + func1: Callable[P, Any], /, *funcs: Callable[[Any], Awaitable[Any]] +) -> Callable[P, Awaitable[Any]]: ... + + +def async_chain( + func1: Callable[P, Awaitable[Any]], + /, + *funcs: Callable[[Any], Awaitable[Any]], +) -> Callable[P, Awaitable[Any]]: + + async def inner(*args: P.args, **kwargs: P.kwargs) -> Any: + v = await func1(*args, **kwargs) + for f in funcs: + v = await f(v) + return v + + return inner