move function to types

This commit is contained in:
Khaleel Al-Adhami 2024-10-28 11:51:50 -07:00
parent 92190c2137
commit 95711b8ac9
2 changed files with 66 additions and 57 deletions

View File

@ -34,7 +34,7 @@ from reflex.utils.exceptions import (
EventHandlerArgMismatch,
EventHandlerArgTypeMismatch,
)
from reflex.utils.types import ArgsSpec, GenericType
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
from reflex.vars import VarData
from reflex.vars.base import (
LiteralVar,
@ -1089,61 +1089,6 @@ def call_event_handler(
"See https://reflex.dev/docs/events/event-arguments/"
)
def compare_types(provided_type, accepted_type):
if accepted_type is Any:
return True
provided_type_origin = get_origin(provided_type)
accepted_type_origin = get_origin(accepted_type)
if provided_type_origin is None and accepted_type_origin is None:
# Check if both are concrete types (e.g., int)
return issubclass(provided_type, accepted_type)
# Remove this check when Python 3.10 is the minimum supported version
if hasattr(types, "UnionType"):
provided_type_origin = (
Union
if provided_type_origin is types.UnionType
else provided_type_origin
)
accepted_type_origin = (
Union
if accepted_type_origin is types.UnionType
else accepted_type_origin
)
# Get type arguments (e.g., int vs. Union[float, int])
provided_args = get_args(provided_type)
accepted_args = get_args(accepted_type)
if accepted_type_origin is Union:
if provided_type_origin is not Union:
return any(
compare_types(provided_type, accepted_arg)
for accepted_arg in accepted_args
)
return all(
any(
compare_types(provided_arg, accepted_arg)
for accepted_arg in accepted_args
)
for provided_arg in provided_args
)
# Check if both are generic types (e.g., List)
if (provided_type_origin or provided_type) != (
accepted_type_origin or accepted_type
):
return False
# Ensure all specific types are compatible with accepted types
return all(
compare_types(provided_arg, accepted_arg)
for provided_arg, accepted_arg in zip(provided_args, accepted_args)
if accepted_arg is not Any
)
all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
event_spec_return_types = list(
@ -1177,7 +1122,7 @@ def call_event_handler(
continue
try:
compare_result = compare_types(
compare_result = typehint_issubclass(
args_types_without_vars[i], type_hints_of_provided_callback[arg]
)
except TypeError as e:

View File

@ -724,3 +724,67 @@ def validate_parameter_literals(func):
# Store this here for performance.
StateBases = get_base_class(StateVar)
StateIterBases = get_base_class(StateIterVar)
def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
"""Check if a type hint is a subclass of another type hint.
Args:
possible_subclass: The type hint to check.
possible_superclass: The type hint to check against.
Returns:
Whether the type hint is a subclass of the other type hint.
"""
if possible_superclass is Any:
return True
provided_type_origin = get_origin(possible_subclass)
accepted_type_origin = get_origin(possible_superclass)
if provided_type_origin is None and accepted_type_origin is None:
# In this case, we are dealing with a non-generic type, so we can use issubclass
return issubclass(possible_subclass, possible_superclass)
# Remove this check when Python 3.10 is the minimum supported version
if hasattr(types, "UnionType"):
provided_type_origin = (
Union if provided_type_origin is types.UnionType else provided_type_origin
)
accepted_type_origin = (
Union if accepted_type_origin is types.UnionType else accepted_type_origin
)
# Get type arguments (e.g., [float, int] for Dict[float, int])
provided_args = get_args(possible_subclass)
accepted_args = get_args(possible_superclass)
if accepted_type_origin is Union:
if provided_type_origin is not Union:
return any(
typehint_issubclass(possible_subclass, accepted_arg)
for accepted_arg in accepted_args
)
return all(
any(
typehint_issubclass(provided_arg, accepted_arg)
for accepted_arg in accepted_args
)
for provided_arg in provided_args
)
# Check if the origin of both types is the same (e.g., list for List[int])
# This probably should be issubclass instead of ==
if (provided_type_origin or possible_subclass) != (
accepted_type_origin or possible_superclass
):
return False
# Ensure all specific types are compatible with accepted types
# Note this is not necessarily correct, as it doesn't check against contravariance and covariance
# It also ignores when the length of the arguments is different
return all(
typehint_issubclass(provided_arg, accepted_arg)
for provided_arg, accepted_arg in zip(provided_args, accepted_args)
if accepted_arg is not Any
)