Check the yield/return from user defined event handlers are valid (#1614)
This commit is contained in:
parent
efefa757a0
commit
510b71e644
@ -690,7 +690,7 @@ class EventNamespace(AsyncNamespace):
|
|||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
def on_connect(self, sid, environ):
|
def on_connect(self, sid, environ):
|
||||||
"""Event for when the websocket disconnects.
|
"""Event for when the websocket is connected.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sid: The Socket.IO session id.
|
sid: The Socket.IO session id.
|
||||||
|
@ -738,13 +738,42 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Clean the state to prepare for the next event.
|
# Clean the state to prepare for the next event.
|
||||||
self._clean()
|
self._clean()
|
||||||
|
|
||||||
|
def _check_valid(self, handler: EventHandler, events: Any) -> Any:
|
||||||
|
"""Check if the events yielded are valid. They must be EventHandlers or EventSpecs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handler: EventHandler.
|
||||||
|
events: The events to be checked.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If any of the events are not valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The events as they are if valid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _is_valid_type(events: Any) -> bool:
|
||||||
|
return isinstance(events, (EventHandler, EventSpec))
|
||||||
|
|
||||||
|
if events is None or _is_valid_type(events):
|
||||||
|
return events
|
||||||
|
try:
|
||||||
|
if all(_is_valid_type(e) for e in events):
|
||||||
|
return events
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
raise TypeError(
|
||||||
|
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
||||||
|
)
|
||||||
|
|
||||||
async def _process_event(
|
async def _process_event(
|
||||||
self, handler: EventHandler, state: State, payload: Dict
|
self, handler: EventHandler, state: State, payload: Dict
|
||||||
) -> AsyncIterator[Tuple[Optional[List[EventSpec]], bool]]:
|
) -> AsyncIterator[Tuple[Optional[List[EventSpec]], bool]]:
|
||||||
"""Process event.
|
"""Process event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
handler: Eventhandler to process.
|
handler: EventHandler to process.
|
||||||
state: State to process the handler.
|
state: State to process the handler.
|
||||||
payload: The event payload.
|
payload: The event payload.
|
||||||
|
|
||||||
@ -765,28 +794,27 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Handle regular functions.
|
# Handle regular functions.
|
||||||
else:
|
else:
|
||||||
events = fn(**payload)
|
events = fn(**payload)
|
||||||
|
|
||||||
# Handle async generators.
|
# Handle async generators.
|
||||||
if inspect.isasyncgen(events):
|
if inspect.isasyncgen(events):
|
||||||
async for event in events:
|
async for event in events:
|
||||||
yield event, False
|
yield self._check_valid(handler, event), False
|
||||||
yield None, True
|
yield None, True
|
||||||
|
|
||||||
# Handle regular generators.
|
# Handle regular generators.
|
||||||
elif inspect.isgenerator(events):
|
elif inspect.isgenerator(events):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
yield next(events), False
|
yield self._check_valid(handler, next(events)), False
|
||||||
except StopIteration as si:
|
except StopIteration as si:
|
||||||
# the "return" value of the generator is not available
|
# the "return" value of the generator is not available
|
||||||
# in the loop, we must catch StopIteration to access it
|
# in the loop, we must catch StopIteration to access it
|
||||||
if si.value is not None:
|
if si.value is not None:
|
||||||
yield si.value, False
|
yield self._check_valid(handler, si.value), False
|
||||||
yield None, True
|
yield None, True
|
||||||
|
|
||||||
# Handle regular event chains.
|
# Handle regular event chains.
|
||||||
else:
|
else:
|
||||||
yield events, True
|
yield self._check_valid(handler, events), True
|
||||||
|
|
||||||
# If an error occurs, throw a window alert.
|
# If an error occurs, throw a window alert.
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -1205,3 +1205,29 @@ def test_error_on_state_method_shadow():
|
|||||||
err.value.args[0]
|
err.value.args[0]
|
||||||
== f"The event handler name `reset` shadows a builtin State method; use a different name instead"
|
== f"The event handler name `reset` shadows a builtin State method; use a different name instead"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_with_invalid_yield():
|
||||||
|
"""Test that an error is thrown when a state yields an invalid value."""
|
||||||
|
|
||||||
|
class StateWithInvalidYield(rx.State):
|
||||||
|
"""A state that yields an invalid value."""
|
||||||
|
|
||||||
|
def invalid_handler(self):
|
||||||
|
"""Invalid handler.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
an invalid value.
|
||||||
|
"""
|
||||||
|
yield 1
|
||||||
|
|
||||||
|
invalid_state = StateWithInvalidYield()
|
||||||
|
with pytest.raises(TypeError) as err:
|
||||||
|
invalid_state._check_valid(
|
||||||
|
invalid_state.event_handlers["invalid_handler"],
|
||||||
|
rx.event.Event(token="fake_token", name="invalid_handler"),
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"must only return/yield: None, Events or other EventHandlers"
|
||||||
|
in err.value.args[0]
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user