improve object var symantics (#4290)

* improve object var symantics

* add case for serializers

* check against serializer with to = dict

* add tests

* fix typing issues

* remove default value

* older version of python doesn't have assert type

* add base to rx field cases

* get it from typing_extension
This commit is contained in:
Khaleel Al-Adhami 2024-11-05 09:56:10 -08:00 committed by GitHub
parent 0ed7c5d969
commit b5d1e03de1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 222 additions and 77 deletions

View File

@ -78,7 +78,7 @@ def serializer(
)
# Apply type transformation if requested
if to is not None:
if to is not None or ((to := type_hints.get("return")) is not None):
SERIALIZER_TYPES[type_] = to
get_serializer_type.cache_clear()
@ -189,16 +189,37 @@ def get_serializer_type(type_: Type) -> Optional[Type]:
return None
def has_serializer(type_: Type) -> bool:
def has_serializer(type_: Type, into_type: Type | None = None) -> bool:
"""Check if there is a serializer for the type.
Args:
type_: The type to check.
into_type: The type to serialize into.
Returns:
Whether there is a serializer for the type.
"""
return get_serializer(type_) is not None
serializer_for_type = get_serializer(type_)
return serializer_for_type is not None and (
into_type is None or get_serializer_type(type_) == into_type
)
def can_serialize(type_: Type, into_type: Type | None = None) -> bool:
"""Check if there is a serializer for the type.
Args:
type_: The type to check.
into_type: The type to serialize into.
Returns:
Whether there is a serializer for the type.
"""
return has_serializer(type_, into_type) or (
isinstance(type_, type)
and dataclasses.is_dataclass(type_)
and (into_type is None or into_type is dict)
)
@serializer(to=str)
@ -214,7 +235,7 @@ def serialize_type(value: type) -> str:
return value.__name__
@serializer
@serializer(to=dict)
def serialize_base(value: Base) -> dict:
"""Serialize a Base instance.

View File

@ -75,7 +75,6 @@ from reflex.utils.types import (
if TYPE_CHECKING:
from reflex.state import BaseState
from .function import FunctionVar
from .number import (
BooleanVar,
NumberVar,
@ -279,6 +278,24 @@ def _decode_var_immutable(value: str) -> tuple[VarData | None, str]:
return VarData.merge(*var_datas) if var_datas else None, value
def can_use_in_object_var(cls: GenericType) -> bool:
"""Check if the class can be used in an ObjectVar.
Args:
cls: The class to check.
Returns:
Whether the class can be used in an ObjectVar.
"""
if types.is_union(cls):
return all(can_use_in_object_var(t) for t in types.get_args(cls))
return (
inspect.isclass(cls)
and not issubclass(cls, Var)
and serializers.can_serialize(cls, dict)
)
@dataclasses.dataclass(
eq=False,
frozen=True,
@ -565,36 +582,33 @@ class Var(Generic[VAR_TYPE]):
# Encode the _var_data into the formatted output for tracking purposes.
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._js_expr}"
@overload
def to(self, output: Type[StringVar]) -> StringVar: ...
@overload
def to(self, output: Type[str]) -> StringVar: ...
@overload
def to(self, output: Type[BooleanVar]) -> BooleanVar: ...
def to(self, output: Type[bool]) -> BooleanVar: ...
@overload
def to(
self, output: Type[NumberVar], var_type: type[int] | type[float] = float
) -> NumberVar: ...
def to(self, output: type[int] | type[float]) -> NumberVar: ...
@overload
def to(
self,
output: Type[ArrayVar],
var_type: type[list] | type[tuple] | type[set] = list,
output: type[list] | type[tuple] | type[set],
) -> ArrayVar: ...
@overload
def to(
self, output: Type[ObjectVar], var_type: types.GenericType = dict
) -> ObjectVar: ...
self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE]
) -> ObjectVar[VAR_INSIDE]: ...
@overload
def to(
self, output: Type[FunctionVar], var_type: Type[Callable] = Callable
) -> FunctionVar: ...
self, output: Type[ObjectVar], var_type: None = None
) -> ObjectVar[VAR_TYPE]: ...
@overload
def to(self, output: VAR_SUBCLASS, var_type: None = None) -> VAR_SUBCLASS: ...
@overload
def to(
@ -630,21 +644,19 @@ class Var(Generic[VAR_TYPE]):
return get_to_operation(NoneVar).create(self) # type: ignore
# Handle fixed_output_type being Base or a dataclass.
try:
if issubclass(fixed_output_type, Base):
return self.to(ObjectVar, output)
except TypeError:
pass
if dataclasses.is_dataclass(fixed_output_type) and not issubclass(
fixed_output_type, Var
):
if can_use_in_object_var(fixed_output_type):
return self.to(ObjectVar, output)
if inspect.isclass(output):
for var_subclass in _var_subclasses[::-1]:
if issubclass(output, var_subclass.var_subclass):
current_var_type = self._var_type
if current_var_type is Any:
new_var_type = var_type
else:
new_var_type = var_type or current_var_type
to_operation_return = var_subclass.to_var_subclass.create(
value=self, _var_type=var_type
value=self, _var_type=new_var_type
)
return to_operation_return # type: ignore
@ -707,11 +719,7 @@ class Var(Generic[VAR_TYPE]):
):
return self.to(NumberVar, self._var_type)
if all(
inspect.isclass(t)
and (issubclass(t, Base) or dataclasses.is_dataclass(t))
for t in inner_types
):
if can_use_in_object_var(var_type):
return self.to(ObjectVar, self._var_type)
return self
@ -730,13 +738,9 @@ class Var(Generic[VAR_TYPE]):
if issubclass(fixed_type, var_subclass.python_types):
return self.to(var_subclass.var_subclass, self._var_type)
try:
if issubclass(fixed_type, Base):
return self.to(ObjectVar, self._var_type)
except TypeError:
pass
if dataclasses.is_dataclass(fixed_type):
if can_use_in_object_var(fixed_type):
return self.to(ObjectVar, self._var_type)
return self
def get_default_value(self) -> Any:
@ -1181,6 +1185,9 @@ class Var(Generic[VAR_TYPE]):
OUTPUT = TypeVar("OUTPUT", bound=Var)
VAR_SUBCLASS = TypeVar("VAR_SUBCLASS", bound=Var)
VAR_INSIDE = TypeVar("VAR_INSIDE")
class ToOperation:
"""A var operation that converts a var to another type."""
@ -2888,6 +2895,8 @@ def dispatch(
V = TypeVar("V")
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
class Field(Generic[T]):
"""Shadow class for Var to allow for type hinting in the IDE."""
@ -2924,6 +2933,11 @@ class Field(Generic[T]):
self: Field[Dict[str, V]], instance: None, owner
) -> ObjectVar[Dict[str, V]]: ...
@overload
def __get__(
self: Field[BASE_TYPE], instance: None, owner
) -> ObjectVar[BASE_TYPE]: ...
@overload
def __get__(self, instance: None, owner) -> Var[T]: ...

View File

@ -1116,7 +1116,9 @@ U = TypeVar("U")
@var_operation
def ternary_operation(condition: BooleanVar, if_true: Var[T], if_false: Var[U]):
def ternary_operation(
condition: BooleanVar, if_true: Var[T], if_false: Var[U]
) -> CustomVarOperationReturn[Union[T, U]]:
"""Create a ternary operation.
Args:

View File

@ -36,7 +36,7 @@ from .base import (
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
from .sequence import ArrayVar, StringVar
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
OBJECT_TYPE = TypeVar("OBJECT_TYPE")
KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
@ -59,7 +59,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def _value_type(
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
self: ObjectVar[Dict[Any, VALUE_TYPE]],
) -> Type[VALUE_TYPE]: ...
@overload
@ -87,7 +87,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def values(
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
self: ObjectVar[Dict[Any, VALUE_TYPE]],
) -> ArrayVar[List[VALUE_TYPE]]: ...
@overload
@ -103,7 +103,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def entries(
self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]],
self: ObjectVar[Dict[Any, VALUE_TYPE]],
) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
@overload
@ -133,47 +133,47 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
self: ObjectVar[Dict[Any, NoReturn]],
key: Var | Any,
) -> Var: ...
@overload
def __getitem__(
self: (
ObjectVar[Dict[KEY_TYPE, int]]
| ObjectVar[Dict[KEY_TYPE, float]]
| ObjectVar[Dict[KEY_TYPE, int | float]]
ObjectVar[Dict[Any, int]]
| ObjectVar[Dict[Any, float]]
| ObjectVar[Dict[Any, int | float]]
),
key: Var | Any,
) -> NumberVar: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, str]],
self: ObjectVar[Dict[Any, str]],
key: Var | Any,
) -> StringVar: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
key: Var | Any,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getitem__(
self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
key: Var | Any,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
@ -195,50 +195,56 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, NoReturn]],
self: ObjectVar[Dict[Any, NoReturn]],
name: str,
) -> Var: ...
@overload
def __getattr__(
self: (
ObjectVar[Dict[KEY_TYPE, int]]
| ObjectVar[Dict[KEY_TYPE, float]]
| ObjectVar[Dict[KEY_TYPE, int | float]]
ObjectVar[Dict[Any, int]]
| ObjectVar[Dict[Any, float]]
| ObjectVar[Dict[Any, int | float]]
),
name: str,
) -> NumberVar: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, str]],
self: ObjectVar[Dict[Any, str]],
name: str,
) -> StringVar: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]],
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]],
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]],
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]],
name: str,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload
def __getattr__(
self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]],
name: str,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar,
name: str,
) -> ObjectItemOperation: ...
def __getattr__(self, name) -> Var:
"""Get an attribute of the var.
@ -377,8 +383,8 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
@classmethod
def create(
cls,
_var_value: OBJECT_TYPE,
_var_type: GenericType | None = None,
_var_value: dict,
_var_type: Type[OBJECT_TYPE] | None = None,
_var_data: VarData | None = None,
) -> LiteralObjectVar[OBJECT_TYPE]:
"""Create the literal object var.

