This commit is contained in:
benedikt-bartscher 2024-11-22 12:32:52 -08:00 committed by GitHub
commit b5a15f1179
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 9 deletions

View File

@ -564,6 +564,9 @@ class EnvironmentVariables:
# The maximum size of the reflex state in kilobytes.
REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000)
# Optional redis key prefix for the state manager.
REFLEX_REDIS_PREFIX: EnvVar[Optional[str]] = env_var(None)
environment = EnvironmentVariables()

View File

@ -43,7 +43,7 @@ from sqlalchemy.orm import DeclarativeBase
from typing_extensions import Self
from reflex import event
from reflex.config import PerformanceMode, get_config
from reflex.config import EnvironmentVariables, PerformanceMode, get_config
from reflex.istate.data import RouterData
from reflex.istate.storage import ClientStorageBase
from reflex.model import Model
@ -3155,6 +3155,25 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration
TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes)
@functools.lru_cache
def prefix_redis_token_str(token: str) -> str:
"""Prefix the token with the redis prefix.
Args:
token: The token to prefix.
Returns:
The prefixed token.
"""
prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get()
if not prefix:
return token
return f"{prefix}{token}"
class StateManagerRedis(StateManager):
"""A state manager that stores states in redis."""
@ -3292,7 +3311,7 @@ class StateManagerRedis(StateManager):
state = None
# Fetch the serialized substate from redis.
redis_state = await self.redis.get(token)
redis_state = await self.redis.get(prefix_redis_token_str(token))
if redis_state is not None:
# Deserialize the substate.
@ -3378,7 +3397,7 @@ class StateManagerRedis(StateManager):
pickle_state = state._serialize()
if pickle_state:
await self.redis.set(
_substate_key(client_token, state),
prefix_redis_token_str(_substate_key(client_token, state)),
pickle_state,
ex=self.token_expiration,
)
@ -3415,7 +3434,7 @@ class StateManagerRedis(StateManager):
"""
# All substates share the same lock domain, so ignore any substate path suffix.
client_token = _split_substate_key(token)[0]
return f"{client_token}_lock".encode()
return prefix_redis_token_str(f"{client_token}_lock").encode()
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
"""Try to get a redis lock for a token.

View File

@ -10,7 +10,7 @@ import os
import sys
import threading
from textwrap import dedent
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Optional, Union
from unittest.mock import AsyncMock, Mock
import pytest
@ -42,6 +42,7 @@ from reflex.state import (
StateProxy,
StateUpdate,
_substate_key,
prefix_redis_token_str,
)
from reflex.testing import chdir
from reflex.utils import format, prerequisites, types
@ -1673,7 +1674,9 @@ async def test_state_manager_modify_state(
"""
async with state_manager.modify_state(substate_token) as state:
if isinstance(state_manager, StateManagerRedis):
assert await state_manager.redis.get(f"{token}_lock")
assert await state_manager.redis.get(
prefix_redis_token_str(f"{token}_lock")
)
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks
assert state_manager._states_locks[token].locked()
@ -1683,7 +1686,9 @@ async def test_state_manager_modify_state(
state.complex[3] = complex_1
# lock should be dropped after exiting the context
if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
assert (
await state_manager.redis.get(prefix_redis_token_str(f"{token}_lock"))
) is None
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert not state_manager._states_locks[token].locked()
@ -1725,7 +1730,9 @@ async def test_state_manager_contend(
assert (await state_manager.get_state(substate_token)).num1 == exp_num1
if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
assert (
await state_manager.redis.get(prefix_redis_token_str(f"{token}_lock"))
) is None
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks
assert not state_manager._states_locks[token].locked()
@ -1785,7 +1792,7 @@ async def test_state_manager_lock_expire(
@pytest.mark.asyncio
async def test_state_manager_lock_expire_contend(
state_manager_redis: StateManager, token: str, substate_token_redis: str
state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str
):
"""Test that the state manager lock expires and queued waiters proceed.
@ -1827,6 +1834,41 @@ async def test_state_manager_lock_expire_contend(
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
@pytest.fixture(scope="function")
def redis_prefix() -> Generator[str, None, None]:
"""Fixture for redis prefix.
Yields:
A redis prefix.
"""
prefix = "test_prefix"
reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(prefix)
yield prefix
reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(None)
@pytest.mark.asyncio
async def test_state_manager_redis_prefix(
state_manager_redis: StateManagerRedis,
substate_token_redis: str,
redis_prefix: str,
):
"""Test that the state manager redis prefix is applied correctly.
Args:
state_manager_redis: A state manager instance.
substate_token_redis: A token + substate name for looking up in state manager.
redis_prefix: A redis prefix.
"""
async with state_manager_redis.modify_state(substate_token_redis) as state:
state.num1 = 42
prefixed_token = prefix_redis_token_str(substate_token_redis)
assert prefixed_token == f"{redis_prefix}{substate_token_redis}"
assert await state_manager_redis.redis.get(prefixed_token)
@pytest.fixture(scope="function")
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
"""Mock app fixture.