fix sqla python_type issues, add tests (#3613)

This commit is contained in:
benedikt-bartscher 2024-07-09 04:21:08 +02:00 committed by GitHub
parent 09ff952d01
commit 6d3321284c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 74 additions and 33 deletions

View File

@ -245,36 +245,41 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
# check for list types # check for list types
column = insp.columns[name] column = insp.columns[name]
column_type = column.type column_type = column.type
type_ = insp.columns[name].type.python_type try:
if hasattr(column_type, "item_type") and ( type_ = insp.columns[name].type.python_type
item_type := column_type.item_type.python_type # type: ignore except NotImplementedError:
): type_ = None
if type_ in PrimitiveToAnnotation: if type_ is not None:
type_ = PrimitiveToAnnotation[type_] # type: ignore if hasattr(column_type, "item_type"):
type_ = type_[item_type] # type: ignore try:
if column.nullable: item_type = column_type.item_type.python_type # type: ignore
type_ = Optional[type_] except NotImplementedError:
return type_ item_type = None
if name not in insp.all_orm_descriptors: if item_type is not None:
return None if type_ in PrimitiveToAnnotation:
descriptor = insp.all_orm_descriptors[name] type_ = PrimitiveToAnnotation[type_] # type: ignore
if hint := get_property_hint(descriptor): type_ = type_[item_type] # type: ignore
return hint if column.nullable:
if isinstance(descriptor, QueryableAttribute): type_ = Optional[type_]
prop = descriptor.property return type_
if not isinstance(prop, Relationship): if name in insp.all_orm_descriptors:
return None descriptor = insp.all_orm_descriptors[name]
type_ = prop.mapper.class_ if hint := get_property_hint(descriptor):
# TODO: check for nullable? return hint
type_ = List[type_] if prop.uselist else Optional[type_] if isinstance(descriptor, QueryableAttribute):
return type_ prop = descriptor.property
if isinstance(attr, AssociationProxyInstance): if isinstance(prop, Relationship):
return List[ type_ = prop.mapper.class_
get_attribute_access_type( # TODO: check for nullable?
attr.target_class, type_ = List[type_] if prop.uselist else Optional[type_]
attr.remote_attr.key, # type: ignore[attr-defined] return type_
) if isinstance(attr, AssociationProxyInstance):
] return List[
get_attribute_access_type(
attr.target_class,
attr.remote_attr.key, # type: ignore[attr-defined]
)
]
elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model): elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
# Check in the annotations directly (for sqlmodel.Relationship) # Check in the annotations directly (for sqlmodel.Relationship)
hints = get_type_hints(cls) hints = get_type_hints(cls)

View File

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Optional, Union from typing import Dict, List, Optional, Type, Union
import attrs import attrs
import pytest import pytest
import sqlalchemy import sqlalchemy
from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@ -12,10 +13,29 @@ import reflex as rx
from reflex.utils.types import GenericType, get_attribute_access_type from reflex.utils.types import GenericType, get_attribute_access_type
class SQLAType(TypeDecorator):
"""SQLAlchemy custom dict type."""
impl = JSON
@property
def python_type(self) -> Type[Dict[str, str]]:
"""Python type.
Returns:
Python Type of the column.
"""
return Dict[str, str]
class SQLABase(DeclarativeBase): class SQLABase(DeclarativeBase):
"""Base class for bare SQLAlchemy models.""" """Base class for bare SQLAlchemy models."""
pass type_annotation_map = {
# do not use lower case dict here!
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
Dict[str, str]: SQLAType,
}
class SQLATag(SQLABase): class SQLATag(SQLABase):
@ -52,6 +72,9 @@ class SQLAClass(SQLABase):
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id)) sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
sqla_tag: Mapped[Optional[SQLATag]] = relationship() sqla_tag: Mapped[Optional[SQLATag]] = relationship()
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test") labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
# do not use lower case dict here!
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
dict_str_str: Mapped[Dict[str, str]] = mapped_column()
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -82,6 +105,7 @@ class ModelClass(rx.Model):
optional_int: Optional[int] = None optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -112,6 +136,7 @@ class BaseClass(rx.Base):
optional_int: Optional[int] = None optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -142,6 +167,7 @@ class BareClass:
optional_int: Optional[int] = None optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -173,6 +199,7 @@ class AttrClass:
optional_int: Optional[int] = None optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -193,7 +220,15 @@ class AttrClass:
return self.name return self.name
@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass, AttrClass]) @pytest.fixture(
params=[
SQLAClass,
BaseClass,
BareClass,
ModelClass,
AttrClass,
]
)
def cls(request: pytest.FixtureRequest) -> type: def cls(request: pytest.FixtureRequest) -> type:
"""Fixture for the class to test. """Fixture for the class to test.
@ -216,6 +251,7 @@ def cls(request: pytest.FixtureRequest) -> type:
pytest.param("optional_int", Optional[int], id="Optional[int]"), pytest.param("optional_int", Optional[int], id="Optional[int]"),
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"), pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"), pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
pytest.param("dict_str_str", Dict[str, str], id="Dict[str, str]"),
pytest.param("str_property", str, id="str_property"), pytest.param("str_property", str, id="str_property"),
pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"), pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
], ],