implement redis key prefix for StateManagerRedis
This commit is contained in:
parent
01e3844ac4
commit
9bbb04ecfc
@ -545,6 +545,9 @@ class EnvironmentVariables:
|
||||
# Where to save screenshots when tests fail.
|
||||
SCREENSHOT_DIR: EnvVar[Optional[Path]] = env_var(None)
|
||||
|
||||
# Optional redis key prefix for the state manager.
|
||||
REFLEX_REDIS_PREFIX: EnvVar[Optional[str]] = env_var(None)
|
||||
|
||||
|
||||
environment = EnvironmentVariables()
|
||||
|
||||
|
@ -42,7 +42,7 @@ from sqlalchemy.orm import DeclarativeBase
|
||||
from typing_extensions import Self
|
||||
|
||||
from reflex import event
|
||||
from reflex.config import get_config
|
||||
from reflex.config import EnvironmentVariables, get_config
|
||||
from reflex.istate.data import RouterData
|
||||
from reflex.istate.storage import (
|
||||
ClientStorageBase,
|
||||
@ -3076,6 +3076,26 @@ def _default_lock_expiration() -> int:
|
||||
return get_config().redis_lock_expiration
|
||||
|
||||
|
||||
TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes)
|
||||
|
||||
|
||||
def prefix_redis_token(token: TOKEN_TYPE) -> TOKEN_TYPE:
|
||||
"""Prefix the token with the redis prefix.
|
||||
|
||||
Args:
|
||||
token: The token to prefix.
|
||||
|
||||
Returns:
|
||||
The prefixed token.
|
||||
"""
|
||||
prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get()
|
||||
if not prefix:
|
||||
return token
|
||||
if isinstance(token, bytes):
|
||||
return prefix.encode() + token
|
||||
return f"{prefix}{token}"
|
||||
|
||||
|
||||
class StateManagerRedis(StateManager):
|
||||
"""A state manager that stores states in redis."""
|
||||
|
||||
@ -3213,7 +3233,7 @@ class StateManagerRedis(StateManager):
|
||||
state = None
|
||||
|
||||
# Fetch the serialized substate from redis.
|
||||
redis_state = await self.redis.get(token)
|
||||
redis_state = await self.redis.get(prefix_redis_token(token))
|
||||
|
||||
if redis_state is not None:
|
||||
# Deserialize the substate.
|
||||
@ -3268,7 +3288,8 @@ class StateManagerRedis(StateManager):
|
||||
# Check that we're holding the lock.
|
||||
if (
|
||||
lock_id is not None
|
||||
and await self.redis.get(self._lock_key(token)) != lock_id
|
||||
and await self.redis.get(prefix_redis_token(self._lock_key(token)))
|
||||
!= lock_id
|
||||
):
|
||||
raise LockExpiredError(
|
||||
f"Lock expired for token {token} while processing. Consider increasing "
|
||||
@ -3299,7 +3320,7 @@ class StateManagerRedis(StateManager):
|
||||
pickle_state = state._serialize()
|
||||
if pickle_state:
|
||||
await self.redis.set(
|
||||
_substate_key(client_token, state),
|
||||
prefix_redis_token(_substate_key(client_token, state)),
|
||||
pickle_state,
|
||||
ex=self.token_expiration,
|
||||
)
|
||||
@ -3349,7 +3370,7 @@ class StateManagerRedis(StateManager):
|
||||
True if the lock was obtained.
|
||||
"""
|
||||
return await self.redis.set(
|
||||
lock_key,
|
||||
prefix_redis_token(lock_key),
|
||||
lock_id,
|
||||
px=self.lock_expiration,
|
||||
nx=True, # only set if it doesn't exist
|
||||
|
@ -40,6 +40,7 @@ from reflex.state import (
|
||||
StateProxy,
|
||||
StateUpdate,
|
||||
_substate_key,
|
||||
prefix_redis_token,
|
||||
)
|
||||
from reflex.testing import chdir
|
||||
from reflex.utils import format, prerequisites, types
|
||||
@ -1671,7 +1672,7 @@ async def test_state_manager_modify_state(
|
||||
"""
|
||||
async with state_manager.modify_state(substate_token) as state:
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert await state_manager.redis.get(f"{token}_lock")
|
||||
assert await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
|
||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
assert token in state_manager._states_locks
|
||||
assert state_manager._states_locks[token].locked()
|
||||
@ -1681,7 +1682,9 @@ async def test_state_manager_modify_state(
|
||||
state.complex[3] = complex_1
|
||||
# lock should be dropped after exiting the context
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||
assert (
|
||||
await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
|
||||
) is None
|
||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
assert not state_manager._states_locks[token].locked()
|
||||
|
||||
@ -1723,7 +1726,9 @@ async def test_state_manager_contend(
|
||||
assert (await state_manager.get_state(substate_token)).num1 == exp_num1
|
||||
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||
assert (
|
||||
await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
|
||||
) is None
|
||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
assert token in state_manager._states_locks
|
||||
assert not state_manager._states_locks[token].locked()
|
||||
@ -1783,7 +1788,7 @@ async def test_state_manager_lock_expire(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_lock_expire_contend(
|
||||
state_manager_redis: StateManager, token: str, substate_token_redis: str
|
||||
state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str
|
||||
):
|
||||
"""Test that the state manager lock expires and queued waiters proceed.
|
||||
|
||||
@ -1825,6 +1830,28 @@ async def test_state_manager_lock_expire_contend(
|
||||
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_redis_prefix(
|
||||
state_manager_redis: StateManagerRedis, substate_token_redis: str
|
||||
):
|
||||
"""Test that the state manager redis prefix is applied correctly.
|
||||
|
||||
Args:
|
||||
state_manager_redis: A state manager instance.
|
||||
substate_token_redis: A token + substate name for looking up in state manager.
|
||||
"""
|
||||
prefix = "test_prefix"
|
||||
reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(prefix)
|
||||
|
||||
async with state_manager_redis.modify_state(substate_token_redis) as state:
|
||||
state.num1 = 42
|
||||
|
||||
prefixed_token = prefix_redis_token(substate_token_redis)
|
||||
assert prefixed_token == f"{prefix}{substate_token_redis}"
|
||||
|
||||
assert await state_manager_redis.redis.get(prefixed_token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
|
||||
"""Mock app fixture.
|
||||
|
@ -6,7 +6,12 @@ import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
import reflex as rx
|
||||
from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key
|
||||
from reflex.state import (
|
||||
BaseState,
|
||||
StateManager,
|
||||
StateManagerRedis,
|
||||
_substate_key,
|
||||
)
|
||||
|
||||
|
||||
class Root(BaseState):
|
||||
|
Loading…
Reference in New Issue
Block a user