Fix event handler returns (#788)

This commit is contained in:
Nikhil Rao 2023-04-08 10:49:00 -07:00 committed by GitHub
parent e8387c8e26
commit e96f1c4d39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 20 deletions

View File

@ -404,7 +404,7 @@ class App(Base):
async def process( async def process(
app: App, event: Event, sid: str, headers: Dict, client_ip: str app: App, event: Event, sid: str, headers: Dict, client_ip: str
) -> Union[StateUpdate, List[StateUpdate]]: ) -> List[StateUpdate]:
"""Process an event. """Process an event.
Args: Args:
@ -415,34 +415,37 @@ async def process(
client_ip: The client_ip. client_ip: The client_ip.
Returns: Returns:
The state update(s) after processing the event. The state updates after processing the event.
""" """
# Get the state for the session. # Get the state for the session.
state = app.state_manager.get_state(event.token) state = app.state_manager.get_state(event.token)
formatted_params = format.format_query_params(event.router_data) # Add request data to the state.
# Pass router_data to the state of the App.
state.router_data = event.router_data state.router_data = event.router_data
# also pass router_data to all substates state.router_data.update(
{
constants.RouteVar.QUERY: format.format_query_params(event.router_data),
constants.RouteVar.CLIENT_TOKEN: event.token,
constants.RouteVar.SESSION_ID: sid,
constants.RouteVar.HEADERS: headers,
constants.RouteVar.CLIENT_IP: client_ip,
}
)
# Also pass router_data to all substates. (TODO: this isn't recursive currently)
for _, substate in state.substates.items(): for _, substate in state.substates.items():
substate.router_data = event.router_data substate.router_data = state.router_data
state.router_data[constants.RouteVar.QUERY] = formatted_params
state.router_data[constants.RouteVar.CLIENT_TOKEN] = event.token
state.router_data[constants.RouteVar.SESSION_ID] = sid
state.router_data[constants.RouteVar.HEADERS] = headers
state.router_data[constants.RouteVar.CLIENT_IP] = client_ip
# Preprocess the event. # Preprocess the event.
pre = await app.preprocess(state, event) pre = await app.preprocess(state, event)
if pre is not None and not isinstance(pre, List): if isinstance(pre, StateUpdate):
return pre return [pre]
updates = pre
# Apply the event to the state. # Apply the event to the state.
updates = pre if pre else await state.process(event) if updates is None:
app.state_manager.set_state(event.token, state) updates = [await state.process(event)]
app.state_manager.set_state(event.token, state)
updates = updates if isinstance(updates, List) else [updates]
# Postprocess the event. # Postprocess the event.
post_list = [] post_list = []
@ -450,10 +453,10 @@ async def process(
post = await app.postprocess(state, event, update.delta) # type: ignore post = await app.postprocess(state, event, update.delta) # type: ignore
post_list.append(post) if post else None post_list.append(post) if post else None
if post_list: if len(post_list) > 0:
return [StateUpdate(delta=post) for post in post_list] return [StateUpdate(delta=post) for post in post_list]
# Return the update. # Return the updates.
return updates return updates

View File

@ -346,6 +346,8 @@ def fix_events(
# Fix the events created by the handler. # Fix the events created by the handler.
out = [] out = []
for e in events: for e in events:
if not isinstance(e, (EventHandler, EventSpec)):
e = EventHandler(fn=e)
# Otherwise, create an event from the event spec. # Otherwise, create an event from the event spec.
if isinstance(e, EventHandler): if isinstance(e, EventHandler):
e = e() e = e()