From c2017b295e2d67da66debfdb7cfa67353bfeee00 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Sat, 27 Apr 2024 02:28:30 +0200 Subject: [PATCH] Improved get_attribute_access_type (#3156) --- reflex/utils/types.py | 55 ++++++++-- tests/test_attribute_access_type.py | 161 ++++++++++++++++++++++++++++ 2 files changed, 208 insertions(+), 8 deletions(-) create mode 100644 tests/test_attribute_access_type.py diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 413a88dd8..32d3e406c 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -4,16 +4,19 @@ from __future__ import annotations import contextlib import inspect +import sys import types from functools import wraps from typing import ( TYPE_CHECKING, Any, Callable, + Dict, Iterable, List, Literal, Optional, + Tuple, Type, Union, _GenericAlias, # type: ignore @@ -37,11 +40,16 @@ except ModuleNotFoundError: from sqlalchemy.ext.associationproxy import AssociationProxyInstance from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + QueryableAttribute, + Relationship, +) from reflex import constants from reflex.base import Base -from reflex.utils import serializers +from reflex.utils import console, serializers # Potential GenericAlias types for isinstance checks. GenericAliasTypes = [_GenericAlias] @@ -76,6 +84,13 @@ StateIterVar = Union[list, set, tuple] ArgsSpec = Callable +PrimitiveToAnnotation = { + list: List, + tuple: Tuple, + dict: Dict, +} + + class Unset: """A class to represent an unset value. @@ -192,7 +207,19 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None elif isinstance(cls, type) and issubclass(cls, DeclarativeBase): insp = sqlalchemy.inspect(cls) if name in insp.columns: - return insp.columns[name].type.python_type + # check for list types + column = insp.columns[name] + column_type = column.type + type_ = insp.columns[name].type.python_type + if hasattr(column_type, "item_type") and ( + item_type := column_type.item_type.python_type # type: ignore + ): + if type_ in PrimitiveToAnnotation: + type_ = PrimitiveToAnnotation[type_] # type: ignore + type_ = type_[item_type] # type: ignore + if column.nullable: + type_ = Optional[type_] + return type_ if name not in insp.all_orm_descriptors: return None descriptor = insp.all_orm_descriptors[name] @@ -202,11 +229,10 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None prop = descriptor.property if not isinstance(prop, Relationship): return None - class_ = prop.mapper.class_ - if prop.uselist: - return List[class_] - else: - return class_ + type_ = prop.mapper.class_ + # TODO: check for nullable? + type_ = List[type_] if prop.uselist else Optional[type_] + return type_ if isinstance(attr, AssociationProxyInstance): return List[ get_attribute_access_type( @@ -232,6 +258,19 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None if type_ is not None: # Return the first attribute type that is accessible. return type_ + elif isinstance(cls, type): + # Bare class + if sys.version_info >= (3, 10): + exceptions = NameError + else: + exceptions = (NameError, TypeError) + try: + hints = get_type_hints(cls) + if name in hints: + return hints[name] + except exceptions as e: + console.warn(f"Failed to resolve ForwardRefs for {cls}.{name} due to {e}") + pass return None # Attribute is not accessible. diff --git a/tests/test_attribute_access_type.py b/tests/test_attribute_access_type.py new file mode 100644 index 000000000..fd17a2b37 --- /dev/null +++ b/tests/test_attribute_access_type.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import List, Optional + +import pytest +import sqlalchemy +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +import reflex as rx +from reflex.utils.types import GenericType, get_attribute_access_type + + +class SQLABase(DeclarativeBase): + """Base class for bare SQLAlchemy models.""" + + pass + + +class SQLATag(SQLABase): + """Tag sqlalchemy model.""" + + __tablename__: str = "tag" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + +class SQLALabel(SQLABase): + """Label sqlalchemy model.""" + + __tablename__: str = "label" + id: Mapped[int] = mapped_column(primary_key=True) + test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id")) + test: Mapped[SQLAClass] = relationship(back_populates="labels") + + +class SQLAClass(SQLABase): + """Test sqlalchemy model.""" + + __tablename__: str = "test" + id: Mapped[int] = mapped_column(primary_key=True) + count: Mapped[int] = mapped_column() + name: Mapped[str] = mapped_column() + int_list: Mapped[List[int]] = mapped_column( + sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER) + ) + str_list: Mapped[List[str]] = mapped_column( + sqlalchemy.types.ARRAY(item_type=sqlalchemy.String) + ) + optional_int: Mapped[Optional[int]] = mapped_column(nullable=True) + sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id)) + sqla_tag: Mapped[Optional[SQLATag]] = relationship() + labels: Mapped[List[SQLALabel]] = relationship(back_populates="test") + + @property + def str_property(self) -> str: + """String property. + + Returns: + Name attribute + """ + return self.name + + +class ModelClass(rx.Model): + """Test reflex model.""" + + count: int = 0 + name: str = "test" + int_list: List[int] = [] + str_list: List[str] = [] + optional_int: Optional[int] = None + sqla_tag: Optional[SQLATag] = None + labels: List[SQLALabel] = [] + + @property + def str_property(self) -> str: + """String property. + + Returns: + Name attribute + """ + return self.name + + +class BaseClass(rx.Base): + """Test rx.Base class.""" + + count: int = 0 + name: str = "test" + int_list: List[int] = [] + str_list: List[str] = [] + optional_int: Optional[int] = None + sqla_tag: Optional[SQLATag] = None + labels: List[SQLALabel] = [] + + @property + def str_property(self) -> str: + """String property. + + Returns: + Name attribute + """ + return self.name + + +class BareClass: + """Bare python class.""" + + count: int = 0 + name: str = "test" + int_list: List[int] = [] + str_list: List[str] = [] + optional_int: Optional[int] = None + sqla_tag: Optional[SQLATag] = None + labels: List[SQLALabel] = [] + + @property + def str_property(self) -> str: + """String property. + + Returns: + Name attribute + """ + return self.name + + +@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass]) +def cls(request: pytest.FixtureRequest) -> type: + """Fixture for the class to test. + + Args: + request: pytest request object. + + Returns: + Class to test. + """ + return request.param + + +@pytest.mark.parametrize( + "attr, expected", + [ + pytest.param("count", int, id="int"), + pytest.param("name", str, id="str"), + pytest.param("int_list", List[int], id="List[int]"), + pytest.param("str_list", List[str], id="List[str]"), + pytest.param("optional_int", Optional[int], id="Optional[int]"), + pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"), + pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"), + pytest.param("str_property", str, id="str_property"), + ], +) +def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None: + """Test get_attribute_access_type returns the correct type. + + Args: + cls: Class to test. + attr: Attribute to test. + expected: Expected type. + """ + assert get_attribute_access_type(cls, attr) == expected