diff --git a/reflex/config.py b/reflex/config.py index a5d66cb52..719d5c21f 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -9,6 +9,8 @@ import urllib.parse from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union +from reflex.utils.exceptions import ConfigError + try: import pydantic.v1 as pydantic except ModuleNotFoundError: @@ -220,6 +222,9 @@ class Config(Base): # Number of gunicorn workers from user gunicorn_workers: Optional[int] = None + # Indicate which type of state manager to use + state_manager_mode: constants.StateManagerMode = constants.StateManagerMode.DISK + # Maximum expiration lock time for redis state manager redis_lock_expiration: int = constants.Expiration.LOCK @@ -235,6 +240,9 @@ class Config(Base): Args: *args: The args to pass to the Pydantic init method. **kwargs: The kwargs to pass to the Pydantic init method. + + Raises: + ConfigError: If some values in the config are invalid. """ super().__init__(*args, **kwargs) @@ -248,6 +256,14 @@ class Config(Base): self._non_default_attributes.update(kwargs) self._replace_defaults(**kwargs) + if ( + self.state_manager_mode == constants.StateManagerMode.REDIS + and not self.redis_url + ): + raise ConfigError( + "REDIS_URL is required when using the redis state manager." + ) + @property def module(self) -> str: """Get the module name of the app. diff --git a/reflex/constants/__init__.py b/reflex/constants/__init__.py index e974ab915..8e61a3717 100644 --- a/reflex/constants/__init__.py +++ b/reflex/constants/__init__.py @@ -63,6 +63,7 @@ from .route import ( RouteRegex, RouteVar, ) +from .state import StateManagerMode from .style import Tailwind __ALL__ = [ @@ -115,6 +116,7 @@ __ALL__ = [ SETTER_PREFIX, SKIP_COMPILE_ENV_VAR, SocketEvent, + StateManagerMode, Tailwind, Templates, CompileVars, diff --git a/reflex/constants/state.py b/reflex/constants/state.py new file mode 100644 index 000000000..aa0e2f97f --- /dev/null +++ b/reflex/constants/state.py @@ -0,0 +1,11 @@ +"""State-related constants.""" + +from enum import Enum + + +class StateManagerMode(str, Enum): + """State manager constants.""" + + DISK = "disk" + MEMORY = "memory" + REDIS = "redis" diff --git a/reflex/state.py b/reflex/state.py index f488602fa..38289d081 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -76,6 +76,7 @@ from reflex.utils.exceptions import ( DynamicRouteArgShadowsStateVar, EventHandlerShadowsBuiltInStateMethod, ImmutableStateError, + InvalidStateManagerMode, LockExpiredError, SetUndefinedStateVarError, StateSchemaMismatchError, @@ -2514,20 +2515,30 @@ class StateManager(Base, ABC): Args: state: The state class to use. + Raises: + InvalidStateManagerMode: If the state manager mode is invalid. + Returns: - The state manager (either disk or redis). + The state manager (either disk, memory or redis). """ - redis = prerequisites.get_redis() - if redis is not None: - # make sure expiration values are obtained only from the config object on creation - config = get_config() - return StateManagerRedis( - state=state, - redis=redis, - token_expiration=config.redis_token_expiration, - lock_expiration=config.redis_lock_expiration, - ) - return StateManagerDisk(state=state) + config = get_config() + if config.state_manager_mode == constants.StateManagerMode.DISK: + return StateManagerMemory(state=state) + if config.state_manager_mode == constants.StateManagerMode.MEMORY: + return StateManagerDisk(state=state) + if config.state_manager_mode == constants.StateManagerMode.REDIS: + redis = prerequisites.get_redis() + if redis is not None: + # make sure expiration values are obtained only from the config object on creation + return StateManagerRedis( + state=state, + redis=redis, + token_expiration=config.redis_token_expiration, + lock_expiration=config.redis_lock_expiration, + ) + raise InvalidStateManagerMode( + f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" + ) @abstractmethod async def get_state(self, token: str) -> BaseState: diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 8bce605b5..35f59a0e1 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -5,6 +5,14 @@ class ReflexError(Exception): """Base exception for all Reflex exceptions.""" +class ConfigError(ReflexError): + """Custom exception for config related errors.""" + + +class InvalidStateManagerMode(ReflexError, ValueError): + """Raised when an invalid state manager mode is provided.""" + + class ReflexRuntimeError(ReflexError, RuntimeError): """Custom RuntimeError for Reflex.""" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 6246618f6..ae74cacce 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3201,6 +3201,7 @@ import reflex as rx config = rx.Config( app_name="project1", redis_url="redis://localhost:6379", + state_manager_mode="redis", {config_items} ) """