From 1e82a457ace343c15b481d08bf3f1d30d34f80de Mon Sep 17 00:00:00 2001 From: rus07tam Date: Wed, 26 Nov 2025 13:26:21 +0000 Subject: [PATCH] feat: add default_factory in PrivProperty --- src/snakia/field/field.py | 21 ++++++++++++------ src/snakia/property/priv_property.py | 33 +++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/src/snakia/field/field.py b/src/snakia/field/field.py index 628b7c7..57e9de6 100644 --- a/src/snakia/field/field.py +++ b/src/snakia/field/field.py @@ -1,7 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, final +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + TypeVar, + final, +) from snakia.property.priv_property import PrivProperty from snakia.utils import inherit @@ -11,10 +18,6 @@ R = TypeVar("R") class Field(ABC, PrivProperty[T], Generic[T]): - def __init__(self, default_value: T) -> None: - self.default_value: Final[T] = default_value - super().__init__(default_value) - @abstractmethod def serialize(self, value: T, /) -> bytes: """Serialize a value @@ -42,14 +45,18 @@ class Field(ABC, PrivProperty[T], Generic[T]): serialize: Callable[[Field[R], R], bytes], deserialize: Callable[[Field[R], bytes], R], ) -> type[Field[R]]: - return inherit(cls, {"serialize": serialize, "deserialize": deserialize}) + return inherit( + cls, {"serialize": serialize, "deserialize": deserialize} + ) @final @staticmethod def get_fields(class_: type[Any] | Any, /) -> dict[str, Field[Any]]: if not isinstance(class_, type): class_ = class_.__class__ - return {k: v for k, v in class_.__dict__.items() if isinstance(v, Field)} + return { + k: v for k, v in class_.__dict__.items() if isinstance(v, Field) + } if TYPE_CHECKING: diff --git a/src/snakia/property/priv_property.py b/src/snakia/property/priv_property.py index 1a85e09..705a2e6 100644 --- a/src/snakia/property/priv_property.py +++ b/src/snakia/property/priv_property.py @@ -1,22 +1,43 @@ -from typing import Any, Generic, TypeVar +from typing import Any, Callable, Final, Generic, TypeVar, overload +from typing_extensions import Self T = TypeVar("T") class PrivProperty(Generic[T]): - __slots__ = "__name", "__default_value" + __slots__ = "__name", "__default_value", "__default_factory" __name: str - def __init__(self, default_value: T | None = None) -> None: - self.__default_value: T | None = default_value + @overload + def __init__(self) -> None: ... + @overload + def __init__(self, default_value: T) -> None: ... + @overload + def __init__(self, *, default_factory: Callable[[Self], T]) -> None: ... + def __init__( + self, + default_value: T | None = None, + default_factory: Callable[[Self], T] | None = None, + ) -> None: + self.__default_value: Final[T | None] = default_value + self.__default_factory: Final[Callable[[Self], T] | None] = ( + default_factory + ) + + def _get_default(self: Self) -> T: + if self.__default_value is not None: + return self.__default_value + if self.__default_factory is not None: + return self.__default_factory(self) + raise ValueError("Either default_value or default_factory must be set") def __set_name__(self, owner: type, name: str) -> None: self.__name = f"_{owner.__name__}__{name}" def __get__(self, instance: Any, owner: type | None = None, /) -> T: - if self.__default_value: - return getattr(instance, self.__name, self.__default_value) + if not hasattr(instance, self.__name): + setattr(instance, self.__name, self._get_default()) return getattr(instance, self.__name) # type: ignore def __set__(self, instance: Any, value: T, /) -> None: