implement disk state manager (#3826)
* implement disk state manager
* put states inside of a folder
* return root state all the time
* factor out code
* add docs for token expiration
* cache states directory
* call absolute on web directory
* change dir to app path when testing the backend
* remove accidental 🥒
* test disk for now
* modify schema
* only serialize specific stuff
* fix issue in types
* what is a kilometer
* create folder if it doesn't exist in write
* this code hates me
* check if the file isn't empty
* add try except clause
* add check for directory again
This commit is contained in:
parent
c457b43ab1
commit
629850162a
@ -45,6 +45,8 @@ class Dirs(SimpleNamespace):
|
||||
REFLEX_JSON = "reflex.json"
|
||||
# The name of the postcss config file.
|
||||
POSTCSS_JS = "postcss.config.js"
|
||||
# The name of the states directory.
|
||||
STATES = "states"
|
||||
|
||||
|
||||
class Reflex(SimpleNamespace):
|
||||
|
253
reflex/state.py
253
reflex/state.py
@ -11,6 +11,7 @@ import os
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from types import FunctionType, MethodType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -23,6 +24,7 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
@ -52,7 +54,7 @@ from reflex.event import (
|
||||
EventSpec,
|
||||
fix_events,
|
||||
)
|
||||
from reflex.utils import console, format, prerequisites, types
|
||||
from reflex.utils import console, format, path_ops, prerequisites, types
|
||||
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
||||
from reflex.utils.exec import is_testing_env
|
||||
from reflex.utils.serializers import SerializedType, serialize, serializer
|
||||
@ -2339,7 +2341,7 @@ class StateManager(Base, ABC):
|
||||
token_expiration=config.redis_token_expiration,
|
||||
lock_expiration=config.redis_lock_expiration,
|
||||
)
|
||||
return StateManagerMemory(state=state)
|
||||
return StateManagerDisk(state=state)
|
||||
|
||||
@abstractmethod
|
||||
async def get_state(self, token: str) -> BaseState:
|
||||
@ -2446,6 +2448,244 @@ class StateManagerMemory(StateManager):
|
||||
await self.set_state(token, state)
|
||||
|
||||
|
||||
def _default_token_expiration() -> int:
|
||||
"""Get the default token expiration time.
|
||||
|
||||
Returns:
|
||||
The default token expiration time.
|
||||
"""
|
||||
return get_config().redis_token_expiration
|
||||
|
||||
|
||||
def state_to_schema(
|
||||
state: BaseState,
|
||||
) -> List[
|
||||
Tuple[
|
||||
str,
|
||||
str,
|
||||
Any,
|
||||
Union[bool, None],
|
||||
]
|
||||
]:
|
||||
"""Convert a state to a schema.
|
||||
|
||||
Args:
|
||||
state: The state to convert to a schema.
|
||||
|
||||
Returns:
|
||||
The schema.
|
||||
"""
|
||||
return list(
|
||||
sorted(
|
||||
(
|
||||
field_name,
|
||||
model_field.name,
|
||||
model_field.type_,
|
||||
(
|
||||
model_field.required
|
||||
if isinstance(model_field.required, bool)
|
||||
else None
|
||||
),
|
||||
)
|
||||
for field_name, model_field in state.__fields__.items()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class StateManagerDisk(StateManager):
|
||||
"""A state manager that stores states in memory."""
|
||||
|
||||
# The mapping of client ids to states.
|
||||
states: Dict[str, BaseState] = {}
|
||||
|
||||
# The mutex ensures the dict of mutexes is updated exclusively
|
||||
_state_manager_lock = asyncio.Lock()
|
||||
|
||||
# The dict of mutexes for each client
|
||||
_states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
|
||||
|
||||
# The token expiration time (s).
|
||||
token_expiration: int = pydantic.Field(default_factory=_default_token_expiration)
|
||||
|
||||
class Config:
|
||||
"""The Pydantic config."""
|
||||
|
||||
fields = {
|
||||
"_states_locks": {"exclude": True},
|
||||
}
|
||||
keep_untouched = (functools.cached_property,)
|
||||
|
||||
def __init__(self, state: Type[BaseState]):
|
||||
"""Create a new state manager.
|
||||
|
||||
Args:
|
||||
state: The state class to use.
|
||||
"""
|
||||
super().__init__(state=state)
|
||||
|
||||
path_ops.mkdir(self.states_directory)
|
||||
|
||||
self._purge_expired_states()
|
||||
|
||||
@functools.cached_property
|
||||
def states_directory(self) -> Path:
|
||||
"""Get the states directory.
|
||||
|
||||
Returns:
|
||||
The states directory.
|
||||
"""
|
||||
return prerequisites.get_web_dir() / constants.Dirs.STATES
|
||||
|
||||
def _purge_expired_states(self):
|
||||
"""Purge expired states from the disk."""
|
||||
import time
|
||||
|
||||
for path in path_ops.ls(self.states_directory):
|
||||
# check path is a pickle file
|
||||
if path.suffix != ".pkl":
|
||||
continue
|
||||
|
||||
# load last edited field from file
|
||||
last_edited = path.stat().st_mtime
|
||||
|
||||
# check if the file is older than the token expiration time
|
||||
if time.time() - last_edited > self.token_expiration:
|
||||
# remove the file
|
||||
path.unlink()
|
||||
|
||||
def token_path(self, token: str) -> Path:
|
||||
"""Get the path for a token.
|
||||
|
||||
Args:
|
||||
token: The token to get the path for.
|
||||
|
||||
Returns:
|
||||
The path for the token.
|
||||
"""
|
||||
return (self.states_directory / f"{token}.pkl").absolute()
|
||||
|
||||
async def load_state(self, token: str, root_state: BaseState) -> BaseState:
|
||||
"""Load a state object based on the provided token.
|
||||
|
||||
Args:
|
||||
token: The token used to identify the state object.
|
||||
root_state: The root state object.
|
||||
|
||||
Returns:
|
||||
The loaded state object.
|
||||
"""
|
||||
if token in self.states:
|
||||
return self.states[token]
|
||||
|
||||
client_token, substate_address = _split_substate_key(token)
|
||||
|
||||
token_path = self.token_path(token)
|
||||
|
||||
if token_path.exists():
|
||||
try:
|
||||
with token_path.open(mode="rb") as file:
|
||||
(substate_schema, substate) = dill.load(file)
|
||||
if substate_schema == state_to_schema(substate):
|
||||
await self.populate_substates(client_token, substate, root_state)
|
||||
return substate
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return root_state.get_substate(substate_address.split(".")[1:])
|
||||
|
||||
async def populate_substates(
|
||||
self, client_token: str, state: BaseState, root_state: BaseState
|
||||
):
|
||||
"""Populate the substates of a state object.
|
||||
|
||||
Args:
|
||||
client_token: The client token.
|
||||
state: The state object to populate.
|
||||
root_state: The root state object.
|
||||
"""
|
||||
for substate in state.get_substates():
|
||||
substate_token = _substate_key(client_token, substate)
|
||||
|
||||
substate = await self.load_state(substate_token, root_state)
|
||||
|
||||
state.substates[substate.get_name()] = substate
|
||||
substate.parent_state = state
|
||||
|
||||
@override
|
||||
async def get_state(
|
||||
self,
|
||||
token: str,
|
||||
) -> BaseState:
|
||||
"""Get the state for a token.
|
||||
|
||||
Args:
|
||||
token: The token to get the state for.
|
||||
|
||||
Returns:
|
||||
The state for the token.
|
||||
"""
|
||||
client_token, substate_address = _split_substate_key(token)
|
||||
|
||||
root_state_token = _substate_key(client_token, substate_address.split(".")[0])
|
||||
|
||||
return await self.load_state(
|
||||
root_state_token, self.state(_reflex_internal_init=True)
|
||||
)
|
||||
|
||||
async def set_state_for_substate(self, client_token: str, substate: BaseState):
|
||||
"""Set the state for a substate.
|
||||
|
||||
Args:
|
||||
client_token: The client token.
|
||||
substate: The substate to set.
|
||||
"""
|
||||
substate_token = _substate_key(client_token, substate)
|
||||
|
||||
self.states[substate_token] = substate
|
||||
|
||||
state_dilled = dill.dumps((state_to_schema(substate), substate), byref=True)
|
||||
if not self.states_directory.exists():
|
||||
self.states_directory.mkdir(parents=True, exist_ok=True)
|
||||
self.token_path(substate_token).write_bytes(state_dilled)
|
||||
|
||||
for substate_substate in substate.substates.values():
|
||||
await self.set_state_for_substate(client_token, substate_substate)
|
||||
|
||||
@override
|
||||
async def set_state(self, token: str, state: BaseState):
|
||||
"""Set the state for a token.
|
||||
|
||||
Args:
|
||||
token: The token to set the state for.
|
||||
state: The state to set.
|
||||
"""
|
||||
client_token, substate = _split_substate_key(token)
|
||||
await self.set_state_for_substate(client_token, state)
|
||||
|
||||
@override
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
|
||||
"""Modify the state for a token while holding exclusive lock.
|
||||
|
||||
Args:
|
||||
token: The token to modify the state for.
|
||||
|
||||
Yields:
|
||||
The state for the token.
|
||||
"""
|
||||
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
||||
client_token, substate = _split_substate_key(token)
|
||||
if client_token not in self._states_locks:
|
||||
async with self._state_manager_lock:
|
||||
if client_token not in self._states_locks:
|
||||
self._states_locks[client_token] = asyncio.Lock()
|
||||
|
||||
async with self._states_locks[client_token]:
|
||||
state = await self.get_state(token)
|
||||
yield state
|
||||
await self.set_state(token, state)
|
||||
|
||||
|
||||
# Workaround https://github.com/cloudpipe/cloudpickle/issues/408 for dynamic pydantic classes
|
||||
if not isinstance(State.validate.__func__, FunctionType):
|
||||
cython_function_or_method = type(State.validate.__func__)
|
||||
@ -2474,15 +2714,6 @@ def _default_lock_expiration() -> int:
|
||||
return get_config().redis_lock_expiration
|
||||
|
||||
|
||||
def _default_token_expiration() -> int:
|
||||
"""Get the default token expiration time.
|
||||
|
||||
Returns:
|
||||
The default token expiration time.
|
||||
"""
|
||||
return get_config().redis_token_expiration
|
||||
|
||||
|
||||
class StateManagerRedis(StateManager):
|
||||
"""A state manager that stores states in redis."""
|
||||
|
||||
|
@ -45,6 +45,8 @@ import reflex.utils.prerequisites
|
||||
import reflex.utils.processes
|
||||
from reflex.state import (
|
||||
BaseState,
|
||||
StateManager,
|
||||
StateManagerDisk,
|
||||
StateManagerMemory,
|
||||
StateManagerRedis,
|
||||
reload_state_module,
|
||||
@ -126,7 +128,7 @@ class AppHarness:
|
||||
frontend_output_thread: Optional[threading.Thread] = None
|
||||
backend_thread: Optional[threading.Thread] = None
|
||||
backend: Optional[uvicorn.Server] = None
|
||||
state_manager: Optional[StateManagerMemory | StateManagerRedis] = None
|
||||
state_manager: Optional[StateManager] = None
|
||||
_frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
|
||||
_decorated_pages: list = dataclasses.field(default_factory=list)
|
||||
|
||||
@ -290,6 +292,8 @@ class AppHarness:
|
||||
if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
||||
# Create our own redis connection for testing.
|
||||
self.state_manager = StateManagerRedis.create(self.app_instance.state)
|
||||
elif isinstance(self.app_instance._state_manager, StateManagerDisk):
|
||||
self.state_manager = StateManagerDisk.create(self.app_instance.state)
|
||||
else:
|
||||
self.state_manager = self.app_instance._state_manager
|
||||
|
||||
@ -327,7 +331,8 @@ class AppHarness:
|
||||
)
|
||||
)
|
||||
self.backend.shutdown = self._get_backend_shutdown_handler()
|
||||
self.backend_thread = threading.Thread(target=self.backend.run)
|
||||
with chdir(self.app_path):
|
||||
self.backend_thread = threading.Thread(target=self.backend.run)
|
||||
self.backend_thread.start()
|
||||
|
||||
async def _reset_backend_state_manager(self):
|
||||
@ -787,7 +792,7 @@ class AppHarness:
|
||||
raise RuntimeError("App is not running.")
|
||||
state_manager = self.app_instance.state_manager
|
||||
assert isinstance(
|
||||
state_manager, StateManagerMemory
|
||||
state_manager, (StateManagerMemory, StateManagerDisk)
|
||||
), "Only works with memory state manager"
|
||||
if not self._poll_for(
|
||||
target=lambda: state_manager.states,
|
||||
|
@ -81,6 +81,18 @@ def mkdir(path: str | Path):
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def ls(path: str | Path) -> list[Path]:
|
||||
"""List the contents of a directory.
|
||||
|
||||
Args:
|
||||
path: The path to the directory.
|
||||
|
||||
Returns:
|
||||
A list of paths to the contents of the directory.
|
||||
"""
|
||||
return list(Path(path).iterdir())
|
||||
|
||||
|
||||
def ln(src: str | Path, dest: str | Path, overwrite: bool = False) -> bool:
|
||||
"""Create a symbolic link.
|
||||
|
||||
|
@ -42,6 +42,7 @@ from reflex.state import (
|
||||
OnLoadInternalState,
|
||||
RouterData,
|
||||
State,
|
||||
StateManagerDisk,
|
||||
StateManagerMemory,
|
||||
StateManagerRedis,
|
||||
StateUpdate,
|
||||
@ -1395,7 +1396,9 @@ def test_app_state_manager():
|
||||
app.state_manager
|
||||
app._enable_state()
|
||||
assert app.state_manager is not None
|
||||
assert isinstance(app.state_manager, (StateManagerMemory, StateManagerRedis))
|
||||
assert isinstance(
|
||||
app.state_manager, (StateManagerMemory, StateManagerDisk, StateManagerRedis)
|
||||
)
|
||||
|
||||
|
||||
def test_generate_component():
|
||||
|
@ -31,6 +31,7 @@ from reflex.state import (
|
||||
RouterData,
|
||||
State,
|
||||
StateManager,
|
||||
StateManagerDisk,
|
||||
StateManagerMemory,
|
||||
StateManagerRedis,
|
||||
StateProxy,
|
||||
@ -1586,7 +1587,7 @@ async def test_state_with_invalid_yield(capsys, mock_app):
|
||||
assert "must only return/yield: None, Events or other EventHandlers" in captured.out
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", params=["in_process", "redis"])
|
||||
@pytest.fixture(scope="function", params=["in_process", "disk", "redis"])
|
||||
def state_manager(request) -> Generator[StateManager, None, None]:
|
||||
"""Instance of state manager parametrized for redis and in-process.
|
||||
|
||||
@ -1600,8 +1601,11 @@ def state_manager(request) -> Generator[StateManager, None, None]:
|
||||
if request.param == "redis":
|
||||
if not isinstance(state_manager, StateManagerRedis):
|
||||
pytest.skip("Test requires redis")
|
||||
else:
|
||||
elif request.param == "disk":
|
||||
# explicitly NOT using redis
|
||||
state_manager = StateManagerDisk(state=TestState)
|
||||
assert not state_manager._states_locks
|
||||
else:
|
||||
state_manager = StateManagerMemory(state=TestState)
|
||||
assert not state_manager._states_locks
|
||||
|
||||
@ -1639,7 +1643,7 @@ async def test_state_manager_modify_state(
|
||||
async with state_manager.modify_state(substate_token) as state:
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert await state_manager.redis.get(f"{token}_lock")
|
||||
elif isinstance(state_manager, StateManagerMemory):
|
||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
assert token in state_manager._states_locks
|
||||
assert state_manager._states_locks[token].locked()
|
||||
# Should be able to write proxy objects inside mutables
|
||||
@ -1649,11 +1653,11 @@ async def test_state_manager_modify_state(
|
||||
# lock should be dropped after exiting the context
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||
elif isinstance(state_manager, StateManagerMemory):
|
||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
assert not state_manager._states_locks[token].locked()
|
||||
|
||||
# separate instances should NOT share locks
|
||||
sm2 = StateManagerMemory(state=TestState)
|
||||
sm2 = state_manager.__class__(state=TestState)
|
||||
assert sm2._state_manager_lock is state_manager._state_manager_lock
|
||||
assert not sm2._states_locks
|
||||
if state_manager._states_locks:
|
||||
@ -1691,7 +1695,7 @@ async def test_state_manager_contend(
|
||||
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||
elif isinstance(state_manager, StateManagerMemory):
|
||||
elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
assert token in state_manager._states_locks
|
||||
assert not state_manager._states_locks[token].locked()
|
||||
|
||||
@ -1831,7 +1835,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
assert child_state is not None
|
||||
parent_state = child_state.parent_state
|
||||
assert parent_state is not None
|
||||
if isinstance(mock_app.state_manager, StateManagerMemory):
|
||||
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
mock_app.state_manager.states[parent_state.router.session.client_token] = (
|
||||
parent_state
|
||||
)
|
||||
@ -1874,7 +1878,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
# For in-process store, only one instance of the state exists
|
||||
assert sp.__wrapped__ is grandchild_state
|
||||
else:
|
||||
# When redis is used, a new+updated instance is assigned to the proxy
|
||||
# When redis or disk is used, a new+updated instance is assigned to the proxy
|
||||
assert sp.__wrapped__ is not grandchild_state
|
||||
sp.value2 = "42"
|
||||
assert not sp._self_mutable # proxy is not mutable after exiting context
|
||||
@ -2837,7 +2841,7 @@ async def test_get_state(mock_app: rx.App, token: str):
|
||||
_substate_key(token, ChildState2)
|
||||
)
|
||||
assert isinstance(test_state, TestState)
|
||||
if isinstance(mock_app.state_manager, StateManagerMemory):
|
||||
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
# All substates are available
|
||||
assert tuple(sorted(test_state.substates)) == (
|
||||
ChildState.get_name(),
|
||||
@ -2916,6 +2920,15 @@ async def test_get_state(mock_app: rx.App, token: str):
|
||||
ChildState2.get_name(),
|
||||
ChildState3.get_name(),
|
||||
)
|
||||
elif isinstance(mock_app.state_manager, StateManagerDisk):
|
||||
# On disk, it's a new instance
|
||||
assert new_test_state is not test_state
|
||||
# All substates are available
|
||||
assert tuple(sorted(new_test_state.substates)) == (
|
||||
ChildState.get_name(),
|
||||
ChildState2.get_name(),
|
||||
ChildState3.get_name(),
|
||||
)
|
||||
else:
|
||||
# With redis, we get a whole new instance
|
||||
assert new_test_state is not test_state
|
||||
|
Loading…
Reference in New Issue
Block a user