improve object var symantics

This commit is contained in:
Khaleel Al-Adhami 2024-11-01 16:42:22 -07:00
parent b70f33d972
commit 7794e065a8
6 changed files with 73 additions and 57 deletions

View File

@ -303,6 +303,30 @@ def get_property_hint(attr: Any | None) -> GenericType | None:
return hints.get("return", None)
def can_access_properties(cls: GenericType) -> bool:
"""Check if the class can access properties.
Args:
cls: The class to check.
Returns:
Whether the class can access properties.
"""
from reflex.model import Model
return (
hasattr(cls, "__fields__")
or (isinstance(cls, type) and issubclass(cls, DeclarativeBase))
or (
isinstance(cls, type)
and not is_generic_alias(cls)
and issubclass(cls, Model)
)
or (is_union(cls) and all(can_access_properties(arg) for arg in get_args(cls)))
or (isinstance(cls, type) and dataclasses.is_dataclass(cls))
)
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
"""Check if an attribute can be accessed on the cls and return its type.

View File

@ -67,6 +67,7 @@ from reflex.utils.types import (
GenericType,
Self,
_isinstance,
can_access_properties,
get_origin,
has_args,
unionize,
@ -630,13 +631,10 @@ class Var(Generic[VAR_TYPE]):
return get_to_operation(NoneVar).create(self) # type: ignore
# Handle fixed_output_type being Base or a dataclass.
try:
if issubclass(fixed_output_type, Base):
return self.to(ObjectVar, output)
except TypeError:
pass
if dataclasses.is_dataclass(fixed_output_type) and not issubclass(
fixed_output_type, Var
if (
inspect.isclass(fixed_output_type)
and not issubclass(fixed_output_type, Var)
and can_access_properties(fixed_output_type)
):
return self.to(ObjectVar, output)
@ -707,11 +705,7 @@ class Var(Generic[VAR_TYPE]):
):
return self.to(NumberVar, self._var_type)
if all(
inspect.isclass(t)
and (issubclass(t, Base) or dataclasses.is_dataclass(t))
for t in inner_types
):
if all(can_access_properties(t) for t in inner_types):
return self.to(ObjectVar, self._var_type)
return self
@ -730,13 +724,9 @@ class Var(Generic[VAR_TYPE]):
if issubclass(fixed_type, var_subclass.python_types):
return self.to(var_subclass.var_subclass, self._var_type)
try:
if issubclass(fixed_type, Base):
return self.to(ObjectVar, self._var_type)
except TypeError:
pass
if dataclasses.is_dataclass(fixed_type):
if can_access_properties(fixed_type):
return self.to(ObjectVar, self._var_type)
return self
def get_default_value(self) -> Any:

View File

@ -1116,7 +1116,9 @@ U = TypeVar("U")
@var_operation
def ternary_operation(condition: BooleanVar, if_true: Var[T], if_false: Var[U]):
def ternary_operation(
condition: BooleanVar, if_true: Var[T], if_false: Var[U]
) -> CustomVarOperationReturn[Union[T, U]]:
"""Create a ternary operation.
Args:

View File

