diff --git a/reflex/utils/types.py b/reflex/utils/types.py index d58825ed5..0adbe6162 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -303,6 +303,30 @@ def get_property_hint(attr: Any | None) -> GenericType | None: return hints.get("return", None) +def can_access_properties(cls: GenericType) -> bool: + """Check if the class can access properties. + + Args: + cls: The class to check. + + Returns: + Whether the class can access properties. + """ + from reflex.model import Model + + return ( + hasattr(cls, "__fields__") + or (isinstance(cls, type) and issubclass(cls, DeclarativeBase)) + or ( + isinstance(cls, type) + and not is_generic_alias(cls) + and issubclass(cls, Model) + ) + or (is_union(cls) and all(can_access_properties(arg) for arg in get_args(cls))) + or (isinstance(cls, type) and dataclasses.is_dataclass(cls)) + ) + + def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None: """Check if an attribute can be accessed on the cls and return its type. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index b06e7b7c9..9a8204ccb 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -67,6 +67,7 @@ from reflex.utils.types import ( GenericType, Self, _isinstance, + can_access_properties, get_origin, has_args, unionize, @@ -630,13 +631,10 @@ class Var(Generic[VAR_TYPE]): return get_to_operation(NoneVar).create(self) # type: ignore # Handle fixed_output_type being Base or a dataclass. - try: - if issubclass(fixed_output_type, Base): - return self.to(ObjectVar, output) - except TypeError: - pass - if dataclasses.is_dataclass(fixed_output_type) and not issubclass( - fixed_output_type, Var + if ( + inspect.isclass(fixed_output_type) + and not issubclass(fixed_output_type, Var) + and can_access_properties(fixed_output_type) ): return self.to(ObjectVar, output) @@ -707,11 +705,7 @@ class Var(Generic[VAR_TYPE]): ): return self.to(NumberVar, self._var_type) - if all( - inspect.isclass(t) - and (issubclass(t, Base) or dataclasses.is_dataclass(t)) - for t in inner_types - ): + if all(can_access_properties(t) for t in inner_types): return self.to(ObjectVar, self._var_type) return self @@ -730,13 +724,9 @@ class Var(Generic[VAR_TYPE]): if issubclass(fixed_type, var_subclass.python_types): return self.to(var_subclass.var_subclass, self._var_type) - try: - if issubclass(fixed_type, Base): - return self.to(ObjectVar, self._var_type) - except TypeError: - pass - if dataclasses.is_dataclass(fixed_type): + if can_access_properties(fixed_type): return self.to(ObjectVar, self._var_type) + return self def get_default_value(self) -> Any: diff --git a/reflex/vars/number.py b/reflex/vars/number.py index e403e63e4..a762796e2 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -1116,7 +1116,9 @@ U = TypeVar("U") @var_operation -def ternary_operation(condition: BooleanVar, if_true: Var[T], if_false: Var[U]): +def ternary_operation( + condition: BooleanVar, if_true: Var[T], if_false: Var[U] +) -> CustomVarOperationReturn[Union[T, U]]: """Create a ternary operation. Args: diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 56f3535d8..4d5827527 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -36,7 +36,7 @@ from .base import ( from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .sequence import ArrayVar, StringVar -OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) +OBJECT_TYPE = TypeVar("OBJECT_TYPE") KEY_TYPE = TypeVar("KEY_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE") @@ -59,7 +59,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def _value_type( - self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]], + self: ObjectVar[Dict[Any, VALUE_TYPE]], ) -> Type[VALUE_TYPE]: ... @overload @@ -87,7 +87,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def values( - self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]], + self: ObjectVar[Dict[Any, VALUE_TYPE]], ) -> ArrayVar[List[VALUE_TYPE]]: ... @overload @@ -103,7 +103,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def entries( - self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]], + self: ObjectVar[Dict[Any, VALUE_TYPE]], ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ... @overload @@ -133,47 +133,47 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): # NoReturn is used here to catch when key value is Any @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, NoReturn]], + self: ObjectVar[Dict[Any, NoReturn]], key: Var | Any, ) -> Var: ... @overload def __getitem__( self: ( - ObjectVar[Dict[KEY_TYPE, int]] - | ObjectVar[Dict[KEY_TYPE, float]] - | ObjectVar[Dict[KEY_TYPE, int | float]] + ObjectVar[Dict[Any, int]] + | ObjectVar[Dict[Any, float]] + | ObjectVar[Dict[Any, int | float]] ), key: Var | Any, ) -> NumberVar: ... @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, str]], + self: ObjectVar[Dict[Any, str]], key: Var | Any, ) -> StringVar: ... @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], key: Var | Any, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getitem__( - self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], key: Var | Any, ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... @@ -195,47 +195,47 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): # NoReturn is used here to catch when key value is Any @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, NoReturn]], + self: ObjectVar[Dict[Any, NoReturn]], name: str, ) -> Var: ... @overload def __getattr__( self: ( - ObjectVar[Dict[KEY_TYPE, int]] - | ObjectVar[Dict[KEY_TYPE, float]] - | ObjectVar[Dict[KEY_TYPE, int | float]] + ObjectVar[Dict[Any, int]] + | ObjectVar[Dict[Any, float]] + | ObjectVar[Dict[Any, int | float]] ), name: str, ) -> NumberVar: ... @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, str]], + self: ObjectVar[Dict[Any, str]], name: str, ) -> StringVar: ... @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]], + self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], name: str, ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]], + self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], name: str, ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... @overload def __getattr__( - self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], name: str, ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... @@ -377,8 +377,8 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): @classmethod def create( cls, - _var_value: OBJECT_TYPE, - _var_type: GenericType | None = None, + _var_value: dict, + _var_type: Type[OBJECT_TYPE] | None = None, _var_data: VarData | None = None, ) -> LiteralObjectVar[OBJECT_TYPE]: """Create the literal object var. diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 39139ce3f..08429883f 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -853,31 +853,31 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): @overload def __getitem__( self: ( - ArrayVar[Tuple[OTHER_TUPLE, int]] - | ArrayVar[Tuple[OTHER_TUPLE, float]] - | ArrayVar[Tuple[OTHER_TUPLE, int | float]] + ArrayVar[Tuple[Any, int]] + | ArrayVar[Tuple[Any, float]] + | ArrayVar[Tuple[Any, int | float]] ), i: Literal[1, -1], ) -> NumberVar: ... @overload def __getitem__( - self: ArrayVar[Tuple[str, OTHER_TUPLE]], i: Literal[0, -2] + self: ArrayVar[Tuple[str, Any]], i: Literal[0, -2] ) -> StringVar: ... @overload def __getitem__( - self: ArrayVar[Tuple[OTHER_TUPLE, str]], i: Literal[1, -1] + self: ArrayVar[Tuple[Any, str]], i: Literal[1, -1] ) -> StringVar: ... @overload def __getitem__( - self: ArrayVar[Tuple[bool, OTHER_TUPLE]], i: Literal[0, -2] + self: ArrayVar[Tuple[bool, Any]], i: Literal[0, -2] ) -> BooleanVar: ... @overload def __getitem__( - self: ArrayVar[Tuple[OTHER_TUPLE, bool]], i: Literal[1, -1] + self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1] ) -> BooleanVar: ... @overload diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 271f2e794..88dc6d39f 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -45,7 +45,7 @@ from reflex.testing import chdir from reflex.utils import format, prerequisites, types from reflex.utils.exceptions import SetUndefinedStateVarError from reflex.utils.format import json_dumps -from reflex.vars.base import ComputedVar, Var +from reflex.vars.base import Var, computed_var from tests.units.states.mutation import MutableSQLAModel, MutableTestState from .states import GenState @@ -109,7 +109,7 @@ class TestState(BaseState): _backend: int = 0 asynctest: int = 0 - @ComputedVar + @computed_var def sum(self) -> float: """Dynamically sum the numbers. @@ -118,7 +118,7 @@ class TestState(BaseState): """ return self.num1 + self.num2 - @ComputedVar + @computed_var def upper(self) -> str: """Uppercase the key. @@ -1124,7 +1124,7 @@ def test_child_state(): v: int = 2 class ChildState(MainState): - @ComputedVar + @computed_var def rendered_var(self): return self.v @@ -1143,7 +1143,7 @@ def test_conditional_computed_vars(): t1: str = "a" t2: str = "b" - @ComputedVar + @computed_var def rendered_var(self) -> str: if self.flag: return self.t1 @@ -3095,12 +3095,12 @@ def test_potentially_dirty_substates(): """ class State(RxState): - @ComputedVar + @computed_var def foo(self) -> str: return "" class C1(State): - @ComputedVar + @computed_var def bar(self) -> str: return ""