From a18c6880b524a44cbb420612bb79e14ba9c218ce Mon Sep 17 00:00:00 2001 From: Nikhil Rao Date: Thu, 1 Jun 2023 21:47:55 -0700 Subject: [PATCH] Add async events (#1107) --- pynecone/app.py | 44 +++++++----- pynecone/event.py | 2 +- pynecone/state.py | 75 ++++++++++++++------- tests/middleware/test_hydrate_middleware.py | 12 ++-- tests/test_app.py | 15 +++-- tests/test_state.py | 60 ++++++++++++++++- 6 files changed, 151 insertions(+), 57 deletions(-) diff --git a/pynecone/app.py b/pynecone/app.py index 5ad080b92..fb4835293 100644 --- a/pynecone/app.py +++ b/pynecone/app.py @@ -2,7 +2,18 @@ import asyncio import inspect -from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + AsyncIterator, + Callable, + Coroutine, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) from fastapi import FastAPI, UploadFile from fastapi.middleware import cors @@ -411,7 +422,7 @@ class App(Base): async def process( app: App, event: Event, sid: str, headers: Dict, client_ip: str -) -> StateUpdate: +) -> AsyncIterator[StateUpdate]: """Process an event. Args: @@ -421,7 +432,7 @@ async def process( headers: The client headers. client_ip: The client_ip. - Returns: + Yields: The state updates after processing the event. """ # Get the state for the session. @@ -447,20 +458,23 @@ async def process( # Preprocess the event. update = await app.preprocess(state, event) + # If there was an update, yield it. + if update is not None: + yield update + # Only process the event if there is no update. - if update is None: - # Apply the event to the state. - update = await state._process(event) + else: + # Process the event. + async for update in state._process(event): + yield update # Postprocess the event. + assert update is not None, "Process did not return an update." update = await app.postprocess(state, event, update) - # Update the state. + # Set the state for the session. app.state_manager.set_state(event.token, state) - # Return the update. - return update - async def ping() -> str: """Test API endpoint. @@ -531,7 +545,8 @@ def upload(app: App): name=handler, payload={handler_upload_param[0]: files}, ) - update = await state._process(event) + # TODO: refactor this to handle yields. + update = await state._process(event).__anext__() return update return upload_file @@ -595,10 +610,9 @@ class EventNamespace(AsyncNamespace): client_ip = environ["REMOTE_ADDR"] # Process the events. - update = await process(self.app, event, sid, headers, client_ip) - - # Emit the event. - await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore + async for update in process(self.app, event, sid, headers, client_ip): + # Emit the event. + await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore async def on_ping(self, sid): """Event for testing the API endpoint. diff --git a/pynecone/event.py b/pynecone/event.py index a606f7a19..d9d3d0de7 100644 --- a/pynecone/event.py +++ b/pynecone/event.py @@ -30,7 +30,7 @@ class EventHandler(Base): """An event handler responds to an event to update the state.""" # The function to call in response to the event. - fn: Callable + fn: Any class Config: """The Pydantic config.""" diff --git a/pynecone/state.py b/pynecone/state.py index 429831e56..85811fcf0 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -4,11 +4,13 @@ from __future__ import annotations import asyncio import copy import functools +import inspect import traceback from abc import ABC from collections import defaultdict from typing import ( Any, + AsyncIterator, Callable, ClassVar, Dict, @@ -26,7 +28,7 @@ from redis import Redis from pynecone import constants from pynecone.base import Base -from pynecone.event import Event, EventHandler, fix_events, window_alert +from pynecone.event import Event, EventHandler, EventSpec, fix_events, window_alert from pynecone.utils import format, prerequisites, types from pynecone.vars import BaseVar, ComputedVar, PCDict, PCList, Var @@ -618,13 +620,13 @@ class State(Base, ABC, extra=pydantic.Extra.allow): raise ValueError(f"Invalid path: {path}") return self.substates[path[0]].get_substate(path[1:]) - async def _process(self, event: Event) -> StateUpdate: + async def _process(self, event: Event) -> AsyncIterator[StateUpdate]: """Obtain event info and process event. Args: event: The event to process. - Returns: + Yields: The state update after processing the event. Raises: @@ -641,52 +643,75 @@ class State(Base, ABC, extra=pydantic.Extra.allow): "The value of state cannot be None when processing an event." ) - return await self._process_event( + # Get the event generator. + event_iter = self._process_event( handler=handler, state=substate, payload=event.payload, - token=event.token, ) + # Clean the state before processing the event. + self.clean() + + # Run the event generator and return state updates. + async for events in event_iter: + # Fix the returned events. + events = fix_events(events, event.token) # type: ignore + + # Get the delta after processing the event. + delta = self.get_delta() + + # Yield the state update. + yield StateUpdate(delta=delta, events=events) + + # Clean the state to prepare for the next event. + self.clean() + async def _process_event( - self, handler: EventHandler, state: State, payload: Dict, token: str - ) -> StateUpdate: + self, handler: EventHandler, state: State, payload: Dict + ) -> AsyncIterator[Optional[List[EventSpec]]]: """Process event. Args: handler: Eventhandler to process. state: State to process the handler. payload: The event payload. - token: Client token. - Returns: + Yields: The state update after processing the event. """ + # Get the function to process the event. fn = functools.partial(handler.fn, state) + + # Wrap the function in a try/except block. try: + # Handle async functions. if asyncio.iscoroutinefunction(fn.func): events = await fn(**payload) + + # Handle regular functions. else: events = fn(**payload) + + # Handle async generators. + if inspect.isasyncgen(events): + async for event in events: + yield event + + # Handle regular generators. + elif inspect.isgenerator(events): + for event in events: + yield event + + # Handle regular event chains. + else: + yield events + + # If an error occurs, throw a window alert. except Exception: error = traceback.format_exc() print(error) - events = fix_events( - [window_alert("An error occurred. See logs for details.")], token - ) - return StateUpdate(events=events) - - # Fix the returned events. - events = fix_events(events, token) - - # Get the delta after processing the event. - delta = self.get_delta() - - # Reset the dirty vars. - self.clean() - - # Return the state update. - return StateUpdate(delta=delta, events=events) + yield [window_alert("An error occurred. See logs for details.")] def _always_dirty_computed_vars(self) -> Set[str]: """The set of ComputedVars that always need to be recalculated. diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 8a43cbec7..9d8541983 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -107,11 +107,11 @@ async def test_preprocess(State, hydrate_middleware, request, event_fixture, exp assert len(events) == 2 # Apply the on_load event. - update = await state._process(events[0]) + update = await state._process(events[0]).__anext__() assert update.delta == expected # Apply the hydrate event. - update = await state._process(events[1]) + update = await state._process(events[1]).__anext__() assert update.delta == exp_is_hydrated(state) @@ -136,13 +136,13 @@ async def test_preprocess_multiple_load_events(hydrate_middleware, event1): # Apply the events. events = update.events - update = await state._process(events[0]) + update = await state._process(events[0]).__anext__() assert update.delta == {"test_state": {"num": 1}} - update = await state._process(events[1]) + update = await state._process(events[1]).__anext__() assert update.delta == {"test_state": {"num": 2}} - update = await state._process(events[2]) + update = await state._process(events[2]).__anext__() assert update.delta == exp_is_hydrated(state) @@ -165,5 +165,5 @@ async def test_preprocess_no_events(hydrate_middleware, event1): assert len(update.events) == 1 assert isinstance(update, StateUpdate) - update = await state._process(update.events[0]) + update = await state._process(update.events[0]).__anext__() assert update.delta == exp_is_hydrated(state) diff --git a/tests/test_app.py b/tests/test_app.py index 71443060a..046c49550 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -207,7 +207,7 @@ async def test_dynamic_var_event(test_state): router_data={"pathname": "/", "query": {}}, payload={"value": 50}, ) - ) + ).__anext__() assert result.delta == {"test_state": {"int_val": 50}} @@ -324,7 +324,7 @@ async def test_list_mutation_detection__plain_list( router_data={"pathname": "/", "query": {}}, payload={}, ) - ) + ).__anext__() assert result.delta == expected_delta @@ -451,7 +451,7 @@ async def test_dict_mutation_detection__plain_list( router_data={"pathname": "/", "query": {}}, payload={}, ) - ) + ).__anext__() assert result.delta == expected_delta @@ -645,7 +645,8 @@ async def test_dynamic_route_var_route_change_completed_on_load( sid=sid, headers={}, client_ip=client_ip, - ) + ).__anext__() + # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)] assert update == StateUpdate( delta={ @@ -675,7 +676,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( sid=sid, headers={}, client_ip=client_ip, - ) + ).__anext__() assert on_load_update == StateUpdate( delta={ state.get_name(): { @@ -695,7 +696,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( sid=sid, headers={}, client_ip=client_ip, - ) + ).__anext__() assert on_set_is_hydrated_update == StateUpdate( delta={ state.get_name(): { @@ -715,7 +716,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( sid=sid, headers={}, client_ip=client_ip, - ) + ).__anext__() assert update == StateUpdate( delta={ state.get_name(): { diff --git a/tests/test_state.py b/tests/test_state.py index ae0447cf4..b0c1a8b38 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -91,6 +91,25 @@ class GrandchildState(ChildState): pass +class GenState(State): + """A state with event handlers that generate multiple updates.""" + + value: int + + def go(self, c: int): + """Increment the value c times and update each time. + + Args: + c: The number of times to increment. + + Yields: + After each increment. + """ + for _ in range(c): + self.value += 1 + yield + + @pytest.fixture def test_state() -> TestState: """A state. @@ -146,6 +165,16 @@ def grandchild_state(child_state) -> GrandchildState: return grandchild_state +@pytest.fixture +def gen_state() -> GenState: + """A state. + + Returns: + A test state. + """ + return GenState() # type: ignore + + def test_base_class_vars(test_state): """Test that the class vars are set correctly. @@ -577,7 +606,7 @@ async def test_process_event_simple(test_state): assert test_state.num1 == 0 event = Event(token="t", name="set_num1", payload={"value": 69}) - update = await test_state._process(event) + update = await test_state._process(event).__anext__() # The event should update the value. assert test_state.num1 == 69 @@ -603,7 +632,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) event = Event( token="t", name="child_state.change_both", payload={"value": "hi", "count": 12} ) - update = await test_state._process(event) + update = await test_state._process(event).__anext__() assert child_state.value == "HI" assert child_state.count == 24 assert update.delta == { @@ -619,7 +648,7 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) name="child_state.grandchild_state.set_value2", payload={"value": "new"}, ) - update = await test_state._process(event) + update = await test_state._process(event).__anext__() assert grandchild_state.value2 == "new" assert update.delta == { "test_state": {"sum": 3.14, "upper": ""}, @@ -627,6 +656,31 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) } +@pytest.mark.asyncio +async def test_process_event_generator(gen_state): + """Test event handlers that generate multiple updates. + + Args: + gen_state: A state. + """ + event = Event( + token="t", + name="go", + payload={"c": 5}, + ) + gen = gen_state._process(event) + + count = 0 + async for update in gen: + count += 1 + assert gen_state.value == count + assert update.delta == { + "gen_state": {"value": count}, + } + + assert count == 5 + + def test_format_event_handler(): """Test formatting an event handler.""" assert (