Add async events (#1107)

This commit is contained in:
Nikhil Rao 2023-06-01 21:47:55 -07:00 committed by GitHub
parent f1ae27da69
commit a18c6880b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 151 additions and 57 deletions

View File

@ -2,7 +2,18 @@
import asyncio import asyncio
import inspect import inspect
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
from fastapi import FastAPI, UploadFile from fastapi import FastAPI, UploadFile
from fastapi.middleware import cors from fastapi.middleware import cors
@ -411,7 +422,7 @@ class App(Base):
async def process( async def process(
app: App, event: Event, sid: str, headers: Dict, client_ip: str app: App, event: Event, sid: str, headers: Dict, client_ip: str
) -> StateUpdate: ) -> AsyncIterator[StateUpdate]:
"""Process an event. """Process an event.
Args: Args:
@ -421,7 +432,7 @@ async def process(
headers: The client headers. headers: The client headers.
client_ip: The client_ip. client_ip: The client_ip.
Returns: Yields:
The state updates after processing the event. The state updates after processing the event.
""" """
# Get the state for the session. # Get the state for the session.
@ -447,20 +458,23 @@ async def process(
# Preprocess the event. # Preprocess the event.
update = await app.preprocess(state, event) update = await app.preprocess(state, event)
# If there was an update, yield it.
if update is not None:
yield update
# Only process the event if there is no update. # Only process the event if there is no update.
if update is None: else:
# Apply the event to the state. # Process the event.
update = await state._process(event) async for update in state._process(event):
yield update
# Postprocess the event. # Postprocess the event.
assert update is not None, "Process did not return an update."
update = await app.postprocess(state, event, update) update = await app.postprocess(state, event, update)
# Update the state. # Set the state for the session.
app.state_manager.set_state(event.token, state) app.state_manager.set_state(event.token, state)
# Return the update.
return update
async def ping() -> str: async def ping() -> str:
"""Test API endpoint. """Test API endpoint.
@ -531,7 +545,8 @@ def upload(app: App):
name=handler, name=handler,
payload={handler_upload_param[0]: files}, payload={handler_upload_param[0]: files},
) )
update = await state._process(event) # TODO: refactor this to handle yields.
update = await state._process(event).__anext__()
return update return update
return upload_file return upload_file
@ -595,10 +610,9 @@ class EventNamespace(AsyncNamespace):
client_ip = environ["REMOTE_ADDR"] client_ip = environ["REMOTE_ADDR"]
# Process the events. # Process the events.
update = await process(self.app, event, sid, headers, client_ip) async for update in process(self.app, event, sid, headers, client_ip):
# Emit the event.
# Emit the event. await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
async def on_ping(self, sid): async def on_ping(self, sid):
"""Event for testing the API endpoint. """Event for testing the API endpoint.

View File

@ -30,7 +30,7 @@ class EventHandler(Base):
"""An event handler responds to an event to update the state.""" """An event handler responds to an event to update the state."""
# The function to call in response to the event. # The function to call in response to the event.
fn: Callable fn: Any
class Config: class Config:
"""The Pydantic config.""" """The Pydantic config."""

View File

@ -4,11 +4,13 @@ from __future__ import annotations
import asyncio import asyncio
import copy import copy
import functools import functools
import inspect
import traceback import traceback
from abc import ABC from abc import ABC
from collections import defaultdict from collections import defaultdict
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
ClassVar, ClassVar,
Dict, Dict,
@ -26,7 +28,7 @@ from redis import Redis
from pynecone import constants from pynecone import constants
from pynecone.base import Base from pynecone.base import Base
from pynecone.event import Event, EventHandler, fix_events, window_alert from pynecone.event import Event, EventHandler, EventSpec, fix_events, window_alert
from pynecone.utils import format, prerequisites, types from pynecone.utils import format, prerequisites, types
from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var
@ -618,13 +620,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
raise ValueError(f"Invalid path: {path}") raise ValueError(f"Invalid path: {path}")
return self.substates[path[0]].get_substate(path[1:]) return self.substates[path[0]].get_substate(path[1:])
async def _process(self, event: Event) -> StateUpdate: async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
"""Obtain event info and process event. """Obtain event info and process event.
Args: Args:
event: The event to process. event: The event to process.
Returns: Yields:
The state update after processing the event. The state update after processing the event.
Raises: Raises:
@ -641,52 +643,75 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
"The value of state cannot be None when processing an event." "The value of state cannot be None when processing an event."
) )
return await self._process_event( # Get the event generator.
event_iter = self._process_event(
handler=handler, handler=handler,
state=substate, state=substate,
payload=event.payload, payload=event.payload,
token=event.token,
) )
# Clean the state before processing the event.
self.clean()
# Run the event generator and return state updates.
async for events in event_iter:
# Fix the returned events.
events = fix_events(events, event.token) # type: ignore
# Get the delta after processing the event.
delta = self.get_delta()
# Yield the state update.
yield StateUpdate(delta=delta, events=events)
# Clean the state to prepare for the next event.
self.clean()
async def _process_event( async def _process_event(
self, handler: EventHandler, state: State, payload: Dict, token: str self, handler: EventHandler, state: State, payload: Dict
) -> StateUpdate: ) -> AsyncIterator[Optional[List[EventSpec]]]:
"""Process event. """Process event.
Args: Args:
handler: Eventhandler to process. handler: Eventhandler to process.
state: State to process the handler. state: State to process the handler.
payload: The event payload. payload: The event payload.
token: Client token.
Returns: Yields:
The state update after processing the event. The state update after processing the event.
""" """
# Get the function to process the event.
fn = functools.partial(handler.fn, state) fn = functools.partial(handler.fn, state)
# Wrap the function in a try/except block.
try: try:
# Handle async functions.
if asyncio.iscoroutinefunction(fn.func): if asyncio.iscoroutinefunction(fn.func):
events = await fn(**payload) events = await fn(**payload)
# Handle regular functions.
else: else:
events = fn(**payload) events = fn(**payload)
# Handle async generators.
if inspect.isasyncgen(events):
async for event in events:
yield event
# Handle regular generators.
elif inspect.isgenerator(events):
for event in events:
yield event
# Handle regular event chains.
else:
yield events
# If an error occurs, throw a window alert.
except Exception: except Exception:
error = traceback.format_exc() error = traceback.format_exc()
print(error) print(error)
events = fix_events( yield [window_alert("An error occurred. See logs for details.")]
[window_alert("An error occurred. See logs for details.")], token
)
return StateUpdate(events=events)
# Fix the returned events.
events = fix_events(events, token)
# Get the delta after processing the event.
delta = self.get_delta()
# Reset the dirty vars.
self.clean()
# Return the state update.
return StateUpdate(delta=delta, events=events)
def _always_dirty_computed_vars(self) -> Set[str]: def _always_dirty_computed_vars(self) -> Set[str]:
"""The set of ComputedVars that always need to be recalculated. """The set of ComputedVars that always need to be recalculated.

