fix tests

This commit is contained in:
Khaleel Al-Adhami 2025-02-18 14:16:14 -08:00
parent d0940b9cef
commit 5396f80604

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import typing
from typing import Optional, Type, Union
import attrs
@ -352,8 +353,8 @@ class AttrClass:
[
pytest.param("count", int, id="int"),
pytest.param("name", str, id="str"),
pytest.param("int_list", list[int], id="list[int]"),
pytest.param("str_list", list[str], id="list[str]"),
pytest.param("int_list", (list[int], typing.List[int]), id="list[int]"),
pytest.param("str_list", (list[str], typing.List[str]), id="list[str]"),
pytest.param("optional_int", Optional[int], id="Optional[int]"),
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
pytest.param("labels", list[SQLALabel], id="list[SQLALabel]"),
@ -371,25 +372,31 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType)
attr: Attribute to test.
expected: Expected type.
"""
assert get_attribute_access_type(cls, attr) == expected
if isinstance(expected, tuple):
assert get_attribute_access_type(cls, attr) in expected
else:
assert get_attribute_access_type(cls, attr) == expected
@pytest.mark.parametrize(
"cls",
("cls", "expected"),
[
SQLAClassDataclass,
BaseClass,
ModelClass,
AttrClass,
(SQLAClassDataclass, typing.List[int]),
(BaseClass, list[int]),
(ModelClass, list[int]),
(AttrClass, list[int]),
],
)
def test_get_attribute_access_type_default_factory(cls: type) -> None:
def test_get_attribute_access_type_default_factory(
cls: type, expected: GenericType
) -> None:
"""Test get_attribute_access_type returns the correct type for default factory fields.
Args:
cls: Class to test.
expected: Expected type.
"""
assert get_attribute_access_type(cls, "default_factory") == list[int]
assert get_attribute_access_type(cls, "default_factory") == expected
@pytest.mark.parametrize(