diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index aac67f7a6..f743b7cbd 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -80,7 +80,7 @@ jobs: fail-fast: false matrix: # Show OS combos first in GUI - os: [ubuntu-latest, windows-latest, macos-12] + os: [ubuntu-latest, windows-latest, macos-latest] python-version: ['3.9.18', '3.10.13', '3.11.5', '3.12.0'] exclude: - os: windows-latest @@ -92,7 +92,7 @@ jobs: python-version: '3.9.18' - os: macos-latest python-version: '3.10.13' - - os: macos-12 + - os: macos-latest python-version: '3.12.0' include: - os: windows-latest @@ -155,7 +155,7 @@ jobs: fail-fast: false matrix: # Show OS combos first in GUI - os: [ubuntu-latest, windows-latest, macos-12] + os: [ubuntu-latest, windows-latest, macos-latest] python-version: ['3.11.5'] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 2ed68ad9f..b2304d463 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -198,7 +198,7 @@ jobs: fail-fast: false matrix: python-version: ['3.11.5', '3.12.0'] - runs-on: macos-12 + runs-on: macos-latest steps: - uses: actions/checkout@v4 - uses: ./.github/actions/setup_build_env diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index a6e39354c..25f5723f3 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -88,8 +88,9 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.9.18', '3.10.13', '3.11.5', '3.12.0', '3.13.0'] - runs-on: macos-12 + # Note: py39, py310 versions chosen due to available arm64 darwin builds. + python-version: ['3.9.13', '3.10.11', '3.11.5', '3.12.0', '3.13.0'] + runs-on: macos-latest steps: - uses: actions/checkout@v4 - uses: ./.github/actions/setup_build_env diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index e135c7c0b..608df084a 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -40,9 +40,6 @@ let event_processing = false; // Array holding pending events to be processed. const event_queue = []; -// Pending upload promises, by id -const upload_controllers = {}; - /** * Generate a UUID (Used for session tokens). * Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid @@ -300,7 +297,7 @@ export const applyEvent = async (event, socket) => { if (socket) { socket.emit( "event", - JSON.stringify(event, (k, v) => (v === undefined ? null : v)) + event, ); return true; } @@ -407,6 +404,8 @@ export const connect = async ( transports: transports, autoUnref: false, }); + // Ensure undefined fields in events are sent as null instead of removed + socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v) function checkVisibility() { if (document.visibilityState === "visible") { @@ -443,8 +442,7 @@ export const connect = async ( }); // On each received message, queue the updates and events. - socket.current.on("event", async (message) => { - const update = JSON5.parse(message); + socket.current.on("event", async (update) => { for (const substate in update.delta) { dispatch[substate](update.delta[substate]); } @@ -456,7 +454,7 @@ export const connect = async ( }); socket.current.on("reload", async (event) => { event_processing = false; - queueEvents([...initialEvents(), JSON5.parse(event)], socket); + queueEvents([...initialEvents(), event], socket); }); document.addEventListener("visibilitychange", checkVisibility); @@ -485,7 +483,9 @@ export const uploadFiles = async ( return false; } - if (upload_controllers[upload_id]) { + const upload_ref_name = `__upload_controllers_${upload_id}` + + if (refs[upload_ref_name]) { console.log("Upload already in progress for ", upload_id); return false; } @@ -497,23 +497,31 @@ export const uploadFiles = async ( // Whenever called, responseText will contain the entire response so far. const chunks = progressEvent.event.target.responseText.trim().split("\n"); // So only process _new_ chunks beyond resp_idx. - chunks.slice(resp_idx).map((chunk) => { - event_callbacks.map((f, ix) => { - f(chunk) - .then(() => { - if (ix === event_callbacks.length - 1) { - // Mark this chunk as processed. - resp_idx += 1; - } - }) - .catch((e) => { - if (progressEvent.progress === 1) { - // Chunk may be incomplete, so only report errors when full response is available. - console.log("Error parsing chunk", chunk, e); - } - return; - }); - }); + chunks.slice(resp_idx).map((chunk_json) => { + try { + const chunk = JSON5.parse(chunk_json); + event_callbacks.map((f, ix) => { + f(chunk) + .then(() => { + if (ix === event_callbacks.length - 1) { + // Mark this chunk as processed. + resp_idx += 1; + } + }) + .catch((e) => { + if (progressEvent.progress === 1) { + // Chunk may be incomplete, so only report errors when full response is available. + console.log("Error processing chunk", chunk, e); + } + return; + }); + }); + } catch (e) { + if (progressEvent.progress === 1) { + console.log("Error parsing chunk", chunk_json, e); + } + return; + } }); }; @@ -537,7 +545,7 @@ export const uploadFiles = async ( }); // Send the file to the server. - upload_controllers[upload_id] = controller; + refs[upload_ref_name] = controller; try { return await axios.post(getBackendURL(UPLOADURL), formdata, config); @@ -557,7 +565,7 @@ export const uploadFiles = async ( } return false; } finally { - delete upload_controllers[upload_id]; + delete refs[upload_ref_name]; } }; diff --git a/reflex/app.py b/reflex/app.py index 42808823a..10dd889b3 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -17,6 +17,7 @@ import sys import traceback from datetime import datetime from pathlib import Path +from types import SimpleNamespace from typing import ( TYPE_CHECKING, Any, @@ -363,6 +364,10 @@ class App(MiddlewareMixin, LifespanMixin): max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, ping_interval=constants.Ping.INTERVAL, ping_timeout=constants.Ping.TIMEOUT, + json=SimpleNamespace( + dumps=staticmethod(format.json_dumps), + loads=staticmethod(json.loads), + ), transports=["websocket"], ) elif getattr(self.sio, "async_mode", "") != "asgi": @@ -1290,7 +1295,7 @@ async def process( await asyncio.create_task( app.event_namespace.emit( "reload", - data=format.json_dumps(event), + data=event, to=sid, ) ) @@ -1543,7 +1548,7 @@ class EventNamespace(AsyncNamespace): """ # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( - self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) + self.emit(str(constants.SocketEvent.EVENT), update, to=sid) ) async def on_event(self, sid, data): @@ -1556,7 +1561,7 @@ class EventNamespace(AsyncNamespace): sid: The Socket.IO session id. data: The event data. """ - fields = json.loads(data) + fields = data # Get the event. event = Event( **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")} diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index b5b701d6d..14205cc6b 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -29,7 +29,7 @@ from reflex.event import ( from reflex.utils import format from reflex.utils.imports import ImportVar from reflex.vars import VarData -from reflex.vars.base import CallableVar, LiteralVar, Var, get_unique_variable_name +from reflex.vars.base import CallableVar, Var, get_unique_variable_name from reflex.vars.sequence import LiteralStringVar DEFAULT_UPLOAD_ID: str = "default" @@ -108,7 +108,8 @@ def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: # UploadFilesProvider assigns a special function to clear selected files # into the shared global refs object to make it accessible outside a React # component via `run_script` (otherwise backend could never clear files). - return run_script(f"refs['__clear_selected_files']({id_!r})") + func = Var("__clear_selected_files")._as_ref() + return run_script(f"{func}({id_!r})") def cancel_upload(upload_id: str) -> EventSpec: @@ -120,7 +121,8 @@ def cancel_upload(upload_id: str) -> EventSpec: Returns: An event spec that cancels the upload when triggered. """ - return run_script(f"upload_controllers[{LiteralVar.create(upload_id)!s}]?.abort()") + controller = Var(f"__upload_controllers_{upload_id}")._as_ref() + return run_script(f"{controller}?.abort()") def get_upload_dir() -> Path: diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 205aae267..61ded4fd3 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -18,6 +18,7 @@ from reflex.event import ( prevent_default, ) from reflex.utils.imports import ImportDict +from reflex.utils.types import is_optional from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var @@ -382,6 +383,33 @@ class Input(BaseHTML): # Fired when a key is released on_key_up: EventHandler[key_event] + @classmethod + def create(cls, *children, **props): + """Create an Input component. + + Args: + *children: The children of the component. + **props: The properties of the component. + + Returns: + The component. + """ + from reflex.vars.number import ternary_operation + + value = props.get("value") + + # React expects an empty string(instead of null) for controlled inputs. + if value is not None and is_optional( + (value_var := Var.create(value))._var_type + ): + props["value"] = ternary_operation( + (value_var != Var.create(None)) # pyright: ignore [reportGeneralTypeIssues] + & (value_var != Var(_js_expr="undefined")), + value, + Var.create(""), + ) + return super().create(*children, **props) + class Label(BaseHTML): """Display the label element.""" diff --git a/reflex/components/el/elements/forms.pyi b/reflex/components/el/elements/forms.pyi index 5870d4b22..dfab40b21 100644 --- a/reflex/components/el/elements/forms.pyi +++ b/reflex/components/el/elements/forms.pyi @@ -512,7 +512,7 @@ class Input(BaseHTML): on_unmount: Optional[EventType[[], BASE_STATE]] = None, **props, ) -> "Input": - """Create the component. + """Create an Input component. Args: *children: The children of the component. @@ -576,7 +576,7 @@ class Input(BaseHTML): class_name: The class name for the component. autofocus: Whether the component should take the focus once the page is loaded custom_attrs: custom attribute - **props: The props of the component. + **props: The properties of the component. Returns: The component. diff --git a/reflex/components/radix/themes/components/context_menu.py b/reflex/components/radix/themes/components/context_menu.py index ea4902233..f8512a902 100644 --- a/reflex/components/radix/themes/components/context_menu.py +++ b/reflex/components/radix/themes/components/context_menu.py @@ -8,6 +8,7 @@ from reflex.event import EventHandler, no_args_event_spec, passthrough_event_spe from reflex.vars.base import Var from ..base import LiteralAccentColor, RadixThemesComponent +from .checkbox import Checkbox LiteralDirType = Literal["ltr", "rtl"] @@ -232,6 +233,15 @@ class ContextMenuSeparator(RadixThemesComponent): tag = "ContextMenu.Separator" +class ContextMenuCheckbox(Checkbox): + """The component that contains the checkbox.""" + + tag = "ContextMenu.CheckboxItem" + + # Text to render as shortcut. + shortcut: Var[str] + + class ContextMenu(ComponentNamespace): """Menu representing a set of actions, displayed at the origin of a pointer right-click or long-press.""" @@ -243,6 +253,7 @@ class ContextMenu(ComponentNamespace): sub_content = staticmethod(ContextMenuSubContent.create) item = staticmethod(ContextMenuItem.create) separator = staticmethod(ContextMenuSeparator.create) + checkbox = staticmethod(ContextMenuCheckbox.create) context_menu = ContextMenu() diff --git a/reflex/components/radix/themes/components/context_menu.pyi b/reflex/components/radix/themes/components/context_menu.pyi index c5ef757d1..2d3ffbebc 100644 --- a/reflex/components/radix/themes/components/context_menu.pyi +++ b/reflex/components/radix/themes/components/context_menu.pyi @@ -12,6 +12,7 @@ from reflex.style import Style from reflex.vars.base import Var from ..base import RadixThemesComponent +from .checkbox import Checkbox LiteralDirType = Literal["ltr", "rtl"] LiteralSizeType = Literal["1", "2"] @@ -672,6 +673,159 @@ class ContextMenuSeparator(RadixThemesComponent): """ ... +class ContextMenuCheckbox(Checkbox): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + shortcut: Optional[Union[Var[str], str]] = None, + as_child: Optional[Union[Var[bool], bool]] = None, + size: Optional[ + Union[ + Breakpoints[str, Literal["1", "2", "3"]], + Literal["1", "2", "3"], + Var[ + Union[ + Breakpoints[str, Literal["1", "2", "3"]], Literal["1", "2", "3"] + ] + ], + ] + ] = None, + variant: Optional[ + Union[ + Literal["classic", "soft", "surface"], + Var[Literal["classic", "soft", "surface"]], + ] + ] = None, + color_scheme: Optional[ + Union[ + Literal[ + "amber", + "blue", + "bronze", + "brown", + "crimson", + "cyan", + "gold", + "grass", + "gray", + "green", + "indigo", + "iris", + "jade", + "lime", + "mint", + "orange", + "pink", + "plum", + "purple", + "red", + "ruby", + "sky", + "teal", + "tomato", + "violet", + "yellow", + ], + Var[ + Literal[ + "amber", + "blue", + "bronze", + "brown", + "crimson", + "cyan", + "gold", + "grass", + "gray", + "green", + "indigo", + "iris", + "jade", + "lime", + "mint", + "orange", + "pink", + "plum", + "purple", + "red", + "ruby", + "sky", + "teal", + "tomato", + "violet", + "yellow", + ] + ], + ] + ] = None, + high_contrast: Optional[Union[Var[bool], bool]] = None, + default_checked: Optional[Union[Var[bool], bool]] = None, + checked: Optional[Union[Var[bool], bool]] = None, + disabled: Optional[Union[Var[bool], bool]] = None, + required: Optional[Union[Var[bool], bool]] = None, + name: Optional[Union[Var[str], str]] = None, + value: Optional[Union[Var[str], str]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, + on_blur: Optional[EventType[[], BASE_STATE]] = None, + on_change: Optional[ + Union[EventType[[], BASE_STATE], EventType[[bool], BASE_STATE]] + ] = None, + on_click: Optional[EventType[[], BASE_STATE]] = None, + on_context_menu: Optional[EventType[[], BASE_STATE]] = None, + on_double_click: Optional[EventType[[], BASE_STATE]] = None, + on_focus: Optional[EventType[[], BASE_STATE]] = None, + on_mount: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_down: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_enter: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_leave: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_move: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_out: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_over: Optional[EventType[[], BASE_STATE]] = None, + on_mouse_up: Optional[EventType[[], BASE_STATE]] = None, + on_scroll: Optional[EventType[[], BASE_STATE]] = None, + on_unmount: Optional[EventType[[], BASE_STATE]] = None, + **props, + ) -> "ContextMenuCheckbox": + """Create a new component instance. + + Will prepend "RadixThemes" to the component tag to avoid conflicts with + other UI libraries for common names, like Text and Button. + + Args: + *children: Child components. + shortcut: Text to render as shortcut. + as_child: Change the default rendered element for the one passed as a child, merging their props and behavior. + size: Checkbox size "1" - "3" + variant: Variant of checkbox: "classic" | "surface" | "soft" + color_scheme: Override theme color for checkbox + high_contrast: Whether to render the checkbox with higher contrast color against background + default_checked: Whether the checkbox is checked by default + checked: Whether the checkbox is checked + disabled: Whether the checkbox is disabled + required: Whether the checkbox is required + name: The name of the checkbox control when submitting the form. + value: The value of the checkbox control when submitting the form. + on_change: Fired when the checkbox is checked or unchecked. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: Component properties. + + Returns: + A new component instance. + """ + ... + class ContextMenu(ComponentNamespace): root = staticmethod(ContextMenuRoot.create) trigger = staticmethod(ContextMenuTrigger.create) @@ -681,5 +835,6 @@ class ContextMenu(ComponentNamespace): sub_content = staticmethod(ContextMenuSubContent.create) item = staticmethod(ContextMenuItem.create) separator = staticmethod(ContextMenuSeparator.create) + checkbox = staticmethod(ContextMenuCheckbox.create) context_menu = ContextMenu() diff --git a/reflex/components/radix/themes/components/text_field.py b/reflex/components/radix/themes/components/text_field.py index 3dabe0936..7e6dfe85c 100644 --- a/reflex/components/radix/themes/components/text_field.py +++ b/reflex/components/radix/themes/components/text_field.py @@ -9,7 +9,9 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.core.debounce import DebounceInput from reflex.components.el import elements from reflex.event import EventHandler, input_event, key_event +from reflex.utils.types import is_optional from reflex.vars.base import Var +from reflex.vars.number import ternary_operation from ..base import LiteralAccentColor, LiteralRadius, RadixThemesComponent @@ -96,6 +98,19 @@ class TextFieldRoot(elements.Div, RadixThemesComponent): Returns: The component. """ + value = props.get("value") + + # React expects an empty string(instead of null) for controlled inputs. + if value is not None and is_optional( + (value_var := Var.create(value))._var_type + ): + props["value"] = ternary_operation( + (value_var != Var.create(None)) # pyright: ignore [reportGeneralTypeIssues] + & (value_var != Var(_js_expr="undefined")), + value, + Var.create(""), + ) + component = super().create(*children, **props) if props.get("value") is not None and props.get("on_change") is not None: # create a debounced input if the user requests full control to avoid typing jank diff --git a/reflex/config.py b/reflex/config.py index ae2c0ea0e..bbea6a5d0 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -684,6 +684,9 @@ class Config(Base): # Maximum expiration lock time for redis state manager redis_lock_expiration: int = constants.Expiration.LOCK + # Maximum lock time before warning for redis state manager. + redis_lock_warning_threshold: int = constants.Expiration.LOCK_WARNING_THRESHOLD + # Token expiration time for redis state manager redis_token_expiration: int = constants.Expiration.TOKEN diff --git a/reflex/constants/config.py b/reflex/constants/config.py index 970e67844..7425fd864 100644 --- a/reflex/constants/config.py +++ b/reflex/constants/config.py @@ -29,6 +29,8 @@ class Expiration(SimpleNamespace): LOCK = 10000 # The PING timeout PING = 120 + # The maximum time in milliseconds to hold a lock before throwing a warning. + LOCK_WARNING_THRESHOLD = 1000 class GitIgnore(SimpleNamespace): diff --git a/reflex/state.py b/reflex/state.py index e113f6cb9..5d1f9df73 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -11,6 +11,7 @@ import inspect import json import pickle import sys +import time import typing import uuid from abc import ABC, abstractmethod @@ -39,6 +40,7 @@ from typing import ( get_type_hints, ) +from redis.asyncio.client import PubSub from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self @@ -69,6 +71,11 @@ try: except ModuleNotFoundError: BaseModelV1 = BaseModelV2 +try: + from pydantic.v1 import validator +except ModuleNotFoundError: + from pydantic import validator + import wrapt from redis.asyncio import Redis from redis.exceptions import ResponseError @@ -92,6 +99,7 @@ from reflex.utils.exceptions import ( DynamicRouteArgShadowsStateVar, EventHandlerShadowsBuiltInStateMethod, ImmutableStateError, + InvalidLockWarningThresholdError, InvalidStateManagerMode, LockExpiredError, ReflexRuntimeError, @@ -1107,6 +1115,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if ( not field.required and field.default is None + and field.default_factory is None and not types.is_optional(prop._var_type) ): # Ensure frontend uses null coalescing when accessing. @@ -2173,14 +2182,26 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): state["__dict__"].pop("router", None) state["__dict__"].pop("router_data", None) # Never serialize parent_state or substates. - state["__dict__"]["parent_state"] = None - state["__dict__"]["substates"] = {} + state["__dict__"].pop("parent_state", None) + state["__dict__"].pop("substates", None) state["__dict__"].pop("_was_touched", None) # Remove all inherited vars. for inherited_var_name in self.inherited_vars: state["__dict__"].pop(inherited_var_name, None) return state + def __setstate__(self, state: dict[str, Any]): + """Set the state from redis deserialization. + + This method is called by pickle to deserialize the object. + + Args: + state: The state dict for deserialization. + """ + state["__dict__"]["parent_state"] = None + state["__dict__"]["substates"] = {} + super().__setstate__(state) + def _check_state_size( self, pickle_state_size: int, @@ -2870,6 +2891,7 @@ class StateManager(Base, ABC): redis=redis, token_expiration=config.redis_token_expiration, lock_expiration=config.redis_lock_expiration, + lock_warning_threshold=config.redis_lock_warning_threshold, ) raise InvalidStateManagerMode( f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" @@ -3239,6 +3261,15 @@ def _default_lock_expiration() -> int: return get_config().redis_lock_expiration +def _default_lock_warning_threshold() -> int: + """Get the default lock warning threshold. + + Returns: + The default lock warning threshold. + """ + return get_config().redis_lock_warning_threshold + + class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" @@ -3251,6 +3282,11 @@ class StateManagerRedis(StateManager): # The maximum time to hold a lock (ms). lock_expiration: int = pydantic.Field(default_factory=_default_lock_expiration) + # The maximum time to hold a lock (ms) before warning. + lock_warning_threshold: int = pydantic.Field( + default_factory=_default_lock_warning_threshold + ) + # If HEXPIRE is not supported, use EXPIRE instead. _hexpire_not_supported: Optional[bool] = pydantic.PrivateAttr(None) @@ -3413,6 +3449,17 @@ class StateManagerRedis(StateManager): f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " "or use `@rx.event(background=True)` decorator for long-running tasks." ) + elif lock_id is not None: + time_taken = self.lock_expiration / 1000 - ( + await self.redis.ttl(self._lock_key(token)) + ) + if time_taken > self.lock_warning_threshold / 1000: + console.warn( + f"Lock for token {token} was held too long {time_taken=}s, " + f"use `@rx.event(background=True)` decorator for long-running tasks.", + dedupe=True, + ) + client_token, substate_name = _split_substate_key(token) # If the substate name on the token doesn't match the instance name, it cannot have a parent. if state.parent_state is not None and state.get_full_name() != substate_name: @@ -3477,6 +3524,27 @@ class StateManagerRedis(StateManager): yield state await self.set_state(token, state, lock_id) + @validator("lock_warning_threshold") + @classmethod + def validate_lock_warning_threshold(cls, lock_warning_threshold: int, values): + """Validate the lock warning threshold. + + Args: + lock_warning_threshold: The lock warning threshold. + values: The validated attributes. + + Returns: + The lock warning threshold. + + Raises: + InvalidLockWarningThresholdError: If the lock warning threshold is invalid. + """ + if lock_warning_threshold >= (lock_expiration := values["lock_expiration"]): + raise InvalidLockWarningThresholdError( + f"The lock warning threshold({lock_warning_threshold}) must be less than the lock expiration time({lock_expiration})." + ) + return lock_warning_threshold + @staticmethod def _lock_key(token: str) -> bytes: """Get the redis key for a token's lock. @@ -3508,6 +3576,35 @@ class StateManagerRedis(StateManager): nx=True, # only set if it doesn't exist ) + async def _get_pubsub_message( + self, pubsub: PubSub, timeout: float | None = None + ) -> None: + """Get lock release events from the pubsub. + + Args: + pubsub: The pubsub to get a message from. + timeout: Remaining time to wait for a message. + + Returns: + The message. + """ + if timeout is None: + timeout = self.lock_expiration / 1000.0 + + started = time.time() + message = await pubsub.get_message( + ignore_subscribe_messages=True, + timeout=timeout, + ) + if ( + message is None + or message["data"] not in self._redis_keyspace_lock_release_events + ): + remaining = timeout - (time.time() - started) + if remaining <= 0: + return + await self._get_pubsub_message(pubsub, timeout=remaining) + async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: """Wait for a redis lock to be released via pubsub. @@ -3520,7 +3617,6 @@ class StateManagerRedis(StateManager): Raises: ResponseError: when the keyspace config cannot be set. """ - state_is_locked = False lock_key_channel = f"__keyspace@0__:{lock_key.decode()}" # Enable keyspace notifications for the lock key, so we know when it is available. try: @@ -3534,20 +3630,13 @@ class StateManagerRedis(StateManager): raise async with self.redis.pubsub() as pubsub: await pubsub.psubscribe(lock_key_channel) - while not state_is_locked: - # wait for the lock to be released - while True: - if not await self.redis.exists(lock_key): - break # key was removed, try to get the lock again - message = await pubsub.get_message( - ignore_subscribe_messages=True, - timeout=self.lock_expiration / 1000.0, - ) - if message is None: - continue - if message["data"] in self._redis_keyspace_lock_release_events: - break - state_is_locked = await self._try_get_lock(lock_key, lock_id) + # wait for the lock to be released + while True: + # fast path + if await self._try_get_lock(lock_key, lock_id): + return + # wait for lock events + await self._get_pubsub_message(pubsub) @contextlib.asynccontextmanager async def _lock(self, token: str): diff --git a/reflex/utils/console.py b/reflex/utils/console.py index b3ba7163d..be545140a 100644 --- a/reflex/utils/console.py +++ b/reflex/utils/console.py @@ -20,6 +20,24 @@ _EMITTED_DEPRECATION_WARNINGS = set() # Info messages which have been printed. _EMITTED_INFO = set() +# Warnings which have been printed. +_EMIITED_WARNINGS = set() + +# Errors which have been printed. +_EMITTED_ERRORS = set() + +# Success messages which have been printed. +_EMITTED_SUCCESS = set() + +# Debug messages which have been printed. +_EMITTED_DEBUG = set() + +# Logs which have been printed. +_EMITTED_LOGS = set() + +# Prints which have been printed. +_EMITTED_PRINTS = set() + def set_log_level(log_level: LogLevel): """Set the log level. @@ -55,25 +73,37 @@ def is_debug() -> bool: return _LOG_LEVEL <= LogLevel.DEBUG -def print(msg: str, **kwargs): +def print(msg: str, dedupe: bool = False, **kwargs): """Print a message. Args: msg: The message to print. + dedupe: If True, suppress multiple console logs of print message. kwargs: Keyword arguments to pass to the print function. """ + if dedupe: + if msg in _EMITTED_PRINTS: + return + else: + _EMITTED_PRINTS.add(msg) _console.print(msg, **kwargs) -def debug(msg: str, **kwargs): +def debug(msg: str, dedupe: bool = False, **kwargs): """Print a debug message. Args: msg: The debug message. + dedupe: If True, suppress multiple console logs of debug message. kwargs: Keyword arguments to pass to the print function. """ if is_debug(): msg_ = f"[purple]Debug: {msg}[/purple]" + if dedupe: + if msg_ in _EMITTED_DEBUG: + return + else: + _EMITTED_DEBUG.add(msg_) if progress := kwargs.pop("progress", None): progress.console.print(msg_, **kwargs) else: @@ -97,25 +127,37 @@ def info(msg: str, dedupe: bool = False, **kwargs): print(f"[cyan]Info: {msg}[/cyan]", **kwargs) -def success(msg: str, **kwargs): +def success(msg: str, dedupe: bool = False, **kwargs): """Print a success message. Args: msg: The success message. + dedupe: If True, suppress multiple console logs of success message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.INFO: + if dedupe: + if msg in _EMITTED_SUCCESS: + return + else: + _EMITTED_SUCCESS.add(msg) print(f"[green]Success: {msg}[/green]", **kwargs) -def log(msg: str, **kwargs): +def log(msg: str, dedupe: bool = False, **kwargs): """Takes a string and logs it to the console. Args: msg: The message to log. + dedupe: If True, suppress multiple console logs of log message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.INFO: + if dedupe: + if msg in _EMITTED_LOGS: + return + else: + _EMITTED_LOGS.add(msg) _console.log(msg, **kwargs) @@ -129,14 +171,20 @@ def rule(title: str, **kwargs): _console.rule(title, **kwargs) -def warn(msg: str, **kwargs): +def warn(msg: str, dedupe: bool = False, **kwargs): """Print a warning message. Args: msg: The warning message. + dedupe: If True, suppress multiple console logs of warning message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.WARNING: + if dedupe: + if msg in _EMIITED_WARNINGS: + return + else: + _EMIITED_WARNINGS.add(msg) print(f"[orange1]Warning: {msg}[/orange1]", **kwargs) @@ -169,14 +217,20 @@ def deprecate( _EMITTED_DEPRECATION_WARNINGS.add(feature_name) -def error(msg: str, **kwargs): +def error(msg: str, dedupe: bool = False, **kwargs): """Print an error message. Args: msg: The error message. + dedupe: If True, suppress multiple console logs of error message. kwargs: Keyword arguments to pass to the print function. """ if _LOG_LEVEL <= LogLevel.ERROR: + if dedupe: + if msg in _EMITTED_ERRORS: + return + else: + _EMITTED_ERRORS.add(msg) print(f"[red]{msg}[/red]", **kwargs) diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 6c378e159..ae5ec0168 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -183,3 +183,7 @@ def raise_system_package_missing_error(package: str) -> NoReturn: " Please install it through your system package manager." + (f" You can do so by running 'brew install {package}'." if IS_MACOS else "") ) + + +class InvalidLockWarningThresholdError(ReflexError): + """Raised when an invalid lock warning threshold is provided.""" diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 0159a17c3..1d6671a0b 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -664,18 +664,22 @@ def format_library_name(library_fullname: str): return lib -def json_dumps(obj: Any) -> str: +def json_dumps(obj: Any, **kwargs) -> str: """Takes an object and returns a jsonified string. Args: obj: The object to be serialized. + kwargs: Additional keyword arguments to pass to json.dumps. Returns: A string """ from reflex.utils import serializers - return json.dumps(obj, ensure_ascii=False, default=serializers.serialize) + kwargs.setdefault("ensure_ascii", False) + kwargs.setdefault("default", serializers.serialize) + + return json.dumps(obj, **kwargs) def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f7b825f16..d11344903 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +import reflex.app from reflex.config import environment from reflex.testing import AppHarness, AppHarnessProd @@ -76,3 +77,25 @@ def app_harness_env(request): The AppHarness class to use for the test. """ return request.param + + +@pytest.fixture(autouse=True) +def raise_console_error(request, mocker): + """Spy on calls to `console.error` used by the framework. + + Help catch spurious error conditions that might otherwise go unnoticed. + + If a test is marked with `ignore_console_error`, the spy will be ignored + after the test. + + Args: + request: The pytest request object. + mocker: The pytest mocker object. + + Yields: + control to the test function. + """ + spy = mocker.spy(reflex.app.console, "error") + yield + if "ignore_console_error" not in request.keywords: + spy.assert_not_called() diff --git a/tests/integration/test_client_storage.py b/tests/integration/test_client_storage.py index 649b236a4..0fd523957 100644 --- a/tests/integration/test_client_storage.py +++ b/tests/integration/test_client_storage.py @@ -628,8 +628,7 @@ async def test_client_side_state( assert await AppHarness._poll_for_async(poll_for_not_hydrated) # Trigger event to get a new instance of the state since the old was expired. - state_var_input = driver.find_element(By.ID, "state_var") - state_var_input.send_keys("re-triggering") + set_sub("c1", "c1 post expire") # get new references to all cookie and local storage elements (again) c1 = driver.find_element(By.ID, "c1") @@ -650,7 +649,7 @@ async def test_client_side_state( l1s = driver.find_element(By.ID, "l1s") s1s = driver.find_element(By.ID, "s1s") - assert c1.text == "c1 value" + assert c1.text == "c1 post expire" assert c2.text == "c2 value" assert c3.text == "" # temporary cookie expired after reset state! assert c4.text == "c4 value" @@ -680,11 +679,11 @@ async def test_client_side_state( async def poll_for_c1_set(): sub_state = await get_sub_state() - return sub_state.c1 == "c1 value" + return sub_state.c1 == "c1 post expire" assert await AppHarness._poll_for_async(poll_for_c1_set) sub_state = await get_sub_state() - assert sub_state.c1 == "c1 value" + assert sub_state.c1 == "c1 post expire" assert sub_state.c2 == "c2 value" assert sub_state.c3 == "" assert sub_state.c4 == "c4 value" diff --git a/tests/integration/test_exception_handlers.py b/tests/integration/test_exception_handlers.py index 406c21e5d..a645d1de6 100644 --- a/tests/integration/test_exception_handlers.py +++ b/tests/integration/test_exception_handlers.py @@ -13,6 +13,8 @@ from selenium.webdriver.support.ui import WebDriverWait from reflex.testing import AppHarness, AppHarnessProd +pytestmark = [pytest.mark.ignore_console_error] + def TestApp(): """A test app for event exception handler integration.""" diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index b7f14b03d..156cf0e45 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -381,9 +381,22 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive await asyncio.sleep(0.3) cancel_button.click() - # look up the backend state and assert on progress + # Wait a bit for the upload to get cancelled. + await asyncio.sleep(0.5) + + # Get interim progress dicts saved in the on_upload_progress handler. + async def _progress_dicts(): + state = await upload_file.get_state(substate_token) + return state.substates[state_name].progress_dicts + + # We should have _some_ progress + assert await AppHarness._poll_for_async(_progress_dicts) + + # But there should never be a final progress record for a cancelled upload. + for p in await _progress_dicts(): + assert p["progress"] != 1 + state = await upload_file.get_state(substate_token) - assert state.substates[state_name].progress_dicts file_data = state.substates[state_name]._file_data assert isinstance(file_data, dict) normalized_file_data = {Path(k).name: v for k, v in file_data.items()} diff --git a/tests/units/components/core/test_foreach.py b/tests/units/components/core/test_foreach.py index 228165d3e..094f6029d 100644 --- a/tests/units/components/core/test_foreach.py +++ b/tests/units/components/core/test_foreach.py @@ -1,8 +1,10 @@ from typing import Dict, List, Set, Tuple, Union +import pydantic.v1 import pytest from reflex import el +from reflex.base import Base from reflex.components.component import Component from reflex.components.core.foreach import ( Foreach, @@ -18,6 +20,12 @@ from reflex.vars.number import NumberVar from reflex.vars.sequence import ArrayVar +class ForEachTag(Base): + """A tag for testing the ForEach component.""" + + name: str = "" + + class ForEachState(BaseState): """A state for testing the ForEach component.""" @@ -46,6 +54,8 @@ class ForEachState(BaseState): bad_annotation_list: list = [["red", "orange"], ["yellow", "blue"]] color_index_tuple: Tuple[int, str] = (0, "red") + default_factory_list: list[ForEachTag] = pydantic.v1.Field(default_factory=list) + class ComponentStateTest(ComponentState): """A test component state.""" @@ -290,3 +300,11 @@ def test_foreach_component_state(): ForEachState.colors_list, ComponentStateTest.create, ) + + +def test_foreach_default_factory(): + """Test that the default factory is called.""" + _ = Foreach.create( + ForEachState.default_factory_list, + lambda tag: text(tag.name), + ) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index a1db3838d..1e17fc653 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -56,6 +56,7 @@ from reflex.state import ( from reflex.testing import chdir from reflex.utils import format, prerequisites, types from reflex.utils.exceptions import ( + InvalidLockWarningThresholdError, ReflexRuntimeError, SetUndefinedStateVarError, StateSerializationError, @@ -67,7 +68,9 @@ from tests.units.states.mutation import MutableSQLAModel, MutableTestState from .states import GenState CI = bool(os.environ.get("CI", False)) -LOCK_EXPIRATION = 2000 if CI else 300 +LOCK_EXPIRATION = 2500 if CI else 300 +LOCK_WARNING_THRESHOLD = 1000 if CI else 100 +LOCK_WARN_SLEEP = 1.5 if CI else 0.15 LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4 @@ -1787,6 +1790,7 @@ async def test_state_manager_lock_expire( substate_token_redis: A token + substate name for looking up in state manager. """ state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD async with state_manager_redis.modify_state(substate_token_redis): await asyncio.sleep(0.01) @@ -1811,6 +1815,7 @@ async def test_state_manager_lock_expire_contend( unexp_num1 = 666 state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD order = [] @@ -1840,6 +1845,57 @@ async def test_state_manager_lock_expire_contend( assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 +@pytest.mark.asyncio +async def test_state_manager_lock_warning_threshold_contend( + state_manager_redis: StateManager, token: str, substate_token_redis: str, mocker +): + """Test that the state manager triggers a warning when lock contention exceeds the warning threshold. + + Args: + state_manager_redis: A state manager instance. + token: A token. + substate_token_redis: A token + substate name for looking up in state manager. + mocker: Pytest mocker object. + """ + console_warn = mocker.patch("reflex.utils.console.warn") + + state_manager_redis.lock_expiration = LOCK_EXPIRATION + state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD + + order = [] + + async def _coro_blocker(): + async with state_manager_redis.modify_state(substate_token_redis): + order.append("blocker") + await asyncio.sleep(LOCK_WARN_SLEEP) + + tasks = [ + asyncio.create_task(_coro_blocker()), + ] + + await tasks[0] + console_warn.assert_called() + assert console_warn.call_count == 7 + + +class CopyingAsyncMock(AsyncMock): + """An AsyncMock, but deepcopy the args and kwargs first.""" + + def __call__(self, *args, **kwargs): + """Call the mock. + + Args: + args: the arguments passed to the mock + kwargs: the keyword arguments passed to the mock + + Returns: + The result of the mock call + """ + args = copy.deepcopy(args) + kwargs = copy.deepcopy(kwargs) + return super().__call__(*args, **kwargs) + + @pytest.fixture(scope="function") def mock_app_simple(monkeypatch) -> rx.App: """Simple Mock app fixture. @@ -1856,7 +1912,7 @@ def mock_app_simple(monkeypatch) -> rx.App: setattr(app_module, CompileVars.APP, app) app.state = TestState - app.event_namespace.emit = AsyncMock() # type: ignore + app.event_namespace.emit = CopyingAsyncMock() # type: ignore def _mock_get_app(*args, **kwargs): return app_module @@ -1960,21 +2016,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): mock_app.event_namespace.emit.assert_called_once() mcall = mock_app.event_namespace.emit.mock_calls[0] assert mcall.args[0] == str(SocketEvent.EVENT) - assert json.loads(mcall.args[1]) == dataclasses.asdict( - StateUpdate( - delta={ - parent_state.get_full_name(): { - "upper": "", - "sum": 3.14, - }, - grandchild_state.get_full_name(): { - "value2": "42", - }, - GrandchildState3.get_full_name(): { - "computed": "", - }, - } - ) + assert mcall.args[1] == StateUpdate( + delta={ + parent_state.get_full_name(): { + "upper": "", + "sum": 3.14, + }, + grandchild_state.get_full_name(): { + "value2": "42", + }, + GrandchildState3.get_full_name(): { + "computed": "", + }, + } ) assert mcall.kwargs["to"] == grandchild_state.router.session.session_id @@ -2156,51 +2210,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): assert mock_app.event_namespace is not None emit_mock = mock_app.event_namespace.emit - first_ws_message = json.loads(emit_mock.mock_calls[0].args[1]) + first_ws_message = emit_mock.mock_calls[0].args[1] assert ( - first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router") + first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router") is not None ) - assert first_ws_message == { - "delta": { + assert first_ws_message == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "order": ["background_task:start"], "computed_order": ["background_task:start"], } }, - "events": [], - "final": True, - } + events=[], + final=True, + ) for call in emit_mock.mock_calls[1:5]: - assert json.loads(call.args[1]) == { - "delta": { + assert call.args[1] == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "computed_order": ["background_task:start"], } }, - "events": [], - "final": True, - } - assert json.loads(emit_mock.mock_calls[-2].args[1]) == { - "delta": { + events=[], + final=True, + ) + assert emit_mock.mock_calls[-2].args[1] == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "order": exp_order, "computed_order": exp_order, "dict_list": {}, } }, - "events": [], - "final": True, - } - assert json.loads(emit_mock.mock_calls[-1].args[1]) == { - "delta": { + events=[], + final=True, + ) + assert emit_mock.mock_calls[-1].args[1] == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "computed_order": exp_order, }, }, - "events": [], - "final": True, - } + events=[], + final=True, + ) @pytest.mark.asyncio @@ -3246,12 +3300,42 @@ async def test_setvar_async_setter(): @pytest.mark.parametrize( "expiration_kwargs, expected_values", [ - ({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)), + ( + {"redis_lock_expiration": 20000}, + ( + 20000, + constants.Expiration.TOKEN, + constants.Expiration.LOCK_WARNING_THRESHOLD, + ), + ), ( {"redis_lock_expiration": 50000, "redis_token_expiration": 5600}, - (50000, 5600), + (50000, 5600, constants.Expiration.LOCK_WARNING_THRESHOLD), + ), + ( + {"redis_token_expiration": 7600}, + ( + constants.Expiration.LOCK, + 7600, + constants.Expiration.LOCK_WARNING_THRESHOLD, + ), + ), + ( + {"redis_lock_expiration": 50000, "redis_lock_warning_threshold": 1500}, + (50000, constants.Expiration.TOKEN, 1500), + ), + ( + {"redis_token_expiration": 5600, "redis_lock_warning_threshold": 3000}, + (constants.Expiration.LOCK, 5600, 3000), + ), + ( + { + "redis_lock_expiration": 50000, + "redis_token_expiration": 5600, + "redis_lock_warning_threshold": 2000, + }, + (50000, 5600, 2000), ), - ({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)), ], ) def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values): @@ -3281,6 +3365,44 @@ config = rx.Config( state_manager = StateManager.create(state=State) assert state_manager.lock_expiration == expected_values[0] # type: ignore assert state_manager.token_expiration == expected_values[1] # type: ignore + assert state_manager.lock_warning_threshold == expected_values[2] # type: ignore + + +@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis") +@pytest.mark.parametrize( + "redis_lock_expiration, redis_lock_warning_threshold", + [ + (10000, 10000), + (20000, 30000), + ], +) +def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold( + tmp_path, redis_lock_expiration, redis_lock_warning_threshold +): + proj_root = tmp_path / "project1" + proj_root.mkdir() + + config_string = f""" +import reflex as rx +config = rx.Config( + app_name="project1", + redis_url="redis://localhost:6379", + state_manager_mode="redis", + redis_lock_expiration = {redis_lock_expiration}, + redis_lock_warning_threshold = {redis_lock_warning_threshold}, +) + """ + + (proj_root / "rxconfig.py").write_text(dedent(config_string)) + + with chdir(proj_root): + # reload config for each parameter to avoid stale values + reflex.config.get_config(reload=True) + from reflex.state import State, StateManager + + with pytest.raises(InvalidLockWarningThresholdError): + StateManager.create(state=State) + del sys.modules[constants.Config.MODULE] class MixinState(State, mixin=True):