Check the yield/return from user defined event handlers are valid (#1614)

This commit is contained in:
Martin Xu 2023-08-18 01:36:30 -07:00 committed by GitHub
parent efefa757a0
commit 510b71e644
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 7 deletions

View File

@ -690,7 +690,7 @@ class EventNamespace(AsyncNamespace):
self.app = app
def on_connect(self, sid, environ):
"""Event for when the websocket disconnects.
"""Event for when the websocket is connected.
Args:
sid: The Socket.IO session id.

View File

@ -738,13 +738,42 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Clean the state to prepare for the next event.
self._clean()
def _check_valid(self, handler: EventHandler, events: Any) -> Any:
"""Check if the events yielded are valid. They must be EventHandlers or EventSpecs.
Args:
handler: EventHandler.
events: The events to be checked.
Raises:
TypeError: If any of the events are not valid.
Returns:
The events as they are if valid.
"""
def _is_valid_type(events: Any) -> bool:
return isinstance(events, (EventHandler, EventSpec))
if events is None or _is_valid_type(events):
return events
try:
if all(_is_valid_type(e) for e in events):
return events
except TypeError:
pass
raise TypeError(
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
)
async def _process_event(
self, handler: EventHandler, state: State, payload: Dict
) -> AsyncIterator[Tuple[Optional[List[EventSpec]], bool]]:
"""Process event.
Args:
handler: Eventhandler to process.
handler: EventHandler to process.
state: State to process the handler.
payload: The event payload.
@ -765,28 +794,27 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Handle regular functions.
else:
events = fn(**payload)
# Handle async generators.
if inspect.isasyncgen(events):
async for event in events:
yield event, False
yield self._check_valid(handler, event), False
yield None, True
# Handle regular generators.
elif inspect.isgenerator(events):
try:
while True:
yield next(events), False
yield self._check_valid(handler, next(events)), False
except StopIteration as si:
# the "return" value of the generator is not available
# in the loop, we must catch StopIteration to access it
if si.value is not None:
yield si.value, False
yield self._check_valid(handler, si.value), False
yield None, True
# Handle regular event chains.
else:
yield events, True
yield self._check_valid(handler, events), True
# If an error occurs, throw a window alert.
except Exception:

View File

@ -1205,3 +1205,29 @@ def test_error_on_state_method_shadow():
err.value.args[0]
== f"The event handler name `reset` shadows a builtin State method; use a different name instead"
)
def test_state_with_invalid_yield():
"""Test that an error is thrown when a state yields an invalid value."""
class StateWithInvalidYield(rx.State):
"""A state that yields an invalid value."""
def invalid_handler(self):
"""Invalid handler.
Yields:
an invalid value.
"""
yield 1
invalid_state = StateWithInvalidYield()
with pytest.raises(TypeError) as err:
invalid_state._check_valid(
invalid_state.event_handlers["invalid_handler"],
rx.event.Event(token="fake_token", name="invalid_handler"),
)
assert (
"must only return/yield: None, Events or other EventHandlers"
in err.value.args[0]
)