improve sqlalchemy type parsing (#2474)

* improve sqlalchemy type parsing

* add support for propertys and relationships

* cleanup duplicate property check

* avoid confusion, improve readability
This commit is contained in:
benedikt-bartscher 2024-01-31 00:57:56 +01:00 committed by GitHub
parent 032017df3a
commit be7f7969ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,9 +18,10 @@ from typing import (
get_type_hints,
)
import sqlalchemy
from pydantic.fields import ModelField
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped
from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
from reflex.base import Base
from reflex.utils import serializers
@ -105,6 +106,21 @@ def is_optional(cls: GenericType) -> bool:
return is_union(cls) and type(None) in get_args(cls)
def get_property_hint(attr: Any | None) -> GenericType | None:
"""Check if an attribute is a property and return its type hint.
Args:
attr: The descriptor to check.
Returns:
The type hint of the property, if it is a property, else None.
"""
if not isinstance(attr, (property, hybrid_property)):
return None
hints = get_type_hints(attr.fget)
return hints.get("return", None)
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
"""Check if an attribute can be accessed on the cls and return its type.
@ -119,6 +135,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
"""
from reflex.model import Model
attr = getattr(cls, name, None)
if hint := get_property_hint(attr):
return hint
if hasattr(cls, "__fields__") and name in cls.__fields__:
# pydantic models
field = cls.__fields__[name]
@ -129,7 +148,21 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
# Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_]
return type_
elif isinstance(cls, type) and issubclass(cls, (Model, DeclarativeBase)):
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
insp = sqlalchemy.inspect(cls)
if name in insp.columns:
return insp.columns[name].type.python_type
if name not in insp.all_orm_descriptors.keys():
return None
descriptor = insp.all_orm_descriptors[name]
if hint := get_property_hint(descriptor):
return hint
if isinstance(descriptor, QueryableAttribute):
prop = descriptor.property
if not isinstance(prop, Relationship):
return None
return prop.mapper.class_
elif isinstance(cls, type) and issubclass(cls, Model):
# Check in the annotations directly (for sqlmodel.Relationship)
hints = get_type_hints(cls)
if name in hints:
@ -140,11 +173,6 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
if isinstance(type_, ModelField):
return type_.type_ # SQLAlchemy v1.4
return type_
if name in cls.__dict__:
value = cls.__dict__[name]
if isinstance(value, hybrid_property):
hints = get_type_hints(value.fget)
return hints.get("return", None)
elif is_union(cls):
# Check in each arg of the annotation.
for arg in get_args(cls):