From eb723b6ebe67ac5570043d243d1763cb4148764b Mon Sep 17 00:00:00 2001 From: Nikhil Rao Date: Wed, 8 Feb 2023 16:27:13 -0800 Subject: [PATCH] Add compile error for invalid event handlers (#482) --- pynecone/components/component.py | 18 ++++++++++++-- pynecone/utils.py | 33 +++++++++++++++++++++++--- tests/components/test_component.py | 38 ++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 5 deletions(-) diff --git a/pynecone/components/component.py b/pynecone/components/component.py index 279a9ac2c..3ad2b3bde 100644 --- a/pynecone/components/component.py +++ b/pynecone/components/component.py @@ -165,6 +165,9 @@ class Component(Base, ABC): Raises: ValueError: If the value is not a valid event chain. """ + # Check if the trigger is a controlled event. + is_controlled_event = event_trigger in self.get_controlled_triggers() + # If it's an event chain var, return it. if isinstance(value, Var): if value.type_ is not EventChain: @@ -182,8 +185,19 @@ class Component(Base, ABC): events = [] for v in value: if isinstance(v, EventHandler): - events.append(utils.call_event_handler(v, arg)) + # Call the event handler to get the event. + event = utils.call_event_handler(v, arg) + + # Check that the event handler takes no args if it's uncontrolled. + if not is_controlled_event and len(event.args) > 0: + raise ValueError( + f"Event handler: {v.fn} for uncontrolled event {event_trigger} should not take any args." + ) + + # Add the event to the chain. + events.append(event) elif isinstance(v, Callable): + # Call the lambda to get the event chain. events.extend(utils.call_event_fn(v, arg)) else: raise ValueError(f"Invalid event: {v}") @@ -197,7 +211,7 @@ class Component(Base, ABC): raise ValueError(f"Invalid event chain: {value}") # Add args to the event specs if necessary. - if event_trigger in self.get_controlled_triggers(): + if is_controlled_event: events = [ EventSpec( handler=e.handler, diff --git a/pynecone/utils.py b/pynecone/utils.py index 1009d0034..92c7f2000 100644 --- a/pynecone/utils.py +++ b/pynecone/utils.py @@ -1238,16 +1238,43 @@ def call_event_fn(fn: Callable, arg: Var) -> List[EventSpec]: Raises: ValueError: If the lambda has an invalid signature. """ + # Import here to avoid circular imports. + from pynecone.event import EventHandler, EventSpec + + # Get the args of the lambda. args = inspect.getfullargspec(fn).args + + # Call the lambda. if len(args) == 0: out = fn() elif len(args) == 1: out = fn(arg) else: raise ValueError(f"Lambda {fn} must have 0 or 1 arguments.") + + # Convert the output to a list. if not isinstance(out, List): out = [out] - return out + + # Convert any event specs to event specs. + events = [] + for e in out: + # Convert handlers to event specs. + if isinstance(e, EventHandler): + if len(args) == 0: + e = e() + elif len(args) == 1: + e = e(arg) + + # Make sure the event spec is valid. + if not isinstance(e, EventSpec): + raise ValueError(f"Lambda {fn} returned an invalid event spec: {e}.") + + # Add the event spec to the chain. + events.append(e) + + # Return the events. + return events def get_handler_args(event_spec: EventSpec, arg: Var) -> Tuple[Tuple[str, str], ...]: @@ -1261,11 +1288,11 @@ def get_handler_args(event_spec: EventSpec, arg: Var) -> Tuple[Tuple[str, str], The handler args. Raises: - TypeError: If the event handler has an invalid signature. + ValueError: If the event handler has an invalid signature. """ args = inspect.getfullargspec(event_spec.handler.fn).args if len(args) < 2: - raise TypeError( + raise ValueError( f"Event handler has an invalid signature, needed a method with a parameter, got {event_spec.handler}." ) return event_spec.args if len(args) > 2 else ((args[1], arg.name),) diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 8ebaecc6a..d21bc1e74 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -15,6 +15,12 @@ def TestState(): class TestState(State): num: int + def do_something(self): + pass + + def do_something_arg(self, arg): + pass + return TestState @@ -299,3 +305,35 @@ def test_custom_component_hash(my_component): component1 = CustomComponent(component_fn=my_component, prop1="test", prop2=1) component2 = CustomComponent(component_fn=my_component, prop1="test", prop2=2) assert {component1, component2} == {component1} + + +def test_invalid_event_handler_args(component2, TestState): + """Test that an invalid event handler raises an error. + + Args: + component2: A test component. + TestState: A test state. + """ + # Uncontrolled event handlers should not take args. + # This is okay. + component2.create(on_click=TestState.do_something) + # This is not okay. + with pytest.raises(ValueError): + component2.create(on_click=TestState.do_something_arg) + # However lambdas are okay. + component2.create(on_click=lambda: TestState.do_something_arg(1)) + component2.create( + on_click=lambda: [TestState.do_something_arg(1), TestState.do_something] + ) + component2.create( + on_click=lambda: [TestState.do_something_arg(1), TestState.do_something()] + ) + + # Controlled event handlers should take args. + # This is okay. + component2.create(on_open=TestState.do_something_arg) + # This is not okay. + with pytest.raises(ValueError): + component2.create(on_open=TestState.do_something) + with pytest.raises(ValueError): + component2.create(on_open=[TestState.do_something_arg, TestState.do_something])