diff --git a/reflex/.templates/jinja/web/utils/context.js.jinja2 b/reflex/.templates/jinja/web/utils/context.js.jinja2 index 54ca36ff8..5ff4a44bd 100644 --- a/reflex/.templates/jinja/web/utils/context.js.jinja2 +++ b/reflex/.templates/jinja/web/utils/context.js.jinja2 @@ -23,10 +23,15 @@ export const clientStorage = {} {% endif %} {% if state_name %} +export const onLoadInternalEvent = () => [Event('{{state_name}}.{{const.on_load_internal}}')] + export const initialEvents = () => [ Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)), + ...onLoadInternalEvent() ] {% else %} +export const onLoadInternalEvent = () => [] + export const initialEvents = () => [] {% endif %} diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index aedce21c0..f3ea46428 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -6,7 +6,7 @@ import env from "env.json"; import Cookies from "universal-cookie"; import { useEffect, useReducer, useRef, useState } from "react"; import Router, { useRouter } from "next/router"; -import { initialEvents, initialState } from "utils/context.js" +import { initialEvents, initialState, onLoadInternalEvent } from "utils/context.js" // Endpoint URLs. const EVENTURL = env.EVENT @@ -529,10 +529,15 @@ export const useEventLoop = ( } const sentHydrate = useRef(false); // Avoid double-hydrate due to React strict-mode - // initial state hydrate useEffect(() => { if (router.isReady && !sentHydrate.current) { - addEvents(initial_events()) + const events = initial_events() + addEvents(events.map((e) => ( + { + ...e, + router_data: (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(router) + } + ))) sentHydrate.current = true } }, [router.isReady]) @@ -560,7 +565,7 @@ export const useEventLoop = ( // Route after the initial page hydration. useEffect(() => { - const change_complete = () => addEvents(initial_events()) + const change_complete = () => addEvents(onLoadInternalEvent()) router.events.on('routeChangeComplete', change_complete) return () => { router.events.off('routeChangeComplete', change_complete) diff --git a/reflex/app.pyi b/reflex/app.pyi index 667ebf52a..b3b77bd20 100644 --- a/reflex/app.pyi +++ b/reflex/app.pyi @@ -125,7 +125,7 @@ class App(Base): self, state: State, event: Event ) -> asyncio.Task | None: ... -async def process( +def process( app: App, event: Event, sid: str, headers: Dict, client_ip: str ) -> AsyncIterator[StateUpdate]: ... async def ping() -> str: ... diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index 5099d2884..472d1fbdc 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -40,6 +40,7 @@ class ReflexJinjaEnvironment(Environment): "toggle_color_mode": constants.ColorMode.TOGGLE, "use_color_mode": constants.ColorMode.USE, "hydrate": constants.CompileVars.HYDRATE, + "on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL, } diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index 175bd91a6..d2307eb4a 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -48,6 +48,7 @@ from .route import ( ROUTE_NOT_FOUND, ROUTER, ROUTER_DATA, + ROUTER_DATA_INCLUDE, DefaultPage, Page404, RouteArgType, @@ -97,6 +98,7 @@ __ALL__ = [ RouteVar, ROUTER, ROUTER_DATA, + ROUTER_DATA_INCLUDE, ROUTE_NOT_FOUND, SETTER_PREFIX, SKIP_COMPILE_ENV_VAR, diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 91e663837..a8ae11610 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -58,6 +58,8 @@ class CompileVars(SimpleNamespace): CONNECT_ERROR = "connectError" # The name of the function for converting a dict to an event. TO_EVENT = "Event" + # The name of the internal on_load event. + ON_LOAD_INTERNAL = "on_load_internal" class PageNames(SimpleNamespace): diff --git a/reflex/constants/route.py b/reflex/constants/route.py index fad285f2f..2ed399a54 100644 --- a/reflex/constants/route.py +++ b/reflex/constants/route.py @@ -30,6 +30,10 @@ class RouteVar(SimpleNamespace): COOKIE = "cookie" +# This subset of router_data is included in chained on_load events. +ROUTER_DATA_INCLUDE = set((RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY)) + + class RouteRegex(SimpleNamespace): """Regex used for extracting route args in route.""" diff --git a/reflex/event.py b/reflex/event.py index f0ad0957b..da2b3a1e3 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -826,6 +826,10 @@ def fix_events( # Fix the events created by the handler. out = [] for e in events: + if isinstance(e, Event): + # If the event is already an event, append it to the list. + out.append(e) + continue if not isinstance(e, (EventHandler, EventSpec)): e = EventHandler(fn=e) # Otherwise, create an event from the event spec. @@ -835,13 +839,19 @@ def fix_events( name = format.format_event_handler(e.handler) payload = {k._var_name: v._decode() for k, v in e.args} # type: ignore + # Filter router_data to reduce payload size + event_router_data = { + k: v + for k, v in (router_data or {}).items() + if k in constants.route.ROUTER_DATA_INCLUDE + } # Create an event and append it to the list. out.append( Event( token=token, name=name, payload=payload, - router_data=router_data or {}, + router_data=event_router_data, ) ) diff --git a/reflex/middleware/hydrate_middleware.py b/reflex/middleware/hydrate_middleware.py index 6108a90c4..dc80971c7 100644 --- a/reflex/middleware/hydrate_middleware.py +++ b/reflex/middleware/hydrate_middleware.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional from reflex import constants -from reflex.event import Event, fix_events, get_hydrate_event +from reflex.event import Event, get_hydrate_event from reflex.middleware.middleware import Middleware from reflex.state import BaseState, StateUpdate from reflex.utils import format @@ -52,11 +52,5 @@ class HydrateMiddleware(Middleware): # since a full dict was captured, clean any dirtiness state._clean() - # Get the route for on_load events. - route = event.router_data.get(constants.RouteVar.PATH, "") - # Add the on_load events and set is_hydrated to True. - events = [*app.get_load_events(route), type(state).set_is_hydrated(True)] # type: ignore - events = fix_events(events, event.token, router_data=event.router_data) - # Return the state update. - return StateUpdate(delta=delta, events=events) + return StateUpdate(delta=delta, events=[]) diff --git a/reflex/state.py b/reflex/state.py index f6b10e849..d6157e3cb 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1016,7 +1016,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): """ def _is_valid_type(events: Any) -> bool: - return isinstance(events, (EventHandler, EventSpec)) + return isinstance(events, (Event, EventHandler, EventSpec)) if events is None or _is_valid_type(events): return events @@ -1313,6 +1313,26 @@ class State(BaseState): # The hydrated bool. is_hydrated: bool = False + def on_load_internal(self) -> list[Event | EventSpec] | None: + """Queue on_load handlers for the current page. + + Returns: + The list of events to queue for on load handling. + """ + app = getattr(prerequisites.get_app(), constants.CompileVars.APP) + load_events = app.get_load_events(self.router.page.path) + if not load_events and self.is_hydrated: + return # Fast path for page-to-page navigation + self.is_hydrated = False + return [ + *fix_events( + load_events, + self.router.session.client_token, + router_data=self.router_data, + ), + type(self).set_is_hydrated(True), # type: ignore + ] + class StateProxy(wrapt.ObjectProxy): """Proxy of a state instance to control mutability of vars for a background task. diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index f0e53b2d2..dee7c8688 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -123,9 +123,17 @@ def get_app(reload: bool = False) -> ModuleType: Returns: The app based on the default config. + + Raises: + RuntimeError: If the app name is not set in the config. """ os.environ[constants.RELOAD_CONFIG] = str(reload) config = get_config() + if not config.app_name: + raise RuntimeError( + "Cannot get the app module because `app_name` is not set in rxconfig! " + "If this error occurs in a reflex test case, ensure that `get_app` is mocked." + ) module = ".".join([config.app_name, config.app_name]) sys.path.insert(0, os.getcwd()) app = __import__(module, fromlist=(constants.CompileVars.APP,)) diff --git a/tests/conftest.py b/tests/conftest.py index e5dddc470..be9290edb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,11 +5,13 @@ import platform import uuid from pathlib import Path from typing import Dict, Generator +from unittest import mock import pytest from reflex.app import App from reflex.event import EventSpec +from reflex.utils import prerequisites from .states import ( DictMutationTestState, @@ -30,6 +32,26 @@ def app() -> App: return App() +@pytest.fixture +def app_module_mock(monkeypatch) -> mock.Mock: + """Mock the app module. + + This overwrites prerequisites.get_app to return the mock for the app module. + + To use this in your test, assign `app_module_mock.app = rx.App(...)`. + + Args: + monkeypatch: pytest monkeypatch fixture. + + Returns: + The mock for the main app module. + """ + app_module_mock = mock.Mock() + get_app_mock = mock.Mock(return_value=app_module_mock) + monkeypatch.setattr(prerequisites, "get_app", get_app_mock) + return app_module_mock + + @pytest.fixture(scope="session") def windows_platform() -> Generator: """Check if system is windows. diff --git a/tests/middleware/conftest.py b/tests/middleware/conftest.py index cad706dc9..5a1897110 100644 --- a/tests/middleware/conftest.py +++ b/tests/middleware/conftest.py @@ -21,14 +21,4 @@ def create_event(name): @pytest.fixture def event1(): - return create_event("test_state.hydrate") - - -@pytest.fixture -def event2(): - return create_event("test_state2.hydrate") - - -@pytest.fixture -def event3(): - return create_event("test_state3.hydrate") + return create_event("state.hydrate") diff --git a/tests/middleware/test_hydrate_middleware.py b/tests/middleware/test_hydrate_middleware.py index 2f21557f0..9ee8d8d25 100644 --- a/tests/middleware/test_hydrate_middleware.py +++ b/tests/middleware/test_hydrate_middleware.py @@ -1,27 +1,13 @@ -from typing import Any, Dict +from __future__ import annotations import pytest -from reflex import constants from reflex.app import App -from reflex.constants import CompileVars from reflex.middleware.hydrate_middleware import HydrateMiddleware -from reflex.state import BaseState, StateUpdate +from reflex.state import State, StateUpdate -def exp_is_hydrated(state: BaseState) -> Dict[str, Any]: - """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. - - Args: - state: the State that is hydrated - - Returns: - dict similar to that returned by `State.get_delta` with IS_HYDRATED: True - """ - return {state.get_name(): {CompileVars.IS_HYDRATED: True}} - - -class TestState(BaseState): +class TestState(State): """A test state with no return in handler.""" __test__ = False @@ -33,40 +19,6 @@ class TestState(BaseState): self.num += 1 -class TestState2(BaseState): - """A test state with return in handler.""" - - __test__ = False - - num: int = 0 - name: str - - def test_handler(self): - """Test handler that calls another handler. - - Returns: - Chain of EventHandlers - """ - self.num += 1 - return self.change_name - - def change_name(self): - """Test handler to change name.""" - self.name = "random" - - -class TestState3(BaseState): - """A test state with async handler.""" - - __test__ = False - - num: int = 0 - - async def test_handler(self): - """Test handler.""" - self.num += 1 - - @pytest.fixture def hydrate_middleware() -> HydrateMiddleware: """Fixture creates an instance of HydrateMiddleware per test case. @@ -78,98 +30,21 @@ def hydrate_middleware() -> HydrateMiddleware: @pytest.mark.asyncio -@pytest.mark.parametrize( - "test_state, expected, event_fixture", - [ - (TestState, {"test_state": {"num": 1}}, "event1"), - (TestState2, {"test_state2": {"num": 1}}, "event2"), - (TestState3, {"test_state3": {"num": 1}}, "event3"), - ], -) -async def test_preprocess( - test_state, hydrate_middleware, request, event_fixture, expected -): - """Test that a state hydrate event is processed correctly. - - Args: - test_state: State to process event. - hydrate_middleware: Instance of HydrateMiddleware. - request: Pytest fixture request. - event_fixture: The event fixture(an Event). - expected: Expected delta. - """ - test_state.add_var( - constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False - ) - app = App(state=test_state, load_events={"index": [test_state.test_handler]}) - state = test_state() - - update = await hydrate_middleware.preprocess( - app=app, event=request.getfixturevalue(event_fixture), state=state - ) - assert isinstance(update, StateUpdate) - assert update.delta == state.dict() - events = update.events - assert len(events) == 2 - - # Apply the on_load event. - update = await state._process(events[0]).__anext__() - assert update.delta == expected - - # Apply the hydrate event. - update = await state._process(events[1]).__anext__() - assert update.delta == exp_is_hydrated(state) - - -@pytest.mark.asyncio -async def test_preprocess_multiple_load_events(hydrate_middleware, event1): - """Test that a state hydrate event for multiple on-load events is processed correctly. - - Args: - hydrate_middleware: Instance of HydrateMiddleware - event1: An Event. - """ - app = App( - state=TestState, - load_events={"index": [TestState.test_handler, TestState.test_handler]}, - ) - state = TestState() - - update = await hydrate_middleware.preprocess(app=app, event=event1, state=state) - assert isinstance(update, StateUpdate) - assert update.delta == state.dict() - assert len(update.events) == 3 - - # Apply the events. - events = update.events - update = await state._process(events[0]).__anext__() - assert update.delta == {"test_state": {"num": 1}} - - update = await state._process(events[1]).__anext__() - assert update.delta == {"test_state": {"num": 2}} - - update = await state._process(events[2]).__anext__() - assert update.delta == exp_is_hydrated(state) - - -@pytest.mark.asyncio -async def test_preprocess_no_events(hydrate_middleware, event1): +async def test_preprocess_no_events(hydrate_middleware, event1, mocker): """Test that app without on_load is processed correctly. Args: hydrate_middleware: Instance of HydrateMiddleware event1: An Event. + mocker: pytest mock object. """ - state = TestState() + mocker.patch("reflex.state.State.class_subclasses", {TestState}) + state = State() update = await hydrate_middleware.preprocess( - app=App(state=TestState), + app=App(state=State), event=event1, state=state, ) assert isinstance(update, StateUpdate) assert update.delta == state.dict() - assert len(update.events) == 1 - assert isinstance(update, StateUpdate) - - update = await state._process(update.events[0]).__anext__() - assert update.delta == exp_is_hydrated(state) + assert not update.events diff --git a/tests/test_app.py b/tests/test_app.py index 7d667ac1e..6f159172a 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -25,10 +25,10 @@ from reflex.app import ( upload, ) from reflex.components import Box, Component, Cond, Fragment, Text -from reflex.event import Event, get_hydrate_event +from reflex.event import Event from reflex.middleware import HydrateMiddleware from reflex.model import Model -from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate +from reflex.state import BaseState, State, StateManagerRedis, StateUpdate from reflex.style import Style from reflex.utils import format from reflex.vars import ComputedVar @@ -870,6 +870,7 @@ class DynamicState(BaseState): recalculated when the dynamic route var was dirty """ + is_hydrated: bool = False loaded: int = 0 counter: int = 0 @@ -893,10 +894,16 @@ class DynamicState(BaseState): # self.side_effect_counter = self.side_effect_counter + 1 return self.dynamic + on_load_internal = State.on_load_internal.fn + @pytest.mark.asyncio async def test_dynamic_route_var_route_change_completed_on_load( - index_page, windows_platform: bool, token: str, mocker + index_page, + windows_platform: bool, + token: str, + app_module_mock: unittest.mock.Mock, + mocker, ): """Create app with dynamic route var, and simulate navigation. @@ -907,17 +914,14 @@ async def test_dynamic_route_var_route_change_completed_on_load( index_page: The index page. windows_platform: Whether the system is windows. token: a Token. + app_module_mock: Mocked app module. mocker: pytest mocker object. """ - mocker.patch("reflex.state.State.class_subclasses", {DynamicState}) - DynamicState.add_var( - constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False - ) arg_name = "dynamic" route = f"/test/[{arg_name}]" if windows_platform: route.lstrip("/").replace("/", "\\") - app = App(state=DynamicState) + app = app_module_mock.app = App(state=DynamicState) assert arg_name not in app.state.vars app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore assert arg_name in app.state.vars @@ -953,33 +957,25 @@ async def test_dynamic_route_var_route_change_completed_on_load( prev_exp_val = "" for exp_index, exp_val in enumerate(exp_vals): - hydrate_event = _event(name=get_hydrate_event(state), val=exp_val) - exp_router_data = { - "headers": {}, - "ip": client_ip, - "sid": sid, - "token": token, - **hydrate_event.router_data, - } - exp_router = RouterData(exp_router_data) + on_load_internal = _event( + name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL}", + val=exp_val, + ) process_coro = process( app, - event=hydrate_event, + event=on_load_internal, sid=sid, headers={}, client_ip=client_ip, ) - update = await process_coro.__anext__() # type: ignore - # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)] + update = await process_coro.__anext__() + # route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)] assert update == StateUpdate( delta={ state.get_name(): { arg_name: exp_val, f"comp_{arg_name}": exp_val, constants.CompileVars.IS_HYDRATED: False, - "loaded": exp_index, - "counter": exp_index, - "router": exp_router, # "side_effect_counter": exp_index, } }, @@ -987,13 +983,12 @@ async def test_dynamic_route_var_route_change_completed_on_load( _dynamic_state_event( name="on_load", val=exp_val, - router_data=exp_router_data, ), _dynamic_state_event( name="set_is_hydrated", payload={"value": True}, val=exp_val, - router_data=exp_router_data, + router_data={}, ), ], ) @@ -1004,7 +999,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( # complete the processing with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() # type: ignore + await process_coro.__anext__() # check that router data was written to the state_manager store state = await app.state_manager.get_state(token) @@ -1017,7 +1012,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( headers={}, client_ip=client_ip, ) - on_load_update = await process_coro.__anext__() # type: ignore + on_load_update = await process_coro.__anext__() assert on_load_update == StateUpdate( delta={ state.get_name(): { @@ -1031,7 +1026,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( ) # complete the processing with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() # type: ignore + await process_coro.__anext__() process_coro = process( app, event=_dynamic_state_event( @@ -1041,7 +1036,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( headers={}, client_ip=client_ip, ) - on_set_is_hydrated_update = await process_coro.__anext__() # type: ignore + on_set_is_hydrated_update = await process_coro.__anext__() assert on_set_is_hydrated_update == StateUpdate( delta={ state.get_name(): { @@ -1055,7 +1050,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( ) # complete the processing with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() # type: ignore + await process_coro.__anext__() # a simple state update event should NOT trigger on_load or route var side effects process_coro = process( @@ -1065,7 +1060,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( headers={}, client_ip=client_ip, ) - update = await process_coro.__anext__() # type: ignore + update = await process_coro.__anext__() assert update == StateUpdate( delta={ state.get_name(): { @@ -1079,7 +1074,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( ) # complete the processing with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() # type: ignore + await process_coro.__anext__() prev_exp_val = exp_val state = await app.state_manager.get_state(token) @@ -1116,7 +1111,7 @@ async def test_process_events(mocker, token: str): token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data ) - async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): # type: ignore + async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): pass assert (await app.state_manager.get_state(token)).value == 5 diff --git a/tests/test_state.py b/tests/test_state.py index b4ca0c87a..28cb41106 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -7,7 +7,7 @@ import functools import json import os import sys -from typing import Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Union from unittest.mock import AsyncMock, Mock import pytest @@ -24,6 +24,7 @@ from reflex.state import ( LockExpiredError, MutableProxy, RouterData, + State, StateManager, StateManagerMemory, StateManagerRedis, @@ -1374,8 +1375,13 @@ def test_error_on_state_method_shadow(): ) -def test_state_with_invalid_yield(): - """Test that an error is thrown when a state yields an invalid value.""" +@pytest.mark.asyncio +async def test_state_with_invalid_yield(capsys): + """Test that an error is thrown when a state yields an invalid value. + + Args: + capsys: Pytest fixture for capture standard streams. + """ class StateWithInvalidYield(BaseState): """A state that yields an invalid value.""" @@ -1389,15 +1395,16 @@ def test_state_with_invalid_yield(): yield 1 invalid_state = StateWithInvalidYield() - with pytest.raises(TypeError) as err: - invalid_state._check_valid( - invalid_state.event_handlers["invalid_handler"], - rx.event.Event(token="fake_token", name="invalid_handler"), + async for update in invalid_state._process( + rx.event.Event(token="fake_token", name="invalid_handler") + ): + assert not update.delta + assert update.events == rx.event.fix_events( + [rx.window_alert("An error occurred. See logs for details.")], + token="", ) - assert ( - "must only return/yield: None, Events or other EventHandlers" - in err.value.args[0] - ) + captured = capsys.readouterr() + assert "must only return/yield: None, Events or other EventHandlers" in captured.out @pytest.fixture(scope="function", params=["in_process", "redis"]) @@ -2303,3 +2310,150 @@ def test_state_union_optional(): assert UnionState.custom_union.c2r is not None # type: ignore assert types.is_optional(UnionState.opt_int._var_type) # type: ignore assert types.is_union(UnionState.int_float._var_type) # type: ignore + + +def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]: + """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. + + Args: + state: the State that is hydrated. + is_hydrated: whether the state is hydrated. + + Returns: + dict similar to that returned by `State.get_delta` with IS_HYDRATED: is_hydrated + """ + return {state.get_full_name(): {CompileVars.IS_HYDRATED: is_hydrated}} + + +class OnLoadState(State): + """A test state with no return in handler.""" + + num: int = 0 + + def test_handler(self): + """Test handler.""" + self.num += 1 + + +class OnLoadState2(State): + """A test state with return in handler.""" + + num: int = 0 + name: str + + def test_handler(self): + """Test handler that calls another handler. + + Returns: + Chain of EventHandlers + """ + self.num += 1 + return self.change_name + + def change_name(self): + """Test handler to change name.""" + self.name = "random" + + +class OnLoadState3(State): + """A test state with async handler.""" + + num: int = 0 + + async def test_handler(self): + """Test handler.""" + self.num += 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_state, expected", + [ + (OnLoadState, {"on_load_state": {"num": 1}}), + (OnLoadState2, {"on_load_state2": {"num": 1}}), + (OnLoadState3, {"on_load_state3": {"num": 1}}), + ], +) +async def test_preprocess(app_module_mock, token, test_state, expected, mocker): + """Test that a state hydrate event is processed correctly. + + Args: + app_module_mock: The app module that will be returned by get_app(). + token: A token. + test_state: State to process event. + expected: Expected delta. + mocker: pytest mock object. + """ + mocker.patch("reflex.state.State.class_subclasses", {test_state}) + app = app_module_mock.app = App( + state=State, load_events={"index": [test_state.test_handler]} + ) + state = State() + + updates = [] + async for update in rx.app.process( + app=app, + event=Event( + token=token, + name=f"{state.get_name()}.{CompileVars.ON_LOAD_INTERNAL}", + router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}}, + ), + sid="sid", + headers={}, + client_ip="", + ): + assert isinstance(update, StateUpdate) + updates.append(update) + assert len(updates) == 1 + assert updates[0].delta == exp_is_hydrated(state, False) + + events = updates[0].events + assert len(events) == 2 + assert (await state._process(events[0]).__anext__()).delta == { + test_state.get_full_name(): {"num": 1} + } + assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state) + + +@pytest.mark.asyncio +async def test_preprocess_multiple_load_events(app_module_mock, token, mocker): + """Test that a state hydrate event for multiple on-load events is processed correctly. + + Args: + app_module_mock: The app module that will be returned by get_app(). + token: A token. + mocker: pytest mock object. + """ + mocker.patch("reflex.state.State.class_subclasses", {OnLoadState}) + app = app_module_mock.app = App( + state=State, + load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]}, + ) + state = State() + + updates = [] + async for update in rx.app.process( + app=app, + event=Event( + token=token, + name=f"{state.get_full_name()}.{CompileVars.ON_LOAD_INTERNAL}", + router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}}, + ), + sid="sid", + headers={}, + client_ip="", + ): + assert isinstance(update, StateUpdate) + updates.append(update) + assert len(updates) == 1 + assert updates[0].delta == exp_is_hydrated(state, False) + + events = updates[0].events + assert len(events) == 3 + assert (await state._process(events[0]).__anext__()).delta == { + OnLoadState.get_full_name(): {"num": 1} + } + assert (await state._process(events[1]).__anext__()).delta == { + OnLoadState.get_full_name(): {"num": 2} + } + assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)