Improve Var type handling for better rx.Model attribute access (#2010)

This commit is contained in:
Masen Furer 2023-10-25 11:55:50 -07:00 committed by GitHub
parent f404205c3f
commit 92dd68c51f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 175 additions and 36 deletions

View File

@ -15,6 +15,7 @@ import alembic.runtime.environment
import alembic.script import alembic.script
import alembic.util import alembic.util
import sqlalchemy import sqlalchemy
import sqlalchemy.orm
import sqlmodel import sqlmodel
from reflex import constants from reflex import constants
@ -68,6 +69,22 @@ class Model(Base, sqlmodel.SQLModel):
super().__init_subclass__() super().__init_subclass__()
@classmethod
def _dict_recursive(cls, value):
"""Recursively serialize the relationship object(s).
Args:
value: The value to serialize.
Returns:
The serialized value.
"""
if hasattr(value, "dict"):
return value.dict()
elif isinstance(value, list):
return [cls._dict_recursive(item) for item in value]
return value
def dict(self, **kwargs): def dict(self, **kwargs):
"""Convert the object to a dictionary. """Convert the object to a dictionary.
@ -77,7 +94,19 @@ class Model(Base, sqlmodel.SQLModel):
Returns: Returns:
The object as a dictionary. The object as a dictionary.
""" """
return {name: getattr(self, name) for name in self.__fields__} base_fields = {name: getattr(self, name) for name in self.__fields__}
relationships = {}
# SQLModel relationships do not appear in __fields__, but should be included if present.
for name in self.__sqlmodel_relationships__:
try:
relationships[name] = self._dict_recursive(getattr(self, name))
except sqlalchemy.orm.exc.DetachedInstanceError:
# This happens when the relationship was never loaded and the session is closed.
continue
return {
**base_fields,
**relationships,
}
@staticmethod @staticmethod
def create_all(): def create_all():

View File

@ -571,6 +571,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
if default_value is not None: if default_value is not None:
field.required = False field.required = False
field.default = default_value field.default = default_value
if (
not field.required
and field.default is None
and not types.is_optional(prop._var_type)
):
# Ensure frontend uses null coalescing when accessing.
prop._var_type = Optional[prop._var_type]
@staticmethod @staticmethod
def _get_base_functions() -> dict[str, FunctionType]: def _get_base_functions() -> dict[str, FunctionType]:

View File

@ -3,8 +3,22 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
import typing import types
from typing import Any, Callable, Literal, Type, Union, _GenericAlias # type: ignore from typing import (
Any,
Callable,
Iterable,
Literal,
Optional,
Type,
Union,
_GenericAlias, # type: ignore
get_args,
get_origin,
get_type_hints,
)
from pydantic.fields import ModelField
from reflex.base import Base from reflex.base import Base
from reflex.utils import serializers from reflex.utils import serializers
@ -21,18 +35,6 @@ StateIterVar = Union[list, set, tuple]
ArgsSpec = Callable ArgsSpec = Callable
def get_args(alias: _GenericAlias) -> tuple[Type, ...]:
"""Get the arguments of a type alias.
Args:
alias: The type alias.
Returns:
The arguments of the type alias.
"""
return alias.__args__
def is_generic_alias(cls: GenericType) -> bool: def is_generic_alias(cls: GenericType) -> bool:
"""Check whether the class is a generic alias. """Check whether the class is a generic alias.
@ -69,11 +71,11 @@ def is_union(cls: GenericType) -> bool:
Returns: Returns:
Whether the class is a Union. Whether the class is a Union.
""" """
with contextlib.suppress(ImportError): # UnionType added in py3.10
from typing import _UnionGenericAlias # type: ignore if not hasattr(types, "UnionType"):
return get_origin(cls) is Union
return isinstance(cls, _UnionGenericAlias) return get_origin(cls) in [Union, types.UnionType]
return cls.__origin__ == Union if is_generic_alias(cls) else False
def is_literal(cls: GenericType) -> bool: def is_literal(cls: GenericType) -> bool:
@ -85,7 +87,61 @@ def is_literal(cls: GenericType) -> bool:
Returns: Returns:
Whether the class is a literal. Whether the class is a literal.
""" """
return hasattr(cls, "__origin__") and cls.__origin__ is Literal return get_origin(cls) is Literal
def is_optional(cls: GenericType) -> bool:
"""Check if a class is an Optional.
Args:
cls: The class to check.
Returns:
Whether the class is an Optional.
"""
return is_union(cls) and type(None) in get_args(cls)
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
"""Check if an attribute can be accessed on the cls and return its type.
Supports pydantic models, unions, and annotated attributes on rx.Model.
Args:
cls: The class to check.
name: The name of the attribute to check.
Returns:
The type of the attribute, if accessible, or None
"""
from reflex.model import Model
if hasattr(cls, "__fields__") and name in cls.__fields__:
# pydantic models
field = cls.__fields__[name]
type_ = field.outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
if not field.required and field.default is None:
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
elif isinstance(cls, type) and issubclass(cls, Model):
# Check in the annotations directly (for sqlmodel.Relationship)
hints = get_type_hints(cls)
if name in hints:
type_ = hints[name]
if isinstance(type_, ModelField):
return type_.type_
return type_
elif is_union(cls):
# Check in each arg of the annotation.
for arg in get_args(cls):
type_ = get_attribute_access_type(arg, name)
if type_ is not None:
# Return the first attribute type that is accessible.
return type_
return None # Attribute is not accessible.
def get_base_class(cls: GenericType) -> Type: def get_base_class(cls: GenericType) -> Type:
@ -171,7 +227,7 @@ def is_dataframe(value: Type) -> bool:
Returns: Returns:
Whether the value is a dataframe. Whether the value is a dataframe.
""" """
if is_generic_alias(value) or value == typing.Any: if is_generic_alias(value) or value == Any:
return False return False
return value.__name__ == "DataFrame" return value.__name__ == "DataFrame"
@ -185,6 +241,8 @@ def is_valid_var_type(type_: Type) -> bool:
Returns: Returns:
Whether the type is a valid prop type. Whether the type is a valid prop type.
""" """
if is_union(type_):
return all((is_valid_var_type(arg) for arg in get_args(type_)))
return _issubclass(type_, StateVar) or serializers.has_serializer(type_) return _issubclass(type_, StateVar) or serializers.has_serializer(type_)
@ -200,9 +258,7 @@ def is_backend_variable(name: str) -> bool:
return name.startswith("_") and not name.startswith("__") return name.startswith("_") and not name.startswith("__")
def check_type_in_allowed_types( def check_type_in_allowed_types(value_type: Type, allowed_types: Iterable) -> bool:
value_type: Type, allowed_types: typing.Iterable
) -> bool:
"""Check that a value type is found in a list of allowed types. """Check that a value type is found in a list of allowed types.
Args: Args:

