From e0984aa83431bc70a31a99f48c98e87470e1be48 Mon Sep 17 00:00:00 2001
From: benedikt-bartscher
 <31854409+benedikt-bartscher@users.noreply.github.com>
Date: Thu, 21 Nov 2024 20:53:50 +0100
Subject: [PATCH] Allow bound method as event handler (#4348)

* subtract 1 arg if the method is a bound method

* fix it early in user_args

* only bound methods pls

* add test
---
 reflex/event.py           |  4 ++++
 tests/units/test_event.py | 15 +++++++++++++++
 2 files changed, 19 insertions(+)

diff --git a/reflex/event.py b/reflex/event.py
index 312c9887f..a9e92b635 100644
--- a/reflex/event.py
+++ b/reflex/event.py
@@ -1346,6 +1346,10 @@ def check_fn_match_arg_spec(
         EventFnArgMismatch: Raised if the number of mandatory arguments do not match
     """
     user_args = inspect.getfullargspec(user_func).args
+    # Drop the first argument if it's a bound method
+    if inspect.ismethod(user_func) and user_func.__self__ is not None:
+        user_args = user_args[1:]
+
     user_default_args = inspect.getfullargspec(user_func).defaults
     number_of_user_args = len(user_args) - number_of_bound_args
     number_of_user_default_args = len(user_default_args) if user_default_args else 0
diff --git a/tests/units/test_event.py b/tests/units/test_event.py
index f17b3c4e4..4399ab2a0 100644
--- a/tests/units/test_event.py
+++ b/tests/units/test_event.py
@@ -2,6 +2,7 @@ from typing import Callable, List
 
 import pytest
 
+import reflex as rx
 from reflex.event import (
     Event,
     EventChain,
@@ -439,3 +440,17 @@ def test_event_var_data():
     # Ensure chain carries _var_data
     chain_var = Var.create(EventChain(events=[S.s(S.x)], args_spec=_args_spec))
     assert chain_var._get_all_var_data() == S.x._get_all_var_data()
+
+
+def test_event_bound_method() -> None:
+    class S(BaseState):
+        @event
+        def e(self, arg: str):
+            print(arg)
+
+    class Wrapper:
+        def get_handler(self, arg: str):
+            return S.e(arg)
+
+    w = Wrapper()
+    _ = rx.input(on_change=w.get_handler)