Add async events (#1107)

This commit is contained in:
Nikhil Rao 2023-06-01 21:47:55 -07:00 committed by GitHub
parent f1ae27da69
commit a18c6880b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 151 additions and 57 deletions

View File

@ -2,7 +2,18 @@
import asyncio
import inspect
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
from fastapi import FastAPI, UploadFile
from fastapi.middleware import cors
@ -411,7 +422,7 @@ class App(Base):
async def process(
app: App, event: Event, sid: str, headers: Dict, client_ip: str
) -> StateUpdate:
) -> AsyncIterator[StateUpdate]:
"""Process an event.
Args:
@ -421,7 +432,7 @@ async def process(
headers: The client headers.
client_ip: The client_ip.
Returns:
Yields:
The state updates after processing the event.
"""
# Get the state for the session.
@ -447,20 +458,23 @@ async def process(
# Preprocess the event.
update = await app.preprocess(state, event)
# If there was an update, yield it.
if update is not None:
yield update
# Only process the event if there is no update.
if update is None:
# Apply the event to the state.
update = await state._process(event)
else:
# Process the event.
async for update in state._process(event):
yield update
# Postprocess the event.
assert update is not None, "Process did not return an update."
update = await app.postprocess(state, event, update)
# Update the state.
# Set the state for the session.
app.state_manager.set_state(event.token, state)
# Return the update.
return update
async def ping() -> str:
"""Test API endpoint.
@ -531,7 +545,8 @@ def upload(app: App):
name=handler,
payload={handler_upload_param[0]: files},
)
update = await state._process(event)
# TODO: refactor this to handle yields.
update = await state._process(event).__anext__()
return update
return upload_file
@ -595,10 +610,9 @@ class EventNamespace(AsyncNamespace):
client_ip = environ["REMOTE_ADDR"]
# Process the events.
update = await process(self.app, event, sid, headers, client_ip)
# Emit the event.
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
async for update in process(self.app, event, sid, headers, client_ip):
# Emit the event.
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
async def on_ping(self, sid):
"""Event for testing the API endpoint.

View File

@ -30,7 +30,7 @@ class EventHandler(Base):
"""An event handler responds to an event to update the state."""
# The function to call in response to the event.
fn: Callable
fn: Any
class Config:
"""The Pydantic config."""

View File

@ -4,11 +4,13 @@ from __future__ import annotations
import asyncio
import copy
import functools
import inspect
import traceback
from abc import ABC
from collections import defaultdict
from typing import (
Any,
AsyncIterator,
Callable,
ClassVar,
Dict,
@ -26,7 +28,7 @@ from redis import Redis
from pynecone import constants
from pynecone.base import Base
from pynecone.event import Event, EventHandler, fix_events, window_alert
from pynecone.event import Event, EventHandler, EventSpec, fix_events, window_alert
from pynecone.utils import format, prerequisites, types
from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var
@ -618,13 +620,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
raise ValueError(f"Invalid path: {path}")
return self.substates[path[0]].get_substate(path[1:])
async def _process(self, event: Event) -> StateUpdate:
async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
"""Obtain event info and process event.
Args:
event: The event to process.
Returns:
Yields:
The state update after processing the event.
Raises:
@ -641,52 +643,75 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
"The value of state cannot be None when processing an event."
)
return await self._process_event(
# Get the event generator.
event_iter = self._process_event(
handler=handler,
state=substate,
payload=event.payload,
token=event.token,
)
# Clean the state before processing the event.
self.clean()
# Run the event generator and return state updates.
async for events in event_iter:
# Fix the returned events.
events = fix_events(events, event.token) # type: ignore
# Get the delta after processing the event.
delta = self.get_delta()
# Yield the state update.
yield StateUpdate(delta=delta, events=events)
# Clean the state to prepare for the next event.
self.clean()
async def _process_event(
self, handler: EventHandler, state: State, payload: Dict, token: str
) -> StateUpdate:
self, handler: EventHandler, state: State, payload: Dict
) -> AsyncIterator[Optional[List[EventSpec]]]:
"""Process event.
Args:
handler: Eventhandler to process.
state: State to process the handler.
payload: The event payload.
token: Client token.
Returns:
Yields:
The state update after processing the event.
"""
# Get the function to process the event.
fn = functools.partial(handler.fn, state)
# Wrap the function in a try/except block.
try:
# Handle async functions.
if asyncio.iscoroutinefunction(fn.func):
events = await fn(**payload)
# Handle regular functions.
else:
events = fn(**payload)
# Handle async generators.
if inspect.isasyncgen(events):
async for event in events:
yield event
# Handle regular generators.
elif inspect.isgenerator(events):
for event in events:
yield event
# Handle regular event chains.
else:
yield events
# If an error occurs, throw a window alert.
except Exception:
error = traceback.format_exc()
print(error)
events = fix_events(
[window_alert("An error occurred. See logs for details.")], token
)
return StateUpdate(events=events)
# Fix the returned events.
events = fix_events(events, token)
# Get the delta after processing the event.
delta = self.get_delta()
# Reset the dirty vars.
self.clean()
# Return the state update.
return StateUpdate(delta=delta, events=events)
yield [window_alert("An error occurred. See logs for details.")]
def _always_dirty_computed_vars(self) -> Set[str]:
"""The set of ComputedVars that always need to be recalculated.

View File

@ -107,11 +107,11 @@ async def test_preprocess(State, hydrate_middleware, request, event_fixture, exp
assert len(events) == 2
# Apply the on_load event.
update = await state._process(events[0])
update = await state._process(events[0]).__anext__()
assert update.delta == expected
# Apply the hydrate event.
update = await state._process(events[1])
update = await state._process(events[1]).__anext__()
assert update.delta == exp_is_hydrated(state)
@ -136,13 +136,13 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
# Apply the events.
events = update.events
update = await state._process(events[0])
update = await state._process(events[0]).__anext__()
assert update.delta == {"test_state": {"num": 1}}
update = await state._process(events[1])
update = await state._process(events[1]).__anext__()
assert update.delta == {"test_state": {"num": 2}}
update = await state._process(events[2])
update = await state._process(events[2]).__anext__()
assert update.delta == exp_is_hydrated(state)
@ -165,5 +165,5 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
assert len(update.events) == 1
assert isinstance(update, StateUpdate)
update = await state._process(update.events[0])
update = await state._process(update.events[0]).__anext__()
assert update.delta == exp_is_hydrated(state)

View File

@ -207,7 +207,7 @@ async def test_dynamic_var_event(test_state):
router_data={"pathname": "/", "query": {}},
payload={"value": 50},
)
)
).__anext__()
assert result.delta == {"test_state": {"int_val": 50}}
@ -324,7 +324,7 @@ async def test_list_mutation_detection__plain_list(
router_data={"pathname": "/", "query": {}},
payload={},
)
)
).__anext__()
assert result.delta == expected_delta
@ -451,7 +451,7 @@ async def test_dict_mutation_detection__plain_list(
router_data={"pathname": "/", "query": {}},
payload={},
)
)
).__anext__()
assert result.delta == expected_delta
@ -645,7 +645,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
sid=sid,
headers={},
client_ip=client_ip,
)
).__anext__()
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
assert update == StateUpdate(
delta={
@ -675,7 +676,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
sid=sid,
headers={},
client_ip=client_ip,
)
).__anext__()
assert on_load_update == StateUpdate(
delta={
state.get_name(): {
@ -695,7 +696,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
sid=sid,
headers={},
client_ip=client_ip,
)
).__anext__()
assert on_set_is_hydrated_update == StateUpdate(
delta={
state.get_name(): {
@ -715,7 +716,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
sid=sid,
headers={},
client_ip=client_ip,
)
).__anext__()
assert update == StateUpdate(
delta={
state.get_name(): {

View File

@ -91,6 +91,25 @@ class GrandchildState(ChildState):
pass
class GenState(State):
"""A state with event handlers that generate multiple updates."""
value: int
def go(self, c: int):
"""Increment the value c times and update each time.
Args:
c: The number of times to increment.
Yields:
After each increment.
"""
for _ in range(c):
self.value += 1
yield
@pytest.fixture
def test_state() -> TestState:
"""A state.
@ -146,6 +165,16 @@ def grandchild_state(child_state) -> GrandchildState:
return grandchild_state
@pytest.fixture
def gen_state() -> GenState:
"""A state.
Returns:
A test state.
"""
return GenState() # type: ignore
def test_base_class_vars(test_state):
"""Test that the class vars are set correctly.
@ -577,7 +606,7 @@ async def test_process_event_simple(test_state):
assert test_state.num1 == 0
event = Event(token="t", name="set_num1", payload={"value": 69})
update = await test_state._process(event)
update = await test_state._process(event).__anext__()
# The event should update the value.
assert test_state.num1 == 69
@ -603,7 +632,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
event = Event(
token="t", name="child_state.change_both", payload={"value": "hi", "count": 12}
)
update = await test_state._process(event)
update = await test_state._process(event).__anext__()
assert child_state.value == "HI"
assert child_state.count == 24
assert update.delta == {
@ -619,7 +648,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
name="child_state.grandchild_state.set_value2",
payload={"value": "new"},
)
update = await test_state._process(event)
update = await test_state._process(event).__anext__()
assert grandchild_state.value2 == "new"
assert update.delta == {
"test_state": {"sum": 3.14, "upper": ""},
@ -627,6 +656,31 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
}
@pytest.mark.asyncio
async def test_process_event_generator(gen_state):
"""Test event handlers that generate multiple updates.
Args:
gen_state: A state.
"""
event = Event(
token="t",
name="go",
payload={"c": 5},
)
gen = gen_state._process(event)
count = 0
async for update in gen:
count += 1
assert gen_state.value == count
assert update.delta == {
"gen_state": {"value": count},
}
assert count == 5
def test_format_event_handler():
"""Test formatting an event handler."""
assert (