From 2dd6b063f5f8e7a03425b874102785f3171faeae Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Wed, 11 Dec 2024 00:36:46 +0100 Subject: [PATCH] only test classes which have default_factory + add test for no default --- tests/units/test_attribute_access_type.py | 62 +++++++++++++++-------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/tests/units/test_attribute_access_type.py b/tests/units/test_attribute_access_type.py index c204e5551..d08c17c8c 100644 --- a/tests/units/test_attribute_access_type.py +++ b/tests/units/test_attribute_access_type.py @@ -87,9 +87,6 @@ class SQLAClass(SQLABase): # do not use lower case dict here! # https://github.com/sqlalchemy/sqlalchemy/issues/9902 dict_str_str: Mapped[Dict[str, str]] = mapped_column() - default_factory: Mapped[List[int]] = mapped_column( - sqlalchemy.types.ARRAY(item_type=sqlalchemy.INTEGER), - ) @property def str_property(self) -> str: @@ -123,6 +120,7 @@ class SQLAClassDataclass(MappedAsDataclass, SQLABase): """Test sqlalchemy model.""" id: Mapped[int] = mapped_column(primary_key=True) + no_default: Mapped[int] = mapped_column(nullable=True) count: Mapped[int] = mapped_column() name: Mapped[str] = mapped_column() int_list: Mapped[List[int]] = mapped_column( @@ -175,6 +173,7 @@ class SQLAClassDataclass(MappedAsDataclass, SQLABase): class ModelClass(rx.Model): """Test reflex model.""" + no_default: Optional[int] = sqlmodel.Field(nullable=True) count: int = 0 name: str = "test" int_list: List[int] = [] @@ -216,6 +215,7 @@ class ModelClass(rx.Model): class BaseClass(rx.Base): """Test rx.Base class.""" + no_default: Optional[int] = pydantic.v1.Field(required=False) count: int = 0 name: str = "test" int_list: List[int] = [] @@ -265,7 +265,6 @@ class BareClass: sqla_tag: Optional[SQLATag] = None labels: List[SQLALabel] = [] dict_str_str: Dict[str, str] = {} - default_factory: List[int] = [] @property def str_property(self) -> str: @@ -337,28 +336,17 @@ class AttrClass: return self.labels[0] if self.labels else None -@pytest.fixture( - params=[ +@pytest.mark.parametrize( + "cls", + [ SQLAClass, SQLAClassDataclass, BaseClass, BareClass, ModelClass, AttrClass, - ] + ], ) -def cls(request: pytest.FixtureRequest) -> type: - """Fixture for the class to test. - - Args: - request: pytest request object. - - Returns: - Class to test. - """ - return request.param - - @pytest.mark.parametrize( "attr, expected", [ @@ -373,7 +361,6 @@ def cls(request: pytest.FixtureRequest) -> type: pytest.param("str_property", str, id="str_property"), pytest.param("str_or_int_property", Union[str, int], id="str_or_int_property"), pytest.param("first_label", Optional[SQLALabel], id="first_label"), - pytest.param("default_factory", List[int], id="default_factory"), ], ) def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) -> None: @@ -385,3 +372,38 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) expected: Expected type. """ assert get_attribute_access_type(cls, attr) == expected + + +@pytest.mark.parametrize( + "cls", + [ + SQLAClassDataclass, + BaseClass, + ModelClass, + AttrClass, + ], +) +def test_get_attribute_access_type_default_factory(cls: type) -> None: + """Test get_attribute_access_type returns the correct type for default factory fields. + + Args: + cls: Class to test. + """ + assert get_attribute_access_type(cls, "default_factory") == List[int] + + +@pytest.mark.parametrize( + "cls", + [ + SQLAClassDataclass, + BaseClass, + ModelClass, + ], +) +def test_get_attribute_access_type_no_default(cls: type) -> None: + """Test get_attribute_access_type returns the correct type for fields with no default which are not required. + + Args: + cls: Class to test. + """ + assert get_attribute_access_type(cls, "no_default") == Optional[int]