Optimize StateManagerDisk (#4056)
* Simplify StateManagerDisk implementation * Act more like the memory state manager and only track the root state in self.states * .load_state always loads a single state or returns None * .populate_states is the new entry point in loading from disk and it only occurs when the root state is not known * much fast * StateManagerDisk now acts much more like StateManagerMemory Treat StateManagerDisk like StateManagerMemory for AppHarness * Handle root_state deserialized from disk In this case, we need to initialize the whole state tree, so any non-persistent states will still get default values, whereas on-disk states will overwrite the defaults. * Cache root_state under client_token for StateManagerMemory compatibility Mainly this just makes it easier for us to write tests that work against either Disk or Memory state managers.
This commit is contained in:
parent
1f3be6340c
commit
aa69234b76
@ -2711,34 +2711,24 @@ class StateManagerDisk(StateManager):
|
|||||||
self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
|
self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
|
||||||
).absolute()
|
).absolute()
|
||||||
|
|
||||||
async def load_state(self, token: str, root_state: BaseState) -> BaseState:
|
async def load_state(self, token: str) -> BaseState | None:
|
||||||
"""Load a state object based on the provided token.
|
"""Load a state object based on the provided token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: The token used to identify the state object.
|
token: The token used to identify the state object.
|
||||||
root_state: The root state object.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The loaded state object.
|
The loaded state object or None.
|
||||||
"""
|
"""
|
||||||
if token in self.states:
|
|
||||||
return self.states[token]
|
|
||||||
|
|
||||||
client_token, substate_address = _split_substate_key(token)
|
|
||||||
|
|
||||||
token_path = self.token_path(token)
|
token_path = self.token_path(token)
|
||||||
|
|
||||||
if token_path.exists():
|
if token_path.exists():
|
||||||
try:
|
try:
|
||||||
with token_path.open(mode="rb") as file:
|
with token_path.open(mode="rb") as file:
|
||||||
substate = BaseState._deserialize(fp=file)
|
return BaseState._deserialize(fp=file)
|
||||||
await self.populate_substates(client_token, substate, root_state)
|
|
||||||
return substate
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return root_state.get_substate(substate_address.split(".")[1:])
|
|
||||||
|
|
||||||
async def populate_substates(
|
async def populate_substates(
|
||||||
self, client_token: str, state: BaseState, root_state: BaseState
|
self, client_token: str, state: BaseState, root_state: BaseState
|
||||||
):
|
):
|
||||||
@ -2752,10 +2742,13 @@ class StateManagerDisk(StateManager):
|
|||||||
for substate in state.get_substates():
|
for substate in state.get_substates():
|
||||||
substate_token = _substate_key(client_token, substate)
|
substate_token = _substate_key(client_token, substate)
|
||||||
|
|
||||||
substate = await self.load_state(substate_token, root_state)
|
instance = await self.load_state(substate_token)
|
||||||
|
if instance is None:
|
||||||
|
instance = await root_state.get_state(substate)
|
||||||
|
state.substates[substate.get_name()] = instance
|
||||||
|
instance.parent_state = state
|
||||||
|
|
||||||
state.substates[substate.get_name()] = substate
|
await self.populate_substates(client_token, instance, root_state)
|
||||||
substate.parent_state = state
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def get_state(
|
async def get_state(
|
||||||
@ -2770,15 +2763,24 @@ class StateManagerDisk(StateManager):
|
|||||||
Returns:
|
Returns:
|
||||||
The state for the token.
|
The state for the token.
|
||||||
"""
|
"""
|
||||||
client_token, substate_address = _split_substate_key(token)
|
client_token = _split_substate_key(token)[0]
|
||||||
|
root_state = self.states.get(client_token)
|
||||||
|
if root_state is not None:
|
||||||
|
# Retrieved state from memory.
|
||||||
|
return root_state
|
||||||
|
|
||||||
root_state_token = _substate_key(client_token, substate_address.split(".")[0])
|
# Deserialize root state from disk.
|
||||||
root_state = self.states.get(root_state_token)
|
root_state = await self.load_state(_substate_key(client_token, self.state))
|
||||||
|
# Create a new root state tree with all substates instantiated.
|
||||||
|
fresh_root_state = self.state(_reflex_internal_init=True)
|
||||||
if root_state is None:
|
if root_state is None:
|
||||||
# Create a new root state which will be persisted in the next set_state call.
|
root_state = fresh_root_state
|
||||||
root_state = self.state(_reflex_internal_init=True)
|
else:
|
||||||
|
# Ensure all substates exist, even if they were not serialized previously.
|
||||||
return await self.load_state(root_state_token, root_state)
|
root_state.substates = fresh_root_state.substates
|
||||||
|
self.states[client_token] = root_state
|
||||||
|
await self.populate_substates(client_token, root_state, root_state)
|
||||||
|
return root_state
|
||||||
|
|
||||||
async def set_state_for_substate(self, client_token: str, substate: BaseState):
|
async def set_state_for_substate(self, client_token: str, substate: BaseState):
|
||||||
"""Set the state for a substate.
|
"""Set the state for a substate.
|
||||||
@ -2789,12 +2791,12 @@ class StateManagerDisk(StateManager):
|
|||||||
"""
|
"""
|
||||||
substate_token = _substate_key(client_token, substate)
|
substate_token = _substate_key(client_token, substate)
|
||||||
|
|
||||||
self.states[substate_token] = substate
|
if substate._get_was_touched():
|
||||||
|
substate._was_touched = False # Reset the touched flag after serializing.
|
||||||
state_dilled = substate._serialize()
|
pickle_state = substate._serialize()
|
||||||
if not self.states_directory.exists():
|
if not self.states_directory.exists():
|
||||||
self.states_directory.mkdir(parents=True, exist_ok=True)
|
self.states_directory.mkdir(parents=True, exist_ok=True)
|
||||||
self.token_path(substate_token).write_bytes(state_dilled)
|
self.token_path(substate_token).write_bytes(pickle_state)
|
||||||
|
|
||||||
for substate_substate in substate.substates.values():
|
for substate_substate in substate.substates.values():
|
||||||
await self.set_state_for_substate(client_token, substate_substate)
|
await self.set_state_for_substate(client_token, substate_substate)
|
||||||
|
@ -292,8 +292,6 @@ class AppHarness:
|
|||||||
if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
||||||
# Create our own redis connection for testing.
|
# Create our own redis connection for testing.
|
||||||
self.state_manager = StateManagerRedis.create(self.app_instance.state)
|
self.state_manager = StateManagerRedis.create(self.app_instance.state)
|
||||||
elif isinstance(self.app_instance._state_manager, StateManagerDisk):
|
|
||||||
self.state_manager = StateManagerDisk.create(self.app_instance.state)
|
|
||||||
else:
|
else:
|
||||||
self.state_manager = self.app_instance._state_manager
|
self.state_manager = self.app_instance._state_manager
|
||||||
|
|
||||||
|
@ -1884,11 +1884,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
|||||||
async with sp:
|
async with sp:
|
||||||
assert sp._self_actx is not None
|
assert sp._self_actx is not None
|
||||||
assert sp._self_mutable # proxy is mutable inside context
|
assert sp._self_mutable # proxy is mutable inside context
|
||||||
if isinstance(mock_app.state_manager, StateManagerMemory):
|
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||||
# For in-process store, only one instance of the state exists
|
# For in-process store, only one instance of the state exists
|
||||||
assert sp.__wrapped__ is grandchild_state
|
assert sp.__wrapped__ is grandchild_state
|
||||||
else:
|
else:
|
||||||
# When redis or disk is used, a new+updated instance is assigned to the proxy
|
# When redis is used, a new+updated instance is assigned to the proxy
|
||||||
assert sp.__wrapped__ is not grandchild_state
|
assert sp.__wrapped__ is not grandchild_state
|
||||||
sp.value2 = "42"
|
sp.value2 = "42"
|
||||||
assert not sp._self_mutable # proxy is not mutable after exiting context
|
assert not sp._self_mutable # proxy is not mutable after exiting context
|
||||||
@ -1899,7 +1899,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
|||||||
gotten_state = await mock_app.state_manager.get_state(
|
gotten_state = await mock_app.state_manager.get_state(
|
||||||
_substate_key(grandchild_state.router.session.client_token, grandchild_state)
|
_substate_key(grandchild_state.router.session.client_token, grandchild_state)
|
||||||
)
|
)
|
||||||
if isinstance(mock_app.state_manager, StateManagerMemory):
|
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||||
# For in-process store, only one instance of the state exists
|
# For in-process store, only one instance of the state exists
|
||||||
assert gotten_state is parent_state
|
assert gotten_state is parent_state
|
||||||
else:
|
else:
|
||||||
@ -2922,7 +2922,7 @@ async def test_get_state(mock_app: rx.App, token: str):
|
|||||||
_substate_key(token, ChildState2)
|
_substate_key(token, ChildState2)
|
||||||
)
|
)
|
||||||
assert isinstance(new_test_state, TestState)
|
assert isinstance(new_test_state, TestState)
|
||||||
if isinstance(mock_app.state_manager, StateManagerMemory):
|
if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||||
# In memory, it's the same instance
|
# In memory, it's the same instance
|
||||||
assert new_test_state is test_state
|
assert new_test_state is test_state
|
||||||
test_state._clean()
|
test_state._clean()
|
||||||
@ -2932,15 +2932,6 @@ async def test_get_state(mock_app: rx.App, token: str):
|
|||||||
ChildState2.get_name(),
|
ChildState2.get_name(),
|
||||||
ChildState3.get_name(),
|
ChildState3.get_name(),
|
||||||
)
|
)
|
||||||
elif isinstance(mock_app.state_manager, StateManagerDisk):
|
|
||||||
# On disk, it's a new instance
|
|
||||||
assert new_test_state is not test_state
|
|
||||||
# All substates are available
|
|
||||||
assert tuple(sorted(new_test_state.substates)) == (
|
|
||||||
ChildState.get_name(),
|
|
||||||
ChildState2.get_name(),
|
|
||||||
ChildState3.get_name(),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# With redis, we get a whole new instance
|
# With redis, we get a whole new instance
|
||||||
assert new_test_state is not test_state
|
assert new_test_state is not test_state
|
||||||
|
Loading…
Reference in New Issue
Block a user