implement redis key prefix for StateManagerRedis
This commit is contained in:
parent
01e3844ac4
commit
9bbb04ecfc
@ -545,6 +545,9 @@ class EnvironmentVariables:
|
|||||||
# Where to save screenshots when tests fail.
|
# Where to save screenshots when tests fail.
|
||||||
SCREENSHOT_DIR: EnvVar[Optional[Path]] = env_var(None)
|
SCREENSHOT_DIR: EnvVar[Optional[Path]] = env_var(None)
|
||||||
|
|
||||||
|
# Optional redis key prefix for the state manager.
|
||||||
|
REFLEX_REDIS_PREFIX: EnvVar[Optional[str]] = env_var(None)
|
||||||
|
|
||||||
|
|
||||||
environment = EnvironmentVariables()
|
environment = EnvironmentVariables()
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ from sqlalchemy.orm import DeclarativeBase
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from reflex import event
|
from reflex import event
|
||||||
from reflex.config import get_config
|
from reflex.config import EnvironmentVariables, get_config
|
||||||
from reflex.istate.data import RouterData
|
from reflex.istate.data import RouterData
|
||||||
from reflex.istate.storage import (
|
from reflex.istate.storage import (
|
||||||
ClientStorageBase,
|
ClientStorageBase,
|
||||||
@ -3076,6 +3076,26 @@ def _default_lock_expiration() -> int:
|
|||||||
return get_config().redis_lock_expiration
|
return get_config().redis_lock_expiration
|
||||||
|
|
||||||
|
|
||||||
|
TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def prefix_redis_token(token: TOKEN_TYPE) -> TOKEN_TYPE:
|
||||||
|
"""Prefix the token with the redis prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to prefix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The prefixed token.
|
||||||
|
"""
|
||||||
|
prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get()
|
||||||
|
if not prefix:
|
||||||
|
return token
|
||||||
|
if isinstance(token, bytes):
|
||||||
|
return prefix.encode() + token
|
||||||
|
return f"{prefix}{token}"
|
||||||
|
|
||||||
|
|
||||||
class StateManagerRedis(StateManager):
|
class StateManagerRedis(StateManager):
|
||||||
"""A state manager that stores states in redis."""
|
"""A state manager that stores states in redis."""
|
||||||
|
|
||||||
@ -3213,7 +3233,7 @@ class StateManagerRedis(StateManager):
|
|||||||
state = None
|
state = None
|
||||||
|
|
||||||
# Fetch the serialized substate from redis.
|
# Fetch the serialized substate from redis.
|
||||||
redis_state = await self.redis.get(token)
|
redis_state = await self.redis.get(prefix_redis_token(token))
|
||||||
|
|
||||||
if redis_state is not None:
|
if redis_state is not None:
|
||||||
# Deserialize the substate.
|
# Deserialize the substate.
|
||||||
@ -3268,7 +3288,8 @@ class StateManagerRedis(StateManager):
|
|||||||
# Check that we're holding the lock.
|
# Check that we're holding the lock.
|
||||||
if (
|
if (
|
||||||
lock_id is not None
|
lock_id is not None
|
||||||
and await self.redis.get(self._lock_key(token)) != lock_id
|
and await self.redis.get(prefix_redis_token(self._lock_key(token)))
|
||||||
|
!= lock_id
|
||||||
):
|
):
|
||||||
raise LockExpiredError(
|
raise LockExpiredError(
|
||||||
f"Lock expired for token {token} while processing. Consider increasing "
|
f"Lock expired for token {token} while processing. Consider increasing "
|
||||||
@ -3299,7 +3320,7 @@ class StateManagerRedis(StateManager):
|
|||||||
pickle_state = state._serialize()
|
pickle_state = state._serialize()
|
||||||
if pickle_state:
|
if pickle_state:
|
||||||
await self.redis.set(
|
await self.redis.set(
|
||||||
_substate_key(client_token, state),
|
prefix_redis_token(_substate_key(client_token, state)),
|
||||||
pickle_state,
|
pickle_state,
|
||||||
ex=self.token_expiration,
|
ex=self.token_expiration,
|
||||||
)
|
)
|
||||||
@ -3349,7 +3370,7 @@ class StateManagerRedis(StateManager):
|
|||||||
True if the lock was obtained.
|
True if the lock was obtained.
|
||||||
"""
|
"""
|
||||||
return await self.redis.set(
|
return await self.redis.set(
|
||||||
lock_key,
|
prefix_redis_token(lock_key),
|
||||||
lock_id,
|
lock_id,
|
||||||
px=self.lock_expiration,
|
px=self.lock_expiration,
|
||||||
nx=True, # only set if it doesn't exist
|
nx=True, # only set if it doesn't exist
|
||||||
|
@ -40,6 +40,7 @@ from reflex.state import (
|
|||||||
StateProxy,
|
StateProxy,
|
||||||
StateUpdate,
|
StateUpdate,
|
||||||
_substate_key,
|
_substate_key,
|
||||||
|
prefix_redis_token,
|
||||||
)
|
)
|
||||||
from reflex.testing import chdir
|
from reflex.testing import chdir
|
||||||
from reflex.utils import format, prerequisites, types
|
from reflex.utils import format, prerequisites, types
|
||||||
@ -1671,7 +1672,7 @@ async def test_state_manager_modify_state(
|
|||||||
"""
|
"""
|
||||||
async with state_manager.modify_state(substate_token) as state:
|
async with state_manager.modify_state(substate_token) as state:
|
||||||
if isinstance(state_manager, StateManagerRedis):
|
if isinstance(state_manager, StateManagerRedis):
|
||||||
assert await state_manager.redis.get(f"{token}_lock")
|
assert await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
|
||||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||||
assert token in state_manager._states_locks
|
assert token in state_manager._states_locks
|
||||||
assert state_manager._states_locks[token].locked()
|
assert state_manager._states_locks[token].locked()
|
||||||
@ -1681,7 +1682,9 @@ async def test_state_manager_modify_state(
|
|||||||
state.complex[3] = complex_1
|
state.complex[3] = complex_1
|
||||||
# lock should be dropped after exiting the context
|
# lock should be dropped after exiting the context
|
||||||
if isinstance(state_manager, StateManagerRedis):
|
if isinstance(state_manager, StateManagerRedis):
|
||||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
assert (
|
||||||
|
await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
|
||||||
|
) is None
|
||||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||||
assert not state_manager._states_locks[token].locked()
|
assert not state_manager._states_locks[token].locked()
|
||||||
|
|
||||||
@ -1723,7 +1726,9 @@ async def test_state_manager_contend(
|
|||||||
assert (await state_manager.get_state(substate_token)).num1 == exp_num1
|
assert (await state_manager.get_state(substate_token)).num1 == exp_num1
|
||||||
|
|
||||||
if isinstance(state_manager, StateManagerRedis):
|
if isinstance(state_manager, StateManagerRedis):
|
||||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
assert (
|
||||||
|
await state_manager.redis.get(prefix_redis_token(f"{token}_lock"))
|
||||||
|
) is None
|
||||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||||
assert token in state_manager._states_locks
|
assert token in state_manager._states_locks
|
||||||
assert not state_manager._states_locks[token].locked()
|
assert not state_manager._states_locks[token].locked()
|
||||||
@ -1783,7 +1788,7 @@ async def test_state_manager_lock_expire(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_state_manager_lock_expire_contend(
|
async def test_state_manager_lock_expire_contend(
|
||||||
state_manager_redis: StateManager, token: str, substate_token_redis: str
|
state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str
|
||||||
):
|
):
|
||||||
"""Test that the state manager lock expires and queued waiters proceed.
|
"""Test that the state manager lock expires and queued waiters proceed.
|
||||||
|
|
||||||
@ -1825,6 +1830,28 @@ async def test_state_manager_lock_expire_contend(
|
|||||||
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
|
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_manager_redis_prefix(
|
||||||
|
state_manager_redis: StateManagerRedis, substate_token_redis: str
|
||||||
|
):
|
||||||
|
"""Test that the state manager redis prefix is applied correctly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_manager_redis: A state manager instance.
|
||||||
|
substate_token_redis: A token + substate name for looking up in state manager.
|
||||||
|
"""
|
||||||
|
prefix = "test_prefix"
|
||||||
|
reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(prefix)
|
||||||
|
|
||||||
|
async with state_manager_redis.modify_state(substate_token_redis) as state:
|
||||||
|
state.num1 = 42
|
||||||
|
|
||||||
|
prefixed_token = prefix_redis_token(substate_token_redis)
|
||||||
|
assert prefixed_token == f"{prefix}{substate_token_redis}"
|
||||||
|
|
||||||
|
assert await state_manager_redis.redis.get(prefixed_token)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
|
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
|
||||||
"""Mock app fixture.
|
"""Mock app fixture.
|
||||||
|
@ -6,7 +6,12 @@ import pytest
|
|||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
import reflex as rx
|
import reflex as rx
|
||||||
from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key
|
from reflex.state import (
|
||||||
|
BaseState,
|
||||||
|
StateManager,
|
||||||
|
StateManagerRedis,
|
||||||
|
_substate_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Root(BaseState):
|
class Root(BaseState):
|
||||||
|
Loading…
Reference in New Issue
Block a user