diff --git a/reflex/components/component.py b/reflex/components/component.py index 632a7b482..6d4fdded1 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -527,15 +527,7 @@ class Component(BaseComponent, ABC): for v in value: if isinstance(v, (EventHandler, EventSpec)): # Call the event handler to get the event. - try: - event = call_event_handler(v, args_spec) - except ValueError as err: - raise ValueError( - f" {err} defined in the `{type(self).__name__}` component" - ) from err - - # Add the event to the chain. - events.append(event) + events.append(call_event_handler(v, args_spec)) elif isinstance(v, Callable): # Call the lambda to get the event chain. result = call_event_fn(v, args_spec) diff --git a/reflex/event.py b/reflex/event.py index 6d0577c6f..b8fed5708 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -3,6 +3,7 @@ from __future__ import annotations import inspect +import types import urllib.parse from base64 import b64encode from typing import ( @@ -22,6 +23,7 @@ from reflex.ivars.base import ImmutableVar, LiteralVar from reflex.ivars.function import FunctionStringVar, FunctionVar from reflex.ivars.object import ObjectVar from reflex.utils import format +from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch from reflex.utils.types import ArgsSpec from reflex.vars import ImmutableVarData, Var @@ -831,7 +833,7 @@ def call_event_handler( arg_spec: The lambda that define the argument(s) to pass to the event handler. Raises: - ValueError: if number of arguments expected by event_handler doesn't match the spec. + EventHandlerArgMismatch: if number of arguments expected by event_handler doesn't match the spec. Returns: The event spec from calling the event handler. @@ -843,13 +845,16 @@ def call_event_handler( return event_handler.add_args(*parsed_args) args = inspect.getfullargspec(event_handler.fn).args - if len(args) == len(["self", *parsed_args]): + n_args = len(args) - 1 # subtract 1 for bound self arg + if n_args == len(parsed_args): return event_handler(*parsed_args) # type: ignore else: - source = inspect.getsource(arg_spec) # type: ignore - raise ValueError( - f"number of arguments in {event_handler.fn.__qualname__} " - f"doesn't match the definition of the event trigger '{source.strip().strip(',')}'" + raise EventHandlerArgMismatch( + "The number of arguments accepted by " + f"{event_handler.fn.__qualname__} ({n_args}) " + "does not match the arguments passed by the event trigger: " + f"{[str(v) for v in parsed_args]}\n" + "See https://reflex.dev/docs/events/event-arguments/" ) @@ -874,58 +879,60 @@ def parse_args_spec(arg_spec: ArgsSpec): ) -def call_event_fn(fn: Callable, arg: Union[Var, ArgsSpec]) -> list[EventSpec] | Var: +def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: """Call a function to a list of event specs. The function should return a single EventSpec, a list of EventSpecs, or a - single Var. If the function takes in an arg, the arg will be passed to the - function. Otherwise, the function will be called with no args. + single Var. The function signature must match the passed arg_spec or + EventFnArgsMismatch will be raised. Args: fn: The function to call. - arg: The argument to pass to the function. + arg_spec: The argument spec for the event trigger. Returns: The event specs from calling the function or a Var. Raises: - EventHandlerValueError: If the lambda has an invalid signature. + EventFnArgMismatch: If the function signature doesn't match the arg spec. + EventHandlerValueError: If the lambda returns an unusable value. """ # Import here to avoid circular imports. from reflex.event import EventHandler, EventSpec from reflex.utils.exceptions import EventHandlerValueError - # Get the args of the lambda. - args = inspect.getfullargspec(fn).args + # Check that fn signature matches arg_spec + fn_args = inspect.getfullargspec(fn).args + n_fn_args = len(fn_args) + if isinstance(fn, types.MethodType): + n_fn_args -= 1 # subtract 1 for bound self arg + parsed_args = parse_args_spec(arg_spec) + if len(parsed_args) != n_fn_args: + raise EventFnArgMismatch( + "The number of arguments accepted by " + f"{fn} ({n_fn_args}) " + "does not match the arguments passed by the event trigger: " + f"{[str(v) for v in parsed_args]}\n" + "See https://reflex.dev/docs/events/event-arguments/" + ) - if isinstance(arg, ArgsSpec): - out = fn(*parse_args_spec(arg)) # type: ignore - else: - # Call the lambda. - if len(args) == 0: - out = fn() - elif len(args) == 1: - out = fn(arg) - else: - raise EventHandlerValueError(f"Lambda {fn} must have 0 or 1 arguments.") + # Call the function with the parsed args. + out = fn(*parsed_args) # If the function returns a Var, assume it's an EventChain and render it directly. if isinstance(out, Var): return out # Convert the output to a list. - if not isinstance(out, List): + if not isinstance(out, list): out = [out] # Convert any event specs to event specs. events = [] for e in out: - # Convert handlers to event specs. if isinstance(e, EventHandler): - if len(args) == 0: - e = e() - elif len(args) == 1: - e = e(arg) # type: ignore + # An un-called EventHandler gets all of the args of the event trigger. + e = call_event_handler(e, arg_spec) # Make sure the event spec is valid. if not isinstance(e, EventSpec): diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index d219dcf0c..8c1a1f07f 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -79,3 +79,11 @@ class LockExpiredError(ReflexError): class MatchTypeError(ReflexError, TypeError): """Raised when the return types of match cases are different.""" + + +class EventHandlerArgMismatch(ReflexError, TypeError): + """Raised when the number of args accepted by an EventHandler is differs from that provided by the event trigger.""" + + +class EventFnArgMismatch(ReflexError, TypeError): + """Raised when the number of args accepted by a lambda differs from that provided by the event trigger.""" diff --git a/tests/components/test_component.py b/tests/components/test_component.py index d88550041..3c824ed01 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -22,6 +22,7 @@ from reflex.ivars.base import LiteralVar from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports +from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import BaseVar, Var, VarData @@ -79,6 +80,8 @@ def component2() -> Type[Component]: # A test list prop. arr: Var[List[str]] + on_prop_event: EventHandler[lambda e0: [e0]] + def get_event_triggers(self) -> Dict[str, Any]: """Test controlled triggers. @@ -496,7 +499,7 @@ def test_get_props(component1, component2): component2: A test component. """ assert component1.get_props() == {"text", "number", "text_or_number"} - assert component2.get_props() == {"arr"} + assert component2.get_props() == {"arr", "on_prop_event"} @pytest.mark.parametrize( @@ -574,7 +577,7 @@ def test_get_event_triggers(component1, component2): assert component1().get_event_triggers().keys() == default_triggers assert ( component2().get_event_triggers().keys() - == {"on_open", "on_close"} | default_triggers + == {"on_open", "on_close", "on_prop_event"} | default_triggers ) @@ -888,18 +891,105 @@ def test_invalid_event_handler_args(component2, test_state): component2: A test component. test_state: A test state. """ - # Uncontrolled event handlers should not take args. - # This is okay. - component2.create(on_click=test_state.do_something) - # This is not okay. - with pytest.raises(ValueError): + # EventHandler args must match + with pytest.raises(EventHandlerArgMismatch): component2.create(on_click=test_state.do_something_arg) + with pytest.raises(EventHandlerArgMismatch): component2.create(on_open=test_state.do_something) + with pytest.raises(EventHandlerArgMismatch): + component2.create(on_prop_event=test_state.do_something) + + # Multiple EventHandler args: all must match + with pytest.raises(EventHandlerArgMismatch): + component2.create( + on_click=[test_state.do_something_arg, test_state.do_something] + ) + with pytest.raises(EventHandlerArgMismatch): component2.create( on_open=[test_state.do_something_arg, test_state.do_something] ) - # However lambdas are okay. + with pytest.raises(EventHandlerArgMismatch): + component2.create( + on_prop_event=[test_state.do_something_arg, test_state.do_something] + ) + + # lambda cannot return weird values. + with pytest.raises(ValueError): + component2.create(on_click=lambda: 1) + with pytest.raises(ValueError): + component2.create(on_click=lambda: [1]) + with pytest.raises(ValueError): + component2.create( + on_click=lambda: (test_state.do_something_arg(1), test_state.do_something) + ) + + # lambda signature must match event trigger. + with pytest.raises(EventFnArgMismatch): + component2.create(on_click=lambda _: test_state.do_something_arg(1)) + with pytest.raises(EventFnArgMismatch): + component2.create(on_open=lambda: test_state.do_something) + with pytest.raises(EventFnArgMismatch): + component2.create(on_prop_event=lambda: test_state.do_something) + + # lambda returning EventHandler must match spec + with pytest.raises(EventHandlerArgMismatch): + component2.create(on_click=lambda: test_state.do_something_arg) + with pytest.raises(EventHandlerArgMismatch): + component2.create(on_open=lambda _: test_state.do_something) + with pytest.raises(EventHandlerArgMismatch): + component2.create(on_prop_event=lambda _: test_state.do_something) + + # Mixed EventSpec and EventHandler must match spec. + with pytest.raises(EventHandlerArgMismatch): + component2.create( + on_click=lambda: [ + test_state.do_something_arg(1), + test_state.do_something_arg, + ] + ) + with pytest.raises(EventHandlerArgMismatch): + component2.create( + on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something] + ) + with pytest.raises(EventHandlerArgMismatch): + component2.create( + on_prop_event=lambda _: [ + test_state.do_something_arg(1), + test_state.do_something, + ] + ) + + +def test_valid_event_handler_args(component2, test_state): + """Test that an valid event handler args do not raise exception. + + Args: + component2: A test component. + test_state: A test state. + """ + # Uncontrolled event handlers should not take args. + component2.create(on_click=test_state.do_something) + component2.create(on_click=test_state.do_something_arg(1)) + + # Controlled event handlers should take args. + component2.create(on_open=test_state.do_something_arg) + component2.create(on_prop_event=test_state.do_something_arg) + + # Using a partial event spec bypasses arg validation (ignoring the args). + component2.create(on_open=test_state.do_something()) + component2.create(on_prop_event=test_state.do_something()) + + # lambda returning EventHandler is okay if the spec matches. + component2.create(on_click=lambda: test_state.do_something) + component2.create(on_open=lambda _: test_state.do_something_arg) + component2.create(on_prop_event=lambda _: test_state.do_something_arg) + + # lambda can always return an EventSpec. component2.create(on_click=lambda: test_state.do_something_arg(1)) + component2.create(on_open=lambda _: test_state.do_something_arg(1)) + component2.create(on_prop_event=lambda _: test_state.do_something_arg(1)) + + # Return EventSpec and EventHandler (no arg). component2.create( on_click=lambda: [test_state.do_something_arg(1), test_state.do_something] ) @@ -907,9 +997,24 @@ def test_invalid_event_handler_args(component2, test_state): on_click=lambda: [test_state.do_something_arg(1), test_state.do_something()] ) - # Controlled event handlers should take args. - # This is okay. - component2.create(on_open=test_state.do_something_arg) + # Return 2 EventSpec. + component2.create( + on_open=lambda _: [test_state.do_something_arg(1), test_state.do_something()] + ) + component2.create( + on_prop_event=lambda _: [ + test_state.do_something_arg(1), + test_state.do_something(), + ] + ) + + # Return EventHandler (1 arg) and EventSpec. + component2.create( + on_open=lambda _: [test_state.do_something_arg, test_state.do_something()] + ) + component2.create( + on_prop_event=lambda _: [test_state.do_something_arg, test_state.do_something()] + ) def test_get_hooks_nested(component1, component2, component3):