Check the yield/return from user defined event handlers are valid (#1614)
This commit is contained in:
parent
efefa757a0
commit
510b71e644
@ -690,7 +690,7 @@ class EventNamespace(AsyncNamespace):
|
||||
self.app = app
|
||||
|
||||
def on_connect(self, sid, environ):
|
||||
"""Event for when the websocket disconnects.
|
||||
"""Event for when the websocket is connected.
|
||||
|
||||
Args:
|
||||
sid: The Socket.IO session id.
|
||||
|
@ -738,13 +738,42 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Clean the state to prepare for the next event.
|
||||
self._clean()
|
||||
|
||||
def _check_valid(self, handler: EventHandler, events: Any) -> Any:
|
||||
"""Check if the events yielded are valid. They must be EventHandlers or EventSpecs.
|
||||
|
||||
Args:
|
||||
handler: EventHandler.
|
||||
events: The events to be checked.
|
||||
|
||||
Raises:
|
||||
TypeError: If any of the events are not valid.
|
||||
|
||||
Returns:
|
||||
The events as they are if valid.
|
||||
"""
|
||||
|
||||
def _is_valid_type(events: Any) -> bool:
|
||||
return isinstance(events, (EventHandler, EventSpec))
|
||||
|
||||
if events is None or _is_valid_type(events):
|
||||
return events
|
||||
try:
|
||||
if all(_is_valid_type(e) for e in events):
|
||||
return events
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
raise TypeError(
|
||||
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
||||
)
|
||||
|
||||
async def _process_event(
|
||||
self, handler: EventHandler, state: State, payload: Dict
|
||||
) -> AsyncIterator[Tuple[Optional[List[EventSpec]], bool]]:
|
||||
"""Process event.
|
||||
|
||||
Args:
|
||||
handler: Eventhandler to process.
|
||||
handler: EventHandler to process.
|
||||
state: State to process the handler.
|
||||
payload: The event payload.
|
||||
|
||||
@ -765,28 +794,27 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Handle regular functions.
|
||||
else:
|
||||
events = fn(**payload)
|
||||
|
||||
# Handle async generators.
|
||||
if inspect.isasyncgen(events):
|
||||
async for event in events:
|
||||
yield event, False
|
||||
yield self._check_valid(handler, event), False
|
||||
yield None, True
|
||||
|
||||
# Handle regular generators.
|
||||
elif inspect.isgenerator(events):
|
||||
try:
|
||||
while True:
|
||||
yield next(events), False
|
||||
yield self._check_valid(handler, next(events)), False
|
||||
except StopIteration as si:
|
||||
# the "return" value of the generator is not available
|
||||
# in the loop, we must catch StopIteration to access it
|
||||
if si.value is not None:
|
||||
yield si.value, False
|
||||
yield self._check_valid(handler, si.value), False
|
||||
yield None, True
|
||||
|
||||
# Handle regular event chains.
|
||||
else:
|
||||
yield events, True
|
||||
yield self._check_valid(handler, events), True
|
||||
|
||||
# If an error occurs, throw a window alert.
|
||||
except Exception:
|
||||
|
@ -1205,3 +1205,29 @@ def test_error_on_state_method_shadow():
|
||||
err.value.args[0]
|
||||
== f"The event handler name `reset` shadows a builtin State method; use a different name instead"
|
||||
)
|
||||
|
||||
|
||||
def test_state_with_invalid_yield():
|
||||
"""Test that an error is thrown when a state yields an invalid value."""
|
||||
|
||||
class StateWithInvalidYield(rx.State):
|
||||
"""A state that yields an invalid value."""
|
||||
|
||||
def invalid_handler(self):
|
||||
"""Invalid handler.
|
||||
|
||||
Yields:
|
||||
an invalid value.
|
||||
"""
|
||||
yield 1
|
||||
|
||||
invalid_state = StateWithInvalidYield()
|
||||
with pytest.raises(TypeError) as err:
|
||||
invalid_state._check_valid(
|
||||
invalid_state.event_handlers["invalid_handler"],
|
||||
rx.event.Event(token="fake_token", name="invalid_handler"),
|
||||
)
|
||||
assert (
|
||||
"must only return/yield: None, Events or other EventHandlers"
|
||||
in err.value.args[0]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user