Improve Var type handling for better rx.Model attribute access (#2010)

This commit is contained in:
Masen Furer 2023-10-25 11:55:50 -07:00 committed by GitHub
parent f404205c3f
commit 92dd68c51f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 175 additions and 36 deletions

View File

@ -15,6 +15,7 @@ import alembic.runtime.environment
import alembic.script
import alembic.util
import sqlalchemy
import sqlalchemy.orm
import sqlmodel
from reflex import constants
@ -68,6 +69,22 @@ class Model(Base, sqlmodel.SQLModel):
super().__init_subclass__()
@classmethod
def _dict_recursive(cls, value):
"""Recursively serialize the relationship object(s).
Args:
value: The value to serialize.
Returns:
The serialized value.
"""
if hasattr(value, "dict"):
return value.dict()
elif isinstance(value, list):
return [cls._dict_recursive(item) for item in value]
return value
def dict(self, **kwargs):
"""Convert the object to a dictionary.
@ -77,7 +94,19 @@ class Model(Base, sqlmodel.SQLModel):
Returns:
The object as a dictionary.
"""
return {name: getattr(self, name) for name in self.__fields__}
base_fields = {name: getattr(self, name) for name in self.__fields__}
relationships = {}
# SQLModel relationships do not appear in __fields__, but should be included if present.
for name in self.__sqlmodel_relationships__:
try:
relationships[name] = self._dict_recursive(getattr(self, name))
except sqlalchemy.orm.exc.DetachedInstanceError:
# This happens when the relationship was never loaded and the session is closed.
continue
return {
**base_fields,
**relationships,
}
@staticmethod
def create_all():

View File

@ -571,6 +571,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
if default_value is not None:
field.required = False
field.default = default_value
if (
not field.required
and field.default is None
and not types.is_optional(prop._var_type)
):
# Ensure frontend uses null coalescing when accessing.
prop._var_type = Optional[prop._var_type]
@staticmethod
def _get_base_functions() -> dict[str, FunctionType]:

View File

@ -3,8 +3,22 @@
from __future__ import annotations
import contextlib
import typing
from typing import Any, Callable, Literal, Type, Union, _GenericAlias # type: ignore
import types
from typing import (
Any,
Callable,
Iterable,
Literal,
Optional,
Type,
Union,
_GenericAlias, # type: ignore
get_args,
get_origin,
get_type_hints,
)
from pydantic.fields import ModelField
from reflex.base import Base
from reflex.utils import serializers
@ -21,18 +35,6 @@ StateIterVar = Union[list, set, tuple]
ArgsSpec = Callable
def get_args(alias: _GenericAlias) -> tuple[Type, ...]:
"""Get the arguments of a type alias.
Args:
alias: The type alias.
Returns:
The arguments of the type alias.
"""
return alias.__args__
def is_generic_alias(cls: GenericType) -> bool:
"""Check whether the class is a generic alias.
@ -69,11 +71,11 @@ def is_union(cls: GenericType) -> bool:
Returns:
Whether the class is a Union.
"""
with contextlib.suppress(ImportError):
from typing import _UnionGenericAlias # type: ignore
# UnionType added in py3.10
if not hasattr(types, "UnionType"):
return get_origin(cls) is Union
return isinstance(cls, _UnionGenericAlias)
return cls.__origin__ == Union if is_generic_alias(cls) else False
return get_origin(cls) in [Union, types.UnionType]
def is_literal(cls: GenericType) -> bool:
@ -85,7 +87,61 @@ def is_literal(cls: GenericType) -> bool:
Returns:
Whether the class is a literal.
"""
return hasattr(cls, "__origin__") and cls.__origin__ is Literal
return get_origin(cls) is Literal
def is_optional(cls: GenericType) -> bool:
"""Check if a class is an Optional.
Args:
cls: The class to check.
Returns:
Whether the class is an Optional.
"""
return is_union(cls) and type(None) in get_args(cls)
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
"""Check if an attribute can be accessed on the cls and return its type.
Supports pydantic models, unions, and annotated attributes on rx.Model.
Args:
cls: The class to check.
name: The name of the attribute to check.
Returns:
The type of the attribute, if accessible, or None
"""
from reflex.model import Model
if hasattr(cls, "__fields__") and name in cls.__fields__:
# pydantic models
field = cls.__fields__[name]
type_ = field.outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
if not field.required and field.default is None:
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
elif isinstance(cls, type) and issubclass(cls, Model):
# Check in the annotations directly (for sqlmodel.Relationship)
hints = get_type_hints(cls)
if name in hints:
type_ = hints[name]
if isinstance(type_, ModelField):
return type_.type_
return type_
elif is_union(cls):
# Check in each arg of the annotation.
for arg in get_args(cls):
type_ = get_attribute_access_type(arg, name)
if type_ is not None:
# Return the first attribute type that is accessible.
return type_
return None # Attribute is not accessible.
def get_base_class(cls: GenericType) -> Type:
@ -171,7 +227,7 @@ def is_dataframe(value: Type) -> bool:
Returns:
Whether the value is a dataframe.
"""
if is_generic_alias(value) or value == typing.Any:
if is_generic_alias(value) or value == Any:
return False
return value.__name__ == "DataFrame"
@ -185,6 +241,8 @@ def is_valid_var_type(type_: Type) -> bool:
Returns:
Whether the type is a valid prop type.
"""
if is_union(type_):
return all((is_valid_var_type(arg) for arg in get_args(type_)))
return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
@ -200,9 +258,7 @@ def is_backend_variable(name: str) -> bool:
return name.startswith("_") and not name.startswith("__")
def check_type_in_allowed_types(
value_type: Type, allowed_types: typing.Iterable
) -> bool:
def check_type_in_allowed_types(value_type: Type, allowed_types: Iterable) -> bool:
"""Check that a value type is found in a list of allowed types.
Args:

View File

@ -27,8 +27,6 @@ from typing import (
get_type_hints,
)
from pydantic.fields import ModelField
from reflex import constants
from reflex.base import Base
from reflex.utils import console, format, serializers, types
@ -420,15 +418,12 @@ class Var:
raise TypeError(
f"You must provide an annotation for the state var `{self._var_full_name}`. Annotation cannot be `{self._var_type}`"
) from None
if (
hasattr(self._var_type, "__fields__")
and name in self._var_type.__fields__
):
type_ = self._var_type.__fields__[name].outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
is_optional = types.is_optional(self._var_type)
type_ = types.get_attribute_access_type(self._var_type, name)
if type_ is not None:
return BaseVar(
_var_name=f"{self._var_name}.{name}",
_var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
_var_type=type_,
_var_state=self._var_state,
_var_is_local=self._var_is_local,
@ -1235,6 +1230,9 @@ class BaseVar(Var):
Raises:
ImportError: If the var is a dataframe and pandas is not installed.
"""
if types.is_optional(self._var_type):
return None
type_ = (
get_origin(self._var_type)
if types.is_generic_alias(self._var_type)

View File

@ -7,7 +7,7 @@ import functools
import json
import os
import sys
from typing import Dict, Generator, List
from typing import Dict, Generator, List, Optional, Union
from unittest.mock import AsyncMock, Mock
import pytest
@ -30,7 +30,7 @@ from reflex.state import (
StateProxy,
StateUpdate,
)
from reflex.utils import prerequisites
from reflex.utils import prerequisites, types
from reflex.utils.format import json_dumps
from reflex.vars import BaseVar, ComputedVar
@ -2239,3 +2239,52 @@ def test_reset_with_mutables():
instance.items.append([3, 3])
assert instance.items != default
assert instance.items != copied_default
class Custom1(Base):
"""A custom class with a str field."""
foo: str
class Custom2(Base):
"""A custom class with a Custom1 field."""
c1: Optional[Custom1] = None
c1r: Custom1
class Custom3(Base):
"""A custom class with a Custom2 field."""
c2: Optional[Custom2] = None
c2r: Custom2
def test_state_union_optional():
"""Test that state can be defined with Union and Optional vars."""
class UnionState(State):
int_float: Union[int, float] = 0
opt_int: Optional[int]
c3: Optional[Custom3]
c3i: Custom3 # implicitly required
c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo="")))
custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="")
assert UnionState.c3.c2._var_name == "c3?.c2" # type: ignore
assert UnionState.c3.c2.c1._var_name == "c3?.c2?.c1" # type: ignore
assert UnionState.c3.c2.c1.foo._var_name == "c3?.c2?.c1?.foo" # type: ignore
assert UnionState.c3.c2.c1r.foo._var_name == "c3?.c2?.c1r.foo" # type: ignore
assert UnionState.c3.c2r.c1._var_name == "c3?.c2r.c1" # type: ignore
assert UnionState.c3.c2r.c1.foo._var_name == "c3?.c2r.c1?.foo" # type: ignore
assert UnionState.c3.c2r.c1r.foo._var_name == "c3?.c2r.c1r.foo" # type: ignore
assert UnionState.c3i.c2._var_name == "c3i.c2" # type: ignore
assert UnionState.c3r.c2._var_name == "c3r.c2" # type: ignore
assert UnionState.custom_union.foo is not None # type: ignore
assert UnionState.custom_union.c1 is not None # type: ignore
assert UnionState.custom_union.c1r is not None # type: ignore
assert UnionState.custom_union.c2 is not None # type: ignore
assert UnionState.custom_union.c2r is not None # type: ignore
assert types.is_optional(UnionState.opt_int._var_type) # type: ignore
assert types.is_union(UnionState.int_float._var_type) # type: ignore