fix sqla python_type issues, add tests (#3613)
This commit is contained in:
parent
09ff952d01
commit
6d3321284c
@ -245,36 +245,41 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
||||
# check for list types
|
||||
column = insp.columns[name]
|
||||
column_type = column.type
|
||||
type_ = insp.columns[name].type.python_type
|
||||
if hasattr(column_type, "item_type") and (
|
||||
item_type := column_type.item_type.python_type # type: ignore
|
||||
):
|
||||
if type_ in PrimitiveToAnnotation:
|
||||
type_ = PrimitiveToAnnotation[type_] # type: ignore
|
||||
type_ = type_[item_type] # type: ignore
|
||||
if column.nullable:
|
||||
type_ = Optional[type_]
|
||||
return type_
|
||||
if name not in insp.all_orm_descriptors:
|
||||
return None
|
||||
descriptor = insp.all_orm_descriptors[name]
|
||||
if hint := get_property_hint(descriptor):
|
||||
return hint
|
||||
if isinstance(descriptor, QueryableAttribute):
|
||||
prop = descriptor.property
|
||||
if not isinstance(prop, Relationship):
|
||||
return None
|
||||
type_ = prop.mapper.class_
|
||||
# TODO: check for nullable?
|
||||
type_ = List[type_] if prop.uselist else Optional[type_]
|
||||
return type_
|
||||
if isinstance(attr, AssociationProxyInstance):
|
||||
return List[
|
||||
get_attribute_access_type(
|
||||
attr.target_class,
|
||||
attr.remote_attr.key, # type: ignore[attr-defined]
|
||||
)
|
||||
]
|
||||
try:
|
||||
type_ = insp.columns[name].type.python_type
|
||||
except NotImplementedError:
|
||||
type_ = None
|
||||
if type_ is not None:
|
||||
if hasattr(column_type, "item_type"):
|
||||
try:
|
||||
item_type = column_type.item_type.python_type # type: ignore
|
||||
except NotImplementedError:
|
||||
item_type = None
|
||||
if item_type is not None:
|
||||
if type_ in PrimitiveToAnnotation:
|
||||
type_ = PrimitiveToAnnotation[type_] # type: ignore
|
||||
type_ = type_[item_type] # type: ignore
|
||||
if column.nullable:
|
||||
type_ = Optional[type_]
|
||||
return type_
|
||||
if name in insp.all_orm_descriptors:
|
||||
descriptor = insp.all_orm_descriptors[name]
|
||||
if hint := get_property_hint(descriptor):
|
||||
return hint
|
||||
if isinstance(descriptor, QueryableAttribute):
|
||||
prop = descriptor.property
|
||||
if isinstance(prop, Relationship):
|
||||
type_ = prop.mapper.class_
|
||||
# TODO: check for nullable?
|
||||
type_ = List[type_] if prop.uselist else Optional[type_]
|
||||
return type_
|
||||
if isinstance(attr, AssociationProxyInstance):
|
||||
return List[
|
||||
get_attribute_access_type(
|
||||
attr.target_class,
|
||||
attr.remote_attr.key, # type: ignore[attr-defined]
|
||||
)
|
||||
]
|
||||
elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
|
||||
# Check in the annotations directly (for sqlmodel.Relationship)
|
||||
hints = get_type_hints(cls)
|
||||
|
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from sqlalchemy import JSON, TypeDecorator
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
|
||||
@ -12,10 +13,29 @@ import reflex as rx
|
||||
from reflex.utils.types import GenericType, get_attribute_access_type
|
||||
|
||||
|
||||
class SQLAType(TypeDecorator):
|
||||
"""SQLAlchemy custom dict type."""
|
||||
|
||||
impl = JSON
|
||||
|
||||
@property
|
||||
def python_type(self) -> Type[Dict[str, str]]:
|
||||
"""Python type.
|
||||
|
||||
Returns:
|
||||
Python Type of the column.
|
||||
"""
|
||||
return Dict[str, str]
|
||||
|
||||
|
||||
class SQLABase(DeclarativeBase):
|
||||
"""Base class for bare SQLAlchemy models."""
|
||||
|
||||
pass
|
||||
type_annotation_map = {
|
||||
# do not use lower case dict here!
|
||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
|
||||
Dict[str, str]: SQLAType,
|
||||
}
|
||||
|
||||
|
||||
class SQLATag(SQLABase):
|
||||
@ -52,6 +72,9 @@ class SQLAClass(SQLABase):
|
||||
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
|
||||
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
|
||||
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
|
||||
# do not use lower case dict here!
|
||||
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
|
||||
dict_str_str: Mapped[Dict[str, str]] = mapped_column()
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
@ -82,6 +105,7 @@ class ModelClass(rx.Model):
|
||||
optional_int: Optional[int] = None
|
||||
sqla_tag: Optional[SQLATag] = None
|
||||
labels: List[SQLALabel] = []
|
||||
dict_str_str: Dict[str, str] = {}
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
@ -112,6 +136,7 @@ class BaseClass(rx.Base):
|
||||
optional_int: Optional[int] = None
|
||||
sqla_tag: Optional[SQLATag] = None
|
||||
labels: List[SQLALabel] = []
|
||||
dict_str_str: Dict[str, str] = {}
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
@ -142,6 +167,7 @@ class BareClass:
|
||||
optional_int: Optional[int] = None
|
||||
sqla_tag: Optional[SQLATag] = None
|
||||
labels: List[SQLALabel] = []
|
||||
dict_str_str: Dict[str, str] = {}
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
@ -173,6 +199,7 @@ class AttrClass:
|
||||
optional_int: Optional[int] = None
|
||||
sqla_tag: Optional[SQLATag] = None
|
||||
labels: List[SQLALabel] = []
|
||||
dict_str_str: Dict[str, str] = {}
|
||||
|
||||
@property
|
||||
def str_property(self) -> str:
|
||||
@ -193,7 +220,15 @@ class AttrClass:
|
||||
return self.name
|
||||
|
||||
|
||||
@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass, AttrClass])
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
SQLAClass,
|
||||
BaseClass,
|
||||
BareClass,
|
||||
ModelClass,
|
||||
AttrClass,
|
||||
]
|
||||
)
|
||||
def cls(request: pytest.FixtureRequest) -> type:
|
||||
"""Fixture for the class to test.
|
||||
|
||||
@ -216,6 +251,7 @@ def cls(request: pytest.FixtureRequest) -> type:
|
||||
pytest.param("optional_int", Optional[int], id="Optional[int]"),
|
||||
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
|
||||
pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
|
||||
pytest.param("dict_str_str", Dict[str, str], id="Dict[str, str]"),
|
||||
pytest.param("str_property", str, id="str_property"),
|
||||
pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
|
||||
],
|
||||
|
Loading…
Reference in New Issue
Block a user