diff --git a/reflex/components/radix/themes/components/slider.py b/reflex/components/radix/themes/components/slider.py index 0d99eda27..bb017ea73 100644 --- a/reflex/components/radix/themes/components/slider.py +++ b/reflex/components/radix/themes/components/slider.py @@ -2,11 +2,11 @@ from __future__ import annotations -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Union from reflex.components.component import Component from reflex.components.core.breakpoints import Responsive -from reflex.event import EventHandler +from reflex.event import EventHandler, identity_event from reflex.vars.base import Var from ..base import ( @@ -14,19 +14,11 @@ from ..base import ( RadixThemesComponent, ) - -def on_value_event_spec( - value: Var[List[int | float]], -) -> Tuple[Var[List[int | float]]]: - """Event handler spec for the value event. - - Args: - value: The value of the event. - - Returns: - The event handler spec. - """ - return (value,) # type: ignore +on_value_event_spec = ( + identity_event(list[Union[int, float]]), + identity_event(list[int]), + identity_event(list[float]), +) class Slider(RadixThemesComponent): diff --git a/reflex/components/radix/themes/components/slider.pyi b/reflex/components/radix/themes/components/slider.pyi index dec836835..270d1ebf5 100644 --- a/reflex/components/radix/themes/components/slider.pyi +++ b/reflex/components/radix/themes/components/slider.pyi @@ -3,18 +3,20 @@ # ------------------- DO NOT EDIT ---------------------- # This file was generated by `reflex/utils/pyi_generator.py`! # ------------------------------------------------------ -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload +from typing import Any, Dict, List, Literal, Optional, Union, overload from reflex.components.core.breakpoints import Breakpoints -from reflex.event import EventType +from reflex.event import EventType, identity_event from reflex.style import Style from reflex.vars.base import Var from ..base import RadixThemesComponent -def on_value_event_spec( - value: Var[List[int | float]], -) -> Tuple[Var[List[int | float]]]: ... +on_value_event_spec = ( + identity_event(list[int]), + identity_event(list[Union[int, float]]), + identity_event(list[float]), +) class Slider(RadixThemesComponent): @overload @@ -138,7 +140,13 @@ class Slider(RadixThemesComponent): autofocus: Optional[bool] = None, custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, on_blur: Optional[EventType[[]]] = None, - on_change: Optional[EventType[List[int | float]]] = None, + on_change: Optional[ + Union[ + EventType[list[int]], + EventType[list[Union[int, float]]], + EventType[list[float]], + ] + ] = None, on_click: Optional[EventType[[]]] = None, on_context_menu: Optional[EventType[[]]] = None, on_double_click: Optional[EventType[[]]] = None, @@ -153,7 +161,13 @@ class Slider(RadixThemesComponent): on_mouse_up: Optional[EventType[[]]] = None, on_scroll: Optional[EventType[[]]] = None, on_unmount: Optional[EventType[[]]] = None, - on_value_commit: Optional[EventType[List[int | float]]] = None, + on_value_commit: Optional[ + Union[ + EventType[list[int]], + EventType[list[Union[int, float]]], + EventType[list[float]], + ] + ] = None, **props, ) -> "Slider": """Create a Slider component. diff --git a/reflex/event.py b/reflex/event.py index e5ed60c2e..78b48b87b 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -16,6 +16,7 @@ from typing import ( Generic, List, Optional, + Sequence, Tuple, Type, TypeVar, @@ -395,7 +396,9 @@ class EventChain(EventActionsMixin): events: List[Union[EventSpec, EventVar]] = dataclasses.field(default_factory=list) - args_spec: Optional[Callable] = dataclasses.field(default=None) + args_spec: Optional[Union[Callable, Sequence[Callable]]] = dataclasses.field( + default=None + ) invocation: Optional[Var] = dataclasses.field(default=None) @@ -1040,7 +1043,7 @@ def get_hydrate_event(state) -> str: def call_event_handler( event_handler: EventHandler | EventSpec, - arg_spec: ArgsSpec, + arg_spec: ArgsSpec | Sequence[ArgsSpec], key: Optional[str] = None, ) -> EventSpec: """Call an event handler to get the event spec. @@ -1084,6 +1087,9 @@ def call_event_handler( ) def compare_types(provided_type, accepted_type): + if accepted_type is Any: + return True + provided_type_origin = get_origin(provided_type) accepted_type_origin = get_origin(accepted_type) @@ -1091,6 +1097,13 @@ def call_event_handler( # Check if both are concrete types (e.g., int) return issubclass(provided_type, accepted_type) + provided_type_origin = ( + Union if provided_type_origin is types.UnionType else provided_type_origin + ) + accepted_type_origin = ( + Union if accepted_type_origin is types.UnionType else accepted_type_origin + ) + # Check if both are generic types (e.g., List) if (provided_type_origin or provided_type) != ( accepted_type_origin or accepted_type @@ -1103,48 +1116,70 @@ def call_event_handler( # Ensure all specific types are compatible with accepted types return all( - issubclass(provided_arg, accepted_arg) + compare_types(provided_arg, accepted_arg) for provided_arg, accepted_arg in zip(provided_args, accepted_args) if accepted_arg is not Any ) - event_spec_return_type = get_type_hints(arg_spec).get("return", None) + all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec - if ( - event_spec_return_type is not None - and get_origin(event_spec_return_type) is tuple - ): - args = get_args(event_spec_return_type) + event_spec_return_types = list( + filter( + lambda event_spec_return_type: event_spec_return_type is not None + and get_origin(event_spec_return_type) is tuple, + (get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec), + ) + ) - args_types_without_vars = [ - arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args - ] + if event_spec_return_types: + failures = [] - try: - type_hints_of_provided_callback = get_type_hints(event_handler.fn) - except NameError: - type_hints_of_provided_callback = {} + for event_spec_return_type in event_spec_return_types: + args = get_args(event_spec_return_type) - # check that args of event handler are matching the spec if type hints are provided - for i, arg in enumerate(provided_callback_fullspec.args[1:]): - if arg not in type_hints_of_provided_callback: - continue + args_types_without_vars = [ + arg if get_origin(arg) is not Var else get_args(arg)[0] for arg in args + ] try: - compare_result = compare_types( - args_types_without_vars[i], type_hints_of_provided_callback[arg] - ) - except TypeError as e: - raise TypeError( - f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}." - ) from e + type_hints_of_provided_callback = get_type_hints(event_handler.fn) + except NameError: + type_hints_of_provided_callback = {} - if compare_result: - continue - else: - raise EventHandlerArgTypeMismatch( - f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead." - ) + failed_type_check = False + + # check that args of event handler are matching the spec if type hints are provided + for i, arg in enumerate(provided_callback_fullspec.args[1:]): + if arg not in type_hints_of_provided_callback: + continue + + try: + compare_result = compare_types( + args_types_without_vars[i], type_hints_of_provided_callback[arg] + ) + except TypeError as e: + raise TypeError( + f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}." + ) from e + + if compare_result: + continue + else: + failure = EventHandlerArgTypeMismatch( + f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead." + ) + if len(event_spec_return_types) == 1: + raise failure + else: + failures.append(failure) + failed_type_check = True + break + + if not failed_type_check: + return event_handler(*parsed_args) + + if failures: + raise EventHandlerArgTypeMismatch("\n".join([str(f) for f in failures])) return event_handler(*parsed_args) # type: ignore @@ -1186,7 +1221,7 @@ def resolve_annotation(annotations: dict[str, Any], arg_name: str): return annotation -def parse_args_spec(arg_spec: ArgsSpec): +def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]): """Parse the args provided in the ArgsSpec of an event trigger. Args: @@ -1195,6 +1230,8 @@ def parse_args_spec(arg_spec: ArgsSpec): Returns: The parsed args. """ + # if there's multiple, the first is the default + arg_spec = arg_spec[0] if isinstance(arg_spec, Sequence) else arg_spec spec = inspect.getfullargspec(arg_spec) annotations = get_type_hints(arg_spec) @@ -1501,7 +1538,12 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar): Returns: The created LiteralEventChainVar instance. """ - sig = inspect.signature(value.args_spec) # type: ignore + arg_spec = ( + value.args_spec[0] + if isinstance(value.args_spec, Sequence) + else value.args_spec + ) + sig = inspect.signature(arg_spec) # type: ignore if sig.parameters: arg_def = tuple((f"_{p}" for p in sig.parameters)) arg_def_expr = LiteralVar.create([Var(_js_expr=arg) for arg in arg_def]) diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 1fc17341b..329503924 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -490,7 +490,7 @@ def _generate_component_create_functiondef( def figure_out_return_type(annotation: Any): if inspect.isclass(annotation) and issubclass(annotation, inspect._empty): - return ast.Name(id="Optional[EventType]") + return ast.Name(id="EventType") if not isinstance(annotation, str) and get_origin(annotation) is tuple: arguments = get_args(annotation) @@ -509,20 +509,13 @@ def _generate_component_create_functiondef( # Create EventType using the joined string event_type = ast.Name(id=f"EventType[{args_str}]") - # Wrap in Optional - optional_type = ast.Subscript( - value=ast.Name(id="Optional"), - slice=ast.Index(value=event_type), - ctx=ast.Load(), - ) - - return ast.Name(id=ast.unparse(optional_type)) + return event_type if isinstance(annotation, str) and annotation.startswith("Tuple["): inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]") if inside_of_tuple == "()": - return ast.Name(id="Optional[EventType[[]]]") + return ast.Name(id="EventType[[]]") arguments = [""] @@ -548,10 +541,8 @@ def _generate_component_create_functiondef( for argument in arguments ] - return ast.Name( - id=f"Optional[EventType[{', '.join(arguments_without_var)}]]" - ) - return ast.Name(id="Optional[EventType]") + return ast.Name(id=f"EventType[{', '.join(arguments_without_var)}]") + return ast.Name(id="EventType") event_triggers = clz().get_event_triggers() @@ -560,8 +551,33 @@ def _generate_component_create_functiondef( ( ast.arg( arg=trigger, - annotation=figure_out_return_type( - inspect.signature(event_triggers[trigger]).return_annotation + annotation=ast.Subscript( + ast.Name("Optional"), + ast.Index( # type: ignore + value=ast.Name( + id=ast.unparse( + figure_out_return_type( + inspect.signature(event_specs).return_annotation + ) + if not isinstance( + event_specs := event_triggers[trigger], tuple + ) + else ast.Subscript( + ast.Name("Union"), + ast.Tuple( + [ + figure_out_return_type( + inspect.signature( + event_spec + ).return_annotation + ) + for event_spec in event_specs + ] + ), + ) + ) + ) + ), ), ), ast.Constant(value=None),