diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 2f26e9170..1ed9acb99 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -64,6 +64,7 @@ from reflex.utils.imports import ( parse_imports, ) from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize +from reflex.vars.object import LiteralObjectVar if TYPE_CHECKING: from reflex.state import BaseState @@ -579,10 +580,35 @@ class Var(Generic[VAR_TYPE]): var_type: type[list] | type[tuple] | type[set] = list, ) -> ArrayVar: ... + VAR_TYPE_ARG = TypeVar("VAR_TYPE_ARG") + @overload def to( - self, output: Type[ObjectVar], var_type: types.GenericType = dict - ) -> ObjectVar: ... + self, + output: Type[LiteralObjectVar[Any]], + var_type: None = None, + ) -> LiteralObjectVar[VAR_TYPE]: ... + + @overload + def to( + self, + output: Type[LiteralObjectVar[Any]], + var_type: Type[VAR_TYPE_ARG], + ) -> LiteralObjectVar[VAR_TYPE_ARG]: ... + + @overload + def to( + self, + output: Type[ObjectVar[Any]], + var_type: None = None, + ) -> ObjectVar[VAR_TYPE]: ... + + @overload + def to( + self, + output: Type[ObjectVar[Any]], + var_type: Type[VAR_TYPE_ARG], + ) -> ObjectVar[VAR_TYPE_ARG]: ... @overload def to( diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 56f3535d8..19b24887c 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -36,7 +36,7 @@ from .base import ( from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .sequence import ArrayVar, StringVar -OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) +OBJECT_TYPE = TypeVar("OBJECT_TYPE") KEY_TYPE = TypeVar("KEY_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE") diff --git a/tests/units/vars/test_object.py b/tests/units/vars/test_object.py index 497b5adb2..daa46e6ca 100644 --- a/tests/units/vars/test_object.py +++ b/tests/units/vars/test_object.py @@ -94,17 +94,17 @@ def test_typing() -> None: # Bare var = ObjectState.bare.to(ObjectVar) reveal_type(var) - assert_type(var, ObjectVar[Bare]) + _ = assert_type(var, ObjectVar[Bare]) var = ObjectState.base.to(ObjectVar, Base) reveal_type(var) - assert_type(var, ObjectVar[Base]) + _ = assert_type(var, ObjectVar[Base]) # Base var = ObjectState.base.to(ObjectVar) reveal_type(var) - assert_type(var, ObjectVar[Base]) + _ = assert_type(var, ObjectVar[Base]) var = ObjectState.base.to(LiteralObjectVar, Base) reveal_type(var) - assert_type(var, LiteralObjectVar[Base]) + _ = assert_type(var, LiteralObjectVar[Base])