improve sqlalchemy type parsing (#2474)
* improve sqlalchemy type parsing * add support for propertys and relationships * cleanup duplicate property check * avoid confusion, improve readability
This commit is contained in:
parent
032017df3a
commit
be7f7969ed
@ -18,9 +18,10 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
import sqlalchemy
|
||||
from pydantic.fields import ModelField
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
|
||||
|
||||
from reflex.base import Base
|
||||
from reflex.utils import serializers
|
||||
@ -105,6 +106,21 @@ def is_optional(cls: GenericType) -> bool:
|
||||
return is_union(cls) and type(None) in get_args(cls)
|
||||
|
||||
|
||||
def get_property_hint(attr: Any | None) -> GenericType | None:
|
||||
"""Check if an attribute is a property and return its type hint.
|
||||
|
||||
Args:
|
||||
attr: The descriptor to check.
|
||||
|
||||
Returns:
|
||||
The type hint of the property, if it is a property, else None.
|
||||
"""
|
||||
if not isinstance(attr, (property, hybrid_property)):
|
||||
return None
|
||||
hints = get_type_hints(attr.fget)
|
||||
return hints.get("return", None)
|
||||
|
||||
|
||||
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
|
||||
"""Check if an attribute can be accessed on the cls and return its type.
|
||||
|
||||
@ -119,6 +135,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
||||
"""
|
||||
from reflex.model import Model
|
||||
|
||||
attr = getattr(cls, name, None)
|
||||
if hint := get_property_hint(attr):
|
||||
return hint
|
||||
if hasattr(cls, "__fields__") and name in cls.__fields__:
|
||||
# pydantic models
|
||||
field = cls.__fields__[name]
|
||||
@ -129,7 +148,21 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
||||
# Ensure frontend uses null coalescing when accessing.
|
||||
type_ = Optional[type_]
|
||||
return type_
|
||||
elif isinstance(cls, type) and issubclass(cls, (Model, DeclarativeBase)):
|
||||
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
|
||||
insp = sqlalchemy.inspect(cls)
|
||||
if name in insp.columns:
|
||||
return insp.columns[name].type.python_type
|
||||
if name not in insp.all_orm_descriptors.keys():
|
||||
return None
|
||||
descriptor = insp.all_orm_descriptors[name]
|
||||
if hint := get_property_hint(descriptor):
|
||||
return hint
|
||||
if isinstance(descriptor, QueryableAttribute):
|
||||
prop = descriptor.property
|
||||
if not isinstance(prop, Relationship):
|
||||
return None
|
||||
return prop.mapper.class_
|
||||
elif isinstance(cls, type) and issubclass(cls, Model):
|
||||
# Check in the annotations directly (for sqlmodel.Relationship)
|
||||
hints = get_type_hints(cls)
|
||||
if name in hints:
|
||||
@ -140,11 +173,6 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
||||
if isinstance(type_, ModelField):
|
||||
return type_.type_ # SQLAlchemy v1.4
|
||||
return type_
|
||||
if name in cls.__dict__:
|
||||
value = cls.__dict__[name]
|
||||
if isinstance(value, hybrid_property):
|
||||
hints = get_type_hints(value.fget)
|
||||
return hints.get("return", None)
|
||||
elif is_union(cls):
|
||||
# Check in each arg of the annotation.
|
||||
for arg in get_args(cls):
|
||||
|
Loading…
Reference in New Issue
Block a user