let users pick state manager mode (#4041)

This commit is contained in:
Thomas Brandého 2024-10-10 12:22:35 -07:00 committed by GitHub
parent 1aed39a848
commit 6f586c8b8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 61 additions and 12 deletions

View File

@ -9,6 +9,8 @@ import urllib.parse
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
from reflex.utils.exceptions import ConfigError
try:
import pydantic.v1 as pydantic
except ModuleNotFoundError:
@ -220,6 +222,9 @@ class Config(Base):
# Number of gunicorn workers from user
gunicorn_workers: Optional[int] = None
# Indicate which type of state manager to use
state_manager_mode: constants.StateManagerMode = constants.StateManagerMode.DISK
# Maximum expiration lock time for redis state manager
redis_lock_expiration: int = constants.Expiration.LOCK
@ -235,6 +240,9 @@ class Config(Base):
Args:
*args: The args to pass to the Pydantic init method.
**kwargs: The kwargs to pass to the Pydantic init method.
Raises:
ConfigError: If some values in the config are invalid.
"""
super().__init__(*args, **kwargs)
@ -248,6 +256,14 @@ class Config(Base):
self._non_default_attributes.update(kwargs)
self._replace_defaults(**kwargs)
if (
self.state_manager_mode == constants.StateManagerMode.REDIS
and not self.redis_url
):
raise ConfigError(
"REDIS_URL is required when using the redis state manager."
)
@property
def module(self) -> str:
"""Get the module name of the app.

View File

@ -63,6 +63,7 @@ from .route import (
RouteRegex,
RouteVar,
)
from .state import StateManagerMode
from .style import Tailwind
__ALL__ = [
@ -115,6 +116,7 @@ __ALL__ = [
SETTER_PREFIX,
SKIP_COMPILE_ENV_VAR,
SocketEvent,
StateManagerMode,
Tailwind,
Templates,
CompileVars,

11
reflex/constants/state.py Normal file
View File

@ -0,0 +1,11 @@
"""State-related constants."""
from enum import Enum
class StateManagerMode(str, Enum):
"""State manager constants."""
DISK = "disk"
MEMORY = "memory"
REDIS = "redis"

View File

@ -76,6 +76,7 @@ from reflex.utils.exceptions import (
DynamicRouteArgShadowsStateVar,
EventHandlerShadowsBuiltInStateMethod,
ImmutableStateError,
InvalidStateManagerMode,
LockExpiredError,
SetUndefinedStateVarError,
StateSchemaMismatchError,
@ -2514,20 +2515,30 @@ class StateManager(Base, ABC):
Args:
state: The state class to use.
Raises:
InvalidStateManagerMode: If the state manager mode is invalid.
Returns:
The state manager (either disk or redis).
The state manager (either disk, memory or redis).
"""
redis = prerequisites.get_redis()
if redis is not None:
# make sure expiration values are obtained only from the config object on creation
config = get_config()
return StateManagerRedis(
state=state,
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
)
return StateManagerDisk(state=state)
config = get_config()
if config.state_manager_mode == constants.StateManagerMode.DISK:
return StateManagerMemory(state=state)
if config.state_manager_mode == constants.StateManagerMode.MEMORY:
return StateManagerDisk(state=state)
if config.state_manager_mode == constants.StateManagerMode.REDIS:
redis = prerequisites.get_redis()
if redis is not None:
# make sure expiration values are obtained only from the config object on creation
return StateManagerRedis(
state=state,
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
)
raise InvalidStateManagerMode(
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
)
@abstractmethod
async def get_state(self, token: str) -> BaseState:

View File

@ -5,6 +5,14 @@ class ReflexError(Exception):
"""Base exception for all Reflex exceptions."""
class ConfigError(ReflexError):
"""Custom exception for config related errors."""
class InvalidStateManagerMode(ReflexError, ValueError):
"""Raised when an invalid state manager mode is provided."""
class ReflexRuntimeError(ReflexError, RuntimeError):
"""Custom RuntimeError for Reflex."""

View File

@ -3201,6 +3201,7 @@ import reflex as rx
config = rx.Config(
app_name="project1",
redis_url="redis://localhost:6379",
state_manager_mode="redis",
{config_items}
)
"""