more types

This commit is contained in:
Khaleel Al-Adhami 2024-11-15 16:59:28 -08:00
parent 92b1232806
commit 079cc56f59
2 changed files with 67 additions and 15 deletions

View File

@ -7,6 +7,7 @@ import dataclasses
import inspect import inspect
import sys import sys
import types import types
from collections import abc
from functools import cached_property, lru_cache, wraps from functools import cached_property, lru_cache, wraps
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -21,6 +22,7 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
TypeVar,
Union, Union,
_GenericAlias, # type: ignore _GenericAlias, # type: ignore
get_args, get_args,
@ -29,6 +31,7 @@ from typing import (
from typing import get_origin as get_origin_og from typing import get_origin as get_origin_og
import sqlalchemy import sqlalchemy
import typing_extensions
import reflex import reflex
from reflex.components.core.breakpoints import Breakpoints from reflex.components.core.breakpoints import Breakpoints
@ -810,24 +813,63 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
provided_args = get_args(possible_subclass) provided_args = get_args(possible_subclass)
accepted_args = get_args(possible_superclass) accepted_args = get_args(possible_superclass)
if accepted_type_origin is Union: if provided_type_origin is Union:
if provided_type_origin is not Union:
return any(
typehint_issubclass(possible_subclass, accepted_arg)
for accepted_arg in accepted_args
)
return all( return all(
any( typehint_issubclass(provided_arg, possible_superclass)
typehint_issubclass(provided_arg, accepted_arg)
for accepted_arg in accepted_args
)
for provided_arg in provided_args for provided_arg in provided_args
) )
if accepted_type_origin is Union:
return any(
typehint_issubclass(possible_subclass, accepted_arg)
for accepted_arg in accepted_args
)
# Check specifically for Sequence and Iterable
if (accepted_type_origin or possible_superclass) in (
Sequence,
abc.Sequence,
Iterable,
abc.Iterable,
):
iterable_type = accepted_args[0] if accepted_args else Any
if provided_type_origin is None:
if not issubclass(
possible_subclass, (accepted_type_origin or possible_superclass)
):
return False
if issubclass(possible_subclass, str) and not isinstance(
iterable_type, TypeVar
):
return typehint_issubclass(str, iterable_type)
if not issubclass(
provided_type_origin, (accepted_type_origin or possible_superclass)
):
return False
if not isinstance(iterable_type, (TypeVar, typing_extensions.TypeVar)):
if provided_type_origin in (list, tuple, set):
# Ensure all specific types are compatible with accepted types
return all(
typehint_issubclass(provided_arg, iterable_type)
for provided_arg in provided_args
if provided_arg is not ... # Ellipsis in Tuples
)
if possible_subclass is dict:
# Ensure all specific types are compatible with accepted types
return all(
typehint_issubclass(provided_arg, iterable_type)
for provided_arg in provided_args[:1]
)
return True
# Check if the origin of both types is the same (e.g., list for List[int]) # Check if the origin of both types is the same (e.g., list for List[int])
# This probably should be issubclass instead of == if not issubclass(
if (provided_type_origin or possible_subclass) != ( provided_type_origin or possible_subclass,
accepted_type_origin or possible_superclass accepted_type_origin or possible_superclass,
): ):
return False return False

View File

@ -2,7 +2,7 @@ import os
import typing import typing
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Type, Union from typing import Any, ClassVar, Dict, List, Literal, Sequence, Tuple, Type, Union
import pytest import pytest
import typer import typer
@ -109,10 +109,20 @@ def test_is_generic_alias(cls: type, expected: bool):
(Dict[str, str], dict[str, str], True), (Dict[str, str], dict[str, str], True),
(Dict[str, str], dict[str, Any], True), (Dict[str, str], dict[str, Any], True),
(Dict[str, Any], dict[str, Any], True), (Dict[str, Any], dict[str, Any], True),
(List[int], Sequence[int], True),
(List[str], Sequence[int], False),
(Tuple[int], Sequence[int], True),
(Tuple[int, str], Sequence[int], False),
(Tuple[int, ...], Sequence[int], True),
(str, Sequence[int], False),
(str, Sequence[str], True),
], ],
) )
def test_typehint_issubclass(subclass, superclass, expected): def test_typehint_issubclass(subclass, superclass, expected):
assert types.typehint_issubclass(subclass, superclass) == expected if expected:
assert types.typehint_issubclass(subclass, superclass)
else:
assert not types.typehint_issubclass(subclass, superclass)
def test_validate_invalid_bun_path(mocker): def test_validate_invalid_bun_path(mocker):