improve sqlalchemy type parsing (#2474)
* improve sqlalchemy type parsing * add support for propertys and relationships * cleanup duplicate property check * avoid confusion, improve readability
This commit is contained in:
parent
032017df3a
commit
be7f7969ed
@ -18,9 +18,10 @@ from typing import (
|
|||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
from pydantic.fields import ModelField
|
from pydantic.fields import ModelField
|
||||||
from sqlalchemy.ext.hybrid import hybrid_property
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped
|
from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
|
||||||
|
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.utils import serializers
|
from reflex.utils import serializers
|
||||||
@ -105,6 +106,21 @@ def is_optional(cls: GenericType) -> bool:
|
|||||||
return is_union(cls) and type(None) in get_args(cls)
|
return is_union(cls) and type(None) in get_args(cls)
|
||||||
|
|
||||||
|
|
||||||
|
def get_property_hint(attr: Any | None) -> GenericType | None:
|
||||||
|
"""Check if an attribute is a property and return its type hint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attr: The descriptor to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The type hint of the property, if it is a property, else None.
|
||||||
|
"""
|
||||||
|
if not isinstance(attr, (property, hybrid_property)):
|
||||||
|
return None
|
||||||
|
hints = get_type_hints(attr.fget)
|
||||||
|
return hints.get("return", None)
|
||||||
|
|
||||||
|
|
||||||
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
|
def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None:
|
||||||
"""Check if an attribute can be accessed on the cls and return its type.
|
"""Check if an attribute can be accessed on the cls and return its type.
|
||||||
|
|
||||||
@ -119,6 +135,9 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|||||||
"""
|
"""
|
||||||
from reflex.model import Model
|
from reflex.model import Model
|
||||||
|
|
||||||
|
attr = getattr(cls, name, None)
|
||||||
|
if hint := get_property_hint(attr):
|
||||||
|
return hint
|
||||||
if hasattr(cls, "__fields__") and name in cls.__fields__:
|
if hasattr(cls, "__fields__") and name in cls.__fields__:
|
||||||
# pydantic models
|
# pydantic models
|
||||||
field = cls.__fields__[name]
|
field = cls.__fields__[name]
|
||||||
@ -129,7 +148,21 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|||||||
# Ensure frontend uses null coalescing when accessing.
|
# Ensure frontend uses null coalescing when accessing.
|
||||||
type_ = Optional[type_]
|
type_ = Optional[type_]
|
||||||
return type_
|
return type_
|
||||||
elif isinstance(cls, type) and issubclass(cls, (Model, DeclarativeBase)):
|
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
|
||||||
|
insp = sqlalchemy.inspect(cls)
|
||||||
|
if name in insp.columns:
|
||||||
|
return insp.columns[name].type.python_type
|
||||||
|
if name not in insp.all_orm_descriptors.keys():
|
||||||
|
return None
|
||||||
|
descriptor = insp.all_orm_descriptors[name]
|
||||||
|
if hint := get_property_hint(descriptor):
|
||||||
|
return hint
|
||||||
|
if isinstance(descriptor, QueryableAttribute):
|
||||||
|
prop = descriptor.property
|
||||||
|
if not isinstance(prop, Relationship):
|
||||||
|
return None
|
||||||
|
return prop.mapper.class_
|
||||||
|
elif isinstance(cls, type) and issubclass(cls, Model):
|
||||||
# Check in the annotations directly (for sqlmodel.Relationship)
|
# Check in the annotations directly (for sqlmodel.Relationship)
|
||||||
hints = get_type_hints(cls)
|
hints = get_type_hints(cls)
|
||||||
if name in hints:
|
if name in hints:
|
||||||
@ -140,11 +173,6 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|||||||
if isinstance(type_, ModelField):
|
if isinstance(type_, ModelField):
|
||||||
return type_.type_ # SQLAlchemy v1.4
|
return type_.type_ # SQLAlchemy v1.4
|
||||||
return type_
|
return type_
|
||||||
if name in cls.__dict__:
|
|
||||||
value = cls.__dict__[name]
|
|
||||||
if isinstance(value, hybrid_property):
|
|
||||||
hints = get_type_hints(value.fget)
|
|
||||||
return hints.get("return", None)
|
|
||||||
elif is_union(cls):
|
elif is_union(cls):
|
||||||
# Check in each arg of the annotation.
|
# Check in each arg of the annotation.
|
||||||
for arg in get_args(cls):
|
for arg in get_args(cls):
|
||||||
|
Loading…
Reference in New Issue
Block a user