let users pick state manager mode

This commit is contained in:
Lendemor 2024-10-02 18:13:39 +02:00
parent c08720ed1a
commit fa564989cf
5 changed files with 60 additions and 12 deletions

View File

@ -8,6 +8,8 @@ import sys
import urllib.parse import urllib.parse
from typing import Any, Dict, List, Optional, Set from typing import Any, Dict, List, Optional, Set
from reflex.utils.exceptions import ConfigError
try: try:
import pydantic.v1 as pydantic import pydantic.v1 as pydantic
except ModuleNotFoundError: except ModuleNotFoundError:
@ -219,6 +221,9 @@ class Config(Base):
# Number of gunicorn workers from user # Number of gunicorn workers from user
gunicorn_workers: Optional[int] = None gunicorn_workers: Optional[int] = None
# Indicate which type of state manager to use
state_manager_mode: constants.StateManagerMode = constants.StateManagerMode.DISK
# Maximum expiration lock time for redis state manager # Maximum expiration lock time for redis state manager
redis_lock_expiration: int = constants.Expiration.LOCK redis_lock_expiration: int = constants.Expiration.LOCK
@ -234,6 +239,9 @@ class Config(Base):
Args: Args:
*args: The args to pass to the Pydantic init method. *args: The args to pass to the Pydantic init method.
**kwargs: The kwargs to pass to the Pydantic init method. **kwargs: The kwargs to pass to the Pydantic init method.
Raises:
ConfigError: If some values in the config are invalid.
""" """
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -247,6 +255,14 @@ class Config(Base):
self._non_default_attributes.update(kwargs) self._non_default_attributes.update(kwargs)
self._replace_defaults(**kwargs) self._replace_defaults(**kwargs)
if (
self.state_manager_mode == constants.StateManagerMode.REDIS
and not self.redis_url
):
raise ConfigError(
"REDIS_URL is required when using the redis state manager."
)
@property @property
def module(self) -> str: def module(self) -> str:
"""Get the module name of the app. """Get the module name of the app.

View File

@ -63,6 +63,7 @@ from .route import (
RouteRegex, RouteRegex,
RouteVar, RouteVar,
) )
from .state import StateManagerMode
from .style import Tailwind from .style import Tailwind
__ALL__ = [ __ALL__ = [
@ -115,6 +116,7 @@ __ALL__ = [
SETTER_PREFIX, SETTER_PREFIX,
SKIP_COMPILE_ENV_VAR, SKIP_COMPILE_ENV_VAR,
SocketEvent, SocketEvent,
StateManagerMode,
Tailwind, Tailwind,
Templates, Templates,
CompileVars, CompileVars,

11
reflex/constants/state.py Normal file
View File

@ -0,0 +1,11 @@
"""State-related constants."""
from enum import Enum
class StateManagerMode(str, Enum):
"""State manager constants."""
DISK = "disk"
MEMORY = "memory"
REDIS = "redis"

View File

@ -73,6 +73,7 @@ from reflex.utils.exceptions import (
DynamicRouteArgShadowsStateVar, DynamicRouteArgShadowsStateVar,
EventHandlerShadowsBuiltInStateMethod, EventHandlerShadowsBuiltInStateMethod,
ImmutableStateError, ImmutableStateError,
InvalidStateManagerMode,
LockExpiredError, LockExpiredError,
) )
from reflex.utils.exec import is_testing_env from reflex.utils.exec import is_testing_env
@ -2490,20 +2491,30 @@ class StateManager(Base, ABC):
Args: Args:
state: The state class to use. state: The state class to use.
Raises:
InvalidStateManagerMode: If the state manager mode is invalid.
Returns: Returns:
The state manager (either memory or redis). The state manager (either disk, memory or redis).
""" """
redis = prerequisites.get_redis() config = get_config()
if redis is not None: if config.state_manager_mode == constants.StateManagerMode.DISK:
# make sure expiration values are obtained only from the config object on creation return StateManagerMemory(state=state)
config = get_config() if config.state_manager_mode == constants.StateManagerMode.MEMORY:
return StateManagerRedis( return StateManagerDisk(state=state)
state=state, if config.state_manager_mode == constants.StateManagerMode.REDIS:
redis=redis, redis = prerequisites.get_redis()
token_expiration=config.redis_token_expiration, if redis is not None:
lock_expiration=config.redis_lock_expiration, # make sure expiration values are obtained only from the config object on creation
) return StateManagerRedis(
return StateManagerDisk(state=state) state=state,
redis=redis,
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
)
raise InvalidStateManagerMode(
f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}"
)
@abstractmethod @abstractmethod
async def get_state(self, token: str) -> BaseState: async def get_state(self, token: str) -> BaseState:

View File

@ -5,6 +5,14 @@ class ReflexError(Exception):
"""Base exception for all Reflex exceptions.""" """Base exception for all Reflex exceptions."""
class ConfigError(ReflexError):
"""Custom exception for config related errors."""
class InvalidStateManagerMode(ReflexError, ValueError):
"""Raised when an invalid state manager mode is provided."""
class ReflexRuntimeError(ReflexError, RuntimeError): class ReflexRuntimeError(ReflexError, RuntimeError):
"""Custom RuntimeError for Reflex.""" """Custom RuntimeError for Reflex."""