Add async events (#1107)
This commit is contained in:
parent
f1ae27da69
commit
a18c6880b5
@ -2,7 +2,18 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from fastapi import FastAPI, UploadFile
|
from fastapi import FastAPI, UploadFile
|
||||||
from fastapi.middleware import cors
|
from fastapi.middleware import cors
|
||||||
@ -411,7 +422,7 @@ class App(Base):
|
|||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
||||||
) -> StateUpdate:
|
) -> AsyncIterator[StateUpdate]:
|
||||||
"""Process an event.
|
"""Process an event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -421,7 +432,7 @@ async def process(
|
|||||||
headers: The client headers.
|
headers: The client headers.
|
||||||
client_ip: The client_ip.
|
client_ip: The client_ip.
|
||||||
|
|
||||||
Returns:
|
Yields:
|
||||||
The state updates after processing the event.
|
The state updates after processing the event.
|
||||||
"""
|
"""
|
||||||
# Get the state for the session.
|
# Get the state for the session.
|
||||||
@ -447,20 +458,23 @@ async def process(
|
|||||||
# Preprocess the event.
|
# Preprocess the event.
|
||||||
update = await app.preprocess(state, event)
|
update = await app.preprocess(state, event)
|
||||||
|
|
||||||
|
# If there was an update, yield it.
|
||||||
|
if update is not None:
|
||||||
|
yield update
|
||||||
|
|
||||||
# Only process the event if there is no update.
|
# Only process the event if there is no update.
|
||||||
if update is None:
|
else:
|
||||||
# Apply the event to the state.
|
# Process the event.
|
||||||
update = await state._process(event)
|
async for update in state._process(event):
|
||||||
|
yield update
|
||||||
|
|
||||||
# Postprocess the event.
|
# Postprocess the event.
|
||||||
|
assert update is not None, "Process did not return an update."
|
||||||
update = await app.postprocess(state, event, update)
|
update = await app.postprocess(state, event, update)
|
||||||
|
|
||||||
# Update the state.
|
# Set the state for the session.
|
||||||
app.state_manager.set_state(event.token, state)
|
app.state_manager.set_state(event.token, state)
|
||||||
|
|
||||||
# Return the update.
|
|
||||||
return update
|
|
||||||
|
|
||||||
|
|
||||||
async def ping() -> str:
|
async def ping() -> str:
|
||||||
"""Test API endpoint.
|
"""Test API endpoint.
|
||||||
@ -531,7 +545,8 @@ def upload(app: App):
|
|||||||
name=handler,
|
name=handler,
|
||||||
payload={handler_upload_param[0]: files},
|
payload={handler_upload_param[0]: files},
|
||||||
)
|
)
|
||||||
update = await state._process(event)
|
# TODO: refactor this to handle yields.
|
||||||
|
update = await state._process(event).__anext__()
|
||||||
return update
|
return update
|
||||||
|
|
||||||
return upload_file
|
return upload_file
|
||||||
@ -595,10 +610,9 @@ class EventNamespace(AsyncNamespace):
|
|||||||
client_ip = environ["REMOTE_ADDR"]
|
client_ip = environ["REMOTE_ADDR"]
|
||||||
|
|
||||||
# Process the events.
|
# Process the events.
|
||||||
update = await process(self.app, event, sid, headers, client_ip)
|
async for update in process(self.app, event, sid, headers, client_ip):
|
||||||
|
# Emit the event.
|
||||||
# Emit the event.
|
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
|
||||||
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
|
|
||||||
|
|
||||||
async def on_ping(self, sid):
|
async def on_ping(self, sid):
|
||||||
"""Event for testing the API endpoint.
|
"""Event for testing the API endpoint.
|
||||||
|
@ -30,7 +30,7 @@ class EventHandler(Base):
|
|||||||
"""An event handler responds to an event to update the state."""
|
"""An event handler responds to an event to update the state."""
|
||||||
|
|
||||||
# The function to call in response to the event.
|
# The function to call in response to the event.
|
||||||
fn: Callable
|
fn: Any
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""The Pydantic config."""
|
"""The Pydantic config."""
|
||||||
|
@ -4,11 +4,13 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
import traceback
|
import traceback
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
@ -26,7 +28,7 @@ from redis import Redis
|
|||||||
|
|
||||||
from pynecone import constants
|
from pynecone import constants
|
||||||
from pynecone.base import Base
|
from pynecone.base import Base
|
||||||
from pynecone.event import Event, EventHandler, fix_events, window_alert
|
from pynecone.event import Event, EventHandler, EventSpec, fix_events, window_alert
|
||||||
from pynecone.utils import format, prerequisites, types
|
from pynecone.utils import format, prerequisites, types
|
||||||
from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var
|
from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var
|
||||||
|
|
||||||
@ -618,13 +620,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
raise ValueError(f"Invalid path: {path}")
|
raise ValueError(f"Invalid path: {path}")
|
||||||
return self.substates[path[0]].get_substate(path[1:])
|
return self.substates[path[0]].get_substate(path[1:])
|
||||||
|
|
||||||
async def _process(self, event: Event) -> StateUpdate:
|
async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
|
||||||
"""Obtain event info and process event.
|
"""Obtain event info and process event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event: The event to process.
|
event: The event to process.
|
||||||
|
|
||||||
Returns:
|
Yields:
|
||||||
The state update after processing the event.
|
The state update after processing the event.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -641,52 +643,75 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
"The value of state cannot be None when processing an event."
|
"The value of state cannot be None when processing an event."
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._process_event(
|
# Get the event generator.
|
||||||
|
event_iter = self._process_event(
|
||||||
handler=handler,
|
handler=handler,
|
||||||
state=substate,
|
state=substate,
|
||||||
payload=event.payload,
|
payload=event.payload,
|
||||||
token=event.token,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Clean the state before processing the event.
|
||||||
|
self.clean()
|
||||||
|
|
||||||
|
# Run the event generator and return state updates.
|
||||||
|
async for events in event_iter:
|
||||||
|
# Fix the returned events.
|
||||||
|
events = fix_events(events, event.token) # type: ignore
|
||||||
|
|
||||||
|
# Get the delta after processing the event.
|
||||||
|
delta = self.get_delta()
|
||||||
|
|
||||||
|
# Yield the state update.
|
||||||
|
yield StateUpdate(delta=delta, events=events)
|
||||||
|
|
||||||
|
# Clean the state to prepare for the next event.
|
||||||
|
self.clean()
|
||||||
|
|
||||||
async def _process_event(
|
async def _process_event(
|
||||||
self, handler: EventHandler, state: State, payload: Dict, token: str
|
self, handler: EventHandler, state: State, payload: Dict
|
||||||
) -> StateUpdate:
|
) -> AsyncIterator[Optional[List[EventSpec]]]:
|
||||||
"""Process event.
|
"""Process event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
handler: Eventhandler to process.
|
handler: Eventhandler to process.
|
||||||
state: State to process the handler.
|
state: State to process the handler.
|
||||||
payload: The event payload.
|
payload: The event payload.
|
||||||
token: Client token.
|
|
||||||
|
|
||||||
Returns:
|
Yields:
|
||||||
The state update after processing the event.
|
The state update after processing the event.
|
||||||
"""
|
"""
|
||||||
|
# Get the function to process the event.
|
||||||
fn = functools.partial(handler.fn, state)
|
fn = functools.partial(handler.fn, state)
|
||||||
|
|
||||||
|
# Wrap the function in a try/except block.
|
||||||
try:
|
try:
|
||||||
|
# Handle async functions.
|
||||||
if asyncio.iscoroutinefunction(fn.func):
|
if asyncio.iscoroutinefunction(fn.func):
|
||||||
events = await fn(**payload)
|
events = await fn(**payload)
|
||||||
|
|
||||||
|
# Handle regular functions.
|
||||||
else:
|
else:
|
||||||
events = fn(**payload)
|
events = fn(**payload)
|
||||||
|
|
||||||
|
# Handle async generators.
|
||||||
|
if inspect.isasyncgen(events):
|
||||||
|
async for event in events:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
# Handle regular generators.
|
||||||
|
elif inspect.isgenerator(events):
|
||||||
|
for event in events:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
# Handle regular event chains.
|
||||||
|
else:
|
||||||
|
yield events
|
||||||
|
|
||||||
|
# If an error occurs, throw a window alert.
|
||||||
except Exception:
|
except Exception:
|
||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
print(error)
|
print(error)
|
||||||
events = fix_events(
|
yield [window_alert("An error occurred. See logs for details.")]
|
||||||
[window_alert("An error occurred. See logs for details.")], token
|
|
||||||
)
|
|
||||||
return StateUpdate(events=events)
|
|
||||||
|
|
||||||
# Fix the returned events.
|
|
||||||
events = fix_events(events, token)
|
|
||||||
|
|
||||||
# Get the delta after processing the event.
|
|
||||||
delta = self.get_delta()
|
|
||||||
|
|
||||||
# Reset the dirty vars.
|
|
||||||
self.clean()
|
|
||||||
|
|
||||||
# Return the state update.
|
|
||||||
return StateUpdate(delta=delta, events=events)
|
|
||||||
|
|
||||||
def _always_dirty_computed_vars(self) -> Set[str]:
|
def _always_dirty_computed_vars(self) -> Set[str]:
|
||||||
"""The set of ComputedVars that always need to be recalculated.
|
"""The set of ComputedVars that always need to be recalculated.
|
||||||
|
@ -107,11 +107,11 @@ async def test_preprocess(State, hydrate_middleware, request, event_fixture, exp
|
|||||||
assert len(events) == 2
|
assert len(events) == 2
|
||||||
|
|
||||||
# Apply the on_load event.
|
# Apply the on_load event.
|
||||||
update = await state._process(events[0])
|
update = await state._process(events[0]).__anext__()
|
||||||
assert update.delta == expected
|
assert update.delta == expected
|
||||||
|
|
||||||
# Apply the hydrate event.
|
# Apply the hydrate event.
|
||||||
update = await state._process(events[1])
|
update = await state._process(events[1]).__anext__()
|
||||||
assert update.delta == exp_is_hydrated(state)
|
assert update.delta == exp_is_hydrated(state)
|
||||||
|
|
||||||
|
|
||||||
@ -136,13 +136,13 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
|
|||||||
|
|
||||||
# Apply the events.
|
# Apply the events.
|
||||||
events = update.events
|
events = update.events
|
||||||
update = await state._process(events[0])
|
update = await state._process(events[0]).__anext__()
|
||||||
assert update.delta == {"test_state": {"num": 1}}
|
assert update.delta == {"test_state": {"num": 1}}
|
||||||
|
|
||||||
update = await state._process(events[1])
|
update = await state._process(events[1]).__anext__()
|
||||||
assert update.delta == {"test_state": {"num": 2}}
|
assert update.delta == {"test_state": {"num": 2}}
|
||||||
|
|
||||||
update = await state._process(events[2])
|
update = await state._process(events[2]).__anext__()
|
||||||
assert update.delta == exp_is_hydrated(state)
|
assert update.delta == exp_is_hydrated(state)
|
||||||
|
|
||||||
|
|
||||||
@ -165,5 +165,5 @@ async def test_preprocess_no_events(hydrate_middleware, event1):
|
|||||||
assert len(update.events) == 1
|
assert len(update.events) == 1
|
||||||
assert isinstance(update, StateUpdate)
|
assert isinstance(update, StateUpdate)
|
||||||
|
|
||||||
update = await state._process(update.events[0])
|
update = await state._process(update.events[0]).__anext__()
|
||||||
assert update.delta == exp_is_hydrated(state)
|
assert update.delta == exp_is_hydrated(state)
|
||||||
|
@ -207,7 +207,7 @@ async def test_dynamic_var_event(test_state):
|
|||||||
router_data={"pathname": "/", "query": {}},
|
router_data={"pathname": "/", "query": {}},
|
||||||
payload={"value": 50},
|
payload={"value": 50},
|
||||||
)
|
)
|
||||||
)
|
).__anext__()
|
||||||
assert result.delta == {"test_state": {"int_val": 50}}
|
assert result.delta == {"test_state": {"int_val": 50}}
|
||||||
|
|
||||||
|
|
||||||
@ -324,7 +324,7 @@ async def test_list_mutation_detection__plain_list(
|
|||||||
router_data={"pathname": "/", "query": {}},
|
router_data={"pathname": "/", "query": {}},
|
||||||
payload={},
|
payload={},
|
||||||
)
|
)
|
||||||
)
|
).__anext__()
|
||||||
|
|
||||||
assert result.delta == expected_delta
|
assert result.delta == expected_delta
|
||||||
|
|
||||||
@ -451,7 +451,7 @@ async def test_dict_mutation_detection__plain_list(
|
|||||||
router_data={"pathname": "/", "query": {}},
|
router_data={"pathname": "/", "query": {}},
|
||||||
payload={},
|
payload={},
|
||||||
)
|
)
|
||||||
)
|
).__anext__()
|
||||||
|
|
||||||
assert result.delta == expected_delta
|
assert result.delta == expected_delta
|
||||||
|
|
||||||
@ -645,7 +645,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
sid=sid,
|
sid=sid,
|
||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
)
|
).__anext__()
|
||||||
|
|
||||||
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
|
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
|
||||||
assert update == StateUpdate(
|
assert update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
@ -675,7 +676,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
sid=sid,
|
sid=sid,
|
||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
)
|
).__anext__()
|
||||||
assert on_load_update == StateUpdate(
|
assert on_load_update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
state.get_name(): {
|
state.get_name(): {
|
||||||
@ -695,7 +696,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
sid=sid,
|
sid=sid,
|
||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
)
|
).__anext__()
|
||||||
assert on_set_is_hydrated_update == StateUpdate(
|
assert on_set_is_hydrated_update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
state.get_name(): {
|
state.get_name(): {
|
||||||
@ -715,7 +716,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
sid=sid,
|
sid=sid,
|
||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
)
|
).__anext__()
|
||||||
assert update == StateUpdate(
|
assert update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
state.get_name(): {
|
state.get_name(): {
|
||||||
|
@ -91,6 +91,25 @@ class GrandchildState(ChildState):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GenState(State):
|
||||||
|
"""A state with event handlers that generate multiple updates."""
|
||||||
|
|
||||||
|
value: int
|
||||||
|
|
||||||
|
def go(self, c: int):
|
||||||
|
"""Increment the value c times and update each time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c: The number of times to increment.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
After each increment.
|
||||||
|
"""
|
||||||
|
for _ in range(c):
|
||||||
|
self.value += 1
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_state() -> TestState:
|
def test_state() -> TestState:
|
||||||
"""A state.
|
"""A state.
|
||||||
@ -146,6 +165,16 @@ def grandchild_state(child_state) -> GrandchildState:
|
|||||||
return grandchild_state
|
return grandchild_state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gen_state() -> GenState:
|
||||||
|
"""A state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A test state.
|
||||||
|
"""
|
||||||
|
return GenState() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_base_class_vars(test_state):
|
def test_base_class_vars(test_state):
|
||||||
"""Test that the class vars are set correctly.
|
"""Test that the class vars are set correctly.
|
||||||
|
|
||||||
@ -577,7 +606,7 @@ async def test_process_event_simple(test_state):
|
|||||||
assert test_state.num1 == 0
|
assert test_state.num1 == 0
|
||||||
|
|
||||||
event = Event(token="t", name="set_num1", payload={"value": 69})
|
event = Event(token="t", name="set_num1", payload={"value": 69})
|
||||||
update = await test_state._process(event)
|
update = await test_state._process(event).__anext__()
|
||||||
|
|
||||||
# The event should update the value.
|
# The event should update the value.
|
||||||
assert test_state.num1 == 69
|
assert test_state.num1 == 69
|
||||||
@ -603,7 +632,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
|||||||
event = Event(
|
event = Event(
|
||||||
token="t", name="child_state.change_both", payload={"value": "hi", "count": 12}
|
token="t", name="child_state.change_both", payload={"value": "hi", "count": 12}
|
||||||
)
|
)
|
||||||
update = await test_state._process(event)
|
update = await test_state._process(event).__anext__()
|
||||||
assert child_state.value == "HI"
|
assert child_state.value == "HI"
|
||||||
assert child_state.count == 24
|
assert child_state.count == 24
|
||||||
assert update.delta == {
|
assert update.delta == {
|
||||||
@ -619,7 +648,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
|||||||
name="child_state.grandchild_state.set_value2",
|
name="child_state.grandchild_state.set_value2",
|
||||||
payload={"value": "new"},
|
payload={"value": "new"},
|
||||||
)
|
)
|
||||||
update = await test_state._process(event)
|
update = await test_state._process(event).__anext__()
|
||||||
assert grandchild_state.value2 == "new"
|
assert grandchild_state.value2 == "new"
|
||||||
assert update.delta == {
|
assert update.delta == {
|
||||||
"test_state": {"sum": 3.14, "upper": ""},
|
"test_state": {"sum": 3.14, "upper": ""},
|
||||||
@ -627,6 +656,31 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_event_generator(gen_state):
|
||||||
|
"""Test event handlers that generate multiple updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gen_state: A state.
|
||||||
|
"""
|
||||||
|
event = Event(
|
||||||
|
token="t",
|
||||||
|
name="go",
|
||||||
|
payload={"c": 5},
|
||||||
|
)
|
||||||
|
gen = gen_state._process(event)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
async for update in gen:
|
||||||
|
count += 1
|
||||||
|
assert gen_state.value == count
|
||||||
|
assert update.delta == {
|
||||||
|
"gen_state": {"value": count},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert count == 5
|
||||||
|
|
||||||
|
|
||||||
def test_format_event_handler():
|
def test_format_event_handler():
|
||||||
"""Test formatting an event handler."""
|
"""Test formatting an event handler."""
|
||||||
assert (
|
assert (
|
||||||
|
Loading…
Reference in New Issue
Block a user