fix tests
This commit is contained in:
parent
d0940b9cef
commit
5396f80604
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import attrs
|
||||
@ -352,8 +353,8 @@ class AttrClass:
|
||||
[
|
||||
pytest.param("count", int, id="int"),
|
||||
pytest.param("name", str, id="str"),
|
||||
pytest.param("int_list", list[int], id="list[int]"),
|
||||
pytest.param("str_list", list[str], id="list[str]"),
|
||||
pytest.param("int_list", (list[int], typing.List[int]), id="list[int]"),
|
||||
pytest.param("str_list", (list[str], typing.List[str]), id="list[str]"),
|
||||
pytest.param("optional_int", Optional[int], id="Optional[int]"),
|
||||
pytest.param("sqla_tag", Optional[SQLATag], id="Optional[SQLATag]"),
|
||||
pytest.param("labels", list[SQLALabel], id="list[SQLALabel]"),
|
||||
@ -371,25 +372,31 @@ def test_get_attribute_access_type(cls: type, attr: str, expected: GenericType)
|
||||
attr: Attribute to test.
|
||||
expected: Expected type.
|
||||
"""
|
||||
assert get_attribute_access_type(cls, attr) == expected
|
||||
if isinstance(expected, tuple):
|
||||
assert get_attribute_access_type(cls, attr) in expected
|
||||
else:
|
||||
assert get_attribute_access_type(cls, attr) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cls",
|
||||
("cls", "expected"),
|
||||
[
|
||||
SQLAClassDataclass,
|
||||
BaseClass,
|
||||
ModelClass,
|
||||
AttrClass,
|
||||
(SQLAClassDataclass, typing.List[int]),
|
||||
(BaseClass, list[int]),
|
||||
(ModelClass, list[int]),
|
||||
(AttrClass, list[int]),
|
||||
],
|
||||
)
|
||||
def test_get_attribute_access_type_default_factory(cls: type) -> None:
|
||||
def test_get_attribute_access_type_default_factory(
|
||||
cls: type, expected: GenericType
|
||||
) -> None:
|
||||
"""Test get_attribute_access_type returns the correct type for default factory fields.
|
||||
|
||||
Args:
|
||||
cls: Class to test.
|
||||
expected: Expected type.
|
||||
"""
|
||||
assert get_attribute_access_type(cls, "default_factory") == list[int]
|
||||
assert get_attribute_access_type(cls, "default_factory") == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
Loading…
Reference in New Issue
Block a user