diff --git a/pynecone/middleware/hydrate_middleware.py b/pynecone/middleware/hydrate_middleware.py index ea93035a1..0e7f091e8 100644 --- a/pynecone/middleware/hydrate_middleware.py +++ b/pynecone/middleware/hydrate_middleware.py @@ -13,6 +13,12 @@ if TYPE_CHECKING: from pynecone.app import App +IS_HYDRATED = "is_hydrated" + + +State.add_var(IS_HYDRATED, type_=bool, default_value=False) + + class HydrateMiddleware(Middleware): """Middleware to handle initial app hydration.""" @@ -38,19 +44,31 @@ class HydrateMiddleware(Middleware): else: load_event = None + updates = [] + + # first get the initial state + delta = format.format_state({state.get_name(): state.dict()}) + if delta: + updates.append(StateUpdate(delta=delta)) + + # then apply changes from on_load event handlers on top of that if load_event: if not isinstance(load_event, List): load_event = [load_event] - updates = [] for single_event in load_event: updates.append( await self.execute_load_event( state, single_event, event.token, event.payload ) ) - return updates - delta = format.format_state({state.get_name(): state.dict()}) - return StateUpdate(delta=delta) if delta else None + # extra message telling the client state that hydration is complete + updates.append( + StateUpdate( + delta=format.format_state({state.get_name(): {IS_HYDRATED: True}}) + ) + ) + + return updates async def execute_load_event( self, state: State, load_event: EventHandler, token: str, payload: Dict diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 4ed0a1082..0bc13de8e 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -1,15 +1,29 @@ -from typing import List +from typing import Any, Dict, List import pytest from pynecone.app import App -from pynecone.middleware.hydrate_middleware import HydrateMiddleware +from pynecone.middleware.hydrate_middleware import IS_HYDRATED, HydrateMiddleware from pynecone.state import State +def exp_is_hydrated(state: State) -> Dict[str, Any]: + """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. + + Args: + state: the State that is hydrated + + Returns: + dict similar to that returned by `State.get_delta` with IS_HYDRATED: True + """ + return {state.get_name(): {IS_HYDRATED: True}} + + class TestState(State): """A test state with no return in handler.""" + __test__ = False + num: int = 0 def test_handler(self): @@ -20,6 +34,8 @@ class TestState(State): class TestState2(State): """A test state with return in handler.""" + __test__ = False + num: int = 0 name: str @@ -40,6 +56,8 @@ class TestState2(State): class TestState3(State): """A test state with async handler.""" + __test__ = False + num: int = 0 async def test_handler(self): @@ -47,6 +65,16 @@ class TestState3(State): self.num += 1 +@pytest.fixture +def hydrate_middleware() -> HydrateMiddleware: + """Fixture creates an instance of HydrateMiddleware per test case. + + Returns: + instance of HydrateMiddleware + """ + return HydrateMiddleware() + + @pytest.mark.asyncio @pytest.mark.parametrize( "state, expected, event_fixture", @@ -56,30 +84,34 @@ class TestState3(State): (TestState3, {"test_state3": {"num": 1}}, "event3"), ], ) -async def test_preprocess(state, request, event_fixture, expected): +async def test_preprocess(state, hydrate_middleware, request, event_fixture, expected): """Test that a state hydrate event is processed correctly. Args: state: state to process event + hydrate_middleware: instance of HydrateMiddleware request: pytest fixture request event_fixture: The event fixture(an Event) expected: expected delta """ app = App(state=state, load_events={"index": state.test_handler}) - hydrate_middleware = HydrateMiddleware() result = await hydrate_middleware.preprocess( app=app, event=request.getfixturevalue(event_fixture), state=state() ) assert isinstance(result, List) - assert result[0].delta == expected + assert len(result) == 3 + assert result[0].delta == {state().get_name(): state().dict()} + assert result[1].delta == expected + assert result[2].delta == exp_is_hydrated(state()) @pytest.mark.asyncio -async def test_preprocess_multiple_load_events(event1): +async def test_preprocess_multiple_load_events(hydrate_middleware, event1): """Test that a state hydrate event for multiple on-load events is processed correctly. Args: + hydrate_middleware: instance of HydrateMiddleware event1: an Event. """ app = App( @@ -87,10 +119,31 @@ async def test_preprocess_multiple_load_events(event1): load_events={"index": [TestState.test_handler, TestState.test_handler]}, ) - hydrate_middleware = HydrateMiddleware() result = await hydrate_middleware.preprocess( app=app, event=event1, state=TestState() ) assert isinstance(result, List) - assert result[0].delta == {"test_state": {"num": 1}} - assert result[1].delta == {"test_state": {"num": 2}} + assert len(result) == 4 + assert result[0].delta == {"test_state": TestState().dict()} + assert result[1].delta == {"test_state": {"num": 1}} + assert result[2].delta == {"test_state": {"num": 2}} + assert result[3].delta == exp_is_hydrated(TestState()) + + +@pytest.mark.asyncio +async def test_preprocess_no_events(hydrate_middleware, event1): + """Test that app without on_load is processed correctly. + + Args: + hydrate_middleware: instance of HydrateMiddleware + event1: an Event. + """ + result = await hydrate_middleware.preprocess( + app=App(state=TestState), + event=event1, + state=TestState(), + ) + assert isinstance(result, List) + assert len(result) == 2 + assert result[0].delta == {"test_state": TestState().dict()} + assert result[1].delta == exp_is_hydrated(TestState()) diff --git a/tests/test_state.py b/tests/test_state.py index 28f1519ae..ba65e2f08 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -6,6 +6,7 @@ from plotly.graph_objects import Figure from pynecone.base import Base from pynecone.constants import RouteVar from pynecone.event import Event +from pynecone.middleware.hydrate_middleware import IS_HYDRATED from pynecone.state import State from pynecone.utils import format from pynecone.var import BaseVar, ComputedVar @@ -191,6 +192,7 @@ def test_class_vars(test_state): """ cls = type(test_state) assert set(cls.vars.keys()) == { + IS_HYDRATED, # added by hydrate_middleware to all State "num1", "num2", "key",