From cec4e3c258a796a8c3064bc1bf36169d33b567bf Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Mon, 28 Oct 2024 12:10:02 -0700 Subject: [PATCH] add a test for type hint is subclass --- tests/units/utils/test_utils.py | 39 ++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index 81579acc7..3c78bb83a 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -2,7 +2,7 @@ import os import typing from functools import cached_property from pathlib import Path -from typing import Any, ClassVar, List, Literal, Type, Union +from typing import Any, ClassVar, Dict, List, Literal, Type, Union import pytest import typer @@ -77,6 +77,43 @@ def test_is_generic_alias(cls: type, expected: bool): assert types.is_generic_alias(cls) == expected +@pytest.mark.parametrize( + ("subclass", "superclass", "expected"), + [ + *[ + (base_type, base_type, True) + for base_type in [int, float, str, bool, list, dict] + ], + *[ + (one_type, another_type, False) + for one_type in [int, float, str, list, dict] + for another_type in [int, float, str, list, dict] + if one_type != another_type + ], + (bool, int, True), + (int, bool, False), + (list, List, True), + (list, List[str], True), # this is wrong, but it's a limitation of the function + (List, list, True), + (List[int], list, True), + (List[int], List, True), + (List[int], List[str], False), + (List[int], List[int], True), + (List[int], List[float], False), + (List[int], List[Union[int, float]], True), + (List[int], List[Union[float, str]], False), + (Union[int, float], List[Union[int, float]], False), + (Union[int, float], Union[int, float, str], True), + (Union[int, float], Union[str, float], False), + (Dict[str, int], Dict[str, int], True), + (Dict[str, bool], Dict[str, int], True), + (Dict[str, int], Dict[str, bool], False), + ], +) +def test_typehint_issubclass(subclass, superclass, expected): + assert types.typehint_issubclass(subclass, superclass) == expected + + def test_validate_invalid_bun_path(mocker): """Test that an error is thrown when a custom specified bun path is not valid or does not exist.