Merge remote-tracking branch 'upstream/main' into redis-hash
This commit is contained in:
commit
f8bfc78f8f
6
.github/workflows/benchmarks.yml
vendored
6
.github/workflows/benchmarks.yml
vendored
@ -80,7 +80,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# Show OS combos first in GUI
|
||||
os: [ubuntu-latest, windows-latest, macos-12]
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
python-version: ['3.9.18', '3.10.13', '3.11.5', '3.12.0']
|
||||
exclude:
|
||||
- os: windows-latest
|
||||
@ -92,7 +92,7 @@ jobs:
|
||||
python-version: '3.9.18'
|
||||
- os: macos-latest
|
||||
python-version: '3.10.13'
|
||||
- os: macos-12
|
||||
- os: macos-latest
|
||||
python-version: '3.12.0'
|
||||
include:
|
||||
- os: windows-latest
|
||||
@ -155,7 +155,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# Show OS combos first in GUI
|
||||
os: [ubuntu-latest, windows-latest, macos-12]
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
python-version: ['3.11.5']
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
2
.github/workflows/integration_tests.yml
vendored
2
.github/workflows/integration_tests.yml
vendored
@ -198,7 +198,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ['3.11.5', '3.12.0']
|
||||
runs-on: macos-12
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: ./.github/actions/setup_build_env
|
||||
|
5
.github/workflows/unit_tests.yml
vendored
5
.github/workflows/unit_tests.yml
vendored
@ -88,8 +88,9 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ['3.9.18', '3.10.13', '3.11.5', '3.12.0', '3.13.0']
|
||||
runs-on: macos-12
|
||||
# Note: py39, py310 versions chosen due to available arm64 darwin builds.
|
||||
python-version: ['3.9.13', '3.10.11', '3.11.5', '3.12.0', '3.13.0']
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: ./.github/actions/setup_build_env
|
||||
|
@ -40,9 +40,6 @@ let event_processing = false;
|
||||
// Array holding pending events to be processed.
|
||||
const event_queue = [];
|
||||
|
||||
// Pending upload promises, by id
|
||||
const upload_controllers = {};
|
||||
|
||||
/**
|
||||
* Generate a UUID (Used for session tokens).
|
||||
* Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid
|
||||
@ -300,7 +297,7 @@ export const applyEvent = async (event, socket) => {
|
||||
if (socket) {
|
||||
socket.emit(
|
||||
"event",
|
||||
JSON.stringify(event, (k, v) => (v === undefined ? null : v))
|
||||
event,
|
||||
);
|
||||
return true;
|
||||
}
|
||||
@ -407,6 +404,8 @@ export const connect = async (
|
||||
transports: transports,
|
||||
autoUnref: false,
|
||||
});
|
||||
// Ensure undefined fields in events are sent as null instead of removed
|
||||
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)
|
||||
|
||||
function checkVisibility() {
|
||||
if (document.visibilityState === "visible") {
|
||||
@ -443,8 +442,7 @@ export const connect = async (
|
||||
});
|
||||
|
||||
// On each received message, queue the updates and events.
|
||||
socket.current.on("event", async (message) => {
|
||||
const update = JSON5.parse(message);
|
||||
socket.current.on("event", async (update) => {
|
||||
for (const substate in update.delta) {
|
||||
dispatch[substate](update.delta[substate]);
|
||||
}
|
||||
@ -456,7 +454,7 @@ export const connect = async (
|
||||
});
|
||||
socket.current.on("reload", async (event) => {
|
||||
event_processing = false;
|
||||
queueEvents([...initialEvents(), JSON5.parse(event)], socket);
|
||||
queueEvents([...initialEvents(), event], socket);
|
||||
});
|
||||
|
||||
document.addEventListener("visibilitychange", checkVisibility);
|
||||
@ -485,7 +483,9 @@ export const uploadFiles = async (
|
||||
return false;
|
||||
}
|
||||
|
||||
if (upload_controllers[upload_id]) {
|
||||
const upload_ref_name = `__upload_controllers_${upload_id}`
|
||||
|
||||
if (refs[upload_ref_name]) {
|
||||
console.log("Upload already in progress for ", upload_id);
|
||||
return false;
|
||||
}
|
||||
@ -497,7 +497,9 @@ export const uploadFiles = async (
|
||||
// Whenever called, responseText will contain the entire response so far.
|
||||
const chunks = progressEvent.event.target.responseText.trim().split("\n");
|
||||
// So only process _new_ chunks beyond resp_idx.
|
||||
chunks.slice(resp_idx).map((chunk) => {
|
||||
chunks.slice(resp_idx).map((chunk_json) => {
|
||||
try {
|
||||
const chunk = JSON5.parse(chunk_json);
|
||||
event_callbacks.map((f, ix) => {
|
||||
f(chunk)
|
||||
.then(() => {
|
||||
@ -509,11 +511,17 @@ export const uploadFiles = async (
|
||||
.catch((e) => {
|
||||
if (progressEvent.progress === 1) {
|
||||
// Chunk may be incomplete, so only report errors when full response is available.
|
||||
console.log("Error parsing chunk", chunk, e);
|
||||
console.log("Error processing chunk", chunk, e);
|
||||
}
|
||||
return;
|
||||
});
|
||||
});
|
||||
} catch (e) {
|
||||
if (progressEvent.progress === 1) {
|
||||
console.log("Error parsing chunk", chunk_json, e);
|
||||
}
|
||||
return;
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
@ -537,7 +545,7 @@ export const uploadFiles = async (
|
||||
});
|
||||
|
||||
// Send the file to the server.
|
||||
upload_controllers[upload_id] = controller;
|
||||
refs[upload_ref_name] = controller;
|
||||
|
||||
try {
|
||||
return await axios.post(getBackendURL(UPLOADURL), formdata, config);
|
||||
@ -557,7 +565,7 @@ export const uploadFiles = async (
|
||||
}
|
||||
return false;
|
||||
} finally {
|
||||
delete upload_controllers[upload_id];
|
||||
delete refs[upload_ref_name];
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -17,6 +17,7 @@ import sys
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -363,6 +364,10 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
|
||||
ping_interval=constants.Ping.INTERVAL,
|
||||
ping_timeout=constants.Ping.TIMEOUT,
|
||||
json=SimpleNamespace(
|
||||
dumps=staticmethod(format.json_dumps),
|
||||
loads=staticmethod(json.loads),
|
||||
),
|
||||
transports=["websocket"],
|
||||
)
|
||||
elif getattr(self.sio, "async_mode", "") != "asgi":
|
||||
@ -1290,7 +1295,7 @@ async def process(
|
||||
await asyncio.create_task(
|
||||
app.event_namespace.emit(
|
||||
"reload",
|
||||
data=format.json_dumps(event),
|
||||
data=event,
|
||||
to=sid,
|
||||
)
|
||||
)
|
||||
@ -1543,7 +1548,7 @@ class EventNamespace(AsyncNamespace):
|
||||
"""
|
||||
# Creating a task prevents the update from being blocked behind other coroutines.
|
||||
await asyncio.create_task(
|
||||
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
|
||||
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
|
||||
)
|
||||
|
||||
async def on_event(self, sid, data):
|
||||
@ -1556,7 +1561,7 @@ class EventNamespace(AsyncNamespace):
|
||||
sid: The Socket.IO session id.
|
||||
data: The event data.
|
||||
"""
|
||||
fields = json.loads(data)
|
||||
fields = data
|
||||
# Get the event.
|
||||
event = Event(
|
||||
**{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}
|
||||
|
@ -29,7 +29,7 @@ from reflex.event import (
|
||||
from reflex.utils import format
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import CallableVar, LiteralVar, Var, get_unique_variable_name
|
||||
from reflex.vars.base import CallableVar, Var, get_unique_variable_name
|
||||
from reflex.vars.sequence import LiteralStringVar
|
||||
|
||||
DEFAULT_UPLOAD_ID: str = "default"
|
||||
@ -108,7 +108,8 @@ def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec:
|
||||
# UploadFilesProvider assigns a special function to clear selected files
|
||||
# into the shared global refs object to make it accessible outside a React
|
||||
# component via `run_script` (otherwise backend could never clear files).
|
||||
return run_script(f"refs['__clear_selected_files']({id_!r})")
|
||||
func = Var("__clear_selected_files")._as_ref()
|
||||
return run_script(f"{func}({id_!r})")
|
||||
|
||||
|
||||
def cancel_upload(upload_id: str) -> EventSpec:
|
||||
@ -120,7 +121,8 @@ def cancel_upload(upload_id: str) -> EventSpec:
|
||||
Returns:
|
||||
An event spec that cancels the upload when triggered.
|
||||
"""
|
||||
return run_script(f"upload_controllers[{LiteralVar.create(upload_id)!s}]?.abort()")
|
||||
controller = Var(f"__upload_controllers_{upload_id}")._as_ref()
|
||||
return run_script(f"{controller}?.abort()")
|
||||
|
||||
|
||||
def get_upload_dir() -> Path:
|
||||
|
@ -18,6 +18,7 @@ from reflex.event import (
|
||||
prevent_default,
|
||||
)
|
||||
from reflex.utils.imports import ImportDict
|
||||
from reflex.utils.types import is_optional
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
|
||||
@ -382,6 +383,33 @@ class Input(BaseHTML):
|
||||
# Fired when a key is released
|
||||
on_key_up: EventHandler[key_event]
|
||||
|
||||
@classmethod
|
||||
def create(cls, *children, **props):
|
||||
"""Create an Input component.
|
||||
|
||||
Args:
|
||||
*children: The children of the component.
|
||||
**props: The properties of the component.
|
||||
|
||||
Returns:
|
||||
The component.
|
||||
"""
|
||||
from reflex.vars.number import ternary_operation
|
||||
|
||||
value = props.get("value")
|
||||
|
||||
# React expects an empty string(instead of null) for controlled inputs.
|
||||
if value is not None and is_optional(
|
||||
(value_var := Var.create(value))._var_type
|
||||
):
|
||||
props["value"] = ternary_operation(
|
||||
(value_var != Var.create(None)) # pyright: ignore [reportGeneralTypeIssues]
|
||||
& (value_var != Var(_js_expr="undefined")),
|
||||
value,
|
||||
Var.create(""),
|
||||
)
|
||||
return super().create(*children, **props)
|
||||
|
||||
|
||||
class Label(BaseHTML):
|
||||
"""Display the label element."""
|
||||
|
@ -512,7 +512,7 @@ class Input(BaseHTML):
|
||||
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
|
||||
**props,
|
||||
) -> "Input":
|
||||
"""Create the component.
|
||||
"""Create an Input component.
|
||||
|
||||
Args:
|
||||
*children: The children of the component.
|
||||
@ -576,7 +576,7 @@ class Input(BaseHTML):
|
||||
class_name: The class name for the component.
|
||||
autofocus: Whether the component should take the focus once the page is loaded
|
||||
custom_attrs: custom attribute
|
||||
**props: The props of the component.
|
||||
**props: The properties of the component.
|
||||
|
||||
Returns:
|
||||
The component.
|
||||
|
@ -8,6 +8,7 @@ from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spe
|
||||
from reflex.vars.base import Var
|
||||
|
||||
from ..base import LiteralAccentColor, RadixThemesComponent
|
||||
from .checkbox import Checkbox
|
||||
|
||||
LiteralDirType = Literal["ltr", "rtl"]
|
||||
|
||||
@ -232,6 +233,15 @@ class ContextMenuSeparator(RadixThemesComponent):
|
||||
tag = "ContextMenu.Separator"
|
||||
|
||||
|
||||
class ContextMenuCheckbox(Checkbox):
|
||||
"""The component that contains the checkbox."""
|
||||
|
||||
tag = "ContextMenu.CheckboxItem"
|
||||
|
||||
# Text to render as shortcut.
|
||||
shortcut: Var[str]
|
||||
|
||||
|
||||
class ContextMenu(ComponentNamespace):
|
||||
"""Menu representing a set of actions, displayed at the origin of a pointer right-click or long-press."""
|
||||
|
||||
@ -243,6 +253,7 @@ class ContextMenu(ComponentNamespace):
|
||||
sub_content = staticmethod(ContextMenuSubContent.create)
|
||||
item = staticmethod(ContextMenuItem.create)
|
||||
separator = staticmethod(ContextMenuSeparator.create)
|
||||
checkbox = staticmethod(ContextMenuCheckbox.create)
|
||||
|
||||
|
||||
context_menu = ContextMenu()
|
||||
|
@ -12,6 +12,7 @@ from reflex.style import Style
|
||||
from reflex.vars.base import Var
|
||||
|
||||
from ..base import RadixThemesComponent
|
||||
from .checkbox import Checkbox
|
||||
|
||||
LiteralDirType = Literal["ltr", "rtl"]
|
||||
LiteralSizeType = Literal["1", "2"]
|
||||
@ -672,6 +673,159 @@ class ContextMenuSeparator(RadixThemesComponent):
|
||||
"""
|
||||
...
|
||||
|
||||
class ContextMenuCheckbox(Checkbox):
|
||||
@overload
|
||||
@classmethod
|
||||
def create( # type: ignore
|
||||
cls,
|
||||
*children,
|
||||
shortcut: Optional[Union[Var[str], str]] = None,
|
||||
as_child: Optional[Union[Var[bool], bool]] = None,
|
||||
size: Optional[
|
||||
Union[
|
||||
Breakpoints[str, Literal["1", "2", "3"]],
|
||||
Literal["1", "2", "3"],
|
||||
Var[
|
||||
Union[
|
||||
Breakpoints[str, Literal["1", "2", "3"]], Literal["1", "2", "3"]
|
||||
]
|
||||
],
|
||||
]
|
||||
] = None,
|
||||
variant: Optional[
|
||||
Union[
|
||||
Literal["classic", "soft", "surface"],
|
||||
Var[Literal["classic", "soft", "surface"]],
|
||||
]
|
||||
] = None,
|
||||
color_scheme: Optional[
|
||||
Union[
|
||||
Literal[
|
||||
"amber",
|
||||
"blue",
|
||||
"bronze",
|
||||
"brown",
|
||||
"crimson",
|
||||
"cyan",
|
||||
"gold",
|
||||
"grass",
|
||||
"gray",
|
||||
"green",
|
||||
"indigo",
|
||||
"iris",
|
||||
"jade",
|
||||
"lime",
|
||||
"mint",
|
||||
"orange",
|
||||
"pink",
|
||||
"plum",
|
||||
"purple",
|
||||
"red",
|
||||
"ruby",
|
||||
"sky",
|
||||
"teal",
|
||||
"tomato",
|
||||
"violet",
|
||||
"yellow",
|
||||
],
|
||||
Var[
|
||||
Literal[
|
||||
"amber",
|
||||
"blue",
|
||||
"bronze",
|
||||
"brown",
|
||||
"crimson",
|
||||
"cyan",
|
||||
"gold",
|
||||
"grass",
|
||||
"gray",
|
||||
"green",
|
||||
"indigo",
|
||||
"iris",
|
||||
"jade",
|
||||
"lime",
|
||||
"mint",
|
||||
"orange",
|
||||
"pink",
|
||||
"plum",
|
||||
"purple",
|
||||
"red",
|
||||
"ruby",
|
||||
"sky",
|
||||
"teal",
|
||||
"tomato",
|
||||
"violet",
|
||||
"yellow",
|
||||
]
|
||||
],
|
||||
]
|
||||
] = None,
|
||||
high_contrast: Optional[Union[Var[bool], bool]] = None,
|
||||
default_checked: Optional[Union[Var[bool], bool]] = None,
|
||||
checked: Optional[Union[Var[bool], bool]] = None,
|
||||
disabled: Optional[Union[Var[bool], bool]] = None,
|
||||
required: Optional[Union[Var[bool], bool]] = None,
|
||||
name: Optional[Union[Var[str], str]] = None,
|
||||
value: Optional[Union[Var[str], str]] = None,
|
||||
style: Optional[Style] = None,
|
||||
key: Optional[Any] = None,
|
||||
id: Optional[Any] = None,
|
||||
class_name: Optional[Any] = None,
|
||||
autofocus: Optional[bool] = None,
|
||||
custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None,
|
||||
on_blur: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_change: Optional[
|
||||
Union[EventType[[], BASE_STATE], EventType[[bool], BASE_STATE]]
|
||||
] = None,
|
||||
on_click: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_context_menu: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_double_click: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_focus: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_mount: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_mouse_down: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_mouse_move: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_mouse_out: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_mouse_over: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_mouse_up: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_scroll: Optional[EventType[[], BASE_STATE]] = None,
|
||||
on_unmount: Optional[EventType[[], BASE_STATE]] = None,
|
||||
**props,
|
||||
) -> "ContextMenuCheckbox":
|
||||
"""Create a new component instance.
|
||||
|
||||
Will prepend "RadixThemes" to the component tag to avoid conflicts with
|
||||
other UI libraries for common names, like Text and Button.
|
||||
|
||||
Args:
|
||||
*children: Child components.
|
||||
shortcut: Text to render as shortcut.
|
||||
as_child: Change the default rendered element for the one passed as a child, merging their props and behavior.
|
||||
size: Checkbox size "1" - "3"
|
||||
variant: Variant of checkbox: "classic" | "surface" | "soft"
|
||||
color_scheme: Override theme color for checkbox
|
||||
high_contrast: Whether to render the checkbox with higher contrast color against background
|
||||
default_checked: Whether the checkbox is checked by default
|
||||
checked: Whether the checkbox is checked
|
||||
disabled: Whether the checkbox is disabled
|
||||
required: Whether the checkbox is required
|
||||
name: The name of the checkbox control when submitting the form.
|
||||
value: The value of the checkbox control when submitting the form.
|
||||
on_change: Fired when the checkbox is checked or unchecked.
|
||||
style: The style of the component.
|
||||
key: A unique key for the component.
|
||||
id: The id for the component.
|
||||
class_name: The class name for the component.
|
||||
autofocus: Whether the component should take the focus once the page is loaded
|
||||
custom_attrs: custom attribute
|
||||
**props: Component properties.
|
||||
|
||||
Returns:
|
||||
A new component instance.
|
||||
"""
|
||||
...
|
||||
|
||||
class ContextMenu(ComponentNamespace):
|
||||
root = staticmethod(ContextMenuRoot.create)
|
||||
trigger = staticmethod(ContextMenuTrigger.create)
|
||||
@ -681,5 +835,6 @@ class ContextMenu(ComponentNamespace):
|
||||
sub_content = staticmethod(ContextMenuSubContent.create)
|
||||
item = staticmethod(ContextMenuItem.create)
|
||||
separator = staticmethod(ContextMenuSeparator.create)
|
||||
checkbox = staticmethod(ContextMenuCheckbox.create)
|
||||
|
||||
context_menu = ContextMenu()
|
||||
|
@ -9,7 +9,9 @@ from reflex.components.core.breakpoints import Responsive
|
||||
from reflex.components.core.debounce import DebounceInput
|
||||
from reflex.components.el import elements
|
||||
from reflex.event import EventHandler, input_event, key_event
|
||||
from reflex.utils.types import is_optional
|
||||
from reflex.vars.base import Var
|
||||
from reflex.vars.number import ternary_operation
|
||||
|
||||
from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent
|
||||
|
||||
@ -96,6 +98,19 @@ class TextFieldRoot(elements.Div, RadixThemesComponent):
|
||||
Returns:
|
||||
The component.
|
||||
"""
|
||||
value = props.get("value")
|
||||
|
||||
# React expects an empty string(instead of null) for controlled inputs.
|
||||
if value is not None and is_optional(
|
||||
(value_var := Var.create(value))._var_type
|
||||
):
|
||||
props["value"] = ternary_operation(
|
||||
(value_var != Var.create(None)) # pyright: ignore [reportGeneralTypeIssues]
|
||||
& (value_var != Var(_js_expr="undefined")),
|
||||
value,
|
||||
Var.create(""),
|
||||
)
|
||||
|
||||
component = super().create(*children, **props)
|
||||
if props.get("value") is not None and props.get("on_change") is not None:
|
||||
# create a debounced input if the user requests full control to avoid typing jank
|
||||
|
@ -684,6 +684,9 @@ class Config(Base):
|
||||
# Maximum expiration lock time for redis state manager
|
||||
redis_lock_expiration: int = constants.Expiration.LOCK
|
||||
|
||||
# Maximum lock time before warning for redis state manager.
|
||||
redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD
|
||||
|
||||
# Token expiration time for redis state manager
|
||||
redis_token_expiration: int = constants.Expiration.TOKEN
|
||||
|
||||
|
@ -29,6 +29,8 @@ class Expiration(SimpleNamespace):
|
||||
LOCK = 10000
|
||||
# The PING timeout
|
||||
PING = 120
|
||||
# The maximum time in milliseconds to hold a lock before throwing a warning.
|
||||
LOCK_WARNING_THRESHOLD = 1000
|
||||
|
||||
|
||||
class GitIgnore(SimpleNamespace):
|
||||
|
119
reflex/state.py
119
reflex/state.py
@ -11,6 +11,7 @@ import inspect
|
||||
import json
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
@ -39,6 +40,7 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from redis.asyncio.client import PubSub
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from typing_extensions import Self
|
||||
|
||||
@ -69,6 +71,11 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
BaseModelV1 = BaseModelV2
|
||||
|
||||
try:
|
||||
from pydantic.v1 import validator
|
||||
except ModuleNotFoundError:
|
||||
from pydantic import validator
|
||||
|
||||
import wrapt
|
||||
from redis.asyncio import Redis
|
||||
from redis.exceptions import ResponseError
|
||||
@ -92,6 +99,7 @@ from reflex.utils.exceptions import (
|
||||
DynamicRouteArgShadowsStateVar,
|
||||
EventHandlerShadowsBuiltInStateMethod,
|
||||
ImmutableStateError,
|
||||
InvalidLockWarningThresholdError,
|
||||
InvalidStateManagerMode,
|
||||
LockExpiredError,
|
||||
ReflexRuntimeError,
|
||||
@ -1107,6 +1115,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if (
|
||||
not field.required
|
||||
and field.default is None
|
||||
and field.default_factory is None
|
||||
and not types.is_optional(prop._var_type)
|
||||
):
|
||||
# Ensure frontend uses null coalescing when accessing.
|
||||
@ -2173,14 +2182,26 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
state["__dict__"].pop("router", None)
|
||||
state["__dict__"].pop("router_data", None)
|
||||
# Never serialize parent_state or substates.
|
||||
state["__dict__"]["parent_state"] = None
|
||||
state["__dict__"]["substates"] = {}
|
||||
state["__dict__"].pop("parent_state", None)
|
||||
state["__dict__"].pop("substates", None)
|
||||
state["__dict__"].pop("_was_touched", None)
|
||||
# Remove all inherited vars.
|
||||
for inherited_var_name in self.inherited_vars:
|
||||
state["__dict__"].pop(inherited_var_name, None)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]):
|
||||
"""Set the state from redis deserialization.
|
||||
|
||||
This method is called by pickle to deserialize the object.
|
||||
|
||||
Args:
|
||||
state: The state dict for deserialization.
|
||||
"""
|
||||
state["__dict__"]["parent_state"] = None
|
||||
state["__dict__"]["substates"] = {}
|
||||
super().__setstate__(state)
|
||||
|
||||
def _check_state_size(
|
||||
self,
|
||||
pickle_state_size: int,
|
||||
@ -2870,6 +2891,7 @@ class StateManager(Base, ABC):
|
||||
redis=redis,
|
||||
token_expiration=config.redis_token_expiration,
|
||||
lock_expiration=config.redis_lock_expiration,
|
||||
lock_warning_threshold=config.redis_lock_warning_threshold,
|
||||
)
|
||||
raise InvalidStateManagerMode(
|
||||
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
|
||||
@ -3239,6 +3261,15 @@ def _default_lock_expiration() -> int:
|
||||
return get_config().redis_lock_expiration
|
||||
|
||||
|
||||
def _default_lock_warning_threshold() -> int:
|
||||
"""Get the default lock warning threshold.
|
||||
|
||||
Returns:
|
||||
The default lock warning threshold.
|
||||
"""
|
||||
return get_config().redis_lock_warning_threshold
|
||||
|
||||
|
||||
class StateManagerRedis(StateManager):
|
||||
"""A state manager that stores states in redis."""
|
||||
|
||||
@ -3251,6 +3282,11 @@ class StateManagerRedis(StateManager):
|
||||
# The maximum time to hold a lock (ms).
|
||||
lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration)
|
||||
|
||||
# The maximum time to hold a lock (ms) before warning.
|
||||
lock_warning_threshold: int = pydantic.Field(
|
||||
default_factory=_default_lock_warning_threshold
|
||||
)
|
||||
|
||||
# If HEXPIRE is not supported, use EXPIRE instead.
|
||||
_hexpire_not_supported: Optional[bool] = pydantic.PrivateAttr(None)
|
||||
|
||||
@ -3413,6 +3449,17 @@ class StateManagerRedis(StateManager):
|
||||
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
||||
"or use `@rx.event(background=True)` decorator for long-running tasks."
|
||||
)
|
||||
elif lock_id is not None:
|
||||
time_taken = self.lock_expiration / 1000 - (
|
||||
await self.redis.ttl(self._lock_key(token))
|
||||
)
|
||||
if time_taken > self.lock_warning_threshold / 1000:
|
||||
console.warn(
|
||||
f"Lock for token {token} was held too long {time_taken=}s, "
|
||||
f"use `@rx.event(background=True)` decorator for long-running tasks.",
|
||||
dedupe=True,
|
||||
)
|
||||
|
||||
client_token, substate_name = _split_substate_key(token)
|
||||
# If the substate name on the token doesn't match the instance name, it cannot have a parent.
|
||||
if state.parent_state is not None and state.get_full_name() != substate_name:
|
||||
@ -3477,6 +3524,27 @@ class StateManagerRedis(StateManager):
|
||||
yield state
|
||||
await self.set_state(token, state, lock_id)
|
||||
|
||||
@validator("lock_warning_threshold")
|
||||
@classmethod
|
||||
def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values):
|
||||
"""Validate the lock warning threshold.
|
||||
|
||||
Args:
|
||||
lock_warning_threshold: The lock warning threshold.
|
||||
values: The validated attributes.
|
||||
|
||||
Returns:
|
||||
The lock warning threshold.
|
||||
|
||||
Raises:
|
||||
InvalidLockWarningThresholdError: If the lock warning threshold is invalid.
|
||||
"""
|
||||
if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]):
|
||||
raise InvalidLockWarningThresholdError(
|
||||
f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})."
|
||||
)
|
||||
return lock_warning_threshold
|
||||
|
||||
@staticmethod
|
||||
def _lock_key(token: str) -> bytes:
|
||||
"""Get the redis key for a token's lock.
|
||||
@ -3508,6 +3576,35 @@ class StateManagerRedis(StateManager):
|
||||
nx=True, # only set if it doesn't exist
|
||||
)
|
||||
|
||||
async def _get_pubsub_message(
|
||||
self, pubsub: PubSub, timeout: float | None = None
|
||||
) -> None:
|
||||
"""Get lock release events from the pubsub.
|
||||
|
||||
Args:
|
||||
pubsub: The pubsub to get a message from.
|
||||
timeout: Remaining time to wait for a message.
|
||||
|
||||
Returns:
|
||||
The message.
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = self.lock_expiration / 1000.0
|
||||
|
||||
started = time.time()
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
if (
|
||||
message is None
|
||||
or message["data"] not in self._redis_keyspace_lock_release_events
|
||||
):
|
||||
remaining = timeout - (time.time() - started)
|
||||
if remaining <= 0:
|
||||
return
|
||||
await self._get_pubsub_message(pubsub, timeout=remaining)
|
||||
|
||||
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
|
||||
"""Wait for a redis lock to be released via pubsub.
|
||||
|
||||
@ -3520,7 +3617,6 @@ class StateManagerRedis(StateManager):
|
||||
Raises:
|
||||
ResponseError: when the keyspace config cannot be set.
|
||||
"""
|
||||
state_is_locked = False
|
||||
lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
|
||||
# Enable keyspace notifications for the lock key, so we know when it is available.
|
||||
try:
|
||||
@ -3534,20 +3630,13 @@ class StateManagerRedis(StateManager):
|
||||
raise
|
||||
async with self.redis.pubsub() as pubsub:
|
||||
await pubsub.psubscribe(lock_key_channel)
|
||||
while not state_is_locked:
|
||||
# wait for the lock to be released
|
||||
while True:
|
||||
if not await self.redis.exists(lock_key):
|
||||
break # key was removed, try to get the lock again
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True,
|
||||
timeout=self.lock_expiration / 1000.0,
|
||||
)
|
||||
if message is None:
|
||||
continue
|
||||
if message["data"] in self._redis_keyspace_lock_release_events:
|
||||
break
|
||||
state_is_locked = await self._try_get_lock(lock_key, lock_id)
|
||||
# fast path
|
||||
if await self._try_get_lock(lock_key, lock_id):
|
||||
return
|
||||
# wait for lock events
|
||||
await self._get_pubsub_message(pubsub)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _lock(self, token: str):
|
||||
|
@ -20,6 +20,24 @@ _EMITTED_DEPRECATION_WARNINGS = set()
|
||||
# Info messages which have been printed.
|
||||
_EMITTED_INFO = set()
|
||||
|
||||
# Warnings which have been printed.
|
||||
_EMIITED_WARNINGS = set()
|
||||
|
||||
# Errors which have been printed.
|
||||
_EMITTED_ERRORS = set()
|
||||
|
||||
# Success messages which have been printed.
|
||||
_EMITTED_SUCCESS = set()
|
||||
|
||||
# Debug messages which have been printed.
|
||||
_EMITTED_DEBUG = set()
|
||||
|
||||
# Logs which have been printed.
|
||||
_EMITTED_LOGS = set()
|
||||
|
||||
# Prints which have been printed.
|
||||
_EMITTED_PRINTS = set()
|
||||
|
||||
|
||||
def set_log_level(log_level: LogLevel):
|
||||
"""Set the log level.
|
||||
@ -55,25 +73,37 @@ def is_debug() -> bool:
|
||||
return _LOG_LEVEL <= LogLevel.DEBUG
|
||||
|
||||
|
||||
def print(msg: str, **kwargs):
|
||||
def print(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Print a message.
|
||||
|
||||
Args:
|
||||
msg: The message to print.
|
||||
dedupe: If True, suppress multiple console logs of print message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if dedupe:
|
||||
if msg in _EMITTED_PRINTS:
|
||||
return
|
||||
else:
|
||||
_EMITTED_PRINTS.add(msg)
|
||||
_console.print(msg, **kwargs)
|
||||
|
||||
|
||||
def debug(msg: str, **kwargs):
|
||||
def debug(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Print a debug message.
|
||||
|
||||
Args:
|
||||
msg: The debug message.
|
||||
dedupe: If True, suppress multiple console logs of debug message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if is_debug():
|
||||
msg_ = f"[purple]Debug: {msg}[/purple]"
|
||||
if dedupe:
|
||||
if msg_ in _EMITTED_DEBUG:
|
||||
return
|
||||
else:
|
||||
_EMITTED_DEBUG.add(msg_)
|
||||
if progress := kwargs.pop("progress", None):
|
||||
progress.console.print(msg_, **kwargs)
|
||||
else:
|
||||
@ -97,25 +127,37 @@ def info(msg: str, dedupe: bool = False, **kwargs):
|
||||
print(f"[cyan]Info: {msg}[/cyan]", **kwargs)
|
||||
|
||||
|
||||
def success(msg: str, **kwargs):
|
||||
def success(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Print a success message.
|
||||
|
||||
Args:
|
||||
msg: The success message.
|
||||
dedupe: If True, suppress multiple console logs of success message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if _LOG_LEVEL <= LogLevel.INFO:
|
||||
if dedupe:
|
||||
if msg in _EMITTED_SUCCESS:
|
||||
return
|
||||
else:
|
||||
_EMITTED_SUCCESS.add(msg)
|
||||
print(f"[green]Success: {msg}[/green]", **kwargs)
|
||||
|
||||
|
||||
def log(msg: str, **kwargs):
|
||||
def log(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Takes a string and logs it to the console.
|
||||
|
||||
Args:
|
||||
msg: The message to log.
|
||||
dedupe: If True, suppress multiple console logs of log message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if _LOG_LEVEL <= LogLevel.INFO:
|
||||
if dedupe:
|
||||
if msg in _EMITTED_LOGS:
|
||||
return
|
||||
else:
|
||||
_EMITTED_LOGS.add(msg)
|
||||
_console.log(msg, **kwargs)
|
||||
|
||||
|
||||
@ -129,14 +171,20 @@ def rule(title: str, **kwargs):
|
||||
_console.rule(title, **kwargs)
|
||||
|
||||
|
||||
def warn(msg: str, **kwargs):
|
||||
def warn(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Print a warning message.
|
||||
|
||||
Args:
|
||||
msg: The warning message.
|
||||
dedupe: If True, suppress multiple console logs of warning message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if _LOG_LEVEL <= LogLevel.WARNING:
|
||||
if dedupe:
|
||||
if msg in _EMIITED_WARNINGS:
|
||||
return
|
||||
else:
|
||||
_EMIITED_WARNINGS.add(msg)
|
||||
print(f"[orange1]Warning: {msg}[/orange1]", **kwargs)
|
||||
|
||||
|
||||
@ -169,14 +217,20 @@ def deprecate(
|
||||
_EMITTED_DEPRECATION_WARNINGS.add(feature_name)
|
||||
|
||||
|
||||
def error(msg: str, **kwargs):
|
||||
def error(msg: str, dedupe: bool = False, **kwargs):
|
||||
"""Print an error message.
|
||||
|
||||
Args:
|
||||
msg: The error message.
|
||||
dedupe: If True, suppress multiple console logs of error message.
|
||||
kwargs: Keyword arguments to pass to the print function.
|
||||
"""
|
||||
if _LOG_LEVEL <= LogLevel.ERROR:
|
||||
if dedupe:
|
||||
if msg in _EMITTED_ERRORS:
|
||||
return
|
||||
else:
|
||||
_EMITTED_ERRORS.add(msg)
|
||||
print(f"[red]{msg}[/red]", **kwargs)
|
||||
|
||||
|
||||
|
@ -183,3 +183,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn:
|
||||
" Please install it through your system package manager."
|
||||
+ (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
|
||||
)
|
||||
|
||||
|
||||
class InvalidLockWarningThresholdError(ReflexError):
|
||||
"""Raised when an invalid lock warning threshold is provided."""
|
||||
|
@ -664,18 +664,22 @@ def format_library_name(library_fullname: str):
|
||||
return lib
|
||||
|
||||
|
||||
def json_dumps(obj: Any) -> str:
|
||||
def json_dumps(obj: Any, **kwargs) -> str:
|
||||
"""Takes an object and returns a jsonified string.
|
||||
|
||||
Args:
|
||||
obj: The object to be serialized.
|
||||
kwargs: Additional keyword arguments to pass to json.dumps.
|
||||
|
||||
Returns:
|
||||
A string
|
||||
"""
|
||||
from reflex.utils import serializers
|
||||
|
||||
return json.dumps(obj, ensure_ascii=False, default=serializers.serialize)
|
||||
kwargs.setdefault("ensure_ascii", False)
|
||||
kwargs.setdefault("default", serializers.serialize)
|
||||
|
||||
return json.dumps(obj, **kwargs)
|
||||
|
||||
|
||||
def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
|
@ -6,6 +6,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import reflex.app
|
||||
from reflex.config import environment
|
||||
from reflex.testing import AppHarness, AppHarnessProd
|
||||
|
||||
@ -76,3 +77,25 @@ def app_harness_env(request):
|
||||
The AppHarness class to use for the test.
|
||||
"""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def raise_console_error(request, mocker):
|
||||
"""Spy on calls to `console.error` used by the framework.
|
||||
|
||||
Help catch spurious error conditions that might otherwise go unnoticed.
|
||||
|
||||
If a test is marked with `ignore_console_error`, the spy will be ignored
|
||||
after the test.
|
||||
|
||||
Args:
|
||||
request: The pytest request object.
|
||||
mocker: The pytest mocker object.
|
||||
|
||||
Yields:
|
||||
control to the test function.
|
||||
"""
|
||||
spy = mocker.spy(reflex.app.console, "error")
|
||||
yield
|
||||
if "ignore_console_error" not in request.keywords:
|
||||
spy.assert_not_called()
|
||||
|
@ -628,8 +628,7 @@ async def test_client_side_state(
|
||||
assert await AppHarness._poll_for_async(poll_for_not_hydrated)
|
||||
|
||||
# Trigger event to get a new instance of the state since the old was expired.
|
||||
state_var_input = driver.find_element(By.ID, "state_var")
|
||||
state_var_input.send_keys("re-triggering")
|
||||
set_sub("c1", "c1 post expire")
|
||||
|
||||
# get new references to all cookie and local storage elements (again)
|
||||
c1 = driver.find_element(By.ID, "c1")
|
||||
@ -650,7 +649,7 @@ async def test_client_side_state(
|
||||
l1s = driver.find_element(By.ID, "l1s")
|
||||
s1s = driver.find_element(By.ID, "s1s")
|
||||
|
||||
assert c1.text == "c1 value"
|
||||
assert c1.text == "c1 post expire"
|
||||
assert c2.text == "c2 value"
|
||||
assert c3.text == "" # temporary cookie expired after reset state!
|
||||
assert c4.text == "c4 value"
|
||||
@ -680,11 +679,11 @@ async def test_client_side_state(
|
||||
|
||||
async def poll_for_c1_set():
|
||||
sub_state = await get_sub_state()
|
||||
return sub_state.c1 == "c1 value"
|
||||
return sub_state.c1 == "c1 post expire"
|
||||
|
||||
assert await AppHarness._poll_for_async(poll_for_c1_set)
|
||||
sub_state = await get_sub_state()
|
||||
assert sub_state.c1 == "c1 value"
|
||||
assert sub_state.c1 == "c1 post expire"
|
||||
assert sub_state.c2 == "c2 value"
|
||||
assert sub_state.c3 == ""
|
||||
assert sub_state.c4 == "c4 value"
|
||||
|
@ -13,6 +13,8 @@ from selenium.webdriver.support.ui import WebDriverWait
|
||||
|
||||
from reflex.testing import AppHarness, AppHarnessProd
|
||||
|
||||
pytestmark = [pytest.mark.ignore_console_error]
|
||||
|
||||
|
||||
def TestApp():
|
||||
"""A test app for event exception handler integration."""
|
||||
|
@ -381,9 +381,22 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
|
||||
await asyncio.sleep(0.3)
|
||||
cancel_button.click()
|
||||
|
||||
# look up the backend state and assert on progress
|
||||
# Wait a bit for the upload to get cancelled.
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Get interim progress dicts saved in the on_upload_progress handler.
|
||||
async def _progress_dicts():
|
||||
state = await upload_file.get_state(substate_token)
|
||||
return state.substates[state_name].progress_dicts
|
||||
|
||||
# We should have _some_ progress
|
||||
assert await AppHarness._poll_for_async(_progress_dicts)
|
||||
|
||||
# But there should never be a final progress record for a cancelled upload.
|
||||
for p in await _progress_dicts():
|
||||
assert p["progress"] != 1
|
||||
|
||||
state = await upload_file.get_state(substate_token)
|
||||
assert state.substates[state_name].progress_dicts
|
||||
file_data = state.substates[state_name]._file_data
|
||||
assert isinstance(file_data, dict)
|
||||
normalized_file_data = {Path(k).name: v for k, v in file_data.items()}
|
||||
|
@ -1,8 +1,10 @@
|
||||
from typing import Dict, List, Set, Tuple, Union
|
||||
|
||||
import pydantic.v1
|
||||
import pytest
|
||||
|
||||
from reflex import el
|
||||
from reflex.base import Base
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.core.foreach import (
|
||||
Foreach,
|
||||
@ -18,6 +20,12 @@ from reflex.vars.number import NumberVar
|
||||
from reflex.vars.sequence import ArrayVar
|
||||
|
||||
|
||||
class ForEachTag(Base):
|
||||
"""A tag for testing the ForEach component."""
|
||||
|
||||
name: str = ""
|
||||
|
||||
|
||||
class ForEachState(BaseState):
|
||||
"""A state for testing the ForEach component."""
|
||||
|
||||
@ -46,6 +54,8 @@ class ForEachState(BaseState):
|
||||
bad_annotation_list: list = [["red", "orange"], ["yellow", "blue"]]
|
||||
color_index_tuple: Tuple[int, str] = (0, "red")
|
||||
|
||||
default_factory_list: list[ForEachTag] = pydantic.v1.Field(default_factory=list)
|
||||
|
||||
|
||||
class ComponentStateTest(ComponentState):
|
||||
"""A test component state."""
|
||||
@ -290,3 +300,11 @@ def test_foreach_component_state():
|
||||
ForEachState.colors_list,
|
||||
ComponentStateTest.create,
|
||||
)
|
||||
|
||||
|
||||
def test_foreach_default_factory():
|
||||
"""Test that the default factory is called."""
|
||||
_ = Foreach.create(
|
||||
ForEachState.default_factory_list,
|
||||
lambda tag: text(tag.name),
|
||||
)
|
||||
|
@ -56,6 +56,7 @@ from reflex.state import (
|
||||
from reflex.testing import chdir
|
||||
from reflex.utils import format, prerequisites, types
|
||||
from reflex.utils.exceptions import (
|
||||
InvalidLockWarningThresholdError,
|
||||
ReflexRuntimeError,
|
||||
SetUndefinedStateVarError,
|
||||
StateSerializationError,
|
||||
@ -67,7 +68,9 @@ from tests.units.states.mutation import MutableSQLAModel, MutableTestState
|
||||
from .states import GenState
|
||||
|
||||
CI = bool(os.environ.get("CI", False))
|
||||
LOCK_EXPIRATION = 2000 if CI else 300
|
||||
LOCK_EXPIRATION = 2500 if CI else 300
|
||||
LOCK_WARNING_THRESHOLD = 1000 if CI else 100
|
||||
LOCK_WARN_SLEEP = 1.5 if CI else 0.15
|
||||
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
|
||||
|
||||
|
||||
@ -1787,6 +1790,7 @@ async def test_state_manager_lock_expire(
|
||||
substate_token_redis: A token + substate name for looking up in state manager.
|
||||
"""
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||
|
||||
async with state_manager_redis.modify_state(substate_token_redis):
|
||||
await asyncio.sleep(0.01)
|
||||
@ -1811,6 +1815,7 @@ async def test_state_manager_lock_expire_contend(
|
||||
unexp_num1 = 666
|
||||
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||
|
||||
order = []
|
||||
|
||||
@ -1840,6 +1845,57 @@ async def test_state_manager_lock_expire_contend(
|
||||
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_lock_warning_threshold_contend(
|
||||
state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker
|
||||
):
|
||||
"""Test that the state manager triggers a warning when lock contention exceeds the warning threshold.
|
||||
|
||||
Args:
|
||||
state_manager_redis: A state manager instance.
|
||||
token: A token.
|
||||
substate_token_redis: A token + substate name for looking up in state manager.
|
||||
mocker: Pytest mocker object.
|
||||
"""
|
||||
console_warn = mocker.patch("reflex.utils.console.warn")
|
||||
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||
|
||||
order = []
|
||||
|
||||
async def _coro_blocker():
|
||||
async with state_manager_redis.modify_state(substate_token_redis):
|
||||
order.append("blocker")
|
||||
await asyncio.sleep(LOCK_WARN_SLEEP)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(_coro_blocker()),
|
||||
]
|
||||
|
||||
await tasks[0]
|
||||
console_warn.assert_called()
|
||||
assert console_warn.call_count == 7
|
||||
|
||||
|
||||
class CopyingAsyncMock(AsyncMock):
|
||||
"""An AsyncMock, but deepcopy the args and kwargs first."""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Call the mock.
|
||||
|
||||
Args:
|
||||
args: the arguments passed to the mock
|
||||
kwargs: the keyword arguments passed to the mock
|
||||
|
||||
Returns:
|
||||
The result of the mock call
|
||||
"""
|
||||
args = copy.deepcopy(args)
|
||||
kwargs = copy.deepcopy(kwargs)
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_app_simple(monkeypatch) -> rx.App:
|
||||
"""Simple Mock app fixture.
|
||||
@ -1856,7 +1912,7 @@ def mock_app_simple(monkeypatch) -> rx.App:
|
||||
|
||||
setattr(app_module, CompileVars.APP, app)
|
||||
app.state = TestState
|
||||
app.event_namespace.emit = AsyncMock() # type: ignore
|
||||
app.event_namespace.emit = CopyingAsyncMock() # type: ignore
|
||||
|
||||
def _mock_get_app(*args, **kwargs):
|
||||
return app_module
|
||||
@ -1960,8 +2016,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
mock_app.event_namespace.emit.assert_called_once()
|
||||
mcall = mock_app.event_namespace.emit.mock_calls[0]
|
||||
assert mcall.args[0] == str(SocketEvent.EVENT)
|
||||
assert json.loads(mcall.args[1]) == dataclasses.asdict(
|
||||
StateUpdate(
|
||||
assert mcall.args[1] == StateUpdate(
|
||||
delta={
|
||||
parent_state.get_full_name(): {
|
||||
"upper": "",
|
||||
@ -1975,7 +2030,6 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
|
||||
|
||||
|
||||
@ -2156,51 +2210,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
assert mock_app.event_namespace is not None
|
||||
emit_mock = mock_app.event_namespace.emit
|
||||
|
||||
first_ws_message = json.loads(emit_mock.mock_calls[0].args[1])
|
||||
first_ws_message = emit_mock.mock_calls[0].args[1]
|
||||
assert (
|
||||
first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router")
|
||||
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
|
||||
is not None
|
||||
)
|
||||
assert first_ws_message == {
|
||||
"delta": {
|
||||
assert first_ws_message == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_full_name(): {
|
||||
"order": ["background_task:start"],
|
||||
"computed_order": ["background_task:start"],
|
||||
}
|
||||
},
|
||||
"events": [],
|
||||
"final": True,
|
||||
}
|
||||
events=[],
|
||||
final=True,
|
||||
)
|
||||
for call in emit_mock.mock_calls[1:5]:
|
||||
assert json.loads(call.args[1]) == {
|
||||
"delta": {
|
||||
assert call.args[1] == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_full_name(): {
|
||||
"computed_order": ["background_task:start"],
|
||||
}
|
||||
},
|
||||
"events": [],
|
||||
"final": True,
|
||||
}
|
||||
assert json.loads(emit_mock.mock_calls[-2].args[1]) == {
|
||||
"delta": {
|
||||
events=[],
|
||||
final=True,
|
||||
)
|
||||
assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_full_name(): {
|
||||
"order": exp_order,
|
||||
"computed_order": exp_order,
|
||||
"dict_list": {},
|
||||
}
|
||||
},
|
||||
"events": [],
|
||||
"final": True,
|
||||
}
|
||||
assert json.loads(emit_mock.mock_calls[-1].args[1]) == {
|
||||
"delta": {
|
||||
events=[],
|
||||
final=True,
|
||||
)
|
||||
assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_full_name(): {
|
||||
"computed_order": exp_order,
|
||||
},
|
||||
},
|
||||
"events": [],
|
||||
"final": True,
|
||||
}
|
||||
events=[],
|
||||
final=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -3246,12 +3300,42 @@ async def test_setvar_async_setter():
|
||||
@pytest.mark.parametrize(
|
||||
"expiration_kwargs, expected_values",
|
||||
[
|
||||
({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)),
|
||||
(
|
||||
{"redis_lock_expiration": 20000},
|
||||
(
|
||||
20000,
|
||||
constants.Expiration.TOKEN,
|
||||
constants.Expiration.LOCK_WARNING_THRESHOLD,
|
||||
),
|
||||
),
|
||||
(
|
||||
{"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
|
||||
(50000, 5600),
|
||||
(50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD),
|
||||
),
|
||||
(
|
||||
{"redis_token_expiration": 7600},
|
||||
(
|
||||
constants.Expiration.LOCK,
|
||||
7600,
|
||||
constants.Expiration.LOCK_WARNING_THRESHOLD,
|
||||
),
|
||||
),
|
||||
(
|
||||
{"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500},
|
||||
(50000, constants.Expiration.TOKEN, 1500),
|
||||
),
|
||||
(
|
||||
{"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000},
|
||||
(constants.Expiration.LOCK, 5600, 3000),
|
||||
),
|
||||
(
|
||||
{
|
||||
"redis_lock_expiration": 50000,
|
||||
"redis_token_expiration": 5600,
|
||||
"redis_lock_warning_threshold": 2000,
|
||||
},
|
||||
(50000, 5600, 2000),
|
||||
),
|
||||
({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
|
||||
],
|
||||
)
|
||||
def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
|
||||
@ -3281,6 +3365,44 @@ config = rx.Config(
|
||||
state_manager = StateManager.create(state=State)
|
||||
assert state_manager.lock_expiration == expected_values[0] # type: ignore
|
||||
assert state_manager.token_expiration == expected_values[1] # type: ignore
|
||||
assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
|
||||
@pytest.mark.parametrize(
|
||||
"redis_lock_expiration, redis_lock_warning_threshold",
|
||||
[
|
||||
(10000, 10000),
|
||||
(20000, 30000),
|
||||
],
|
||||
)
|
||||
def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold(
|
||||
tmp_path, redis_lock_expiration, redis_lock_warning_threshold
|
||||
):
|
||||
proj_root = tmp_path / "project1"
|
||||
proj_root.mkdir()
|
||||
|
||||
config_string = f"""
|
||||
import reflex as rx
|
||||
config = rx.Config(
|
||||
app_name="project1",
|
||||
redis_url="redis://localhost:6379",
|
||||
state_manager_mode="redis",
|
||||
redis_lock_expiration = {redis_lock_expiration},
|
||||
redis_lock_warning_threshold = {redis_lock_warning_threshold},
|
||||
)
|
||||
"""
|
||||
|
||||
(proj_root / "rxconfig.py").write_text(dedent(config_string))
|
||||
|
||||
with chdir(proj_root):
|
||||
# reload config for each parameter to avoid stale values
|
||||
reflex.config.get_config(reload=True)
|
||||
from reflex.state import State, StateManager
|
||||
|
||||
with pytest.raises(InvalidLockWarningThresholdError):
|
||||
StateManager.create(state=State)
|
||||
del sys.modules[constants.Config.MODULE]
|
||||
|
||||
|
||||
class MixinState(State, mixin=True):
|
||||
|
Loading…
Reference in New Issue
Block a user