fix sqla python_type issues, add tests (#3613)

This commit is contained in:
benedikt-bartscher 2024-07-09 04:21:08 +02:00 committed by GitHub
parent 09ff952d01
commit 6d3321284c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 74 additions and 33 deletions

View File

@ -245,36 +245,41 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
# check for list types
column = insp.columns[name]
column_type = column.type
type_ = insp.columns[name].type.python_type
if hasattr(column_type, "item_type") and (
item_type := column_type.item_type.python_type # type: ignore
):
if type_ in PrimitiveToAnnotation:
type_ = PrimitiveToAnnotation[type_] # type: ignore
type_ = type_[item_type] # type: ignore
if column.nullable:
type_ = Optional[type_]
return type_
if name not in insp.all_orm_descriptors:
return None
descriptor = insp.all_orm_descriptors[name]
if hint := get_property_hint(descriptor):
return hint
if isinstance(descriptor, QueryableAttribute):
prop = descriptor.property
if not isinstance(prop, Relationship):
return None
type_ = prop.mapper.class_
# TODO: check for nullable?
type_ = List[type_] if prop.uselist else Optional[type_]
return type_
if isinstance(attr, AssociationProxyInstance):
return List[
get_attribute_access_type(
attr.target_class,
attr.remote_attr.key, # type: ignore[attr-defined]
)
]
try:
type_ = insp.columns[name].type.python_type
except NotImplementedError:
type_ = None
if type_ is not None:
if hasattr(column_type, "item_type"):
try:
item_type = column_type.item_type.python_type # type: ignore
except NotImplementedError:
item_type = None
if item_type is not None:
if type_ in PrimitiveToAnnotation:
type_ = PrimitiveToAnnotation[type_] # type: ignore
type_ = type_[item_type] # type: ignore
if column.nullable:
type_ = Optional[type_]
return type_
if name in insp.all_orm_descriptors:
descriptor = insp.all_orm_descriptors[name]
if hint := get_property_hint(descriptor):
return hint
if isinstance(descriptor, QueryableAttribute):
prop = descriptor.property
if isinstance(prop, Relationship):
type_ = prop.mapper.class_
# TODO: check for nullable?
type_ = List[type_] if prop.uselist else Optional[type_]
return type_
if isinstance(attr, AssociationProxyInstance):
return List[
get_attribute_access_type(
attr.target_class,
attr.remote_attr.key, # type: ignore[attr-defined]
)
]
elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
# Check in the annotations directly (for sqlmodel.Relationship)
hints = get_type_hints(cls)

View File

@ -1,10 +1,11 @@
from __future__ import annotations
from typing import List, Optional, Union
from typing import Dict, List, Optional, Type, Union
import attrs
import pytest
import sqlalchemy
from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@ -12,10 +13,29 @@ import reflex as rx
from reflex.utils.types import GenericType, get_attribute_access_type
class SQLAType(TypeDecorator):
"""SQLAlchemy custom dict type."""
impl = JSON
@property
def python_type(self) -> Type[Dict[str, str]]:
"""Python type.
Returns:
Python Type of the column.
"""
return Dict[str, str]
class SQLABase(DeclarativeBase):
"""Base class for bare SQLAlchemy models."""
pass
type_annotation_map = {
# do not use lower case dict here!
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
Dict[str, str]: SQLAType,
}
class SQLATag(SQLABase):
@ -52,6 +72,9 @@ class SQLAClass(SQLABase):
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
# do not use lower case dict here!
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
dict_str_str: Mapped[Dict[str, str]] = mapped_column()
@property
def str_property(self) -> str:
@ -82,6 +105,7 @@ class ModelClass(rx.Model):
optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property
def str_property(self) -> str:
@ -112,6 +136,7 @@ class BaseClass(rx.Base):
optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property
def str_property(self) -> str:
@ -142,6 +167,7 @@ class BareClass:
optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property
def str_property(self) -> str:
@ -173,6 +199,7 @@ class AttrClass:
optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property
def str_property(self) -> str:
@ -193,7 +220,15 @@ class AttrClass:
return self.name
@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass, AttrClass])
@pytest.fixture(
params=[
SQLAClass,
BaseClass,
BareClass,
ModelClass,
AttrClass,
]
)
def cls(request: pytest.FixtureRequest) -> type:
"""Fixture for the class to test.
@ -216,6 +251,7 @@ def cls(request: pytest.FixtureRequest) -> type:
pytest.param("optional_int", Optional[int], id="Optional[int]"),
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
pytest.param("dict_str_str", Dict[str, str], id="Dict[str, str]"),
pytest.param("str_property", str, id="str_property"),
pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
],