From 629850162ab3f7fadaf72ea94518270e232e9c66 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 30 Aug 2024 17:26:10 -0700 Subject: [PATCH] implement disk state manager (#3826) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * implement disk state manager * put states inside of a folder * return root state all the time * factor out code * add docs for token expiration * cache states directory * call absolute on web directory * change dir to app path when testing the backend * remove accidental 🥒 * test disk for now * modify schema * only serialize specific stuff * fix issue in types * what is a kilometer * create folder if it doesn't exist in write * this code hates me * check if the file isn't empty * add try except clause * add check for directory again --- reflex/constants/base.py | 2 + reflex/state.py | 253 +++++++++++++++++++++++++++++++++++++-- reflex/testing.py | 11 +- reflex/utils/path_ops.py | 12 ++ tests/test_app.py | 5 +- tests/test_state.py | 31 +++-- 6 files changed, 290 insertions(+), 24 deletions(-) diff --git a/reflex/constants/base.py b/reflex/constants/base.py index a858a69b1..c9be9a462 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -45,6 +45,8 @@ class Dirs(SimpleNamespace): REFLEX_JSON = "reflex.json" # The name of the postcss config file. POSTCSS_JS = "postcss.config.js" + # The name of the states directory. + STATES = "states" class Reflex(SimpleNamespace): diff --git a/reflex/state.py b/reflex/state.py index 6712edd81..c0435d665 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -11,6 +11,7 @@ import os import uuid from abc import ABC, abstractmethod from collections import defaultdict +from pathlib import Path from types import FunctionType, MethodType from typing import ( TYPE_CHECKING, @@ -23,6 +24,7 @@ from typing import ( Optional, Sequence, Set, + Tuple, Type, Union, cast, @@ -52,7 +54,7 @@ from reflex.event import ( EventSpec, fix_events, ) -from reflex.utils import console, format, prerequisites, types +from reflex.utils import console, format, path_ops, prerequisites, types from reflex.utils.exceptions import ImmutableStateError, LockExpiredError from reflex.utils.exec import is_testing_env from reflex.utils.serializers import SerializedType, serialize, serializer @@ -2339,7 +2341,7 @@ class StateManager(Base, ABC): token_expiration=config.redis_token_expiration, lock_expiration=config.redis_lock_expiration, ) - return StateManagerMemory(state=state) + return StateManagerDisk(state=state) @abstractmethod async def get_state(self, token: str) -> BaseState: @@ -2446,6 +2448,244 @@ class StateManagerMemory(StateManager): await self.set_state(token, state) +def _default_token_expiration() -> int: + """Get the default token expiration time. + + Returns: + The default token expiration time. + """ + return get_config().redis_token_expiration + + +def state_to_schema( + state: BaseState, +) -> List[ + Tuple[ + str, + str, + Any, + Union[bool, None], + ] +]: + """Convert a state to a schema. + + Args: + state: The state to convert to a schema. + + Returns: + The schema. + """ + return list( + sorted( + ( + field_name, + model_field.name, + model_field.type_, + ( + model_field.required + if isinstance(model_field.required, bool) + else None + ), + ) + for field_name, model_field in state.__fields__.items() + ) + ) + + +class StateManagerDisk(StateManager): + """A state manager that stores states in memory.""" + + # The mapping of client ids to states. + states: Dict[str, BaseState] = {} + + # The mutex ensures the dict of mutexes is updated exclusively + _state_manager_lock = asyncio.Lock() + + # The dict of mutexes for each client + _states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({}) + + # The token expiration time (s). + token_expiration: int = pydantic.Field(default_factory=_default_token_expiration) + + class Config: + """The Pydantic config.""" + + fields = { + "_states_locks": {"exclude": True}, + } + keep_untouched = (functools.cached_property,) + + def __init__(self, state: Type[BaseState]): + """Create a new state manager. + + Args: + state: The state class to use. + """ + super().__init__(state=state) + + path_ops.mkdir(self.states_directory) + + self._purge_expired_states() + + @functools.cached_property + def states_directory(self) -> Path: + """Get the states directory. + + Returns: + The states directory. + """ + return prerequisites.get_web_dir() / constants.Dirs.STATES + + def _purge_expired_states(self): + """Purge expired states from the disk.""" + import time + + for path in path_ops.ls(self.states_directory): + # check path is a pickle file + if path.suffix != ".pkl": + continue + + # load last edited field from file + last_edited = path.stat().st_mtime + + # check if the file is older than the token expiration time + if time.time() - last_edited > self.token_expiration: + # remove the file + path.unlink() + + def token_path(self, token: str) -> Path: + """Get the path for a token. + + Args: + token: The token to get the path for. + + Returns: + The path for the token. + """ + return (self.states_directory / f"{token}.pkl").absolute() + + async def load_state(self, token: str, root_state: BaseState) -> BaseState: + """Load a state object based on the provided token. + + Args: + token: The token used to identify the state object. + root_state: The root state object. + + Returns: + The loaded state object. + """ + if token in self.states: + return self.states[token] + + client_token, substate_address = _split_substate_key(token) + + token_path = self.token_path(token) + + if token_path.exists(): + try: + with token_path.open(mode="rb") as file: + (substate_schema, substate) = dill.load(file) + if substate_schema == state_to_schema(substate): + await self.populate_substates(client_token, substate, root_state) + return substate + except Exception: + pass + + return root_state.get_substate(substate_address.split(".")[1:]) + + async def populate_substates( + self, client_token: str, state: BaseState, root_state: BaseState + ): + """Populate the substates of a state object. + + Args: + client_token: The client token. + state: The state object to populate. + root_state: The root state object. + """ + for substate in state.get_substates(): + substate_token = _substate_key(client_token, substate) + + substate = await self.load_state(substate_token, root_state) + + state.substates[substate.get_name()] = substate + substate.parent_state = state + + @override + async def get_state( + self, + token: str, + ) -> BaseState: + """Get the state for a token. + + Args: + token: The token to get the state for. + + Returns: + The state for the token. + """ + client_token, substate_address = _split_substate_key(token) + + root_state_token = _substate_key(client_token, substate_address.split(".")[0]) + + return await self.load_state( + root_state_token, self.state(_reflex_internal_init=True) + ) + + async def set_state_for_substate(self, client_token: str, substate: BaseState): + """Set the state for a substate. + + Args: + client_token: The client token. + substate: The substate to set. + """ + substate_token = _substate_key(client_token, substate) + + self.states[substate_token] = substate + + state_dilled = dill.dumps((state_to_schema(substate), substate), byref=True) + if not self.states_directory.exists(): + self.states_directory.mkdir(parents=True, exist_ok=True) + self.token_path(substate_token).write_bytes(state_dilled) + + for substate_substate in substate.substates.values(): + await self.set_state_for_substate(client_token, substate_substate) + + @override + async def set_state(self, token: str, state: BaseState): + """Set the state for a token. + + Args: + token: The token to set the state for. + state: The state to set. + """ + client_token, substate = _split_substate_key(token) + await self.set_state_for_substate(client_token, state) + + @override + @contextlib.asynccontextmanager + async def modify_state(self, token: str) -> AsyncIterator[BaseState]: + """Modify the state for a token while holding exclusive lock. + + Args: + token: The token to modify the state for. + + Yields: + The state for the token. + """ + # Memory state manager ignores the substate suffix and always returns the top-level state. + client_token, substate = _split_substate_key(token) + if client_token not in self._states_locks: + async with self._state_manager_lock: + if client_token not in self._states_locks: + self._states_locks[client_token] = asyncio.Lock() + + async with self._states_locks[client_token]: + state = await self.get_state(token) + yield state + await self.set_state(token, state) + + # Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes if not isinstance(State.validate.__func__, FunctionType): cython_function_or_method = type(State.validate.__func__) @@ -2474,15 +2714,6 @@ def _default_lock_expiration() -> int: return get_config().redis_lock_expiration -def _default_token_expiration() -> int: - """Get the default token expiration time. - - Returns: - The default token expiration time. - """ - return get_config().redis_token_expiration - - class StateManagerRedis(StateManager): """A state manager that stores states in redis.""" diff --git a/reflex/testing.py b/reflex/testing.py index c52396fc6..503db6c2f 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -45,6 +45,8 @@ import reflex.utils.prerequisites import reflex.utils.processes from reflex.state import ( BaseState, + StateManager, + StateManagerDisk, StateManagerMemory, StateManagerRedis, reload_state_module, @@ -126,7 +128,7 @@ class AppHarness: frontend_output_thread: Optional[threading.Thread] = None backend_thread: Optional[threading.Thread] = None backend: Optional[uvicorn.Server] = None - state_manager: Optional[StateManagerMemory | StateManagerRedis] = None + state_manager: Optional[StateManager] = None _frontends: list["WebDriver"] = dataclasses.field(default_factory=list) _decorated_pages: list = dataclasses.field(default_factory=list) @@ -290,6 +292,8 @@ class AppHarness: if isinstance(self.app_instance._state_manager, StateManagerRedis): # Create our own redis connection for testing. self.state_manager = StateManagerRedis.create(self.app_instance.state) + elif isinstance(self.app_instance._state_manager, StateManagerDisk): + self.state_manager = StateManagerDisk.create(self.app_instance.state) else: self.state_manager = self.app_instance._state_manager @@ -327,7 +331,8 @@ class AppHarness: ) ) self.backend.shutdown = self._get_backend_shutdown_handler() - self.backend_thread = threading.Thread(target=self.backend.run) + with chdir(self.app_path): + self.backend_thread = threading.Thread(target=self.backend.run) self.backend_thread.start() async def _reset_backend_state_manager(self): @@ -787,7 +792,7 @@ class AppHarness: raise RuntimeError("App is not running.") state_manager = self.app_instance.state_manager assert isinstance( - state_manager, StateManagerMemory + state_manager, (StateManagerMemory, StateManagerDisk) ), "Only works with memory state manager" if not self._poll_for( target=lambda: state_manager.states, diff --git a/reflex/utils/path_ops.py b/reflex/utils/path_ops.py index 39f2138f8..d38239a83 100644 --- a/reflex/utils/path_ops.py +++ b/reflex/utils/path_ops.py @@ -81,6 +81,18 @@ def mkdir(path: str | Path): Path(path).mkdir(parents=True, exist_ok=True) +def ls(path: str | Path) -> list[Path]: + """List the contents of a directory. + + Args: + path: The path to the directory. + + Returns: + A list of paths to the contents of the directory. + """ + return list(Path(path).iterdir()) + + def ln(src: str | Path, dest: str | Path, overwrite: bool = False) -> bool: """Create a symbolic link. diff --git a/tests/test_app.py b/tests/test_app.py index efaca4234..f933149f1 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -42,6 +42,7 @@ from reflex.state import ( OnLoadInternalState, RouterData, State, + StateManagerDisk, StateManagerMemory, StateManagerRedis, StateUpdate, @@ -1395,7 +1396,9 @@ def test_app_state_manager(): app.state_manager app._enable_state() assert app.state_manager is not None - assert isinstance(app.state_manager, (StateManagerMemory, StateManagerRedis)) + assert isinstance( + app.state_manager, (StateManagerMemory, StateManagerDisk, StateManagerRedis) + ) def test_generate_component(): diff --git a/tests/test_state.py b/tests/test_state.py index d34e771cb..ba8fc592f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -31,6 +31,7 @@ from reflex.state import ( RouterData, State, StateManager, + StateManagerDisk, StateManagerMemory, StateManagerRedis, StateProxy, @@ -1586,7 +1587,7 @@ async def test_state_with_invalid_yield(capsys, mock_app): assert "must only return/yield: None, Events or other EventHandlers" in captured.out -@pytest.fixture(scope="function", params=["in_process", "redis"]) +@pytest.fixture(scope="function", params=["in_process", "disk", "redis"]) def state_manager(request) -> Generator[StateManager, None, None]: """Instance of state manager parametrized for redis and in-process. @@ -1600,8 +1601,11 @@ def state_manager(request) -> Generator[StateManager, None, None]: if request.param == "redis": if not isinstance(state_manager, StateManagerRedis): pytest.skip("Test requires redis") - else: + elif request.param == "disk": # explicitly NOT using redis + state_manager = StateManagerDisk(state=TestState) + assert not state_manager._states_locks + else: state_manager = StateManagerMemory(state=TestState) assert not state_manager._states_locks @@ -1639,7 +1643,7 @@ async def test_state_manager_modify_state( async with state_manager.modify_state(substate_token) as state: if isinstance(state_manager, StateManagerRedis): assert await state_manager.redis.get(f"{token}_lock") - elif isinstance(state_manager, StateManagerMemory): + elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert state_manager._states_locks[token].locked() # Should be able to write proxy objects inside mutables @@ -1649,11 +1653,11 @@ async def test_state_manager_modify_state( # lock should be dropped after exiting the context if isinstance(state_manager, StateManagerRedis): assert (await state_manager.redis.get(f"{token}_lock")) is None - elif isinstance(state_manager, StateManagerMemory): + elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert not state_manager._states_locks[token].locked() # separate instances should NOT share locks - sm2 = StateManagerMemory(state=TestState) + sm2 = state_manager.__class__(state=TestState) assert sm2._state_manager_lock is state_manager._state_manager_lock assert not sm2._states_locks if state_manager._states_locks: @@ -1691,7 +1695,7 @@ async def test_state_manager_contend( if isinstance(state_manager, StateManagerRedis): assert (await state_manager.redis.get(f"{token}_lock")) is None - elif isinstance(state_manager, StateManagerMemory): + elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): assert token in state_manager._states_locks assert not state_manager._states_locks[token].locked() @@ -1831,7 +1835,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): assert child_state is not None parent_state = child_state.parent_state assert parent_state is not None - if isinstance(mock_app.state_manager, StateManagerMemory): + if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): mock_app.state_manager.states[parent_state.router.session.client_token] = ( parent_state ) @@ -1874,7 +1878,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): # For in-process store, only one instance of the state exists assert sp.__wrapped__ is grandchild_state else: - # When redis is used, a new+updated instance is assigned to the proxy + # When redis or disk is used, a new+updated instance is assigned to the proxy assert sp.__wrapped__ is not grandchild_state sp.value2 = "42" assert not sp._self_mutable # proxy is not mutable after exiting context @@ -2837,7 +2841,7 @@ async def test_get_state(mock_app: rx.App, token: str): _substate_key(token, ChildState2) ) assert isinstance(test_state, TestState) - if isinstance(mock_app.state_manager, StateManagerMemory): + if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): # All substates are available assert tuple(sorted(test_state.substates)) == ( ChildState.get_name(), @@ -2916,6 +2920,15 @@ async def test_get_state(mock_app: rx.App, token: str): ChildState2.get_name(), ChildState3.get_name(), ) + elif isinstance(mock_app.state_manager, StateManagerDisk): + # On disk, it's a new instance + assert new_test_state is not test_state + # All substates are available + assert tuple(sorted(new_test_state.substates)) == ( + ChildState.get_name(), + ChildState2.get_name(), + ChildState3.get_name(), + ) else: # With redis, we get a whole new instance assert new_test_state is not test_state