View File

@ -27,8 +27,6 @@ from typing import (
get_type_hints, get_type_hints,
) )
from pydantic.fields import ModelField
from reflex import constants from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.utils import console, format, serializers, types from reflex.utils import console, format, serializers, types
@ -420,15 +418,12 @@ class Var:
raise TypeError( raise TypeError(
f"You must provide an annotation for the state var `{self._var_full_name}`. Annotation cannot be `{self._var_type}`" f"You must provide an annotation for the state var `{self._var_full_name}`. Annotation cannot be `{self._var_type}`"
) from None ) from None
if ( is_optional = types.is_optional(self._var_type)
hasattr(self._var_type, "__fields__") type_ = types.get_attribute_access_type(self._var_type, name)
and name in self._var_type.__fields__
): if type_ is not None:
type_ = self._var_type.__fields__[name].outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
return BaseVar( return BaseVar(
_var_name=f"{self._var_name}.{name}", _var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
_var_type=type_, _var_type=type_,
_var_state=self._var_state, _var_state=self._var_state,
_var_is_local=self._var_is_local, _var_is_local=self._var_is_local,
@ -1235,6 +1230,9 @@ class BaseVar(Var):
Raises: Raises:
ImportError: If the var is a dataframe and pandas is not installed. ImportError: If the var is a dataframe and pandas is not installed.
""" """
if types.is_optional(self._var_type):
return None
type_ = ( type_ = (
get_origin(self._var_type) get_origin(self._var_type)
if types.is_generic_alias(self._var_type) if types.is_generic_alias(self._var_type)

View File

@ -7,7 +7,7 @@ import functools
import json import json
import os import os
import sys import sys
from typing import Dict, Generator, List from typing import Dict, Generator, List, Optional, Union
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
import pytest import pytest
@ -30,7 +30,7 @@ from reflex.state import (
StateProxy, StateProxy,
StateUpdate, StateUpdate,
) )
from reflex.utils import prerequisites from reflex.utils import prerequisites, types
from reflex.utils.format import json_dumps from reflex.utils.format import json_dumps
from reflex.vars import BaseVar, ComputedVar from reflex.vars import BaseVar, ComputedVar
@ -2239,3 +2239,52 @@ def test_reset_with_mutables():
instance.items.append([3, 3]) instance.items.append([3, 3])
assert instance.items != default assert instance.items != default
assert instance.items != copied_default assert instance.items != copied_default
class Custom1(Base):
"""A custom class with a str field."""
foo: str
class Custom2(Base):
"""A custom class with a Custom1 field."""
c1: Optional[Custom1] = None
c1r: Custom1
class Custom3(Base):
"""A custom class with a Custom2 field."""
c2: Optional[Custom2] = None
c2r: Custom2
def test_state_union_optional():
"""Test that state can be defined with Union and Optional vars."""
class UnionState(State):
int_float: Union[int, float] = 0
opt_int: Optional[int]
c3: Optional[Custom3]
c3i: Custom3 # implicitly required
c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo="")))
custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="")
assert UnionState.c3.c2._var_name == "c3?.c2" # type: ignore
assert UnionState.c3.c2.c1._var_name == "c3?.c2?.c1" # type: ignore
assert UnionState.c3.c2.c1.foo._var_name == "c3?.c2?.c1?.foo" # type: ignore
assert UnionState.c3.c2.c1r.foo._var_name == "c3?.c2?.c1r.foo" # type: ignore
assert UnionState.c3.c2r.c1._var_name == "c3?.c2r.c1" # type: ignore
assert UnionState.c3.c2r.c1.foo._var_name == "c3?.c2r.c1?.foo" # type: ignore
assert UnionState.c3.c2r.c1r.foo._var_name == "c3?.c2r.c1r.foo" # type: ignore
assert UnionState.c3i.c2._var_name == "c3i.c2" # type: ignore
assert UnionState.c3r.c2._var_name == "c3r.c2" # type: ignore
assert UnionState.custom_union.foo is not None # type: ignore
assert UnionState.custom_union.c1 is not None # type: ignore
assert UnionState.custom_union.c1r is not None # type: ignore
assert UnionState.custom_union.c2 is not None # type: ignore
assert UnionState.custom_union.c2r is not None # type: ignore
assert types.is_optional(UnionState.opt_int._var_type) # type: ignore
assert types.is_union(UnionState.int_float._var_type) # type: ignore