Improve Var type handling for better rx.Model attribute access (#2010)
This commit is contained in:
parent
f404205c3f
commit
92dd68c51f
@ -15,6 +15,7 @@ import alembic.runtime.environment
|
||||
import alembic.script
|
||||
import alembic.util
|
||||
import sqlalchemy
|
||||
import sqlalchemy.orm
|
||||
import sqlmodel
|
||||
|
||||
from reflex import constants
|
||||
@ -68,6 +69,22 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
|
||||
super().__init_subclass__()
|
||||
|
||||
@classmethod
|
||||
def _dict_recursive(cls, value):
|
||||
"""Recursively serialize the relationship object(s).
|
||||
|
||||
Args:
|
||||
value: The value to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized value.
|
||||
"""
|
||||
if hasattr(value, "dict"):
|
||||
return value.dict()
|
||||
elif isinstance(value, list):
|
||||
return [cls._dict_recursive(item) for item in value]
|
||||
return value
|
||||
|
||||
def dict(self, **kwargs):
|
||||
"""Convert the object to a dictionary.
|
||||
|
||||
@ -77,7 +94,19 @@ class Model(Base, sqlmodel.SQLModel):
|
||||
Returns:
|
||||
The object as a dictionary.
|
||||
"""
|
||||
return {name: getattr(self, name) for name in self.__fields__}
|
||||
base_fields = {name: getattr(self, name) for name in self.__fields__}
|
||||
relationships = {}
|
||||
# SQLModel relationships do not appear in __fields__, but should be included if present.
|
||||
for name in self.__sqlmodel_relationships__:
|
||||
try:
|
||||
relationships[name] = self._dict_recursive(getattr(self, name))
|
||||
except sqlalchemy.orm.exc.DetachedInstanceError:
|
||||
# This happens when the relationship was never loaded and the session is closed.
|
||||
continue
|
||||
return {
|
||||
**base_fields,
|
||||
**relationships,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_all():
|
||||
|
@ -571,6 +571,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if default_value is not None:
|
||||
field.required = False
|
||||
field.default = default_value
|
||||
if (
|
||||
not field.required
|
||||
and field.default is None
|
||||
and not types.is_optional(prop._var_type)
|
||||
):
|
||||
# Ensure frontend uses null coalescing when accessing.
|
||||
prop._var_type = Optional[prop._var_type]
|
||||
|
||||
@staticmethod
|
||||
def _get_base_functions() -> dict[str, FunctionType]:
|
||||
|
@ -3,8 +3,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import typing
|
||||
from typing import Any, Callable, Literal, Type, Union, _GenericAlias # type: ignore
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
_GenericAlias, # type: ignore
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic.fields import ModelField
|
||||
|
||||
from reflex.base import Base
|
||||
from reflex.utils import serializers
|
||||
@ -21,18 +35,6 @@ StateIterVar = Union[list, set, tuple]
|
||||
ArgsSpec = Callable
|
||||
|
||||
|
||||
def get_args(alias: _GenericAlias) -> tuple[Type, ...]:
|
||||
"""Get the arguments of a type alias.
|
||||
|
||||
Args:
|
||||
alias: The type alias.
|
||||
|
||||
Returns:
|
||||
The arguments of the type alias.
|
||||
"""
|
||||
return alias.__args__
|
||||
|
||||
|
||||
def is_generic_alias(cls: GenericType) -> bool:
|
||||
"""Check whether the class is a generic alias.
|
||||
|
||||
@ -69,11 +71,11 @@ def is_union(cls: GenericType) -> bool:
|
||||
Returns:
|
||||
Whether the class is a Union.
|
||||
"""
|
||||
with contextlib.suppress(ImportError):
|
||||
from typing import _UnionGenericAlias # type: ignore
|
||||
# UnionType added in py3.10
|
||||
if not hasattr(types, "UnionType"):
|
||||
return get_origin(cls) is Union
|
||||
|
||||
return isinstance(cls, _UnionGenericAlias)
|
||||
return cls.__origin__ == Union if is_generic_alias(cls) else False
|
||||
return get_origin(cls) in [Union, types.UnionType]
|
||||
|
||||
|
||||
def is_literal(cls: GenericType) -> bool:
|
||||
@ -85,7 +87,61 @@ def is_literal(cls: GenericType) -> bool:
|
||||
Returns:
|
||||
Whether the class is a literal.
|
||||
"""
|
||||
return hasattr(cls, "__origin__") and cls.__origin__ is Literal
|
||||
return get_origin(cls) is Literal
|
||||
|
||||
|
||||
def is_optional(cls: GenericType) -> bool:
|
||||
"""Check if a class is an Optional.
|
||||
|
||||
Args:
|
||||
cls: The class to check.
|
||||
|
||||
Returns:
|
||||
Whether the class is an Optional.
|
||||
"""
|
||||
return is_union(cls) and type(None) in get_args(cls)
|
||||
|
||||
|
||||
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
|
||||
"""Check if an attribute can be accessed on the cls and return its type.
|
||||
|
||||
Supports pydantic models, unions, and annotated attributes on rx.Model.
|
||||
|
||||
Args:
|
||||
cls: The class to check.
|
||||
name: The name of the attribute to check.
|
||||
|
||||
Returns:
|
||||
The type of the attribute, if accessible, or None
|
||||
"""
|
||||
from reflex.model import Model
|
||||
|
||||
if hasattr(cls, "__fields__") and name in cls.__fields__:
|
||||
# pydantic models
|
||||
field = cls.__fields__[name]
|
||||
type_ = field.outer_type_
|
||||
if isinstance(type_, ModelField):
|
||||
type_ = type_.type_
|
||||
if not field.required and field.default is None:
|
||||
# Ensure frontend uses null coalescing when accessing.
|
||||
type_ = Optional[type_]
|
||||
return type_
|
||||
elif isinstance(cls, type) and issubclass(cls, Model):
|
||||
# Check in the annotations directly (for sqlmodel.Relationship)
|
||||
hints = get_type_hints(cls)
|
||||
if name in hints:
|
||||
type_ = hints[name]
|
||||
if isinstance(type_, ModelField):
|
||||
return type_.type_
|
||||
return type_
|
||||
elif is_union(cls):
|
||||
# Check in each arg of the annotation.
|
||||
for arg in get_args(cls):
|
||||
type_ = get_attribute_access_type(arg, name)
|
||||
if type_ is not None:
|
||||
# Return the first attribute type that is accessible.
|
||||
return type_
|
||||
return None # Attribute is not accessible.
|
||||
|
||||
|
||||
def get_base_class(cls: GenericType) -> Type:
|
||||
@ -171,7 +227,7 @@ def is_dataframe(value: Type) -> bool:
|
||||
Returns:
|
||||
Whether the value is a dataframe.
|
||||
"""
|
||||
if is_generic_alias(value) or value == typing.Any:
|
||||
if is_generic_alias(value) or value == Any:
|
||||
return False
|
||||
return value.__name__ == "DataFrame"
|
||||
|
||||
@ -185,6 +241,8 @@ def is_valid_var_type(type_: Type) -> bool:
|
||||
Returns:
|
||||
Whether the type is a valid prop type.
|
||||
"""
|
||||
if is_union(type_):
|
||||
return all((is_valid_var_type(arg) for arg in get_args(type_)))
|
||||
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
|
||||
|
||||
|
||||
@ -200,9 +258,7 @@ def is_backend_variable(name: str) -> bool:
|
||||
return name.startswith("_") and not name.startswith("__")
|
||||
|
||||
|
||||
def check_type_in_allowed_types(
|
||||
value_type: Type, allowed_types: typing.Iterable
|
||||
) -> bool:
|
||||
def check_type_in_allowed_types(value_type: Type, allowed_types: Iterable) -> bool:
|
||||
"""Check that a value type is found in a list of allowed types.
|
||||
|
||||
Args:
|
||||
|
@ -27,8 +27,6 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic.fields import ModelField
|
||||
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.utils import console, format, serializers, types
|
||||
@ -420,15 +418,12 @@ class Var:
|
||||
raise TypeError(
|
||||
f"You must provide an annotation for the state var `{self._var_full_name}`. Annotation cannot be `{self._var_type}`"
|
||||
) from None
|
||||
if (
|
||||
hasattr(self._var_type, "__fields__")
|
||||
and name in self._var_type.__fields__
|
||||
):
|
||||
type_ = self._var_type.__fields__[name].outer_type_
|
||||
if isinstance(type_, ModelField):
|
||||
type_ = type_.type_
|
||||
is_optional = types.is_optional(self._var_type)
|
||||
type_ = types.get_attribute_access_type(self._var_type, name)
|
||||
|
||||
if type_ is not None:
|
||||
return BaseVar(
|
||||
_var_name=f"{self._var_name}.{name}",
|
||||
_var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
|
||||
_var_type=type_,
|
||||
_var_state=self._var_state,
|
||||
_var_is_local=self._var_is_local,
|
||||
@ -1235,6 +1230,9 @@ class BaseVar(Var):
|
||||
Raises:
|
||||
ImportError: If the var is a dataframe and pandas is not installed.
|
||||
"""
|
||||
if types.is_optional(self._var_type):
|
||||
return None
|
||||
|
||||
type_ = (
|
||||
get_origin(self._var_type)
|
||||
if types.is_generic_alias(self._var_type)
|
||||
|
@ -7,7 +7,7 @@ import functools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Generator, List
|
||||
from typing import Dict, Generator, List, Optional, Union
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
@ -30,7 +30,7 @@ from reflex.state import (
|
||||
StateProxy,
|
||||
StateUpdate,
|
||||
)
|
||||
from reflex.utils import prerequisites
|
||||
from reflex.utils import prerequisites, types
|
||||
from reflex.utils.format import json_dumps
|
||||
from reflex.vars import BaseVar, ComputedVar
|
||||
|
||||
@ -2239,3 +2239,52 @@ def test_reset_with_mutables():
|
||||
instance.items.append([3, 3])
|
||||
assert instance.items != default
|
||||
assert instance.items != copied_default
|
||||
|
||||
|
||||
class Custom1(Base):
|
||||
"""A custom class with a str field."""
|
||||
|
||||
foo: str
|
||||
|
||||
|
||||
class Custom2(Base):
|
||||
"""A custom class with a Custom1 field."""
|
||||
|
||||
c1: Optional[Custom1] = None
|
||||
c1r: Custom1
|
||||
|
||||
|
||||
class Custom3(Base):
|
||||
"""A custom class with a Custom2 field."""
|
||||
|
||||
c2: Optional[Custom2] = None
|
||||
c2r: Custom2
|
||||
|
||||
|
||||
def test_state_union_optional():
|
||||
"""Test that state can be defined with Union and Optional vars."""
|
||||
|
||||
class UnionState(State):
|
||||
int_float: Union[int, float] = 0
|
||||
opt_int: Optional[int]
|
||||
c3: Optional[Custom3]
|
||||
c3i: Custom3 # implicitly required
|
||||
c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo="")))
|
||||
custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="")
|
||||
|
||||
assert UnionState.c3.c2._var_name == "c3?.c2" # type: ignore
|
||||
assert UnionState.c3.c2.c1._var_name == "c3?.c2?.c1" # type: ignore
|
||||
assert UnionState.c3.c2.c1.foo._var_name == "c3?.c2?.c1?.foo" # type: ignore
|
||||
assert UnionState.c3.c2.c1r.foo._var_name == "c3?.c2?.c1r.foo" # type: ignore
|
||||
assert UnionState.c3.c2r.c1._var_name == "c3?.c2r.c1" # type: ignore
|
||||
assert UnionState.c3.c2r.c1.foo._var_name == "c3?.c2r.c1?.foo" # type: ignore
|
||||
assert UnionState.c3.c2r.c1r.foo._var_name == "c3?.c2r.c1r.foo" # type: ignore
|
||||
assert UnionState.c3i.c2._var_name == "c3i.c2" # type: ignore
|
||||
assert UnionState.c3r.c2._var_name == "c3r.c2" # type: ignore
|
||||
assert UnionState.custom_union.foo is not None # type: ignore
|
||||
assert UnionState.custom_union.c1 is not None # type: ignore
|
||||
assert UnionState.custom_union.c1r is not None # type: ignore
|
||||
assert UnionState.custom_union.c2 is not None # type: ignore
|
||||
assert UnionState.custom_union.c2r is not None # type: ignore
|
||||
assert types.is_optional(UnionState.opt_int._var_type) # type: ignore
|
||||
assert types.is_union(UnionState.int_float._var_type) # type: ignore
|
||||
|
Loading…
Reference in New Issue
Block a user