View File

@ -853,31 +853,31 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
@overload
def __getitem__(
self: (
ArrayVar[Tuple[OTHER_TUPLE, int]]
| ArrayVar[Tuple[OTHER_TUPLE, float]]
| ArrayVar[Tuple[OTHER_TUPLE, int | float]]
ArrayVar[Tuple[Any, int]]
| ArrayVar[Tuple[Any, float]]
| ArrayVar[Tuple[Any, int | float]]
),
i: Literal[1, -1],
) -> NumberVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[str, OTHER_TUPLE]], i: Literal[0, -2]
self: ArrayVar[Tuple[str, Any]], i: Literal[0, -2]
) -> StringVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[OTHER_TUPLE, str]], i: Literal[1, -1]
self: ArrayVar[Tuple[Any, str]], i: Literal[1, -1]
) -> StringVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[bool, OTHER_TUPLE]], i: Literal[0, -2]
self: ArrayVar[Tuple[bool, Any]], i: Literal[0, -2]
) -> BooleanVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[OTHER_TUPLE, bool]], i: Literal[1, -1]
self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1]
) -> BooleanVar: ...
@overload

View File

@ -45,7 +45,7 @@ from reflex.testing import chdir
from reflex.utils import format, prerequisites, types
from reflex.utils.exceptions import SetUndefinedStateVarError
from reflex.utils.format import json_dumps
from reflex.vars.base import ComputedVar, Var
from reflex.vars.base import Var, computed_var
from tests.units.states.mutation import MutableSQLAModel, MutableTestState
from .states import GenState
@ -109,7 +109,7 @@ class TestState(BaseState):
_backend: int = 0
asynctest: int = 0
@ComputedVar
@computed_var
def sum(self) -> float:
"""Dynamically sum the numbers.
@ -118,7 +118,7 @@ class TestState(BaseState):
"""
return self.num1 + self.num2
@ComputedVar
@computed_var
def upper(self) -> str:
"""Uppercase the key.
@ -1124,7 +1124,7 @@ def test_child_state():
v: int = 2
class ChildState(MainState):
@ComputedVar
@computed_var
def rendered_var(self):
return self.v
@ -1143,7 +1143,7 @@ def test_conditional_computed_vars():
t1: str = "a"
t2: str = "b"
@ComputedVar
@computed_var
def rendered_var(self) -> str:
if self.flag:
return self.t1
@ -3095,12 +3095,12 @@ def test_potentially_dirty_substates():
"""
class State(RxState):
@ComputedVar
@computed_var
def foo(self) -> str:
return ""
class C1(State):
@ComputedVar
@computed_var
def bar(self) -> str:
return ""

