fix: handle default_factory in get_attribute_access_type (#4517)

* fix: handle default_factory in get_attribute_access_type, add tests for sqla dataclasses

* only test classes which have default_factory + add test for no default
This commit is contained in:
benedikt-bartscher 2024-12-12 03:22:31 +01:00 committed by GitHub
parent 95eb663347
commit e4b5755568
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 117 additions and 17 deletions

View File

@ -331,7 +331,11 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
type_ = field.outer_type_ type_ = field.outer_type_
if isinstance(type_, ModelField): if isinstance(type_, ModelField):
type_ = type_.type_ type_ = type_.type_
if not field.required and field.default is None: if (
not field.required
and field.default is None
and field.default_factory is None
):
# Ensure frontend uses null coalescing when accessing. # Ensure frontend uses null coalescing when accessing.
type_ = Optional[type_] type_ = Optional[type_]
return type_ return type_

View File

@ -3,11 +3,19 @@ from __future__ import annotations
from typing import Dict, List, Optional, Type, Union from typing import Dict, List, Optional, Type, Union
import attrs import attrs
import pydantic.v1
import pytest import pytest
import sqlalchemy import sqlalchemy
import sqlmodel
from sqlalchemy import JSON, TypeDecorator from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
MappedAsDataclass,
mapped_column,
relationship,
)
import reflex as rx import reflex as rx
from reflex.utils.types import GenericType, get_attribute_access_type from reflex.utils.types import GenericType, get_attribute_access_type
@ -53,6 +61,10 @@ class SQLALabel(SQLABase):
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id")) test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id"))
test: Mapped[SQLAClass] = relationship(back_populates="labels") test: Mapped[SQLAClass] = relationship(back_populates="labels")
test_dataclass_id: Mapped[int] = mapped_column(
sqlalchemy.ForeignKey("test_dataclass.id")
)
test_dataclass: Mapped[SQLAClassDataclass] = relationship(back_populates="labels")
class SQLAClass(SQLABase): class SQLAClass(SQLABase):
@ -104,9 +116,64 @@ class SQLAClass(SQLABase):
return self.labels[0] if self.labels else None return self.labels[0] if self.labels else None
class SQLAClassDataclass(MappedAsDataclass, SQLABase):
"""Test sqlalchemy model."""
id: Mapped[int] = mapped_column(primary_key=True)
no_default: Mapped[int] = mapped_column(nullable=True)
count: Mapped[int] = mapped_column()
name: Mapped[str] = mapped_column()
int_list: Mapped[List[int]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER)
)
str_list: Mapped[List[str]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.String)
)
optional_int: Mapped[Optional[int]] = mapped_column(nullable=True)
sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id))
sqla_tag: Mapped[Optional[SQLATag]] = relationship()
labels: Mapped[List[SQLALabel]] = relationship(back_populates="test_dataclass")
# do not use lower case dict here!
# https://github.com/sqlalchemy/sqlalchemy/issues/9902
dict_str_str: Mapped[Dict[str, str]] = mapped_column()
default_factory: Mapped[List[int]] = mapped_column(
sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER),
default_factory=list,
)
__tablename__: str = "test_dataclass"
@property
def str_property(self) -> str:
"""String property.
Returns:
Name attribute
"""
return self.name
@hybrid_property
def str_or_int_property(self) -> Union[str, int]:
"""String or int property.
Returns:
Name attribute
"""
return self.name
@hybrid_property
def first_label(self) -> Optional[SQLALabel]:
"""First label property.
Returns:
First label
"""
return self.labels[0] if self.labels else None
class ModelClass(rx.Model): class ModelClass(rx.Model):
"""Test reflex model.""" """Test reflex model."""
no_default: Optional[int] = sqlmodel.Field(nullable=True)
count: int = 0 count: int = 0
name: str = "test" name: str = "test"
int_list: List[int] = [] int_list: List[int] = []
@ -115,6 +182,7 @@ class ModelClass(rx.Model):
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {} dict_str_str: Dict[str, str] = {}
default_factory: List[int] = sqlmodel.Field(default_factory=list)
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -147,6 +215,7 @@ class ModelClass(rx.Model):
class BaseClass(rx.Base): class BaseClass(rx.Base):
"""Test rx.Base class.""" """Test rx.Base class."""
no_default: Optional[int] = pydantic.v1.Field(required=False)
count: int = 0 count: int = 0
name: str = "test" name: str = "test"
int_list: List[int] = [] int_list: List[int] = []
@ -155,6 +224,7 @@ class BaseClass(rx.Base):
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {} dict_str_str: Dict[str, str] = {}
default_factory: List[int] = pydantic.v1.Field(default_factory=list)
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -236,6 +306,7 @@ class AttrClass:
sqla_tag: Optional[SQLATag] = None sqla_tag: Optional[SQLATag] = None
labels: List[SQLALabel] = [] labels: List[SQLALabel] = []
dict_str_str: Dict[str, str] = {} dict_str_str: Dict[str, str] = {}
default_factory: List[int] = attrs.field(factory=list)
@property @property
def str_property(self) -> str: def str_property(self) -> str:
@ -265,27 +336,17 @@ class AttrClass:
return self.labels[0] if self.labels else None return self.labels[0] if self.labels else None
@pytest.fixture( @pytest.mark.parametrize(
params=[ "cls",
[
SQLAClass, SQLAClass,
SQLAClassDataclass,
BaseClass, BaseClass,
BareClass, BareClass,
ModelClass, ModelClass,
AttrClass, AttrClass,
] ],
) )
def cls(request: pytest.FixtureRequest) -> type:
"""Fixture for the class to test.
Args:
request: pytest request object.
Returns:
Class to test.
"""
return request.param
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attr, expected", "attr, expected",
[ [
@ -311,3 +372,38 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType)
expected: Expected type. expected: Expected type.
""" """
assert get_attribute_access_type(cls, attr) == expected assert get_attribute_access_type(cls, attr) == expected
@pytest.mark.parametrize(
"cls",
[
SQLAClassDataclass,
BaseClass,
ModelClass,
AttrClass,
],
)
def test_get_attribute_access_type_default_factory(cls: type) -> None:
"""Test get_attribute_access_type returns the correct type for default factory fields.
Args:
cls: Class to test.
"""
assert get_attribute_access_type(cls, "default_factory") == List[int]
@pytest.mark.parametrize(
"cls",
[
SQLAClassDataclass,
BaseClass,
ModelClass,
],
)
def test_get_attribute_access_type_no_default(cls: type) -> None:
"""Test get_attribute_access_type returns the correct type for fields with no default which are not required.
Args:
cls: Class to test.
"""
assert get_attribute_access_type(cls, "no_default") == Optional[int]