fix sqla python_type issues, add tests (#3613)

This commit is contained in:
benedikt-bartscher 2024-07-09 04:21:08 +02:00 committed by GitHub
parent 09ff952d01
commit 6d3321284c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 74 additions and 33 deletions

View File

@ -245,25 +245,30 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
# check for list types # check for list types
column = insp.columns[name] column = insp.columns[name]
column_type = column.type column_type = column.type
try:
type_ = insp.columns[name].type.python_type type_ = insp.columns[name].type.python_type
if hasattr(column_type, "item_type") and ( except NotImplementedError:
item_type := column_type.item_type.python_type # type: ignore type_ = None
): if type_ is not None:
if hasattr(column_type, "item_type"):
try:
item_type = column_type.item_type.python_type # type: ignore
except NotImplementedError:
item_type = None
if item_type is not None:
if type_ in PrimitiveToAnnotation: if type_ in PrimitiveToAnnotation:
type_ = PrimitiveToAnnotation[type_] # type: ignore type_ = PrimitiveToAnnotation[type_] # type: ignore
type_ = type_[item_type] # type: ignore type_ = type_[item_type] # type: ignore
if column.nullable: if column.nullable:
type_ = Optional[type_] type_ = Optional[type_]
return type_ return type_
if name not in insp.all_orm_descriptors: if name in insp.all_orm_descriptors:
return None
descriptor = insp.all_orm_descriptors[name] descriptor = insp.all_orm_descriptors[name]
if hint := get_property_hint(descriptor): if hint := get_property_hint(descriptor):
return hint return hint
if isinstance(descriptor, QueryableAttribute): if isinstance(descriptor, QueryableAttribute):
prop = descriptor.property prop = descriptor.property
if not isinstance(prop, Relationship): if isinstance(prop, Relationship):
return None
type_ = prop.mapper.class_ type_ = prop.mapper.class_
# TODO: check for nullable? # TODO: check for nullable?
type_ = List[type_] if prop.uselist else Optional[type_] type_ = List[type_] if prop.uselist else Optional[type_]

View File

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Optional, Union from typing import Dict, List, Optional, Type, Union
import attrs import attrs
import pytest import pytest
import sqlalchemy import sqlalchemy
from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
@ -12,10 +13,29 @@ import reflex as rx
from reflex.utils.types import GenericType, get_attribute_access_type from reflex.utils.types import GenericType, get_attribute_access_type
class SQLAType(TypeDecorator):
"""SQLAlchemy custom dict type."""
impl = JSON
@property
def python_type(self) -> Type[Dict[str, str]]:
"""Python type.
Returns:
Python Type of the column.
"""
return Dict[str, str]
class SQLABase(DeclarativeBase): class SQLABase(DeclarativeBase):
"""Base class for bare SQLAlchemy models.""" """Base class for bare SQLAlchemy models."""
pass type_annotation_map = {
# do not use lower case dict here!
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
Dict[str, str]: SQLAType,
}
class SQLATag(SQLABase): class SQLATag(SQLABase):
@ -52,6 +72,9 @@ class SQLAClass(SQLABase):
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id)) sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
sqla_tag: Mapped[Optional[SQLATag]] = relationship() sqla_tag: Mapped[Optional[SQLATag]] = relationship()
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test") labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
# do not use lower case dict here!
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
dict_str_str: Mapped[Dict[str, str]] = mapped_column()
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -82,6 +105,7 @@ class ModelClass(rx.Model):
optional_int: Optional[int] = None optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -112,6 +136,7 @@ class BaseClass(rx.Base):
optional_int: Optional[int] = None optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -142,6 +167,7 @@ class BareClass:
optional_int: Optional[int] = None optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -173,6 +199,7 @@ class AttrClass:
optional_int: Optional[int] = None optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {}
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -193,7 +220,15 @@ class AttrClass:
return self.name return self.name
@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass, AttrClass]) @pytest.fixture(
params=[
SQLAClass,
BaseClass,
BareClass,
ModelClass,
AttrClass,
]
)
def cls(request: pytest.FixtureRequest) -> type: def cls(request: pytest.FixtureRequest) -> type:
"""Fixture for the class to test. """Fixture for the class to test.
@ -216,6 +251,7 @@ def cls(request: pytest.FixtureRequest) -> type:
pytest.param("optional_int", Optional[int], id="Optional[int]"), pytest.param("optional_int", Optional[int], id="Optional[int]"),
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"), pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"), pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
pytest.param("dict_str_str", Dict[str, str], id="Dict[str, str]"),
pytest.param("str_property", str, id="str_property"), pytest.param("str_property", str, id="str_property"),
pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"), pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
], ],