View File

@ -0,0 +1,102 @@
import pytest
from typing_extensions import assert_type
import reflex as rx
from reflex.utils.types import GenericType
from reflex.vars.base import Var
from reflex.vars.object import LiteralObjectVar, ObjectVar
class Bare:
"""A bare class with a single attribute."""
quantity: int = 0
@rx.serializer
def serialize_bare(obj: Bare) -> dict:
"""A serializer for the bare class.
Args:
obj: The object to serialize.
Returns:
A dictionary with the quantity attribute.
"""
return {"quantity": obj.quantity}
class Base(rx.Base):
"""A reflex base class with a single attribute."""
quantity: int = 0
class ObjectState(rx.State):
"""A reflex state with bare and base objects."""
bare: rx.Field[Bare] = rx.field(Bare())
base: rx.Field[Base] = rx.field(Base())
@pytest.mark.parametrize("type_", [Base, Bare])
def test_var_create(type_: GenericType) -> None:
my_object = type_()
var = Var.create(my_object)
assert var._var_type is type_
quantity = var.quantity
assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare])
def test_literal_create(type_: GenericType) -> None:
my_object = type_()
var = LiteralObjectVar.create(my_object)
assert var._var_type is type_
quantity = var.quantity
assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare])
def test_guess(type_: GenericType) -> None:
my_object = type_()
var = Var.create(my_object)
var = var.guess_type()
assert var._var_type is type_
quantity = var.quantity
assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare])
def test_state(type_: GenericType) -> None:
attr_name = type_.__name__.lower()
var = getattr(ObjectState, attr_name)
assert var._var_type is type_
quantity = var.quantity
assert quantity._var_type is int
@pytest.mark.parametrize("type_", [Base, Bare])
def test_state_to_operation(type_: GenericType) -> None:
attr_name = type_.__name__.lower()
original_var = getattr(ObjectState, attr_name)
var = original_var.to(ObjectVar, type_)
assert var._var_type is type_
var = original_var.to(ObjectVar)
assert var._var_type is type_
def test_typing() -> None:
# Bare
var = ObjectState.bare.to(ObjectVar)
_ = assert_type(var, ObjectVar[Bare])
# Base
var = ObjectState.base
_ = assert_type(var, ObjectVar[Base])