Merge remote-tracking branch 'upstream/main' into redis-hash

This commit is contained in:
Benedikt Bartscher 2024-12-13 11:20:08 +01:00
commit f8bfc78f8f
No known key found for this signature in database
23 changed files with 673 additions and 115 deletions

View File

@ -80,7 +80,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
# Show OS combos first in GUI # Show OS combos first in GUI
os: [ubuntu-latest, windows-latest, macos-12] os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.9.18', '3.10.13', '3.11.5', '3.12.0'] python-version: ['3.9.18', '3.10.13', '3.11.5', '3.12.0']
exclude: exclude:
- os: windows-latest - os: windows-latest
@ -92,7 +92,7 @@ jobs:
python-version: '3.9.18' python-version: '3.9.18'
- os: macos-latest - os: macos-latest
python-version: '3.10.13' python-version: '3.10.13'
- os: macos-12 - os: macos-latest
python-version: '3.12.0' python-version: '3.12.0'
include: include:
- os: windows-latest - os: windows-latest
@ -155,7 +155,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
# Show OS combos first in GUI # Show OS combos first in GUI
os: [ubuntu-latest, windows-latest, macos-12] os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.11.5'] python-version: ['3.11.5']
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}

View File

@ -198,7 +198,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ['3.11.5', '3.12.0'] python-version: ['3.11.5', '3.12.0']
runs-on: macos-12 runs-on: macos-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: ./.github/actions/setup_build_env - uses: ./.github/actions/setup_build_env

View File

@ -88,8 +88,9 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ['3.9.18', '3.10.13', '3.11.5', '3.12.0', '3.13.0'] # Note: py39, py310 versions chosen due to available arm64 darwin builds.
runs-on: macos-12 python-version: ['3.9.13', '3.10.11', '3.11.5', '3.12.0', '3.13.0']
runs-on: macos-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: ./.github/actions/setup_build_env - uses: ./.github/actions/setup_build_env

View File

