Call event handlers from other event handlers (#691)

This commit is contained in:
Elijah Ahianyo 2023-03-16 23:57:28 +00:00 committed by GitHub
parent 7067baf176
commit 592be487c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 1 deletions

View File

@ -394,6 +394,8 @@ class App(Base):
# Compile the custom components.
compiler.compile_components(custom_components)
self.state.convert_handlers_to_fns()
async def process(
app: App, event: Event, sid: str, headers: Dict, client_ip: str

View File

@ -51,6 +51,9 @@ class State(Base, ABC):
# Backend vars inherited
inherited_backend_vars: ClassVar[Dict[str, Any]] = {}
# The event handlers.
event_handlers: ClassVar[Dict[str, EventHandler]] = {}
# The parent state.
parent_state: Optional[State] = None
@ -181,8 +184,18 @@ class State(Base, ABC):
}
for name, fn in events.items():
event_handler = EventHandler(fn=fn)
cls.event_handlers[name] = event_handler
setattr(cls, name, event_handler)
@classmethod
def convert_handlers_to_fns(cls):
"""Convert the event handlers to functions.
This is done so the state functions can be called as normal functions during runtime.
"""
for name, event_handler in cls.event_handlers.items():
setattr(cls, name, event_handler.fn)
@classmethod
@functools.lru_cache()
def get_parent_state(cls) -> Optional[Type[State]]:
@ -545,7 +558,7 @@ class State(Base, ABC):
path = event.name.split(".")
path, name = path[:-1], path[-1]
substate = self.get_substate(path)
handler = getattr(substate, name)
handler = substate.event_handlers[name] # type: ignore
# Process the event.
fn = functools.partial(handler.fn, substate)

View File

@ -204,6 +204,33 @@ def test_class_vars(test_state):
}
def test_event_handlers(test_state):
"""Test that event handler is set correctly.
Args:
test_state: A state.
"""
expected = {
"change_both",
"do_nothing",
"do_something",
"set_array",
"set_complex",
"set_count",
"set_fig",
"set_key",
"set_mapping",
"set_num1",
"set_num2",
"set_obj",
"set_value",
"set_value2",
}
cls = type(test_state)
assert set(cls.event_handlers.keys()).intersection(expected) == expected
def test_default_value(test_state):
"""Test that the default value of a var is correct.