state: _init_event_handlers recursively (#1640)

This commit is contained in:
Masen Furer 2023-08-25 13:28:58 -07:00 committed by GitHub
parent dbaa6a1e56
commit 12e516da64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 5 deletions

View File

@ -106,11 +106,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
for substate in self.get_substates():
self.substates[substate.get_name()] = substate(parent_state=self)
# Convert the event handlers to functions.
for name, event_handler in self.event_handlers.items():
fn = functools.partial(event_handler.fn, self)
fn.__module__ = event_handler.fn.__module__ # type: ignore
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
setattr(self, name, fn)
self._init_event_handlers()
# Initialize computed vars dependencies.
inherited_vars = set(self.inherited_vars).union(
@ -155,6 +151,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
self._clean()
def _init_event_handlers(self, state: State | None = None):
"""Initialize event handlers.
Allow event handlers to be called directly on the instance. This is
called recursively for all parent states.
Args:
state: The state to initialize the event handlers on.
"""
if state is None:
state = self
# Convert the event handlers to functions.
for name, event_handler in state.event_handlers.items():
fn = functools.partial(event_handler.fn, self)
fn.__module__ = event_handler.fn.__module__ # type: ignore
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
setattr(self, name, fn)
# Also allow direct calling of parent state event handlers
if state.parent_state is not None:
self._init_event_handlers(state.parent_state)
def _reassign_field(self, field_name: str):
"""Reassign the given field.

View File

@ -992,10 +992,18 @@ def test_event_handlers_call_other_handlers():
def set_v2(self, v: int):
self.set_v(v)
class SubState(MainState):
def set_v3(self, v: int):
self.set_v2(v)
ms = MainState()
ms.set_v2(1)
assert ms.v == 1
# ensure handler can be called from substate
ms.substates[SubState.get_name()].set_v3(2)
assert ms.v == 2
def test_computed_var_cached():
"""Test that a ComputedVar doesn't recalculate when accessed."""