From 12e516da64e4834733a5b8b8efdfde08e5dd4406 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 25 Aug 2023 13:28:58 -0700 Subject: [PATCH] state: _init_event_handlers recursively (#1640) --- reflex/state.py | 29 ++++++++++++++++++++++++----- tests/test_state.py | 8 ++++++++ 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 1b62d5c56..7cd789fa0 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -106,11 +106,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow): for substate in self.get_substates(): self.substates[substate.get_name()] = substate(parent_state=self) # Convert the event handlers to functions. - for name, event_handler in self.event_handlers.items(): - fn = functools.partial(event_handler.fn, self) - fn.__module__ = event_handler.fn.__module__ # type: ignore - fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore - setattr(self, name, fn) + self._init_event_handlers() # Initialize computed vars dependencies. inherited_vars = set(self.inherited_vars).union( @@ -155,6 +151,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow): self._clean() + def _init_event_handlers(self, state: State | None = None): + """Initialize event handlers. + + Allow event handlers to be called directly on the instance. This is + called recursively for all parent states. + + Args: + state: The state to initialize the event handlers on. + """ + if state is None: + state = self + + # Convert the event handlers to functions. + for name, event_handler in state.event_handlers.items(): + fn = functools.partial(event_handler.fn, self) + fn.__module__ = event_handler.fn.__module__ # type: ignore + fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore + setattr(self, name, fn) + + # Also allow direct calling of parent state event handlers + if state.parent_state is not None: + self._init_event_handlers(state.parent_state) + def _reassign_field(self, field_name: str): """Reassign the given field. diff --git a/tests/test_state.py b/tests/test_state.py index e13da9dac..f5650fc0f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -992,10 +992,18 @@ def test_event_handlers_call_other_handlers(): def set_v2(self, v: int): self.set_v(v) + class SubState(MainState): + def set_v3(self, v: int): + self.set_v2(v) + ms = MainState() ms.set_v2(1) assert ms.v == 1 + # ensure handler can be called from substate + ms.substates[SubState.get_name()].set_v3(2) + assert ms.v == 2 + def test_computed_var_cached(): """Test that a ComputedVar doesn't recalculate when accessed."""