Improved get_attribute_access_type (#3156)

This commit is contained in:
benedikt-bartscher 2024-04-27 02:28:30 +02:00 committed by GitHub
parent 74eaab5e19
commit c2017b295e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 208 additions and 8 deletions

View File

@ -4,16 +4,19 @@ from __future__ import annotations
import contextlib
import inspect
import sys
import types
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
Union,
_GenericAlias, # type: ignore
@ -37,11 +40,16 @@ except ModuleNotFoundError:
from sqlalchemy.ext.associationproxy import AssociationProxyInstance
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
QueryableAttribute,
Relationship,
)
from reflex import constants
from reflex.base import Base
from reflex.utils import serializers
from reflex.utils import console, serializers
# Potential GenericAlias types for isinstance checks.
GenericAliasTypes = [_GenericAlias]
@ -76,6 +84,13 @@ StateIterVar = Union[list, set, tuple]
ArgsSpec = Callable
PrimitiveToAnnotation = {
list: List,
tuple: Tuple,
dict: Dict,
}
class Unset:
"""A class to represent an unset value.
@ -192,7 +207,19 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
insp = sqlalchemy.inspect(cls)
if name in insp.columns:
return insp.columns[name].type.python_type
# check for list types
column = insp.columns[name]
column_type = column.type
type_ = insp.columns[name].type.python_type
if hasattr(column_type, "item_type") and (
item_type := column_type.item_type.python_type # type: ignore
):
if type_ in PrimitiveToAnnotation:
type_ = PrimitiveToAnnotation[type_] # type: ignore
type_ = type_[item_type] # type: ignore
if column.nullable:
type_ = Optional[type_]
return type_
if name not in insp.all_orm_descriptors:
return None
descriptor = insp.all_orm_descriptors[name]
@ -202,11 +229,10 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
prop = descriptor.property
if not isinstance(prop, Relationship):
return None
class_ = prop.mapper.class_
if prop.uselist:
return List[class_]
else:
return class_
type_ = prop.mapper.class_
# TODO: check for nullable?
type_ = List[type_] if prop.uselist else Optional[type_]
return type_
if isinstance(attr, AssociationProxyInstance):
return List[
get_attribute_access_type(
@ -232,6 +258,19 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
if type_ is not None:
# Return the first attribute type that is accessible.
return type_
elif isinstance(cls, type):
# Bare class
if sys.version_info >= (3, 10):
exceptions = NameError
else:
exceptions = (NameError, TypeError)
try:
hints = get_type_hints(cls)
if name in hints:
return hints[name]
except exceptions as e:
console.warn(f"Failed to resolve ForwardRefs for {cls}.{name} due to {e}")
pass
return None # Attribute is not accessible.

View File

@ -0,0 +1,161 @@
from __future__ import annotations
from typing import List, Optional
import pytest
import sqlalchemy
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
import reflex as rx
from reflex.utils.types import GenericType, get_attribute_access_type
class SQLABase(DeclarativeBase):
"""Base class for bare SQLAlchemy models."""
pass
class SQLATag(SQLABase):
"""Tag sqlalchemy model."""
__tablename__: str = "tag"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column()
class SQLALabel(SQLABase):
"""Label sqlalchemy model."""
__tablename__: str = "label"
id: Mapped[int] = mapped_column(primary_key=True)
test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
test: Mapped[SQLAClass] = relationship(back_populates="labels")
class SQLAClass(SQLABase):
"""Test sqlalchemy model."""
__tablename__: str = "test"
id: Mapped[int] = mapped_column(primary_key=True)
count: Mapped[int] = mapped_column()
name: Mapped[str] = mapped_column()
int_list: Mapped[List[int]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
)
str_list: Mapped[List[str]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
)
optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name
class ModelClass(rx.Model):
"""Test reflex model."""
count: int = 0
name: str = "test"
int_list: List[int] = []
str_list: List[str] = []
optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name
class BaseClass(rx.Base):
"""Test rx.Base class."""
count: int = 0
name: str = "test"
int_list: List[int] = []
str_list: List[str] = []
optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name
class BareClass:
"""Bare python class."""
count: int = 0
name: str = "test"
int_list: List[int] = []
str_list: List[str] = []
optional_int: Optional[int] = None
sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = []
@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name
@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass])
def cls(request: pytest.FixtureRequest) -> type:
"""Fixture for the class to test.
Args:
request: pytest request object.
Returns:
Class to test.
"""
return request.param
@pytest.mark.parametrize(
"attr, expected",
[
pytest.param("count", int, id="int"),
pytest.param("name", str, id="str"),
pytest.param("int_list", List[int], id="List[int]"),
pytest.param("str_list", List[str], id="List[str]"),
pytest.param("optional_int", Optional[int], id="Optional[int]"),
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
pytest.param("str_property", str, id="str_property"),
],
)
def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None:
"""Test get_attribute_access_type returns the correct type.
Args:
cls: Class to test.
attr: Attribute to test.
expected: Expected type.
"""
assert get_attribute_access_type(cls, attr) == expected