Merge 1079708031
into 000938414f
This commit is contained in:
commit
b5a15f1179
@ -564,6 +564,9 @@ class EnvironmentVariables:
|
||||
# The maximum size of the reflex state in kilobytes.
|
||||
REFLEX_STATE_SIZE_LIMIT: EnvVar[int] = env_var(1000)
|
||||
|
||||
# Optional redis key prefix for the state manager.
|
||||
REFLEX_REDIS_PREFIX: EnvVar[Optional[str]] = env_var(None)
|
||||
|
||||
|
||||
environment = EnvironmentVariables()
|
||||
|
||||
|
@ -43,7 +43,7 @@ from sqlalchemy.orm import DeclarativeBase
|
||||
from typing_extensions import Self
|
||||
|
||||
from reflex import event
|
||||
from reflex.config import PerformanceMode, get_config
|
||||
from reflex.config import EnvironmentVariables, PerformanceMode, get_config
|
||||
from reflex.istate.data import RouterData
|
||||
from reflex.istate.storage import ClientStorageBase
|
||||
from reflex.model import Model
|
||||
@ -3155,6 +3155,25 @@ def _default_lock_expiration() -> int:
|
||||
return get_config().redis_lock_expiration
|
||||
|
||||
|
||||
TOKEN_TYPE = TypeVar("TOKEN_TYPE", str, bytes)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def prefix_redis_token_str(token: str) -> str:
|
||||
"""Prefix the token with the redis prefix.
|
||||
|
||||
Args:
|
||||
token: The token to prefix.
|
||||
|
||||
Returns:
|
||||
The prefixed token.
|
||||
"""
|
||||
prefix = EnvironmentVariables.REFLEX_REDIS_PREFIX.get()
|
||||
if not prefix:
|
||||
return token
|
||||
return f"{prefix}{token}"
|
||||
|
||||
|
||||
class StateManagerRedis(StateManager):
|
||||
"""A state manager that stores states in redis."""
|
||||
|
||||
@ -3292,7 +3311,7 @@ class StateManagerRedis(StateManager):
|
||||
state = None
|
||||
|
||||
# Fetch the serialized substate from redis.
|
||||
redis_state = await self.redis.get(token)
|
||||
redis_state = await self.redis.get(prefix_redis_token_str(token))
|
||||
|
||||
if redis_state is not None:
|
||||
# Deserialize the substate.
|
||||
@ -3378,7 +3397,7 @@ class StateManagerRedis(StateManager):
|
||||
pickle_state = state._serialize()
|
||||
if pickle_state:
|
||||
await self.redis.set(
|
||||
_substate_key(client_token, state),
|
||||
prefix_redis_token_str(_substate_key(client_token, state)),
|
||||
pickle_state,
|
||||
ex=self.token_expiration,
|
||||
)
|
||||
@ -3415,7 +3434,7 @@ class StateManagerRedis(StateManager):
|
||||
"""
|
||||
# All substates share the same lock domain, so ignore any substate path suffix.
|
||||
client_token = _split_substate_key(token)[0]
|
||||
return f"{client_token}_lock".encode()
|
||||
return prefix_redis_token_str(f"{client_token}_lock").encode()
|
||||
|
||||
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
|
||||
"""Try to get a redis lock for a token.
|
||||
|
@ -10,7 +10,7 @@ import os
|
||||
import sys
|
||||
import threading
|
||||
from textwrap import dedent
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, Generator, List, Optional, Union
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
@ -42,6 +42,7 @@ from reflex.state import (
|
||||
StateProxy,
|
||||
StateUpdate,
|
||||
_substate_key,
|
||||
prefix_redis_token_str,
|
||||
)
|
||||
from reflex.testing import chdir
|
||||
from reflex.utils import format, prerequisites, types
|
||||
@ -1673,7 +1674,9 @@ async def test_state_manager_modify_state(
|
||||
"""
|
||||
async with state_manager.modify_state(substate_token) as state:
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert await state_manager.redis.get(f"{token}_lock")
|
||||
assert await state_manager.redis.get(
|
||||
prefix_redis_token_str(f"{token}_lock")
|
||||
)
|
||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
assert token in state_manager._states_locks
|
||||
assert state_manager._states_locks[token].locked()
|
||||
@ -1683,7 +1686,9 @@ async def test_state_manager_modify_state(
|
||||
state.complex[3] = complex_1
|
||||
# lock should be dropped after exiting the context
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||
assert (
|
||||
await state_manager.redis.get(prefix_redis_token_str(f"{token}_lock"))
|
||||
) is None
|
||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
assert not state_manager._states_locks[token].locked()
|
||||
|
||||
@ -1725,7 +1730,9 @@ async def test_state_manager_contend(
|
||||
assert (await state_manager.get_state(substate_token)).num1 == exp_num1
|
||||
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||
assert (
|
||||
await state_manager.redis.get(prefix_redis_token_str(f"{token}_lock"))
|
||||
) is None
|
||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
assert token in state_manager._states_locks
|
||||
assert not state_manager._states_locks[token].locked()
|
||||
@ -1785,7 +1792,7 @@ async def test_state_manager_lock_expire(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_lock_expire_contend(
|
||||
state_manager_redis: StateManager, token: str, substate_token_redis: str
|
||||
state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str
|
||||
):
|
||||
"""Test that the state manager lock expires and queued waiters proceed.
|
||||
|
||||
@ -1827,6 +1834,41 @@ async def test_state_manager_lock_expire_contend(
|
||||
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def redis_prefix() -> Generator[str, None, None]:
|
||||
"""Fixture for redis prefix.
|
||||
|
||||
Yields:
|
||||
A redis prefix.
|
||||
"""
|
||||
prefix = "test_prefix"
|
||||
reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(prefix)
|
||||
yield prefix
|
||||
reflex.config.EnvironmentVariables.REFLEX_REDIS_PREFIX.set(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_redis_prefix(
|
||||
state_manager_redis: StateManagerRedis,
|
||||
substate_token_redis: str,
|
||||
redis_prefix: str,
|
||||
):
|
||||
"""Test that the state manager redis prefix is applied correctly.
|
||||
|
||||
Args:
|
||||
state_manager_redis: A state manager instance.
|
||||
substate_token_redis: A token + substate name for looking up in state manager.
|
||||
redis_prefix: A redis prefix.
|
||||
"""
|
||||
async with state_manager_redis.modify_state(substate_token_redis) as state:
|
||||
state.num1 = 42
|
||||
|
||||
prefixed_token = prefix_redis_token_str(substate_token_redis)
|
||||
assert prefixed_token == f"{redis_prefix}{substate_token_redis}"
|
||||
|
||||
assert await state_manager_redis.redis.get(prefixed_token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
|
||||
"""Mock app fixture.
|
||||
|
Loading…
Reference in New Issue
Block a user