Fix Event chaining in the on_load event handler return not working (#773)
* Fix Event chaining in the on_load event handler return not working * added async tests * addressed comments
This commit is contained in:
parent
bd6ea9d977
commit
e8387c8e26
@ -148,7 +148,9 @@ class App(Base):
|
||||
allow_origins=["*"],
|
||||
)
|
||||
|
||||
def preprocess(self, state: State, event: Event) -> Optional[Delta]:
|
||||
async def preprocess(
|
||||
self, state: State, event: Event
|
||||
) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
|
||||
"""Preprocess the event.
|
||||
|
||||
This is where middleware can modify the event before it is processed.
|
||||
@ -165,11 +167,13 @@ class App(Base):
|
||||
An optional state to return.
|
||||
"""
|
||||
for middleware in self.middleware:
|
||||
out = middleware.preprocess(app=self, state=state, event=event)
|
||||
out = await middleware.preprocess(app=self, state=state, event=event)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
def postprocess(self, state: State, event: Event, delta: Delta) -> Optional[Delta]:
|
||||
async def postprocess(
|
||||
self, state: State, event: Event, delta: Delta
|
||||
) -> Optional[Delta]:
|
||||
"""Postprocess the event.
|
||||
|
||||
This is where middleware can modify the delta after it is processed.
|
||||
@ -187,7 +191,7 @@ class App(Base):
|
||||
An optional state to return.
|
||||
"""
|
||||
for middleware in self.middleware:
|
||||
out = middleware.postprocess(
|
||||
out = await middleware.postprocess(
|
||||
app=self, state=state, event=event, delta=delta
|
||||
)
|
||||
if out is not None:
|
||||
@ -400,7 +404,7 @@ class App(Base):
|
||||
|
||||
async def process(
|
||||
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
||||
) -> StateUpdate:
|
||||
) -> Union[StateUpdate, List[StateUpdate]]:
|
||||
"""Process an event.
|
||||
|
||||
Args:
|
||||
@ -411,7 +415,7 @@ async def process(
|
||||
client_ip: The client_ip.
|
||||
|
||||
Returns:
|
||||
The state update after processing the event.
|
||||
The state update(s) after processing the event.
|
||||
"""
|
||||
# Get the state for the session.
|
||||
state = app.state_manager.get_state(event.token)
|
||||
@ -430,21 +434,27 @@ async def process(
|
||||
state.router_data[constants.RouteVar.CLIENT_IP] = client_ip
|
||||
|
||||
# Preprocess the event.
|
||||
pre = app.preprocess(state, event)
|
||||
if pre is not None:
|
||||
return StateUpdate(delta=pre)
|
||||
pre = await app.preprocess(state, event)
|
||||
if pre is not None and not isinstance(pre, List):
|
||||
return pre
|
||||
|
||||
# Apply the event to the state.
|
||||
update = await state.process(event)
|
||||
updates = pre if pre else await state.process(event)
|
||||
app.state_manager.set_state(event.token, state)
|
||||
|
||||
updates = updates if isinstance(updates, List) else [updates]
|
||||
|
||||
# Postprocess the event.
|
||||
post = app.postprocess(state, event, update.delta)
|
||||
if post is not None:
|
||||
return StateUpdate(delta=post)
|
||||
post_list = []
|
||||
for update in updates:
|
||||
post = await app.postprocess(state, event, update.delta) # type: ignore
|
||||
post_list.append(post) if post else None
|
||||
|
||||
if post_list:
|
||||
return [StateUpdate(delta=post) for post in post_list]
|
||||
|
||||
# Return the update.
|
||||
return update
|
||||
return updates
|
||||
|
||||
|
||||
async def ping() -> str:
|
||||
@ -578,11 +588,12 @@ class EventNamespace(AsyncNamespace):
|
||||
# Get the client IP
|
||||
client_ip = environ["REMOTE_ADDR"]
|
||||
|
||||
# Process the event.
|
||||
update = await process(self.app, event, sid, headers, client_ip)
|
||||
# Process the events.
|
||||
updates = await process(self.app, event, sid, headers, client_ip)
|
||||
|
||||
# Emit the event.
|
||||
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
|
||||
for update in updates:
|
||||
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
|
||||
|
||||
async def on_ping(self, sid):
|
||||
"""Event for testing the API endpoint.
|
||||
|
@ -69,7 +69,7 @@ class Markdown(Component):
|
||||
"li": "{ListItem}",
|
||||
"p": "{Text}",
|
||||
"a": "{Link}",
|
||||
"code": """{({node, inline, className, children, ...props}) =>
|
||||
"code": """{({node, inline, className, children, ...props}) =>
|
||||
{
|
||||
const match = (className || '').match(/language-(?<lang>.*)/);
|
||||
return !inline ? (
|
||||
|
@ -1,12 +1,12 @@
|
||||
"""Middleware to hydrate the state."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from pynecone import constants
|
||||
from pynecone.event import Event, EventHandler, get_hydrate_event
|
||||
from pynecone.middleware.middleware import Middleware
|
||||
from pynecone.state import Delta, State
|
||||
from pynecone.state import State, StateUpdate
|
||||
from pynecone.utils import format
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -16,7 +16,9 @@ if TYPE_CHECKING:
|
||||
class HydrateMiddleware(Middleware):
|
||||
"""Middleware to handle initial app hydration."""
|
||||
|
||||
def preprocess(self, app: App, state: State, event: Event) -> Optional[Delta]:
|
||||
async def preprocess(
|
||||
self, app: App, state: State, event: Event
|
||||
) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
|
||||
"""Preprocess the event.
|
||||
|
||||
Args:
|
||||
@ -25,7 +27,7 @@ class HydrateMiddleware(Middleware):
|
||||
event: The event to preprocess.
|
||||
|
||||
Returns:
|
||||
An optional state to return.
|
||||
An optional delta or list of state updates to return.
|
||||
"""
|
||||
if event.name == get_hydrate_event(state):
|
||||
route = event.router_data.get(constants.RouteVar.PATH, "")
|
||||
@ -37,20 +39,43 @@ class HydrateMiddleware(Middleware):
|
||||
load_event = None
|
||||
|
||||
if load_event:
|
||||
if isinstance(load_event, list):
|
||||
for single_event in load_event:
|
||||
self.execute_load_event(state, single_event)
|
||||
else:
|
||||
self.execute_load_event(state, load_event)
|
||||
return format.format_state({state.get_name(): state.dict()})
|
||||
if not isinstance(load_event, List):
|
||||
load_event = [load_event]
|
||||
updates = []
|
||||
for single_event in load_event:
|
||||
updates.append(
|
||||
await self.execute_load_event(
|
||||
state, single_event, event.token, event.payload
|
||||
)
|
||||
)
|
||||
return updates
|
||||
delta = format.format_state({state.get_name(): state.dict()})
|
||||
return StateUpdate(delta=delta) if delta else None
|
||||
|
||||
def execute_load_event(self, state: State, load_event: EventHandler) -> None:
|
||||
async def execute_load_event(
|
||||
self, state: State, load_event: EventHandler, token: str, payload: Dict
|
||||
) -> StateUpdate:
|
||||
"""Execute single load event.
|
||||
|
||||
Args:
|
||||
state: The client state.
|
||||
load_event: A single load event to execute.
|
||||
token: Client token
|
||||
payload: The event payload
|
||||
|
||||
Returns:
|
||||
A state Update.
|
||||
|
||||
Raises:
|
||||
ValueError: If the state value is None.
|
||||
"""
|
||||
substate_path = format.format_event_handler(load_event).split(".")
|
||||
ex_state = state.get_substate(substate_path[:-1])
|
||||
load_event.fn(ex_state)
|
||||
if not ex_state:
|
||||
raise ValueError(
|
||||
"The value of state cannot be None when processing an on-load event."
|
||||
)
|
||||
|
||||
return await state.process_event(
|
||||
handler=load_event, state=ex_state, payload=payload, token=token
|
||||
)
|
||||
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
class LoggingMiddleware(Middleware):
|
||||
"""Middleware to log requests and responses."""
|
||||
|
||||
def preprocess(self, app: App, state: State, event: Event):
|
||||
async def preprocess(self, app: App, state: State, event: Event):
|
||||
"""Preprocess the event.
|
||||
|
||||
Args:
|
||||
@ -24,7 +24,7 @@ class LoggingMiddleware(Middleware):
|
||||
"""
|
||||
print(f"Event {event}")
|
||||
|
||||
def postprocess(self, app: App, state: State, event: Event, delta: Delta):
|
||||
async def postprocess(self, app: App, state: State, event: Event, delta: Delta):
|
||||
"""Postprocess the event.
|
||||
|
||||
Args:
|
||||
|
@ -2,11 +2,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from pynecone.base import Base
|
||||
from pynecone.event import Event
|
||||
from pynecone.state import Delta, State
|
||||
from pynecone.state import Delta, State, StateUpdate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pynecone.app import App
|
||||
@ -15,7 +15,9 @@ if TYPE_CHECKING:
|
||||
class Middleware(Base, ABC):
|
||||
"""Middleware to preprocess and postprocess requests."""
|
||||
|
||||
def preprocess(self, app: App, state: State, event: Event) -> Optional[Delta]:
|
||||
async def preprocess(
|
||||
self, app: App, state: State, event: Event
|
||||
) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
|
||||
"""Preprocess the event.
|
||||
|
||||
Args:
|
||||
@ -28,7 +30,7 @@ class Middleware(Base, ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def postprocess(
|
||||
async def postprocess(
|
||||
self, app: App, state: State, event: Event, delta
|
||||
) -> Optional[Delta]:
|
||||
"""Postprocess the event.
|
||||
|
@ -203,12 +203,12 @@ def export(
|
||||
|
||||
if zipping:
|
||||
console.rule(
|
||||
"""Backend & Frontend compiled. See [green bold]backend.zip[/green bold]
|
||||
"""Backend & Frontend compiled. See [green bold]backend.zip[/green bold]
|
||||
and [green bold]frontend.zip[/green bold]."""
|
||||
)
|
||||
else:
|
||||
console.rule(
|
||||
"""Backend & Frontend compiled. See [green bold]app[/green bold]
|
||||
"""Backend & Frontend compiled. See [green bold]app[/green bold]
|
||||
and [green bold].web/_static[/green bold] directories."""
|
||||
)
|
||||
|
||||
|
@ -546,13 +546,16 @@ class State(Base, ABC):
|
||||
return self.substates[path[0]].get_substate(path[1:])
|
||||
|
||||
async def process(self, event: Event) -> StateUpdate:
|
||||
"""Process an event.
|
||||
"""Obtain event info and process event.
|
||||
|
||||
Args:
|
||||
event: The event to process.
|
||||
|
||||
Returns:
|
||||
The state update after processing the event.
|
||||
|
||||
Raises:
|
||||
ValueError: If the state value is None.
|
||||
"""
|
||||
# Get the event handler.
|
||||
path = event.name.split(".")
|
||||
@ -560,23 +563,48 @@ class State(Base, ABC):
|
||||
substate = self.get_substate(path)
|
||||
handler = substate.event_handlers[name] # type: ignore
|
||||
|
||||
# Process the event.
|
||||
fn = functools.partial(handler.fn, substate)
|
||||
if not substate:
|
||||
raise ValueError(
|
||||
"The value of state cannot be None when processing an event."
|
||||
)
|
||||
|
||||
return await self.process_event(
|
||||
handler=handler,
|
||||
state=substate,
|
||||
payload=event.payload,
|
||||
token=event.token,
|
||||
)
|
||||
|
||||
async def process_event(
|
||||
self, handler: EventHandler, state: State, payload: Dict, token: str
|
||||
) -> StateUpdate:
|
||||
"""Process event.
|
||||
|
||||
Args:
|
||||
handler: Eventhandler to process.
|
||||
state: State to process the handler.
|
||||
payload: The event payload.
|
||||
token: Client token.
|
||||
|
||||
Returns:
|
||||
The state update after processing the event.
|
||||
"""
|
||||
fn = functools.partial(handler.fn, state)
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(fn.func):
|
||||
events = await fn(**event.payload)
|
||||
events = await fn(**payload)
|
||||
else:
|
||||
events = fn(**event.payload)
|
||||
events = fn(**payload)
|
||||
except Exception:
|
||||
error = traceback.format_exc()
|
||||
print(error)
|
||||
events = fix_events(
|
||||
[window_alert("An error occurred. See logs for details.")], event.token
|
||||
[window_alert("An error occurred. See logs for details.")], token
|
||||
)
|
||||
return StateUpdate(events=events)
|
||||
|
||||
# Fix the returned events.
|
||||
events = fix_events(events, event.token)
|
||||
events = fix_events(events, token)
|
||||
|
||||
# Get the delta after processing the event.
|
||||
delta = self.get_delta()
|
||||
|
@ -94,7 +94,6 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
|
||||
is_data_frame: whether data field is a pandas dataframe.
|
||||
"""
|
||||
with pytest.raises(ValueError) as err:
|
||||
|
||||
if is_data_frame:
|
||||
data_table(data=request.getfixturevalue(fixture).data)
|
||||
else:
|
||||
|
0
tests/middleware/__init__.py
Normal file
0
tests/middleware/__init__.py
Normal file
34
tests/middleware/conftest.py
Normal file
34
tests/middleware/conftest.py
Normal file
@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
|
||||
from pynecone.event import Event
|
||||
|
||||
|
||||
def create_event(name):
|
||||
return Event(
|
||||
token="<token>",
|
||||
name=name,
|
||||
router_data={
|
||||
"pathname": "/",
|
||||
"query": {},
|
||||
"token": "<token>",
|
||||
"sid": "<sid>",
|
||||
"headers": {},
|
||||
"ip": "127.0.0.1",
|
||||
},
|
||||
payload={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event1():
|
||||
return create_event("test_state.hydrate")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event2():
|
||||
return create_event("test_state2.hydrate")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event3():
|
||||
return create_event("test_state3.hydrate")
|
96
tests/middleware/test_hydrate_middleware.py
Normal file
96
tests/middleware/test_hydrate_middleware.py
Normal file
@ -0,0 +1,96 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from pynecone.app import App
|
||||
from pynecone.middleware.hydrate_middleware import HydrateMiddleware
|
||||
from pynecone.state import State
|
||||
|
||||
|
||||
class TestState(State):
|
||||
"""A test state with no return in handler."""
|
||||
|
||||
num: int = 0
|
||||
|
||||
def test_handler(self):
|
||||
"""Test handler."""
|
||||
self.num += 1
|
||||
|
||||
|
||||
class TestState2(State):
|
||||
"""A test state with return in handler."""
|
||||
|
||||
num: int = 0
|
||||
name: str
|
||||
|
||||
def test_handler(self):
|
||||
"""Test handler that calls another handler.
|
||||
|
||||
Returns:
|
||||
Chain of EventHandlers
|
||||
"""
|
||||
self.num += 1
|
||||
return [self.change_name()]
|
||||
|
||||
def change_name(self):
|
||||
"""Test handler to change name."""
|
||||
self.name = "random"
|
||||
|
||||
|
||||
class TestState3(State):
|
||||
"""A test state with async handler."""
|
||||
|
||||
num: int = 0
|
||||
|
||||
async def test_handler(self):
|
||||
"""Test handler."""
|
||||
self.num += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"state, expected, event_fixture",
|
||||
[
|
||||
(TestState, {"test_state": {"num": 1}}, "event1"),
|
||||
(TestState2, {"test_state2": {"num": 1}}, "event2"),
|
||||
(TestState3, {"test_state3": {"num": 1}}, "event3"),
|
||||
],
|
||||
)
|
||||
async def test_preprocess(state, request, event_fixture, expected):
|
||||
"""Test that a state hydrate event is processed correctly.
|
||||
|
||||
Args:
|
||||
state: state to process event
|
||||
request: pytest fixture request
|
||||
event_fixture: The event fixture(an Event)
|
||||
expected: expected delta
|
||||
"""
|
||||
app = App(state=state, load_events={"index": state.test_handler})
|
||||
|
||||
hydrate_middleware = HydrateMiddleware()
|
||||
result = await hydrate_middleware.preprocess(
|
||||
app=app, event=request.getfixturevalue(event_fixture), state=state()
|
||||
)
|
||||
assert isinstance(result, List)
|
||||
assert result[0].delta == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preprocess_multiple_load_events(event1):
|
||||
"""Test that a state hydrate event for multiple on-load events is processed correctly.
|
||||
|
||||
Args:
|
||||
event1: an Event.
|
||||
"""
|
||||
app = App(
|
||||
state=TestState,
|
||||
load_events={"index": [TestState.test_handler, TestState.test_handler]},
|
||||
)
|
||||
|
||||
hydrate_middleware = HydrateMiddleware()
|
||||
result = await hydrate_middleware.preprocess(
|
||||
app=app, event=event1, state=TestState()
|
||||
)
|
||||
assert isinstance(result, List)
|
||||
assert result[0].delta == {"test_state": {"num": 1}}
|
||||
assert result[1].delta == {"test_state": {"num": 2}}
|
Loading…
Reference in New Issue
Block a user