implement disk state manager (#3826)

* implement disk state manager

* put states inside of a folder

* return root state all the time

* factor out code

* add docs for token expiration

* cache states directory

* call absolute on web directory

* change dir to app path when testing the backend

* remove accidental 🥒

* test disk for now

* modify schema

* only serialize specific stuff

* fix issue in types

* what is a kilometer

* create folder if it doesn't exist in write

* this code hates me

* check if the file isn't empty

* add try except clause

* add check for directory again
This commit is contained in:
Khaleel Al-Adhami 2024-08-30 17:26:10 -07:00 committed by GitHub
parent c457b43ab1
commit 629850162a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 290 additions and 24 deletions

View File

@ -45,6 +45,8 @@ class Dirs(SimpleNamespace):
REFLEX_JSON = "reflex.json"
# The name of the postcss config file.
POSTCSS_JS = "postcss.config.js"
# The name of the states directory.
STATES = "states"
class Reflex(SimpleNamespace):

View File

@ -11,6 +11,7 @@ import os
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path
from types import FunctionType, MethodType
from typing import (
TYPE_CHECKING,
@ -23,6 +24,7 @@ from typing import (
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
cast,
@ -52,7 +54,7 @@ from reflex.event import (
EventSpec,
fix_events,
)
from reflex.utils import console, format, prerequisites, types
from reflex.utils import console, format, path_ops, prerequisites, types
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import SerializedType, serialize, serializer
@ -2339,7 +2341,7 @@ class StateManager(Base, ABC):
token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration,
)
return StateManagerMemory(state=state)
return StateManagerDisk(state=state)
@abstractmethod
async def get_state(self, token: str) -> BaseState:
@ -2446,6 +2448,244 @@ class StateManagerMemory(StateManager):
await self.set_state(token, state)
def _default_token_expiration() -> int:
"""Get the default token expiration time.
Returns:
The default token expiration time.
"""
return get_config().redis_token_expiration
def state_to_schema(
state: BaseState,
) -> List[
Tuple[
str,
str,
Any,
Union[bool, None],
]
]:
"""Convert a state to a schema.
Args:
state: The state to convert to a schema.
Returns:
The schema.
"""
return list(
sorted(
(
field_name,
model_field.name,
model_field.type_,
(
model_field.required
if isinstance(model_field.required, bool)
else None
),
)
for field_name, model_field in state.__fields__.items()
)
)
class StateManagerDisk(StateManager):
"""A state manager that stores states in memory."""
# The mapping of client ids to states.
states: Dict[str, BaseState] = {}
# The mutex ensures the dict of mutexes is updated exclusively
_state_manager_lock = asyncio.Lock()
# The dict of mutexes for each client
_states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
# The token expiration time (s).
token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
class Config:
"""The Pydantic config."""
fields = {
"_states_locks": {"exclude": True},
}
keep_untouched = (functools.cached_property,)
def __init__(self, state: Type[BaseState]):
"""Create a new state manager.
Args:
state: The state class to use.
"""
super().__init__(state=state)
path_ops.mkdir(self.states_directory)
self._purge_expired_states()
@functools.cached_property
def states_directory(self) -> Path:
"""Get the states directory.
Returns:
The states directory.
"""
return prerequisites.get_web_dir() / constants.Dirs.STATES
def _purge_expired_states(self):
"""Purge expired states from the disk."""
import time
for path in path_ops.ls(self.states_directory):
# check path is a pickle file
if path.suffix != ".pkl":
continue
# load last edited field from file
last_edited = path.stat().st_mtime
# check if the file is older than the token expiration time
if time.time() - last_edited > self.token_expiration:
# remove the file
path.unlink()
def token_path(self, token: str) -> Path:
"""Get the path for a token.
Args:
token: The token to get the path for.
Returns:
The path for the token.
"""
return (self.states_directory / f"{token}.pkl").absolute()
async def load_state(self, token: str, root_state: BaseState) -> BaseState:
"""Load a state object based on the provided token.
Args:
token: The token used to identify the state object.
root_state: The root state object.
Returns:
The loaded state object.
"""
if token in self.states:
return self.states[token]
client_token, substate_address = _split_substate_key(token)
token_path = self.token_path(token)
if token_path.exists():
try:
with token_path.open(mode="rb") as file:
(substate_schema, substate) = dill.load(file)
if substate_schema == state_to_schema(substate):
await self.populate_substates(client_token, substate, root_state)
return substate
except Exception:
pass
return root_state.get_substate(substate_address.split(".")[1:])
async def populate_substates(
self, client_token: str, state: BaseState, root_state: BaseState
):
"""Populate the substates of a state object.
Args:
client_token: The client token.
state: The state object to populate.
root_state: The root state object.
"""
for substate in state.get_substates():
substate_token = _substate_key(client_token, substate)
substate = await self.load_state(substate_token, root_state)
state.substates[substate.get_name()] = substate
substate.parent_state = state
@override
async def get_state(
self,
token: str,
) -> BaseState:
"""Get the state for a token.
Args:
token: The token to get the state for.
Returns:
The state for the token.
"""
client_token, substate_address = _split_substate_key(token)
root_state_token = _substate_key(client_token, substate_address.split(".")[0])
return await self.load_state(
root_state_token, self.state(_reflex_internal_init=True)
)
async def set_state_for_substate(self, client_token: str, substate: BaseState):
"""Set the state for a substate.
Args:
client_token: The client token.
substate: The substate to set.
"""
substate_token = _substate_key(client_token, substate)
self.states[substate_token] = substate
state_dilled = dill.dumps((state_to_schema(substate), substate), byref=True)
if not self.states_directory.exists():
self.states_directory.mkdir(parents=True, exist_ok=True)
self.token_path(substate_token).write_bytes(state_dilled)
for substate_substate in substate.substates.values():
await self.set_state_for_substate(client_token, substate_substate)
@override
async def set_state(self, token: str, state: BaseState):
"""Set the state for a token.
Args:
token: The token to set the state for.
state: The state to set.
"""
client_token, substate = _split_substate_key(token)
await self.set_state_for_substate(client_token, state)
@override
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.
Args:
token: The token to modify the state for.
Yields:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
client_token, substate = _split_substate_key(token)
if client_token not in self._states_locks:
async with self._state_manager_lock:
if client_token not in self._states_locks:
self._states_locks[client_token] = asyncio.Lock()
async with self._states_locks[client_token]:
state = await self.get_state(token)
yield state
await self.set_state(token, state)
# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
if not isinstance(State.validate.__func__, FunctionType):
cython_function_or_method = type(State.validate.__func__)
@ -2474,15 +2714,6 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration
def _default_token_expiration() -> int:
"""Get the default token expiration time.
Returns:
The default token expiration time.
"""
return get_config().redis_token_expiration
class StateManagerRedis(StateManager):
"""A state manager that stores states in redis."""

View File

@ -45,6 +45,8 @@ import reflex.utils.prerequisites
import reflex.utils.processes
from reflex.state import (
BaseState,
StateManager,
StateManagerDisk,
StateManagerMemory,
StateManagerRedis,
reload_state_module,
@ -126,7 +128,7 @@ class AppHarness:
frontend_output_thread: Optional[threading.Thread] = None
backend_thread: Optional[threading.Thread] = None
backend: Optional[uvicorn.Server] = None
state_manager: Optional[StateManagerMemory | StateManagerRedis] = None
state_manager: Optional[StateManager] = None
_frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
_decorated_pages: list = dataclasses.field(default_factory=list)
@ -290,6 +292,8 @@ class AppHarness:
if isinstance(self.app_instance._state_manager, StateManagerRedis):
# Create our own redis connection for testing.
self.state_manager = StateManagerRedis.create(self.app_instance.state)
elif isinstance(self.app_instance._state_manager, StateManagerDisk):
self.state_manager = StateManagerDisk.create(self.app_instance.state)
else:
self.state_manager = self.app_instance._state_manager
@ -327,7 +331,8 @@ class AppHarness:
)
)
self.backend.shutdown = self._get_backend_shutdown_handler()
self.backend_thread = threading.Thread(target=self.backend.run)
with chdir(self.app_path):
self.backend_thread = threading.Thread(target=self.backend.run)
self.backend_thread.start()
async def _reset_backend_state_manager(self):
@ -787,7 +792,7 @@ class AppHarness:
raise RuntimeError("App is not running.")
state_manager = self.app_instance.state_manager
assert isinstance(
state_manager, StateManagerMemory
state_manager, (StateManagerMemory, StateManagerDisk)
), "Only works with memory state manager"
if not self._poll_for(
target=lambda: state_manager.states,

View File

@ -81,6 +81,18 @@ def mkdir(path: str | Path):
Path(path).mkdir(parents=True, exist_ok=True)
def ls(path: str | Path) -> list[Path]:
"""List the contents of a directory.
Args:
path: The path to the directory.
Returns:
A list of paths to the contents of the directory.
"""
return list(Path(path).iterdir())
def ln(src: str | Path, dest: str | Path, overwrite: bool = False) -> bool:
"""Create a symbolic link.

View File

@ -42,6 +42,7 @@ from reflex.state import (
OnLoadInternalState,
RouterData,
State,
StateManagerDisk,
StateManagerMemory,
StateManagerRedis,
StateUpdate,
@ -1395,7 +1396,9 @@ def test_app_state_manager():
app.state_manager
app._enable_state()
assert app.state_manager is not None
assert isinstance(app.state_manager, (StateManagerMemory, StateManagerRedis))
assert isinstance(
app.state_manager, (StateManagerMemory, StateManagerDisk, StateManagerRedis)
)
def test_generate_component():

View File

@ -31,6 +31,7 @@ from reflex.state import (
RouterData,
State,
StateManager,
StateManagerDisk,
StateManagerMemory,
StateManagerRedis,
StateProxy,
@ -1586,7 +1587,7 @@ async def test_state_with_invalid_yield(capsys, mock_app):
assert "must only return/yield: None, Events or other EventHandlers" in captured.out
@pytest.fixture(scope="function", params=["in_process", "redis"])
@pytest.fixture(scope="function", params=["in_process", "disk", "redis"])
def state_manager(request) -> Generator[StateManager, None, None]:
"""Instance of state manager parametrized for redis and in-process.
@ -1600,8 +1601,11 @@ def state_manager(request) -> Generator[StateManager, None, None]:
if request.param == "redis":
if not isinstance(state_manager, StateManagerRedis):
pytest.skip("Test requires redis")
else:
elif request.param == "disk":
# explicitly NOT using redis
state_manager = StateManagerDisk(state=TestState)
assert not state_manager._states_locks
else:
state_manager = StateManagerMemory(state=TestState)
assert not state_manager._states_locks
@ -1639,7 +1643,7 @@ async def test_state_manager_modify_state(
async with state_manager.modify_state(substate_token) as state:
if isinstance(state_manager, StateManagerRedis):
assert await state_manager.redis.get(f"{token}_lock")
elif isinstance(state_manager, StateManagerMemory):
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks
assert state_manager._states_locks[token].locked()
# Should be able to write proxy objects inside mutables
@ -1649,11 +1653,11 @@ async def test_state_manager_modify_state(
# lock should be dropped after exiting the context
if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
elif isinstance(state_manager, StateManagerMemory):
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert not state_manager._states_locks[token].locked()
# separate instances should NOT share locks
sm2 = StateManagerMemory(state=TestState)
sm2 = state_manager.__class__(state=TestState)
assert sm2._state_manager_lock is state_manager._state_manager_lock
assert not sm2._states_locks
if state_manager._states_locks:
@ -1691,7 +1695,7 @@ async def test_state_manager_contend(
if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
elif isinstance(state_manager, StateManagerMemory):
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks
assert not state_manager._states_locks[token].locked()
@ -1831,7 +1835,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert child_state is not None
parent_state = child_state.parent_state
assert parent_state is not None
if isinstance(mock_app.state_manager, StateManagerMemory):
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
mock_app.state_manager.states[parent_state.router.session.client_token] = (
parent_state
)
@ -1874,7 +1878,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
# For in-process store, only one instance of the state exists
assert sp.__wrapped__ is grandchild_state
else:
# When redis is used, a new+updated instance is assigned to the proxy
# When redis or disk is used, a new+updated instance is assigned to the proxy
assert sp.__wrapped__ is not grandchild_state
sp.value2 = "42"
assert not sp._self_mutable # proxy is not mutable after exiting context
@ -2837,7 +2841,7 @@ async def test_get_state(mock_app: rx.App, token: str):
_substate_key(token, ChildState2)
)
assert isinstance(test_state, TestState)
if isinstance(mock_app.state_manager, StateManagerMemory):
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
# All substates are available
assert tuple(sorted(test_state.substates)) == (
ChildState.get_name(),
@ -2916,6 +2920,15 @@ async def test_get_state(mock_app: rx.App, token: str):
ChildState2.get_name(),
ChildState3.get_name(),
)
elif isinstance(mock_app.state_manager, StateManagerDisk):
# On disk, it's a new instance
assert new_test_state is not test_state
# All substates are available
assert tuple(sorted(new_test_state.substates)) == (
ChildState.get_name(),
ChildState2.get_name(),
ChildState3.get_name(),
)
else:
# With redis, we get a whole new instance
assert new_test_state is not test_state