implement redis key prefix for StateManagerRedis

This commit is contained in:
Benedikt Bartscher 2024-11-06 00:15:04 +01:00
parent 01e3844ac4
commit 9bbb04ecfc
No known key found for this signature in database
4 changed files with 66 additions and 10 deletions

View File

@ -545,6 +545,9 @@ class EnvironmentVariables:
# Where to save screenshots when tests fail.
SCREENSHOT_DIR: EnvVar[Optional[Path]] = env_var(None)
# Optional redis key prefix for the state manager.
REFLEX_REDIS_PREFIX: EnvVar[Optional[str]] = env_var(None)
environment = EnvironmentVariables()

View File

@ -42,7 +42,7 @@ from sqlalchemy.orm import DeclarativeBase
from typing_extensions import Self
from reflex import event
from reflex.config import get_config
from reflex.config import EnvironmentVariables, get_config
from reflex.istate.data import RouterData
from reflex.istate.storage import (
ClientStorageBase,
@ -3076,6 +3076,26 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration
TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes)
def prefix_redis_token(token: TOKEN_TYPE) -> TOKEN_TYPE:
"""Prefix the token with the redis prefix.
Args:
token: The token to prefix.
Returns:
The prefixed token.
"""
prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get()
if not prefix:
return token
if isinstance(token, bytes):
return prefix.encode() + token
return f"{prefix}{token}"
class StateManagerRedis(StateManager):
"""A state manager that stores states in redis."""
@ -3213,7 +3233,7 @@ class StateManagerRedis(StateManager):
state = None
# Fetch the serialized substate from redis.
redis_state = await self.redis.get(token)
redis_state = await self.redis.get(prefix_redis_token(token))
if redis_state is not None:
# Deserialize the substate.
@ -3268,7 +3288,8 @@ class StateManagerRedis(StateManager):
# Check that we're holding the lock.
if (
lock_id is not None
and await self.redis.get(self._lock_key(token)) != lock_id
and await self.redis.get(prefix_redis_token(self._lock_key(token)))
!= lock_id
):
raise LockExpiredError(
f"Lock expired for token {token} while processing. Consider increasing "
@ -3299,7 +3320,7 @@ class StateManagerRedis(StateManager):
pickle_state = state._serialize()
if pickle_state:
await self.redis.set(
_substate_key(client_token, state),
prefix_redis_token(_substate_key(client_token, state)),
pickle_state,
ex=self.token_expiration,
)
@ -3349,7 +3370,7 @@ class StateManagerRedis(StateManager):
True if the lock was obtained.
"""
return await self.redis.set(
lock_key,
prefix_redis_token(lock_key),
lock_id,
px=self.lock_expiration,
nx=True, # only set if it doesn't exist

View File

@ -40,6 +40,7 @@ from reflex.state import (
StateProxy,
StateUpdate,
_substate_key,
prefix_redis_token,
)
from reflex.testing import chdir
from reflex.utils import format, prerequisites, types
@ -1671,7 +1672,7 @@ async def test_state_manager_modify_state(
"""
async with state_manager.modify_state(substate_token) as state:
if isinstance(state_manager, StateManagerRedis):
assert await state_manager.redis.get(f"{token}_lock")
assert await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks
assert state_manager._states_locks[token].locked()
@ -1681,7 +1682,9 @@ async def test_state_manager_modify_state(
state.complex[3] = complex_1
# lock should be dropped after exiting the context
if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
assert (
await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
) is None
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert not state_manager._states_locks[token].locked()
@ -1723,7 +1726,9 @@ async def test_state_manager_contend(
assert (await state_manager.get_state(substate_token)).num1 == exp_num1
if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
assert (
await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
) is None
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks
assert not state_manager._states_locks[token].locked()
@ -1783,7 +1788,7 @@ async def test_state_manager_lock_expire(
@pytest.mark.asyncio
async def test_state_manager_lock_expire_contend(
state_manager_redis: StateManager, token: str, substate_token_redis: str
state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str
):
"""Test that the state manager lock expires and queued waiters proceed.
@ -1825,6 +1830,28 @@ async def test_state_manager_lock_expire_contend(
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
@pytest.mark.asyncio
async def test_state_manager_redis_prefix(
state_manager_redis: StateManagerRedis, substate_token_redis: str
):
"""Test that the state manager redis prefix is applied correctly.
Args:
state_manager_redis: A state manager instance.
substate_token_redis: A token + substate name for looking up in state manager.
"""
prefix = "test_prefix"
reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(prefix)
async with state_manager_redis.modify_state(substate_token_redis) as state:
state.num1 = 42
prefixed_token = prefix_redis_token(substate_token_redis)
assert prefixed_token == f"{prefix}{substate_token_redis}"
assert await state_manager_redis.redis.get(prefixed_token)
@pytest.fixture(scope="function")
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
"""Mock app fixture.

View File

@ -6,7 +6,12 @@ import pytest
import pytest_asyncio
import reflex as rx
from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key
from reflex.state import (
BaseState,
StateManager,
StateManagerRedis,
_substate_key,
)
class Root(BaseState):