fix tests
This commit is contained in:
parent
d0940b9cef
commit
5396f80604
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import typing
|
||||||
from typing import Optional, Type, Union
|
from typing import Optional, Type, Union
|
||||||
|
|
||||||
import attrs
|
import attrs
|
||||||
@ -352,8 +353,8 @@ class AttrClass:
|
|||||||
[
|
[
|
||||||
pytest.param("count", int, id="int"),
|
pytest.param("count", int, id="int"),
|
||||||
pytest.param("name", str, id="str"),
|
pytest.param("name", str, id="str"),
|
||||||
pytest.param("int_list", list[int], id="list[int]"),
|
pytest.param("int_list", (list[int], typing.List[int]), id="list[int]"),
|
||||||
pytest.param("str_list", list[str], id="list[str]"),
|
pytest.param("str_list", (list[str], typing.List[str]), id="list[str]"),
|
||||||
pytest.param("optional_int", Optional[int], id="Optional[int]"),
|
pytest.param("optional_int", Optional[int], id="Optional[int]"),
|
||||||
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
|
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
|
||||||
pytest.param("labels", list[SQLALabel], id="list[SQLALabel]"),
|
pytest.param("labels", list[SQLALabel], id="list[SQLALabel]"),
|
||||||
@ -371,25 +372,31 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType)
|
|||||||
attr: Attribute to test.
|
attr: Attribute to test.
|
||||||
expected: Expected type.
|
expected: Expected type.
|
||||||
"""
|
"""
|
||||||
assert get_attribute_access_type(cls, attr) == expected
|
if isinstance(expected, tuple):
|
||||||
|
assert get_attribute_access_type(cls, attr) in expected
|
||||||
|
else:
|
||||||
|
assert get_attribute_access_type(cls, attr) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"cls",
|
("cls", "expected"),
|
||||||
[
|
[
|
||||||
SQLAClassDataclass,
|
(SQLAClassDataclass, typing.List[int]),
|
||||||
BaseClass,
|
(BaseClass, list[int]),
|
||||||
ModelClass,
|
(ModelClass, list[int]),
|
||||||
AttrClass,
|
(AttrClass, list[int]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_get_attribute_access_type_default_factory(cls: type) -> None:
|
def test_get_attribute_access_type_default_factory(
|
||||||
|
cls: type, expected: GenericType
|
||||||
|
) -> None:
|
||||||
"""Test get_attribute_access_type returns the correct type for default factory fields.
|
"""Test get_attribute_access_type returns the correct type for default factory fields.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cls: Class to test.
|
cls: Class to test.
|
||||||
|
expected: Expected type.
|
||||||
"""
|
"""
|
||||||
assert get_attribute_access_type(cls, "default_factory") == list[int]
|
assert get_attribute_access_type(cls, "default_factory") == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
Loading…
Reference in New Issue
Block a user