From 1877758c055d12681ea4b50dac164f3a7ec31de6 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Thu, 10 Oct 2024 19:43:08 +0200 Subject: [PATCH] move has_args to utils.types, add tests for it --- reflex/utils/types.py | 21 +++++++++++++++ reflex/vars/base.py | 23 +--------------- tests/units/utils/test_types.py | 47 ++++++++++++++++++++++++++++++++- 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 6bedf5b61..3c5182e04 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -220,6 +220,27 @@ def is_literal(cls: GenericType) -> bool: return get_origin(cls) is Literal +def has_args(cls) -> bool: + """Check if the class has generic parameters. + + Args: + cls: The class to check. + + Returns: + Whether the class has generic + """ + if get_args(cls): + return True + + # Check if the class inherits from a generic class (using __orig_bases__) + if hasattr(cls, "__orig_bases__"): + for base in cls.__orig_bases__: + if get_args(base): + return True + + return False + + def is_optional(cls: GenericType) -> bool: """Check if a class is an Optional. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index f84dfbdd5..0b560d93d 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -56,7 +56,7 @@ from reflex.utils.imports import ( ParsedImportDict, parse_imports, ) -from reflex.utils.types import GenericType, Self, get_origin +from reflex.utils.types import GenericType, Self, get_origin, has_args if TYPE_CHECKING: from reflex.state import BaseState @@ -1257,27 +1257,6 @@ def unionize(*args: Type) -> Type: return Union[unionize(*first_half), unionize(*second_half)] -def has_args(cls) -> bool: - """Check if the class has generic parameters. - - Args: - cls: The class to check. - - Returns: - Whether the class has generic - """ - if get_args(cls): - return True - - # Check if the class inherits from a generic class (using __orig_bases__) - if hasattr(cls, "__orig_bases__"): - for base in cls.__orig_bases__: - if get_args(base): - return True - - return False - - def figure_out_type(value: Any) -> types.GenericType: """Figure out the type of the value. diff --git a/tests/units/utils/test_types.py b/tests/units/utils/test_types.py index fc9261e04..623aacc1f 100644 --- a/tests/units/utils/test_types.py +++ b/tests/units/utils/test_types.py @@ -1,4 +1,4 @@ -from typing import Any, List, Literal, Tuple, Union +from typing import Any, Dict, List, Literal, Tuple, Union import pytest @@ -45,3 +45,48 @@ def test_issubclass( cls: types.GenericType, cls_check: types.GenericType, expected: bool ) -> None: assert types._issubclass(cls, cls_check) == expected + + +class CustomDict(dict[str, str]): + """A custom dict with generic arguments.""" + + pass + + +class ChildCustomDict(CustomDict): + """A child of CustomDict.""" + + pass + + +class GenericDict(dict): + """A generic dict with no generic arguments.""" + + pass + + +class ChildGenericDict(GenericDict): + """A child of GenericDict.""" + + pass + + +@pytest.mark.parametrize( + "cls,expected", + [ + (int, False), + (str, False), + (float, False), + (Tuple[int], True), + (List[int], True), + (Union[int, str], True), + (Union[str, int], True), + (Dict[str, int], True), + (CustomDict, True), + (ChildCustomDict, True), + (GenericDict, False), + (ChildGenericDict, False), + ], +) +def test_has_args(cls, expected: bool) -> None: + assert types.has_args(cls) == expected