implement disk state manager ()

* implement disk state manager

* put states inside of a folder

* return root state all the time

* factor out code

* add docs for token expiration

* cache states directory

* call absolute on web directory

* change dir to app path when testing the backend

* remove accidental 🥒

* test disk for now

* modify schema

* only serialize specific stuff

* fix issue in types

* what is a kilometer

* create folder if it doesn't exist in write

* this code hates me

* check if the file isn't empty

* add try except clause

* add check for directory again
This commit is contained in:
Khaleel Al-Adhami 2024-08-30 17:26:10 -07:00 committed by GitHub
parent c457b43ab1
commit 629850162a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 290 additions and 24 deletions

View File

@ -45,6 +45,8 @@ class Dirs(SimpleNamespace):
REFLEX_JSON = "reflex.json" REFLEX_JSON = "reflex.json"
# The name of the postcss config file. # The name of the postcss config file.
POSTCSS_JS = "postcss.config.js" POSTCSS_JS = "postcss.config.js"
# The name of the states directory.
STATES = "states"
class Reflex(SimpleNamespace): class Reflex(SimpleNamespace):

View File

@ -11,6 +11,7 @@ import os
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from pathlib import Path
from types import FunctionType, MethodType from types import FunctionType, MethodType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -23,6 +24,7 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Set, Set,
Tuple,
Type, Type,
Union, Union,
cast, cast,
@ -52,7 +54,7 @@ from reflex.event import (
EventSpec, EventSpec,
fix_events, fix_events,
) )
from reflex.utils import console, format, prerequisites, types from reflex.utils import console, format, path_ops, prerequisites, types
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
from reflex.utils.exec import is_testing_env from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import SerializedType, serialize, serializer from reflex.utils.serializers import SerializedType, serialize, serializer
@ -2339,7 +2341,7 @@ class StateManager(Base, ABC):
token_expiration=config.redis_token_expiration, token_expiration=config.redis_token_expiration,
lock_expiration=config.redis_lock_expiration, lock_expiration=config.redis_lock_expiration,
) )
return StateManagerMemory(state=state) return StateManagerDisk(state=state)
@abstractmethod @abstractmethod
async def get_state(self, token: str) -> BaseState: async def get_state(self, token: str) -> BaseState:
@ -2446,6 +2448,244 @@ class StateManagerMemory(StateManager):
await self.set_state(token, state) await self.set_state(token, state)
def _default_token_expiration() -> int:
"""Get the default token expiration time.
Returns:
The default token expiration time.
"""
return get_config().redis_token_expiration
def state_to_schema(
state: BaseState,
) -> List[
Tuple[
str,
str,
Any,
Union[bool, None],
]
]:
"""Convert a state to a schema.
Args:
state: The state to convert to a schema.
Returns:
The schema.
"""
return list(
sorted(
(
field_name,
model_field.name,
model_field.type_,
(
model_field.required
if isinstance(model_field.required, bool)
else None
),
)
for field_name, model_field in state.__fields__.items()
)
)
class StateManagerDisk(StateManager):
"""A state manager that stores states in memory."""
# The mapping of client ids to states.
states: Dict[str, BaseState] = {}
# The mutex ensures the dict of mutexes is updated exclusively
_state_manager_lock = asyncio.Lock()
# The dict of mutexes for each client
_states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
# The token expiration time (s).
token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
class Config:
"""The Pydantic config."""
fields = {
"_states_locks": {"exclude": True},
}
keep_untouched = (functools.cached_property,)
def __init__(self, state: Type[BaseState]):
"""Create a new state manager.
Args:
state: The state class to use.
"""
super().__init__(state=state)
path_ops.mkdir(self.states_directory)
self._purge_expired_states()
@functools.cached_property
def states_directory(self) -> Path:
"""Get the states directory.
Returns:
The states directory.
"""
return prerequisites.get_web_dir() / constants.Dirs.STATES
def _purge_expired_states(self):
"""Purge expired states from the disk."""
import time
for path in path_ops.ls(self.states_directory):
# check path is a pickle file
if path.suffix != ".pkl":
continue
# load last edited field from file
last_edited = path.stat().st_mtime
# check if the file is older than the token expiration time
if time.time() - last_edited > self.token_expiration:
# remove the file
path.unlink()
def token_path(self, token: str) -> Path:
"""Get the path for a token.
Args:
token: The token to get the path for.
Returns:
The path for the token.
"""
return (self.states_directory / f"{token}.pkl").absolute()
async def load_state(self, token: str, root_state: BaseState) -> BaseState:
"""Load a state object based on the provided token.
Args:
token: The token used to identify the state object.
root_state: The root state object.
Returns:
The loaded state object.
"""
if token in self.states:
return self.states[token]
client_token, substate_address = _split_substate_key(token)
token_path = self.token_path(token)
if token_path.exists():
try:
with token_path.open(mode="rb") as file:
(substate_schema, substate) = dill.load(file)
if substate_schema == state_to_schema(substate):
await self.populate_substates(client_token, substate, root_state)
return substate
except Exception:
pass
return root_state.get_substate(substate_address.split(".")[1:])
async def populate_substates(
self, client_token: str, state: BaseState, root_state: BaseState
):
"""Populate the substates of a state object.
Args:
client_token: The client token.
state: The state object to populate.
root_state: The root state object.
"""
for substate in state.get_substates():
substate_token = _substate_key(client_token, substate)
substate = await self.load_state(substate_token, root_state)
state.substates[substate.get_name()] = substate
substate.parent_state = state
@override
async def get_state(
self,
token: str,
) -> BaseState:
"""Get the state for a token.
Args:
token: The token to get the state for.
Returns:
The state for the token.
"""
client_token, substate_address = _split_substate_key(token)
root_state_token = _substate_key(client_token, substate_address.split(".")[0])
return await self.load_state(
root_state_token, self.state(_reflex_internal_init=True)
)
async def set_state_for_substate(self, client_token: str, substate: BaseState):
"""Set the state for a substate.
Args:
client_token: The client token.
substate: The substate to set.
"""
substate_token = _substate_key(client_token, substate)
self.states[substate_token] = substate
state_dilled = dill.dumps((state_to_schema(substate), substate), byref=True)
if not self.states_directory.exists():
self.states_directory.mkdir(parents=True, exist_ok=True)
self.token_path(substate_token).write_bytes(state_dilled)
for substate_substate in substate.substates.values():
await self.set_state_for_substate(client_token, substate_substate)
@override
async def set_state(self, token: str, state: BaseState):
"""Set the state for a token.
Args:
token: The token to set the state for.
state: The state to set.
"""
client_token, substate = _split_substate_key(token)
await self.set_state_for_substate(client_token, state)
@override
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.
Args:
token: The token to modify the state for.
Yields:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
client_token, substate = _split_substate_key(token)
if client_token not in self._states_locks:
async with self._state_manager_lock:
if client_token not in self._states_locks:
self._states_locks[client_token] = asyncio.Lock()
async with self._states_locks[client_token]:
state = await self.get_state(token)
yield state
await self.set_state(token, state)
# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes # Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
if not isinstance(State.validate.__func__, FunctionType): if not isinstance(State.validate.__func__, FunctionType):
cython_function_or_method = type(State.validate.__func__) cython_function_or_method = type(State.validate.__func__)
@ -2474,15 +2714,6 @@ def _default_lock_expiration() -> int:
return get_config().redis_lock_expiration return get_config().redis_lock_expiration
def _default_token_expiration() -> int:
"""Get the default token expiration time.
Returns:
The default token expiration time.
"""
return get_config().redis_token_expiration
class StateManagerRedis(StateManager): class StateManagerRedis(StateManager):
"""A state manager that stores states in redis.""" """A state manager that stores states in redis."""

