Fix Event chaining in the on_load event handler return not working (#773)

* Fix Event chaining in the on_load event handler return not working

* added async tests

* addressed comments
This commit is contained in:
Elijah Ahianyo 2023-04-07 05:26:43 +00:00 committed by GitHub
parent bd6ea9d977
commit e8387c8e26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 241 additions and 46 deletions

View File

@ -148,7 +148,9 @@ class App(Base):
allow_origins=["*"],
)
def preprocess(self, state: State, event: Event) -> Optional[Delta]:
async def preprocess(
self, state: State, event: Event
) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
"""Preprocess the event.
This is where middleware can modify the event before it is processed.
@ -165,11 +167,13 @@ class App(Base):
An optional state to return.
"""
for middleware in self.middleware:
out = middleware.preprocess(app=self, state=state, event=event)
out = await middleware.preprocess(app=self, state=state, event=event)
if out is not None:
return out
def postprocess(self, state: State, event: Event, delta: Delta) -> Optional[Delta]:
async def postprocess(
self, state: State, event: Event, delta: Delta
) -> Optional[Delta]:
"""Postprocess the event.
This is where middleware can modify the delta after it is processed.
@ -187,7 +191,7 @@ class App(Base):
An optional state to return.
"""
for middleware in self.middleware:
out = middleware.postprocess(
out = await middleware.postprocess(
app=self, state=state, event=event, delta=delta
)
if out is not None:
@ -400,7 +404,7 @@ class App(Base):
async def process(
app: App, event: Event, sid: str, headers: Dict, client_ip: str
) -> StateUpdate:
) -> Union[StateUpdate, List[StateUpdate]]:
"""Process an event.
Args:
@ -411,7 +415,7 @@ async def process(
client_ip: The client_ip.
Returns:
The state update after processing the event.
The state update(s) after processing the event.
"""
# Get the state for the session.
state = app.state_manager.get_state(event.token)
@ -430,21 +434,27 @@ async def process(
state.router_data[constants.RouteVar.CLIENT_IP] = client_ip
# Preprocess the event.
pre = app.preprocess(state, event)
if pre is not None:
return StateUpdate(delta=pre)
pre = await app.preprocess(state, event)
if pre is not None and not isinstance(pre, List):
return pre
# Apply the event to the state.
update = await state.process(event)
updates = pre if pre else await state.process(event)
app.state_manager.set_state(event.token, state)
updates = updates if isinstance(updates, List) else [updates]
# Postprocess the event.
post = app.postprocess(state, event, update.delta)
if post is not None:
return StateUpdate(delta=post)
post_list = []
for update in updates:
post = await app.postprocess(state, event, update.delta) # type: ignore
post_list.append(post) if post else None
if post_list:
return [StateUpdate(delta=post) for post in post_list]
# Return the update.
return update
return updates
async def ping() -> str:
@ -578,11 +588,12 @@ class EventNamespace(AsyncNamespace):
# Get the client IP
client_ip = environ["REMOTE_ADDR"]
# Process the event.
update = await process(self.app, event, sid, headers, client_ip)
# Process the events.
updates = await process(self.app, event, sid, headers, client_ip)
# Emit the event.
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
for update in updates:
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
async def on_ping(self, sid):
"""Event for testing the API endpoint.

View File

@ -69,7 +69,7 @@ class Markdown(Component):
"li": "{ListItem}",
"p": "{Text}",
"a": "{Link}",
"code": """{({node, inline, className, children, ...props}) =>
"code": """{({node, inline, className, children, ...props}) =>
{
const match = (className || '').match(/language-(?<lang>.*)/);
return !inline ? (

View File

@ -1,12 +1,12 @@
"""Middleware to hydrate the state."""
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from pynecone import constants
from pynecone.event import Event, EventHandler, get_hydrate_event
from pynecone.middleware.middleware import Middleware
from pynecone.state import Delta, State
from pynecone.state import State, StateUpdate
from pynecone.utils import format
if TYPE_CHECKING:
@ -16,7 +16,9 @@ if TYPE_CHECKING:
class HydrateMiddleware(Middleware):
"""Middleware to handle initial app hydration."""
def preprocess(self, app: App, state: State, event: Event) -> Optional[Delta]:
async def preprocess(
self, app: App, state: State, event: Event
) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
"""Preprocess the event.
Args:
@ -25,7 +27,7 @@ class HydrateMiddleware(Middleware):
event: The event to preprocess.
Returns:
An optional state to return.
An optional delta or list of state updates to return.
"""
if event.name == get_hydrate_event(state):
route = event.router_data.get(constants.RouteVar.PATH, "")
@ -37,20 +39,43 @@ class HydrateMiddleware(Middleware):
load_event = None
if load_event:
if isinstance(load_event, list):
for single_event in load_event:
self.execute_load_event(state, single_event)
else:
self.execute_load_event(state, load_event)
return format.format_state({state.get_name(): state.dict()})
if not isinstance(load_event, List):
load_event = [load_event]
updates = []
for single_event in load_event:
updates.append(
await self.execute_load_event(
state, single_event, event.token, event.payload
)
)
return updates
delta = format.format_state({state.get_name(): state.dict()})
return StateUpdate(delta=delta) if delta else None
def execute_load_event(self, state: State, load_event: EventHandler) -> None:
async def execute_load_event(
self, state: State, load_event: EventHandler, token: str, payload: Dict
) -> StateUpdate:
"""Execute single load event.
Args:
state: The client state.
load_event: A single load event to execute.
token: Client token
payload: The event payload
Returns:
A state Update.
Raises:
ValueError: If the state value is None.
"""
substate_path = format.format_event_handler(load_event).split(".")
ex_state = state.get_substate(substate_path[:-1])
load_event.fn(ex_state)
if not ex_state:
raise ValueError(
"The value of state cannot be None when processing an on-load event."
)
return await state.process_event(
handler=load_event, state=ex_state, payload=payload, token=token
)

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
class LoggingMiddleware(Middleware):
"""Middleware to log requests and responses."""
def preprocess(self, app: App, state: State, event: Event):
async def preprocess(self, app: App, state: State, event: Event):
"""Preprocess the event.
Args:
@ -24,7 +24,7 @@ class LoggingMiddleware(Middleware):
"""
print(f"Event {event}")
def postprocess(self, app: App, state: State, event: Event, delta: Delta):
async def postprocess(self, app: App, state: State, event: Event, delta: Delta):
"""Postprocess the event.
Args:

View File

@ -2,11 +2,11 @@
from __future__ import annotations
from abc import ABC
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional, Union
from pynecone.base import Base
from pynecone.event import Event
from pynecone.state import Delta, State
from pynecone.state import Delta, State, StateUpdate
if TYPE_CHECKING:
from pynecone.app import App
@ -15,7 +15,9 @@ if TYPE_CHECKING:
class Middleware(Base, ABC):
"""Middleware to preprocess and postprocess requests."""
def preprocess(self, app: App, state: State, event: Event) -> Optional[Delta]:
async def preprocess(
self, app: App, state: State, event: Event
) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
"""Preprocess the event.
Args:
@ -28,7 +30,7 @@ class Middleware(Base, ABC):
"""
return None
def postprocess(
async def postprocess(
self, app: App, state: State, event: Event, delta
) -> Optional[Delta]:
"""Postprocess the event.

View File

@ -203,12 +203,12 @@ def export(
if zipping:
console.rule(
"""Backend & Frontend compiled. See [green bold]backend.zip[/green bold]
"""Backend & Frontend compiled. See [green bold]backend.zip[/green bold]
and [green bold]frontend.zip[/green bold]."""
)
else:
console.rule(
"""Backend & Frontend compiled. See [green bold]app[/green bold]
"""Backend & Frontend compiled. See [green bold]app[/green bold]
and [green bold].web/_static[/green bold] directories."""
)

View File

@ -546,13 +546,16 @@ class State(Base, ABC):
return self.substates[path[0]].get_substate(path[1:])
async def process(self, event: Event) -> StateUpdate:
"""Process an event.
"""Obtain event info and process event.
Args:
event: The event to process.
Returns:
The state update after processing the event.
Raises:
ValueError: If the state value is None.
"""
# Get the event handler.
path = event.name.split(".")
@ -560,23 +563,48 @@ class State(Base, ABC):
substate = self.get_substate(path)
handler = substate.event_handlers[name] # type: ignore
# Process the event.
fn = functools.partial(handler.fn, substate)
if not substate:
raise ValueError(
"The value of state cannot be None when processing an event."
)
return await self.process_event(
handler=handler,
state=substate,
payload=event.payload,
token=event.token,
)
async def process_event(
self, handler: EventHandler, state: State, payload: Dict, token: str
) -> StateUpdate:
"""Process event.
Args:
handler: Eventhandler to process.
state: State to process the handler.
payload: The event payload.
token: Client token.
Returns:
The state update after processing the event.
"""
fn = functools.partial(handler.fn, state)
try:
if asyncio.iscoroutinefunction(fn.func):
events = await fn(**event.payload)
events = await fn(**payload)
else:
events = fn(**event.payload)
events = fn(**payload)
except Exception:
error = traceback.format_exc()
print(error)
events = fix_events(
[window_alert("An error occurred. See logs for details.")], event.token
[window_alert("An error occurred. See logs for details.")], token
)
return StateUpdate(events=events)
# Fix the returned events.
events = fix_events(events, event.token)
events = fix_events(events, token)
# Get the delta after processing the event.
delta = self.get_delta()

View File

@ -94,7 +94,6 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
is_data_frame: whether data field is a pandas dataframe.
"""
with pytest.raises(ValueError) as err:
if is_data_frame:
data_table(data=request.getfixturevalue(fixture).data)
else:

View File

View File

@ -0,0 +1,34 @@
import pytest
from pynecone.event import Event
def create_event(name):
return Event(
token="<token>",
name=name,
router_data={
"pathname": "/",
"query": {},
"token": "<token>",
"sid": "<sid>",
"headers": {},
"ip": "127.0.0.1",
},
payload={},
)
@pytest.fixture
def event1():
return create_event("test_state.hydrate")
@pytest.fixture
def event2():
return create_event("test_state2.hydrate")
@pytest.fixture
def event3():
return create_event("test_state3.hydrate")

View File

@ -0,0 +1,96 @@
from typing import List
import pytest
from pynecone.app import App
from pynecone.middleware.hydrate_middleware import HydrateMiddleware
from pynecone.state import State
class TestState(State):
"""A test state with no return in handler."""
num: int = 0
def test_handler(self):
"""Test handler."""
self.num += 1
class TestState2(State):
"""A test state with return in handler."""
num: int = 0
name: str
def test_handler(self):
"""Test handler that calls another handler.
Returns:
Chain of EventHandlers
"""
self.num += 1
return [self.change_name()]
def change_name(self):
"""Test handler to change name."""
self.name = "random"
class TestState3(State):
"""A test state with async handler."""
num: int = 0
async def test_handler(self):
"""Test handler."""
self.num += 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"state, expected, event_fixture",
[
(TestState, {"test_state": {"num": 1}}, "event1"),
(TestState2, {"test_state2": {"num": 1}}, "event2"),
(TestState3, {"test_state3": {"num": 1}}, "event3"),
],
)
async def test_preprocess(state, request, event_fixture, expected):
"""Test that a state hydrate event is processed correctly.
Args:
state: state to process event
request: pytest fixture request
event_fixture: The event fixture(an Event)
expected: expected delta
"""
app = App(state=state, load_events={"index": state.test_handler})
hydrate_middleware = HydrateMiddleware()
result = await hydrate_middleware.preprocess(
app=app, event=request.getfixturevalue(event_fixture), state=state()
)
assert isinstance(result, List)
assert result[0].delta == expected
@pytest.mark.asyncio
async def test_preprocess_multiple_load_events(event1):
"""Test that a state hydrate event for multiple on-load events is processed correctly.
Args:
event1: an Event.
"""
app = App(
state=TestState,
load_events={"index": [TestState.test_handler, TestState.test_handler]},
)
hydrate_middleware = HydrateMiddleware()
result = await hydrate_middleware.preprocess(
app=app, event=event1, state=TestState()
)
assert isinstance(result, List)
assert result[0].delta == {"test_state": {"num": 1}}
assert result[1].delta == {"test_state": {"num": 2}}