From be7f7969ed5d57cbdf2608cc5195b06fd14e9c0a Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Wed, 31 Jan 2024 00:57:56 +0100 Subject: [PATCH] improve sqlalchemy type parsing (#2474) * improve sqlalchemy type parsing * add support for propertys and relationships * cleanup duplicate property check * avoid confusion, improve readability --- reflex/utils/types.py | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 66d6a0a30..93469e7b7 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -18,9 +18,10 @@ from typing import ( get_type_hints, ) +import sqlalchemy from pydantic.fields import ModelField from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import DeclarativeBase, Mapped +from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship from reflex.base import Base from reflex.utils import serializers @@ -105,6 +106,21 @@ def is_optional(cls: GenericType) -> bool: return is_union(cls) and type(None) in get_args(cls) +def get_property_hint(attr: Any | None) -> GenericType | None: + """Check if an attribute is a property and return its type hint. + + Args: + attr: The descriptor to check. + + Returns: + The type hint of the property, if it is a property, else None. + """ + if not isinstance(attr, (property, hybrid_property)): + return None + hints = get_type_hints(attr.fget) + return hints.get("return", None) + + def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None: """Check if an attribute can be accessed on the cls and return its type. @@ -119,6 +135,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None """ from reflex.model import Model + attr = getattr(cls, name, None) + if hint := get_property_hint(attr): + return hint if hasattr(cls, "__fields__") and name in cls.__fields__: # pydantic models field = cls.__fields__[name] @@ -129,7 +148,21 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None # Ensure frontend uses null coalescing when accessing. type_ = Optional[type_] return type_ - elif isinstance(cls, type) and issubclass(cls, (Model, DeclarativeBase)): + elif isinstance(cls, type) and issubclass(cls, DeclarativeBase): + insp = sqlalchemy.inspect(cls) + if name in insp.columns: + return insp.columns[name].type.python_type + if name not in insp.all_orm_descriptors.keys(): + return None + descriptor = insp.all_orm_descriptors[name] + if hint := get_property_hint(descriptor): + return hint + if isinstance(descriptor, QueryableAttribute): + prop = descriptor.property + if not isinstance(prop, Relationship): + return None + return prop.mapper.class_ + elif isinstance(cls, type) and issubclass(cls, Model): # Check in the annotations directly (for sqlmodel.Relationship) hints = get_type_hints(cls) if name in hints: @@ -140,11 +173,6 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None if isinstance(type_, ModelField): return type_.type_ # SQLAlchemy v1.4 return type_ - if name in cls.__dict__: - value = cls.__dict__[name] - if isinstance(value, hybrid_property): - hints = get_type_hints(value.fget) - return hints.get("return", None) elif is_union(cls): # Check in each arg of the annotation. for arg in get_args(cls):