diff --git a/reflex/config.py b/reflex/config.py index 88230cefe..a010fce32 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -564,6 +564,9 @@ class EnvironmentVariables: # The maximum size of the reflex state in kilobytes. REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000) + # Optional redis key prefix for the state manager. + REFLEX_REDIS_PREFIX: EnvVar[Optional[str]] = env_var(None) + environment = EnvironmentVariables() diff --git a/reflex/state.py b/reflex/state.py index 349dc59e9..f12faaddf 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -43,7 +43,7 @@ from sqlalchemy.orm import DeclarativeBase from typing_extensions import Self from reflex import event -from reflex.config import PerformanceMode, get_config +from reflex.config import EnvironmentVariables, PerformanceMode, get_config from reflex.istate.data import RouterData from reflex.istate.storage import ClientStorageBase from reflex.model import Model @@ -3155,6 +3155,25 @@ def _default_lock_expiration() -> int: return get_config().redis_lock_expiration +TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes) + + +@functools.lru_cache +def prefix_redis_token_str(token: str) -> str: + """Prefix the token with the redis prefix. + + Args: + token: The token to prefix. + + Returns: + The prefixed token. + """ + prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get() + if not prefix: + return token + return f"{prefix}{token}" + + class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" @@ -3292,7 +3311,7 @@ class StateManagerRedis(StateManager): state = None # Fetch the serialized substate from redis. - redis_state = await self.redis.get(token) + redis_state = await self.redis.get(prefix_redis_token_str(token)) if redis_state is not None: # Deserialize the substate. @@ -3378,7 +3397,7 @@ class StateManagerRedis(StateManager): pickle_state = state._serialize() if pickle_state: await self.redis.set( - _substate_key(client_token, state), + prefix_redis_token_str(_substate_key(client_token, state)), pickle_state, ex=self.token_expiration, ) @@ -3415,7 +3434,7 @@ class StateManagerRedis(StateManager): """ # All substates share the same lock domain, so ignore any substate path suffix. client_token = _split_substate_key(token)[0] - return f"{client_token}_lock".encode() + return prefix_redis_token_str(f"{client_token}_lock").encode() async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None: """Try to get a redis lock for a token. diff --git a/tests/units/test_state.py b/tests/units/test_state.py index c8a52e6c0..b0ab0b804 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -10,7 +10,7 @@ import os import sys import threading from textwrap import dedent -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Optional, Union from unittest.mock import AsyncMock, Mock import pytest @@ -42,6 +42,7 @@ from reflex.state import ( StateProxy, StateUpdate, _substate_key, + prefix_redis_token_str, ) from reflex.testing import chdir from reflex.utils import format, prerequisites, types @@ -1673,7 +1674,9 @@ async def test_state_manager_modify_state( """ async with state_manager.modify_state(substate_token) as state: if isinstance(state_manager, StateManagerRedis): - assert await state_manager.redis.get(f"{token}_lock") + assert await state_manager.redis.get( + prefix_redis_token_str(f"{token}_lock") + ) elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert state_manager._states_locks[token].locked() @@ -1683,7 +1686,9 @@ async def test_state_manager_modify_state( state.complex[3] = complex_1 # lock should be dropped after exiting the context if isinstance(state_manager, StateManagerRedis): - assert (await state_manager.redis.get(f"{token}_lock")) is None + assert ( + await state_manager.redis.get(prefix_redis_token_str(f"{token}_lock")) + ) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert not state_manager._states_locks[token].locked() @@ -1725,7 +1730,9 @@ async def test_state_manager_contend( assert (await state_manager.get_state(substate_token)).num1 == exp_num1 if isinstance(state_manager, StateManagerRedis): - assert (await state_manager.redis.get(f"{token}_lock")) is None + assert ( + await state_manager.redis.get(prefix_redis_token_str(f"{token}_lock")) + ) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert not state_manager._states_locks[token].locked() @@ -1785,7 +1792,7 @@ async def test_state_manager_lock_expire( @pytest.mark.asyncio async def test_state_manager_lock_expire_contend( - state_manager_redis: StateManager, token: str, substate_token_redis: str + state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str ): """Test that the state manager lock expires and queued waiters proceed. @@ -1827,6 +1834,41 @@ async def test_state_manager_lock_expire_contend( assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 +@pytest.fixture(scope="function") +def redis_prefix() -> Generator[str, None, None]: + """Fixture for redis prefix. + + Yields: + A redis prefix. + """ + prefix = "test_prefix" + reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(prefix) + yield prefix + reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(None) + + +@pytest.mark.asyncio +async def test_state_manager_redis_prefix( + state_manager_redis: StateManagerRedis, + substate_token_redis: str, + redis_prefix: str, +): + """Test that the state manager redis prefix is applied correctly. + + Args: + state_manager_redis: A state manager instance. + substate_token_redis: A token + substate name for looking up in state manager. + redis_prefix: A redis prefix. + """ + async with state_manager_redis.modify_state(substate_token_redis) as state: + state.num1 = 42 + + prefixed_token = prefix_redis_token_str(substate_token_redis) + assert prefixed_token == f"{redis_prefix}{substate_token_redis}" + + assert await state_manager_redis.redis.get(prefixed_token) + + @pytest.fixture(scope="function") def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: """Mock app fixture.