[REF-3056] Config knob for redis StateManager expiration times (#3523)

This commit is contained in:
Elijah Ahianyo 2024-06-24 16:03:30 -07:00 committed by GitHub
parent 44233c9cd1
commit f037df0977
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 62 additions and 3 deletions

View File

@ -219,6 +219,12 @@ class Config(Base):
# Number of gunicorn workers from user # Number of gunicorn workers from user
gunicorn_workers: Optional[int] = None gunicorn_workers: Optional[int] = None
# Maximum expiration lock time for redis state manager
redis_lock_expiration: int = constants.Expiration.LOCK
# Token expiration time for redis state manager
redis_token_expiration: int = constants.Expiration.TOKEN
# Attributes that were explicitly set by the user. # Attributes that were explicitly set by the user.
_non_default_attributes: Set[str] = pydantic.PrivateAttr(set()) _non_default_attributes: Set[str] = pydantic.PrivateAttr(set())

View File

@ -40,6 +40,7 @@ from redis.exceptions import ResponseError
from reflex import constants from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.config import get_config
from reflex.event import ( from reflex.event import (
BACKGROUND_TASK_MARKER, BACKGROUND_TASK_MARKER,
Event, Event,
@ -60,6 +61,7 @@ if TYPE_CHECKING:
Delta = Dict[str, Any] Delta = Dict[str, Any]
var = computed_var var = computed_var
config = get_config()
# If the state is this large, it's considered a performance issue. # If the state is this large, it's considered a performance issue.
@ -2202,7 +2204,14 @@ class StateManager(Base, ABC):
""" """
redis = prerequisites.get_redis() redis = prerequisites.get_redis()
if redis is not None: if redis is not None:
return StateManagerRedis(state=state, redis=redis) # make sure expiration values are obtained only from the config object on creation
config = get_config()
return StateManagerRedis(
state=state,
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
)
return StateManagerMemory(state=state) return StateManagerMemory(state=state)
@abstractmethod @abstractmethod
@ -2333,10 +2342,10 @@ class StateManagerRedis(StateManager):
redis: Redis redis: Redis
# The token expiration time (s). # The token expiration time (s).
token_expiration: int = constants.Expiration.TOKEN token_expiration: int = config.redis_token_expiration
# The maximum time to hold a lock (ms). # The maximum time to hold a lock (ms).
lock_expiration: int = constants.Expiration.LOCK lock_expiration: int = config.redis_lock_expiration
# The keyspace subscription string when redis is waiting for lock to be released # The keyspace subscription string when redis is waiting for lock to be released
_redis_notify_keyspace_events: str = ( _redis_notify_keyspace_events: str = (

View File

@ -7,6 +7,7 @@ import functools
import json import json
import os import os
import sys import sys
from textwrap import dedent
from typing import Any, Dict, Generator, List, Optional, Union from typing import Any, Dict, Generator, List, Optional, Union
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
@ -14,6 +15,8 @@ import pytest
from plotly.graph_objects import Figure from plotly.graph_objects import Figure
import reflex as rx import reflex as rx
import reflex.config
from reflex import constants
from reflex.app import App from reflex.app import App
from reflex.base import Base from reflex.base import Base
from reflex.constants import CompileVars, RouteVar, SocketEvent from reflex.constants import CompileVars, RouteVar, SocketEvent
@ -33,6 +36,7 @@ from reflex.state import (
StateUpdate, StateUpdate,
_substate_key, _substate_key,
) )
from reflex.testing import chdir
from reflex.utils import format, prerequisites, types from reflex.utils import format, prerequisites, types
from reflex.utils.format import json_dumps from reflex.utils.format import json_dumps
from reflex.vars import BaseVar, ComputedVar from reflex.vars import BaseVar, ComputedVar
@ -2925,3 +2929,43 @@ async def test_setvar(mock_app: rx.App, token: str):
# Cannot setvar with non-string # Cannot setvar with non-string
with pytest.raises(ValueError): with pytest.raises(ValueError):
TestState.setvar(42, 42) TestState.setvar(42, 42)
@pytest.mark.skipif("REDIS_URL" not in os.environ, reason="Test requires redis")
@pytest.mark.parametrize(
"expiration_kwargs, expected_values",
[
({"redis_lock_expiration": 20000}, (20000, constants.Expiration.TOKEN)),
(
{"redis_lock_expiration": 50000, "redis_token_expiration": 5600},
(50000, 5600),
),
({"redis_token_expiration": 7600}, (constants.Expiration.LOCK, 7600)),
],
)
def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_values):
proj_root = tmp_path / "project1"
proj_root.mkdir()
config_items = ",\n ".join(
f"{key} = {value}" for key, value in expiration_kwargs.items()
)
config_string = f"""
import reflex as rx
config = rx.Config(
app_name="project1",
redis_url="redis://localhost:6379",
{config_items}
)
"""
(proj_root / "rxconfig.py").write_text(dedent(config_string))
with chdir(proj_root):
# reload config for each parameter to avoid stale values
reflex.config.get_config(reload=True)
from reflex.state import State, StateManager
state_manager = StateManager.create(state=State)
assert state_manager.lock_expiration == expected_values[0] # type: ignore
assert state_manager.token_expiration == expected_values[1] # type: ignore