From 510b71e6446d9fb818985ec6ac88fba2bc46f21e Mon Sep 17 00:00:00 2001 From: Martin Xu <15661672+martinxu9@users.noreply.github.com> Date: Fri, 18 Aug 2023 01:36:30 -0700 Subject: [PATCH] Check the yield/return from user defined event handlers are valid (#1614) --- reflex/app.py | 2 +- reflex/state.py | 40 ++++++++++++++++++++++++++++++++++------ tests/test_state.py | 26 ++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index d75fae40c..474ae3c5b 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -690,7 +690,7 @@ class EventNamespace(AsyncNamespace): self.app = app def on_connect(self, sid, environ): - """Event for when the websocket disconnects. + """Event for when the websocket is connected. Args: sid: The Socket.IO session id. diff --git a/reflex/state.py b/reflex/state.py index f5906197e..1b62d5c56 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -738,13 +738,42 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Clean the state to prepare for the next event. self._clean() + def _check_valid(self, handler: EventHandler, events: Any) -> Any: + """Check if the events yielded are valid. They must be EventHandlers or EventSpecs. + + Args: + handler: EventHandler. + events: The events to be checked. + + Raises: + TypeError: If any of the events are not valid. + + Returns: + The events as they are if valid. + """ + + def _is_valid_type(events: Any) -> bool: + return isinstance(events, (EventHandler, EventSpec)) + + if events is None or _is_valid_type(events): + return events + try: + if all(_is_valid_type(e) for e in events): + return events + except TypeError: + pass + + raise TypeError( + f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" + ) + async def _process_event( self, handler: EventHandler, state: State, payload: Dict ) -> AsyncIterator[Tuple[Optional[List[EventSpec]], bool]]: """Process event. Args: - handler: Eventhandler to process. + handler: EventHandler to process. state: State to process the handler. payload: The event payload. @@ -765,28 +794,27 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Handle regular functions. else: events = fn(**payload) - # Handle async generators. if inspect.isasyncgen(events): async for event in events: - yield event, False + yield self._check_valid(handler, event), False yield None, True # Handle regular generators. elif inspect.isgenerator(events): try: while True: - yield next(events), False + yield self._check_valid(handler, next(events)), False except StopIteration as si: # the "return" value of the generator is not available # in the loop, we must catch StopIteration to access it if si.value is not None: - yield si.value, False + yield self._check_valid(handler, si.value), False yield None, True # Handle regular event chains. else: - yield events, True + yield self._check_valid(handler, events), True # If an error occurs, throw a window alert. except Exception: diff --git a/tests/test_state.py b/tests/test_state.py index 9288fba2b..e13da9dac 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1205,3 +1205,29 @@ def test_error_on_state_method_shadow(): err.value.args[0] == f"The event handler name `reset` shadows a builtin State method; use a different name instead" ) + + +def test_state_with_invalid_yield(): + """Test that an error is thrown when a state yields an invalid value.""" + + class StateWithInvalidYield(rx.State): + """A state that yields an invalid value.""" + + def invalid_handler(self): + """Invalid handler. + + Yields: + an invalid value. + """ + yield 1 + + invalid_state = StateWithInvalidYield() + with pytest.raises(TypeError) as err: + invalid_state._check_valid( + invalid_state.event_handlers["invalid_handler"], + rx.event.Event(token="fake_token", name="invalid_handler"), + ) + assert ( + "must only return/yield: None, Events or other EventHandlers" + in err.value.args[0] + )