diff --git a/reflex/event.py b/reflex/event.py index b354bc144..e83bd8e56 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -34,7 +34,7 @@ from reflex.utils.exceptions import ( EventHandlerArgMismatch, EventHandlerArgTypeMismatch, ) -from reflex.utils.types import ArgsSpec, GenericType +from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass from reflex.vars import VarData from reflex.vars.base import ( LiteralVar, @@ -1089,61 +1089,6 @@ def call_event_handler( "See https://reflex.dev/docs/events/event-arguments/" ) - def compare_types(provided_type, accepted_type): - if accepted_type is Any: - return True - - provided_type_origin = get_origin(provided_type) - accepted_type_origin = get_origin(accepted_type) - - if provided_type_origin is None and accepted_type_origin is None: - # Check if both are concrete types (e.g., int) - return issubclass(provided_type, accepted_type) - - # Remove this check when Python 3.10 is the minimum supported version - if hasattr(types, "UnionType"): - provided_type_origin = ( - Union - if provided_type_origin is types.UnionType - else provided_type_origin - ) - accepted_type_origin = ( - Union - if accepted_type_origin is types.UnionType - else accepted_type_origin - ) - - # Get type arguments (e.g., int vs. Union[float, int]) - provided_args = get_args(provided_type) - accepted_args = get_args(accepted_type) - - if accepted_type_origin is Union: - if provided_type_origin is not Union: - return any( - compare_types(provided_type, accepted_arg) - for accepted_arg in accepted_args - ) - return all( - any( - compare_types(provided_arg, accepted_arg) - for accepted_arg in accepted_args - ) - for provided_arg in provided_args - ) - - # Check if both are generic types (e.g., List) - if (provided_type_origin or provided_type) != ( - accepted_type_origin or accepted_type - ): - return False - - # Ensure all specific types are compatible with accepted types - return all( - compare_types(provided_arg, accepted_arg) - for provided_arg, accepted_arg in zip(provided_args, accepted_args) - if accepted_arg is not Any - ) - all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec event_spec_return_types = list( @@ -1177,7 +1122,7 @@ def call_event_handler( continue try: - compare_result = compare_types( + compare_result = typehint_issubclass( args_types_without_vars[i], type_hints_of_provided_callback[arg] ) except TypeError as e: diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 3d7992011..6ac8a4862 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -724,3 +724,67 @@ def validate_parameter_literals(func): # Store this here for performance. StateBases = get_base_class(StateVar) StateIterBases = get_base_class(StateIterVar) + + +def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool: + """Check if a type hint is a subclass of another type hint. + + Args: + possible_subclass: The type hint to check. + possible_superclass: The type hint to check against. + + Returns: + Whether the type hint is a subclass of the other type hint. + """ + if possible_superclass is Any: + return True + + provided_type_origin = get_origin(possible_subclass) + accepted_type_origin = get_origin(possible_superclass) + + if provided_type_origin is None and accepted_type_origin is None: + # In this case, we are dealing with a non-generic type, so we can use issubclass + return issubclass(possible_subclass, possible_superclass) + + # Remove this check when Python 3.10 is the minimum supported version + if hasattr(types, "UnionType"): + provided_type_origin = ( + Union if provided_type_origin is types.UnionType else provided_type_origin + ) + accepted_type_origin = ( + Union if accepted_type_origin is types.UnionType else accepted_type_origin + ) + + # Get type arguments (e.g., [float, int] for Dict[float, int]) + provided_args = get_args(possible_subclass) + accepted_args = get_args(possible_superclass) + + if accepted_type_origin is Union: + if provided_type_origin is not Union: + return any( + typehint_issubclass(possible_subclass, accepted_arg) + for accepted_arg in accepted_args + ) + return all( + any( + typehint_issubclass(provided_arg, accepted_arg) + for accepted_arg in accepted_args + ) + for provided_arg in provided_args + ) + + # Check if the origin of both types is the same (e.g., list for List[int]) + # This probably should be issubclass instead of == + if (provided_type_origin or possible_subclass) != ( + accepted_type_origin or possible_superclass + ): + return False + + # Ensure all specific types are compatible with accepted types + # Note this is not necessarily correct, as it doesn't check against contravariance and covariance + # It also ignores when the length of the arguments is different + return all( + typehint_issubclass(provided_arg, accepted_arg) + for provided_arg, accepted_arg in zip(provided_args, accepted_args) + if accepted_arg is not Any + )