View File

@ -107,11 +107,11 @@ async def test_preprocess(State, hydrate_middleware, request, event_fixture, exp
assert len(events) == 2 assert len(events) == 2
# Apply the on_load event. # Apply the on_load event.
update = await state._process(events[0]) update = await state._process(events[0]).__anext__()
assert update.delta == expected assert update.delta == expected
# Apply the hydrate event. # Apply the hydrate event.
update = await state._process(events[1]) update = await state._process(events[1]).__anext__()
assert update.delta == exp_is_hydrated(state) assert update.delta == exp_is_hydrated(state)
@ -136,13 +136,13 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
# Apply the events. # Apply the events.
events = update.events events = update.events
update = await state._process(events[0]) update = await state._process(events[0]).__anext__()
assert update.delta == {"test_state": {"num": 1}} assert update.delta == {"test_state": {"num": 1}}
update = await state._process(events[1]) update = await state._process(events[1]).__anext__()
assert update.delta == {"test_state": {"num": 2}} assert update.delta == {"test_state": {"num": 2}}
update = await state._process(events[2]) update = await state._process(events[2]).__anext__()
assert update.delta == exp_is_hydrated(state) assert update.delta == exp_is_hydrated(state)
@ -165,5 +165,5 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
assert len(update.events) == 1 assert len(update.events) == 1
assert isinstance(update, StateUpdate) assert isinstance(update, StateUpdate)
update = await state._process(update.events[0]) update = await state._process(update.events[0]).__anext__()
assert update.delta == exp_is_hydrated(state) assert update.delta == exp_is_hydrated(state)

View File

@ -207,7 +207,7 @@ async def test_dynamic_var_event(test_state):
router_data={"pathname": "/", "query": {}}, router_data={"pathname": "/", "query": {}},
payload={"value": 50}, payload={"value": 50},
) )
) ).__anext__()
assert result.delta == {"test_state": {"int_val": 50}} assert result.delta == {"test_state": {"int_val": 50}}
@ -324,7 +324,7 @@ async def test_list_mutation_detection__plain_list(
router_data={"pathname": "/", "query": {}}, router_data={"pathname": "/", "query": {}},
payload={}, payload={},
) )
) ).__anext__()
assert result.delta == expected_delta assert result.delta == expected_delta
@ -451,7 +451,7 @@ async def test_dict_mutation_detection__plain_list(
router_data={"pathname": "/", "query": {}}, router_data={"pathname": "/", "query": {}},
payload={}, payload={},
) )
) ).__anext__()
assert result.delta == expected_delta assert result.delta == expected_delta
@ -645,7 +645,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
sid=sid, sid=sid,
headers={}, headers={},
client_ip=client_ip, client_ip=client_ip,
) ).__anext__()
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)] # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
assert update == StateUpdate( assert update == StateUpdate(
delta={ delta={
@ -675,7 +676,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
sid=sid, sid=sid,
headers={}, headers={},
client_ip=client_ip, client_ip=client_ip,
) ).__anext__()
assert on_load_update == StateUpdate( assert on_load_update == StateUpdate(
delta={ delta={
state.get_name(): { state.get_name(): {
@ -695,7 +696,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
sid=sid, sid=sid,
headers={}, headers={},
client_ip=client_ip, client_ip=client_ip,
) ).__anext__()
assert on_set_is_hydrated_update == StateUpdate( assert on_set_is_hydrated_update == StateUpdate(
delta={ delta={
state.get_name(): { state.get_name(): {
@ -715,7 +716,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
sid=sid, sid=sid,
headers={}, headers={},
client_ip=client_ip, client_ip=client_ip,
) ).__anext__()
assert update == StateUpdate( assert update == StateUpdate(
delta={ delta={
state.get_name(): { state.get_name(): {

View File

@ -91,6 +91,25 @@ class GrandchildState(ChildState):
pass pass
class GenState(State):
"""A state with event handlers that generate multiple updates."""
value: int
def go(self, c: int):
"""Increment the value c times and update each time.
Args:
c: The number of times to increment.
Yields:
After each increment.
"""
for _ in range(c):
self.value += 1
yield
@pytest.fixture @pytest.fixture
def test_state() -> TestState: def test_state() -> TestState:
"""A state. """A state.
@ -146,6 +165,16 @@ def grandchild_state(child_state) -> GrandchildState:
return grandchild_state return grandchild_state
@pytest.fixture
def gen_state() -> GenState:
"""A state.
Returns:
A test state.
"""
return GenState() # type: ignore
def test_base_class_vars(test_state): def test_base_class_vars(test_state):
"""Test that the class vars are set correctly. """Test that the class vars are set correctly.
@ -577,7 +606,7 @@ async def test_process_event_simple(test_state):
assert test_state.num1 == 0 assert test_state.num1 == 0
event = Event(token="t", name="set_num1", payload={"value": 69}) event = Event(token="t", name="set_num1", payload={"value": 69})
update = await test_state._process(event) update = await test_state._process(event).__anext__()
# The event should update the value. # The event should update the value.
assert test_state.num1 == 69 assert test_state.num1 == 69
@ -603,7 +632,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
event = Event( event = Event(
token="t", name="child_state.change_both", payload={"value": "hi", "count": 12} token="t", name="child_state.change_both", payload={"value": "hi", "count": 12}
) )
update = await test_state._process(event) update = await test_state._process(event).__anext__()
assert child_state.value == "HI" assert child_state.value == "HI"
assert child_state.count == 24 assert child_state.count == 24
assert update.delta == { assert update.delta == {
@ -619,7 +648,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
name="child_state.grandchild_state.set_value2", name="child_state.grandchild_state.set_value2",
payload={"value": "new"}, payload={"value": "new"},
) )
update = await test_state._process(event) update = await test_state._process(event).__anext__()
assert grandchild_state.value2 == "new" assert grandchild_state.value2 == "new"
assert update.delta == { assert update.delta == {
"test_state": {"sum": 3.14, "upper": ""}, "test_state": {"sum": 3.14, "upper": ""},
@ -627,6 +656,31 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
} }
@pytest.mark.asyncio
async def test_process_event_generator(gen_state):
"""Test event handlers that generate multiple updates.
Args:
gen_state: A state.
"""
event = Event(
token="t",
name="go",
payload={"c": 5},
)
gen = gen_state._process(event)
count = 0
async for update in gen:
count += 1
assert gen_state.value == count
assert update.delta == {
"gen_state": {"value": count},
}
assert count == 5
def test_format_event_handler(): def test_format_event_handler():
"""Test formatting an event handler.""" """Test formatting an event handler."""
assert ( assert (