@ -40,9 +40,6 @@ let event_processing = false;
// Array holding pending events to be processed. // Array holding pending events to be processed.
const event_queue = []; const event_queue = [];
// Pending upload promises, by id
const upload_controllers = {};
/** /**
* Generate a UUID (Used for session tokens). * Generate a UUID (Used for session tokens).
* Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid * Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid
@ -300,7 +297,7 @@ export const applyEvent = async (event, socket) => {
if (socket) { if (socket) {
socket.emit( socket.emit(
"event", "event",
JSON.stringify(event, (k, v) => (v === undefined ? null : v)) event,
); );
return true; return true;
} }
@ -407,6 +404,8 @@ export const connect = async (
transports: transports, transports: transports,
autoUnref: false, autoUnref: false,
}); });
// Ensure undefined fields in events are sent as null instead of removed
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)
function checkVisibility() { function checkVisibility() {
if (document.visibilityState === "visible") { if (document.visibilityState === "visible") {
@ -443,8 +442,7 @@ export const connect = async (
}); });
// On each received message, queue the updates and events. // On each received message, queue the updates and events.
socket.current.on("event", async (message) => { socket.current.on("event", async (update) => {
const update = JSON5.parse(message);
for (const substate in update.delta) { for (const substate in update.delta) {
dispatch[substate](update.delta[substate]); dispatch[substate](update.delta[substate]);
} }
@ -456,7 +454,7 @@ export const connect = async (
}); });
socket.current.on("reload", async (event) => { socket.current.on("reload", async (event) => {
event_processing = false; event_processing = false;
queueEvents([...initialEvents(), JSON5.parse(event)], socket); queueEvents([...initialEvents(), event], socket);
}); });
document.addEventListener("visibilitychange", checkVisibility); document.addEventListener("visibilitychange", checkVisibility);
@ -485,7 +483,9 @@ export const uploadFiles = async (
return false; return false;
} }
if (upload_controllers[upload_id]) { const upload_ref_name = `__upload_controllers_${upload_id}`
if (refs[upload_ref_name]) {
console.log("Upload already in progress for ", upload_id); console.log("Upload already in progress for ", upload_id);
return false; return false;
} }
@ -497,23 +497,31 @@ export const uploadFiles = async (
// Whenever called, responseText will contain the entire response so far. // Whenever called, responseText will contain the entire response so far.
const chunks = progressEvent.event.target.responseText.trim().split("\n"); const chunks = progressEvent.event.target.responseText.trim().split("\n");
// So only process _new_ chunks beyond resp_idx. // So only process _new_ chunks beyond resp_idx.
chunks.slice(resp_idx).map((chunk) => { chunks.slice(resp_idx).map((chunk_json) => {
event_callbacks.map((f, ix) => { try {
f(chunk) const chunk = JSON5.parse(chunk_json);
.then(() => { event_callbacks.map((f, ix) => {
if (ix === event_callbacks.length - 1) { f(chunk)
// Mark this chunk as processed. .then(() => {
resp_idx += 1; if (ix === event_callbacks.length - 1) {
} // Mark this chunk as processed.
}) resp_idx += 1;
.catch((e) => { }
if (progressEvent.progress === 1) { })
// Chunk may be incomplete, so only report errors when full response is available. .catch((e) => {
console.log("Error parsing chunk", chunk, e); if (progressEvent.progress === 1) {
} // Chunk may be incomplete, so only report errors when full response is available.
return; console.log("Error processing chunk", chunk, e);
}); }
}); return;
});
});
} catch (e) {
if (progressEvent.progress === 1) {
console.log("Error parsing chunk", chunk_json, e);
}
return;
}
}); });
}; };
@ -537,7 +545,7 @@ export const uploadFiles = async (
}); });
// Send the file to the server. // Send the file to the server.
upload_controllers[upload_id] = controller; refs[upload_ref_name] = controller;
try { try {
return await axios.post(getBackendURL(UPLOADURL), formdata, config); return await axios.post(getBackendURL(UPLOADURL), formdata, config);
@ -557,7 +565,7 @@ export const uploadFiles = async (
} }
return false; return false;
} finally { } finally {
delete upload_controllers[upload_id]; delete refs[upload_ref_name];
} }
}; };

View File

@ -17,6 +17,7 @@ import sys
import traceback import traceback
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -363,6 +364,10 @@ class App(MiddlewareMixin, LifespanMixin):
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
ping_interval=constants.Ping.INTERVAL, ping_interval=constants.Ping.INTERVAL,
ping_timeout=constants.Ping.TIMEOUT, ping_timeout=constants.Ping.TIMEOUT,
json=SimpleNamespace(
dumps=staticmethod(format.json_dumps),
loads=staticmethod(json.loads),
),
transports=["websocket"], transports=["websocket"],
) )
elif getattr(self.sio, "async_mode", "") != "asgi": elif getattr(self.sio, "async_mode", "") != "asgi":
@ -1290,7 +1295,7 @@ async def process(
await asyncio.create_task( await asyncio.create_task(
app.event_namespace.emit( app.event_namespace.emit(
"reload", "reload",
data=format.json_dumps(event), data=event,
to=sid, to=sid,
) )
) )
@ -1543,7 +1548,7 @@ class EventNamespace(AsyncNamespace):
""" """
# Creating a task prevents the update from being blocked behind other coroutines. # Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task( await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
) )
async def on_event(self, sid, data): async def on_event(self, sid, data):
@ -1556,7 +1561,7 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id. sid: The Socket.IO session id.
data: The event data. data: The event data.
""" """
fields = json.loads(data) fields = data
# Get the event. # Get the event.
event = Event( event = Event(
**{k: v for k, v in fields.items() if k not in ("handler", "event_actions")} **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}

View File

@ -29,7 +29,7 @@ from reflex.event import (
from reflex.utils import format from reflex.utils import format
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
from reflex.vars import VarData from reflex.vars import VarData
from reflex.vars.base import CallableVar, LiteralVar, Var, get_unique_variable_name from reflex.vars.base import CallableVar, Var, get_unique_variable_name
from reflex.vars.sequence import LiteralStringVar from reflex.vars.sequence import LiteralStringVar
DEFAULT_UPLOAD_ID: str = "default" DEFAULT_UPLOAD_ID: str = "default"
@ -108,7 +108,8 @@ def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec:
# UploadFilesProvider assigns a special function to clear selected files # UploadFilesProvider assigns a special function to clear selected files
# into the shared global refs object to make it accessible outside a React # into the shared global refs object to make it accessible outside a React
# component via `run_script` (otherwise backend could never clear files). # component via `run_script` (otherwise backend could never clear files).
return run_script(f"refs['__clear_selected_files']({id_!r})") func = Var("__clear_selected_files")._as_ref()
return run_script(f"{func}({id_!r})")
def cancel_upload(upload_id: str) -> EventSpec: def cancel_upload(upload_id: str) -> EventSpec:
@ -120,7 +121,8 @@ def cancel_upload(upload_id: str) -> EventSpec:
Returns: Returns:
An event spec that cancels the upload when triggered. An event spec that cancels the upload when triggered.
""" """
return run_script(f"upload_controllers[{LiteralVar.create(upload_id)!s}]?.abort()") controller = Var(f"__upload_controllers_{upload_id}")._as_ref()
return run_script(f"{controller}?.abort()")
def get_upload_dir() -> Path: def get_upload_dir() -> Path:

View File

@ -18,6 +18,7 @@ from reflex.event import (
prevent_default, prevent_default,
) )
from reflex.utils.imports import ImportDict from reflex.utils.imports import ImportDict
from reflex.utils.types import is_optional
from reflex.vars import VarData from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
@ -382,6 +383,33 @@ class Input(BaseHTML):
# Fired when a key is released # Fired when a key is released
on_key_up: EventHandler[key_event] on_key_up: EventHandler[key_event]
@classmethod
def create(cls, *children, **props):
"""Create an Input component.
Args:
*children: The children of the component.
**props: The properties of the component.
Returns:
The component.
"""
from reflex.vars.number import ternary_operation
value = props.get("value")
# React expects an empty string(instead of null) for controlled inputs.
if value is not None and is_optional(
(value_var := Var.create(value))._var_type
):
props["value"] = ternary_operation(
(value_var != Var.create(None)) # pyright: ignore [reportGeneralTypeIssues]
& (value_var != Var(_js_expr="undefined")),
value,
Var.create(""),
)
return super().create(*children, **props)
class Label(BaseHTML): class Label(BaseHTML):
"""Display the label element.""" """Display the label element."""

View File

@ -512,7 +512,7 @@ class Input(BaseHTML):
on_unmount: Optional[EventType[[], BASE_STATE]] = None, on_unmount: Optional[EventType[[], BASE_STATE]] = None,
**props, **props,
) -> "Input": ) -> "Input":
"""Create the component. """Create an Input component.
Args: Args:
*children: The children of the component. *children: The children of the component.
@ -576,7 +576,7 @@ class Input(BaseHTML):
class_name: The class name for the component. class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute custom_attrs: custom attribute
**props: The props of the component. **props: The properties of the component.
Returns: Returns:
The component. The component.

View File

@ -8,6 +8,7 @@ from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spe
from reflex.vars.base import Var from reflex.vars.base import Var
from ..base import LiteralAccentColor, RadixThemesComponent from ..base import LiteralAccentColor, RadixThemesComponent
from .checkbox import Checkbox
LiteralDirType = Literal["ltr", "rtl"] LiteralDirType = Literal["ltr", "rtl"]
@ -232,6 +233,15 @@ class ContextMenuSeparator(RadixThemesComponent):
tag = "ContextMenu.Separator" tag = "ContextMenu.Separator"
class ContextMenuCheckbox(Checkbox):
"""The component that contains the checkbox."""
tag = "ContextMenu.CheckboxItem"
# Text to render as shortcut.
shortcut: Var[str]
class ContextMenu(ComponentNamespace): class ContextMenu(ComponentNamespace):
"""Menu representing a set of actions, displayed at the origin of a pointer right-click or long-press.""" """Menu representing a set of actions, displayed at the origin of a pointer right-click or long-press."""
@ -243,6 +253,7 @@ class ContextMenu(ComponentNamespace):
sub_content = staticmethod(ContextMenuSubContent.create) sub_content = staticmethod(ContextMenuSubContent.create)
item = staticmethod(ContextMenuItem.create) item = staticmethod(ContextMenuItem.create)
separator = staticmethod(ContextMenuSeparator.create) separator = staticmethod(ContextMenuSeparator.create)
checkbox = staticmethod(ContextMenuCheckbox.create)
context_menu = ContextMenu() context_menu = ContextMenu()

View File

@ -12,6 +12,7 @@ from reflex.style import Style
from reflex.vars.base import Var from reflex.vars.base import Var
from ..base import RadixThemesComponent from ..base import RadixThemesComponent
from .checkbox import Checkbox
LiteralDirType = Literal["ltr", "rtl"] LiteralDirType = Literal["ltr", "rtl"]
LiteralSizeType = Literal["1", "2"] LiteralSizeType = Literal["1", "2"]
@ -672,6 +673,159 @@ class ContextMenuSeparator(RadixThemesComponent):
""" """
... ...
class ContextMenuCheckbox(Checkbox):
@overload
@classmethod
def create( # type: ignore
cls,
*children,
shortcut: Optional[Union[Var[str], str]] = None,
as_child: Optional[Union[Var[bool], bool]] = None,
size: Optional[
Union[
Breakpoints[str, Literal["1", "2", "3"]],
Literal["1", "2", "3"],
Var[
Union[
Breakpoints[str, Literal["1", "2", "3"]], Literal["1", "2", "3"]
]
],
]
] = None,
variant: Optional[
Union[
Literal["classic", "soft", "surface"],
Var[Literal["classic", "soft", "surface"]],
]
] = None,
color_scheme: Optional[
Union[
Literal[
"amber",
"blue",
"bronze",
"brown",
"crimson",
"cyan",
"gold",
"grass",
"gray",
"green",
"indigo",
"iris",
"jade",
"lime",
"mint",
"orange",
"pink",
"plum",
"purple",
"red",
"ruby",
"sky",
"teal",
"tomato",
"violet",
"yellow",
],
Var[
Literal[
"amber",
"blue",
"bronze",
"brown",
"crimson",
"cyan",
"gold",
"grass",
"gray",
"green",
"indigo",
"iris",
"jade",
"lime",
"mint",
"orange",
"pink",
"plum",
"purple",
"red",
"ruby",
"sky",
"teal",
"tomato",
"violet",
"yellow",
]
],
]
] = None,
high_contrast: Optional[Union[Var[bool], bool]] = None,
default_checked: Optional[Union[Var[bool], bool]] = None,
checked: Optional[Union[Var[bool], bool]] = None,
disabled: Optional[Union[Var[bool], bool]] = None,
required: Optional[Union[Var[bool], bool]] = None,
name: Optional[Union[Var[str], str]] = None,
value: Optional[Union[Var[str], str]] = None,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
on_blur: Optional[EventType[[], BASE_STATE]] = None,
on_change: Optional[
Union[EventType[[], BASE_STATE], EventType[[bool], BASE_STATE]]
] = None,
on_click: Optional[EventType[[], BASE_STATE]] = None,
on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
on_double_click: Optional[EventType[[], BASE_STATE]] = None,
on_focus: Optional[EventType[[], BASE_STATE]] = None,
on_mount: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
on_scroll: Optional[EventType[[], BASE_STATE]] = None,
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
**props,
) -> "ContextMenuCheckbox":
"""Create a new component instance.
Will prepend "RadixThemes" to the component tag to avoid conflicts with
other UI libraries for common names, like Text and Button.
Args:
*children: Child components.
shortcut: Text to render as shortcut.
as_child: Change the default rendered element for the one passed as a child, merging their props and behavior.
size: Checkbox size "1" - "3"
variant: Variant of checkbox: "classic" | "surface" | "soft"
color_scheme: Override theme color for checkbox
high_contrast: Whether to render the checkbox with higher contrast color against background
default_checked: Whether the checkbox is checked by default
checked: Whether the checkbox is checked
disabled: Whether the checkbox is disabled
required: Whether the checkbox is required
name: The name of the checkbox control when submitting the form.
value: The value of the checkbox control when submitting the form.
on_change: Fired when the checkbox is checked or unchecked.
style: The style of the component.
key: A unique key for the component.
id: The id for the component.
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
**props: Component properties.
Returns:
A new component instance.
"""
...
class ContextMenu(ComponentNamespace): class ContextMenu(ComponentNamespace):
root = staticmethod(ContextMenuRoot.create) root = staticmethod(ContextMenuRoot.create)
trigger = staticmethod(ContextMenuTrigger.create) trigger = staticmethod(ContextMenuTrigger.create)
@ -681,5 +835,6 @@ class ContextMenu(ComponentNamespace):
sub_content = staticmethod(ContextMenuSubContent.create) sub_content = staticmethod(ContextMenuSubContent.create)
item = staticmethod(ContextMenuItem.create) item = staticmethod(ContextMenuItem.create)
separator = staticmethod(ContextMenuSeparator.create) separator = staticmethod(ContextMenuSeparator.create)
checkbox = staticmethod(ContextMenuCheckbox.create)
context_menu = ContextMenu() context_menu = ContextMenu()

View File

@ -9,7 +9,9 @@ from reflex.components.core.breakpoints import Responsive
from reflex.components.core.debounce import DebounceInput from reflex.components.core.debounce import DebounceInput
from reflex.components.el import elements from reflex.components.el import elements
from reflex.event import EventHandler, input_event, key_event from reflex.event import EventHandler, input_event, key_event
from reflex.utils.types import is_optional
from reflex.vars.base import Var from reflex.vars.base import Var
from reflex.vars.number import ternary_operation
from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent
@ -96,6 +98,19 @@ class TextFieldRoot(elements.Div, RadixThemesComponent):
Returns: Returns:
The component. The component.
""" """
value = props.get("value")
# React expects an empty string(instead of null) for controlled inputs.
if value is not None and is_optional(
(value_var := Var.create(value))._var_type
):
props["value"] = ternary_operation(
(value_var != Var.create(None)) # pyright: ignore [reportGeneralTypeIssues]
& (value_var != Var(_js_expr="undefined")),
value,
Var.create(""),
)
component = super().create(*children, **props) component = super().create(*children, **props)
if props.get("value") is not None and props.get("on_change") is not None: if props.get("value") is not None and props.get("on_change") is not None:
# create a debounced input if the user requests full control to avoid typing jank # create a debounced input if the user requests full control to avoid typing jank

View File

@ -684,6 +684,9 @@ class Config(Base):
# Maximum expiration lock time for redis state manager # Maximum expiration lock time for redis state manager
redis_lock_expiration: int = constants.Expiration.LOCK redis_lock_expiration: int = constants.Expiration.LOCK
# Maximum lock time before warning for redis state manager.
redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD
# Token expiration time for redis state manager # Token expiration time for redis state manager
redis_token_expiration: int = constants.Expiration.TOKEN redis_token_expiration: int = constants.Expiration.TOKEN

View File

@ -29,6 +29,8 @@ class Expiration(SimpleNamespace):
LOCK = 10000 LOCK = 10000
# The PING timeout # The PING timeout
PING = 120 PING = 120
# The maximum time in milliseconds to hold a lock before throwing a warning.
LOCK_WARNING_THRESHOLD = 1000
class GitIgnore(SimpleNamespace): class GitIgnore(SimpleNamespace):

View File

@ -11,6 +11,7 @@ import inspect
import json import json
import pickle import pickle
import sys import sys
import time
import typing import typing
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -39,6 +40,7 @@ from typing import (
get_type_hints, get_type_hints,
) )
from redis.asyncio.client import PubSub
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from typing_extensions import Self from typing_extensions import Self
@ -69,6 +71,11 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
BaseModelV1 = BaseModelV2 BaseModelV1 = BaseModelV2
try:
from pydantic.v1 import validator
except ModuleNotFoundError:
from pydantic import validator
import wrapt import wrapt
from redis.asyncio import Redis from redis.asyncio import Redis
from redis.exceptions import ResponseError from redis.exceptions import ResponseError
@ -92,6 +99,7 @@ from reflex.utils.exceptions import (
DynamicRouteArgShadowsStateVar, DynamicRouteArgShadowsStateVar,
EventHandlerShadowsBuiltInStateMethod, EventHandlerShadowsBuiltInStateMethod,
ImmutableStateError, ImmutableStateError,
InvalidLockWarningThresholdError,
InvalidStateManagerMode, InvalidStateManagerMode,
LockExpiredError, LockExpiredError,
ReflexRuntimeError, ReflexRuntimeError,
@ -1107,6 +1115,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if ( if (
not field.required not field.required
and field.default is None and field.default is None
and field.default_factory is None
and not types.is_optional(prop._var_type) and not types.is_optional(prop._var_type)
): ):
# Ensure frontend uses null coalescing when accessing. # Ensure frontend uses null coalescing when accessing.
@ -2173,14 +2182,26 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
state["__dict__"].pop("router", None) state["__dict__"].pop("router", None)
state["__dict__"].pop("router_data", None) state["__dict__"].pop("router_data", None)
# Never serialize parent_state or substates. # Never serialize parent_state or substates.
state["__dict__"]["parent_state"] = None state["__dict__"].pop("parent_state", None)
state["__dict__"]["substates"] = {} state["__dict__"].pop("substates", None)
state["__dict__"].pop("_was_touched", None) state["__dict__"].pop("_was_touched", None)
# Remove all inherited vars. # Remove all inherited vars.
for inherited_var_name in self.inherited_vars: for inherited_var_name in self.inherited_vars:
state["__dict__"].pop(inherited_var_name, None) state["__dict__"].pop(inherited_var_name, None)
return state return state
def __setstate__(self, state: dict[str, Any]):
"""Set the state from redis deserialization.
This method is called by pickle to deserialize the object.
Args:
state: The state dict for deserialization.
"""
state["__dict__"]["parent_state"] = None
state["__dict__"]["substates"] = {}
super().__setstate__(state)
def _check_state_size( def _check_state_size(
self, self,
pickle_state_size: int, pickle_state_size: int,
@ -2870,6 +2891,7 @@ class StateManager(Base, ABC):
redis=redis, redis=redis,
token_expiration=config.redis_token_expiration, token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration, lock_expiration=config.redis_lock_expiration,
lock_warning_threshold=config.redis_lock_warning_threshold,
) )
raise InvalidStateManagerMode( raise InvalidStateManagerMode(
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
@ -3239,6 +3261,15 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration return get_config().redis_lock_expiration
def _default_lock_warning_threshold() -> int:
"""Get the default lock warning threshold.
Returns:
The default lock warning threshold.
"""
return get_config().redis_lock_warning_threshold
class StateManagerRedis(StateManager): class StateManagerRedis(StateManager):
"""A state manager that stores states in redis.""" """A state manager that stores states in redis."""
@ -3251,6 +3282,11 @@ class StateManagerRedis(StateManager):
# The maximum time to hold a lock (ms). # The maximum time to hold a lock (ms).
lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration) lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
# The maximum time to hold a lock (ms) before warning.
lock_warning_threshold: int = pydantic.Field(
default_factory=_default_lock_warning_threshold
)
# If HEXPIRE is not supported, use EXPIRE instead. # If HEXPIRE is not supported, use EXPIRE instead.
_hexpire_not_supported: Optional[bool] = pydantic.PrivateAttr(None) _hexpire_not_supported: Optional[bool] = pydantic.PrivateAttr(None)
@ -3413,6 +3449,17 @@ class StateManagerRedis(StateManager):
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
"or use `@rx.event(background=True)` decorator for long-running tasks." "or use `@rx.event(background=True)` decorator for long-running tasks."
) )
elif lock_id is not None:
time_taken = self.lock_expiration / 1000 - (
await self.redis.ttl(self._lock_key(token))
)
if time_taken > self.lock_warning_threshold / 1000:
console.warn(
f"Lock for token {token} was held too long {time_taken=}s, "
f"use `@rx.event(background=True)` decorator for long-running tasks.",
dedupe=True,
)
client_token, substate_name = _split_substate_key(token) client_token, substate_name = _split_substate_key(token)
# If the substate name on the token doesn't match the instance name, it cannot have a parent. # If the substate name on the token doesn't match the instance name, it cannot have a parent.
if state.parent_state is not None and state.get_full_name() != substate_name: if state.parent_state is not None and state.get_full_name() != substate_name:
@ -3477,6 +3524,27 @@ class StateManagerRedis(StateManager):
yield state yield state
await self.set_state(token, state, lock_id) await self.set_state(token, state, lock_id)
@validator("lock_warning_threshold")
@classmethod
def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
"""Validate the lock warning threshold.
Args:
lock_warning_threshold: The lock warning threshold.
values: The validated attributes.
Returns:
The lock warning threshold.
Raises:
InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
"""
if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
raise InvalidLockWarningThresholdError(
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
)
return lock_warning_threshold
@staticmethod @staticmethod
def _lock_key(token: str) -> bytes: def _lock_key(token: str) -> bytes:
"""Get the redis key for a token's lock. """Get the redis key for a token's lock.
@ -3508,6 +3576,35 @@ class StateManagerRedis(StateManager):
nx=True, # only set if it doesn't exist nx=True, # only set if it doesn't exist
) )
async def _get_pubsub_message(
self, pubsub: PubSub, timeout: float | None = None
) -> None:
"""Get lock release events from the pubsub.
Args:
pubsub: The pubsub to get a message from.
timeout: Remaining time to wait for a message.
Returns:
The message.
"""
if timeout is None:
timeout = self.lock_expiration / 1000.0
started = time.time()
message = await pubsub.get_message(
ignore_subscribe_messages=True,
timeout=timeout,
)
if (
message is None
or message["data"] not in self._redis_keyspace_lock_release_events
):
remaining = timeout - (time.time() - started)
if remaining <= 0:
return
await self._get_pubsub_message(pubsub, timeout=remaining)
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
"""Wait for a redis lock to be released via pubsub. """Wait for a redis lock to be released via pubsub.
@ -3520,7 +3617,6 @@ class StateManagerRedis(StateManager):
Raises: Raises:
ResponseError: when the keyspace config cannot be set. ResponseError: when the keyspace config cannot be set.
""" """
state_is_locked = False
lock_key_channel = f"__keyspace@0__:{lock_key.decode()}" lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
# Enable keyspace notifications for the lock key, so we know when it is available. # Enable keyspace notifications for the lock key, so we know when it is available.
try: try:
@ -3534,20 +3630,13 @@ class StateManagerRedis(StateManager):
raise raise
async with self.redis.pubsub() as pubsub: async with self.redis.pubsub() as pubsub:
await pubsub.psubscribe(lock_key_channel) await pubsub.psubscribe(lock_key_channel)
while not state_is_locked: # wait for the lock to be released
# wait for the lock to be released while True:
while True: # fast path
if not await self.redis.exists(lock_key): if await self._try_get_lock(lock_key, lock_id):
break # key was removed, try to get the lock again return
message = await pubsub.get_message( # wait for lock events
ignore_subscribe_messages=True, await self._get_pubsub_message(pubsub)
timeout=self.lock_expiration / 1000.0,
)
if message is None:
continue
if message["data"] in self._redis_keyspace_lock_release_events:
break
state_is_locked = await self._try_get_lock(lock_key, lock_id)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def _lock(self, token: str): async def _lock(self, token: str):

View File

@ -20,6 +20,24 @@ _EMITTED_DEPRECATION_WARNINGS = set()
# Info messages which have been printed. # Info messages which have been printed.
_EMITTED_INFO = set() _EMITTED_INFO = set()
# Warnings which have been printed.
_EMIITED_WARNINGS = set()
# Errors which have been printed.
_EMITTED_ERRORS = set()
# Success messages which have been printed.
_EMITTED_SUCCESS = set()
# Debug messages which have been printed.
_EMITTED_DEBUG = set()
# Logs which have been printed.
_EMITTED_LOGS = set()
# Prints which have been printed.
_EMITTED_PRINTS = set()
def set_log_level(log_level: LogLevel): def set_log_level(log_level: LogLevel):
"""Set the log level. """Set the log level.
@ -55,25 +73,37 @@ def is_debug() -> bool:
return _LOG_LEVEL <= LogLevel.DEBUG return _LOG_LEVEL <= LogLevel.DEBUG
def print(msg: str, **kwargs): def print(msg: str, dedupe: bool = False, **kwargs):
"""Print a message. """Print a message.
Args: Args:
msg: The message to print. msg: The message to print.
dedupe: If True, suppress multiple console logs of print message.
kwargs: Keyword arguments to pass to the print function. kwargs: Keyword arguments to pass to the print function.
""" """
if dedupe:
if msg in _EMITTED_PRINTS:
return
else:
_EMITTED_PRINTS.add(msg)
_console.print(msg, **kwargs) _console.print(msg, **kwargs)
def debug(msg: str, **kwargs): def debug(msg: str, dedupe: bool = False, **kwargs):
"""Print a debug message. """Print a debug message.
Args: Args:
msg: The debug message. msg: The debug message.
dedupe: If True, suppress multiple console logs of debug message.
kwargs: Keyword arguments to pass to the print function. kwargs: Keyword arguments to pass to the print function.
""" """
if is_debug(): if is_debug():
msg_ = f"[purple]Debug: {msg}[/purple]" msg_ = f"[purple]Debug: {msg}[/purple]"
if dedupe:
if msg_ in _EMITTED_DEBUG:
return
else:
_EMITTED_DEBUG.add(msg_)
if progress := kwargs.pop("progress", None): if progress := kwargs.pop("progress", None):
progress.console.print(msg_, **kwargs) progress.console.print(msg_, **kwargs)
else: else:
@ -97,25 +127,37 @@ def info(msg: str, dedupe: bool = False, **kwargs):
print(f"[cyan]Info: {msg}[/cyan]", **kwargs) print(f"[cyan]Info: {msg}[/cyan]", **kwargs)
def success(msg: str, **kwargs): def success(msg: str, dedupe: bool = False, **kwargs):
"""Print a success message. """Print a success message.
Args: Args:
msg: The success message. msg: The success message.
dedupe: If True, suppress multiple console logs of success message.
kwargs: Keyword arguments to pass to the print function. kwargs: Keyword arguments to pass to the print function.
""" """
if _LOG_LEVEL <= LogLevel.INFO: if _LOG_LEVEL <= LogLevel.INFO:
if dedupe:
if msg in _EMITTED_SUCCESS:
return
else:
_EMITTED_SUCCESS.add(msg)
print(f"[green]Success: {msg}[/green]", **kwargs) print(f"[green]Success: {msg}[/green]", **kwargs)
def log(msg: str, **kwargs): def log(msg: str, dedupe: bool = False, **kwargs):
"""Takes a string and logs it to the console. """Takes a string and logs it to the console.
Args: Args:
msg: The message to log. msg: The message to log.
dedupe: If True, suppress multiple console logs of log message.
kwargs: Keyword arguments to pass to the print function. kwargs: Keyword arguments to pass to the print function.
""" """
if _LOG_LEVEL <= LogLevel.INFO: if _LOG_LEVEL <= LogLevel.INFO:
if dedupe:
if msg in _EMITTED_LOGS:
return
else:
_EMITTED_LOGS.add(msg)
_console.log(msg, **kwargs) _console.log(msg, **kwargs)
@ -129,14 +171,20 @@ def rule(title: str, **kwargs):
_console.rule(title, **kwargs) _console.rule(title, **kwargs)
def warn(msg: str, **kwargs): def warn(msg: str, dedupe: bool = False, **kwargs):
"""Print a warning message. """Print a warning message.
Args: Args:
msg: The warning message. msg: The warning message.
dedupe: If True, suppress multiple console logs of warning message.
kwargs: Keyword arguments to pass to the print function. kwargs: Keyword arguments to pass to the print function.
""" """
if _LOG_LEVEL <= LogLevel.WARNING: if _LOG_LEVEL <= LogLevel.WARNING:
if dedupe:
if msg in _EMIITED_WARNINGS:
return
else:
_EMIITED_WARNINGS.add(msg)
print(f"[orange1]Warning: {msg}[/orange1]", **kwargs) print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
@ -169,14 +217,20 @@ def deprecate(
_EMITTED_DEPRECATION_WARNINGS.add(feature_name) _EMITTED_DEPRECATION_WARNINGS.add(feature_name)
def error(msg: str, **kwargs): def error(msg: str, dedupe: bool = False, **kwargs):
"""Print an error message. """Print an error message.
Args: Args:
msg: The error message. msg: The error message.
dedupe: If True, suppress multiple console logs of error message.
kwargs: Keyword arguments to pass to the print function. kwargs: Keyword arguments to pass to the print function.
""" """
if _LOG_LEVEL <= LogLevel.ERROR: if _LOG_LEVEL <= LogLevel.ERROR:
if dedupe:
if msg in _EMITTED_ERRORS:
return
else:
_EMITTED_ERRORS.add(msg)
print(f"[red]{msg}[/red]", **kwargs) print(f"[red]{msg}[/red]", **kwargs)

View File

@ -183,3 +183,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
" Please install it through your system package manager." " Please install it through your system package manager."
+ (f" You can do so by running 'brew install {package}'." if IS_MACOS else "") + (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
) )
class InvalidLockWarningThresholdError(ReflexError):
"""Raised when an invalid lock warning threshold is provided."""

View File

@ -664,18 +664,22 @@ def format_library_name(library_fullname: str):
return lib return lib
def json_dumps(obj: Any) -> str: def json_dumps(obj: Any, **kwargs) -> str:
"""Takes an object and returns a jsonified string. """Takes an object and returns a jsonified string.
Args: Args:
obj: The object to be serialized. obj: The object to be serialized.
kwargs: Additional keyword arguments to pass to json.dumps.
Returns: Returns:
A string A string
""" """
from reflex.utils import serializers from reflex.utils import serializers
return json.dumps(obj, ensure_ascii=False, default=serializers.serialize) kwargs.setdefault("ensure_ascii", False)
kwargs.setdefault("default", serializers.serialize)
return json.dumps(obj, **kwargs)
def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]: def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]:

View File

@ -6,6 +6,7 @@ from pathlib import Path
import pytest import pytest
import reflex.app
from reflex.config import environment from reflex.config import environment
from reflex.testing import AppHarness, AppHarnessProd from reflex.testing import AppHarness, AppHarnessProd
@ -76,3 +77,25 @@ def app_harness_env(request):
The AppHarness class to use for the test. The AppHarness class to use for the test.
""" """
return request.param return request.param
@pytest.fixture(autouse=True)
def raise_console_error(request, mocker):
"""Spy on calls to `console.error` used by the framework.
Help catch spurious error conditions that might otherwise go unnoticed.
If a test is marked with `ignore_console_error`, the spy will be ignored
after the test.
Args:
request: The pytest request object.
mocker: The pytest mocker object.
Yields:
control to the test function.
"""
spy = mocker.spy(reflex.app.console, "error")
yield
if "ignore_console_error" not in request.keywords:
spy.assert_not_called()

View File

@ -628,8 +628,7 @@ async def test_client_side_state(
assert await AppHarness._poll_for_async(poll_for_not_hydrated) assert await AppHarness._poll_for_async(poll_for_not_hydrated)
# Trigger event to get a new instance of the state since the old was expired. # Trigger event to get a new instance of the state since the old was expired.
state_var_input = driver.find_element(By.ID, "state_var") set_sub("c1", "c1 post expire")
state_var_input.send_keys("re-triggering")
# get new references to all cookie and local storage elements (again) # get new references to all cookie and local storage elements (again)
c1 = driver.find_element(By.ID, "c1") c1 = driver.find_element(By.ID, "c1")
@ -650,7 +649,7 @@ async def test_client_side_state(
l1s = driver.find_element(By.ID, "l1s") l1s = driver.find_element(By.ID, "l1s")
s1s = driver.find_element(By.ID, "s1s") s1s = driver.find_element(By.ID, "s1s")
assert c1.text == "c1 value" assert c1.text == "c1 post expire"
assert c2.text == "c2 value" assert c2.text == "c2 value"
assert c3.text == "" # temporary cookie expired after reset state! assert c3.text == "" # temporary cookie expired after reset state!
assert c4.text == "c4 value" assert c4.text == "c4 value"
@ -680,11 +679,11 @@ async def test_client_side_state(
async def poll_for_c1_set(): async def poll_for_c1_set():
sub_state = await get_sub_state() sub_state = await get_sub_state()
return sub_state.c1 == "c1 value" return sub_state.c1 == "c1 post expire"
assert await AppHarness._poll_for_async(poll_for_c1_set) assert await AppHarness._poll_for_async(poll_for_c1_set)
sub_state = await get_sub_state() sub_state = await get_sub_state()
assert sub_state.c1 == "c1 value" assert sub_state.c1 == "c1 post expire"
assert sub_state.c2 == "c2 value" assert sub_state.c2 == "c2 value"
assert sub_state.c3 == "" assert sub_state.c3 == ""
assert sub_state.c4 == "c4 value" assert sub_state.c4 == "c4 value"

View File

@ -13,6 +13,8 @@ from selenium.webdriver.support.ui import WebDriverWait
from reflex.testing import AppHarness, AppHarnessProd from reflex.testing import AppHarness, AppHarnessProd
pytestmark = [pytest.mark.ignore_console_error]
def TestApp(): def TestApp():
"""A test app for event exception handler integration.""" """A test app for event exception handler integration."""

View File

@ -381,9 +381,22 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
cancel_button.click() cancel_button.click()
# look up the backend state and assert on progress # Wait a bit for the upload to get cancelled.
await asyncio.sleep(0.5)
# Get interim progress dicts saved in the on_upload_progress handler.
async def _progress_dicts():
state = await upload_file.get_state(substate_token)
return state.substates[state_name].progress_dicts
# We should have _some_ progress
assert await AppHarness._poll_for_async(_progress_dicts)
# But there should never be a final progress record for a cancelled upload.
for p in await _progress_dicts():
assert p["progress"] != 1
state = await upload_file.get_state(substate_token) state = await upload_file.get_state(substate_token)
assert state.substates[state_name].progress_dicts
file_data = state.substates[state_name]._file_data file_data = state.substates[state_name]._file_data
assert isinstance(file_data, dict) assert isinstance(file_data, dict)
normalized_file_data = {Path(k).name: v for k, v in file_data.items()} normalized_file_data = {Path(k).name: v for k, v in file_data.items()}

View File

@ -1,8 +1,10 @@
from typing import Dict, List, Set, Tuple, Union from typing import Dict, List, Set, Tuple, Union
import pydantic.v1
import pytest import pytest
from reflex import el from reflex import el
from reflex.base import Base
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.core.foreach import ( from reflex.components.core.foreach import (
Foreach, Foreach,
@ -18,6 +20,12 @@ from reflex.vars.number import NumberVar
from reflex.vars.sequence import ArrayVar from reflex.vars.sequence import ArrayVar
class ForEachTag(Base):
"""A tag for testing the ForEach component."""
name: str = ""
class ForEachState(BaseState): class ForEachState(BaseState):
"""A state for testing the ForEach component.""" """A state for testing the ForEach component."""
@ -46,6 +54,8 @@ class ForEachState(BaseState):
bad_annotation_list: list = [["red", "orange"], ["yellow", "blue"]] bad_annotation_list: list = [["red", "orange"], ["yellow", "blue"]]
color_index_tuple: Tuple[int, str] = (0, "red") color_index_tuple: Tuple[int, str] = (0, "red")
default_factory_list: list[ForEachTag] = pydantic.v1.Field(default_factory=list)
class ComponentStateTest(ComponentState): class ComponentStateTest(ComponentState):
"""A test component state.""" """A test component state."""
@ -290,3 +300,11 @@ def test_foreach_component_state():
ForEachState.colors_list, ForEachState.colors_list,
ComponentStateTest.create, ComponentStateTest.create,
) )
def test_foreach_default_factory():
"""Test that the default factory is called."""
_ = Foreach.create(
ForEachState.default_factory_list,
lambda tag: text(tag.name),
)

View File

@ -56,6 +56,7 @@ from reflex.state import (
from reflex.testing import chdir from reflex.testing import chdir
from reflex.utils import format, prerequisites, types from reflex.utils import format, prerequisites, types
from reflex.utils.exceptions import ( from reflex.utils.exceptions import (
InvalidLockWarningThresholdError,
ReflexRuntimeError, ReflexRuntimeError,
SetUndefinedStateVarError, SetUndefinedStateVarError,
StateSerializationError, StateSerializationError,
@ -67,7 +68,9 @@ from tests.units.states.mutation import MutableSQLAModel, MutableTestState
from .states import GenState from .states import GenState
CI = bool(os.environ.get("CI", False)) CI = bool(os.environ.get("CI", False))
LOCK_EXPIRATION = 2000 if CI else 300 LOCK_EXPIRATION = 2500 if CI else 300
LOCK_WARNING_THRESHOLD = 1000 if CI else 100
LOCK_WARN_SLEEP = 1.5 if CI else 0.15
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
@ -1787,6 +1790,7 @@ async def test_state_manager_lock_expire(
substate_token_redis: A token + substate name for looking up in state manager. substate_token_redis: A token + substate name for looking up in state manager.
""" """
state_manager_redis.lock_expiration = LOCK_EXPIRATION state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
async with state_manager_redis.modify_state(substate_token_redis): async with state_manager_redis.modify_state(substate_token_redis):
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
@ -1811,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
unexp_num1 = 666 unexp_num1 = 666
state_manager_redis.lock_expiration = LOCK_EXPIRATION state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
order = [] order = []
@ -1840,6 +1845,57 @@ async def test_state_manager_lock_expire_contend(
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
@pytest.mark.asyncio
async def test_state_manager_lock_warning_threshold_contend(
state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker
):
"""Test that the state manager triggers a warning when lock contention exceeds the warning threshold.
Args:
state_manager_redis: A state manager instance.
token: A token.
substate_token_redis: A token + substate name for looking up in state manager.
mocker: Pytest mocker object.
"""
console_warn = mocker.patch("reflex.utils.console.warn")
state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
order = []
async def _coro_blocker():
async with state_manager_redis.modify_state(substate_token_redis):
order.append("blocker")
await asyncio.sleep(LOCK_WARN_SLEEP)
tasks = [
asyncio.create_task(_coro_blocker()),
]
await tasks[0]
console_warn.assert_called()
assert console_warn.call_count == 7
class CopyingAsyncMock(AsyncMock):
"""An AsyncMock, but deepcopy the args and kwargs first."""
def __call__(self, *args, **kwargs):
"""Call the mock.
Args:
args: the arguments passed to the mock
kwargs: the keyword arguments passed to the mock
Returns:
The result of the mock call
"""
args = copy.deepcopy(args)
kwargs = copy.deepcopy(kwargs)
return super().__call__(*args, **kwargs)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_app_simple(monkeypatch) -> rx.App: def mock_app_simple(monkeypatch) -> rx.App:
"""Simple Mock app fixture. """Simple Mock app fixture.
@ -1856,7 +1912,7 @@ def mock_app_simple(monkeypatch) -> rx.App:
setattr(app_module, CompileVars.APP, app) setattr(app_module, CompileVars.APP, app)
app.state = TestState app.state = TestState
app.event_namespace.emit = AsyncMock() # type: ignore app.event_namespace.emit = CopyingAsyncMock() # type: ignore
def _mock_get_app(*args, **kwargs): def _mock_get_app(*args, **kwargs):
return app_module return app_module
@ -1960,21 +2016,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
mock_app.event_namespace.emit.assert_called_once() mock_app.event_namespace.emit.assert_called_once()
mcall = mock_app.event_namespace.emit.mock_calls[0] mcall = mock_app.event_namespace.emit.mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT) assert mcall.args[0] == str(SocketEvent.EVENT)
assert json.loads(mcall.args[1]) == dataclasses.asdict( assert mcall.args[1] == StateUpdate(
StateUpdate( delta={
delta={ parent_state.get_full_name(): {
parent_state.get_full_name(): { "upper": "",
"upper": "", "sum": 3.14,
"sum": 3.14, },
}, grandchild_state.get_full_name(): {
grandchild_state.get_full_name(): { "value2": "42",
"value2": "42", },
}, GrandchildState3.get_full_name(): {
GrandchildState3.get_full_name(): { "computed": "",
"computed": "", },
}, }
}
)
) )
assert mcall.kwargs["to"] == grandchild_state.router.session.session_id assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
@ -2156,51 +2210,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
assert mock_app.event_namespace is not None assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit emit_mock = mock_app.event_namespace.emit
first_ws_message = json.loads(emit_mock.mock_calls[0].args[1]) first_ws_message = emit_mock.mock_calls[0].args[1]
assert ( assert (
first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router") first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
is not None is not None
) )
assert first_ws_message == { assert first_ws_message == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"order": ["background_task:start"], "order": ["background_task:start"],
"computed_order": ["background_task:start"], "computed_order": ["background_task:start"],
} }
}, },
"events": [], events=[],
"final": True, final=True,
} )
for call in emit_mock.mock_calls[1:5]: for call in emit_mock.mock_calls[1:5]:
assert json.loads(call.args[1]) == { assert call.args[1] == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"computed_order": ["background_task:start"], "computed_order": ["background_task:start"],
} }
}, },
"events": [], events=[],
"final": True, final=True,
} )
assert json.loads(emit_mock.mock_calls[-2].args[1]) == { assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"order": exp_order, "order": exp_order,
"computed_order": exp_order, "computed_order": exp_order,
"dict_list": {}, "dict_list": {},
} }
}, },
"events": [], events=[],
"final": True, final=True,
} )
assert json.loads(emit_mock.mock_calls[-1].args[1]) == { assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"computed_order": exp_order, "computed_order": exp_order,
}, },
}, },
"events": [], events=[],
"final": True, final=True,
} )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -3246,12 +3300,42 @@ async def test_setvar_async_setter():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"expiration_kwargs, expected_values", "expiration_kwargs, expected_values",
[ [
({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)), (
{"redis_lock_expiration": 20000},
(
20000,
constants.Expiration.TOKEN,
constants.Expiration.LOCK_WARNING_THRESHOLD,
),
),
( (
{"redis_lock_expiration": 50000, "redis_token_expiration": 5600}, {"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
(50000, 5600), (50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD),
),
(
{"redis_token_expiration": 7600},
(
constants.Expiration.LOCK,
7600,
constants.Expiration.LOCK_WARNING_THRESHOLD,
),
),
(
{"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500},
(50000, constants.Expiration.TOKEN, 1500),
),
(
{"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000},
(constants.Expiration.LOCK, 5600, 3000),
),
(
{
"redis_lock_expiration": 50000,
"redis_token_expiration": 5600,
"redis_lock_warning_threshold": 2000,
},
(50000, 5600, 2000),
), ),
({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
], ],
) )
def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values): def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
@ -3281,6 +3365,44 @@ config = rx.Config(
state_manager = StateManager.create(state=State) state_manager = StateManager.create(state=State)
assert state_manager.lock_expiration == expected_values[0] # type: ignore assert state_manager.lock_expiration == expected_values[0] # type: ignore
assert state_manager.token_expiration == expected_values[1] # type: ignore assert state_manager.token_expiration == expected_values[1] # type: ignore
assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
@pytest.mark.parametrize(
"redis_lock_expiration, redis_lock_warning_threshold",
[
(10000, 10000),
(20000, 30000),
],
)
def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold(
tmp_path, redis_lock_expiration, redis_lock_warning_threshold
):
proj_root = tmp_path / "project1"
proj_root.mkdir()
config_string = f"""
import reflex as rx
config = rx.Config(
app_name="project1",
redis_url="redis://localhost:6379",
state_manager_mode="redis",
redis_lock_expiration = {redis_lock_expiration},
redis_lock_warning_threshold = {redis_lock_warning_threshold},
)
"""
(proj_root / "rxconfig.py").write_text(dedent(config_string))
with chdir(proj_root):
# reload config for each parameter to avoid stale values
reflex.config.get_config(reload=True)
from reflex.state import State, StateManager
with pytest.raises(InvalidLockWarningThresholdError):
StateManager.create(state=State)
del sys.modules[constants.Config.MODULE]
class MixinState(State, mixin=True): class MixinState(State, mixin=True):