From 77fe5770bb26b1d939c6936c62e9be171fa11bf5 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 31 Oct 2024 15:07:01 -0700 Subject: [PATCH] allow for event handlers to ignore args --- reflex/components/component.py | 16 ++- reflex/event.py | 131 ++++++++++++----------- reflex/utils/exceptions.py | 6 +- reflex/utils/pyi_generator.py | 4 +- tests/units/components/test_component.py | 73 ++++++------- tests/units/test_event.py | 9 +- 6 files changed, 128 insertions(+), 111 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 85db3906d..50cc13007 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -17,6 +17,7 @@ from typing import ( Iterator, List, Optional, + Sequence, Set, Type, Union, @@ -533,7 +534,7 @@ class Component(BaseComponent, ABC): def _create_event_chain( self, - args_spec: Any, + args_spec: types.ArgsSpec | Sequence[types.ArgsSpec], value: Union[ Var, EventHandler, @@ -599,7 +600,7 @@ class Component(BaseComponent, ABC): # If the input is a callable, create an event chain. elif isinstance(value, Callable): - result = call_event_fn(value, args_spec) + result = call_event_fn(value, args_spec, key=key) if isinstance(result, Var): # Recursively call this function if the lambda returned an EventChain Var. return self._create_event_chain(args_spec, result, key=key) @@ -629,14 +630,16 @@ class Component(BaseComponent, ABC): event_actions={}, ) - def get_event_triggers(self) -> Dict[str, Any]: + def get_event_triggers( + self, + ) -> Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]]: """Get the event triggers for the component. Returns: The event triggers. """ - default_triggers = { + default_triggers: Dict[str, types.ArgsSpec | Sequence[types.ArgsSpec]] = { EventTriggers.ON_FOCUS: empty_event, EventTriggers.ON_BLUR: empty_event, EventTriggers.ON_CLICK: empty_event, @@ -1142,7 +1145,10 @@ class Component(BaseComponent, ABC): if isinstance(event, EventCallback): continue if isinstance(event, EventSpec): - if event.handler.state_full_name: + if ( + event.handler.state_full_name + and event.handler.state_full_name != "state" + ): return True else: if event._var_state: diff --git a/reflex/event.py b/reflex/event.py index c2e6955f6..25c6258b4 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -31,7 +31,6 @@ from reflex import constants from reflex.utils import console, format from reflex.utils.exceptions import ( EventFnArgMismatch, - EventHandlerArgMismatch, EventHandlerArgTypeMismatch, ) from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass @@ -689,7 +688,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec: fn.__qualname__ = name fn.__signature__ = sig return EventSpec( - handler=EventHandler(fn=fn), + handler=EventHandler(fn=fn, state_full_name="state"), args=tuple( ( Var(_js_expr=k), @@ -1058,8 +1057,8 @@ def get_hydrate_event(state) -> str: def call_event_handler( - event_handler: EventHandler | EventSpec, - arg_spec: ArgsSpec | Sequence[ArgsSpec], + event_callback: EventHandler | EventSpec, + event_spec: ArgsSpec | Sequence[ArgsSpec], key: Optional[str] = None, ) -> EventSpec: """Call an event handler to get the event spec. @@ -1069,12 +1068,12 @@ def call_event_handler( Otherwise, the event handler will be called with no args. Args: - event_handler: The event handler. - arg_spec: The lambda that define the argument(s) to pass to the event handler. + event_callback: The event handler. + event_spec: The lambda that define the argument(s) to pass to the event handler. key: The key to pass to the event handler. Raises: - EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec. + EventFnArgMismatch: if number of arguments expected by event_handler doesn't match the spec. Returns: The event spec from calling the event handler. @@ -1082,40 +1081,47 @@ def call_event_handler( # noqa: DAR401 failure """ - parsed_args = parse_args_spec(arg_spec) # type: ignore + event_spec_args = parse_args_spec(event_spec) # type: ignore - if isinstance(event_handler, EventSpec): - # Handle partial application of EventSpec args - return event_handler.add_args(*parsed_args) - - provided_callback_fullspec = inspect.getfullargspec(event_handler.fn) - - provided_callback_n_args = ( - len(provided_callback_fullspec.args) - 1 - ) # subtract 1 for bound self arg - - if provided_callback_n_args != len(parsed_args): - raise EventHandlerArgMismatch( - "The number of arguments accepted by " - f"{event_handler.fn.__qualname__} ({provided_callback_n_args}) " - "does not match the arguments passed by the event trigger: " - f"{[str(v) for v in parsed_args]}\n" - "See https://reflex.dev/docs/events/event-arguments/" + if isinstance(event_callback, EventSpec): + check_fn_match_arg_spec( + event_callback.handler.fn, + event_spec, + key, + bool(event_callback.handler.state_full_name) + len(event_callback.args), + event_callback.handler.fn.__qualname__, ) + # Handle partial application of EventSpec args + return event_callback.add_args(*event_spec_args) - all_arg_spec = [arg_spec] if not isinstance(arg_spec, Sequence) else arg_spec + check_fn_match_arg_spec( + event_callback.fn, + event_spec, + key, + bool(event_callback.state_full_name), + event_callback.fn.__qualname__, + ) + + all_acceptable_specs = ( + [event_spec] if not isinstance(event_spec, Sequence) else event_spec + ) event_spec_return_types = list( filter( lambda event_spec_return_type: event_spec_return_type is not None and get_origin(event_spec_return_type) is tuple, - (get_type_hints(arg_spec).get("return", None) for arg_spec in all_arg_spec), + ( + get_type_hints(arg_spec).get("return", None) + for arg_spec in all_acceptable_specs + ), ) ) if event_spec_return_types: failures = [] + event_callback_spec = inspect.getfullargspec(event_callback.fn) + for event_spec_index, event_spec_return_type in enumerate( event_spec_return_types ): @@ -1126,14 +1132,14 @@ def call_event_handler( ] try: - type_hints_of_provided_callback = get_type_hints(event_handler.fn) + type_hints_of_provided_callback = get_type_hints(event_callback.fn) except NameError: type_hints_of_provided_callback = {} failed_type_check = False # check that args of event handler are matching the spec if type hints are provided - for i, arg in enumerate(provided_callback_fullspec.args[1:]): + for i, arg in enumerate(event_callback_spec.args[1:]): if arg not in type_hints_of_provided_callback: continue @@ -1147,7 +1153,7 @@ def call_event_handler( # f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}." # ) from e console.warn( - f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_handler.fn.__qualname__} provided for {key}." + f"Could not compare types {args_types_without_vars[i]} and {type_hints_of_provided_callback[arg]} for argument {arg} of {event_callback.fn.__qualname__} provided for {key}." ) compare_result = False @@ -1155,7 +1161,7 @@ def call_event_handler( continue else: failure = EventHandlerArgTypeMismatch( - f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_handler.fn.__qualname__} instead." + f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead." ) failures.append(failure) failed_type_check = True @@ -1176,14 +1182,14 @@ def call_event_handler( given_string = ", ".join( repr(type_hints_of_provided_callback.get(arg, Any)) - for arg in provided_callback_fullspec.args[1:] + for arg in event_callback_spec.args[1:] ).replace("[", "\\[") console.warn( - f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_handler.fn.__qualname__} instead. " + f"Event handler {key} expects ({expect_string}) -> () but got ({given_string}) -> () as annotated in {event_callback.fn.__qualname__} instead. " f"This may lead to unexpected behavior but is intentionally ignored for {key}." ) - return event_handler(*parsed_args) + return event_callback(*event_spec_args) if failures: console.deprecate( @@ -1193,7 +1199,7 @@ def call_event_handler( "0.7.0", ) - return event_handler(*parsed_args) # type: ignore + return event_callback(*event_spec_args) # type: ignore def unwrap_var_annotation(annotation: GenericType): @@ -1260,45 +1266,46 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]): def check_fn_match_arg_spec( - fn: Callable, - arg_spec: ArgsSpec, - key: Optional[str] = None, -) -> List[Var]: + user_func: Callable, + arg_spec: ArgsSpec | Sequence[ArgsSpec], + key: str | None = None, + number_of_bound_args: int = 0, + func_name: str | None = None, +): """Ensures that the function signature matches the passed argument specification or raises an EventFnArgMismatch if they do not. Args: - fn: The function to be validated. + user_func: The function to be validated. arg_spec: The argument specification for the event trigger. - key: The key to pass to the event handler. - - Returns: - The parsed arguments from the argument specification. + key: The key of the event trigger. + number_of_bound_args: The number of bound arguments to the function. + func_name: The name of the function to be validated. Raises: EventFnArgMismatch: Raised if the number of mandatory arguments do not match """ - fn_args = inspect.getfullargspec(fn).args - fn_defaults_args = inspect.getfullargspec(fn).defaults - n_fn_args = len(fn_args) - n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0 - if isinstance(fn, types.MethodType): - n_fn_args -= 1 # subtract 1 for bound self arg - parsed_args = parse_args_spec(arg_spec) - if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args): + user_args = inspect.getfullargspec(user_func).args + user_default_args = inspect.getfullargspec(user_func).defaults + number_of_user_args = len(user_args) - number_of_bound_args + number_of_user_default_args = len(user_default_args) if user_default_args else 0 + + parsed_event_args = parse_args_spec(arg_spec) + + number_of_event_args = len(parsed_event_args) + + if number_of_user_args - number_of_user_default_args > number_of_event_args: raise EventFnArgMismatch( - "The number of mandatory arguments accepted by " - f"{fn} ({n_fn_args - n_fn_defaults_args}) " - "does not match the arguments passed by the event trigger: " - f"{[str(v) for v in parsed_args]}\n" + f"Event {key} only provides {number_of_event_args} arguments, but " + f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} " + "arguments to be passed to the event handler.\n" "See https://reflex.dev/docs/events/event-arguments/" ) - return parsed_args def call_event_fn( fn: Callable, - arg_spec: ArgsSpec, + arg_spec: ArgsSpec | Sequence[ArgsSpec], key: Optional[str] = None, ) -> list[EventSpec] | Var: """Call a function to a list of event specs. @@ -1322,10 +1329,14 @@ def call_event_fn( from reflex.utils.exceptions import EventHandlerValueError # Check that fn signature matches arg_spec - parsed_args = check_fn_match_arg_spec(fn, arg_spec, key=key) + check_fn_match_arg_spec(fn, arg_spec, key=key) + + parsed_args = parse_args_spec(arg_spec) + + number_of_fn_args = len(inspect.getfullargspec(fn).args) # Call the function with the parsed args. - out = fn(*parsed_args) + out = fn(*[*parsed_args][:number_of_fn_args]) # If the function returns a Var, assume it's an EventChain and render it directly. if isinstance(out, Var): diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 661f29095..dc25a09e0 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -89,16 +89,12 @@ class MatchTypeError(ReflexError, TypeError): """Raised when the return types of match cases are different.""" -class EventHandlerArgMismatch(ReflexError, TypeError): - """Raised when the number of args accepted by an EventHandler differs from that provided by the event trigger.""" - - class EventHandlerArgTypeMismatch(ReflexError, TypeError): """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger.""" class EventFnArgMismatch(ReflexError, TypeError): - """Raised when the number of args accepted by a lambda differs from that provided by the event trigger.""" + """Raised when the number of args required by an event handler is more than provided by the event trigger.""" class DynamicRouteArgShadowsStateVar(ReflexError, NameError): diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 342277cad..961eea178 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -16,7 +16,7 @@ from itertools import chain from multiprocessing import Pool, cpu_count from pathlib import Path from types import ModuleType, SimpleNamespace -from typing import Any, Callable, Iterable, Type, get_args, get_origin +from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin from reflex.components.component import Component from reflex.utils import types as rx_types @@ -560,7 +560,7 @@ def _generate_component_create_functiondef( inspect.signature(event_specs).return_annotation ) if not isinstance( - event_specs := event_triggers[trigger], tuple + event_specs := event_triggers[trigger], Sequence ) else ast.Subscript( ast.Name("Union"), diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index a614fd715..89f4bd417 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -29,7 +29,6 @@ from reflex.style import Style from reflex.utils import imports from reflex.utils.exceptions import ( EventFnArgMismatch, - EventHandlerArgMismatch, ) from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import VarData @@ -907,26 +906,28 @@ def test_invalid_event_handler_args(component2, test_state): test_state: A test state. """ # EventHandler args must match - with pytest.raises(EventHandlerArgMismatch): + with pytest.raises(EventFnArgMismatch): component2.create(on_click=test_state.do_something_arg) - with pytest.raises(EventHandlerArgMismatch): - component2.create(on_open=test_state.do_something) - with pytest.raises(EventHandlerArgMismatch): - component2.create(on_prop_event=test_state.do_something) + # Does not raise because event handlers are allowed to have less args than the spec. + # with pytest.raises(EventFnArgMismatch): + # component2.create(on_open=test_state.do_something) + # with pytest.raises(EventFnArgMismatch): + # component2.create(on_prop_event=test_state.do_something) # Multiple EventHandler args: all must match - with pytest.raises(EventHandlerArgMismatch): + with pytest.raises(EventFnArgMismatch): component2.create( on_click=[test_state.do_something_arg, test_state.do_something] ) - with pytest.raises(EventHandlerArgMismatch): - component2.create( - on_open=[test_state.do_something_arg, test_state.do_something] - ) - with pytest.raises(EventHandlerArgMismatch): - component2.create( - on_prop_event=[test_state.do_something_arg, test_state.do_something] - ) + # Same as above + # with pytest.raises(EventFnArgMismatch): + # component2.create( + # on_open=[test_state.do_something_arg, test_state.do_something] + # ) + # with pytest.raises(EventFnArgMismatch): + # component2.create( + # on_prop_event=[test_state.do_something_arg, test_state.do_something] + # ) # Enable when 0.7.0 happens # # Event Handler types must match @@ -957,38 +958,38 @@ def test_invalid_event_handler_args(component2, test_state): # lambda signature must match event trigger. with pytest.raises(EventFnArgMismatch): component2.create(on_click=lambda _: test_state.do_something_arg(1)) - with pytest.raises(EventFnArgMismatch): - component2.create(on_open=lambda: test_state.do_something) - with pytest.raises(EventFnArgMismatch): - component2.create(on_prop_event=lambda: test_state.do_something) + # with pytest.raises(EventFnArgMismatch): + # component2.create(on_open=lambda: test_state.do_something) + # with pytest.raises(EventFnArgMismatch): + # component2.create(on_prop_event=lambda: test_state.do_something) # lambda returning EventHandler must match spec - with pytest.raises(EventHandlerArgMismatch): + with pytest.raises(EventFnArgMismatch): component2.create(on_click=lambda: test_state.do_something_arg) - with pytest.raises(EventHandlerArgMismatch): - component2.create(on_open=lambda _: test_state.do_something) - with pytest.raises(EventHandlerArgMismatch): - component2.create(on_prop_event=lambda _: test_state.do_something) + # with pytest.raises(EventFnArgMismatch): + # component2.create(on_open=lambda _: test_state.do_something) + # with pytest.raises(EventFnArgMismatch): + # component2.create(on_prop_event=lambda _: test_state.do_something) # Mixed EventSpec and EventHandler must match spec. - with pytest.raises(EventHandlerArgMismatch): + with pytest.raises(EventFnArgMismatch): component2.create( on_click=lambda: [ test_state.do_something_arg(1), test_state.do_something_arg, ] ) - with pytest.raises(EventHandlerArgMismatch): - component2.create( - on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something] - ) - with pytest.raises(EventHandlerArgMismatch): - component2.create( - on_prop_event=lambda _: [ - test_state.do_something_arg(1), - test_state.do_something, - ] - ) + # with pytest.raises(EventFnArgMismatch): + # component2.create( + # on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something] + # ) + # with pytest.raises(EventFnArgMismatch): + # component2.create( + # on_prop_event=lambda _: [ + # test_state.do_something_arg(1), + # test_state.do_something, + # ] + # ) def test_valid_event_handler_args(component2, test_state): diff --git a/tests/units/test_event.py b/tests/units/test_event.py index d7b7cf7a2..0ea559e28 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -107,7 +107,7 @@ def test_call_event_handler_partial(): def spec(a2: Var[str]) -> List[Var[str]]: return [a2] - handler = EventHandler(fn=test_fn_with_args) + handler = EventHandler(fn=test_fn_with_args, state_full_name="BigState") event_spec = handler(make_var("first")) event_spec2 = call_event_handler(event_spec, spec) @@ -115,7 +115,10 @@ def test_call_event_handler_partial(): assert len(event_spec.args) == 1 assert event_spec.args[0][0].equals(Var(_js_expr="arg1")) assert event_spec.args[0][1].equals(Var(_js_expr="first")) - assert format.format_event(event_spec) == 'Event("test_fn_with_args", {arg1:first})' + assert ( + format.format_event(event_spec) + == 'Event("BigState.test_fn_with_args", {arg1:first})' + ) assert event_spec2 is not event_spec assert event_spec2.handler == handler @@ -126,7 +129,7 @@ def test_call_event_handler_partial(): assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str)) assert ( format.format_event(event_spec2) - == 'Event("test_fn_with_args", {arg1:first,arg2:_a2})' + == 'Event("BigState.test_fn_with_args", {arg1:first,arg2:_a2})' )