Improve Var type handling for better rx.Model attribute access (#2010)
This commit is contained in:
parent
f404205c3f
commit
92dd68c51f
@ -15,6 +15,7 @@ import alembic.runtime.environment
|
|||||||
import alembic.script
|
import alembic.script
|
||||||
import alembic.util
|
import alembic.util
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
import sqlalchemy.orm
|
||||||
import sqlmodel
|
import sqlmodel
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
@ -68,6 +69,22 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
|
|
||||||
super().__init_subclass__()
|
super().__init_subclass__()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _dict_recursive(cls, value):
|
||||||
|
"""Recursively serialize the relationship object(s).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to serialize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The serialized value.
|
||||||
|
"""
|
||||||
|
if hasattr(value, "dict"):
|
||||||
|
return value.dict()
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return [cls._dict_recursive(item) for item in value]
|
||||||
|
return value
|
||||||
|
|
||||||
def dict(self, **kwargs):
|
def dict(self, **kwargs):
|
||||||
"""Convert the object to a dictionary.
|
"""Convert the object to a dictionary.
|
||||||
|
|
||||||
@ -77,7 +94,19 @@ class Model(Base, sqlmodel.SQLModel):
|
|||||||
Returns:
|
Returns:
|
||||||
The object as a dictionary.
|
The object as a dictionary.
|
||||||
"""
|
"""
|
||||||
return {name: getattr(self, name) for name in self.__fields__}
|
base_fields = {name: getattr(self, name) for name in self.__fields__}
|
||||||
|
relationships = {}
|
||||||
|
# SQLModel relationships do not appear in __fields__, but should be included if present.
|
||||||
|
for name in self.__sqlmodel_relationships__:
|
||||||
|
try:
|
||||||
|
relationships[name] = self._dict_recursive(getattr(self, name))
|
||||||
|
except sqlalchemy.orm.exc.DetachedInstanceError:
|
||||||
|
# This happens when the relationship was never loaded and the session is closed.
|
||||||
|
continue
|
||||||
|
return {
|
||||||
|
**base_fields,
|
||||||
|
**relationships,
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_all():
|
def create_all():
|
||||||
|
@ -571,6 +571,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if default_value is not None:
|
if default_value is not None:
|
||||||
field.required = False
|
field.required = False
|
||||||
field.default = default_value
|
field.default = default_value
|
||||||
|
if (
|
||||||
|
not field.required
|
||||||
|
and field.default is None
|
||||||
|
and not types.is_optional(prop._var_type)
|
||||||
|
):
|
||||||
|
# Ensure frontend uses null coalescing when accessing.
|
||||||
|
prop._var_type = Optional[prop._var_type]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_base_functions() -> dict[str, FunctionType]:
|
def _get_base_functions() -> dict[str, FunctionType]:
|
||||||
|
@ -3,8 +3,22 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import typing
|
import types
|
||||||
from typing import Any, Callable, Literal, Type, Union, _GenericAlias # type: ignore
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
_GenericAlias, # type: ignore
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
|
get_type_hints,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic.fields import ModelField
|
||||||
|
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.utils import serializers
|
from reflex.utils import serializers
|
||||||
@ -21,18 +35,6 @@ StateIterVar = Union[list, set, tuple]
|
|||||||
ArgsSpec = Callable
|
ArgsSpec = Callable
|
||||||
|
|
||||||
|
|
||||||
def get_args(alias: _GenericAlias) -> tuple[Type, ...]:
|
|
||||||
"""Get the arguments of a type alias.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
alias: The type alias.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The arguments of the type alias.
|
|
||||||
"""
|
|
||||||
return alias.__args__
|
|
||||||
|
|
||||||
|
|
||||||
def is_generic_alias(cls: GenericType) -> bool:
|
def is_generic_alias(cls: GenericType) -> bool:
|
||||||
"""Check whether the class is a generic alias.
|
"""Check whether the class is a generic alias.
|
||||||
|
|
||||||
@ -69,11 +71,11 @@ def is_union(cls: GenericType) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
Whether the class is a Union.
|
Whether the class is a Union.
|
||||||
"""
|
"""
|
||||||
with contextlib.suppress(ImportError):
|
# UnionType added in py3.10
|
||||||
from typing import _UnionGenericAlias # type: ignore
|
if not hasattr(types, "UnionType"):
|
||||||
|
return get_origin(cls) is Union
|
||||||
|
|
||||||
return isinstance(cls, _UnionGenericAlias)
|
return get_origin(cls) in [Union, types.UnionType]
|
||||||
return cls.__origin__ == Union if is_generic_alias(cls) else False
|
|
||||||
|
|
||||||
|
|
||||||
def is_literal(cls: GenericType) -> bool:
|
def is_literal(cls: GenericType) -> bool:
|
||||||
@ -85,7 +87,61 @@ def is_literal(cls: GenericType) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
Whether the class is a literal.
|
Whether the class is a literal.
|
||||||
"""
|
"""
|
||||||
return hasattr(cls, "__origin__") and cls.__origin__ is Literal
|
return get_origin(cls) is Literal
|
||||||
|
|
||||||
|
|
||||||
|
def is_optional(cls: GenericType) -> bool:
|
||||||
|
"""Check if a class is an Optional.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: The class to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the class is an Optional.
|
||||||
|
"""
|
||||||
|
return is_union(cls) and type(None) in get_args(cls)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
|
||||||
|
"""Check if an attribute can be accessed on the cls and return its type.
|
||||||
|
|
||||||
|
Supports pydantic models, unions, and annotated attributes on rx.Model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: The class to check.
|
||||||
|
name: The name of the attribute to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The type of the attribute, if accessible, or None
|
||||||
|
"""
|
||||||
|
from reflex.model import Model
|
||||||
|
|
||||||
|
if hasattr(cls, "__fields__") and name in cls.__fields__:
|
||||||
|
# pydantic models
|
||||||
|
field = cls.__fields__[name]
|
||||||
|
type_ = field.outer_type_
|
||||||
|
if isinstance(type_, ModelField):
|
||||||
|
type_ = type_.type_
|
||||||
|
if not field.required and field.default is None:
|
||||||
|
# Ensure frontend uses null coalescing when accessing.
|
||||||
|
type_ = Optional[type_]
|
||||||
|
return type_
|
||||||
|
elif isinstance(cls, type) and issubclass(cls, Model):
|
||||||
|
# Check in the annotations directly (for sqlmodel.Relationship)
|
||||||
|
hints = get_type_hints(cls)
|
||||||
|
if name in hints:
|
||||||
|
type_ = hints[name]
|
||||||
|
if isinstance(type_, ModelField):
|
||||||
|
return type_.type_
|
||||||
|
return type_
|
||||||
|
elif is_union(cls):
|
||||||
|
# Check in each arg of the annotation.
|
||||||
|
for arg in get_args(cls):
|
||||||
|
type_ = get_attribute_access_type(arg, name)
|
||||||
|
if type_ is not None:
|
||||||
|
# Return the first attribute type that is accessible.
|
||||||
|
return type_
|
||||||
|
return None # Attribute is not accessible.
|
||||||
|
|
||||||
|
|
||||||
def get_base_class(cls: GenericType) -> Type:
|
def get_base_class(cls: GenericType) -> Type:
|
||||||
@ -171,7 +227,7 @@ def is_dataframe(value: Type) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
Whether the value is a dataframe.
|
Whether the value is a dataframe.
|
||||||
"""
|
"""
|
||||||
if is_generic_alias(value) or value == typing.Any:
|
if is_generic_alias(value) or value == Any:
|
||||||
return False
|
return False
|
||||||
return value.__name__ == "DataFrame"
|
return value.__name__ == "DataFrame"
|
||||||
|
|
||||||
@ -185,6 +241,8 @@ def is_valid_var_type(type_: Type) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
Whether the type is a valid prop type.
|
Whether the type is a valid prop type.
|
||||||
"""
|
"""
|
||||||
|
if is_union(type_):
|
||||||
|
return all((is_valid_var_type(arg) for arg in get_args(type_)))
|
||||||
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
|
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
|
||||||
|
|
||||||
|
|
||||||
@ -200,9 +258,7 @@ def is_backend_variable(name: str) -> bool:
|
|||||||
return name.startswith("_") and not name.startswith("__")
|
return name.startswith("_") and not name.startswith("__")
|
||||||
|
|
||||||
|
|
||||||
def check_type_in_allowed_types(
|
def check_type_in_allowed_types(value_type: Type, allowed_types: Iterable) -> bool:
|
||||||
value_type: Type, allowed_types: typing.Iterable
|
|
||||||
) -> bool:
|
|
||||||
"""Check that a value type is found in a list of allowed types.
|
"""Check that a value type is found in a list of allowed types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -27,8 +27,6 @@ from typing import (
|
|||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic.fields import ModelField
|
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.utils import console, format, serializers, types
|
from reflex.utils import console, format, serializers, types
|
||||||
@ -420,15 +418,12 @@ class Var:
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"You must provide an annotation for the state var `{self._var_full_name}`. Annotation cannot be `{self._var_type}`"
|
f"You must provide an annotation for the state var `{self._var_full_name}`. Annotation cannot be `{self._var_type}`"
|
||||||
) from None
|
) from None
|
||||||
if (
|
is_optional = types.is_optional(self._var_type)
|
||||||
hasattr(self._var_type, "__fields__")
|
type_ = types.get_attribute_access_type(self._var_type, name)
|
||||||
and name in self._var_type.__fields__
|
|
||||||
):
|
if type_ is not None:
|
||||||
type_ = self._var_type.__fields__[name].outer_type_
|
|
||||||
if isinstance(type_, ModelField):
|
|
||||||
type_ = type_.type_
|
|
||||||
return BaseVar(
|
return BaseVar(
|
||||||
_var_name=f"{self._var_name}.{name}",
|
_var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
|
||||||
_var_type=type_,
|
_var_type=type_,
|
||||||
_var_state=self._var_state,
|
_var_state=self._var_state,
|
||||||
_var_is_local=self._var_is_local,
|
_var_is_local=self._var_is_local,
|
||||||
@ -1235,6 +1230,9 @@ class BaseVar(Var):
|
|||||||
Raises:
|
Raises:
|
||||||
ImportError: If the var is a dataframe and pandas is not installed.
|
ImportError: If the var is a dataframe and pandas is not installed.
|
||||||
"""
|
"""
|
||||||
|
if types.is_optional(self._var_type):
|
||||||
|
return None
|
||||||
|
|
||||||
type_ = (
|
type_ = (
|
||||||
get_origin(self._var_type)
|
get_origin(self._var_type)
|
||||||
if types.is_generic_alias(self._var_type)
|
if types.is_generic_alias(self._var_type)
|
||||||
|
@ -7,7 +7,7 @@ import functools
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict, Generator, List
|
from typing import Dict, Generator, List, Optional, Union
|
||||||
from unittest.mock import AsyncMock, Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -30,7 +30,7 @@ from reflex.state import (
|
|||||||
StateProxy,
|
StateProxy,
|
||||||
StateUpdate,
|
StateUpdate,
|
||||||
)
|
)
|
||||||
from reflex.utils import prerequisites
|
from reflex.utils import prerequisites, types
|
||||||
from reflex.utils.format import json_dumps
|
from reflex.utils.format import json_dumps
|
||||||
from reflex.vars import BaseVar, ComputedVar
|
from reflex.vars import BaseVar, ComputedVar
|
||||||
|
|
||||||
@ -2239,3 +2239,52 @@ def test_reset_with_mutables():
|
|||||||
instance.items.append([3, 3])
|
instance.items.append([3, 3])
|
||||||
assert instance.items != default
|
assert instance.items != default
|
||||||
assert instance.items != copied_default
|
assert instance.items != copied_default
|
||||||
|
|
||||||
|
|
||||||
|
class Custom1(Base):
|
||||||
|
"""A custom class with a str field."""
|
||||||
|
|
||||||
|
foo: str
|
||||||
|
|
||||||
|
|
||||||
|
class Custom2(Base):
|
||||||
|
"""A custom class with a Custom1 field."""
|
||||||
|
|
||||||
|
c1: Optional[Custom1] = None
|
||||||
|
c1r: Custom1
|
||||||
|
|
||||||
|
|
||||||
|
class Custom3(Base):
|
||||||
|
"""A custom class with a Custom2 field."""
|
||||||
|
|
||||||
|
c2: Optional[Custom2] = None
|
||||||
|
c2r: Custom2
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_union_optional():
|
||||||
|
"""Test that state can be defined with Union and Optional vars."""
|
||||||
|
|
||||||
|
class UnionState(State):
|
||||||
|
int_float: Union[int, float] = 0
|
||||||
|
opt_int: Optional[int]
|
||||||
|
c3: Optional[Custom3]
|
||||||
|
c3i: Custom3 # implicitly required
|
||||||
|
c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo="")))
|
||||||
|
custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="")
|
||||||
|
|
||||||
|
assert UnionState.c3.c2._var_name == "c3?.c2" # type: ignore
|
||||||
|
assert UnionState.c3.c2.c1._var_name == "c3?.c2?.c1" # type: ignore
|
||||||
|
assert UnionState.c3.c2.c1.foo._var_name == "c3?.c2?.c1?.foo" # type: ignore
|
||||||
|
assert UnionState.c3.c2.c1r.foo._var_name == "c3?.c2?.c1r.foo" # type: ignore
|
||||||
|
assert UnionState.c3.c2r.c1._var_name == "c3?.c2r.c1" # type: ignore
|
||||||
|
assert UnionState.c3.c2r.c1.foo._var_name == "c3?.c2r.c1?.foo" # type: ignore
|
||||||
|
assert UnionState.c3.c2r.c1r.foo._var_name == "c3?.c2r.c1r.foo" # type: ignore
|
||||||
|
assert UnionState.c3i.c2._var_name == "c3i.c2" # type: ignore
|
||||||
|
assert UnionState.c3r.c2._var_name == "c3r.c2" # type: ignore
|
||||||
|
assert UnionState.custom_union.foo is not None # type: ignore
|
||||||
|
assert UnionState.custom_union.c1 is not None # type: ignore
|
||||||
|
assert UnionState.custom_union.c1r is not None # type: ignore
|
||||||
|
assert UnionState.custom_union.c2 is not None # type: ignore
|
||||||
|
assert UnionState.custom_union.c2r is not None # type: ignore
|
||||||
|
assert types.is_optional(UnionState.opt_int._var_type) # type: ignore
|
||||||
|
assert types.is_union(UnionState.int_float._var_type) # type: ignore
|
||||||
|
Loading…
Reference in New Issue
Block a user