Improved get_attribute_access_type (#3156)
This commit is contained in:
parent
74eaab5e19
commit
c2017b295e
@ -4,16 +4,19 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import sys
|
||||||
import types
|
import types
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
_GenericAlias, # type: ignore
|
_GenericAlias, # type: ignore
|
||||||
@ -37,11 +40,16 @@ except ModuleNotFoundError:
|
|||||||
|
|
||||||
from sqlalchemy.ext.associationproxy import AssociationProxyInstance
|
from sqlalchemy.ext.associationproxy import AssociationProxyInstance
|
||||||
from sqlalchemy.ext.hybrid import hybrid_property
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, QueryableAttribute, Relationship
|
from sqlalchemy.orm import (
|
||||||
|
DeclarativeBase,
|
||||||
|
Mapped,
|
||||||
|
QueryableAttribute,
|
||||||
|
Relationship,
|
||||||
|
)
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.utils import serializers
|
from reflex.utils import console, serializers
|
||||||
|
|
||||||
# Potential GenericAlias types for isinstance checks.
|
# Potential GenericAlias types for isinstance checks.
|
||||||
GenericAliasTypes = [_GenericAlias]
|
GenericAliasTypes = [_GenericAlias]
|
||||||
@ -76,6 +84,13 @@ StateIterVar = Union[list, set, tuple]
|
|||||||
ArgsSpec = Callable
|
ArgsSpec = Callable
|
||||||
|
|
||||||
|
|
||||||
|
PrimitiveToAnnotation = {
|
||||||
|
list: List,
|
||||||
|
tuple: Tuple,
|
||||||
|
dict: Dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Unset:
|
class Unset:
|
||||||
"""A class to represent an unset value.
|
"""A class to represent an unset value.
|
||||||
|
|
||||||
@ -192,7 +207,19 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|||||||
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
|
elif isinstance(cls, type) and issubclass(cls, DeclarativeBase):
|
||||||
insp = sqlalchemy.inspect(cls)
|
insp = sqlalchemy.inspect(cls)
|
||||||
if name in insp.columns:
|
if name in insp.columns:
|
||||||
return insp.columns[name].type.python_type
|
# check for list types
|
||||||
|
column = insp.columns[name]
|
||||||
|
column_type = column.type
|
||||||
|
type_ = insp.columns[name].type.python_type
|
||||||
|
if hasattr(column_type, "item_type") and (
|
||||||
|
item_type := column_type.item_type.python_type # type: ignore
|
||||||
|
):
|
||||||
|
if type_ in PrimitiveToAnnotation:
|
||||||
|
type_ = PrimitiveToAnnotation[type_] # type: ignore
|
||||||
|
type_ = type_[item_type] # type: ignore
|
||||||
|
if column.nullable:
|
||||||
|
type_ = Optional[type_]
|
||||||
|
return type_
|
||||||
if name not in insp.all_orm_descriptors:
|
if name not in insp.all_orm_descriptors:
|
||||||
return None
|
return None
|
||||||
descriptor = insp.all_orm_descriptors[name]
|
descriptor = insp.all_orm_descriptors[name]
|
||||||
@ -202,11 +229,10 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|||||||
prop = descriptor.property
|
prop = descriptor.property
|
||||||
if not isinstance(prop, Relationship):
|
if not isinstance(prop, Relationship):
|
||||||
return None
|
return None
|
||||||
class_ = prop.mapper.class_
|
type_ = prop.mapper.class_
|
||||||
if prop.uselist:
|
# TODO: check for nullable?
|
||||||
return List[class_]
|
type_ = List[type_] if prop.uselist else Optional[type_]
|
||||||
else:
|
return type_
|
||||||
return class_
|
|
||||||
if isinstance(attr, AssociationProxyInstance):
|
if isinstance(attr, AssociationProxyInstance):
|
||||||
return List[
|
return List[
|
||||||
get_attribute_access_type(
|
get_attribute_access_type(
|
||||||
@ -232,6 +258,19 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
|||||||
if type_ is not None:
|
if type_ is not None:
|
||||||
# Return the first attribute type that is accessible.
|
# Return the first attribute type that is accessible.
|
||||||
return type_
|
return type_
|
||||||
|
elif isinstance(cls, type):
|
||||||
|
# Bare class
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
exceptions = NameError
|
||||||
|
else:
|
||||||
|
exceptions = (NameError, TypeError)
|
||||||
|
try:
|
||||||
|
hints = get_type_hints(cls)
|
||||||
|
if name in hints:
|
||||||
|
return hints[name]
|
||||||
|
except exceptions as e:
|
||||||
|
console.warn(f"Failed to resolve ForwardRefs for {cls}.{name} due to {e}")
|
||||||
|
pass
|
||||||
return None # Attribute is not accessible.
|
return None # Attribute is not accessible.
|
||||||
|
|
||||||
|
|
||||||
|
161
tests/test_attribute_access_type.py
Normal file
161
tests/test_attribute_access_type.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
import reflex as rx
|
||||||
|
from reflex.utils.types import GenericType, get_attribute_access_type
|
||||||
|
|
||||||
|
|
||||||
|
class SQLABase(DeclarativeBase):
|
||||||
|
"""Base class for bare SQLAlchemy models."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SQLATag(SQLABase):
|
||||||
|
"""Tag sqlalchemy model."""
|
||||||
|
|
||||||
|
__tablename__: str = "tag"
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column()
|
||||||
|
|
||||||
|
|
||||||
|
class SQLALabel(SQLABase):
|
||||||
|
"""Label sqlalchemy model."""
|
||||||
|
|
||||||
|
__tablename__: str = "label"
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
|
||||||
|
test: Mapped[SQLAClass] = relationship(back_populates="labels")
|
||||||
|
|
||||||
|
|
||||||
|
class SQLAClass(SQLABase):
|
||||||
|
"""Test sqlalchemy model."""
|
||||||
|
|
||||||
|
__tablename__: str = "test"
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
count: Mapped[int] = mapped_column()
|
||||||
|
name: Mapped[str] = mapped_column()
|
||||||
|
int_list: Mapped[List[int]] = mapped_column(
|
||||||
|
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
|
||||||
|
)
|
||||||
|
str_list: Mapped[List[str]] = mapped_column(
|
||||||
|
sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
|
||||||
|
)
|
||||||
|
optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
|
||||||
|
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
|
||||||
|
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
|
||||||
|
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def str_property(self) -> str:
|
||||||
|
"""String property.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Name attribute
|
||||||
|
"""
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
class ModelClass(rx.Model):
|
||||||
|
"""Test reflex model."""
|
||||||
|
|
||||||
|
count: int = 0
|
||||||
|
name: str = "test"
|
||||||
|
int_list: List[int] = []
|
||||||
|
str_list: List[str] = []
|
||||||
|
optional_int: Optional[int] = None
|
||||||
|
sqla_tag: Optional[SQLATag] = None
|
||||||
|
labels: List[SQLALabel] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def str_property(self) -> str:
|
||||||
|
"""String property.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Name attribute
|
||||||
|
"""
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
class BaseClass(rx.Base):
|
||||||
|
"""Test rx.Base class."""
|
||||||
|
|
||||||
|
count: int = 0
|
||||||
|
name: str = "test"
|
||||||
|
int_list: List[int] = []
|
||||||
|
str_list: List[str] = []
|
||||||
|
optional_int: Optional[int] = None
|
||||||
|
sqla_tag: Optional[SQLATag] = None
|
||||||
|
labels: List[SQLALabel] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def str_property(self) -> str:
|
||||||
|
"""String property.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Name attribute
|
||||||
|
"""
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
class BareClass:
|
||||||
|
"""Bare python class."""
|
||||||
|
|
||||||
|
count: int = 0
|
||||||
|
name: str = "test"
|
||||||
|
int_list: List[int] = []
|
||||||
|
str_list: List[str] = []
|
||||||
|
optional_int: Optional[int] = None
|
||||||
|
sqla_tag: Optional[SQLATag] = None
|
||||||
|
labels: List[SQLALabel] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def str_property(self) -> str:
|
||||||
|
"""String property.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Name attribute
|
||||||
|
"""
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[SQLAClass, BaseClass, BareClass, ModelClass])
|
||||||
|
def cls(request: pytest.FixtureRequest) -> type:
|
||||||
|
"""Fixture for the class to test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: pytest request object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Class to test.
|
||||||
|
"""
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"attr, expected",
|
||||||
|
[
|
||||||
|
pytest.param("count", int, id="int"),
|
||||||
|
pytest.param("name", str, id="str"),
|
||||||
|
pytest.param("int_list", List[int], id="List[int]"),
|
||||||
|
pytest.param("str_list", List[str], id="List[str]"),
|
||||||
|
pytest.param("optional_int", Optional[int], id="Optional[int]"),
|
||||||
|
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
|
||||||
|
pytest.param("labels", List[SQLALabel], id="List[SQLALabel]"),
|
||||||
|
pytest.param("str_property", str, id="str_property"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None:
|
||||||
|
"""Test get_attribute_access_type returns the correct type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: Class to test.
|
||||||
|
attr: Attribute to test.
|
||||||
|
expected: Expected type.
|
||||||
|
"""
|
||||||
|
assert get_attribute_access_type(cls, attr) == expected
|
Loading…
Reference in New Issue
Block a user