[REF-201] Separate on_load handler from initial hydration (#1847)
This commit is contained in:
parent
3c7af9fad4
commit
60147dec65
@ -23,10 +23,15 @@ export const clientStorage = {}
|
||||
{% endif %}
|
||||
|
||||
{% if state_name %}
|
||||
export const onLoadInternalEvent = () => [Event('{{state_name}}.{{const.on_load_internal}}')]
|
||||
|
||||
export const initialEvents = () => [
|
||||
Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)),
|
||||
...onLoadInternalEvent()
|
||||
]
|
||||
{% else %}
|
||||
export const onLoadInternalEvent = () => []
|
||||
|
||||
export const initialEvents = () => []
|
||||
{% endif %}
|
||||
|
||||
|
@ -6,7 +6,7 @@ import env from "env.json";
|
||||
import Cookies from "universal-cookie";
|
||||
import { useEffect, useReducer, useRef, useState } from "react";
|
||||
import Router, { useRouter } from "next/router";
|
||||
import { initialEvents, initialState } from "utils/context.js"
|
||||
import { initialEvents, initialState, onLoadInternalEvent } from "utils/context.js"
|
||||
|
||||
// Endpoint URLs.
|
||||
const EVENTURL = env.EVENT
|
||||
@ -529,10 +529,15 @@ export const useEventLoop = (
|
||||
}
|
||||
|
||||
const sentHydrate = useRef(false); // Avoid double-hydrate due to React strict-mode
|
||||
// initial state hydrate
|
||||
useEffect(() => {
|
||||
if (router.isReady && !sentHydrate.current) {
|
||||
addEvents(initial_events())
|
||||
const events = initial_events()
|
||||
addEvents(events.map((e) => (
|
||||
{
|
||||
...e,
|
||||
router_data: (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(router)
|
||||
}
|
||||
)))
|
||||
sentHydrate.current = true
|
||||
}
|
||||
}, [router.isReady])
|
||||
@ -560,7 +565,7 @@ export const useEventLoop = (
|
||||
|
||||
// Route after the initial page hydration.
|
||||
useEffect(() => {
|
||||
const change_complete = () => addEvents(initial_events())
|
||||
const change_complete = () => addEvents(onLoadInternalEvent())
|
||||
router.events.on('routeChangeComplete', change_complete)
|
||||
return () => {
|
||||
router.events.off('routeChangeComplete', change_complete)
|
||||
|
@ -125,7 +125,7 @@ class App(Base):
|
||||
self, state: State, event: Event
|
||||
) -> asyncio.Task | None: ...
|
||||
|
||||
async def process(
|
||||
def process(
|
||||
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
||||
) -> AsyncIterator[StateUpdate]: ...
|
||||
async def ping() -> str: ...
|
||||
|
@ -40,6 +40,7 @@ class ReflexJinjaEnvironment(Environment):
|
||||
"toggle_color_mode": constants.ColorMode.TOGGLE,
|
||||
"use_color_mode": constants.ColorMode.USE,
|
||||
"hydrate": constants.CompileVars.HYDRATE,
|
||||
"on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
|
||||
}
|
||||
|
||||
|
||||
|
@ -48,6 +48,7 @@ from .route import (
|
||||
ROUTE_NOT_FOUND,
|
||||
ROUTER,
|
||||
ROUTER_DATA,
|
||||
ROUTER_DATA_INCLUDE,
|
||||
DefaultPage,
|
||||
Page404,
|
||||
RouteArgType,
|
||||
@ -97,6 +98,7 @@ __ALL__ = [
|
||||
RouteVar,
|
||||
ROUTER,
|
||||
ROUTER_DATA,
|
||||
ROUTER_DATA_INCLUDE,
|
||||
ROUTE_NOT_FOUND,
|
||||
SETTER_PREFIX,
|
||||
SKIP_COMPILE_ENV_VAR,
|
||||
|
@ -58,6 +58,8 @@ class CompileVars(SimpleNamespace):
|
||||
CONNECT_ERROR = "connectError"
|
||||
# The name of the function for converting a dict to an event.
|
||||
TO_EVENT = "Event"
|
||||
# The name of the internal on_load event.
|
||||
ON_LOAD_INTERNAL = "on_load_internal"
|
||||
|
||||
|
||||
class PageNames(SimpleNamespace):
|
||||
|
@ -30,6 +30,10 @@ class RouteVar(SimpleNamespace):
|
||||
COOKIE = "cookie"
|
||||
|
||||
|
||||
# This subset of router_data is included in chained on_load events.
|
||||
ROUTER_DATA_INCLUDE = set((RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY))
|
||||
|
||||
|
||||
class RouteRegex(SimpleNamespace):
|
||||
"""Regex used for extracting route args in route."""
|
||||
|
||||
|
@ -826,6 +826,10 @@ def fix_events(
|
||||
# Fix the events created by the handler.
|
||||
out = []
|
||||
for e in events:
|
||||
if isinstance(e, Event):
|
||||
# If the event is already an event, append it to the list.
|
||||
out.append(e)
|
||||
continue
|
||||
if not isinstance(e, (EventHandler, EventSpec)):
|
||||
e = EventHandler(fn=e)
|
||||
# Otherwise, create an event from the event spec.
|
||||
@ -835,13 +839,19 @@ def fix_events(
|
||||
name = format.format_event_handler(e.handler)
|
||||
payload = {k._var_name: v._decode() for k, v in e.args} # type: ignore
|
||||
|
||||
# Filter router_data to reduce payload size
|
||||
event_router_data = {
|
||||
k: v
|
||||
for k, v in (router_data or {}).items()
|
||||
if k in constants.route.ROUTER_DATA_INCLUDE
|
||||
}
|
||||
# Create an event and append it to the list.
|
||||
out.append(
|
||||
Event(
|
||||
token=token,
|
||||
name=name,
|
||||
payload=payload,
|
||||
router_data=router_data or {},
|
||||
router_data=event_router_data,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from reflex import constants
|
||||
from reflex.event import Event, fix_events, get_hydrate_event
|
||||
from reflex.event import Event, get_hydrate_event
|
||||
from reflex.middleware.middleware import Middleware
|
||||
from reflex.state import BaseState, StateUpdate
|
||||
from reflex.utils import format
|
||||
@ -52,11 +52,5 @@ class HydrateMiddleware(Middleware):
|
||||
# since a full dict was captured, clean any dirtiness
|
||||
state._clean()
|
||||
|
||||
# Get the route for on_load events.
|
||||
route = event.router_data.get(constants.RouteVar.PATH, "")
|
||||
# Add the on_load events and set is_hydrated to True.
|
||||
events = [*app.get_load_events(route), type(state).set_is_hydrated(True)] # type: ignore
|
||||
events = fix_events(events, event.token, router_data=event.router_data)
|
||||
|
||||
# Return the state update.
|
||||
return StateUpdate(delta=delta, events=events)
|
||||
return StateUpdate(delta=delta, events=[])
|
||||
|
@ -1016,7 +1016,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""
|
||||
|
||||
def _is_valid_type(events: Any) -> bool:
|
||||
return isinstance(events, (EventHandler, EventSpec))
|
||||
return isinstance(events, (Event, EventHandler, EventSpec))
|
||||
|
||||
if events is None or _is_valid_type(events):
|
||||
return events
|
||||
@ -1313,6 +1313,26 @@ class State(BaseState):
|
||||
# The hydrated bool.
|
||||
is_hydrated: bool = False
|
||||
|
||||
def on_load_internal(self) -> list[Event | EventSpec] | None:
|
||||
"""Queue on_load handlers for the current page.
|
||||
|
||||
Returns:
|
||||
The list of events to queue for on load handling.
|
||||
"""
|
||||
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
|
||||
load_events = app.get_load_events(self.router.page.path)
|
||||
if not load_events and self.is_hydrated:
|
||||
return # Fast path for page-to-page navigation
|
||||
self.is_hydrated = False
|
||||
return [
|
||||
*fix_events(
|
||||
load_events,
|
||||
self.router.session.client_token,
|
||||
router_data=self.router_data,
|
||||
),
|
||||
type(self).set_is_hydrated(True), # type: ignore
|
||||
]
|
||||
|
||||
|
||||
class StateProxy(wrapt.ObjectProxy):
|
||||
"""Proxy of a state instance to control mutability of vars for a background task.
|
||||
|
@ -123,9 +123,17 @@ def get_app(reload: bool = False) -> ModuleType:
|
||||
|
||||
Returns:
|
||||
The app based on the default config.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the app name is not set in the config.
|
||||
"""
|
||||
os.environ[constants.RELOAD_CONFIG] = str(reload)
|
||||
config = get_config()
|
||||
if not config.app_name:
|
||||
raise RuntimeError(
|
||||
"Cannot get the app module because `app_name` is not set in rxconfig! "
|
||||
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
|
||||
)
|
||||
module = ".".join([config.app_name, config.app_name])
|
||||
sys.path.insert(0, os.getcwd())
|
||||
app = __import__(module, fromlist=(constants.CompileVars.APP,))
|
||||
|
@ -5,11 +5,13 @@ import platform
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from reflex.app import App
|
||||
from reflex.event import EventSpec
|
||||
from reflex.utils import prerequisites
|
||||
|
||||
from .states import (
|
||||
DictMutationTestState,
|
||||
@ -30,6 +32,26 @@ def app() -> App:
|
||||
return App()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_module_mock(monkeypatch) -> mock.Mock:
|
||||
"""Mock the app module.
|
||||
|
||||
This overwrites prerequisites.get_app to return the mock for the app module.
|
||||
|
||||
To use this in your test, assign `app_module_mock.app = rx.App(...)`.
|
||||
|
||||
Args:
|
||||
monkeypatch: pytest monkeypatch fixture.
|
||||
|
||||
Returns:
|
||||
The mock for the main app module.
|
||||
"""
|
||||
app_module_mock = mock.Mock()
|
||||
get_app_mock = mock.Mock(return_value=app_module_mock)
|
||||
monkeypatch.setattr(prerequisites, "get_app", get_app_mock)
|
||||
return app_module_mock
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def windows_platform() -> Generator:
|
||||
"""Check if system is windows.
|
||||
|
@ -21,14 +21,4 @@ def create_event(name):
|
||||
|
||||
@pytest.fixture
|
||||
def event1():
|
||||
return create_event("test_state.hydrate")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event2():
|
||||
return create_event("test_state2.hydrate")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event3():
|
||||
return create_event("test_state3.hydrate")
|
||||
return create_event("state.hydrate")
|
||||
|
@ -1,27 +1,13 @@
|
||||
from typing import Any, Dict
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from reflex import constants
|
||||
from reflex.app import App
|
||||
from reflex.constants import CompileVars
|
||||
from reflex.middleware.hydrate_middleware import HydrateMiddleware
|
||||
from reflex.state import BaseState, StateUpdate
|
||||
from reflex.state import State, StateUpdate
|
||||
|
||||
|
||||
def exp_is_hydrated(state: BaseState) -> Dict[str, Any]:
|
||||
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
|
||||
|
||||
Args:
|
||||
state: the State that is hydrated
|
||||
|
||||
Returns:
|
||||
dict similar to that returned by `State.get_delta` with IS_HYDRATED: True
|
||||
"""
|
||||
return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
|
||||
|
||||
|
||||
class TestState(BaseState):
|
||||
class TestState(State):
|
||||
"""A test state with no return in handler."""
|
||||
|
||||
__test__ = False
|
||||
@ -33,40 +19,6 @@ class TestState(BaseState):
|
||||
self.num += 1
|
||||
|
||||
|
||||
class TestState2(BaseState):
|
||||
"""A test state with return in handler."""
|
||||
|
||||
__test__ = False
|
||||
|
||||
num: int = 0
|
||||
name: str
|
||||
|
||||
def test_handler(self):
|
||||
"""Test handler that calls another handler.
|
||||
|
||||
Returns:
|
||||
Chain of EventHandlers
|
||||
"""
|
||||
self.num += 1
|
||||
return self.change_name
|
||||
|
||||
def change_name(self):
|
||||
"""Test handler to change name."""
|
||||
self.name = "random"
|
||||
|
||||
|
||||
class TestState3(BaseState):
|
||||
"""A test state with async handler."""
|
||||
|
||||
__test__ = False
|
||||
|
||||
num: int = 0
|
||||
|
||||
async def test_handler(self):
|
||||
"""Test handler."""
|
||||
self.num += 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hydrate_middleware() -> HydrateMiddleware:
|
||||
"""Fixture creates an instance of HydrateMiddleware per test case.
|
||||
@ -78,98 +30,21 @@ def hydrate_middleware() -> HydrateMiddleware:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_state, expected, event_fixture",
|
||||
[
|
||||
(TestState, {"test_state": {"num": 1}}, "event1"),
|
||||
(TestState2, {"test_state2": {"num": 1}}, "event2"),
|
||||
(TestState3, {"test_state3": {"num": 1}}, "event3"),
|
||||
],
|
||||
)
|
||||
async def test_preprocess(
|
||||
test_state, hydrate_middleware, request, event_fixture, expected
|
||||
):
|
||||
"""Test that a state hydrate event is processed correctly.
|
||||
|
||||
Args:
|
||||
test_state: State to process event.
|
||||
hydrate_middleware: Instance of HydrateMiddleware.
|
||||
request: Pytest fixture request.
|
||||
event_fixture: The event fixture(an Event).
|
||||
expected: Expected delta.
|
||||
"""
|
||||
test_state.add_var(
|
||||
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
|
||||
)
|
||||
app = App(state=test_state, load_events={"index": [test_state.test_handler]})
|
||||
state = test_state()
|
||||
|
||||
update = await hydrate_middleware.preprocess(
|
||||
app=app, event=request.getfixturevalue(event_fixture), state=state
|
||||
)
|
||||
assert isinstance(update, StateUpdate)
|
||||
assert update.delta == state.dict()
|
||||
events = update.events
|
||||
assert len(events) == 2
|
||||
|
||||
# Apply the on_load event.
|
||||
update = await state._process(events[0]).__anext__()
|
||||
assert update.delta == expected
|
||||
|
||||
# Apply the hydrate event.
|
||||
update = await state._process(events[1]).__anext__()
|
||||
assert update.delta == exp_is_hydrated(state)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
|
||||
"""Test that a state hydrate event for multiple on-load events is processed correctly.
|
||||
|
||||
Args:
|
||||
hydrate_middleware: Instance of HydrateMiddleware
|
||||
event1: An Event.
|
||||
"""
|
||||
app = App(
|
||||
state=TestState,
|
||||
load_events={"index": [TestState.test_handler, TestState.test_handler]},
|
||||
)
|
||||
state = TestState()
|
||||
|
||||
update = await hydrate_middleware.preprocess(app=app, event=event1, state=state)
|
||||
assert isinstance(update, StateUpdate)
|
||||
assert update.delta == state.dict()
|
||||
assert len(update.events) == 3
|
||||
|
||||
# Apply the events.
|
||||
events = update.events
|
||||
update = await state._process(events[0]).__anext__()
|
||||
assert update.delta == {"test_state": {"num": 1}}
|
||||
|
||||
update = await state._process(events[1]).__anext__()
|
||||
assert update.delta == {"test_state": {"num": 2}}
|
||||
|
||||
update = await state._process(events[2]).__anext__()
|
||||
assert update.delta == exp_is_hydrated(state)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preprocess_no_events(hydrate_middleware, event1):
|
||||
async def test_preprocess_no_events(hydrate_middleware, event1, mocker):
|
||||
"""Test that app without on_load is processed correctly.
|
||||
|
||||
Args:
|
||||
hydrate_middleware: Instance of HydrateMiddleware
|
||||
event1: An Event.
|
||||
mocker: pytest mock object.
|
||||
"""
|
||||
state = TestState()
|
||||
mocker.patch("reflex.state.State.class_subclasses", {TestState})
|
||||
state = State()
|
||||
update = await hydrate_middleware.preprocess(
|
||||
app=App(state=TestState),
|
||||
app=App(state=State),
|
||||
event=event1,
|
||||
state=state,
|
||||
)
|
||||
assert isinstance(update, StateUpdate)
|
||||
assert update.delta == state.dict()
|
||||
assert len(update.events) == 1
|
||||
assert isinstance(update, StateUpdate)
|
||||
|
||||
update = await state._process(update.events[0]).__anext__()
|
||||
assert update.delta == exp_is_hydrated(state)
|
||||
assert not update.events
|
||||
|
@ -25,10 +25,10 @@ from reflex.app import (
|
||||
upload,
|
||||
)
|
||||
from reflex.components import Box, Component, Cond, Fragment, Text
|
||||
from reflex.event import Event, get_hydrate_event
|
||||
from reflex.event import Event
|
||||
from reflex.middleware import HydrateMiddleware
|
||||
from reflex.model import Model
|
||||
from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
|
||||
from reflex.state import BaseState, State, StateManagerRedis, StateUpdate
|
||||
from reflex.style import Style
|
||||
from reflex.utils import format
|
||||
from reflex.vars import ComputedVar
|
||||
@ -870,6 +870,7 @@ class DynamicState(BaseState):
|
||||
recalculated when the dynamic route var was dirty
|
||||
"""
|
||||
|
||||
is_hydrated: bool = False
|
||||
loaded: int = 0
|
||||
counter: int = 0
|
||||
|
||||
@ -893,10 +894,16 @@ class DynamicState(BaseState):
|
||||
# self.side_effect_counter = self.side_effect_counter + 1
|
||||
return self.dynamic
|
||||
|
||||
on_load_internal = State.on_load_internal.fn
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
index_page, windows_platform: bool, token: str, mocker
|
||||
index_page,
|
||||
windows_platform: bool,
|
||||
token: str,
|
||||
app_module_mock: unittest.mock.Mock,
|
||||
mocker,
|
||||
):
|
||||
"""Create app with dynamic route var, and simulate navigation.
|
||||
|
||||
@ -907,17 +914,14 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
index_page: The index page.
|
||||
windows_platform: Whether the system is windows.
|
||||
token: a Token.
|
||||
app_module_mock: Mocked app module.
|
||||
mocker: pytest mocker object.
|
||||
"""
|
||||
mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
|
||||
DynamicState.add_var(
|
||||
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
|
||||
)
|
||||
arg_name = "dynamic"
|
||||
route = f"/test/[{arg_name}]"
|
||||
if windows_platform:
|
||||
route.lstrip("/").replace("/", "\\")
|
||||
app = App(state=DynamicState)
|
||||
app = app_module_mock.app = App(state=DynamicState)
|
||||
assert arg_name not in app.state.vars
|
||||
app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore
|
||||
assert arg_name in app.state.vars
|
||||
@ -953,33 +957,25 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
|
||||
prev_exp_val = ""
|
||||
for exp_index, exp_val in enumerate(exp_vals):
|
||||
hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
|
||||
exp_router_data = {
|
||||
"headers": {},
|
||||
"ip": client_ip,
|
||||
"sid": sid,
|
||||
"token": token,
|
||||
**hydrate_event.router_data,
|
||||
}
|
||||
exp_router = RouterData(exp_router_data)
|
||||
on_load_internal = _event(
|
||||
name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL}",
|
||||
val=exp_val,
|
||||
)
|
||||
process_coro = process(
|
||||
app,
|
||||
event=hydrate_event,
|
||||
event=on_load_internal,
|
||||
sid=sid,
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
)
|
||||
update = await process_coro.__anext__() # type: ignore
|
||||
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
|
||||
update = await process_coro.__anext__()
|
||||
# route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)]
|
||||
assert update == StateUpdate(
|
||||
delta={
|
||||
state.get_name(): {
|
||||
arg_name: exp_val,
|
||||
f"comp_{arg_name}": exp_val,
|
||||
constants.CompileVars.IS_HYDRATED: False,
|
||||
"loaded": exp_index,
|
||||
"counter": exp_index,
|
||||
"router": exp_router,
|
||||
# "side_effect_counter": exp_index,
|
||||
}
|
||||
},
|
||||
@ -987,13 +983,12 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
_dynamic_state_event(
|
||||
name="on_load",
|
||||
val=exp_val,
|
||||
router_data=exp_router_data,
|
||||
),
|
||||
_dynamic_state_event(
|
||||
name="set_is_hydrated",
|
||||
payload={"value": True},
|
||||
val=exp_val,
|
||||
router_data=exp_router_data,
|
||||
router_data={},
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -1004,7 +999,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
|
||||
# complete the processing
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await process_coro.__anext__() # type: ignore
|
||||
await process_coro.__anext__()
|
||||
|
||||
# check that router data was written to the state_manager store
|
||||
state = await app.state_manager.get_state(token)
|
||||
@ -1017,7 +1012,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
)
|
||||
on_load_update = await process_coro.__anext__() # type: ignore
|
||||
on_load_update = await process_coro.__anext__()
|
||||
assert on_load_update == StateUpdate(
|
||||
delta={
|
||||
state.get_name(): {
|
||||
@ -1031,7 +1026,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
)
|
||||
# complete the processing
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await process_coro.__anext__() # type: ignore
|
||||
await process_coro.__anext__()
|
||||
process_coro = process(
|
||||
app,
|
||||
event=_dynamic_state_event(
|
||||
@ -1041,7 +1036,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
)
|
||||
on_set_is_hydrated_update = await process_coro.__anext__() # type: ignore
|
||||
on_set_is_hydrated_update = await process_coro.__anext__()
|
||||
assert on_set_is_hydrated_update == StateUpdate(
|
||||
delta={
|
||||
state.get_name(): {
|
||||
@ -1055,7 +1050,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
)
|
||||
# complete the processing
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await process_coro.__anext__() # type: ignore
|
||||
await process_coro.__anext__()
|
||||
|
||||
# a simple state update event should NOT trigger on_load or route var side effects
|
||||
process_coro = process(
|
||||
@ -1065,7 +1060,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
)
|
||||
update = await process_coro.__anext__() # type: ignore
|
||||
update = await process_coro.__anext__()
|
||||
assert update == StateUpdate(
|
||||
delta={
|
||||
state.get_name(): {
|
||||
@ -1079,7 +1074,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
)
|
||||
# complete the processing
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await process_coro.__anext__() # type: ignore
|
||||
await process_coro.__anext__()
|
||||
|
||||
prev_exp_val = exp_val
|
||||
state = await app.state_manager.get_state(token)
|
||||
@ -1116,7 +1111,7 @@ async def test_process_events(mocker, token: str):
|
||||
token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
|
||||
)
|
||||
|
||||
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): # type: ignore
|
||||
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
|
||||
pass
|
||||
|
||||
assert (await app.state_manager.get_state(token)).value == 5
|
||||
|
@ -7,7 +7,7 @@ import functools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Generator, List, Optional, Union
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
@ -24,6 +24,7 @@ from reflex.state import (
|
||||
LockExpiredError,
|
||||
MutableProxy,
|
||||
RouterData,
|
||||
State,
|
||||
StateManager,
|
||||
StateManagerMemory,
|
||||
StateManagerRedis,
|
||||
@ -1374,8 +1375,13 @@ def test_error_on_state_method_shadow():
|
||||
)
|
||||
|
||||
|
||||
def test_state_with_invalid_yield():
|
||||
"""Test that an error is thrown when a state yields an invalid value."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_with_invalid_yield(capsys):
|
||||
"""Test that an error is thrown when a state yields an invalid value.
|
||||
|
||||
Args:
|
||||
capsys: Pytest fixture for capture standard streams.
|
||||
"""
|
||||
|
||||
class StateWithInvalidYield(BaseState):
|
||||
"""A state that yields an invalid value."""
|
||||
@ -1389,15 +1395,16 @@ def test_state_with_invalid_yield():
|
||||
yield 1
|
||||
|
||||
invalid_state = StateWithInvalidYield()
|
||||
with pytest.raises(TypeError) as err:
|
||||
invalid_state._check_valid(
|
||||
invalid_state.event_handlers["invalid_handler"],
|
||||
rx.event.Event(token="fake_token", name="invalid_handler"),
|
||||
async for update in invalid_state._process(
|
||||
rx.event.Event(token="fake_token", name="invalid_handler")
|
||||
):
|
||||
assert not update.delta
|
||||
assert update.events == rx.event.fix_events(
|
||||
[rx.window_alert("An error occurred. See logs for details.")],
|
||||
token="",
|
||||
)
|
||||
assert (
|
||||
"must only return/yield: None, Events or other EventHandlers"
|
||||
in err.value.args[0]
|
||||
)
|
||||
captured = capsys.readouterr()
|
||||
assert "must only return/yield: None, Events or other EventHandlers" in captured.out
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", params=["in_process", "redis"])
|
||||
@ -2303,3 +2310,150 @@ def test_state_union_optional():
|
||||
assert UnionState.custom_union.c2r is not None # type: ignore
|
||||
assert types.is_optional(UnionState.opt_int._var_type) # type: ignore
|
||||
assert types.is_union(UnionState.int_float._var_type) # type: ignore
|
||||
|
||||
|
||||
def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
|
||||
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
|
||||
|
||||
Args:
|
||||
state: the State that is hydrated.
|
||||
is_hydrated: whether the state is hydrated.
|
||||
|
||||
Returns:
|
||||
dict similar to that returned by `State.get_delta` with IS_HYDRATED: is_hydrated
|
||||
"""
|
||||
return {state.get_full_name(): {CompileVars.IS_HYDRATED: is_hydrated}}
|
||||
|
||||
|
||||
class OnLoadState(State):
|
||||
"""A test state with no return in handler."""
|
||||
|
||||
num: int = 0
|
||||
|
||||
def test_handler(self):
|
||||
"""Test handler."""
|
||||
self.num += 1
|
||||
|
||||
|
||||
class OnLoadState2(State):
|
||||
"""A test state with return in handler."""
|
||||
|
||||
num: int = 0
|
||||
name: str
|
||||
|
||||
def test_handler(self):
|
||||
"""Test handler that calls another handler.
|
||||
|
||||
Returns:
|
||||
Chain of EventHandlers
|
||||
"""
|
||||
self.num += 1
|
||||
return self.change_name
|
||||
|
||||
def change_name(self):
|
||||
"""Test handler to change name."""
|
||||
self.name = "random"
|
||||
|
||||
|
||||
class OnLoadState3(State):
|
||||
"""A test state with async handler."""
|
||||
|
||||
num: int = 0
|
||||
|
||||
async def test_handler(self):
|
||||
"""Test handler."""
|
||||
self.num += 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_state, expected",
|
||||
[
|
||||
(OnLoadState, {"on_load_state": {"num": 1}}),
|
||||
(OnLoadState2, {"on_load_state2": {"num": 1}}),
|
||||
(OnLoadState3, {"on_load_state3": {"num": 1}}),
|
||||
],
|
||||
)
|
||||
async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
|
||||
"""Test that a state hydrate event is processed correctly.
|
||||
|
||||
Args:
|
||||
app_module_mock: The app module that will be returned by get_app().
|
||||
token: A token.
|
||||
test_state: State to process event.
|
||||
expected: Expected delta.
|
||||
mocker: pytest mock object.
|
||||
"""
|
||||
mocker.patch("reflex.state.State.class_subclasses", {test_state})
|
||||
app = app_module_mock.app = App(
|
||||
state=State, load_events={"index": [test_state.test_handler]}
|
||||
)
|
||||
state = State()
|
||||
|
||||
updates = []
|
||||
async for update in rx.app.process(
|
||||
app=app,
|
||||
event=Event(
|
||||
token=token,
|
||||
name=f"{state.get_name()}.{CompileVars.ON_LOAD_INTERNAL}",
|
||||
router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}},
|
||||
),
|
||||
sid="sid",
|
||||
headers={},
|
||||
client_ip="",
|
||||
):
|
||||
assert isinstance(update, StateUpdate)
|
||||
updates.append(update)
|
||||
assert len(updates) == 1
|
||||
assert updates[0].delta == exp_is_hydrated(state, False)
|
||||
|
||||
events = updates[0].events
|
||||
assert len(events) == 2
|
||||
assert (await state._process(events[0]).__anext__()).delta == {
|
||||
test_state.get_full_name(): {"num": 1}
|
||||
}
|
||||
assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
|
||||
"""Test that a state hydrate event for multiple on-load events is processed correctly.
|
||||
|
||||
Args:
|
||||
app_module_mock: The app module that will be returned by get_app().
|
||||
token: A token.
|
||||
mocker: pytest mock object.
|
||||
"""
|
||||
mocker.patch("reflex.state.State.class_subclasses", {OnLoadState})
|
||||
app = app_module_mock.app = App(
|
||||
state=State,
|
||||
load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
|
||||
)
|
||||
state = State()
|
||||
|
||||
updates = []
|
||||
async for update in rx.app.process(
|
||||
app=app,
|
||||
event=Event(
|
||||
token=token,
|
||||
name=f"{state.get_full_name()}.{CompileVars.ON_LOAD_INTERNAL}",
|
||||
router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}},
|
||||
),
|
||||
sid="sid",
|
||||
headers={},
|
||||
client_ip="",
|
||||
):
|
||||
assert isinstance(update, StateUpdate)
|
||||
updates.append(update)
|
||||
assert len(updates) == 1
|
||||
assert updates[0].delta == exp_is_hydrated(state, False)
|
||||
|
||||
events = updates[0].events
|
||||
assert len(events) == 3
|
||||
assert (await state._process(events[0]).__anext__()).delta == {
|
||||
OnLoadState.get_full_name(): {"num": 1}
|
||||
}
|
||||
assert (await state._process(events[1]).__anext__()).delta == {
|
||||
OnLoadState.get_full_name(): {"num": 2}
|
||||
}
|
||||
assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
|
||||
|
Loading…
Reference in New Issue
Block a user