diff --git a/reflex/utils/types.py b/reflex/utils/types.py index bd3584ab6..3bb5eae35 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -245,36 +245,41 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None # check for list types column = insp.columns[name] column_type = column.type - type_ = insp.columns[name].type.python_type - if hasattr(column_type, "item_type") and ( - item_type := column_type.item_type.python_type # type: ignore - ): - if type_ in PrimitiveToAnnotation: - type_ = PrimitiveToAnnotation[type_] # type: ignore - type_ = type_[item_type] # type: ignore - if column.nullable: - type_ = Optional[type_] - return type_ - if name not in insp.all_orm_descriptors: - return None - descriptor = insp.all_orm_descriptors[name] - if hint := get_property_hint(descriptor): - return hint - if isinstance(descriptor, QueryableAttribute): - prop = descriptor.property - if not isinstance(prop, Relationship): - return None - type_ = prop.mapper.class_ - # TODO: check for nullable? - type_ = List[type_] if prop.uselist else Optional[type_] - return type_ - if isinstance(attr, AssociationProxyInstance): - return List[ - get_attribute_access_type( - attr.target_class, - attr.remote_attr.key, # type: ignore[attr-defined] - ) - ] + try: + type_ = insp.columns[name].type.python_type + except NotImplementedError: + type_ = None + if type_ is not None: + if hasattr(column_type, "item_type"): + try: + item_type = column_type.item_type.python_type # type: ignore + except NotImplementedError: + item_type = None + if item_type is not None: + if type_ in PrimitiveToAnnotation: + type_ = PrimitiveToAnnotation[type_] # type: ignore + type_ = type_[item_type] # type: ignore + if column.nullable: + type_ = Optional[type_] + return type_ + if name in insp.all_orm_descriptors: + descriptor = insp.all_orm_descriptors[name] + if hint := get_property_hint(descriptor): + return hint + if isinstance(descriptor, QueryableAttribute): + prop = descriptor.property + if isinstance(prop, Relationship): + type_ = prop.mapper.class_ + # TODO: check for nullable? + type_ = List[type_] if prop.uselist else Optional[type_] + return type_ + if isinstance(attr, AssociationProxyInstance): + return List[ + get_attribute_access_type( + attr.target_class, + attr.remote_attr.key, # type: ignore[attr-defined] + ) + ] elif isinstance(cls, type) and not is_generic_alias(cls) and issubclass(cls, Model): # Check in the annotations directly (for sqlmodel.Relationship) hints = get_type_hints(cls) diff --git a/tests/test_attribute_access_type.py b/tests/test_attribute_access_type.py index 821ccad04..0813a5e62 100644 --- a/tests/test_attribute_access_type.py +++ b/tests/test_attribute_access_type.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import List, Optional, Union +from typing import Dict, List, Optional, Type, Union import attrs import pytest import sqlalchemy +from sqlalchemy import JSON, TypeDecorator from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship @@ -12,10 +13,29 @@ import reflex as rx from reflex.utils.types import GenericType, get_attribute_access_type +class SQLAType(TypeDecorator): + """SQLAlchemy custom dict type.""" + + impl = JSON + + @property + def python_type(self) -> Type[Dict[str, str]]: + """Python type. + + Returns: + Python Type of the column. + """ + return Dict[str, str] + + class SQLABase(DeclarativeBase): """Base class for bare SQLAlchemy models.""" - pass + type_annotation_map = { + # do not use lower case dict here! + # https://github.com/sqlalchemy/sqlalchemy/issues/9902 + Dict[str, str]: SQLAType, + } class SQLATag(SQLABase): @@ -52,6 +72,9 @@ class SQLAClass(SQLABase): sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id)) sqla_tag: Mapped[Optional[SQLATag]] = relationship() labels: Mapped[List[SQLALabel]] = relationship(back_populates="test") + # do not use lower case dict here! + # https://github.com/sqlalchemy/sqlalchemy/issues/9902 + dict_str_str: Mapped[Dict[str, str]] = mapped_column() @property def str_property(self) -> str: @@ -82,6 +105,7 @@ class ModelClass(rx.Model): optional_int: Optional[int] = None sqla_tag: Optional[SQLATag] = None labels: List[SQLALabel] = [] + dict_str_str: Dict[str, str] = {} @property def str_property(self) -> str: @@ -112,6 +136,7 @@ class BaseClass(rx.Base): optional_int: Optional[int] = None sqla_tag: Optional[SQLATag] = None labels: List[SQLALabel] = [] + dict_str_str: Dict[str, str] = {} @property def str_property(self) -> str: @@ -142,6 +167,7 @@ class BareClass: optional_int: Optional[int] = None sqla_tag: Optional[SQLATag] = None labels: List[SQLALabel] = [] + dict_str_str: Dict[str, str] = {} @property def str_property(self) -> str: @@ -173,6 +199,7 @@ class AttrClass: optional_int: Optional[int] = None sqla_tag: Optional[SQLATag] = None labels: List[SQLALabel] = [] + dict_str_str: Dict[str, str] = {} @property def str_property(self) -> str: @@ -193,7 +220,15 @@ class AttrClass: return self.name -@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass, AttrClass]) +@pytest.fixture( + params=[ + SQLAClass, + BaseClass, + BareClass, + ModelClass, + AttrClass, + ] +) def cls(request: pytest.FixtureRequest) -> type: """Fixture for the class to test. @@ -216,6 +251,7 @@ def cls(request: pytest.FixtureRequest) -> type: pytest.param("optional_int", Optional[int], id="Optional[int]"), pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"), pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"), + pytest.param("dict_str_str", Dict[str, str], id="Dict[str, str]"), pytest.param("str_property", str, id="str_property"), pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"), ],