diff --git a/pynecone/app.py b/pynecone/app.py index 8c359e30d..c4868e8d9 100644 --- a/pynecone/app.py +++ b/pynecone/app.py @@ -148,7 +148,9 @@ class App(Base): allow_origins=["*"], ) - def preprocess(self, state: State, event: Event) -> Optional[Delta]: + async def preprocess( + self, state: State, event: Event + ) -> Optional[Union[StateUpdate, List[StateUpdate]]]: """Preprocess the event. This is where middleware can modify the event before it is processed. @@ -165,11 +167,13 @@ class App(Base): An optional state to return. """ for middleware in self.middleware: - out = middleware.preprocess(app=self, state=state, event=event) + out = await middleware.preprocess(app=self, state=state, event=event) if out is not None: return out - def postprocess(self, state: State, event: Event, delta: Delta) -> Optional[Delta]: + async def postprocess( + self, state: State, event: Event, delta: Delta + ) -> Optional[Delta]: """Postprocess the event. This is where middleware can modify the delta after it is processed. @@ -187,7 +191,7 @@ class App(Base): An optional state to return. """ for middleware in self.middleware: - out = middleware.postprocess( + out = await middleware.postprocess( app=self, state=state, event=event, delta=delta ) if out is not None: @@ -400,7 +404,7 @@ class App(Base): async def process( app: App, event: Event, sid: str, headers: Dict, client_ip: str -) -> StateUpdate: +) -> Union[StateUpdate, List[StateUpdate]]: """Process an event. Args: @@ -411,7 +415,7 @@ async def process( client_ip: The client_ip. Returns: - The state update after processing the event. + The state update(s) after processing the event. """ # Get the state for the session. state = app.state_manager.get_state(event.token) @@ -430,21 +434,27 @@ async def process( state.router_data[constants.RouteVar.CLIENT_IP] = client_ip # Preprocess the event. - pre = app.preprocess(state, event) - if pre is not None: - return StateUpdate(delta=pre) + pre = await app.preprocess(state, event) + if pre is not None and not isinstance(pre, List): + return pre # Apply the event to the state. - update = await state.process(event) + updates = pre if pre else await state.process(event) app.state_manager.set_state(event.token, state) + updates = updates if isinstance(updates, List) else [updates] + # Postprocess the event. - post = app.postprocess(state, event, update.delta) - if post is not None: - return StateUpdate(delta=post) + post_list = [] + for update in updates: + post = await app.postprocess(state, event, update.delta) # type: ignore + post_list.append(post) if post else None + + if post_list: + return [StateUpdate(delta=post) for post in post_list] # Return the update. - return update + return updates async def ping() -> str: @@ -578,11 +588,12 @@ class EventNamespace(AsyncNamespace): # Get the client IP client_ip = environ["REMOTE_ADDR"] - # Process the event. - update = await process(self.app, event, sid, headers, client_ip) + # Process the events. + updates = await process(self.app, event, sid, headers, client_ip) # Emit the event. - await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) + for update in updates: + await self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore async def on_ping(self, sid): """Event for testing the API endpoint. diff --git a/pynecone/components/typography/markdown.py b/pynecone/components/typography/markdown.py index a2fa771fa..8d50912d1 100644 --- a/pynecone/components/typography/markdown.py +++ b/pynecone/components/typography/markdown.py @@ -69,7 +69,7 @@ class Markdown(Component): "li": "{ListItem}", "p": "{Text}", "a": "{Link}", - "code": """{({node, inline, className, children, ...props}) => + "code": """{({node, inline, className, children, ...props}) => { const match = (className || '').match(/language-(?.*)/); return !inline ? ( diff --git a/pynecone/middleware/hydrate_middleware.py b/pynecone/middleware/hydrate_middleware.py index 4c2b254f1..159d06e8b 100644 --- a/pynecone/middleware/hydrate_middleware.py +++ b/pynecone/middleware/hydrate_middleware.py @@ -1,12 +1,12 @@ """Middleware to hydrate the state.""" from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union from pynecone import constants from pynecone.event import Event, EventHandler, get_hydrate_event from pynecone.middleware.middleware import Middleware -from pynecone.state import Delta, State +from pynecone.state import State, StateUpdate from pynecone.utils import format if TYPE_CHECKING: @@ -16,7 +16,9 @@ if TYPE_CHECKING: class HydrateMiddleware(Middleware): """Middleware to handle initial app hydration.""" - def preprocess(self, app: App, state: State, event: Event) -> Optional[Delta]: + async def preprocess( + self, app: App, state: State, event: Event + ) -> Optional[Union[StateUpdate, List[StateUpdate]]]: """Preprocess the event. Args: @@ -25,7 +27,7 @@ class HydrateMiddleware(Middleware): event: The event to preprocess. Returns: - An optional state to return. + An optional delta or list of state updates to return. """ if event.name == get_hydrate_event(state): route = event.router_data.get(constants.RouteVar.PATH, "") @@ -37,20 +39,43 @@ class HydrateMiddleware(Middleware): load_event = None if load_event: - if isinstance(load_event, list): - for single_event in load_event: - self.execute_load_event(state, single_event) - else: - self.execute_load_event(state, load_event) - return format.format_state({state.get_name(): state.dict()}) + if not isinstance(load_event, List): + load_event = [load_event] + updates = [] + for single_event in load_event: + updates.append( + await self.execute_load_event( + state, single_event, event.token, event.payload + ) + ) + return updates + delta = format.format_state({state.get_name(): state.dict()}) + return StateUpdate(delta=delta) if delta else None - def execute_load_event(self, state: State, load_event: EventHandler) -> None: + async def execute_load_event( + self, state: State, load_event: EventHandler, token: str, payload: Dict + ) -> StateUpdate: """Execute single load event. Args: state: The client state. load_event: A single load event to execute. + token: Client token + payload: The event payload + + Returns: + A state Update. + + Raises: + ValueError: If the state value is None. """ substate_path = format.format_event_handler(load_event).split(".") ex_state = state.get_substate(substate_path[:-1]) - load_event.fn(ex_state) + if not ex_state: + raise ValueError( + "The value of state cannot be None when processing an on-load event." + ) + + return await state.process_event( + handler=load_event, state=ex_state, payload=payload, token=token + ) diff --git a/pynecone/middleware/logging_middleware.py b/pynecone/middleware/logging_middleware.py index e5e95399d..351364686 100644 --- a/pynecone/middleware/logging_middleware.py +++ b/pynecone/middleware/logging_middleware.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: class LoggingMiddleware(Middleware): """Middleware to log requests and responses.""" - def preprocess(self, app: App, state: State, event: Event): + async def preprocess(self, app: App, state: State, event: Event): """Preprocess the event. Args: @@ -24,7 +24,7 @@ class LoggingMiddleware(Middleware): """ print(f"Event {event}") - def postprocess(self, app: App, state: State, event: Event, delta: Delta): + async def postprocess(self, app: App, state: State, event: Event, delta: Delta): """Postprocess the event. Args: diff --git a/pynecone/middleware/middleware.py b/pynecone/middleware/middleware.py index 94fbc0cd4..b7f792053 100644 --- a/pynecone/middleware/middleware.py +++ b/pynecone/middleware/middleware.py @@ -2,11 +2,11 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional, Union from pynecone.base import Base from pynecone.event import Event -from pynecone.state import Delta, State +from pynecone.state import Delta, State, StateUpdate if TYPE_CHECKING: from pynecone.app import App @@ -15,7 +15,9 @@ if TYPE_CHECKING: class Middleware(Base, ABC): """Middleware to preprocess and postprocess requests.""" - def preprocess(self, app: App, state: State, event: Event) -> Optional[Delta]: + async def preprocess( + self, app: App, state: State, event: Event + ) -> Optional[Union[StateUpdate, List[StateUpdate]]]: """Preprocess the event. Args: @@ -28,7 +30,7 @@ class Middleware(Base, ABC): """ return None - def postprocess( + async def postprocess( self, app: App, state: State, event: Event, delta ) -> Optional[Delta]: """Postprocess the event. diff --git a/pynecone/pc.py b/pynecone/pc.py index 7edc6bff9..ab3c19e23 100644 --- a/pynecone/pc.py +++ b/pynecone/pc.py @@ -203,12 +203,12 @@ def export( if zipping: console.rule( - """Backend & Frontend compiled. See [green bold]backend.zip[/green bold] + """Backend & Frontend compiled. See [green bold]backend.zip[/green bold] and [green bold]frontend.zip[/green bold].""" ) else: console.rule( - """Backend & Frontend compiled. See [green bold]app[/green bold] + """Backend & Frontend compiled. See [green bold]app[/green bold] and [green bold].web/_static[/green bold] directories.""" ) diff --git a/pynecone/state.py b/pynecone/state.py index 8dff76cfa..d8a5f1db3 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -546,13 +546,16 @@ class State(Base, ABC): return self.substates[path[0]].get_substate(path[1:]) async def process(self, event: Event) -> StateUpdate: - """Process an event. + """Obtain event info and process event. Args: event: The event to process. Returns: The state update after processing the event. + + Raises: + ValueError: If the state value is None. """ # Get the event handler. path = event.name.split(".") @@ -560,23 +563,48 @@ class State(Base, ABC): substate = self.get_substate(path) handler = substate.event_handlers[name] # type: ignore - # Process the event. - fn = functools.partial(handler.fn, substate) + if not substate: + raise ValueError( + "The value of state cannot be None when processing an event." + ) + + return await self.process_event( + handler=handler, + state=substate, + payload=event.payload, + token=event.token, + ) + + async def process_event( + self, handler: EventHandler, state: State, payload: Dict, token: str + ) -> StateUpdate: + """Process event. + + Args: + handler: Eventhandler to process. + state: State to process the handler. + payload: The event payload. + token: Client token. + + Returns: + The state update after processing the event. + """ + fn = functools.partial(handler.fn, state) try: if asyncio.iscoroutinefunction(fn.func): - events = await fn(**event.payload) + events = await fn(**payload) else: - events = fn(**event.payload) + events = fn(**payload) except Exception: error = traceback.format_exc() print(error) events = fix_events( - [window_alert("An error occurred. See logs for details.")], event.token + [window_alert("An error occurred. See logs for details.")], token ) return StateUpdate(events=events) # Fix the returned events. - events = fix_events(events, event.token) + events = fix_events(events, token) # Get the delta after processing the event. delta = self.get_delta() diff --git a/tests/components/datadisplay/test_datatable.py b/tests/components/datadisplay/test_datatable.py index b2ed7f2a9..f733cc63c 100644 --- a/tests/components/datadisplay/test_datatable.py +++ b/tests/components/datadisplay/test_datatable.py @@ -94,7 +94,6 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram is_data_frame: whether data field is a pandas dataframe. """ with pytest.raises(ValueError) as err: - if is_data_frame: data_table(data=request.getfixturevalue(fixture).data) else: diff --git a/tests/middleware/__init__.py b/tests/middleware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/middleware/conftest.py b/tests/middleware/conftest.py new file mode 100644 index 000000000..b06406930 --- /dev/null +++ b/tests/middleware/conftest.py @@ -0,0 +1,34 @@ +import pytest + +from pynecone.event import Event + + +def create_event(name): + return Event( + token="", + name=name, + router_data={ + "pathname": "/", + "query": {}, + "token": "", + "sid": "", + "headers": {}, + "ip": "127.0.0.1", + }, + payload={}, + ) + + +@pytest.fixture +def event1(): + return create_event("test_state.hydrate") + + +@pytest.fixture +def event2(): + return create_event("test_state2.hydrate") + + +@pytest.fixture +def event3(): + return create_event("test_state3.hydrate") diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py new file mode 100644 index 000000000..948b1f06a --- /dev/null +++ b/tests/middleware/test_hydrate_middleware.py @@ -0,0 +1,96 @@ +from typing import List + +import pytest + +from pynecone.app import App +from pynecone.middleware.hydrate_middleware import HydrateMiddleware +from pynecone.state import State + + +class TestState(State): + """A test state with no return in handler.""" + + num: int = 0 + + def test_handler(self): + """Test handler.""" + self.num += 1 + + +class TestState2(State): + """A test state with return in handler.""" + + num: int = 0 + name: str + + def test_handler(self): + """Test handler that calls another handler. + + Returns: + Chain of EventHandlers + """ + self.num += 1 + return [self.change_name()] + + def change_name(self): + """Test handler to change name.""" + self.name = "random" + + +class TestState3(State): + """A test state with async handler.""" + + num: int = 0 + + async def test_handler(self): + """Test handler.""" + self.num += 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "state, expected, event_fixture", + [ + (TestState, {"test_state": {"num": 1}}, "event1"), + (TestState2, {"test_state2": {"num": 1}}, "event2"), + (TestState3, {"test_state3": {"num": 1}}, "event3"), + ], +) +async def test_preprocess(state, request, event_fixture, expected): + """Test that a state hydrate event is processed correctly. + + Args: + state: state to process event + request: pytest fixture request + event_fixture: The event fixture(an Event) + expected: expected delta + """ + app = App(state=state, load_events={"index": state.test_handler}) + + hydrate_middleware = HydrateMiddleware() + result = await hydrate_middleware.preprocess( + app=app, event=request.getfixturevalue(event_fixture), state=state() + ) + assert isinstance(result, List) + assert result[0].delta == expected + + +@pytest.mark.asyncio +async def test_preprocess_multiple_load_events(event1): + """Test that a state hydrate event for multiple on-load events is processed correctly. + + Args: + event1: an Event. + """ + app = App( + state=TestState, + load_events={"index": [TestState.test_handler, TestState.test_handler]}, + ) + + hydrate_middleware = HydrateMiddleware() + result = await hydrate_middleware.preprocess( + app=app, event=event1, state=TestState() + ) + assert isinstance(result, List) + assert result[0].delta == {"test_state": {"num": 1}} + assert result[1].delta == {"test_state": {"num": 2}}