fix: handle default_factory in get_attribute_access_type, add tests for sqla dataclasses

This commit is contained in:
Benedikt Bartscher 2024-12-10 23:56:20 +01:00
parent fd0fd2c6d4
commit 930c1e2129
No known key found for this signature in database
2 changed files with 80 additions and 2 deletions

View File

@ -331,7 +331,11 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
type_ = field.outer_type_ type_ = field.outer_type_
if isinstance(type_, ModelField): if isinstance(type_, ModelField):
type_ = type_.type_ type_ = type_.type_
if not field.required and field.default is None: if (
not field.required
and field.default is None
and field.default_factory is None
):
# Ensure frontend uses null coalescing when accessing. # Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_] type_ = Optional[type_]
return type_ return type_

View File

@ -3,11 +3,19 @@ from __future__ import annotations
from typing import Dict, List, Optional, Type, Union from typing import Dict, List, Optional, Type, Union
import attrs import attrs
import pydantic.v1
import pytest import pytest
import sqlalchemy import sqlalchemy
import sqlmodel
from sqlalchemy import JSON, TypeDecorator from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
MappedAsDataclass,
mapped_column,
relationship,
)
import reflex as rx import reflex as rx
from reflex.utils.types import GenericType, get_attribute_access_type from reflex.utils.types import GenericType, get_attribute_access_type
@ -53,6 +61,10 @@ class SQLALabel(SQLABase):
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id")) test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
test: Mapped[SQLAClass] = relationship(back_populates="labels") test: Mapped[SQLAClass] = relationship(back_populates="labels")
test_dataclass_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey("test_dataclass.id")
)
test_dataclass: Mapped[SQLAClassDataclass] = relationship(back_populates="labels")
class SQLAClass(SQLABase): class SQLAClass(SQLABase):
@ -75,6 +87,62 @@ class SQLAClass(SQLABase):
# do not use lower case dict here! # do not use lower case dict here!
# https://github.com/sqlalchemy/sqlalchemy/issues/9902 # https://github.com/sqlalchemy/sqlalchemy/issues/9902
dict_str_str: Mapped[Dict[str, str]] = mapped_column() dict_str_str: Mapped[Dict[str, str]] = mapped_column()
default_factory: Mapped[List[int]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER),
)
@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name
@hybrid_property
def str_or_int_property(self) -> Union[str, int]:
"""String or int property.
Returns:
Name attribute
"""
return self.name
@hybrid_property
def first_label(self) -> Optional[SQLALabel]:
"""First label property.
Returns:
First label
"""
return self.labels[0] if self.labels else None
class SQLAClassDataclass(MappedAsDataclass, SQLABase):
"""Test sqlalchemy model."""
id: Mapped[int] = mapped_column(primary_key=True)
count: Mapped[int] = mapped_column()
name: Mapped[str] = mapped_column()
int_list: Mapped[List[int]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
)
str_list: Mapped[List[str]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
)
optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test_dataclass")
# do not use lower case dict here!
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
dict_str_str: Mapped[Dict[str, str]] = mapped_column()
default_factory: Mapped[List[int]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER),
default_factory=list,
)
__tablename__: str = "test_dataclass"
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -115,6 +183,7 @@ class ModelClass(rx.Model):
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {} dict_str_str: Dict[str, str] = {}
default_factory: List[int] = sqlmodel.Field(default_factory=list)
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -155,6 +224,7 @@ class BaseClass(rx.Base):
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {} dict_str_str: Dict[str, str] = {}
default_factory: List[int] = pydantic.v1.Field(default_factory=list)
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -195,6 +265,7 @@ class BareClass:
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {} dict_str_str: Dict[str, str] = {}
default_factory: List[int] = []
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -236,6 +307,7 @@ class AttrClass:
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {} dict_str_str: Dict[str, str] = {}
default_factory: List[int] = attrs.field(factory=list)
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -268,6 +340,7 @@ class AttrClass:
@pytest.fixture( @pytest.fixture(
params=[ params=[
SQLAClass, SQLAClass,
SQLAClassDataclass,
BaseClass, BaseClass,
BareClass, BareClass,
ModelClass, ModelClass,
@ -300,6 +373,7 @@ def cls(request: pytest.FixtureRequest) -> type:
pytest.param("str_property", str, id="str_property"), pytest.param("str_property", str, id="str_property"),
pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"), pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"),
pytest.param("first_label", Optional[SQLALabel], id="first_label"), pytest.param("first_label", Optional[SQLALabel], id="first_label"),
pytest.param("default_factory", List[int], id="default_factory"),
], ],
) )
def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None: def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None: