From 92dd68c51f425b31d8114eb7956bc8ed8fd88cc4 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 25 Oct 2023 11:55:50 -0700 Subject: [PATCH] Improve Var type handling for better rx.Model attribute access (#2010) --- reflex/model.py | 31 ++++++++++++- reflex/state.py | 7 +++ reflex/utils/types.py | 102 ++++++++++++++++++++++++++++++++---------- reflex/vars.py | 18 ++++---- tests/test_state.py | 53 +++++++++++++++++++++- 5 files changed, 175 insertions(+), 36 deletions(-) diff --git a/reflex/model.py b/reflex/model.py index 12703e7be..3167782e5 100644 --- a/reflex/model.py +++ b/reflex/model.py @@ -15,6 +15,7 @@ import alembic.runtime.environment import alembic.script import alembic.util import sqlalchemy +import sqlalchemy.orm import sqlmodel from reflex import constants @@ -68,6 +69,22 @@ class Model(Base, sqlmodel.SQLModel): super().__init_subclass__() + @classmethod + def _dict_recursive(cls, value): + """Recursively serialize the relationship object(s). + + Args: + value: The value to serialize. + + Returns: + The serialized value. + """ + if hasattr(value, "dict"): + return value.dict() + elif isinstance(value, list): + return [cls._dict_recursive(item) for item in value] + return value + def dict(self, **kwargs): """Convert the object to a dictionary. @@ -77,7 +94,19 @@ class Model(Base, sqlmodel.SQLModel): Returns: The object as a dictionary. """ - return {name: getattr(self, name) for name in self.__fields__} + base_fields = {name: getattr(self, name) for name in self.__fields__} + relationships = {} + # SQLModel relationships do not appear in __fields__, but should be included if present. + for name in self.__sqlmodel_relationships__: + try: + relationships[name] = self._dict_recursive(getattr(self, name)) + except sqlalchemy.orm.exc.DetachedInstanceError: + # This happens when the relationship was never loaded and the session is closed. + continue + return { + **base_fields, + **relationships, + } @staticmethod def create_all(): diff --git a/reflex/state.py b/reflex/state.py index 77445aa03..ff2233aed 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -571,6 +571,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow): if default_value is not None: field.required = False field.default = default_value + if ( + not field.required + and field.default is None + and not types.is_optional(prop._var_type) + ): + # Ensure frontend uses null coalescing when accessing. + prop._var_type = Optional[prop._var_type] @staticmethod def _get_base_functions() -> dict[str, FunctionType]: diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 973196af1..c114e94ac 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -3,8 +3,22 @@ from __future__ import annotations import contextlib -import typing -from typing import Any, Callable, Literal, Type, Union, _GenericAlias # type: ignore +import types +from typing import ( + Any, + Callable, + Iterable, + Literal, + Optional, + Type, + Union, + _GenericAlias, # type: ignore + get_args, + get_origin, + get_type_hints, +) + +from pydantic.fields import ModelField from reflex.base import Base from reflex.utils import serializers @@ -21,18 +35,6 @@ StateIterVar = Union[list, set, tuple] ArgsSpec = Callable -def get_args(alias: _GenericAlias) -> tuple[Type, ...]: - """Get the arguments of a type alias. - - Args: - alias: The type alias. - - Returns: - The arguments of the type alias. - """ - return alias.__args__ - - def is_generic_alias(cls: GenericType) -> bool: """Check whether the class is a generic alias. @@ -69,11 +71,11 @@ def is_union(cls: GenericType) -> bool: Returns: Whether the class is a Union. """ - with contextlib.suppress(ImportError): - from typing import _UnionGenericAlias # type: ignore + # UnionType added in py3.10 + if not hasattr(types, "UnionType"): + return get_origin(cls) is Union - return isinstance(cls, _UnionGenericAlias) - return cls.__origin__ == Union if is_generic_alias(cls) else False + return get_origin(cls) in [Union, types.UnionType] def is_literal(cls: GenericType) -> bool: @@ -85,7 +87,61 @@ def is_literal(cls: GenericType) -> bool: Returns: Whether the class is a literal. """ - return hasattr(cls, "__origin__") and cls.__origin__ is Literal + return get_origin(cls) is Literal + + +def is_optional(cls: GenericType) -> bool: + """Check if a class is an Optional. + + Args: + cls: The class to check. + + Returns: + Whether the class is an Optional. + """ + return is_union(cls) and type(None) in get_args(cls) + + +def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None: + """Check if an attribute can be accessed on the cls and return its type. + + Supports pydantic models, unions, and annotated attributes on rx.Model. + + Args: + cls: The class to check. + name: The name of the attribute to check. + + Returns: + The type of the attribute, if accessible, or None + """ + from reflex.model import Model + + if hasattr(cls, "__fields__") and name in cls.__fields__: + # pydantic models + field = cls.__fields__[name] + type_ = field.outer_type_ + if isinstance(type_, ModelField): + type_ = type_.type_ + if not field.required and field.default is None: + # Ensure frontend uses null coalescing when accessing. + type_ = Optional[type_] + return type_ + elif isinstance(cls, type) and issubclass(cls, Model): + # Check in the annotations directly (for sqlmodel.Relationship) + hints = get_type_hints(cls) + if name in hints: + type_ = hints[name] + if isinstance(type_, ModelField): + return type_.type_ + return type_ + elif is_union(cls): + # Check in each arg of the annotation. + for arg in get_args(cls): + type_ = get_attribute_access_type(arg, name) + if type_ is not None: + # Return the first attribute type that is accessible. + return type_ + return None # Attribute is not accessible. def get_base_class(cls: GenericType) -> Type: @@ -171,7 +227,7 @@ def is_dataframe(value: Type) -> bool: Returns: Whether the value is a dataframe. """ - if is_generic_alias(value) or value == typing.Any: + if is_generic_alias(value) or value == Any: return False return value.__name__ == "DataFrame" @@ -185,6 +241,8 @@ def is_valid_var_type(type_: Type) -> bool: Returns: Whether the type is a valid prop type. """ + if is_union(type_): + return all((is_valid_var_type(arg) for arg in get_args(type_))) return _issubclass(type_, StateVar) or serializers.has_serializer(type_) @@ -200,9 +258,7 @@ def is_backend_variable(name: str) -> bool: return name.startswith("_") and not name.startswith("__") -def check_type_in_allowed_types( - value_type: Type, allowed_types: typing.Iterable -) -> bool: +def check_type_in_allowed_types(value_type: Type, allowed_types: Iterable) -> bool: """Check that a value type is found in a list of allowed types. Args: diff --git a/reflex/vars.py b/reflex/vars.py index be8ac0dd3..cb55db023 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -27,8 +27,6 @@ from typing import ( get_type_hints, ) -from pydantic.fields import ModelField - from reflex import constants from reflex.base import Base from reflex.utils import console, format, serializers, types @@ -420,15 +418,12 @@ class Var: raise TypeError( f"You must provide an annotation for the state var `{self._var_full_name}`. Annotation cannot be `{self._var_type}`" ) from None - if ( - hasattr(self._var_type, "__fields__") - and name in self._var_type.__fields__ - ): - type_ = self._var_type.__fields__[name].outer_type_ - if isinstance(type_, ModelField): - type_ = type_.type_ + is_optional = types.is_optional(self._var_type) + type_ = types.get_attribute_access_type(self._var_type, name) + + if type_ is not None: return BaseVar( - _var_name=f"{self._var_name}.{name}", + _var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}", _var_type=type_, _var_state=self._var_state, _var_is_local=self._var_is_local, @@ -1235,6 +1230,9 @@ class BaseVar(Var): Raises: ImportError: If the var is a dataframe and pandas is not installed. """ + if types.is_optional(self._var_type): + return None + type_ = ( get_origin(self._var_type) if types.is_generic_alias(self._var_type) diff --git a/tests/test_state.py b/tests/test_state.py index 62d051819..b8673fb86 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -7,7 +7,7 @@ import functools import json import os import sys -from typing import Dict, Generator, List +from typing import Dict, Generator, List, Optional, Union from unittest.mock import AsyncMock, Mock import pytest @@ -30,7 +30,7 @@ from reflex.state import ( StateProxy, StateUpdate, ) -from reflex.utils import prerequisites +from reflex.utils import prerequisites, types from reflex.utils.format import json_dumps from reflex.vars import BaseVar, ComputedVar @@ -2239,3 +2239,52 @@ def test_reset_with_mutables(): instance.items.append([3, 3]) assert instance.items != default assert instance.items != copied_default + + +class Custom1(Base): + """A custom class with a str field.""" + + foo: str + + +class Custom2(Base): + """A custom class with a Custom1 field.""" + + c1: Optional[Custom1] = None + c1r: Custom1 + + +class Custom3(Base): + """A custom class with a Custom2 field.""" + + c2: Optional[Custom2] = None + c2r: Custom2 + + +def test_state_union_optional(): + """Test that state can be defined with Union and Optional vars.""" + + class UnionState(State): + int_float: Union[int, float] = 0 + opt_int: Optional[int] + c3: Optional[Custom3] + c3i: Custom3 # implicitly required + c3r: Custom3 = Custom3(c2r=Custom2(c1r=Custom1(foo=""))) + custom_union: Union[Custom1, Custom2, Custom3] = Custom1(foo="") + + assert UnionState.c3.c2._var_name == "c3?.c2" # type: ignore + assert UnionState.c3.c2.c1._var_name == "c3?.c2?.c1" # type: ignore + assert UnionState.c3.c2.c1.foo._var_name == "c3?.c2?.c1?.foo" # type: ignore + assert UnionState.c3.c2.c1r.foo._var_name == "c3?.c2?.c1r.foo" # type: ignore + assert UnionState.c3.c2r.c1._var_name == "c3?.c2r.c1" # type: ignore + assert UnionState.c3.c2r.c1.foo._var_name == "c3?.c2r.c1?.foo" # type: ignore + assert UnionState.c3.c2r.c1r.foo._var_name == "c3?.c2r.c1r.foo" # type: ignore + assert UnionState.c3i.c2._var_name == "c3i.c2" # type: ignore + assert UnionState.c3r.c2._var_name == "c3r.c2" # type: ignore + assert UnionState.custom_union.foo is not None # type: ignore + assert UnionState.custom_union.c1 is not None # type: ignore + assert UnionState.custom_union.c1r is not None # type: ignore + assert UnionState.custom_union.c2 is not None # type: ignore + assert UnionState.custom_union.c2r is not None # type: ignore + assert types.is_optional(UnionState.opt_int._var_type) # type: ignore + assert types.is_union(UnionState.int_float._var_type) # type: ignore