diff --git a/pynecone/app.py b/pynecone/app.py index 40e07de22..883654ad0 100644 --- a/pynecone/app.py +++ b/pynecone/app.py @@ -394,6 +394,8 @@ class App(Base): # Compile the custom components. compiler.compile_components(custom_components) + self.state.convert_handlers_to_fns() + async def process( app: App, event: Event, sid: str, headers: Dict, client_ip: str diff --git a/pynecone/state.py b/pynecone/state.py index e375486f3..8dff76cfa 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -51,6 +51,9 @@ class State(Base, ABC): # Backend vars inherited inherited_backend_vars: ClassVar[Dict[str, Any]] = {} + # The event handlers. + event_handlers: ClassVar[Dict[str, EventHandler]] = {} + # The parent state. parent_state: Optional[State] = None @@ -181,8 +184,18 @@ class State(Base, ABC): } for name, fn in events.items(): event_handler = EventHandler(fn=fn) + cls.event_handlers[name] = event_handler setattr(cls, name, event_handler) + @classmethod + def convert_handlers_to_fns(cls): + """Convert the event handlers to functions. + + This is done so the state functions can be called as normal functions during runtime. + """ + for name, event_handler in cls.event_handlers.items(): + setattr(cls, name, event_handler.fn) + @classmethod @functools.lru_cache() def get_parent_state(cls) -> Optional[Type[State]]: @@ -545,7 +558,7 @@ class State(Base, ABC): path = event.name.split(".") path, name = path[:-1], path[-1] substate = self.get_substate(path) - handler = getattr(substate, name) + handler = substate.event_handlers[name] # type: ignore # Process the event. fn = functools.partial(handler.fn, substate) diff --git a/tests/test_state.py b/tests/test_state.py index 81dfb50cd..500d34ba9 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -204,6 +204,33 @@ def test_class_vars(test_state): } +def test_event_handlers(test_state): + """Test that event handler is set correctly. + + Args: + test_state: A state. + """ + expected = { + "change_both", + "do_nothing", + "do_something", + "set_array", + "set_complex", + "set_count", + "set_fig", + "set_key", + "set_mapping", + "set_num1", + "set_num2", + "set_obj", + "set_value", + "set_value2", + } + + cls = type(test_state) + assert set(cls.event_handlers.keys()).intersection(expected) == expected + + def test_default_value(test_state): """Test that the default value of a var is correct.