diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index 614257181..d3dbb1d4c 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -78,7 +78,7 @@ def serializer( ) # Apply type transformation if requested - if to is not None: + if to is not None or ((to := type_hints.get("return")) is not None): SERIALIZER_TYPES[type_] = to get_serializer_type.cache_clear() @@ -189,16 +189,37 @@ def get_serializer_type(type_: Type) -> Optional[Type]: return None -def has_serializer(type_: Type) -> bool: +def has_serializer(type_: Type, into_type: Type | None = None) -> bool: """Check if there is a serializer for the type. Args: type_: The type to check. + into_type: The type to serialize into. Returns: Whether there is a serializer for the type. """ - return get_serializer(type_) is not None + serializer_for_type = get_serializer(type_) + return serializer_for_type is not None and ( + into_type is None or get_serializer_type(type_) == into_type + ) + + +def can_serialize(type_: Type, into_type: Type | None = None) -> bool: + """Check if there is a serializer for the type. + + Args: + type_: The type to check. + into_type: The type to serialize into. + + Returns: + Whether there is a serializer for the type. + """ + return has_serializer(type_, into_type) or ( + isinstance(type_, type) + and dataclasses.is_dataclass(type_) + and (into_type is None or into_type is dict) + ) @serializer(to=str) @@ -214,7 +235,7 @@ def serialize_type(value: type) -> str: return value.__name__ -@serializer +@serializer(to=dict) def serialize_base(value: Base) -> dict: """Serialize a Base instance. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index b06e7b7c9..0e6bbaec7 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -75,7 +75,6 @@ from reflex.utils.types import ( if TYPE_CHECKING: from reflex.state import BaseState - from .function import FunctionVar from .number import ( BooleanVar, NumberVar, @@ -279,6 +278,24 @@ def _decode_var_immutable(value: str) -> tuple[VarData | None, str]: return VarData.merge(*var_datas) if var_datas else None, value +def can_use_in_object_var(cls: GenericType) -> bool: + """Check if the class can be used in an ObjectVar. + + Args: + cls: The class to check. + + Returns: + Whether the class can be used in an ObjectVar. + """ + if types.is_union(cls): + return all(can_use_in_object_var(t) for t in types.get_args(cls)) + return ( + inspect.isclass(cls) + and not issubclass(cls, Var) + and serializers.can_serialize(cls, dict) + ) + + @dataclasses.dataclass( eq=False, frozen=True, @@ -565,36 +582,33 @@ class Var(Generic[VAR_TYPE]): # Encode the _var_data into the formatted output for tracking purposes. return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._js_expr}" - @overload - def to(self, output: Type[StringVar]) -> StringVar: ... - @overload def to(self, output: Type[str]) -> StringVar: ... @overload - def to(self, output: Type[BooleanVar]) -> BooleanVar: ... + def to(self, output: Type[bool]) -> BooleanVar: ... @overload - def to( - self, output: Type[NumberVar], var_type: type[int] | type[float] = float - ) -> NumberVar: ... + def to(self, output: type[int] | type[float]) -> NumberVar: ... @overload def to( self, - output: Type[ArrayVar], - var_type: type[list] | type[tuple] | type[set] = list, + output: type[list] | type[tuple] | type[set], ) -> ArrayVar: ... @overload def to( - self, output: Type[ObjectVar], var_type: types.GenericType = dict - ) -> ObjectVar: ... + self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE] + ) -> ObjectVar[VAR_INSIDE]: ... @overload def to( - self, output: Type[FunctionVar], var_type: Type[Callable] = Callable - ) -> FunctionVar: ... + self, output: Type[ObjectVar], var_type: None = None + ) -> ObjectVar[VAR_TYPE]: ... + + @overload + def to(self, output: VAR_SUBCLASS, var_type: None = None) -> VAR_SUBCLASS: ... @overload def to( @@ -630,21 +644,19 @@ class Var(Generic[VAR_TYPE]): return get_to_operation(NoneVar).create(self) # type: ignore # Handle fixed_output_type being Base or a dataclass. - try: - if issubclass(fixed_output_type, Base): - return self.to(ObjectVar, output) - except TypeError: - pass - if dataclasses.is_dataclass(fixed_output_type) and not issubclass( - fixed_output_type, Var - ): + if can_use_in_object_var(fixed_output_type): return self.to(ObjectVar, output) if inspect.isclass(output): for var_subclass in _var_subclasses[::-1]: if issubclass(output, var_subclass.var_subclass): + current_var_type = self._var_type + if current_var_type is Any: + new_var_type = var_type + else: + new_var_type = var_type or current_var_type to_operation_return = var_subclass.to_var_subclass.create( - value=self, _var_type=var_type + value=self, _var_type=new_var_type ) return to_operation_return # type: ignore @@ -707,11 +719,7 @@ class Var(Generic[VAR_TYPE]): ): return self.to(NumberVar, self._var_type) - if all( - inspect.isclass(t) - and (issubclass(t, Base) or dataclasses.is_dataclass(t)) - for t in inner_types - ): + if can_use_in_object_var(var_type): return self.to(ObjectVar, self._var_type) return self @@ -730,13 +738,9 @@ class Var(Generic[VAR_TYPE]): if issubclass(fixed_type, var_subclass.python_types): return self.to(var_subclass.var_subclass, self._var_type) - try: - if issubclass(fixed_type, Base): - return self.to(ObjectVar, self._var_type) - except TypeError: - pass - if dataclasses.is_dataclass(fixed_type): + if can_use_in_object_var(fixed_type): return self.to(ObjectVar, self._var_type) + return self def get_default_value(self) -> Any: @@ -1181,6 +1185,9 @@ class Var(Generic[VAR_TYPE]): OUTPUT = TypeVar("OUTPUT", bound=Var) +VAR_SUBCLASS = TypeVar("VAR_SUBCLASS", bound=Var) +VAR_INSIDE = TypeVar("VAR_INSIDE") + class ToOperation: """A var operation that converts a var to another type.""" @@ -2888,6 +2895,8 @@ def dispatch( V = TypeVar("V") +BASE_TYPE = TypeVar("BASE_TYPE", bound=Base) + class Field(Generic[T]): """Shadow class for Var to allow for type hinting in the IDE.""" @@ -2924,6 +2933,11 @@ class Field(Generic[T]): self: Field[Dict[str, V]], instance: None, owner ) -> ObjectVar[Dict[str, V]]: ... + @overload + def __get__( + self: Field[BASE_TYPE], instance: None, owner + ) -> ObjectVar[BASE_TYPE]: ... + @overload def __get__(self, instance: None, owner) -> Var[T]: ... diff --git a/reflex/vars/number.py b/reflex/vars/number.py index e403e63e4..a762796e2 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -1116,7 +1116,9 @@ U = TypeVar("U") @var_operation -def ternary_operation(condition: BooleanVar, if_true: Var[T], if_false: Var[U]): +def ternary_operation( + condition: BooleanVar, if_true: Var[T], if_false: Var[U] +) -> CustomVarOperationReturn[Union[T, U]]: """Create a ternary operation. Args: diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 56f3535d8..e60ea09e3 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -36,7 +36,7 @@ from .base import ( from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .sequence import ArrayVar, StringVar -OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) +OBJECT_TYPE = TypeVar("OBJECT_TYPE") KEY_TYPE = TypeVar("KEY_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE") @@ -59,7 +59,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def _value_type( - self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]], + self: ObjectVar[Dict[Any, VALUE_TYPE]], ) -> Type[VALUE_TYPE]: ... @overload @@ -87,7 +87,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def values( - self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]], + self: ObjectVar[Dict[Any, VALUE_TYPE]], ) -> ArrayVar[List[VALUE_TYPE]]: ... @overload @@ -103,7 +103,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def entries( - self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]], + self: ObjectVar[Dict[Any, VALUE_TYPE]], ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ... @overload @@ -133,47 +133,47 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): # NoReturn is used here to catch when key value is Any @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, NoReturn]], + self: ObjectVar[Dict[Any, NoReturn]], key: Var | Any, ) -> Var: ... @overload def __getitem__( self: ( - ObjectVar[Dict[KEY_TYPE, int]] - | ObjectVar[Dict[KEY_TYPE, float]] - | ObjectVar[Dict[KEY_TYPE, int | float]] + ObjectVar[Dict[Any, int]] + | ObjectVar[Dict[Any, float]] + | ObjectVar[Dict[Any, int | float]] ), key: Var | Any, ) -> NumberVar: ... @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, str]], + self: ObjectVar[Dict[Any, str]], key: Var | Any, ) -> StringVar: ... @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], key: Var | Any, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], key: Var | Any, ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... @@ -195,50 +195,56 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): # NoReturn is used here to catch when key value is Any @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, NoReturn]], + self: ObjectVar[Dict[Any, NoReturn]], name: str, ) -> Var: ... @overload def __getattr__( self: ( - ObjectVar[Dict[KEY_TYPE, int]] - | ObjectVar[Dict[KEY_TYPE, float]] - | ObjectVar[Dict[KEY_TYPE, int | float]] + ObjectVar[Dict[Any, int]] + | ObjectVar[Dict[Any, float]] + | ObjectVar[Dict[Any, int | float]] ), name: str, ) -> NumberVar: ... @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, str]], + self: ObjectVar[Dict[Any, str]], name: str, ) -> StringVar: ... @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], name: str, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], name: str, ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + @overload + def __getattr__( + self: ObjectVar, + name: str, + ) -> ObjectItemOperation: ... + def __getattr__(self, name) -> Var: """Get an attribute of the var. @@ -377,8 +383,8 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): @classmethod def create( cls, - _var_value: OBJECT_TYPE, - _var_type: GenericType | None = None, + _var_value: dict, + _var_type: Type[OBJECT_TYPE] | None = None, _var_data: VarData | None = None, ) -> LiteralObjectVar[OBJECT_TYPE]: """Create the literal object var. diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 39139ce3f..08429883f 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -853,31 +853,31 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): @overload def __getitem__( self: ( - ArrayVar[Tuple[OTHER_TUPLE, int]] - | ArrayVar[Tuple[OTHER_TUPLE, float]] - | ArrayVar[Tuple[OTHER_TUPLE, int | float]] + ArrayVar[Tuple[Any, int]] + | ArrayVar[Tuple[Any, float]] + | ArrayVar[Tuple[Any, int | float]] ), i: Literal[1, -1], ) -> NumberVar: ... @overload def __getitem__( - self: ArrayVar[Tuple[str, OTHER_TUPLE]], i: Literal[0, -2] + self: ArrayVar[Tuple[str, Any]], i: Literal[0, -2] ) -> StringVar: ... @overload def __getitem__( - self: ArrayVar[Tuple[OTHER_TUPLE, str]], i: Literal[1, -1] + self: ArrayVar[Tuple[Any, str]], i: Literal[1, -1] ) -> StringVar: ... @overload def __getitem__( - self: ArrayVar[Tuple[bool, OTHER_TUPLE]], i: Literal[0, -2] + self: ArrayVar[Tuple[bool, Any]], i: Literal[0, -2] ) -> BooleanVar: ... @overload def __getitem__( - self: ArrayVar[Tuple[OTHER_TUPLE, bool]], i: Literal[1, -1] + self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1] ) -> BooleanVar: ... @overload diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 3ff8d453c..83e348cd2 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -45,7 +45,7 @@ from reflex.testing import chdir from reflex.utils import format, prerequisites, types from reflex.utils.exceptions import SetUndefinedStateVarError from reflex.utils.format import json_dumps -from reflex.vars.base import ComputedVar, Var +from reflex.vars.base import Var, computed_var from tests.units.states.mutation import MutableSQLAModel, MutableTestState from .states import GenState @@ -109,7 +109,7 @@ class TestState(BaseState): _backend: int = 0 asynctest: int = 0 - @ComputedVar + @computed_var def sum(self) -> float: """Dynamically sum the numbers. @@ -118,7 +118,7 @@ class TestState(BaseState): """ return self.num1 + self.num2 - @ComputedVar + @computed_var def upper(self) -> str: """Uppercase the key. @@ -1124,7 +1124,7 @@ def test_child_state(): v: int = 2 class ChildState(MainState): - @ComputedVar + @computed_var def rendered_var(self): return self.v @@ -1143,7 +1143,7 @@ def test_conditional_computed_vars(): t1: str = "a" t2: str = "b" - @ComputedVar + @computed_var def rendered_var(self) -> str: if self.flag: return self.t1 @@ -3095,12 +3095,12 @@ def test_potentially_dirty_substates(): """ class State(RxState): - @ComputedVar + @computed_var def foo(self) -> str: return "" class C1(State): - @ComputedVar + @computed_var def bar(self) -> str: return "" diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py new file mode 100644 index 000000000..efcb21166 --- /dev/null +++ b/tests/units/vars/test_object.py @@ -0,0 +1,102 @@ +import pytest +from typing_extensions import assert_type + +import reflex as rx +from reflex.utils.types import GenericType +from reflex.vars.base import Var +from reflex.vars.object import LiteralObjectVar, ObjectVar + + +class Bare: + """A bare class with a single attribute.""" + + quantity: int = 0 + + +@rx.serializer +def serialize_bare(obj: Bare) -> dict: + """A serializer for the bare class. + + Args: + obj: The object to serialize. + + Returns: + A dictionary with the quantity attribute. + """ + return {"quantity": obj.quantity} + + +class Base(rx.Base): + """A reflex base class with a single attribute.""" + + quantity: int = 0 + + +class ObjectState(rx.State): + """A reflex state with bare and base objects.""" + + bare: rx.Field[Bare] = rx.field(Bare()) + base: rx.Field[Base] = rx.field(Base()) + + +@pytest.mark.parametrize("type_", [Base, Bare]) +def test_var_create(type_: GenericType) -> None: + my_object = type_() + var = Var.create(my_object) + assert var._var_type is type_ + + quantity = var.quantity + assert quantity._var_type is int + + +@pytest.mark.parametrize("type_", [Base, Bare]) +def test_literal_create(type_: GenericType) -> None: + my_object = type_() + var = LiteralObjectVar.create(my_object) + assert var._var_type is type_ + + quantity = var.quantity + assert quantity._var_type is int + + +@pytest.mark.parametrize("type_", [Base, Bare]) +def test_guess(type_: GenericType) -> None: + my_object = type_() + var = Var.create(my_object) + var = var.guess_type() + assert var._var_type is type_ + + quantity = var.quantity + assert quantity._var_type is int + + +@pytest.mark.parametrize("type_", [Base, Bare]) +def test_state(type_: GenericType) -> None: + attr_name = type_.__name__.lower() + var = getattr(ObjectState, attr_name) + assert var._var_type is type_ + + quantity = var.quantity + assert quantity._var_type is int + + +@pytest.mark.parametrize("type_", [Base, Bare]) +def test_state_to_operation(type_: GenericType) -> None: + attr_name = type_.__name__.lower() + original_var = getattr(ObjectState, attr_name) + + var = original_var.to(ObjectVar, type_) + assert var._var_type is type_ + + var = original_var.to(ObjectVar) + assert var._var_type is type_ + + +def test_typing() -> None: + # Bare + var = ObjectState.bare.to(ObjectVar) + _ = assert_type(var, ObjectVar[Bare]) + + # Base + var = ObjectState.base + _ = assert_type(var, ObjectVar[Base])