From e0984aa83431bc70a31a99f48c98e87470e1be48 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:53:50 +0100 Subject: [PATCH] Allow bound method as event handler (#4348) * subtract 1 arg if the method is a bound method * fix it early in user_args * only bound methods pls * add test --- reflex/event.py | 4 ++++ tests/units/test_event.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/reflex/event.py b/reflex/event.py index 312c9887f..a9e92b635 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -1346,6 +1346,10 @@ def check_fn_match_arg_spec( EventFnArgMismatch: Raised if the number of mandatory arguments do not match """ user_args = inspect.getfullargspec(user_func).args + # Drop the first argument if it's a bound method + if inspect.ismethod(user_func) and user_func.__self__ is not None: + user_args = user_args[1:] + user_default_args = inspect.getfullargspec(user_func).defaults number_of_user_args = len(user_args) - number_of_bound_args number_of_user_default_args = len(user_default_args) if user_default_args else 0 diff --git a/tests/units/test_event.py b/tests/units/test_event.py index f17b3c4e4..4399ab2a0 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -2,6 +2,7 @@ from typing import Callable, List import pytest +import reflex as rx from reflex.event import ( Event, EventChain, @@ -439,3 +440,17 @@ def test_event_var_data(): # Ensure chain carries _var_data chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec)) assert chain_var._get_all_var_data() == S.x._get_all_var_data() + + +def test_event_bound_method() -> None: + class S(BaseState): + @event + def e(self, arg: str): + print(arg) + + class Wrapper: + def get_handler(self, arg: str): + return S.e(arg) + + w = Wrapper() + _ = rx.input(on_change=w.get_handler)