diff --git a/src/snakia/decorators/__init__.py b/src/snakia/decorators/__init__.py index afce779..5d84da4 100644 --- a/src/snakia/decorators/__init__.py +++ b/src/snakia/decorators/__init__.py @@ -2,17 +2,21 @@ from .inject_after import after_hook, inject_after from .inject_before import before_hook, inject_before from .inject_const import inject_const from .inject_replace import inject_replace, replace_hook +from .meta_decorators import hook_decorator, inject_decorator, replace_decorator from .pass_exceptions import pass_exceptions from .singleton import singleton __all__ = [ - "inject_replace", - "replace_hook", - "inject_after", "after_hook", - "inject_before", "before_hook", + "inject_after", + "inject_before", "inject_const", + "inject_decorator", + "inject_replace", + "hook_decorator", "pass_exceptions", + "replace_decorator", + "replace_hook", "singleton", ] diff --git a/src/snakia/decorators/meta_decorators.py b/src/snakia/decorators/meta_decorators.py new file mode 100644 index 0000000..e684d71 --- /dev/null +++ b/src/snakia/decorators/meta_decorators.py @@ -0,0 +1,72 @@ +import functools +from typing import Callable, Concatenate, ParamSpec, TypeVar + +T = TypeVar("T") +R = TypeVar("R") +D = ParamSpec("D") +P = ParamSpec("P") + + +def inject_decorator( + decorator: Callable[Concatenate[Callable[P, T], D], None], +) -> Callable[D, Callable[[Callable[P, T]], Callable[P, T]]]: + + @functools.wraps(decorator) + def wrapper( + *d_args: D.args, **d_kwargs: D.kwargs + ) -> Callable[[Callable[P, T]], Callable[P, T]]: + def inner(obj: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(obj) + def func(*args: P.args, **kwargs: P.kwargs) -> T: + decorator(obj, *d_args, **d_kwargs) + return obj(*args, **kwargs) + + return func + + return inner + + return wrapper + + +def hook_decorator( + decorator: Callable[Concatenate[Callable[P, T], T, D], T], +) -> Callable[D, Callable[[Callable[P, T]], Callable[P, T]]]: + + @functools.wraps(decorator) + def wrapper( + *d_args: D.args, **d_kwargs: D.kwargs + ) -> Callable[[Callable[P, T]], Callable[P, T]]: + def inner(obj: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(obj) + def func(*args: P.args, **kwargs: P.kwargs) -> T: + val = obj(*args, **kwargs) + return decorator(obj, val, *d_args, **d_kwargs) + + return func + + return inner + + return wrapper + + +def replace_decorator( + decorator: Callable[Concatenate[T, D], T], +) -> Callable[D, Callable[[T], T]]: + @functools.wraps(decorator) + def wrapper(*d_args: D.args, **d_kwargs: D.kwargs) -> Callable[[T], T]: + def inner(obj: T) -> T: + result = decorator(obj, *d_args, **d_kwargs) + if not callable(obj): + return result + for attr in functools.WRAPPER_ASSIGNMENTS: + try: + value = getattr(obj, attr) + except AttributeError: + pass + else: + setattr(result, attr, value) + return result + + return inner + + return wrapper