state: _init_event_handlers recursively (#1640)
This commit is contained in:
parent
dbaa6a1e56
commit
12e516da64
@ -106,11 +106,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
for substate in self.get_substates():
|
for substate in self.get_substates():
|
||||||
self.substates[substate.get_name()] = substate(parent_state=self)
|
self.substates[substate.get_name()] = substate(parent_state=self)
|
||||||
# Convert the event handlers to functions.
|
# Convert the event handlers to functions.
|
||||||
for name, event_handler in self.event_handlers.items():
|
self._init_event_handlers()
|
||||||
fn = functools.partial(event_handler.fn, self)
|
|
||||||
fn.__module__ = event_handler.fn.__module__ # type: ignore
|
|
||||||
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
|
||||||
setattr(self, name, fn)
|
|
||||||
|
|
||||||
# Initialize computed vars dependencies.
|
# Initialize computed vars dependencies.
|
||||||
inherited_vars = set(self.inherited_vars).union(
|
inherited_vars = set(self.inherited_vars).union(
|
||||||
@ -155,6 +151,29 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
self._clean()
|
self._clean()
|
||||||
|
|
||||||
|
def _init_event_handlers(self, state: State | None = None):
|
||||||
|
"""Initialize event handlers.
|
||||||
|
|
||||||
|
Allow event handlers to be called directly on the instance. This is
|
||||||
|
called recursively for all parent states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The state to initialize the event handlers on.
|
||||||
|
"""
|
||||||
|
if state is None:
|
||||||
|
state = self
|
||||||
|
|
||||||
|
# Convert the event handlers to functions.
|
||||||
|
for name, event_handler in state.event_handlers.items():
|
||||||
|
fn = functools.partial(event_handler.fn, self)
|
||||||
|
fn.__module__ = event_handler.fn.__module__ # type: ignore
|
||||||
|
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
||||||
|
setattr(self, name, fn)
|
||||||
|
|
||||||
|
# Also allow direct calling of parent state event handlers
|
||||||
|
if state.parent_state is not None:
|
||||||
|
self._init_event_handlers(state.parent_state)
|
||||||
|
|
||||||
def _reassign_field(self, field_name: str):
|
def _reassign_field(self, field_name: str):
|
||||||
"""Reassign the given field.
|
"""Reassign the given field.
|
||||||
|
|
||||||
|
@ -992,10 +992,18 @@ def test_event_handlers_call_other_handlers():
|
|||||||
def set_v2(self, v: int):
|
def set_v2(self, v: int):
|
||||||
self.set_v(v)
|
self.set_v(v)
|
||||||
|
|
||||||
|
class SubState(MainState):
|
||||||
|
def set_v3(self, v: int):
|
||||||
|
self.set_v2(v)
|
||||||
|
|
||||||
ms = MainState()
|
ms = MainState()
|
||||||
ms.set_v2(1)
|
ms.set_v2(1)
|
||||||
assert ms.v == 1
|
assert ms.v == 1
|
||||||
|
|
||||||
|
# ensure handler can be called from substate
|
||||||
|
ms.substates[SubState.get_name()].set_v3(2)
|
||||||
|
assert ms.v == 2
|
||||||
|
|
||||||
|
|
||||||
def test_computed_var_cached():
|
def test_computed_var_cached():
|
||||||
"""Test that a ComputedVar doesn't recalculate when accessed."""
|
"""Test that a ComputedVar doesn't recalculate when accessed."""
|
||||||
|
Loading…
Reference in New Issue
Block a user