Simplify StateManagerDisk implementation

* Act more like the memory state manager and only track the root state in self.states
* .load_state always loads a single state or returns None
* .populate_states is the new entry point in loading from disk and it only occurs
  when the root state is not known
* much fast
This commit is contained in:
Masen Furer 2024-10-03 16:52:03 -07:00
parent d77b900bd7
commit eebcbc1054
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95

View File

@ -2711,34 +2711,24 @@ class StateManagerDisk(StateManager):
self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl" self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl"
).absolute() ).absolute()
async def load_state(self, token: str, root_state: BaseState) -> BaseState: async def load_state(self, token: str) -> BaseState | None:
"""Load a state object based on the provided token. """Load a state object based on the provided token.
Args: Args:
token: The token used to identify the state object. token: The token used to identify the state object.
root_state: The root state object.
Returns: Returns:
The loaded state object. The loaded state object or None.
""" """
if token in self.states:
return self.states[token]
client_token, substate_address = _split_substate_key(token)
token_path = self.token_path(token) token_path = self.token_path(token)
if token_path.exists(): if token_path.exists():
try: try:
with token_path.open(mode="rb") as file: with token_path.open(mode="rb") as file:
substate = BaseState._deserialize(fp=file) return BaseState._deserialize(fp=file)
await self.populate_substates(client_token, substate, root_state)
return substate
except Exception: except Exception:
pass pass
return root_state.get_substate(substate_address.split(".")[1:])
async def populate_substates( async def populate_substates(
self, client_token: str, state: BaseState, root_state: BaseState self, client_token: str, state: BaseState, root_state: BaseState
): ):
@ -2752,10 +2742,14 @@ class StateManagerDisk(StateManager):
for substate in state.get_substates(): for substate in state.get_substates():
substate_token = _substate_key(client_token, substate) substate_token = _substate_key(client_token, substate)
substate = await self.load_state(substate_token, root_state) instance = await self.load_state(substate_token)
if instance is not None:
state.substates[substate.get_name()] = substate state.substates[substate.get_name()] = instance
substate.parent_state = state substate.parent_state = state
else:
instance = await root_state.get_state(substate)
await self.populate_substates(client_token, instance, root_state)
@override @override
async def get_state( async def get_state(
@ -2770,15 +2764,14 @@ class StateManagerDisk(StateManager):
Returns: Returns:
The state for the token. The state for the token.
""" """
client_token, substate_address = _split_substate_key(token) client_token = _split_substate_key(token)[0]
root_state = self.states.get(client_token)
root_state_token = _substate_key(client_token, substate_address.split(".")[0])
root_state = self.states.get(root_state_token)
if root_state is None: if root_state is None:
# Create a new root state which will be persisted in the next set_state call. # Create a new root state which will be persisted in the next set_state call.
root_state = self.state(_reflex_internal_init=True) root_state = self.state(_reflex_internal_init=True)
self.states[client_token] = root_state
return await self.load_state(root_state_token, root_state) await self.populate_substates(client_token, root_state, root_state)
return root_state
async def set_state_for_substate(self, client_token: str, substate: BaseState): async def set_state_for_substate(self, client_token: str, substate: BaseState):
"""Set the state for a substate. """Set the state for a substate.
@ -2789,12 +2782,12 @@ class StateManagerDisk(StateManager):
""" """
substate_token = _substate_key(client_token, substate) substate_token = _substate_key(client_token, substate)
self.states[substate_token] = substate if substate._get_was_touched():
substate._was_touched = False # Reset the touched flag after serializing.
state_dilled = substate._serialize() pickle_state = substate._serialize()
if not self.states_directory.exists(): if not self.states_directory.exists():
self.states_directory.mkdir(parents=True, exist_ok=True) self.states_directory.mkdir(parents=True, exist_ok=True)
self.token_path(substate_token).write_bytes(state_dilled) self.token_path(substate_token).write_bytes(pickle_state)
for substate_substate in substate.substates.values(): for substate_substate in substate.substates.values():
await self.set_state_for_substate(client_token, substate_substate) await self.set_state_for_substate(client_token, substate_substate)