diff --git a/tests/units/test_attribute_access_type.py b/tests/units/test_attribute_access_type.py index d1f134881..2d5d87c36 100644 --- a/tests/units/test_attribute_access_type.py +++ b/tests/units/test_attribute_access_type.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from typing import Optional, Type, Union import attrs @@ -352,8 +353,8 @@ class AttrClass: [ pytest.param("count", int, id="int"), pytest.param("name", str, id="str"), - pytest.param("int_list", list[int], id="list[int]"), - pytest.param("str_list", list[str], id="list[str]"), + pytest.param("int_list", (list[int], typing.List[int]), id="list[int]"), + pytest.param("str_list", (list[str], typing.List[str]), id="list[str]"), pytest.param("optional_int", Optional[int], id="Optional[int]"), pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"), pytest.param("labels", list[SQLALabel], id="list[SQLALabel]"), @@ -371,25 +372,31 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType) attr: Attribute to test. expected: Expected type. """ - assert get_attribute_access_type(cls, attr) == expected + if isinstance(expected, tuple): + assert get_attribute_access_type(cls, attr) in expected + else: + assert get_attribute_access_type(cls, attr) == expected @pytest.mark.parametrize( - "cls", + ("cls", "expected"), [ - SQLAClassDataclass, - BaseClass, - ModelClass, - AttrClass, + (SQLAClassDataclass, typing.List[int]), + (BaseClass, list[int]), + (ModelClass, list[int]), + (AttrClass, list[int]), ], ) -def test_get_attribute_access_type_default_factory(cls: type) -> None: +def test_get_attribute_access_type_default_factory( + cls: type, expected: GenericType +) -> None: """Test get_attribute_access_type returns the correct type for default factory fields. Args: cls: Class to test. + expected: Expected type. """ - assert get_attribute_access_type(cls, "default_factory") == list[int] + assert get_attribute_access_type(cls, "default_factory") == expected @pytest.mark.parametrize(