From e887dd143bc11013ced63fbf4336e36075da1388 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 10 May 2023 01:11:54 -0700 Subject: [PATCH] Trigger on_load when router completes navigation (#984) --- pynecone/app.py | 13 +- pynecone/compiler/templates.py | 1 + pynecone/constants.py | 4 + pynecone/middleware/hydrate_middleware.py | 3 + pynecone/state.py | 8 + pynecone/templates/web/pages/index.js.jinja2 | 7 + tests/test_app.py | 196 ++++++++++++++++++- 7 files changed, 222 insertions(+), 10 deletions(-) diff --git a/pynecone/app.py b/pynecone/app.py index 8c273981d..cb372e627 100644 --- a/pynecone/app.py +++ b/pynecone/app.py @@ -441,8 +441,8 @@ async def process( state = app.state_manager.get_state(event.token) # Add request data to the state. - state.router_data = event.router_data - state.router_data.update( + router_data = event.router_data + router_data.update( { constants.RouteVar.QUERY: format.format_query_params(event.router_data), constants.RouteVar.CLIENT_TOKEN: event.token, @@ -451,10 +451,11 @@ async def process( constants.RouteVar.CLIENT_IP: client_ip, } ) - - # Also pass router_data to all substates. (TODO: this isn't recursive currently) - for _, substate in state.substates.items(): - substate.router_data = state.router_data + # re-assign only when the value is different + if state.router_data != router_data: + # assignment will recurse into substates and force recalculation of + # dependent ComputedVar (dynamic route variables) + state.router_data = router_data # Preprocess the event. update = await app.preprocess(state, event) diff --git a/pynecone/compiler/templates.py b/pynecone/compiler/templates.py index c97bea9a5..e42b6f966 100644 --- a/pynecone/compiler/templates.py +++ b/pynecone/compiler/templates.py @@ -37,6 +37,7 @@ class PyneconeJinjaEnvironment(Environment): "color_mode": constants.COLOR_MODE, "toggle_color_mode": constants.TOGGLE_COLOR_MODE, "use_color_mode": constants.USE_COLOR_MODE, + "hydrate": constants.HYDRATE, } diff --git a/pynecone/constants.py b/pynecone/constants.py index d7f01b7d3..863d37905 100644 --- a/pynecone/constants.py +++ b/pynecone/constants.py @@ -267,6 +267,10 @@ class RouteArgType(SimpleNamespace): LIST = str("arg_list") +# the name of the backend var containing path and client information +ROUTER_DATA = "router_data" + + class RouteVar(SimpleNamespace): """Names of variables used in the router_data dict stored in State.""" diff --git a/pynecone/middleware/hydrate_middleware.py b/pynecone/middleware/hydrate_middleware.py index 779802a44..9d227c2b4 100644 --- a/pynecone/middleware/hydrate_middleware.py +++ b/pynecone/middleware/hydrate_middleware.py @@ -37,7 +37,10 @@ class HydrateMiddleware(Middleware): return None # Get the initial state. + setattr(state, constants.IS_HYDRATED, False) delta = format.format_state({state.get_name(): state.dict()}) + # since a full dict was captured, clean any dirtiness + state.clean() # Get the route for on_load events. route = event.router_data.get(constants.RouteVar.PATH, "") diff --git a/pynecone/state.py b/pynecone/state.py index ab3b5731b..3f3ba1eab 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -574,6 +574,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow): self.dirty_vars.add(name) self.mark_dirty() + # For now, handle router_data updates as a special case + if name == constants.ROUTER_DATA: + self.dirty_vars.add(name) + self.mark_dirty() + # propagate router_data updates down the state tree + for substate in self.substates.values(): + setattr(substate, name, value) + def reset(self): """Reset all the base vars to their default values.""" # Reset the base vars. diff --git a/pynecone/templates/web/pages/index.js.jinja2 b/pynecone/templates/web/pages/index.js.jinja2 index 819f7c673..f8e4bdfdd 100644 --- a/pynecone/templates/web/pages/index.js.jinja2 +++ b/pynecone/templates/web/pages/index.js.jinja2 @@ -54,6 +54,13 @@ export default function Component() { } update() }) + useEffect(() => { + const change_complete = () => Event([E('{{state_name}}.{{const.hydrate}}', {})]) + {{const.router}}.events.on('routeChangeComplete', change_complete) + return () => { + {{const.router}}.events.off('routeChangeComplete', change_complete) + } + }, [{{const.router}}]) {% for hook in hooks %} {{ hook }} diff --git a/tests/test_app.py b/tests/test_app.py index 543627f32..353ee4294 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -5,12 +5,15 @@ from typing import List, Tuple, Type import pytest from fastapi import UploadFile -from pynecone.app import App, DefaultState, upload +from pynecone import constants +from pynecone.app import App, DefaultState, process, upload from pynecone.components import Box -from pynecone.event import Event +from pynecone.event import Event, get_hydrate_event from pynecone.middleware import HydrateMiddleware from pynecone.state import State, StateUpdate from pynecone.style import Style +from pynecone.utils import format +from pynecone.var import ComputedVar @pytest.fixture @@ -121,9 +124,9 @@ def test_add_page_set_route_dynamic(app: App, index_page, windows_platform: bool assert set(app.pages.keys()) == {"test/[dynamic]"} assert "dynamic" in app.state.computed_vars assert app.state.computed_vars["dynamic"].deps(objclass=DefaultState) == { - "router_data" + constants.ROUTER_DATA } - assert "router_data" in app.state().computed_var_dependencies + assert constants.ROUTER_DATA in app.state().computed_var_dependencies def test_add_page_set_route_nested(app: App, index_page, windows_platform: bool): @@ -547,3 +550,188 @@ async def test_upload_file_without_annotation(fixture, request): err.value.args[0] == "`file_upload_state.handle_upload2` handler should have a parameter annotated as List[pc.UploadFile]" ) + + +class DynamicState(State): + """State class for testing dynamic route var. + + This is defined at module level because event handlers cannot be addressed + correctly when the class is defined as a local. + + There are several counters: + * loaded: counts how many times `on_load` was triggered by the hydrate middleware + * counter: counts how many times `on_counter` was triggered by a non-naviagational event + -> these events should NOT trigger reload or recalculation of router_data dependent vars + * side_effect_counter: counts how many times a computed var was + recalculated when the dynamic route var was dirty + """ + + loaded: int = 0 + counter: int = 0 + # side_effect_counter: int = 0 + + def on_load(self): + """Event handler for page on_load, should trigger for all navigation events.""" + self.loaded = self.loaded + 1 + + def on_counter(self): + """Increment the counter var.""" + self.counter = self.counter + 1 + + @ComputedVar + def comp_dynamic(self) -> str: + """A computed var that depends on the dynamic var. + + Returns: + same as self.dynamic + """ + # self.side_effect_counter = self.side_effect_counter + 1 + return self.dynamic + + +@pytest.mark.asyncio +async def test_dynamic_route_var_route_change_completed_on_load( + index_page, + windows_platform: bool, +): + """Create app with dynamic route var, and simulate navigation. + + on_load should fire, allowing any additional vars to be updated before the + initial page hydrate. + + Args: + index_page: The index page. + windows_platform: Whether the system is windows. + """ + arg_name = "dynamic" + route = f"/test/[{arg_name}]" + if windows_platform: + route.lstrip("/").replace("/", "\\") + app = App(state=DynamicState) + assert arg_name not in app.state.vars + app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore + assert arg_name in app.state.vars + assert arg_name in app.state.computed_vars + assert app.state.computed_vars[arg_name].deps(objclass=DynamicState) == { + constants.ROUTER_DATA + } + assert constants.ROUTER_DATA in app.state().computed_var_dependencies + + token = "mock_token" + sid = "mock_sid" + client_ip = "127.0.0.1" + state = app.state_manager.get_state(token) + assert state.dynamic == "" + exp_vals = ["foo", "foobar", "baz"] + + def _event(name, val, **kwargs): + return Event( + token=kwargs.pop("token", token), + name=name, + router_data=kwargs.pop( + "router_data", {"pathname": route, "query": {arg_name: val}} + ), + payload=kwargs.pop("payload", {}), + **kwargs, + ) + + def _dynamic_state_event(name, val, **kwargs): + return _event( + name=format.format_event_handler(getattr(DynamicState, name)), # type: ignore + val=val, + **kwargs, + ) + + for exp_index, exp_val in enumerate(exp_vals): + update = await process( + app, + event=_event(name=get_hydrate_event(state), val=exp_val), + sid=sid, + headers={}, + client_ip=client_ip, + ) + # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)] + assert update == StateUpdate( + delta={ + state.get_name(): { + arg_name: exp_val, + f"comp_{arg_name}": exp_val, + constants.IS_HYDRATED: False, + "loaded": exp_index, + "counter": exp_index, + # "side_effect_counter": exp_index, + } + }, + events=[ + _dynamic_state_event(name="on_load", val=exp_val, router_data={}), + _dynamic_state_event( + name="set_is_hydrated", + payload={"value": "true"}, + val=exp_val, + router_data={}, + ), + ], + ) + assert state.dynamic == exp_val + on_load_update = await process( + app, + event=_dynamic_state_event(name="on_load", val=exp_val), + sid=sid, + headers={}, + client_ip=client_ip, + ) + assert on_load_update == StateUpdate( + delta={ + state.get_name(): { + # These computed vars _shouldn't_ be here, because they didn't change + arg_name: exp_val, + f"comp_{arg_name}": exp_val, + "loaded": exp_index + 1, + }, + }, + events=[], + ) + on_set_is_hydrated_update = await process( + app, + event=_dynamic_state_event( + name="set_is_hydrated", payload={"value": True}, val=exp_val + ), + sid=sid, + headers={}, + client_ip=client_ip, + ) + assert on_set_is_hydrated_update == StateUpdate( + delta={ + state.get_name(): { + # These computed vars _shouldn't_ be here, because they didn't change + arg_name: exp_val, + f"comp_{arg_name}": exp_val, + "is_hydrated": True, + }, + }, + events=[], + ) + + # a simple state update event should NOT trigger on_load or route var side effects + update = await process( + app, + event=_dynamic_state_event(name="on_counter", val=exp_val), + sid=sid, + headers={}, + client_ip=client_ip, + ) + assert update == StateUpdate( + delta={ + state.get_name(): { + # These computed vars _shouldn't_ be here, because they didn't change + f"comp_{arg_name}": exp_val, + arg_name: exp_val, + "counter": exp_index + 1, + } + }, + events=[], + ) + assert state.loaded == len(exp_vals) + assert state.counter == len(exp_vals) + # print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}") + # assert state.side_effect_counter == len(exp_vals)