state: _init_event_handlers recursively (#1640)
This commit is contained in:
parent
dbaa6a1e56
commit
12e516da64
@ -106,11 +106,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
for substate in self.get_substates():
|
||||
self.substates[substate.get_name()] = substate(parent_state=self)
|
||||
# Convert the event handlers to functions.
|
||||
for name, event_handler in self.event_handlers.items():
|
||||
fn = functools.partial(event_handler.fn, self)
|
||||
fn.__module__ = event_handler.fn.__module__ # type: ignore
|
||||
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
||||
setattr(self, name, fn)
|
||||
self._init_event_handlers()
|
||||
|
||||
# Initialize computed vars dependencies.
|
||||
inherited_vars = set(self.inherited_vars).union(
|
||||
@ -155,6 +151,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
self._clean()
|
||||
|
||||
def _init_event_handlers(self, state: State | None = None):
|
||||
"""Initialize event handlers.
|
||||
|
||||
Allow event handlers to be called directly on the instance. This is
|
||||
called recursively for all parent states.
|
||||
|
||||
Args:
|
||||
state: The state to initialize the event handlers on.
|
||||
"""
|
||||
if state is None:
|
||||
state = self
|
||||
|
||||
# Convert the event handlers to functions.
|
||||
for name, event_handler in state.event_handlers.items():
|
||||
fn = functools.partial(event_handler.fn, self)
|
||||
fn.__module__ = event_handler.fn.__module__ # type: ignore
|
||||
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
||||
setattr(self, name, fn)
|
||||
|
||||
# Also allow direct calling of parent state event handlers
|
||||
if state.parent_state is not None:
|
||||
self._init_event_handlers(state.parent_state)
|
||||
|
||||
def _reassign_field(self, field_name: str):
|
||||
"""Reassign the given field.
|
||||
|
||||
|
@ -992,10 +992,18 @@ def test_event_handlers_call_other_handlers():
|
||||
def set_v2(self, v: int):
|
||||
self.set_v(v)
|
||||
|
||||
class SubState(MainState):
|
||||
def set_v3(self, v: int):
|
||||
self.set_v2(v)
|
||||
|
||||
ms = MainState()
|
||||
ms.set_v2(1)
|
||||
assert ms.v == 1
|
||||
|
||||
# ensure handler can be called from substate
|
||||
ms.substates[SubState.get_name()].set_v3(2)
|
||||
assert ms.v == 2
|
||||
|
||||
|
||||
def test_computed_var_cached():
|
||||
"""Test that a ComputedVar doesn't recalculate when accessed."""
|
||||
|
Loading…
Reference in New Issue
Block a user