move function to types
This commit is contained in:
parent
92190c2137
commit
95711b8ac9
@ -34,7 +34,7 @@ from reflex.utils.exceptions import (
|
||||
EventHandlerArgMismatch,
|
||||
EventHandlerArgTypeMismatch,
|
||||
)
|
||||
from reflex.utils.types import ArgsSpec, GenericType
|
||||
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import (
|
||||
LiteralVar,
|
||||
@ -1089,61 +1089,6 @@ def call_event_handler(
|
||||
"See https://reflex.dev/docs/events/event-arguments/"
|
||||
)
|
||||
|
||||
def compare_types(provided_type, accepted_type):
|
||||
if accepted_type is Any:
|
||||
return True
|
||||
|
||||
provided_type_origin = get_origin(provided_type)
|
||||
accepted_type_origin = get_origin(accepted_type)
|
||||
|
||||
if provided_type_origin is None and accepted_type_origin is None:
|
||||
# Check if both are concrete types (e.g., int)
|
||||
return issubclass(provided_type, accepted_type)
|
||||
|
||||
# Remove this check when Python 3.10 is the minimum supported version
|
||||
if hasattr(types, "UnionType"):
|
||||
provided_type_origin = (
|
||||
Union
|
||||
if provided_type_origin is types.UnionType
|
||||
else provided_type_origin
|
||||
)
|
||||
accepted_type_origin = (
|
||||
Union
|
||||
if accepted_type_origin is types.UnionType
|
||||
else accepted_type_origin
|
||||
)
|
||||
|
||||
# Get type arguments (e.g., int vs. Union[float, int])
|
||||
provided_args = get_args(provided_type)
|
||||
accepted_args = get_args(accepted_type)
|
||||
|
||||
if accepted_type_origin is Union:
|
||||
if provided_type_origin is not Union:
|
||||
return any(
|
||||
compare_types(provided_type, accepted_arg)
|
||||
for accepted_arg in accepted_args
|
||||
)
|
||||
return all(
|
||||
any(
|
||||
compare_types(provided_arg, accepted_arg)
|
||||
for accepted_arg in accepted_args
|
||||
)
|
||||
for provided_arg in provided_args
|
||||
)
|
||||
|
||||
# Check if both are generic types (e.g., List)
|
||||
if (provided_type_origin or provided_type) != (
|
||||
accepted_type_origin or accepted_type
|
||||
):
|
||||
return False
|
||||
|
||||
# Ensure all specific types are compatible with accepted types
|
||||
return all(
|
||||
compare_types(provided_arg, accepted_arg)
|
||||
for provided_arg, accepted_arg in zip(provided_args, accepted_args)
|
||||
if accepted_arg is not Any
|
||||
)
|
||||
|
||||
all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec
|
||||
|
||||
event_spec_return_types = list(
|
||||
@ -1177,7 +1122,7 @@ def call_event_handler(
|
||||
continue
|
||||
|
||||
try:
|
||||
compare_result = compare_types(
|
||||
compare_result = typehint_issubclass(
|
||||
args_types_without_vars[i], type_hints_of_provided_callback[arg]
|
||||
)
|
||||
except TypeError as e:
|
||||
|
@ -724,3 +724,67 @@ def validate_parameter_literals(func):
|
||||
# Store this here for performance.
|
||||
StateBases = get_base_class(StateVar)
|
||||
StateIterBases = get_base_class(StateIterVar)
|
||||
|
||||
|
||||
def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
|
||||
"""Check if a type hint is a subclass of another type hint.
|
||||
|
||||
Args:
|
||||
possible_subclass: The type hint to check.
|
||||
possible_superclass: The type hint to check against.
|
||||
|
||||
Returns:
|
||||
Whether the type hint is a subclass of the other type hint.
|
||||
"""
|
||||
if possible_superclass is Any:
|
||||
return True
|
||||
|
||||
provided_type_origin = get_origin(possible_subclass)
|
||||
accepted_type_origin = get_origin(possible_superclass)
|
||||
|
||||
if provided_type_origin is None and accepted_type_origin is None:
|
||||
# In this case, we are dealing with a non-generic type, so we can use issubclass
|
||||
return issubclass(possible_subclass, possible_superclass)
|
||||
|
||||
# Remove this check when Python 3.10 is the minimum supported version
|
||||
if hasattr(types, "UnionType"):
|
||||
provided_type_origin = (
|
||||
Union if provided_type_origin is types.UnionType else provided_type_origin
|
||||
)
|
||||
accepted_type_origin = (
|
||||
Union if accepted_type_origin is types.UnionType else accepted_type_origin
|
||||
)
|
||||
|
||||
# Get type arguments (e.g., [float, int] for Dict[float, int])
|
||||
provided_args = get_args(possible_subclass)
|
||||
accepted_args = get_args(possible_superclass)
|
||||
|
||||
if accepted_type_origin is Union:
|
||||
if provided_type_origin is not Union:
|
||||
return any(
|
||||
typehint_issubclass(possible_subclass, accepted_arg)
|
||||
for accepted_arg in accepted_args
|
||||
)
|
||||
return all(
|
||||
any(
|
||||
typehint_issubclass(provided_arg, accepted_arg)
|
||||
for accepted_arg in accepted_args
|
||||
)
|
||||
for provided_arg in provided_args
|
||||
)
|
||||
|
||||
# Check if the origin of both types is the same (e.g., list for List[int])
|
||||
# This probably should be issubclass instead of ==
|
||||
if (provided_type_origin or possible_subclass) != (
|
||||
accepted_type_origin or possible_superclass
|
||||
):
|
||||
return False
|
||||
|
||||
# Ensure all specific types are compatible with accepted types
|
||||
# Note this is not necessarily correct, as it doesn't check against contravariance and covariance
|
||||
# It also ignores when the length of the arguments is different
|
||||
return all(
|
||||
typehint_issubclass(provided_arg, accepted_arg)
|
||||
for provided_arg, accepted_arg in zip(provided_args, accepted_args)
|
||||
if accepted_arg is not Any
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user