diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index 9167656f3..d3dbb1d4c 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -78,7 +78,7 @@ def serializer( ) # Apply type transformation if requested - if to is not None: + if to is not None or ((to := type_hints.get("return")) is not None): SERIALIZER_TYPES[type_] = to get_serializer_type.cache_clear() diff --git a/reflex/vars/base.py b/reflex/vars/base.py index d0e71b871..303cabcc2 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -654,8 +654,13 @@ class Var(Generic[VAR_TYPE]): if inspect.isclass(output): for var_subclass in _var_subclasses[::-1]: if issubclass(output, var_subclass.var_subclass): + current_var_type = self._var_type + if current_var_type is Any: + new_var_type = var_type + else: + new_var_type = var_type or current_var_type to_operation_return = var_subclass.to_var_subclass.create( - value=self, _var_type=var_type + value=self, _var_type=new_var_type ) return to_operation_return # type: ignore diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py new file mode 100644 index 000000000..d762b208e --- /dev/null +++ b/tests/units/vars/test_object.py @@ -0,0 +1,113 @@ +from typing import assert_type, reveal_type + +import pytest + +import reflex as rx +from reflex.utils.types import GenericType +from reflex.vars.base import Var +from reflex.vars.object import LiteralObjectVar, ObjectVar + + +class Bare: + """A bare class with a single attribute.""" + + quantity: int = 0 + + +@rx.serializer +def serialize_bare(obj: Bare) -> dict: + """A serializer for the bare class. + + Args: + obj: The object to serialize. + + Returns: + A dictionary with the quantity attribute. + """ + return {"quantity": obj.quantity} + + +class Base(rx.Base): + """A reflex base class with a single attribute.""" + + quantity: int = 0 + + +class ObjectState(rx.State): + """A reflex state with bare and base objects.""" + + bare: rx.Field[Bare] = rx.field(Bare()) + base: rx.Field[Base] = rx.field(Base()) + + +@pytest.mark.parametrize("type_", [Base, Bare]) +def test_var_create(type_: GenericType) -> None: + my_object = type_() + var = Var.create(my_object) + assert var._var_type is type_ + + quantity = var.quantity + assert quantity._var_type is int + + +@pytest.mark.parametrize("type_", [Base, Bare]) +def test_literal_create(type_: GenericType) -> None: + my_object = type_() + var = LiteralObjectVar.create(my_object) + assert var._var_type is type_ + + quantity = var.quantity + assert quantity._var_type is int + + +@pytest.mark.parametrize("type_", [Base, Bare]) +def test_guess(type_: GenericType) -> None: + my_object = type_() + var = Var.create(my_object) + var = var.guess_type() + assert var._var_type is type_ + + quantity = var.quantity + assert quantity._var_type is int + + +@pytest.mark.parametrize("type_", [Base, Bare]) +def test_state(type_: GenericType) -> None: + attr_name = type_.__name__.lower() + var = getattr(ObjectState, attr_name) + assert var._var_type is type_ + + quantity = var.quantity + assert quantity._var_type is int + + +@pytest.mark.parametrize("type_", [Base, Bare]) +def test_state_to_operation(type_: GenericType) -> None: + attr_name = type_.__name__.lower() + original_var = getattr(ObjectState, attr_name) + + var = original_var.to(ObjectVar, type_) + assert var._var_type is type_ + + var = original_var.to(ObjectVar) + assert var._var_type is type_ + + +def test_typing() -> None: + # Bare + var = ObjectState.bare.to(ObjectVar) + reveal_type(var) + _ = assert_type(var, ObjectVar[Bare]) + + var = ObjectState.base.to(ObjectVar, Base) + reveal_type(var) + _ = assert_type(var, ObjectVar[Base]) + + # Base + var = ObjectState.base.to(ObjectVar) + reveal_type(var) + _ = assert_type(var, ObjectVar[Base]) + + var = ObjectState.base.to(LiteralObjectVar, Base) + reveal_type(var) + _ = assert_type(var, LiteralObjectVar[Base])