Fix processing flag for generator event handlers (#1136)

This commit is contained in:
Nikhil Rao 2023-06-05 11:30:59 -07:00 committed by GitHub
parent 9812ab2a58
commit 895719cf68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 30 additions and 19 deletions

View File

@ -47,7 +47,7 @@ export default function Component() {
{{const.result|react_setter}}({ {{const.result|react_setter}}({
{{const.state}}: null, {{const.state}}: null,
{{const.events}}: [], {{const.events}}: [],
{{const.processing}}: false, {{const.processing}}: {{const.result}}.{{const.processing}},
}) })
} }

View File

@ -214,7 +214,7 @@ export const connect = async (
update = JSON5.parse(update); update = JSON5.parse(update);
applyDelta(state, update.delta); applyDelta(state, update.delta);
setResult({ setResult({
processing: true, processing: update.processing,
state: state, state: state,
events: update.events, events: update.events,
}); });

View File

@ -466,11 +466,11 @@ async def process(
else: else:
# Process the event. # Process the event.
async for update in state._process(event): async for update in state._process(event):
yield update # Postprocess the event.
update = await app.postprocess(state, event, update)
# Postprocess the event. # Yield the update.
assert update is not None, "Process did not return an update." yield update
update = await app.postprocess(state, event, update)
# Set the state for the session. # Set the state for the session.
app.state_manager.set_state(event.token, state) app.state_manager.set_state(event.token, state)

View File

@ -18,6 +18,7 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Set, Set,
Tuple,
Type, Type,
Union, Union,
) )
@ -654,7 +655,7 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
self.clean() self.clean()
# Run the event generator and return state updates. # Run the event generator and return state updates.
async for events in event_iter: async for events, processing in event_iter:
# Fix the returned events. # Fix the returned events.
events = fix_events(events, event.token) # type: ignore events = fix_events(events, event.token) # type: ignore
@ -662,14 +663,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
delta = self.get_delta() delta = self.get_delta()
# Yield the state update. # Yield the state update.
yield StateUpdate(delta=delta, events=events) yield StateUpdate(delta=delta, events=events, processing=processing)
# Clean the state to prepare for the next event. # Clean the state to prepare for the next event.
self.clean() self.clean()
async def _process_event( async def _process_event(
self, handler: EventHandler, state: State, payload: Dict self, handler: EventHandler, state: State, payload: Dict
) -> AsyncIterator[Optional[List[EventSpec]]]: ) -> AsyncIterator[Tuple[Optional[List[EventSpec]], bool]]:
"""Process event. """Process event.
Args: Args:
@ -678,7 +679,9 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
payload: The event payload. payload: The event payload.
Yields: Yields:
The state update after processing the event. Tuple containing:
0: The state update after processing the event.
1: Whether the event is being processed.
""" """
# Get the function to process the event. # Get the function to process the event.
fn = functools.partial(handler.fn, state) fn = functools.partial(handler.fn, state)
@ -696,22 +699,24 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Handle async generators. # Handle async generators.
if inspect.isasyncgen(events): if inspect.isasyncgen(events):
async for event in events: async for event in events:
yield event yield event, True
yield None, False
# Handle regular generators. # Handle regular generators.
elif inspect.isgenerator(events): elif inspect.isgenerator(events):
for event in events: for event in events:
yield event yield event, True
yield None, False
# Handle regular event chains. # Handle regular event chains.
else: else:
yield events yield events, False
# If an error occurs, throw a window alert. # If an error occurs, throw a window alert.
except Exception: except Exception:
error = traceback.format_exc() error = traceback.format_exc()
print(error) print(error)
yield [window_alert("An error occurred. See logs for details.")] yield [window_alert("An error occurred. See logs for details.")], False
def _always_dirty_computed_vars(self) -> Set[str]: def _always_dirty_computed_vars(self) -> Set[str]:
"""The set of ComputedVars that always need to be recalculated. """The set of ComputedVars that always need to be recalculated.
@ -876,6 +881,9 @@ class StateUpdate(Base):
# Events to be added to the event queue. # Events to be added to the event queue.
events: List[Event] = [] events: List[Event] = []
# Whether the event is still processing.
processing: bool = False
class StateManager(Base): class StateManager(Base):
"""A class to manage many client states.""" """A class to manage many client states."""

View File

@ -673,12 +673,15 @@ async def test_process_event_generator(gen_state):
count = 0 count = 0
async for update in gen: async for update in gen:
count += 1 count += 1
assert gen_state.value == count if count == 6:
assert update.delta == { assert update.delta == {}
"gen_state": {"value": count}, else:
} assert gen_state.value == count
assert update.delta == {
"gen_state": {"value": count},
}
assert count == 5 assert count == 6
def test_format_event_handler(): def test_format_event_handler():