diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 27b6e7ce7..0d6f878ec 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -7,6 +7,7 @@ import dataclasses import inspect import sys import types +from collections import abc from functools import cached_property, lru_cache, wraps from typing import ( TYPE_CHECKING, @@ -21,6 +22,7 @@ from typing import ( Sequence, Tuple, Type, + TypeVar, Union, _GenericAlias, # type: ignore get_args, @@ -29,6 +31,7 @@ from typing import ( from typing import get_origin as get_origin_og import sqlalchemy +import typing_extensions import reflex from reflex.components.core.breakpoints import Breakpoints @@ -810,24 +813,63 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo provided_args = get_args(possible_subclass) accepted_args = get_args(possible_superclass) - if accepted_type_origin is Union: - if provided_type_origin is not Union: - return any( - typehint_issubclass(possible_subclass, accepted_arg) - for accepted_arg in accepted_args - ) + if provided_type_origin is Union: return all( - any( - typehint_issubclass(provided_arg, accepted_arg) - for accepted_arg in accepted_args - ) + typehint_issubclass(provided_arg, possible_superclass) for provided_arg in provided_args ) + if accepted_type_origin is Union: + return any( + typehint_issubclass(possible_subclass, accepted_arg) + for accepted_arg in accepted_args + ) + + # Check specifically for Sequence and Iterable + if (accepted_type_origin or possible_superclass) in ( + Sequence, + abc.Sequence, + Iterable, + abc.Iterable, + ): + iterable_type = accepted_args[0] if accepted_args else Any + + if provided_type_origin is None: + if not issubclass( + possible_subclass, (accepted_type_origin or possible_superclass) + ): + return False + + if issubclass(possible_subclass, str) and not isinstance( + iterable_type, TypeVar + ): + return typehint_issubclass(str, iterable_type) + + if not issubclass( + provided_type_origin, (accepted_type_origin or possible_superclass) + ): + return False + + if not isinstance(iterable_type, (TypeVar, typing_extensions.TypeVar)): + if provided_type_origin in (list, tuple, set): + # Ensure all specific types are compatible with accepted types + return all( + typehint_issubclass(provided_arg, iterable_type) + for provided_arg in provided_args + if provided_arg is not ... # Ellipsis in Tuples + ) + if possible_subclass is dict: + # Ensure all specific types are compatible with accepted types + return all( + typehint_issubclass(provided_arg, iterable_type) + for provided_arg in provided_args[:1] + ) + return True + # Check if the origin of both types is the same (e.g., list for List[int]) - # This probably should be issubclass instead of == - if (provided_type_origin or possible_subclass) != ( - accepted_type_origin or possible_superclass + if not issubclass( + provided_type_origin or possible_subclass, + accepted_type_origin or possible_superclass, ): return False diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index dd1a3b3ef..2a71f3461 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -2,7 +2,7 @@ import os import typing from functools import cached_property from pathlib import Path -from typing import Any, ClassVar, Dict, List, Literal, Type, Union +from typing import Any, ClassVar, Dict, List, Literal, Sequence, Tuple, Type, Union import pytest import typer @@ -109,10 +109,20 @@ def test_is_generic_alias(cls: type, expected: bool): (Dict[str, str], dict[str, str], True), (Dict[str, str], dict[str, Any], True), (Dict[str, Any], dict[str, Any], True), + (List[int], Sequence[int], True), + (List[str], Sequence[int], False), + (Tuple[int], Sequence[int], True), + (Tuple[int, str], Sequence[int], False), + (Tuple[int, ...], Sequence[int], True), + (str, Sequence[int], False), + (str, Sequence[str], True), ], ) def test_typehint_issubclass(subclass, superclass, expected): - assert types.typehint_issubclass(subclass, superclass) == expected + if expected: + assert types.typehint_issubclass(subclass, superclass) + else: + assert not types.typehint_issubclass(subclass, superclass) def test_validate_invalid_bun_path(mocker):