View File

@ -45,6 +45,8 @@ import reflex.utils.prerequisites
import reflex.utils.processes import reflex.utils.processes
from reflex.state import ( from reflex.state import (
BaseState, BaseState,
StateManager,
StateManagerDisk,
StateManagerMemory, StateManagerMemory,
StateManagerRedis, StateManagerRedis,
reload_state_module, reload_state_module,
@ -126,7 +128,7 @@ class AppHarness:
frontend_output_thread: Optional[threading.Thread] = None frontend_output_thread: Optional[threading.Thread] = None
backend_thread: Optional[threading.Thread] = None backend_thread: Optional[threading.Thread] = None
backend: Optional[uvicorn.Server] = None backend: Optional[uvicorn.Server] = None
state_manager: Optional[StateManagerMemory | StateManagerRedis] = None state_manager: Optional[StateManager] = None
_frontends: list["WebDriver"] = dataclasses.field(default_factory=list) _frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
_decorated_pages: list = dataclasses.field(default_factory=list) _decorated_pages: list = dataclasses.field(default_factory=list)
@ -290,6 +292,8 @@ class AppHarness:
if isinstance(self.app_instance._state_manager, StateManagerRedis): if isinstance(self.app_instance._state_manager, StateManagerRedis):
# Create our own redis connection for testing. # Create our own redis connection for testing.
self.state_manager = StateManagerRedis.create(self.app_instance.state) self.state_manager = StateManagerRedis.create(self.app_instance.state)
elif isinstance(self.app_instance._state_manager, StateManagerDisk):
self.state_manager = StateManagerDisk.create(self.app_instance.state)
else: else:
self.state_manager = self.app_instance._state_manager self.state_manager = self.app_instance._state_manager
@ -327,7 +331,8 @@ class AppHarness:
) )
) )
self.backend.shutdown = self._get_backend_shutdown_handler() self.backend.shutdown = self._get_backend_shutdown_handler()
self.backend_thread = threading.Thread(target=self.backend.run) with chdir(self.app_path):
self.backend_thread = threading.Thread(target=self.backend.run)
self.backend_thread.start() self.backend_thread.start()
async def _reset_backend_state_manager(self): async def _reset_backend_state_manager(self):
@ -787,7 +792,7 @@ class AppHarness:
raise RuntimeError("App is not running.") raise RuntimeError("App is not running.")
state_manager = self.app_instance.state_manager state_manager = self.app_instance.state_manager
assert isinstance( assert isinstance(
state_manager, StateManagerMemory state_manager, (StateManagerMemory, StateManagerDisk)
), "Only works with memory state manager" ), "Only works with memory state manager"
if not self._poll_for( if not self._poll_for(
target=lambda: state_manager.states, target=lambda: state_manager.states,

View File

@ -81,6 +81,18 @@ def mkdir(path: str | Path):
Path(path).mkdir(parents=True, exist_ok=True) Path(path).mkdir(parents=True, exist_ok=True)
def ls(path: str | Path) -> list[Path]:
"""List the contents of a directory.
Args:
path: The path to the directory.
Returns:
A list of paths to the contents of the directory.
"""
return list(Path(path).iterdir())
def ln(src: str | Path, dest: str | Path, overwrite: bool = False) -> bool: def ln(src: str | Path, dest: str | Path, overwrite: bool = False) -> bool:
"""Create a symbolic link. """Create a symbolic link.

View File

@ -42,6 +42,7 @@ from reflex.state import (
OnLoadInternalState, OnLoadInternalState,
RouterData, RouterData,
State, State,
StateManagerDisk,
StateManagerMemory, StateManagerMemory,
StateManagerRedis, StateManagerRedis,
StateUpdate, StateUpdate,
@ -1395,7 +1396,9 @@ def test_app_state_manager():
app.state_manager app.state_manager
app._enable_state() app._enable_state()
assert app.state_manager is not None assert app.state_manager is not None
assert isinstance(app.state_manager, (StateManagerMemory, StateManagerRedis)) assert isinstance(
app.state_manager, (StateManagerMemory, StateManagerDisk, StateManagerRedis)
)
def test_generate_component(): def test_generate_component():

View File

@ -31,6 +31,7 @@ from reflex.state import (
RouterData, RouterData,
State, State,
StateManager, StateManager,
StateManagerDisk,
StateManagerMemory, StateManagerMemory,
StateManagerRedis, StateManagerRedis,
StateProxy, StateProxy,
@ -1586,7 +1587,7 @@ async def test_state_with_invalid_yield(capsys, mock_app):
assert "must only return/yield: None, Events or other EventHandlers" in captured.out assert "must only return/yield: None, Events or other EventHandlers" in captured.out
@pytest.fixture(scope="function", params=["in_process", "redis"]) @pytest.fixture(scope="function", params=["in_process", "disk", "redis"])
def state_manager(request) -> Generator[StateManager, None, None]: def state_manager(request) -> Generator[StateManager, None, None]:
"""Instance of state manager parametrized for redis and in-process. """Instance of state manager parametrized for redis and in-process.
@ -1600,8 +1601,11 @@ def state_manager(request) -> Generator[StateManager, None, None]:
if request.param == "redis": if request.param == "redis":
if not isinstance(state_manager, StateManagerRedis): if not isinstance(state_manager, StateManagerRedis):
pytest.skip("Test requires redis") pytest.skip("Test requires redis")
else: elif request.param == "disk":
# explicitly NOT using redis # explicitly NOT using redis
state_manager = StateManagerDisk(state=TestState)
assert not state_manager._states_locks
else:
state_manager = StateManagerMemory(state=TestState) state_manager = StateManagerMemory(state=TestState)
assert not state_manager._states_locks assert not state_manager._states_locks
@ -1639,7 +1643,7 @@ async def test_state_manager_modify_state(
async with state_manager.modify_state(substate_token) as state: async with state_manager.modify_state(substate_token) as state:
if isinstance(state_manager, StateManagerRedis): if isinstance(state_manager, StateManagerRedis):
assert await state_manager.redis.get(f"{token}_lock") assert await state_manager.redis.get(f"{token}_lock")
elif isinstance(state_manager, StateManagerMemory): elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks assert token in state_manager._states_locks
assert state_manager._states_locks[token].locked() assert state_manager._states_locks[token].locked()
# Should be able to write proxy objects inside mutables # Should be able to write proxy objects inside mutables
@ -1649,11 +1653,11 @@ async def test_state_manager_modify_state(
# lock should be dropped after exiting the context # lock should be dropped after exiting the context
if isinstance(state_manager, StateManagerRedis): if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None assert (await state_manager.redis.get(f"{token}_lock")) is None
elif isinstance(state_manager, StateManagerMemory): elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert not state_manager._states_locks[token].locked() assert not state_manager._states_locks[token].locked()
# separate instances should NOT share locks # separate instances should NOT share locks
sm2 = StateManagerMemory(state=TestState) sm2 = state_manager.__class__(state=TestState)
assert sm2._state_manager_lock is state_manager._state_manager_lock assert sm2._state_manager_lock is state_manager._state_manager_lock
assert not sm2._states_locks assert not sm2._states_locks
if state_manager._states_locks: if state_manager._states_locks:
@ -1691,7 +1695,7 @@ async def test_state_manager_contend(
if isinstance(state_manager, StateManagerRedis): if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None assert (await state_manager.redis.get(f"{token}_lock")) is None
elif isinstance(state_manager, StateManagerMemory): elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
assert token in state_manager._states_locks assert token in state_manager._states_locks
assert not state_manager._states_locks[token].locked() assert not state_manager._states_locks[token].locked()
@ -1831,7 +1835,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert child_state is not None assert child_state is not None
parent_state = child_state.parent_state parent_state = child_state.parent_state
assert parent_state is not None assert parent_state is not None
if isinstance(mock_app.state_manager, StateManagerMemory): if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
mock_app.state_manager.states[parent_state.router.session.client_token] = ( mock_app.state_manager.states[parent_state.router.session.client_token] = (
parent_state parent_state
) )
@ -1874,7 +1878,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
# For in-process store, only one instance of the state exists # For in-process store, only one instance of the state exists
assert sp.__wrapped__ is grandchild_state assert sp.__wrapped__ is grandchild_state
else: else:
# When redis is used, a new+updated instance is assigned to the proxy # When redis or disk is used, a new+updated instance is assigned to the proxy
assert sp.__wrapped__ is not grandchild_state assert sp.__wrapped__ is not grandchild_state
sp.value2 = "42" sp.value2 = "42"
assert not sp._self_mutable # proxy is not mutable after exiting context assert not sp._self_mutable # proxy is not mutable after exiting context
@ -2837,7 +2841,7 @@ async def test_get_state(mock_app: rx.App, token: str):
_substate_key(token, ChildState2) _substate_key(token, ChildState2)
) )
assert isinstance(test_state, TestState) assert isinstance(test_state, TestState)
if isinstance(mock_app.state_manager, StateManagerMemory): if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
# All substates are available # All substates are available
assert tuple(sorted(test_state.substates)) == ( assert tuple(sorted(test_state.substates)) == (
ChildState.get_name(), ChildState.get_name(),
@ -2916,6 +2920,15 @@ async def test_get_state(mock_app: rx.App, token: str):
ChildState2.get_name(), ChildState2.get_name(),
ChildState3.get_name(), ChildState3.get_name(),
) )
elif isinstance(mock_app.state_manager, StateManagerDisk):
# On disk, it's a new instance
assert new_test_state is not test_state
# All substates are available
assert tuple(sorted(new_test_state.substates)) == (
ChildState.get_name(),
ChildState2.get_name(),
ChildState3.get_name(),
)
else: else:
# With redis, we get a whole new instance # With redis, we get a whole new instance
assert new_test_state is not test_state assert new_test_state is not test_state