Compare commits

...

12 Commits
main ... v0.6.6

Author SHA1 Message Date
Masen Furer
0c8192222f
rx.upload must include _var_data from props (#4463)
* rx.upload must include _var_data from props

str-casting the dropzone arguments removed any VarData they depended on, like
the state context.

update test_upload to include passing a prop from a state var

* Handle large payload delta from upload event handler

Fix update chunk chaining logic; try/catch wasn't catching errors from the
async inner function.
2024-12-02 20:57:41 -08:00
Masen Furer
9c9602363c
pyproject.toml: bump to 0.6.6 (final) 2024-12-02 20:55:13 -08:00
Masen Furer
27bad2575e
bump to 0.6.6a3 2024-12-02 10:24:37 -08:00
Thomas Brandého
3a7497c852
enable css props via wrapperStyle for recharts components (#4447) 2024-12-02 09:24:10 -08:00
Masen Furer
77594bd0ea
[HOS-333] Send a "reload" message to the frontend after state expiry (#4442)
* Unit test updates

* test_client_storage: simulate backend state expiry

* [HOS-333] Send a "reload" message to the frontend after state expiry

1. a state instance expires on the backing store
2. frontend attempts to process an event against the expired token and gets a
   fresh instance of the state without router_data set
3. backend sends a "reload" message on the websocket containing the event and
   immediately stops processing
4. in response to the "reload" message, frontend sends
   [hydrate, update client storage, on_load, <previous_event>]

This allows the frontend and backend to re-syncronize on the state of the app
before continuing to process regular events.

If the event in (2) is a special hydrate event, then it is processed normally
by the middleware and the "reload" logic is skipped since this indicates an
initial load or a browser refresh.

* unit tests working with redis
2024-12-02 09:24:09 -08:00
Elijah Ahianyo
e4ccba7aee
Remove invitation code logic from reflex logoutv2 (#4433)
* what happened there?

* we should do this for v2 instead
2024-11-25 12:59:32 -08:00
Masen Furer
09b2d92466
bump to 0.6.6a2 2024-11-25 11:32:12 -08:00
Thomas Brandého
cb087acbeb
follow up to #4426 (#4436) 2024-11-25 11:29:08 -08:00
Thomas Brandého
7129bfb513
allow for 'go.Figure | None' annotation in State (#4426) 2024-11-25 10:46:30 -08:00
Thomas Brandého
b41b1f364a
fix mutable default in EventNamespace (#4420) 2024-11-25 10:42:40 -08:00
Masen Furer
9ebf16c140
[ENG-4137] Handle generic alias passing inspect.isclass check (#4427)
On py3.9 and py3.10, `dict[str, str]` and other typing forms are kinda
considered classes, but they still fail when doing `issubclass`, so
specifically exclude generic aliases before calling issubclass.

Fix #4424

Bonus fix: support upcasting of pydantic v1 and v2 models
2024-11-25 10:42:39 -08:00
Masen Furer
6d0fae36e6
pyproject.toml: bump to 0.6.6a1 2024-11-21 19:37:07 -08:00
13 changed files with 335 additions and 54 deletions

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "reflex" name = "reflex"
version = "0.6.6dev1" version = "0.6.6"
description = "Web apps in pure Python." description = "Web apps in pure Python."
license = "Apache-2.0" license = "Apache-2.0"
authors = [ authors = [

View File

@ -454,6 +454,10 @@ export const connect = async (
queueEvents(update.events, socket); queueEvents(update.events, socket);
} }
}); });
socket.current.on("reload", async (event) => {
event_processing = false;
queueEvents([...initialEvents(), JSON5.parse(event)], socket);
});
document.addEventListener("visibilitychange", checkVisibility); document.addEventListener("visibilitychange", checkVisibility);
}; };
@ -486,23 +490,30 @@ export const uploadFiles = async (
return false; return false;
} }
// Track how many partial updates have been processed for this upload.
let resp_idx = 0; let resp_idx = 0;
const eventHandler = (progressEvent) => { const eventHandler = (progressEvent) => {
// handle any delta / event streamed from the upload event handler const event_callbacks = socket._callbacks.$event;
// Whenever called, responseText will contain the entire response so far.
const chunks = progressEvent.event.target.responseText.trim().split("\n"); const chunks = progressEvent.event.target.responseText.trim().split("\n");
// So only process _new_ chunks beyond resp_idx.
chunks.slice(resp_idx).map((chunk) => { chunks.slice(resp_idx).map((chunk) => {
try { event_callbacks.map((f, ix) => {
socket._callbacks.$event.map((f) => { f(chunk)
f(chunk); .then(() => {
}); if (ix === event_callbacks.length - 1) {
resp_idx += 1; // Mark this chunk as processed.
} catch (e) { resp_idx += 1;
if (progressEvent.progress === 1) { }
// Chunk may be incomplete, so only report errors when full response is available. })
console.log("Error parsing chunk", chunk, e); .catch((e) => {
} if (progressEvent.progress === 1) {
return; // Chunk may be incomplete, so only report errors when full response is available.
} console.log("Error parsing chunk", chunk, e);
}
return;
});
});
}); });
}; };
@ -707,7 +718,7 @@ export const useEventLoop = (
const combined_name = events.map((e) => e.name).join("+++"); const combined_name = events.map((e) => e.name).join("+++");
if (event_actions?.temporal) { if (event_actions?.temporal) {
if (!socket.current || !socket.current.connected) { if (!socket.current || !socket.current.connected) {
return; // don't queue when the backend is not connected return; // don't queue when the backend is not connected
} }
} }
if (event_actions?.throttle) { if (event_actions?.throttle) {
@ -848,7 +859,7 @@ export const useEventLoop = (
if (router.components[router.pathname].error) { if (router.components[router.pathname].error) {
delete router.components[router.pathname].error; delete router.components[router.pathname].error;
} }
} };
router.events.on("routeChangeStart", change_start); router.events.on("routeChangeStart", change_start);
router.events.on("routeChangeComplete", change_complete); router.events.on("routeChangeComplete", change_complete);
router.events.on("routeChangeError", change_error); router.events.on("routeChangeError", change_error);

View File

@ -73,6 +73,7 @@ from reflex.event import (
EventSpec, EventSpec,
EventType, EventType,
IndividualEventType, IndividualEventType,
get_hydrate_event,
window_alert, window_alert,
) )
from reflex.model import Model, get_db_status from reflex.model import Model, get_db_status
@ -1259,6 +1260,21 @@ async def process(
) )
# Get the state for the session exclusively. # Get the state for the session exclusively.
async with app.state_manager.modify_state(event.substate_token) as state: async with app.state_manager.modify_state(event.substate_token) as state:
# When this is a brand new instance of the state, signal the
# frontend to reload before processing it.
if (
not state.router_data
and event.name != get_hydrate_event(state)
and app.event_namespace is not None
):
await asyncio.create_task(
app.event_namespace.emit(
"reload",
data=format.json_dumps(event),
to=sid,
)
)
return
# re-assign only when the value is different # re-assign only when the value is different
if state.router_data != router_data: if state.router_data != router_data:
# assignment will recurse into substates and force recalculation of # assignment will recurse into substates and force recalculation of
@ -1462,10 +1478,10 @@ class EventNamespace(AsyncNamespace):
app: App app: App
# Keep a mapping between socket ID and client token. # Keep a mapping between socket ID and client token.
token_to_sid: dict[str, str] = {} token_to_sid: dict[str, str]
# Keep a mapping between client token and socket ID. # Keep a mapping between client token and socket ID.
sid_to_token: dict[str, str] = {} sid_to_token: dict[str, str]
def __init__(self, namespace: str, app: App): def __init__(self, namespace: str, app: App):
"""Initialize the event namespace. """Initialize the event namespace.
@ -1475,6 +1491,8 @@ class EventNamespace(AsyncNamespace):
app: The application object. app: The application object.
""" """
super().__init__(namespace) super().__init__(namespace)
self.token_to_sid = {}
self.sid_to_token = {}
self.app = app self.app = app
def on_connect(self, sid, environ): def on_connect(self, sid, environ):

View File

@ -293,13 +293,15 @@ class Upload(MemoizationLeaf):
format.to_camel_case(key): value for key, value in upload_props.items() format.to_camel_case(key): value for key, value in upload_props.items()
} }
use_dropzone_arguments = { use_dropzone_arguments = Var.create(
"onDrop": event_var, {
**upload_props, "onDrop": event_var,
} **upload_props,
}
)
left_side = f"const {{getRootProps: {root_props_unique_name}, getInputProps: {input_props_unique_name}}} " left_side = f"const {{getRootProps: {root_props_unique_name}, getInputProps: {input_props_unique_name}}} "
right_side = f"useDropzone({str(Var.create(use_dropzone_arguments))})" right_side = f"useDropzone({str(use_dropzone_arguments)})"
var_data = VarData.merge( var_data = VarData.merge(
VarData( VarData(
@ -307,6 +309,7 @@ class Upload(MemoizationLeaf):
hooks={Hooks.EVENTS: None}, hooks={Hooks.EVENTS: None},
), ),
event_var._get_all_var_data(), event_var._get_all_var_data(),
use_dropzone_arguments._get_all_var_data(),
VarData( VarData(
hooks={ hooks={
callback_str: None, callback_str: None,

View File

@ -255,7 +255,7 @@ const extractPoints = (points) => {
def _render(self): def _render(self):
tag = super()._render() tag = super()._render()
figure = self.data.to(dict) figure = self.data.to(dict) if self.data is not None else Var.create({})
merge_dicts = [] # Data will be merged and spread from these dict Vars merge_dicts = [] # Data will be merged and spread from these dict Vars
if self.layout is not None: if self.layout is not None:
# Why is this not a literal dict? Great question... it didn't work # Why is this not a literal dict? Great question... it didn't work

View File

@ -3,7 +3,6 @@
from typing import Dict, Literal from typing import Dict, Literal
from reflex.components.component import Component, MemoizationLeaf, NoSSRComponent from reflex.components.component import Component, MemoizationLeaf, NoSSRComponent
from reflex.utils import console
class Recharts(Component): class Recharts(Component):
@ -11,19 +10,8 @@ class Recharts(Component):
library = "recharts@2.13.0" library = "recharts@2.13.0"
def render(self) -> Dict: def _get_style(self) -> Dict:
"""Render the tag. return {"wrapperStyle": self.style}
Returns:
The rendered tag.
"""
tag = super().render()
if any(p.startswith("css") for p in tag["props"]):
console.warn(
f"CSS props do not work for {self.__class__.__name__}. Consult docs to style it with its own prop."
)
tag["props"] = [p for p in tag["props"] if not p.startswith("css")]
return tag
class RechartsCharts(NoSSRComponent, MemoizationLeaf): class RechartsCharts(NoSSRComponent, MemoizationLeaf):

View File

@ -11,7 +11,6 @@ from reflex.style import Style
from reflex.vars.base import Var from reflex.vars.base import Var
class Recharts(Component): class Recharts(Component):
def render(self) -> Dict: ...
@overload @overload
@classmethod @classmethod
def create( # type: ignore def create( # type: ignore

View File

@ -404,7 +404,7 @@ def logoutv2(
hosting.log_out_on_browser() hosting.log_out_on_browser()
console.debug("Deleting access token from config locally") console.debug("Deleting access token from config locally")
hosting.delete_token_from_config(include_invitation_code=True) hosting.delete_token_from_config()
db_cli = typer.Typer() db_cli = typer.Typer()

View File

@ -1748,7 +1748,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if value is None: if value is None:
continue continue
hinted_args = value_inside_optional(hinted_args) hinted_args = value_inside_optional(hinted_args)
if isinstance(value, dict) and inspect.isclass(hinted_args): if (
isinstance(value, dict)
and inspect.isclass(hinted_args)
and not types.is_generic_alias(hinted_args) # py3.9-py3.10
):
if issubclass(hinted_args, Model): if issubclass(hinted_args, Model):
# Remove non-fields from the payload # Remove non-fields from the payload
payload[arg] = hinted_args( payload[arg] = hinted_args(
@ -1759,7 +1763,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
} }
) )
elif dataclasses.is_dataclass(hinted_args) or issubclass( elif dataclasses.is_dataclass(hinted_args) or issubclass(
hinted_args, Base hinted_args, (Base, BaseModelV1, BaseModelV2)
): ):
payload[arg] = hinted_args(**value) payload[arg] = hinted_args(**value)
if isinstance(value, list) and (hinted_args is set or hinted_args is Set): if isinstance(value, list) and (hinted_args is set or hinted_args is Set):
@ -1955,6 +1959,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if var in self.base_vars or var in self._backend_vars: if var in self.base_vars or var in self._backend_vars:
self._was_touched = True self._was_touched = True
break break
if var == constants.ROUTER_DATA and self.parent_state is None:
self._was_touched = True
break
def _get_was_touched(self) -> bool: def _get_was_touched(self) -> bool:
"""Check current dirty_vars and flag to determine if state instance was modified. """Check current dirty_vars and flag to determine if state instance was modified.

View File

@ -10,6 +10,13 @@ from selenium.webdriver import Firefox
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webdriver import WebDriver from selenium.webdriver.remote.webdriver import WebDriver
from reflex.state import (
State,
StateManagerDisk,
StateManagerMemory,
StateManagerRedis,
_substate_key,
)
from reflex.testing import AppHarness from reflex.testing import AppHarness
from . import utils from . import utils
@ -74,7 +81,7 @@ def ClientSide():
return rx.fragment( return rx.fragment(
rx.input( rx.input(
value=ClientSideState.router.session.client_token, value=ClientSideState.router.session.client_token,
is_read_only=True, read_only=True,
id="token", id="token",
), ),
rx.input( rx.input(
@ -604,6 +611,110 @@ async def test_client_side_state(
assert s2.text == "s2 value" assert s2.text == "s2 value"
assert s3.text == "s3 value" assert s3.text == "s3 value"
# Simulate state expiration
if isinstance(client_side.state_manager, StateManagerRedis):
await client_side.state_manager.redis.delete(
_substate_key(token, State.get_full_name())
)
await client_side.state_manager.redis.delete(_substate_key(token, state_name))
await client_side.state_manager.redis.delete(
_substate_key(token, sub_state_name)
)
await client_side.state_manager.redis.delete(
_substate_key(token, sub_sub_state_name)
)
elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)):
del client_side.state_manager.states[token]
if isinstance(client_side.state_manager, StateManagerDisk):
client_side.state_manager.token_expiration = 0
client_side.state_manager._purge_expired_states()
# Ensure the state is gone (not hydrated)
async def poll_for_not_hydrated():
state = await client_side.get_state(_substate_key(token or "", state_name))
return not state.is_hydrated
assert await AppHarness._poll_for_async(poll_for_not_hydrated)
# Trigger event to get a new instance of the state since the old was expired.
state_var_input = driver.find_element(By.ID, "state_var")
state_var_input.send_keys("re-triggering")
# get new references to all cookie and local storage elements (again)
c1 = driver.find_element(By.ID, "c1")
c2 = driver.find_element(By.ID, "c2")
c3 = driver.find_element(By.ID, "c3")
c4 = driver.find_element(By.ID, "c4")
c5 = driver.find_element(By.ID, "c5")
c6 = driver.find_element(By.ID, "c6")
c7 = driver.find_element(By.ID, "c7")
l1 = driver.find_element(By.ID, "l1")
l2 = driver.find_element(By.ID, "l2")
l3 = driver.find_element(By.ID, "l3")
l4 = driver.find_element(By.ID, "l4")
s1 = driver.find_element(By.ID, "s1")
s2 = driver.find_element(By.ID, "s2")
s3 = driver.find_element(By.ID, "s3")
c1s = driver.find_element(By.ID, "c1s")
l1s = driver.find_element(By.ID, "l1s")
s1s = driver.find_element(By.ID, "s1s")
assert c1.text == "c1 value"
assert c2.text == "c2 value"
assert c3.text == "" # temporary cookie expired after reset state!
assert c4.text == "c4 value"
assert c5.text == "c5 value"
assert c6.text == "c6 value"
assert c7.text == "c7 value"
assert l1.text == "l1 value"
assert l2.text == "l2 value"
assert l3.text == "l3 value"
assert l4.text == "l4 value"
assert s1.text == "s1 value"
assert s2.text == "s2 value"
assert s3.text == "s3 value"
assert c1s.text == "c1s value"
assert l1s.text == "l1s value"
assert s1s.text == "s1s value"
# Get the backend state and ensure the values are still set
async def get_sub_state():
root_state = await client_side.get_state(
_substate_key(token or "", sub_state_name)
)
state = root_state.substates[client_side.get_state_name("_client_side_state")]
sub_state = state.substates[
client_side.get_state_name("_client_side_sub_state")
]
return sub_state
async def poll_for_c1_set():
sub_state = await get_sub_state()
return sub_state.c1 == "c1 value"
assert await AppHarness._poll_for_async(poll_for_c1_set)
sub_state = await get_sub_state()
assert sub_state.c1 == "c1 value"
assert sub_state.c2 == "c2 value"
assert sub_state.c3 == ""
assert sub_state.c4 == "c4 value"
assert sub_state.c5 == "c5 value"
assert sub_state.c6 == "c6 value"
assert sub_state.c7 == "c7 value"
assert sub_state.l1 == "l1 value"
assert sub_state.l2 == "l2 value"
assert sub_state.l3 == "l3 value"
assert sub_state.l4 == "l4 value"
assert sub_state.s1 == "s1 value"
assert sub_state.s2 == "s2 value"
assert sub_state.s3 == "s3 value"
sub_sub_state = sub_state.substates[
client_side.get_state_name("_client_side_sub_sub_state")
]
assert sub_sub_state.c1s == "c1s value"
assert sub_sub_state.l1s == "l1s value"
assert sub_sub_state.s1s == "s1s value"
# clear the cookie jar and local storage, ensure state reset to default # clear the cookie jar and local storage, ensure state reset to default
driver.delete_all_cookies() driver.delete_all_cookies()
local_storage.clear() local_storage.clear()

View File

@ -19,10 +19,14 @@ def UploadFile():
import reflex as rx import reflex as rx
LARGE_DATA = "DUMMY" * 1024 * 512
class UploadState(rx.State): class UploadState(rx.State):
_file_data: Dict[str, str] = {} _file_data: Dict[str, str] = {}
event_order: List[str] = [] event_order: List[str] = []
progress_dicts: List[dict] = [] progress_dicts: List[dict] = []
disabled: bool = False
large_data: str = ""
async def handle_upload(self, files: List[rx.UploadFile]): async def handle_upload(self, files: List[rx.UploadFile]):
for file in files: for file in files:
@ -33,6 +37,7 @@ def UploadFile():
for file in files: for file in files:
upload_data = await file.read() upload_data = await file.read()
self._file_data[file.filename or ""] = upload_data.decode("utf-8") self._file_data[file.filename or ""] = upload_data.decode("utf-8")
self.large_data = LARGE_DATA
yield UploadState.chain_event yield UploadState.chain_event
def upload_progress(self, progress): def upload_progress(self, progress):
@ -41,13 +46,15 @@ def UploadFile():
self.progress_dicts.append(progress) self.progress_dicts.append(progress)
def chain_event(self): def chain_event(self):
assert self.large_data == LARGE_DATA
self.large_data = ""
self.event_order.append("chain_event") self.event_order.append("chain_event")
def index(): def index():
return rx.vstack( return rx.vstack(
rx.input( rx.input(
value=UploadState.router.session.client_token, value=UploadState.router.session.client_token,
is_read_only=True, read_only=True,
id="token", id="token",
), ),
rx.heading("Default Upload"), rx.heading("Default Upload"),
@ -56,6 +63,7 @@ def UploadFile():
rx.button("Select File"), rx.button("Select File"),
rx.text("Drag and drop files here or click to select files"), rx.text("Drag and drop files here or click to select files"),
), ),
disabled=UploadState.disabled,
), ),
rx.button( rx.button(
"Upload", "Upload",

View File

@ -1007,8 +1007,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
substate_token = _substate_key(token, DynamicState) substate_token = _substate_key(token, DynamicState)
sid = "mock_sid" sid = "mock_sid"
client_ip = "127.0.0.1" client_ip = "127.0.0.1"
state = await app.state_manager.get_state(substate_token) async with app.state_manager.modify_state(substate_token) as state:
assert state.dynamic == "" state.router_data = {"simulate": "hydrated"}
assert state.dynamic == ""
exp_vals = ["foo", "foobar", "baz"] exp_vals = ["foo", "foobar", "baz"]
def _event(name, val, **kwargs): def _event(name, val, **kwargs):
@ -1180,6 +1181,7 @@ async def test_process_events(mocker, token: str):
"ip": "127.0.0.1", "ip": "127.0.0.1",
} }
app = App(state=GenState) app = App(state=GenState)
mocker.patch.object(app, "_postprocess", AsyncMock()) mocker.patch.object(app, "_postprocess", AsyncMock())
event = Event( event = Event(
token=token, token=token,
@ -1187,6 +1189,8 @@ async def test_process_events(mocker, token: str):
payload={"c": 5}, payload={"c": 5},
router_data=router_data, router_data=router_data,
) )
async with app.state_manager.modify_state(event.substate_token) as state:
state.router_data = {"simulate": "hydrated"}
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
pass pass

View File

@ -10,7 +10,17 @@ import os
import sys import sys
import threading import threading
from textwrap import dedent from textwrap import dedent
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
)
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
import pytest import pytest
@ -1828,12 +1838,11 @@ async def test_state_manager_lock_expire_contend(
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: def mock_app_simple(monkeypatch) -> rx.App:
"""Mock app fixture. """Simple Mock app fixture.
Args: Args:
monkeypatch: Pytest monkeypatch object. monkeypatch: Pytest monkeypatch object.
state_manager: A state manager.
Returns: Returns:
The app, after mocking out prerequisites.get_app() The app, after mocking out prerequisites.get_app()
@ -1844,7 +1853,6 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
setattr(app_module, CompileVars.APP, app) setattr(app_module, CompileVars.APP, app)
app.state = TestState app.state = TestState
app._state_manager = state_manager
app.event_namespace.emit = AsyncMock() # type: ignore app.event_namespace.emit = AsyncMock() # type: ignore
def _mock_get_app(*args, **kwargs): def _mock_get_app(*args, **kwargs):
@ -1854,6 +1862,21 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
return app return app
@pytest.fixture(scope="function")
def mock_app(mock_app_simple: rx.App, state_manager: StateManager) -> rx.App:
"""Mock app fixture.
Args:
mock_app_simple: A simple mock app.
state_manager: A state manager.
Returns:
The app, after mocking out prerequisites.get_app()
"""
mock_app_simple._state_manager = state_manager
return mock_app_simple
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
"""Test that the state proxy works. """Test that the state proxy works.
@ -1959,6 +1982,10 @@ class BackgroundTaskState(BaseState):
order: List[str] = [] order: List[str] = []
dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]} dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
def __init__(self, **kwargs): # noqa: D107
super().__init__(**kwargs)
self.router_data = {"simulate": "hydrate"}
@rx.var @rx.var
def computed_order(self) -> List[str]: def computed_order(self) -> List[str]:
"""Get the order as a computed var. """Get the order as a computed var.
@ -2709,7 +2736,7 @@ def test_set_base_field_via_setter():
assert "c2" in bfss.dirty_vars assert "c2" in bfss.dirty_vars
def exp_is_hydrated(state: State, is_hydrated: bool = True) -> Dict[str, Any]: def exp_is_hydrated(state: BaseState, is_hydrated: bool = True) -> Dict[str, Any]:
"""Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware.
Args: Args:
@ -2788,7 +2815,8 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
app = app_module_mock.app = App( app = app_module_mock.app = App(
state=State, load_events={"index": [test_state.test_handler]} state=State, load_events={"index": [test_state.test_handler]}
) )
state = State() async with app.state_manager.modify_state(_substate_key(token, State)) as state:
state.router_data = {"simulate": "hydrate"}
updates = [] updates = []
async for update in rx.app.process( async for update in rx.app.process(
@ -2835,7 +2863,8 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
state=State, state=State,
load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]}, load_events={"index": [OnLoadState.test_handler, OnLoadState.test_handler]},
) )
state = State() async with app.state_manager.modify_state(_substate_key(token, State)) as state:
state.router_data = {"simulate": "hydrate"}
updates = [] updates = []
async for update in rx.app.process( async for update in rx.app.process(
@ -3506,3 +3535,106 @@ def test_init_mixin() -> None:
with pytest.raises(ReflexRuntimeError): with pytest.raises(ReflexRuntimeError):
SubMixin() SubMixin()
class ReflexModel(rx.Model):
"""A model for testing."""
foo: str
class UpcastState(rx.State):
"""A state for testing upcasting."""
passed: bool = False
def rx_model(self, m: ReflexModel): # noqa: D102
assert isinstance(m, ReflexModel)
self.passed = True
def rx_base(self, o: Object): # noqa: D102
assert isinstance(o, Object)
self.passed = True
def rx_base_or_none(self, o: Optional[Object]): # noqa: D102
if o is not None:
assert isinstance(o, Object)
self.passed = True
def rx_basemodelv1(self, m: ModelV1): # noqa: D102
assert isinstance(m, ModelV1)
self.passed = True
def rx_basemodelv2(self, m: ModelV2): # noqa: D102
assert isinstance(m, ModelV2)
self.passed = True
def rx_dataclass(self, dc: ModelDC): # noqa: D102
assert isinstance(dc, ModelDC)
self.passed = True
def py_set(self, s: set): # noqa: D102
assert isinstance(s, set)
self.passed = True
def py_Set(self, s: Set): # noqa: D102
assert isinstance(s, Set)
self.passed = True
def py_tuple(self, t: tuple): # noqa: D102
assert isinstance(t, tuple)
self.passed = True
def py_Tuple(self, t: Tuple): # noqa: D102
assert isinstance(t, tuple)
self.passed = True
def py_dict(self, d: dict[str, str]): # noqa: D102
assert isinstance(d, dict)
self.passed = True
def py_list(self, ls: list[str]): # noqa: D102
assert isinstance(ls, list)
self.passed = True
def py_Any(self, a: Any): # noqa: D102
assert isinstance(a, list)
self.passed = True
def py_unresolvable(self, u: "Unresolvable"): # noqa: D102, F821 # type: ignore
assert isinstance(u, list)
self.passed = True
@pytest.mark.asyncio
@pytest.mark.usefixtures("mock_app_simple")
@pytest.mark.parametrize(
("handler", "payload"),
[
(UpcastState.rx_model, {"m": {"foo": "bar"}}),
(UpcastState.rx_base, {"o": {"foo": "bar"}}),
(UpcastState.rx_base_or_none, {"o": {"foo": "bar"}}),
(UpcastState.rx_base_or_none, {"o": None}),
(UpcastState.rx_basemodelv1, {"m": {"foo": "bar"}}),
(UpcastState.rx_basemodelv2, {"m": {"foo": "bar"}}),
(UpcastState.rx_dataclass, {"dc": {"foo": "bar"}}),
(UpcastState.py_set, {"s": ["foo", "foo"]}),
(UpcastState.py_Set, {"s": ["foo", "foo"]}),
(UpcastState.py_tuple, {"t": ["foo", "foo"]}),
(UpcastState.py_Tuple, {"t": ["foo", "foo"]}),
(UpcastState.py_dict, {"d": {"foo": "bar"}}),
(UpcastState.py_list, {"ls": ["foo", "foo"]}),
(UpcastState.py_Any, {"a": ["foo"]}),
(UpcastState.py_unresolvable, {"u": ["foo"]}),
],
)
async def test_upcast_event_handler_arg(handler, payload):
"""Test that upcast event handler args work correctly.
Args:
handler: The handler to test.
payload: The payload to test.
"""
state = UpcastState()
async for update in state._process_event(handler, state, payload):
assert update.delta == {UpcastState.get_full_name(): {"passed": True}}