@ -36,7 +36,7 @@ from .base import (
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
from .sequence import ArrayVar, StringVar
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
OBJECT_TYPE = TypeVar("OBJECT_TYPE")
KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
@ -59,7 +59,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def _value_type(
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
self: ObjectVar[Dict[Any, VALUE_TYPE]],
) -> Type[VALUE_TYPE]: ...
@overload
@ -87,7 +87,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def values(
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
self: ObjectVar[Dict[Any, VALUE_TYPE]],
) -> ArrayVar[List[VALUE_TYPE]]: ...
@overload
@ -103,7 +103,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def entries(
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
self: ObjectVar[Dict[Any, VALUE_TYPE]],
) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
@overload
@ -133,47 +133,47 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
self: ObjectVar[Dict[Any, NoReturn]],
key: Var | Any,
) -> Var: ...
@overload
def __getitem__(
self: (
ObjectVar[Dict[KEY_TYPE, int]]
| ObjectVar[Dict[KEY_TYPE, float]]
| ObjectVar[Dict[KEY_TYPE, int | float]]
ObjectVar[Dict[Any, int]]
| ObjectVar[Dict[Any, float]]
| ObjectVar[Dict[Any, int | float]]
),
key: Var | Any,
) -> NumberVar: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, str]],
self: ObjectVar[Dict[Any, str]],
key: Var | Any,
) -> StringVar: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
key: Var | Any,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
key: Var | Any,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
@ -195,47 +195,47 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
self: ObjectVar[Dict[Any, NoReturn]],
name: str,
) -> Var: ...
@overload
def __getattr__(
self: (
ObjectVar[Dict[KEY_TYPE, int]]
| ObjectVar[Dict[KEY_TYPE, float]]
| ObjectVar[Dict[KEY_TYPE, int | float]]
ObjectVar[Dict[Any, int]]
| ObjectVar[Dict[Any, float]]
| ObjectVar[Dict[Any, int | float]]
),
name: str,
) -> NumberVar: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, str]],
self: ObjectVar[Dict[Any, str]],
name: str,
) -> StringVar: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
name: str,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
name: str,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
@ -377,8 +377,8 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
@classmethod
def create(
cls,
_var_value: OBJECT_TYPE,
_var_type: GenericType | None = None,
_var_value: dict,
_var_type: Type[OBJECT_TYPE] | None = None,
_var_data: VarData | None = None,
) -> LiteralObjectVar[OBJECT_TYPE]:
"""Create the literal object var.

View File

@ -853,31 +853,31 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
@overload
def __getitem__(
self: (
ArrayVar[Tuple[OTHER_TUPLE, int]]
| ArrayVar[Tuple[OTHER_TUPLE, float]]
| ArrayVar[Tuple[OTHER_TUPLE, int | float]]
ArrayVar[Tuple[Any, int]]
| ArrayVar[Tuple[Any, float]]
| ArrayVar[Tuple[Any, int | float]]
),
i: Literal[1, -1],
) -> NumberVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[str, OTHER_TUPLE]], i: Literal[0, -2]
self: ArrayVar[Tuple[str, Any]], i: Literal[0, -2]
) -> StringVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[OTHER_TUPLE, str]], i: Literal[1, -1]
self: ArrayVar[Tuple[Any, str]], i: Literal[1, -1]
) -> StringVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[bool, OTHER_TUPLE]], i: Literal[0, -2]
self: ArrayVar[Tuple[bool, Any]], i: Literal[0, -2]
) -> BooleanVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[OTHER_TUPLE, bool]], i: Literal[1, -1]
self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1]
) -> BooleanVar: ...
@overload

View File

@ -45,7 +45,7 @@ from reflex.testing import chdir
from reflex.utils import format, prerequisites, types
from reflex.utils.exceptions import SetUndefinedStateVarError
from reflex.utils.format import json_dumps
from reflex.vars.base import ComputedVar, Var
from reflex.vars.base import Var, computed_var
from tests.units.states.mutation import MutableSQLAModel, MutableTestState
from .states import GenState
@ -109,7 +109,7 @@ class TestState(BaseState):
_backend: int = 0
asynctest: int = 0
@ComputedVar
@computed_var
def sum(self) -> float:
"""Dynamically sum the numbers.
@ -118,7 +118,7 @@ class TestState(BaseState):
"""
return self.num1 + self.num2
@ComputedVar
@computed_var
def upper(self) -> str:
"""Uppercase the key.
@ -1124,7 +1124,7 @@ def test_child_state():
v: int = 2
class ChildState(MainState):
@ComputedVar
@computed_var
def rendered_var(self):
return self.v
@ -1143,7 +1143,7 @@ def test_conditional_computed_vars():
t1: str = "a"
t2: str = "b"
@ComputedVar
@computed_var
def rendered_var(self) -> str:
if self.flag:
return self.t1
@ -3095,12 +3095,12 @@ def test_potentially_dirty_substates():
"""
class State(RxState):
@ComputedVar
@computed_var
def foo(self) -> str:
return ""
class C1(State):
@ComputedVar
@computed_var
def bar(self) -> str:
return ""