fix sqla python_type issues, add tests (#3613)
This commit is contained in:
parent
09ff952d01
commit
6d3321284c
@ -245,36 +245,41 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|||||||
# check for list types
|
# check for list types
|
||||||
column = insp.columns[name]
|
column = insp.columns[name]
|
||||||
column_type = column.type
|
column_type = column.type
|
||||||
type_ = insp.columns[name].type.python_type
|
try:
|
||||||
if hasattr(column_type, "item_type") and (
|
type_ = insp.columns[name].type.python_type
|
||||||
item_type := column_type.item_type.python_type # type: ignore
|
except NotImplementedError:
|
||||||
):
|
type_ = None
|
||||||
if type_ in PrimitiveToAnnotation:
|
if type_ is not None:
|
||||||
type_ = PrimitiveToAnnotation[type_] # type: ignore
|
if hasattr(column_type, "item_type"):
|
||||||
type_ = type_[item_type] # type: ignore
|
try:
|
||||||
if column.nullable:
|
item_type = column_type.item_type.python_type # type: ignore
|
||||||
type_ = Optional[type_]
|
except NotImplementedError:
|
||||||
return type_
|
item_type = None
|
||||||
if name not in insp.all_orm_descriptors:
|
if item_type is not None:
|
||||||
return None
|
if type_ in PrimitiveToAnnotation:
|
||||||
descriptor = insp.all_orm_descriptors[name]
|
type_ = PrimitiveToAnnotation[type_] # type: ignore
|
||||||
if hint := get_property_hint(descriptor):
|
type_ = type_[item_type] # type: ignore
|
||||||
return hint
|
if column.nullable:
|
||||||
if isinstance(descriptor, QueryableAttribute):
|
type_ = Optional[type_]
|
||||||
prop = descriptor.property
|
return type_
|
||||||
if not isinstance(prop, Relationship):
|
if name in insp.all_orm_descriptors:
|
||||||
return None
|
descriptor = insp.all_orm_descriptors[name]
|
||||||
type_ = prop.mapper.class_
|
if hint := get_property_hint(descriptor):
|
||||||
# TODO: check for nullable?
|
return hint
|
||||||
type_ = List[type_] if prop.uselist else Optional[type_]
|
if isinstance(descriptor, QueryableAttribute):
|
||||||
return type_
|
prop = descriptor.property
|
||||||
if isinstance(attr, AssociationProxyInstance):
|
if isinstance(prop, Relationship):
|
||||||
return List[
|
type_ = prop.mapper.class_
|
||||||
get_attribute_access_type(
|
# TODO: check for nullable?
|
||||||
attr.target_class,
|
type_ = List[type_] if prop.uselist else Optional[type_]
|
||||||
attr.remote_attr.key, # type: ignore[attr-defined]
|
return type_
|
||||||
)
|
if isinstance(attr, AssociationProxyInstance):
|
||||||
]
|
return List[
|
||||||
|
get_attribute_access_type(
|
||||||
|
attr.target_class,
|
||||||
|
attr.remote_attr.key, # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
]
|
||||||
elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
|
elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model):
|
||||||
# Check in the annotations directly (for sqlmodel.Relationship)
|
# Check in the annotations directly (for sqlmodel.Relationship)
|
||||||
hints = get_type_hints(cls)
|
hints = get_type_hints(cls)
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import attrs
|
import attrs
|
||||||
import pytest
|
import pytest
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
from sqlalchemy import JSON, TypeDecorator
|
||||||
from sqlalchemy.ext.hybrid import hybrid_property
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
@ -12,10 +13,29 @@ import reflex as rx
|
|||||||
from reflex.utils.types import GenericType, get_attribute_access_type
|
from reflex.utils.types import GenericType, get_attribute_access_type
|
||||||
|
|
||||||
|
|
||||||
|
class SQLAType(TypeDecorator):
|
||||||
|
"""SQLAlchemy custom dict type."""
|
||||||
|
|
||||||
|
impl = JSON
|
||||||
|
|
||||||
|
@property
|
||||||
|
def python_type(self) -> Type[Dict[str, str]]:
|
||||||
|
"""Python type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Python Type of the column.
|
||||||
|
"""
|
||||||
|
return Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
class SQLABase(DeclarativeBase):
|
class SQLABase(DeclarativeBase):
|
||||||
"""Base class for bare SQLAlchemy models."""
|
"""Base class for bare SQLAlchemy models."""
|
||||||
|
|
||||||
pass
|
type_annotation_map = {
|
||||||
|
# do not use lower case dict here!
|
||||||
|
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
|
||||||
|
Dict[str, str]: SQLAType,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SQLATag(SQLABase):
|
class SQLATag(SQLABase):
|
||||||
@ -52,6 +72,9 @@ class SQLAClass(SQLABase):
|
|||||||
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
|
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
|
||||||
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
|
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
|
||||||
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
|
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
|
||||||
|
# do not use lower case dict here!
|
||||||
|
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
|
||||||
|
dict_str_str: Mapped[Dict[str, str]] = mapped_column()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def str_property(self) -> str:
|
def str_property(self) -> str:
|
||||||
@ -82,6 +105,7 @@ class ModelClass(rx.Model):
|
|||||||
optional_int: Optional[int] = None
|
optional_int: Optional[int] = None
|
||||||
sqla_tag: Optional[SQLATag] = None
|
sqla_tag: Optional[SQLATag] = None
|
||||||
labels: List[SQLALabel] = []
|
labels: List[SQLALabel] = []
|
||||||
|
dict_str_str: Dict[str, str] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def str_property(self) -> str:
|
def str_property(self) -> str:
|
||||||
@ -112,6 +136,7 @@ class BaseClass(rx.Base):
|
|||||||
optional_int: Optional[int] = None
|
optional_int: Optional[int] = None
|
||||||
sqla_tag: Optional[SQLATag] = None
|
sqla_tag: Optional[SQLATag] = None
|
||||||
labels: List[SQLALabel] = []
|
labels: List[SQLALabel] = []
|
||||||
|
dict_str_str: Dict[str, str] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def str_property(self) -> str:
|
def str_property(self) -> str:
|
||||||
@ -142,6 +167,7 @@ class BareClass:
|
|||||||
optional_int: Optional[int] = None
|
optional_int: Optional[int] = None
|
||||||
sqla_tag: Optional[SQLATag] = None
|
sqla_tag: Optional[SQLATag] = None
|
||||||
labels: List[SQLALabel] = []
|
labels: List[SQLALabel] = []
|
||||||
|
dict_str_str: Dict[str, str] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def str_property(self) -> str:
|
def str_property(self) -> str:
|
||||||
@ -173,6 +199,7 @@ class AttrClass:
|
|||||||
optional_int: Optional[int] = None
|
optional_int: Optional[int] = None
|
||||||
sqla_tag: Optional[SQLATag] = None
|
sqla_tag: Optional[SQLATag] = None
|
||||||
labels: List[SQLALabel] = []
|
labels: List[SQLALabel] = []
|
||||||
|
dict_str_str: Dict[str, str] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def str_property(self) -> str:
|
def str_property(self) -> str:
|
||||||
@ -193,7 +220,15 @@ class AttrClass:
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass, AttrClass])
|
@pytest.fixture(
|
||||||
|
params=[
|
||||||
|
SQLAClass,
|
||||||
|
BaseClass,
|
||||||
|
BareClass,
|
||||||
|
ModelClass,
|
||||||
|
AttrClass,
|
||||||
|
]
|
||||||
|
)
|
||||||
def cls(request: pytest.FixtureRequest) -> type:
|
def cls(request: pytest.FixtureRequest) -> type:
|
||||||
"""Fixture for the class to test.
|
"""Fixture for the class to test.
|
||||||
|
|
||||||
@ -216,6 +251,7 @@ def cls(request: pytest.FixtureRequest) -> type:
|
|||||||
pytest.param("optional_int", Optional[int], id="Optional[int]"),
|
pytest.param("optional_int", Optional[int], id="Optional[int]"),
|
||||||
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
|
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
|
||||||
pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
|
pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
|
||||||
|
pytest.param("dict_str_str", Dict[str, str], id="Dict[str, str]"),
|
||||||
pytest.param("str_property", str, id="str_property"),
|
pytest.param("str_property", str, id="str_property"),
|
||||||
pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
|
pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
|
||||||
],
|
],
|
||||||
|
Loading…
Reference in New Issue
Block a user