move has_args to utils.types, add tests for it

This commit is contained in:
Benedikt Bartscher 2024-10-10 19:43:08 +02:00
parent bbd34ab592
commit 1877758c05
No known key found for this signature in database
3 changed files with 68 additions and 23 deletions

View File

@ -220,6 +220,27 @@ def is_literal(cls: GenericType) -> bool:
return get_origin(cls) is Literal return get_origin(cls) is Literal
def has_args(cls) -> bool:
"""Check if the class has generic parameters.
Args:
cls: The class to check.
Returns:
Whether the class has generic
"""
if get_args(cls):
return True
# Check if the class inherits from a generic class (using __orig_bases__)
if hasattr(cls, "__orig_bases__"):
for base in cls.__orig_bases__:
if get_args(base):
return True
return False
def is_optional(cls: GenericType) -> bool: def is_optional(cls: GenericType) -> bool:
"""Check if a class is an Optional. """Check if a class is an Optional.

View File

@ -56,7 +56,7 @@ from reflex.utils.imports import (
ParsedImportDict, ParsedImportDict,
parse_imports, parse_imports,
) )
from reflex.utils.types import GenericType, Self, get_origin from reflex.utils.types import GenericType, Self, get_origin, has_args
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.state import BaseState from reflex.state import BaseState
@ -1257,27 +1257,6 @@ def unionize(*args: Type) -> Type:
return Union[unionize(*first_half), unionize(*second_half)] return Union[unionize(*first_half), unionize(*second_half)]
def has_args(cls) -> bool:
"""Check if the class has generic parameters.
Args:
cls: The class to check.
Returns:
Whether the class has generic
"""
if get_args(cls):
return True
# Check if the class inherits from a generic class (using __orig_bases__)
if hasattr(cls, "__orig_bases__"):
for base in cls.__orig_bases__:
if get_args(base):
return True
return False
def figure_out_type(value: Any) -> types.GenericType: def figure_out_type(value: Any) -> types.GenericType:
"""Figure out the type of the value. """Figure out the type of the value.

View File

@ -1,4 +1,4 @@
from typing import Any, List, Literal, Tuple, Union from typing import Any, Dict, List, Literal, Tuple, Union
import pytest import pytest
@ -45,3 +45,48 @@ def test_issubclass(
cls: types.GenericType, cls_check: types.GenericType, expected: bool cls: types.GenericType, cls_check: types.GenericType, expected: bool
) -> None: ) -> None:
assert types._issubclass(cls, cls_check) == expected assert types._issubclass(cls, cls_check) == expected
class CustomDict(dict[str, str]):
"""A custom dict with generic arguments."""
pass
class ChildCustomDict(CustomDict):
"""A child of CustomDict."""
pass
class GenericDict(dict):
"""A generic dict with no generic arguments."""
pass
class ChildGenericDict(GenericDict):
"""A child of GenericDict."""
pass
@pytest.mark.parametrize(
"cls,expected",
[
(int, False),
(str, False),
(float, False),
(Tuple[int], True),
(List[int], True),
(Union[int, str], True),
(Union[str, int], True),
(Dict[str, int], True),
(CustomDict, True),
(ChildCustomDict, True),
(GenericDict, False),
(ChildGenericDict, False),
],
)
def test_has_args(cls, expected: bool) -> None:
assert types.has_args(cls) == expected