From f037df09774726a43dad86c7c015b627b3b6d0b9 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Mon, 24 Jun 2024 16:03:30 -0700 Subject: [PATCH] [REF-3056] Config knob for redis StateManager expiration times (#3523) --- reflex/config.py | 6 ++++++ reflex/state.py | 15 ++++++++++++--- tests/test_state.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/reflex/config.py b/reflex/config.py index 08663aa04..95e6e54d8 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -219,6 +219,12 @@ class Config(Base): # Number of gunicorn workers from user gunicorn_workers: Optional[int] = None + # Maximum expiration lock time for redis state manager + redis_lock_expiration: int = constants.Expiration.LOCK + + # Token expiration time for redis state manager + redis_token_expiration: int = constants.Expiration.TOKEN + # Attributes that were explicitly set by the user. _non_default_attributes: Set[str] = pydantic.PrivateAttr(set()) diff --git a/reflex/state.py b/reflex/state.py index 2a4494e65..ce88b7f5a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -40,6 +40,7 @@ from redis.exceptions import ResponseError from reflex import constants from reflex.base import Base +from reflex.config import get_config from reflex.event import ( BACKGROUND_TASK_MARKER, Event, @@ -60,6 +61,7 @@ if TYPE_CHECKING: Delta = Dict[str, Any] var = computed_var +config = get_config() # If the state is this large, it's considered a performance issue. @@ -2202,7 +2204,14 @@ class StateManager(Base, ABC): """ redis = prerequisites.get_redis() if redis is not None: - return StateManagerRedis(state=state, redis=redis) + # make sure expiration values are obtained only from the config object on creation + config = get_config() + return StateManagerRedis( + state=state, + redis=redis, + token_expiration=config.redis_token_expiration, + lock_expiration=config.redis_lock_expiration, + ) return StateManagerMemory(state=state) @abstractmethod @@ -2333,10 +2342,10 @@ class StateManagerRedis(StateManager): redis: Redis # The token expiration time (s). - token_expiration: int = constants.Expiration.TOKEN + token_expiration: int = config.redis_token_expiration # The maximum time to hold a lock (ms). - lock_expiration: int = constants.Expiration.LOCK + lock_expiration: int = config.redis_lock_expiration # The keyspace subscription string when redis is waiting for lock to be released _redis_notify_keyspace_events: str = ( diff --git a/tests/test_state.py b/tests/test_state.py index 0a83e58a6..8b91f1cfe 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -7,6 +7,7 @@ import functools import json import os import sys +from textwrap import dedent from typing import Any, Dict, Generator, List, Optional, Union from unittest.mock import AsyncMock, Mock @@ -14,6 +15,8 @@ import pytest from plotly.graph_objects import Figure import reflex as rx +import reflex.config +from reflex import constants from reflex.app import App from reflex.base import Base from reflex.constants import CompileVars, RouteVar, SocketEvent @@ -33,6 +36,7 @@ from reflex.state import ( StateUpdate, _substate_key, ) +from reflex.testing import chdir from reflex.utils import format, prerequisites, types from reflex.utils.format import json_dumps from reflex.vars import BaseVar, ComputedVar @@ -2925,3 +2929,43 @@ async def test_setvar(mock_app: rx.App, token: str): # Cannot setvar with non-string with pytest.raises(ValueError): TestState.setvar(42, 42) + + +@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis") +@pytest.mark.parametrize( + "expiration_kwargs, expected_values", + [ + ({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)), + ( + {"redis_lock_expiration": 50000, "redis_token_expiration": 5600}, + (50000, 5600), + ), + ({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)), + ], +) +def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values): + proj_root = tmp_path / "project1" + proj_root.mkdir() + + config_items = ",\n ".join( + f"{key} = {value}" for key, value in expiration_kwargs.items() + ) + + config_string = f""" +import reflex as rx +config = rx.Config( + app_name="project1", + redis_url="redis://localhost:6379", + {config_items} +) +""" + (proj_root / "rxconfig.py").write_text(dedent(config_string)) + + with chdir(proj_root): + # reload config for each parameter to avoid stale values + reflex.config.get_config(reload=True) + from reflex.state import State, StateManager + + state_manager = StateManager.create(state=State) + assert state_manager.lock_expiration == expected_values[0] # type: ignore + assert state_manager.token_expiration == expected_values[1] # type: ignore