[REF-201] Separate on_load handler from initial hydration (#1847)
This commit is contained in:
parent
3c7af9fad4
commit
60147dec65
@ -23,10 +23,15 @@ export const clientStorage = {}
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% if state_name %}
|
{% if state_name %}
|
||||||
|
export const onLoadInternalEvent = () => [Event('{{state_name}}.{{const.on_load_internal}}')]
|
||||||
|
|
||||||
export const initialEvents = () => [
|
export const initialEvents = () => [
|
||||||
Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)),
|
Event('{{state_name}}.{{const.hydrate}}', hydrateClientStorage(clientStorage)),
|
||||||
|
...onLoadInternalEvent()
|
||||||
]
|
]
|
||||||
{% else %}
|
{% else %}
|
||||||
|
export const onLoadInternalEvent = () => []
|
||||||
|
|
||||||
export const initialEvents = () => []
|
export const initialEvents = () => []
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import env from "env.json";
|
|||||||
import Cookies from "universal-cookie";
|
import Cookies from "universal-cookie";
|
||||||
import { useEffect, useReducer, useRef, useState } from "react";
|
import { useEffect, useReducer, useRef, useState } from "react";
|
||||||
import Router, { useRouter } from "next/router";
|
import Router, { useRouter } from "next/router";
|
||||||
import { initialEvents, initialState } from "utils/context.js"
|
import { initialEvents, initialState, onLoadInternalEvent } from "utils/context.js"
|
||||||
|
|
||||||
// Endpoint URLs.
|
// Endpoint URLs.
|
||||||
const EVENTURL = env.EVENT
|
const EVENTURL = env.EVENT
|
||||||
@ -529,10 +529,15 @@ export const useEventLoop = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
const sentHydrate = useRef(false); // Avoid double-hydrate due to React strict-mode
|
const sentHydrate = useRef(false); // Avoid double-hydrate due to React strict-mode
|
||||||
// initial state hydrate
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (router.isReady && !sentHydrate.current) {
|
if (router.isReady && !sentHydrate.current) {
|
||||||
addEvents(initial_events())
|
const events = initial_events()
|
||||||
|
addEvents(events.map((e) => (
|
||||||
|
{
|
||||||
|
...e,
|
||||||
|
router_data: (({ pathname, query, asPath }) => ({ pathname, query, asPath }))(router)
|
||||||
|
}
|
||||||
|
)))
|
||||||
sentHydrate.current = true
|
sentHydrate.current = true
|
||||||
}
|
}
|
||||||
}, [router.isReady])
|
}, [router.isReady])
|
||||||
@ -560,7 +565,7 @@ export const useEventLoop = (
|
|||||||
|
|
||||||
// Route after the initial page hydration.
|
// Route after the initial page hydration.
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const change_complete = () => addEvents(initial_events())
|
const change_complete = () => addEvents(onLoadInternalEvent())
|
||||||
router.events.on('routeChangeComplete', change_complete)
|
router.events.on('routeChangeComplete', change_complete)
|
||||||
return () => {
|
return () => {
|
||||||
router.events.off('routeChangeComplete', change_complete)
|
router.events.off('routeChangeComplete', change_complete)
|
||||||
|
@ -125,7 +125,7 @@ class App(Base):
|
|||||||
self, state: State, event: Event
|
self, state: State, event: Event
|
||||||
) -> asyncio.Task | None: ...
|
) -> asyncio.Task | None: ...
|
||||||
|
|
||||||
async def process(
|
def process(
|
||||||
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
||||||
) -> AsyncIterator[StateUpdate]: ...
|
) -> AsyncIterator[StateUpdate]: ...
|
||||||
async def ping() -> str: ...
|
async def ping() -> str: ...
|
||||||
|
@ -40,6 +40,7 @@ class ReflexJinjaEnvironment(Environment):
|
|||||||
"toggle_color_mode": constants.ColorMode.TOGGLE,
|
"toggle_color_mode": constants.ColorMode.TOGGLE,
|
||||||
"use_color_mode": constants.ColorMode.USE,
|
"use_color_mode": constants.ColorMode.USE,
|
||||||
"hydrate": constants.CompileVars.HYDRATE,
|
"hydrate": constants.CompileVars.HYDRATE,
|
||||||
|
"on_load_internal": constants.CompileVars.ON_LOAD_INTERNAL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,6 +48,7 @@ from .route import (
|
|||||||
ROUTE_NOT_FOUND,
|
ROUTE_NOT_FOUND,
|
||||||
ROUTER,
|
ROUTER,
|
||||||
ROUTER_DATA,
|
ROUTER_DATA,
|
||||||
|
ROUTER_DATA_INCLUDE,
|
||||||
DefaultPage,
|
DefaultPage,
|
||||||
Page404,
|
Page404,
|
||||||
RouteArgType,
|
RouteArgType,
|
||||||
@ -97,6 +98,7 @@ __ALL__ = [
|
|||||||
RouteVar,
|
RouteVar,
|
||||||
ROUTER,
|
ROUTER,
|
||||||
ROUTER_DATA,
|
ROUTER_DATA,
|
||||||
|
ROUTER_DATA_INCLUDE,
|
||||||
ROUTE_NOT_FOUND,
|
ROUTE_NOT_FOUND,
|
||||||
SETTER_PREFIX,
|
SETTER_PREFIX,
|
||||||
SKIP_COMPILE_ENV_VAR,
|
SKIP_COMPILE_ENV_VAR,
|
||||||
|
@ -58,6 +58,8 @@ class CompileVars(SimpleNamespace):
|
|||||||
CONNECT_ERROR = "connectError"
|
CONNECT_ERROR = "connectError"
|
||||||
# The name of the function for converting a dict to an event.
|
# The name of the function for converting a dict to an event.
|
||||||
TO_EVENT = "Event"
|
TO_EVENT = "Event"
|
||||||
|
# The name of the internal on_load event.
|
||||||
|
ON_LOAD_INTERNAL = "on_load_internal"
|
||||||
|
|
||||||
|
|
||||||
class PageNames(SimpleNamespace):
|
class PageNames(SimpleNamespace):
|
||||||
|
@ -30,6 +30,10 @@ class RouteVar(SimpleNamespace):
|
|||||||
COOKIE = "cookie"
|
COOKIE = "cookie"
|
||||||
|
|
||||||
|
|
||||||
|
# This subset of router_data is included in chained on_load events.
|
||||||
|
ROUTER_DATA_INCLUDE = set((RouteVar.PATH, RouteVar.ORIGIN, RouteVar.QUERY))
|
||||||
|
|
||||||
|
|
||||||
class RouteRegex(SimpleNamespace):
|
class RouteRegex(SimpleNamespace):
|
||||||
"""Regex used for extracting route args in route."""
|
"""Regex used for extracting route args in route."""
|
||||||
|
|
||||||
|
@ -826,6 +826,10 @@ def fix_events(
|
|||||||
# Fix the events created by the handler.
|
# Fix the events created by the handler.
|
||||||
out = []
|
out = []
|
||||||
for e in events:
|
for e in events:
|
||||||
|
if isinstance(e, Event):
|
||||||
|
# If the event is already an event, append it to the list.
|
||||||
|
out.append(e)
|
||||||
|
continue
|
||||||
if not isinstance(e, (EventHandler, EventSpec)):
|
if not isinstance(e, (EventHandler, EventSpec)):
|
||||||
e = EventHandler(fn=e)
|
e = EventHandler(fn=e)
|
||||||
# Otherwise, create an event from the event spec.
|
# Otherwise, create an event from the event spec.
|
||||||
@ -835,13 +839,19 @@ def fix_events(
|
|||||||
name = format.format_event_handler(e.handler)
|
name = format.format_event_handler(e.handler)
|
||||||
payload = {k._var_name: v._decode() for k, v in e.args} # type: ignore
|
payload = {k._var_name: v._decode() for k, v in e.args} # type: ignore
|
||||||
|
|
||||||
|
# Filter router_data to reduce payload size
|
||||||
|
event_router_data = {
|
||||||
|
k: v
|
||||||
|
for k, v in (router_data or {}).items()
|
||||||
|
if k in constants.route.ROUTER_DATA_INCLUDE
|
||||||
|
}
|
||||||
# Create an event and append it to the list.
|
# Create an event and append it to the list.
|
||||||
out.append(
|
out.append(
|
||||||
Event(
|
Event(
|
||||||
token=token,
|
token=token,
|
||||||
name=name,
|
name=name,
|
||||||
payload=payload,
|
payload=payload,
|
||||||
router_data=router_data or {},
|
router_data=event_router_data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.event import Event, fix_events, get_hydrate_event
|
from reflex.event import Event, get_hydrate_event
|
||||||
from reflex.middleware.middleware import Middleware
|
from reflex.middleware.middleware import Middleware
|
||||||
from reflex.state import BaseState, StateUpdate
|
from reflex.state import BaseState, StateUpdate
|
||||||
from reflex.utils import format
|
from reflex.utils import format
|
||||||
@ -52,11 +52,5 @@ class HydrateMiddleware(Middleware):
|
|||||||
# since a full dict was captured, clean any dirtiness
|
# since a full dict was captured, clean any dirtiness
|
||||||
state._clean()
|
state._clean()
|
||||||
|
|
||||||
# Get the route for on_load events.
|
|
||||||
route = event.router_data.get(constants.RouteVar.PATH, "")
|
|
||||||
# Add the on_load events and set is_hydrated to True.
|
|
||||||
events = [*app.get_load_events(route), type(state).set_is_hydrated(True)] # type: ignore
|
|
||||||
events = fix_events(events, event.token, router_data=event.router_data)
|
|
||||||
|
|
||||||
# Return the state update.
|
# Return the state update.
|
||||||
return StateUpdate(delta=delta, events=events)
|
return StateUpdate(delta=delta, events=[])
|
||||||
|
@ -1016,7 +1016,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _is_valid_type(events: Any) -> bool:
|
def _is_valid_type(events: Any) -> bool:
|
||||||
return isinstance(events, (EventHandler, EventSpec))
|
return isinstance(events, (Event, EventHandler, EventSpec))
|
||||||
|
|
||||||
if events is None or _is_valid_type(events):
|
if events is None or _is_valid_type(events):
|
||||||
return events
|
return events
|
||||||
@ -1313,6 +1313,26 @@ class State(BaseState):
|
|||||||
# The hydrated bool.
|
# The hydrated bool.
|
||||||
is_hydrated: bool = False
|
is_hydrated: bool = False
|
||||||
|
|
||||||
|
def on_load_internal(self) -> list[Event | EventSpec] | None:
|
||||||
|
"""Queue on_load handlers for the current page.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of events to queue for on load handling.
|
||||||
|
"""
|
||||||
|
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
|
||||||
|
load_events = app.get_load_events(self.router.page.path)
|
||||||
|
if not load_events and self.is_hydrated:
|
||||||
|
return # Fast path for page-to-page navigation
|
||||||
|
self.is_hydrated = False
|
||||||
|
return [
|
||||||
|
*fix_events(
|
||||||
|
load_events,
|
||||||
|
self.router.session.client_token,
|
||||||
|
router_data=self.router_data,
|
||||||
|
),
|
||||||
|
type(self).set_is_hydrated(True), # type: ignore
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class StateProxy(wrapt.ObjectProxy):
|
class StateProxy(wrapt.ObjectProxy):
|
||||||
"""Proxy of a state instance to control mutability of vars for a background task.
|
"""Proxy of a state instance to control mutability of vars for a background task.
|
||||||
|
@ -123,9 +123,17 @@ def get_app(reload: bool = False) -> ModuleType:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The app based on the default config.
|
The app based on the default config.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the app name is not set in the config.
|
||||||
"""
|
"""
|
||||||
os.environ[constants.RELOAD_CONFIG] = str(reload)
|
os.environ[constants.RELOAD_CONFIG] = str(reload)
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
if not config.app_name:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot get the app module because `app_name` is not set in rxconfig! "
|
||||||
|
"If this error occurs in a reflex test case, ensure that `get_app` is mocked."
|
||||||
|
)
|
||||||
module = ".".join([config.app_name, config.app_name])
|
module = ".".join([config.app_name, config.app_name])
|
||||||
sys.path.insert(0, os.getcwd())
|
sys.path.insert(0, os.getcwd())
|
||||||
app = __import__(module, fromlist=(constants.CompileVars.APP,))
|
app = __import__(module, fromlist=(constants.CompileVars.APP,))
|
||||||
|
@ -5,11 +5,13 @@ import platform
|
|||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Generator
|
from typing import Dict, Generator
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reflex.app import App
|
from reflex.app import App
|
||||||
from reflex.event import EventSpec
|
from reflex.event import EventSpec
|
||||||
|
from reflex.utils import prerequisites
|
||||||
|
|
||||||
from .states import (
|
from .states import (
|
||||||
DictMutationTestState,
|
DictMutationTestState,
|
||||||
@ -30,6 +32,26 @@ def app() -> App:
|
|||||||
return App()
|
return App()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app_module_mock(monkeypatch) -> mock.Mock:
|
||||||
|
"""Mock the app module.
|
||||||
|
|
||||||
|
This overwrites prerequisites.get_app to return the mock for the app module.
|
||||||
|
|
||||||
|
To use this in your test, assign `app_module_mock.app = rx.App(...)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
monkeypatch: pytest monkeypatch fixture.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The mock for the main app module.
|
||||||
|
"""
|
||||||
|
app_module_mock = mock.Mock()
|
||||||
|
get_app_mock = mock.Mock(return_value=app_module_mock)
|
||||||
|
monkeypatch.setattr(prerequisites, "get_app", get_app_mock)
|
||||||
|
return app_module_mock
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def windows_platform() -> Generator:
|
def windows_platform() -> Generator:
|
||||||
"""Check if system is windows.
|
"""Check if system is windows.
|
||||||
|
@ -21,14 +21,4 @@ def create_event(name):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def event1():
|
def event1():
|
||||||
return create_event("test_state.hydrate")
|
return create_event("state.hydrate")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def event2():
|
|
||||||
return create_event("test_state2.hydrate")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def event3():
|
|
||||||
return create_event("test_state3.hydrate")
|
|
||||||
|
@ -1,27 +1,13 @@
|
|||||||
from typing import Any, Dict
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from reflex import constants
|
|
||||||
from reflex.app import App
|
from reflex.app import App
|
||||||
from reflex.constants import CompileVars
|
|
||||||
from reflex.middleware.hydrate_middleware import HydrateMiddleware
|
from reflex.middleware.hydrate_middleware import HydrateMiddleware
|
||||||
from reflex.state import BaseState, StateUpdate
|
from reflex.state import State, StateUpdate
|
||||||
|
|
||||||
|
|
||||||
def exp_is_hydrated(state: BaseState) -> Dict[str, Any]:
|
class TestState(State):
|
||||||
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: the State that is hydrated
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict similar to that returned by `State.get_delta` with IS_HYDRATED: True
|
|
||||||
"""
|
|
||||||
return {state.get_name(): {CompileVars.IS_HYDRATED: True}}
|
|
||||||
|
|
||||||
|
|
||||||
class TestState(BaseState):
|
|
||||||
"""A test state with no return in handler."""
|
"""A test state with no return in handler."""
|
||||||
|
|
||||||
__test__ = False
|
__test__ = False
|
||||||
@ -33,40 +19,6 @@ class TestState(BaseState):
|
|||||||
self.num += 1
|
self.num += 1
|
||||||
|
|
||||||
|
|
||||||
class TestState2(BaseState):
|
|
||||||
"""A test state with return in handler."""
|
|
||||||
|
|
||||||
__test__ = False
|
|
||||||
|
|
||||||
num: int = 0
|
|
||||||
name: str
|
|
||||||
|
|
||||||
def test_handler(self):
|
|
||||||
"""Test handler that calls another handler.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Chain of EventHandlers
|
|
||||||
"""
|
|
||||||
self.num += 1
|
|
||||||
return self.change_name
|
|
||||||
|
|
||||||
def change_name(self):
|
|
||||||
"""Test handler to change name."""
|
|
||||||
self.name = "random"
|
|
||||||
|
|
||||||
|
|
||||||
class TestState3(BaseState):
|
|
||||||
"""A test state with async handler."""
|
|
||||||
|
|
||||||
__test__ = False
|
|
||||||
|
|
||||||
num: int = 0
|
|
||||||
|
|
||||||
async def test_handler(self):
|
|
||||||
"""Test handler."""
|
|
||||||
self.num += 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def hydrate_middleware() -> HydrateMiddleware:
|
def hydrate_middleware() -> HydrateMiddleware:
|
||||||
"""Fixture creates an instance of HydrateMiddleware per test case.
|
"""Fixture creates an instance of HydrateMiddleware per test case.
|
||||||
@ -78,98 +30,21 @@ def hydrate_middleware() -> HydrateMiddleware:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
async def test_preprocess_no_events(hydrate_middleware, event1, mocker):
|
||||||
"test_state, expected, event_fixture",
|
|
||||||
[
|
|
||||||
(TestState, {"test_state": {"num": 1}}, "event1"),
|
|
||||||
(TestState2, {"test_state2": {"num": 1}}, "event2"),
|
|
||||||
(TestState3, {"test_state3": {"num": 1}}, "event3"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
async def test_preprocess(
|
|
||||||
test_state, hydrate_middleware, request, event_fixture, expected
|
|
||||||
):
|
|
||||||
"""Test that a state hydrate event is processed correctly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
test_state: State to process event.
|
|
||||||
hydrate_middleware: Instance of HydrateMiddleware.
|
|
||||||
request: Pytest fixture request.
|
|
||||||
event_fixture: The event fixture(an Event).
|
|
||||||
expected: Expected delta.
|
|
||||||
"""
|
|
||||||
test_state.add_var(
|
|
||||||
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
|
|
||||||
)
|
|
||||||
app = App(state=test_state, load_events={"index": [test_state.test_handler]})
|
|
||||||
state = test_state()
|
|
||||||
|
|
||||||
update = await hydrate_middleware.preprocess(
|
|
||||||
app=app, event=request.getfixturevalue(event_fixture), state=state
|
|
||||||
)
|
|
||||||
assert isinstance(update, StateUpdate)
|
|
||||||
assert update.delta == state.dict()
|
|
||||||
events = update.events
|
|
||||||
assert len(events) == 2
|
|
||||||
|
|
||||||
# Apply the on_load event.
|
|
||||||
update = await state._process(events[0]).__anext__()
|
|
||||||
assert update.delta == expected
|
|
||||||
|
|
||||||
# Apply the hydrate event.
|
|
||||||
update = await state._process(events[1]).__anext__()
|
|
||||||
assert update.delta == exp_is_hydrated(state)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preprocess_multiple_load_events(hydrate_middleware, event1):
|
|
||||||
"""Test that a state hydrate event for multiple on-load events is processed correctly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hydrate_middleware: Instance of HydrateMiddleware
|
|
||||||
event1: An Event.
|
|
||||||
"""
|
|
||||||
app = App(
|
|
||||||
state=TestState,
|
|
||||||
load_events={"index": [TestState.test_handler, TestState.test_handler]},
|
|
||||||
)
|
|
||||||
state = TestState()
|
|
||||||
|
|
||||||
update = await hydrate_middleware.preprocess(app=app, event=event1, state=state)
|
|
||||||
assert isinstance(update, StateUpdate)
|
|
||||||
assert update.delta == state.dict()
|
|
||||||
assert len(update.events) == 3
|
|
||||||
|
|
||||||
# Apply the events.
|
|
||||||
events = update.events
|
|
||||||
update = await state._process(events[0]).__anext__()
|
|
||||||
assert update.delta == {"test_state": {"num": 1}}
|
|
||||||
|
|
||||||
update = await state._process(events[1]).__anext__()
|
|
||||||
assert update.delta == {"test_state": {"num": 2}}
|
|
||||||
|
|
||||||
update = await state._process(events[2]).__anext__()
|
|
||||||
assert update.delta == exp_is_hydrated(state)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preprocess_no_events(hydrate_middleware, event1):
|
|
||||||
"""Test that app without on_load is processed correctly.
|
"""Test that app without on_load is processed correctly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hydrate_middleware: Instance of HydrateMiddleware
|
hydrate_middleware: Instance of HydrateMiddleware
|
||||||
event1: An Event.
|
event1: An Event.
|
||||||
|
mocker: pytest mock object.
|
||||||
"""
|
"""
|
||||||
state = TestState()
|
mocker.patch("reflex.state.State.class_subclasses", {TestState})
|
||||||
|
state = State()
|
||||||
update = await hydrate_middleware.preprocess(
|
update = await hydrate_middleware.preprocess(
|
||||||
app=App(state=TestState),
|
app=App(state=State),
|
||||||
event=event1,
|
event=event1,
|
||||||
state=state,
|
state=state,
|
||||||
)
|
)
|
||||||
assert isinstance(update, StateUpdate)
|
assert isinstance(update, StateUpdate)
|
||||||
assert update.delta == state.dict()
|
assert update.delta == state.dict()
|
||||||
assert len(update.events) == 1
|
assert not update.events
|
||||||
assert isinstance(update, StateUpdate)
|
|
||||||
|
|
||||||
update = await state._process(update.events[0]).__anext__()
|
|
||||||
assert update.delta == exp_is_hydrated(state)
|
|
||||||
|
@ -25,10 +25,10 @@ from reflex.app import (
|
|||||||
upload,
|
upload,
|
||||||
)
|
)
|
||||||
from reflex.components import Box, Component, Cond, Fragment, Text
|
from reflex.components import Box, Component, Cond, Fragment, Text
|
||||||
from reflex.event import Event, get_hydrate_event
|
from reflex.event import Event
|
||||||
from reflex.middleware import HydrateMiddleware
|
from reflex.middleware import HydrateMiddleware
|
||||||
from reflex.model import Model
|
from reflex.model import Model
|
||||||
from reflex.state import BaseState, RouterData, State, StateManagerRedis, StateUpdate
|
from reflex.state import BaseState, State, StateManagerRedis, StateUpdate
|
||||||
from reflex.style import Style
|
from reflex.style import Style
|
||||||
from reflex.utils import format
|
from reflex.utils import format
|
||||||
from reflex.vars import ComputedVar
|
from reflex.vars import ComputedVar
|
||||||
@ -870,6 +870,7 @@ class DynamicState(BaseState):
|
|||||||
recalculated when the dynamic route var was dirty
|
recalculated when the dynamic route var was dirty
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
is_hydrated: bool = False
|
||||||
loaded: int = 0
|
loaded: int = 0
|
||||||
counter: int = 0
|
counter: int = 0
|
||||||
|
|
||||||
@ -893,10 +894,16 @@ class DynamicState(BaseState):
|
|||||||
# self.side_effect_counter = self.side_effect_counter + 1
|
# self.side_effect_counter = self.side_effect_counter + 1
|
||||||
return self.dynamic
|
return self.dynamic
|
||||||
|
|
||||||
|
on_load_internal = State.on_load_internal.fn
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_dynamic_route_var_route_change_completed_on_load(
|
async def test_dynamic_route_var_route_change_completed_on_load(
|
||||||
index_page, windows_platform: bool, token: str, mocker
|
index_page,
|
||||||
|
windows_platform: bool,
|
||||||
|
token: str,
|
||||||
|
app_module_mock: unittest.mock.Mock,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""Create app with dynamic route var, and simulate navigation.
|
"""Create app with dynamic route var, and simulate navigation.
|
||||||
|
|
||||||
@ -907,17 +914,14 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
index_page: The index page.
|
index_page: The index page.
|
||||||
windows_platform: Whether the system is windows.
|
windows_platform: Whether the system is windows.
|
||||||
token: a Token.
|
token: a Token.
|
||||||
|
app_module_mock: Mocked app module.
|
||||||
mocker: pytest mocker object.
|
mocker: pytest mocker object.
|
||||||
"""
|
"""
|
||||||
mocker.patch("reflex.state.State.class_subclasses", {DynamicState})
|
|
||||||
DynamicState.add_var(
|
|
||||||
constants.CompileVars.IS_HYDRATED, type_=bool, default_value=False
|
|
||||||
)
|
|
||||||
arg_name = "dynamic"
|
arg_name = "dynamic"
|
||||||
route = f"/test/[{arg_name}]"
|
route = f"/test/[{arg_name}]"
|
||||||
if windows_platform:
|
if windows_platform:
|
||||||
route.lstrip("/").replace("/", "\\")
|
route.lstrip("/").replace("/", "\\")
|
||||||
app = App(state=DynamicState)
|
app = app_module_mock.app = App(state=DynamicState)
|
||||||
assert arg_name not in app.state.vars
|
assert arg_name not in app.state.vars
|
||||||
app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore
|
app.add_page(index_page, route=route, on_load=DynamicState.on_load) # type: ignore
|
||||||
assert arg_name in app.state.vars
|
assert arg_name in app.state.vars
|
||||||
@ -953,33 +957,25 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
|
|
||||||
prev_exp_val = ""
|
prev_exp_val = ""
|
||||||
for exp_index, exp_val in enumerate(exp_vals):
|
for exp_index, exp_val in enumerate(exp_vals):
|
||||||
hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
|
on_load_internal = _event(
|
||||||
exp_router_data = {
|
name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL}",
|
||||||
"headers": {},
|
val=exp_val,
|
||||||
"ip": client_ip,
|
)
|
||||||
"sid": sid,
|
|
||||||
"token": token,
|
|
||||||
**hydrate_event.router_data,
|
|
||||||
}
|
|
||||||
exp_router = RouterData(exp_router_data)
|
|
||||||
process_coro = process(
|
process_coro = process(
|
||||||
app,
|
app,
|
||||||
event=hydrate_event,
|
event=on_load_internal,
|
||||||
sid=sid,
|
sid=sid,
|
||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
)
|
)
|
||||||
update = await process_coro.__anext__() # type: ignore
|
update = await process_coro.__anext__()
|
||||||
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
|
# route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)]
|
||||||
assert update == StateUpdate(
|
assert update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
state.get_name(): {
|
state.get_name(): {
|
||||||
arg_name: exp_val,
|
arg_name: exp_val,
|
||||||
f"comp_{arg_name}": exp_val,
|
f"comp_{arg_name}": exp_val,
|
||||||
constants.CompileVars.IS_HYDRATED: False,
|
constants.CompileVars.IS_HYDRATED: False,
|
||||||
"loaded": exp_index,
|
|
||||||
"counter": exp_index,
|
|
||||||
"router": exp_router,
|
|
||||||
# "side_effect_counter": exp_index,
|
# "side_effect_counter": exp_index,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -987,13 +983,12 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
_dynamic_state_event(
|
_dynamic_state_event(
|
||||||
name="on_load",
|
name="on_load",
|
||||||
val=exp_val,
|
val=exp_val,
|
||||||
router_data=exp_router_data,
|
|
||||||
),
|
),
|
||||||
_dynamic_state_event(
|
_dynamic_state_event(
|
||||||
name="set_is_hydrated",
|
name="set_is_hydrated",
|
||||||
payload={"value": True},
|
payload={"value": True},
|
||||||
val=exp_val,
|
val=exp_val,
|
||||||
router_data=exp_router_data,
|
router_data={},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1004,7 +999,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
|
|
||||||
# complete the processing
|
# complete the processing
|
||||||
with pytest.raises(StopAsyncIteration):
|
with pytest.raises(StopAsyncIteration):
|
||||||
await process_coro.__anext__() # type: ignore
|
await process_coro.__anext__()
|
||||||
|
|
||||||
# check that router data was written to the state_manager store
|
# check that router data was written to the state_manager store
|
||||||
state = await app.state_manager.get_state(token)
|
state = await app.state_manager.get_state(token)
|
||||||
@ -1017,7 +1012,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
)
|
)
|
||||||
on_load_update = await process_coro.__anext__() # type: ignore
|
on_load_update = await process_coro.__anext__()
|
||||||
assert on_load_update == StateUpdate(
|
assert on_load_update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
state.get_name(): {
|
state.get_name(): {
|
||||||
@ -1031,7 +1026,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
)
|
)
|
||||||
# complete the processing
|
# complete the processing
|
||||||
with pytest.raises(StopAsyncIteration):
|
with pytest.raises(StopAsyncIteration):
|
||||||
await process_coro.__anext__() # type: ignore
|
await process_coro.__anext__()
|
||||||
process_coro = process(
|
process_coro = process(
|
||||||
app,
|
app,
|
||||||
event=_dynamic_state_event(
|
event=_dynamic_state_event(
|
||||||
@ -1041,7 +1036,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
)
|
)
|
||||||
on_set_is_hydrated_update = await process_coro.__anext__() # type: ignore
|
on_set_is_hydrated_update = await process_coro.__anext__()
|
||||||
assert on_set_is_hydrated_update == StateUpdate(
|
assert on_set_is_hydrated_update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
state.get_name(): {
|
state.get_name(): {
|
||||||
@ -1055,7 +1050,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
)
|
)
|
||||||
# complete the processing
|
# complete the processing
|
||||||
with pytest.raises(StopAsyncIteration):
|
with pytest.raises(StopAsyncIteration):
|
||||||
await process_coro.__anext__() # type: ignore
|
await process_coro.__anext__()
|
||||||
|
|
||||||
# a simple state update event should NOT trigger on_load or route var side effects
|
# a simple state update event should NOT trigger on_load or route var side effects
|
||||||
process_coro = process(
|
process_coro = process(
|
||||||
@ -1065,7 +1060,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
)
|
)
|
||||||
update = await process_coro.__anext__() # type: ignore
|
update = await process_coro.__anext__()
|
||||||
assert update == StateUpdate(
|
assert update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
state.get_name(): {
|
state.get_name(): {
|
||||||
@ -1079,7 +1074,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
)
|
)
|
||||||
# complete the processing
|
# complete the processing
|
||||||
with pytest.raises(StopAsyncIteration):
|
with pytest.raises(StopAsyncIteration):
|
||||||
await process_coro.__anext__() # type: ignore
|
await process_coro.__anext__()
|
||||||
|
|
||||||
prev_exp_val = exp_val
|
prev_exp_val = exp_val
|
||||||
state = await app.state_manager.get_state(token)
|
state = await app.state_manager.get_state(token)
|
||||||
@ -1116,7 +1111,7 @@ async def test_process_events(mocker, token: str):
|
|||||||
token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
|
token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
|
||||||
)
|
)
|
||||||
|
|
||||||
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): # type: ignore
|
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert (await app.state_manager.get_state(token)).value == 5
|
assert (await app.state_manager.get_state(token)).value == 5
|
||||||
|
@ -7,7 +7,7 @@ import functools
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict, Generator, List, Optional, Union
|
from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
from unittest.mock import AsyncMock, Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -24,6 +24,7 @@ from reflex.state import (
|
|||||||
LockExpiredError,
|
LockExpiredError,
|
||||||
MutableProxy,
|
MutableProxy,
|
||||||
RouterData,
|
RouterData,
|
||||||
|
State,
|
||||||
StateManager,
|
StateManager,
|
||||||
StateManagerMemory,
|
StateManagerMemory,
|
||||||
StateManagerRedis,
|
StateManagerRedis,
|
||||||
@ -1374,8 +1375,13 @@ def test_error_on_state_method_shadow():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_state_with_invalid_yield():
|
@pytest.mark.asyncio
|
||||||
"""Test that an error is thrown when a state yields an invalid value."""
|
async def test_state_with_invalid_yield(capsys):
|
||||||
|
"""Test that an error is thrown when a state yields an invalid value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
capsys: Pytest fixture for capture standard streams.
|
||||||
|
"""
|
||||||
|
|
||||||
class StateWithInvalidYield(BaseState):
|
class StateWithInvalidYield(BaseState):
|
||||||
"""A state that yields an invalid value."""
|
"""A state that yields an invalid value."""
|
||||||
@ -1389,15 +1395,16 @@ def test_state_with_invalid_yield():
|
|||||||
yield 1
|
yield 1
|
||||||
|
|
||||||
invalid_state = StateWithInvalidYield()
|
invalid_state = StateWithInvalidYield()
|
||||||
with pytest.raises(TypeError) as err:
|
async for update in invalid_state._process(
|
||||||
invalid_state._check_valid(
|
rx.event.Event(token="fake_token", name="invalid_handler")
|
||||||
invalid_state.event_handlers["invalid_handler"],
|
):
|
||||||
rx.event.Event(token="fake_token", name="invalid_handler"),
|
assert not update.delta
|
||||||
|
assert update.events == rx.event.fix_events(
|
||||||
|
[rx.window_alert("An error occurred. See logs for details.")],
|
||||||
|
token="",
|
||||||
)
|
)
|
||||||
assert (
|
captured = capsys.readouterr()
|
||||||
"must only return/yield: None, Events or other EventHandlers"
|
assert "must only return/yield: None, Events or other EventHandlers" in captured.out
|
||||||
in err.value.args[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", params=["in_process", "redis"])
|
@pytest.fixture(scope="function", params=["in_process", "redis"])
|
||||||
@ -2303,3 +2310,150 @@ def test_state_union_optional():
|
|||||||
assert UnionState.custom_union.c2r is not None # type: ignore
|
assert UnionState.custom_union.c2r is not None # type: ignore
|
||||||
assert types.is_optional(UnionState.opt_int._var_type) # type: ignore
|
assert types.is_optional(UnionState.opt_int._var_type) # type: ignore
|
||||||
assert types.is_union(UnionState.int_float._var_type) # type: ignore
|
assert types.is_union(UnionState.int_float._var_type) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]:
|
||||||
|
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: the State that is hydrated.
|
||||||
|
is_hydrated: whether the state is hydrated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict similar to that returned by `State.get_delta` with IS_HYDRATED: is_hydrated
|
||||||
|
"""
|
||||||
|
return {state.get_full_name(): {CompileVars.IS_HYDRATED: is_hydrated}}
|
||||||
|
|
||||||
|
|
||||||
|
class OnLoadState(State):
|
||||||
|
"""A test state with no return in handler."""
|
||||||
|
|
||||||
|
num: int = 0
|
||||||
|
|
||||||
|
def test_handler(self):
|
||||||
|
"""Test handler."""
|
||||||
|
self.num += 1
|
||||||
|
|
||||||
|
|
||||||
|
class OnLoadState2(State):
|
||||||
|
"""A test state with return in handler."""
|
||||||
|
|
||||||
|
num: int = 0
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def test_handler(self):
|
||||||
|
"""Test handler that calls another handler.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Chain of EventHandlers
|
||||||
|
"""
|
||||||
|
self.num += 1
|
||||||
|
return self.change_name
|
||||||
|
|
||||||
|
def change_name(self):
|
||||||
|
"""Test handler to change name."""
|
||||||
|
self.name = "random"
|
||||||
|
|
||||||
|
|
||||||
|
class OnLoadState3(State):
|
||||||
|
"""A test state with async handler."""
|
||||||
|
|
||||||
|
num: int = 0
|
||||||
|
|
||||||
|
async def test_handler(self):
|
||||||
|
"""Test handler."""
|
||||||
|
self.num += 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_state, expected",
|
||||||
|
[
|
||||||
|
(OnLoadState, {"on_load_state": {"num": 1}}),
|
||||||
|
(OnLoadState2, {"on_load_state2": {"num": 1}}),
|
||||||
|
(OnLoadState3, {"on_load_state3": {"num": 1}}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
|
||||||
|
"""Test that a state hydrate event is processed correctly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_module_mock: The app module that will be returned by get_app().
|
||||||
|
token: A token.
|
||||||
|
test_state: State to process event.
|
||||||
|
expected: Expected delta.
|
||||||
|
mocker: pytest mock object.
|
||||||
|
"""
|
||||||
|
mocker.patch("reflex.state.State.class_subclasses", {test_state})
|
||||||
|
app = app_module_mock.app = App(
|
||||||
|
state=State, load_events={"index": [test_state.test_handler]}
|
||||||
|
)
|
||||||
|
state = State()
|
||||||
|
|
||||||
|
updates = []
|
||||||
|
async for update in rx.app.process(
|
||||||
|
app=app,
|
||||||
|
event=Event(
|
||||||
|
token=token,
|
||||||
|
name=f"{state.get_name()}.{CompileVars.ON_LOAD_INTERNAL}",
|
||||||
|
router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}},
|
||||||
|
),
|
||||||
|
sid="sid",
|
||||||
|
headers={},
|
||||||
|
client_ip="",
|
||||||
|
):
|
||||||
|
assert isinstance(update, StateUpdate)
|
||||||
|
updates.append(update)
|
||||||
|
assert len(updates) == 1
|
||||||
|
assert updates[0].delta == exp_is_hydrated(state, False)
|
||||||
|
|
||||||
|
events = updates[0].events
|
||||||
|
assert len(events) == 2
|
||||||
|
assert (await state._process(events[0]).__anext__()).delta == {
|
||||||
|
test_state.get_full_name(): {"num": 1}
|
||||||
|
}
|
||||||
|
assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
|
||||||
|
"""Test that a state hydrate event for multiple on-load events is processed correctly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_module_mock: The app module that will be returned by get_app().
|
||||||
|
token: A token.
|
||||||
|
mocker: pytest mock object.
|
||||||
|
"""
|
||||||
|
mocker.patch("reflex.state.State.class_subclasses", {OnLoadState})
|
||||||
|
app = app_module_mock.app = App(
|
||||||
|
state=State,
|
||||||
|
load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
|
||||||
|
)
|
||||||
|
state = State()
|
||||||
|
|
||||||
|
updates = []
|
||||||
|
async for update in rx.app.process(
|
||||||
|
app=app,
|
||||||
|
event=Event(
|
||||||
|
token=token,
|
||||||
|
name=f"{state.get_full_name()}.{CompileVars.ON_LOAD_INTERNAL}",
|
||||||
|
router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}},
|
||||||
|
),
|
||||||
|
sid="sid",
|
||||||
|
headers={},
|
||||||
|
client_ip="",
|
||||||
|
):
|
||||||
|
assert isinstance(update, StateUpdate)
|
||||||
|
updates.append(update)
|
||||||
|
assert len(updates) == 1
|
||||||
|
assert updates[0].delta == exp_is_hydrated(state, False)
|
||||||
|
|
||||||
|
events = updates[0].events
|
||||||
|
assert len(events) == 3
|
||||||
|
assert (await state._process(events[0]).__anext__()).delta == {
|
||||||
|
OnLoadState.get_full_name(): {"num": 1}
|
||||||
|
}
|
||||||
|
assert (await state._process(events[1]).__anext__()).delta == {
|
||||||
|
OnLoadState.get_full_name(): {"num": 2}
|
||||||
|
}
|
||||||
|
assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
|
||||||
|
Loading…
Reference in New Issue
Block a user