diff --git a/reflex/config.py b/reflex/config.py index 049cc2e83..a57dfd731 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -545,6 +545,9 @@ class EnvironmentVariables: # Where to save screenshots when tests fail. SCREENSHOT_DIR: EnvVar[Optional[Path]] = env_var(None) + # Optional redis key prefix for the state manager. + REFLEX_REDIS_PREFIX: EnvVar[Optional[str]] = env_var(None) + environment = EnvironmentVariables() diff --git a/reflex/state.py b/reflex/state.py index a53df7b6f..9fbd370b8 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -42,7 +42,7 @@ from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self from reflex import event -from reflex.config import get_config +from reflex.config import EnvironmentVariables, get_config from reflex.istate.data import RouterData from reflex.istate.storage import ( ClientStorageBase, @@ -3076,6 +3076,26 @@ def _default_lock_expiration() -> int: return get_config().redis_lock_expiration +TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes) + + +def prefix_redis_token(token: TOKEN_TYPE) -> TOKEN_TYPE: + """Prefix the token with the redis prefix. + + Args: + token: The token to prefix. + + Returns: + The prefixed token. + """ + prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get() + if not prefix: + return token + if isinstance(token, bytes): + return prefix.encode() + token + return f"{prefix}{token}" + + class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" @@ -3213,7 +3233,7 @@ class StateManagerRedis(StateManager): state = None # Fetch the serialized substate from redis. - redis_state = await self.redis.get(token) + redis_state = await self.redis.get(prefix_redis_token(token)) if redis_state is not None: # Deserialize the substate. @@ -3268,7 +3288,8 @@ class StateManagerRedis(StateManager): # Check that we're holding the lock. if ( lock_id is not None - and await self.redis.get(self._lock_key(token)) != lock_id + and await self.redis.get(prefix_redis_token(self._lock_key(token))) + != lock_id ): raise LockExpiredError( f"Lock expired for token {token} while processing. Consider increasing " @@ -3299,7 +3320,7 @@ class StateManagerRedis(StateManager): pickle_state = state._serialize() if pickle_state: await self.redis.set( - _substate_key(client_token, state), + prefix_redis_token(_substate_key(client_token, state)), pickle_state, ex=self.token_expiration, ) @@ -3349,7 +3370,7 @@ class StateManagerRedis(StateManager): True if the lock was obtained. """ return await self.redis.set( - lock_key, + prefix_redis_token(lock_key), lock_id, px=self.lock_expiration, nx=True, # only set if it doesn't exist diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 83e348cd2..685b0ccfc 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -40,6 +40,7 @@ from reflex.state import ( StateProxy, StateUpdate, _substate_key, + prefix_redis_token, ) from reflex.testing import chdir from reflex.utils import format, prerequisites, types @@ -1671,7 +1672,7 @@ async def test_state_manager_modify_state( """ async with state_manager.modify_state(substate_token) as state: if isinstance(state_manager, StateManagerRedis): - assert await state_manager.redis.get(f"{token}_lock") + assert await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert state_manager._states_locks[token].locked() @@ -1681,7 +1682,9 @@ async def test_state_manager_modify_state( state.complex[3] = complex_1 # lock should be dropped after exiting the context if isinstance(state_manager, StateManagerRedis): - assert (await state_manager.redis.get(f"{token}_lock")) is None + assert ( + await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) + ) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert not state_manager._states_locks[token].locked() @@ -1723,7 +1726,9 @@ async def test_state_manager_contend( assert (await state_manager.get_state(substate_token)).num1 == exp_num1 if isinstance(state_manager, StateManagerRedis): - assert (await state_manager.redis.get(f"{token}_lock")) is None + assert ( + await state_manager.redis.get(prefix_redis_token(f"{token}_lock")) + ) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert not state_manager._states_locks[token].locked() @@ -1783,7 +1788,7 @@ async def test_state_manager_lock_expire( @pytest.mark.asyncio async def test_state_manager_lock_expire_contend( - state_manager_redis: StateManager, token: str, substate_token_redis: str + state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str ): """Test that the state manager lock expires and queued waiters proceed. @@ -1825,6 +1830,28 @@ async def test_state_manager_lock_expire_contend( assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 +@pytest.mark.asyncio +async def test_state_manager_redis_prefix( + state_manager_redis: StateManagerRedis, substate_token_redis: str +): + """Test that the state manager redis prefix is applied correctly. + + Args: + state_manager_redis: A state manager instance. + substate_token_redis: A token + substate name for looking up in state manager. + """ + prefix = "test_prefix" + reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(prefix) + + async with state_manager_redis.modify_state(substate_token_redis) as state: + state.num1 = 42 + + prefixed_token = prefix_redis_token(substate_token_redis) + assert prefixed_token == f"{prefix}{substate_token_redis}" + + assert await state_manager_redis.redis.get(prefixed_token) + + @pytest.fixture(scope="function") def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: """Mock app fixture. diff --git a/tests/units/test_state_tree.py b/tests/units/test_state_tree.py index ebdd877de..aa853af9b 100644 --- a/tests/units/test_state_tree.py +++ b/tests/units/test_state_tree.py @@ -6,7 +6,12 @@ import pytest import pytest_asyncio import reflex as rx -from reflex.state import BaseState, StateManager, StateManagerRedis, _substate_key +from reflex.state import ( + BaseState, + StateManager, + StateManagerRedis, + _substate_key, +) class Root(BaseState):