diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 7138dafb1..fb26b14c9 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -331,7 +331,11 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None type_ = field.outer_type_ if isinstance(type_, ModelField): type_ = type_.type_ - if not field.required and field.default is None: + if ( + not field.required + and field.default is None + and field.default_factory is None + ): # Ensure frontend uses null coalescing when accessing. type_ = Optional[type_] return type_ diff --git a/tests/units/test_attribute_access_type.py b/tests/units/test_attribute_access_type.py index 0d490ec1e..d08c17c8c 100644 --- a/tests/units/test_attribute_access_type.py +++ b/tests/units/test_attribute_access_type.py @@ -3,11 +3,19 @@ from __future__ import annotations from typing import Dict, List, Optional, Type, Union import attrs +import pydantic.v1 import pytest import sqlalchemy +import sqlmodel from sqlalchemy import JSON, TypeDecorator from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + MappedAsDataclass, + mapped_column, + relationship, +) import reflex as rx from reflex.utils.types import GenericType, get_attribute_access_type @@ -53,6 +61,10 @@ class SQLALabel(SQLABase): id: Mapped[int] = mapped_column(primary_key=True) test_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey("test.id")) test: Mapped[SQLAClass] = relationship(back_populates="labels") + test_dataclass_id: Mapped[int] = mapped_column( + sqlalchemy.ForeignKey("test_dataclass.id") + ) + test_dataclass: Mapped[SQLAClassDataclass] = relationship(back_populates="labels") class SQLAClass(SQLABase): @@ -104,9 +116,64 @@ class SQLAClass(SQLABase): return self.labels[0] if self.labels else None +class SQLAClassDataclass(MappedAsDataclass, SQLABase): + """Test sqlalchemy model.""" + + id: Mapped[int] = mapped_column(primary_key=True) + no_default: Mapped[int] = mapped_column(nullable=True) + count: Mapped[int] = mapped_column() + name: Mapped[str] = mapped_column() + int_list: Mapped[List[int]] = mapped_column( + sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER) + ) + str_list: Mapped[List[str]] = mapped_column( + sqlalchemy.types.ARRAY(item_type=sqlalchemy.String) + ) + optional_int: Mapped[Optional[int]] = mapped_column(nullable=True) + sqla_tag_id: Mapped[int] = mapped_column(sqlalchemy.ForeignKey(SQLATag.id)) + sqla_tag: Mapped[Optional[SQLATag]] = relationship() + labels: Mapped[List[SQLALabel]] = relationship(back_populates="test_dataclass") + # do not use lower case dict here! + # https://github.com/sqlalchemy/sqlalchemy/issues/9902 + dict_str_str: Mapped[Dict[str, str]] = mapped_column() + default_factory: Mapped[List[int]] = mapped_column( + sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER), + default_factory=list, + ) + __tablename__: str = "test_dataclass" + + @property + def str_property(self) -> str: + """String property. + + Returns: + Name attribute + """ + return self.name + + @hybrid_property + def str_or_int_property(self) -> Union[str, int]: + """String or int property. + + Returns: + Name attribute + """ + return self.name + + @hybrid_property + def first_label(self) -> Optional[SQLALabel]: + """First label property. + + Returns: + First label + """ + return self.labels[0] if self.labels else None + + class ModelClass(rx.Model): """Test reflex model.""" + no_default: Optional[int] = sqlmodel.Field(nullable=True) count: int = 0 name: str = "test" int_list: List[int] = [] @@ -115,6 +182,7 @@ class ModelClass(rx.Model): sqla_tag: Optional[SQLATag] = None labels: List[SQLALabel] = [] dict_str_str: Dict[str, str] = {} + default_factory: List[int] = sqlmodel.Field(default_factory=list) @property def str_property(self) -> str: @@ -147,6 +215,7 @@ class ModelClass(rx.Model): class BaseClass(rx.Base): """Test rx.Base class.""" + no_default: Optional[int] = pydantic.v1.Field(required=False) count: int = 0 name: str = "test" int_list: List[int] = [] @@ -155,6 +224,7 @@ class BaseClass(rx.Base): sqla_tag: Optional[SQLATag] = None labels: List[SQLALabel] = [] dict_str_str: Dict[str, str] = {} + default_factory: List[int] = pydantic.v1.Field(default_factory=list) @property def str_property(self) -> str: @@ -236,6 +306,7 @@ class AttrClass: sqla_tag: Optional[SQLATag] = None labels: List[SQLALabel] = [] dict_str_str: Dict[str, str] = {} + default_factory: List[int] = attrs.field(factory=list) @property def str_property(self) -> str: @@ -265,27 +336,17 @@ class AttrClass: return self.labels[0] if self.labels else None -@pytest.fixture( - params=[ +@pytest.mark.parametrize( + "cls", + [ SQLAClass, + SQLAClassDataclass, BaseClass, BareClass, ModelClass, AttrClass, - ] + ], ) -def cls(request: pytest.FixtureRequest) -> type: - """Fixture for the class to test. - - Args: - request: pytest request object. - - Returns: - Class to test. - """ - return request.param - - @pytest.mark.parametrize( "attr, expected", [ @@ -311,3 +372,38 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) expected: Expected type. """ assert get_attribute_access_type(cls, attr) == expected + + +@pytest.mark.parametrize( + "cls", + [ + SQLAClassDataclass, + BaseClass, + ModelClass, + AttrClass, + ], +) +def test_get_attribute_access_type_default_factory(cls: type) -> None: + """Test get_attribute_access_type returns the correct type for default factory fields. + + Args: + cls: Class to test. + """ + assert get_attribute_access_type(cls, "default_factory") == List[int] + + +@pytest.mark.parametrize( + "cls", + [ + SQLAClassDataclass, + BaseClass, + ModelClass, + ], +) +def test_get_attribute_access_type_no_default(cls: type) -> None: + """Test get_attribute_access_type returns the correct type for fields with no default which are not required. + + Args: + cls: Class to test. + """ + assert get_attribute_access_type(cls, "no_default") == Optional[int]