Fix Event chaining in the on_load event handler return not working (#773)

* Fix Event chaining in the on_load event handler return not working

* added async tests

* addressed comments
This commit is contained in:
Elijah Ahianyo 2023-04-07 05:26:43 +00:00 committed by GitHub
parent bd6ea9d977
commit e8387c8e26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 241 additions and 46 deletions

View File

@ -148,7 +148,9 @@ class App(Base):
allow_origins=["*"], allow_origins=["*"],
) )
def preprocess(self, state: State, event: Event) -> Optional[Delta]: async def preprocess(
self, state: State, event: Event
) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
"""Preprocess the event. """Preprocess the event.
This is where middleware can modify the event before it is processed. This is where middleware can modify the event before it is processed.
@ -165,11 +167,13 @@ class App(Base):
An optional state to return. An optional state to return.
""" """
for middleware in self.middleware: for middleware in self.middleware:
out = middleware.preprocess(app=self, state=state, event=event) out = await middleware.preprocess(app=self, state=state, event=event)
if out is not None: if out is not None:
return out return out
def postprocess(self, state: State, event: Event, delta: Delta) -> Optional[Delta]: async def postprocess(
self, state: State, event: Event, delta: Delta
) -> Optional[Delta]:
"""Postprocess the event. """Postprocess the event.
This is where middleware can modify the delta after it is processed. This is where middleware can modify the delta after it is processed.
@ -187,7 +191,7 @@ class App(Base):
An optional state to return. An optional state to return.
""" """
for middleware in self.middleware: for middleware in self.middleware:
out = middleware.postprocess( out = await middleware.postprocess(
app=self, state=state, event=event, delta=delta app=self, state=state, event=event, delta=delta
) )
if out is not None: if out is not None:
@ -400,7 +404,7 @@ class App(Base):
async def process( async def process(
app: App, event: Event, sid: str, headers: Dict, client_ip: str app: App, event: Event, sid: str, headers: Dict, client_ip: str
) -> StateUpdate: ) -> Union[StateUpdate, List[StateUpdate]]:
"""Process an event. """Process an event.
Args: Args:
@ -411,7 +415,7 @@ async def process(
client_ip: The client_ip. client_ip: The client_ip.
Returns: Returns:
The state update after processing the event. The state update(s) after processing the event.
""" """
# Get the state for the session. # Get the state for the session.
state = app.state_manager.get_state(event.token) state = app.state_manager.get_state(event.token)
@ -430,21 +434,27 @@ async def process(
state.router_data[constants.RouteVar.CLIENT_IP] = client_ip state.router_data[constants.RouteVar.CLIENT_IP] = client_ip
# Preprocess the event. # Preprocess the event.
pre = app.preprocess(state, event) pre = await app.preprocess(state, event)
if pre is not None: if pre is not None and not isinstance(pre, List):
return StateUpdate(delta=pre) return pre
# Apply the event to the state. # Apply the event to the state.
update = await state.process(event) updates = pre if pre else await state.process(event)
app.state_manager.set_state(event.token, state) app.state_manager.set_state(event.token, state)
updates = updates if isinstance(updates, List) else [updates]
# Postprocess the event. # Postprocess the event.
post = app.postprocess(state, event, update.delta) post_list = []
if post is not None: for update in updates:
return StateUpdate(delta=post) post = await app.postprocess(state, event, update.delta) # type: ignore
post_list.append(post) if post else None
if post_list:
return [StateUpdate(delta=post) for post in post_list]
# Return the update. # Return the update.
return update return updates
async def ping() -> str: async def ping() -> str:
@ -578,11 +588,12 @@ class EventNamespace(AsyncNamespace):
# Get the client IP # Get the client IP
client_ip = environ["REMOTE_ADDR"] client_ip = environ["REMOTE_ADDR"]
# Process the event. # Process the events.
update = await process(self.app, event, sid, headers, client_ip) updates = await process(self.app, event, sid, headers, client_ip)
# Emit the event. # Emit the event.
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) for update in updates:
await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
async def on_ping(self, sid): async def on_ping(self, sid):
"""Event for testing the API endpoint. """Event for testing the API endpoint.

View File

@ -69,7 +69,7 @@ class Markdown(Component):
"li": "{ListItem}", "li": "{ListItem}",
"p": "{Text}", "p": "{Text}",
"a": "{Link}", "a": "{Link}",
"code": """{({node, inline, className, children, ...props}) => "code": """{({node, inline, className, children, ...props}) =>
{ {
const match = (className || '').match(/language-(?<lang>.*)/); const match = (className || '').match(/language-(?<lang>.*)/);
return !inline ? ( return !inline ? (

View File

@ -1,12 +1,12 @@
"""Middleware to hydrate the state.""" """Middleware to hydrate the state."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Dict, List, Optional, Union
from pynecone import constants from pynecone import constants
from pynecone.event import Event, EventHandler, get_hydrate_event from pynecone.event import Event, EventHandler, get_hydrate_event
from pynecone.middleware.middleware import Middleware from pynecone.middleware.middleware import Middleware
from pynecone.state import Delta, State from pynecone.state import State, StateUpdate
from pynecone.utils import format from pynecone.utils import format
if TYPE_CHECKING: if TYPE_CHECKING:
@ -16,7 +16,9 @@ if TYPE_CHECKING:
class HydrateMiddleware(Middleware): class HydrateMiddleware(Middleware):
"""Middleware to handle initial app hydration.""" """Middleware to handle initial app hydration."""
def preprocess(self, app: App, state: State, event: Event) -> Optional[Delta]: async def preprocess(
self, app: App, state: State, event: Event
) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
"""Preprocess the event. """Preprocess the event.
Args: Args:
@ -25,7 +27,7 @@ class HydrateMiddleware(Middleware):
event: The event to preprocess. event: The event to preprocess.
Returns: Returns:
An optional state to return. An optional delta or list of state updates to return.
""" """
if event.name == get_hydrate_event(state): if event.name == get_hydrate_event(state):
route = event.router_data.get(constants.RouteVar.PATH, "") route = event.router_data.get(constants.RouteVar.PATH, "")
@ -37,20 +39,43 @@ class HydrateMiddleware(Middleware):
load_event = None load_event = None
if load_event: if load_event:
if isinstance(load_event, list): if not isinstance(load_event, List):
for single_event in load_event: load_event = [load_event]
self.execute_load_event(state, single_event) updates = []
else: for single_event in load_event:
self.execute_load_event(state, load_event) updates.append(
return format.format_state({state.get_name(): state.dict()}) await self.execute_load_event(
state, single_event, event.token, event.payload
)
)
return updates
delta = format.format_state({state.get_name(): state.dict()})
return StateUpdate(delta=delta) if delta else None
def execute_load_event(self, state: State, load_event: EventHandler) -> None: async def execute_load_event(
self, state: State, load_event: EventHandler, token: str, payload: Dict
) -> StateUpdate:
"""Execute single load event. """Execute single load event.
Args: Args:
state: The client state. state: The client state.
load_event: A single load event to execute. load_event: A single load event to execute.
token: Client token
payload: The event payload
Returns:
A state Update.
Raises:
ValueError: If the state value is None.
""" """
substate_path = format.format_event_handler(load_event).split(".") substate_path = format.format_event_handler(load_event).split(".")
ex_state = state.get_substate(substate_path[:-1]) ex_state = state.get_substate(substate_path[:-1])
load_event.fn(ex_state) if not ex_state:
raise ValueError(
"The value of state cannot be None when processing an on-load event."
)
return await state.process_event(
handler=load_event, state=ex_state, payload=payload, token=token
)

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
class LoggingMiddleware(Middleware): class LoggingMiddleware(Middleware):
"""Middleware to log requests and responses.""" """Middleware to log requests and responses."""
def preprocess(self, app: App, state: State, event: Event): async def preprocess(self, app: App, state: State, event: Event):
"""Preprocess the event. """Preprocess the event.
Args: Args:
@ -24,7 +24,7 @@ class LoggingMiddleware(Middleware):
""" """
print(f"Event {event}") print(f"Event {event}")
def postprocess(self, app: App, state: State, event: Event, delta: Delta): async def postprocess(self, app: App, state: State, event: Event, delta: Delta):
"""Postprocess the event. """Postprocess the event.
Args: Args:

View File

@ -2,11 +2,11 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC from abc import ABC
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, List, Optional, Union
from pynecone.base import Base from pynecone.base import Base
from pynecone.event import Event from pynecone.event import Event
from pynecone.state import Delta, State from pynecone.state import Delta, State, StateUpdate
if TYPE_CHECKING: if TYPE_CHECKING:
from pynecone.app import App from pynecone.app import App
@ -15,7 +15,9 @@ if TYPE_CHECKING:
class Middleware(Base, ABC): class Middleware(Base, ABC):
"""Middleware to preprocess and postprocess requests.""" """Middleware to preprocess and postprocess requests."""
def preprocess(self, app: App, state: State, event: Event) -> Optional[Delta]: async def preprocess(
self, app: App, state: State, event: Event
) -> Optional[Union[StateUpdate, List[StateUpdate]]]:
"""Preprocess the event. """Preprocess the event.
Args: Args:
@ -28,7 +30,7 @@ class Middleware(Base, ABC):
""" """
return None return None
def postprocess( async def postprocess(
self, app: App, state: State, event: Event, delta self, app: App, state: State, event: Event, delta
) -> Optional[Delta]: ) -> Optional[Delta]:
"""Postprocess the event. """Postprocess the event.

View File

@ -203,12 +203,12 @@ def export(
if zipping: if zipping:
console.rule( console.rule(
"""Backend & Frontend compiled. See [green bold]backend.zip[/green bold] """Backend & Frontend compiled. See [green bold]backend.zip[/green bold]
and [green bold]frontend.zip[/green bold].""" and [green bold]frontend.zip[/green bold]."""
) )
else: else:
console.rule( console.rule(
"""Backend & Frontend compiled. See [green bold]app[/green bold] """Backend & Frontend compiled. See [green bold]app[/green bold]
and [green bold].web/_static[/green bold] directories.""" and [green bold].web/_static[/green bold] directories."""
) )

View File

@ -546,13 +546,16 @@ class State(Base, ABC):
return self.substates[path[0]].get_substate(path[1:]) return self.substates[path[0]].get_substate(path[1:])
async def process(self, event: Event) -> StateUpdate: async def process(self, event: Event) -> StateUpdate:
"""Process an event. """Obtain event info and process event.
Args: Args:
event: The event to process. event: The event to process.
Returns: Returns:
The state update after processing the event. The state update after processing the event.
Raises:
ValueError: If the state value is None.
""" """
# Get the event handler. # Get the event handler.
path = event.name.split(".") path = event.name.split(".")
@ -560,23 +563,48 @@ class State(Base, ABC):
substate = self.get_substate(path) substate = self.get_substate(path)
handler = substate.event_handlers[name] # type: ignore handler = substate.event_handlers[name] # type: ignore
# Process the event. if not substate:
fn = functools.partial(handler.fn, substate) raise ValueError(
"The value of state cannot be None when processing an event."
)
return await self.process_event(
handler=handler,
state=substate,
payload=event.payload,
token=event.token,
)
async def process_event(
self, handler: EventHandler, state: State, payload: Dict, token: str
) -> StateUpdate:
"""Process event.
Args:
handler: Eventhandler to process.
state: State to process the handler.
payload: The event payload.
token: Client token.
Returns:
The state update after processing the event.
"""
fn = functools.partial(handler.fn, state)
try: try:
if asyncio.iscoroutinefunction(fn.func): if asyncio.iscoroutinefunction(fn.func):
events = await fn(**event.payload) events = await fn(**payload)
else: else:
events = fn(**event.payload) events = fn(**payload)
except Exception: except Exception:
error = traceback.format_exc() error = traceback.format_exc()
print(error) print(error)
events = fix_events( events = fix_events(
[window_alert("An error occurred. See logs for details.")], event.token [window_alert("An error occurred. See logs for details.")], token
) )
return StateUpdate(events=events) return StateUpdate(events=events)
# Fix the returned events. # Fix the returned events.
events = fix_events(events, event.token) events = fix_events(events, token)
# Get the delta after processing the event. # Get the delta after processing the event.
delta = self.get_delta() delta = self.get_delta()

View File

@ -94,7 +94,6 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
is_data_frame: whether data field is a pandas dataframe. is_data_frame: whether data field is a pandas dataframe.
""" """
with pytest.raises(ValueError) as err: with pytest.raises(ValueError) as err:
if is_data_frame: if is_data_frame:
data_table(data=request.getfixturevalue(fixture).data) data_table(data=request.getfixturevalue(fixture).data)
else: else:

View File

View File

@ -0,0 +1,34 @@
import pytest
from pynecone.event import Event
def create_event(name):
return Event(
token="<token>",
name=name,
router_data={
"pathname": "/",
"query": {},
"token": "<token>",
"sid": "<sid>",
"headers": {},
"ip": "127.0.0.1",
},
payload={},
)
@pytest.fixture
def event1():
return create_event("test_state.hydrate")
@pytest.fixture
def event2():
return create_event("test_state2.hydrate")
@pytest.fixture
def event3():
return create_event("test_state3.hydrate")

View File

@ -0,0 +1,96 @@
from typing import List
import pytest
from pynecone.app import App
from pynecone.middleware.hydrate_middleware import HydrateMiddleware
from pynecone.state import State
class TestState(State):
"""A test state with no return in handler."""
num: int = 0
def test_handler(self):
"""Test handler."""
self.num += 1
class TestState2(State):
"""A test state with return in handler."""
num: int = 0
name: str
def test_handler(self):
"""Test handler that calls another handler.
Returns:
Chain of EventHandlers
"""
self.num += 1
return [self.change_name()]
def change_name(self):
"""Test handler to change name."""
self.name = "random"
class TestState3(State):
"""A test state with async handler."""
num: int = 0
async def test_handler(self):
"""Test handler."""
self.num += 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"state, expected, event_fixture",
[
(TestState, {"test_state": {"num": 1}}, "event1"),
(TestState2, {"test_state2": {"num": 1}}, "event2"),
(TestState3, {"test_state3": {"num": 1}}, "event3"),
],
)
async def test_preprocess(state, request, event_fixture, expected):
"""Test that a state hydrate event is processed correctly.
Args:
state: state to process event
request: pytest fixture request
event_fixture: The event fixture(an Event)
expected: expected delta
"""
app = App(state=state, load_events={"index": state.test_handler})
hydrate_middleware = HydrateMiddleware()
result = await hydrate_middleware.preprocess(
app=app, event=request.getfixturevalue(event_fixture), state=state()
)
assert isinstance(result, List)
assert result[0].delta == expected
@pytest.mark.asyncio
async def test_preprocess_multiple_load_events(event1):
"""Test that a state hydrate event for multiple on-load events is processed correctly.
Args:
event1: an Event.
"""
app = App(
state=TestState,
load_events={"index": [TestState.test_handler, TestState.test_handler]},
)
hydrate_middleware = HydrateMiddleware()
result = await hydrate_middleware.preprocess(
app=app, event=event1, state=TestState()
)
assert isinstance(result, List)
assert result[0].delta == {"test_state": {"num": 1}}
assert result[1].delta == {"test_state": {"num": 2}}