let users pick state manager mode (#4041)

This commit is contained in:
Thomas Brandého 2024-10-10 12:22:35 -07:00 committed by GitHub
parent 1aed39a848
commit 6f586c8b8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 61 additions and 12 deletions

View File

@ -9,6 +9,8 @@ import urllib.parse
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
from reflex.utils.exceptions import ConfigError
try: try:
import pydantic.v1 as pydantic import pydantic.v1 as pydantic
except ModuleNotFoundError: except ModuleNotFoundError:
@ -220,6 +222,9 @@ class Config(Base):
# Number of gunicorn workers from user # Number of gunicorn workers from user
gunicorn_workers: Optional[int] = None gunicorn_workers: Optional[int] = None
# Indicate which type of state manager to use
state_manager_mode: constants.StateManagerMode = constants.StateManagerMode.DISK
# Maximum expiration lock time for redis state manager # Maximum expiration lock time for redis state manager
redis_lock_expiration: int = constants.Expiration.LOCK redis_lock_expiration: int = constants.Expiration.LOCK
@ -235,6 +240,9 @@ class Config(Base):
Args: Args:
*args: The args to pass to the Pydantic init method. *args: The args to pass to the Pydantic init method.
**kwargs: The kwargs to pass to the Pydantic init method. **kwargs: The kwargs to pass to the Pydantic init method.
Raises:
ConfigError: If some values in the config are invalid.
""" """
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -248,6 +256,14 @@ class Config(Base):
self._non_default_attributes.update(kwargs) self._non_default_attributes.update(kwargs)
self._replace_defaults(**kwargs) self._replace_defaults(**kwargs)
if (
self.state_manager_mode == constants.StateManagerMode.REDIS
and not self.redis_url
):
raise ConfigError(
"REDIS_URL is required when using the redis state manager."
)
@property @property
def module(self) -> str: def module(self) -> str:
"""Get the module name of the app. """Get the module name of the app.

View File

@ -63,6 +63,7 @@ from .route import (
RouteRegex, RouteRegex,
RouteVar, RouteVar,
) )
from .state import StateManagerMode
from .style import Tailwind from .style import Tailwind
__ALL__ = [ __ALL__ = [
@ -115,6 +116,7 @@ __ALL__ = [
SETTER_PREFIX, SETTER_PREFIX,
SKIP_COMPILE_ENV_VAR, SKIP_COMPILE_ENV_VAR,
SocketEvent, SocketEvent,
StateManagerMode,
Tailwind, Tailwind,
Templates, Templates,
CompileVars, CompileVars,

11
reflex/constants/state.py Normal file
View File

@ -0,0 +1,11 @@
"""State-related constants."""
from enum import Enum
class StateManagerMode(str, Enum):
"""State manager constants."""
DISK = "disk"
MEMORY = "memory"
REDIS = "redis"

View File

@ -76,6 +76,7 @@ from reflex.utils.exceptions import (
DynamicRouteArgShadowsStateVar, DynamicRouteArgShadowsStateVar,
EventHandlerShadowsBuiltInStateMethod, EventHandlerShadowsBuiltInStateMethod,
ImmutableStateError, ImmutableStateError,
InvalidStateManagerMode,
LockExpiredError, LockExpiredError,
SetUndefinedStateVarError, SetUndefinedStateVarError,
StateSchemaMismatchError, StateSchemaMismatchError,
@ -2514,20 +2515,30 @@ class StateManager(Base, ABC):
Args: Args:
state: The state class to use. state: The state class to use.
Raises:
InvalidStateManagerMode: If the state manager mode is invalid.
Returns: Returns:
The state manager (either disk or redis). The state manager (either disk, memory or redis).
""" """
redis = prerequisites.get_redis() config = get_config()
if redis is not None: if config.state_manager_mode == constants.StateManagerMode.DISK:
# make sure expiration values are obtained only from the config object on creation return StateManagerMemory(state=state)
config = get_config() if config.state_manager_mode == constants.StateManagerMode.MEMORY:
return StateManagerRedis( return StateManagerDisk(state=state)
state=state, if config.state_manager_mode == constants.StateManagerMode.REDIS:
redis=redis, redis = prerequisites.get_redis()
token_expiration=config.redis_token_expiration, if redis is not None:
lock_expiration=config.redis_lock_expiration, # make sure expiration values are obtained only from the config object on creation
) return StateManagerRedis(
return StateManagerDisk(state=state) state=state,
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
)
raise InvalidStateManagerMode(
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
)
@abstractmethod @abstractmethod
async def get_state(self, token: str) -> BaseState: async def get_state(self, token: str) -> BaseState:

View File

@ -5,6 +5,14 @@ class ReflexError(Exception):
"""Base exception for all Reflex exceptions.""" """Base exception for all Reflex exceptions."""
class ConfigError(ReflexError):
"""Custom exception for config related errors."""
class InvalidStateManagerMode(ReflexError, ValueError):
"""Raised when an invalid state manager mode is provided."""
class ReflexRuntimeError(ReflexError, RuntimeError): class ReflexRuntimeError(ReflexError, RuntimeError):
"""Custom RuntimeError for Reflex.""" """Custom RuntimeError for Reflex."""

View File

@ -3201,6 +3201,7 @@ import reflex as rx
config = rx.Config( config = rx.Config(
app_name="project1", app_name="project1",
redis_url="redis://localhost:6379", redis_url="redis://localhost:6379",
state_manager_mode="redis",
{config_items} {config_items}
) )
""" """