Add async events (#1107)
This commit is contained in:
parent
f1ae27da69
commit
a18c6880b5
@ -2,7 +2,18 @@
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from fastapi import FastAPI, UploadFile
|
||||
from fastapi.middleware import cors
|
||||
@ -411,7 +422,7 @@ class App(Base):
|
||||
|
||||
async def process(
|
||||
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
||||
) -> StateUpdate:
|
||||
) -> AsyncIterator[StateUpdate]:
|
||||
"""Process an event.
|
||||
|
||||
Args:
|
||||
@ -421,7 +432,7 @@ async def process(
|
||||
headers: The client headers.
|
||||
client_ip: The client_ip.
|
||||
|
||||
Returns:
|
||||
Yields:
|
||||
The state updates after processing the event.
|
||||
"""
|
||||
# Get the state for the session.
|
||||
@ -447,20 +458,23 @@ async def process(
|
||||
# Preprocess the event.
|
||||
update = await app.preprocess(state, event)
|
||||
|
||||
# If there was an update, yield it.
|
||||
if update is not None:
|
||||
yield update
|
||||
|
||||
# Only process the event if there is no update.
|
||||
if update is None:
|
||||
# Apply the event to the state.
|
||||
update = await state._process(event)
|
||||
else:
|
||||
# Process the event.
|
||||
async for update in state._process(event):
|
||||
yield update
|
||||
|
||||
# Postprocess the event.
|
||||
assert update is not None, "Process did not return an update."
|
||||
update = await app.postprocess(state, event, update)
|
||||
|
||||
# Update the state.
|
||||
# Set the state for the session.
|
||||
app.state_manager.set_state(event.token, state)
|
||||
|
||||
# Return the update.
|
||||
return update
|
||||
|
||||
|
||||
async def ping() -> str:
|
||||
"""Test API endpoint.
|
||||
@ -531,7 +545,8 @@ def upload(app: App):
|
||||
name=handler,
|
||||
payload={handler_upload_param[0]: files},
|
||||
)
|
||||
update = await state._process(event)
|
||||
# TODO: refactor this to handle yields.
|
||||
update = await state._process(event).__anext__()
|
||||
return update
|
||||
|
||||
return upload_file
|
||||
@ -595,10 +610,9 @@ class EventNamespace(AsyncNamespace):
|
||||
client_ip = environ["REMOTE_ADDR"]
|
||||
|
||||
# Process the events.
|
||||
update = await process(self.app, event, sid, headers, client_ip)
|
||||
|
||||
# Emit the event.
|
||||
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
|
||||
async for update in process(self.app, event, sid, headers, client_ip):
|
||||
# Emit the event.
|
||||
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
|
||||
|
||||
async def on_ping(self, sid):
|
||||
"""Event for testing the API endpoint.
|
||||
|
@ -30,7 +30,7 @@ class EventHandler(Base):
|
||||
"""An event handler responds to an event to update the state."""
|
||||
|
||||
# The function to call in response to the event.
|
||||
fn: Callable
|
||||
fn: Any
|
||||
|
||||
class Config:
|
||||
"""The Pydantic config."""
|
||||
|
@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
@ -26,7 +28,7 @@ from redis import Redis
|
||||
|
||||
from pynecone import constants
|
||||
from pynecone.base import Base
|
||||
from pynecone.event import Event, EventHandler, fix_events, window_alert
|
||||
from pynecone.event import Event, EventHandler, EventSpec, fix_events, window_alert
|
||||
from pynecone.utils import format, prerequisites, types
|
||||
from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var
|
||||
|
||||
@ -618,13 +620,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
return self.substates[path[0]].get_substate(path[1:])
|
||||
|
||||
async def _process(self, event: Event) -> StateUpdate:
|
||||
async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
|
||||
"""Obtain event info and process event.
|
||||
|
||||
Args:
|
||||
event: The event to process.
|
||||
|
||||
Returns:
|
||||
Yields:
|
||||
The state update after processing the event.
|
||||
|
||||
Raises:
|
||||
@ -641,52 +643,75 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"The value of state cannot be None when processing an event."
|
||||
)
|
||||
|
||||
return await self._process_event(
|
||||
# Get the event generator.
|
||||
event_iter = self._process_event(
|
||||
handler=handler,
|
||||
state=substate,
|
||||
payload=event.payload,
|
||||
token=event.token,
|
||||
)
|
||||
|
||||
# Clean the state before processing the event.
|
||||
self.clean()
|
||||
|
||||
# Run the event generator and return state updates.
|
||||
async for events in event_iter:
|
||||
# Fix the returned events.
|
||||
events = fix_events(events, event.token) # type: ignore
|
||||
|
||||
# Get the delta after processing the event.
|
||||
delta = self.get_delta()
|
||||
|
||||
# Yield the state update.
|
||||
yield StateUpdate(delta=delta, events=events)
|
||||
|
||||
# Clean the state to prepare for the next event.
|
||||
self.clean()
|
||||
|
||||
async def _process_event(
|
||||
self, handler: EventHandler, state: State, payload: Dict, token: str
|
||||
) -> StateUpdate:
|
||||
self, handler: EventHandler, state: State, payload: Dict
|
||||
) -> AsyncIterator[Optional[List[EventSpec]]]:
|
||||
"""Process event.
|
||||
|
||||
Args:
|
||||
handler: Eventhandler to process.
|
||||
state: State to process the handler.
|
||||
payload: The event payload.
|
||||
token: Client token.
|
||||
|
||||
Returns:
|
||||
Yields:
|
||||
The state update after processing the event.
|
||||
"""
|
||||
# Get the function to process the event.
|
||||
fn = functools.partial(handler.fn, state)
|
||||
|
||||
# Wrap the function in a try/except block.
|
||||
try:
|
||||
# Handle async functions.
|
||||
if asyncio.iscoroutinefunction(fn.func):
|
||||
events = await fn(**payload)
|
||||
|
||||
# Handle regular functions.
|
||||
else:
|
||||
events = fn(**payload)
|
||||
|
||||
# Handle async generators.
|
||||
if inspect.isasyncgen(events):
|
||||
async for event in events:
|
||||
yield event
|
||||
|
||||
# Handle regular generators.
|
||||
elif inspect.isgenerator(events):
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
# Handle regular event chains.
|
||||
else:
|
||||
yield events
|
||||
|
||||
# If an error occurs, throw a window alert.
|
||||
except Exception:
|
||||
error = traceback.format_exc()
|
||||
print(error)
|
||||
events = fix_events(
|
||||
[window_alert("An error occurred. See logs for details.")], token
|
||||
)
|
||||
return StateUpdate(events=events)
|
||||
|
||||
# Fix the returned events.
|
||||
events = fix_events(events, token)
|
||||
|
||||
# Get the delta after processing the event.
|
||||
delta = self.get_delta()
|
||||
|
||||
# Reset the dirty vars.
|
||||
self.clean()
|
||||
|
||||
# Return the state update.
|
||||
return StateUpdate(delta=delta, events=events)
|
||||
yield [window_alert("An error occurred. See logs for details.")]
|
||||
|
||||
def _always_dirty_computed_vars(self) -> Set[str]:
|
||||
"""The set of ComputedVars that always need to be recalculated.
|
||||
|
@ -107,11 +107,11 @@ async def test_preprocess(State, hydrate_middleware, request, event_fixture, exp
|
||||
assert len(events) == 2
|
||||
|
||||
# Apply the on_load event.
|
||||
update = await state._process(events[0])
|
||||
update = await state._process(events[0]).__anext__()
|
||||
assert update.delta == expected
|
||||
|
||||
# Apply the hydrate event.
|
||||
update = await state._process(events[1])
|
||||
update = await state._process(events[1]).__anext__()
|
||||
assert update.delta == exp_is_hydrated(state)
|
||||
|
||||
|
||||
@ -136,13 +136,13 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
|
||||
|
||||
# Apply the events.
|
||||
events = update.events
|
||||
update = await state._process(events[0])
|
||||
update = await state._process(events[0]).__anext__()
|
||||
assert update.delta == {"test_state": {"num": 1}}
|
||||
|
||||
update = await state._process(events[1])
|
||||
update = await state._process(events[1]).__anext__()
|
||||
assert update.delta == {"test_state": {"num": 2}}
|
||||
|
||||
update = await state._process(events[2])
|
||||
update = await state._process(events[2]).__anext__()
|
||||
assert update.delta == exp_is_hydrated(state)
|
||||
|
||||
|
||||
@ -165,5 +165,5 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
|
||||
assert len(update.events) == 1
|
||||
assert isinstance(update, StateUpdate)
|
||||
|
||||
update = await state._process(update.events[0])
|
||||
update = await state._process(update.events[0]).__anext__()
|
||||
assert update.delta == exp_is_hydrated(state)
|
||||
|
@ -207,7 +207,7 @@ async def test_dynamic_var_event(test_state):
|
||||
router_data={"pathname": "/", "query": {}},
|
||||
payload={"value": 50},
|
||||
)
|
||||
)
|
||||
).__anext__()
|
||||
assert result.delta == {"test_state": {"int_val": 50}}
|
||||
|
||||
|
||||
@ -324,7 +324,7 @@ async def test_list_mutation_detection__plain_list(
|
||||
router_data={"pathname": "/", "query": {}},
|
||||
payload={},
|
||||
)
|
||||
)
|
||||
).__anext__()
|
||||
|
||||
assert result.delta == expected_delta
|
||||
|
||||
@ -451,7 +451,7 @@ async def test_dict_mutation_detection__plain_list(
|
||||
router_data={"pathname": "/", "query": {}},
|
||||
payload={},
|
||||
)
|
||||
)
|
||||
).__anext__()
|
||||
|
||||
assert result.delta == expected_delta
|
||||
|
||||
@ -645,7 +645,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
sid=sid,
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
)
|
||||
).__anext__()
|
||||
|
||||
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
|
||||
assert update == StateUpdate(
|
||||
delta={
|
||||
@ -675,7 +676,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
sid=sid,
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
)
|
||||
).__anext__()
|
||||
assert on_load_update == StateUpdate(
|
||||
delta={
|
||||
state.get_name(): {
|
||||
@ -695,7 +696,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
sid=sid,
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
)
|
||||
).__anext__()
|
||||
assert on_set_is_hydrated_update == StateUpdate(
|
||||
delta={
|
||||
state.get_name(): {
|
||||
@ -715,7 +716,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
sid=sid,
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
)
|
||||
).__anext__()
|
||||
assert update == StateUpdate(
|
||||
delta={
|
||||
state.get_name(): {
|
||||
|
@ -91,6 +91,25 @@ class GrandchildState(ChildState):
|
||||
pass
|
||||
|
||||
|
||||
class GenState(State):
|
||||
"""A state with event handlers that generate multiple updates."""
|
||||
|
||||
value: int
|
||||
|
||||
def go(self, c: int):
|
||||
"""Increment the value c times and update each time.
|
||||
|
||||
Args:
|
||||
c: The number of times to increment.
|
||||
|
||||
Yields:
|
||||
After each increment.
|
||||
"""
|
||||
for _ in range(c):
|
||||
self.value += 1
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_state() -> TestState:
|
||||
"""A state.
|
||||
@ -146,6 +165,16 @@ def grandchild_state(child_state) -> GrandchildState:
|
||||
return grandchild_state
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gen_state() -> GenState:
|
||||
"""A state.
|
||||
|
||||
Returns:
|
||||
A test state.
|
||||
"""
|
||||
return GenState() # type: ignore
|
||||
|
||||
|
||||
def test_base_class_vars(test_state):
|
||||
"""Test that the class vars are set correctly.
|
||||
|
||||
@ -577,7 +606,7 @@ async def test_process_event_simple(test_state):
|
||||
assert test_state.num1 == 0
|
||||
|
||||
event = Event(token="t", name="set_num1", payload={"value": 69})
|
||||
update = await test_state._process(event)
|
||||
update = await test_state._process(event).__anext__()
|
||||
|
||||
# The event should update the value.
|
||||
assert test_state.num1 == 69
|
||||
@ -603,7 +632,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
event = Event(
|
||||
token="t", name="child_state.change_both", payload={"value": "hi", "count": 12}
|
||||
)
|
||||
update = await test_state._process(event)
|
||||
update = await test_state._process(event).__anext__()
|
||||
assert child_state.value == "HI"
|
||||
assert child_state.count == 24
|
||||
assert update.delta == {
|
||||
@ -619,7 +648,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
name="child_state.grandchild_state.set_value2",
|
||||
payload={"value": "new"},
|
||||
)
|
||||
update = await test_state._process(event)
|
||||
update = await test_state._process(event).__anext__()
|
||||
assert grandchild_state.value2 == "new"
|
||||
assert update.delta == {
|
||||
"test_state": {"sum": 3.14, "upper": ""},
|
||||
@ -627,6 +656,31 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_event_generator(gen_state):
|
||||
"""Test event handlers that generate multiple updates.
|
||||
|
||||
Args:
|
||||
gen_state: A state.
|
||||
"""
|
||||
event = Event(
|
||||
token="t",
|
||||
name="go",
|
||||
payload={"c": 5},
|
||||
)
|
||||
gen = gen_state._process(event)
|
||||
|
||||
count = 0
|
||||
async for update in gen:
|
||||
count += 1
|
||||
assert gen_state.value == count
|
||||
assert update.delta == {
|
||||
"gen_state": {"value": count},
|
||||
}
|
||||
|
||||
assert count == 5
|
||||
|
||||
|
||||
def test_format_event_handler():
|
||||
"""Test formatting an event handler."""
|
||||
assert (
|
||||
|
Loading…
Reference in New Issue
Block a user