wip
This commit is contained in:
parent
39cdce6960
commit
82c82d9bd9
288
reflex/state.py
288
reflex/state.py
@ -938,7 +938,20 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
for substate in cls.get_substates():
|
for substate in cls.get_substates():
|
||||||
if path[0] == substate.get_name():
|
if path[0] == substate.get_name():
|
||||||
return substate.get_class_substate(path[1:])
|
return substate.get_class_substate(path[1:])
|
||||||
raise ValueError(f"Invalid path: {path}")
|
raise ValueError(f"Invalid path: {cls.get_full_name()=} {path=}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
# @functools.lru_cache()
|
||||||
|
def get_all_substate_classes(cls) -> set[Type[BaseState]]:
|
||||||
|
"""Get all substate classes of the state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The set of all substate classes.
|
||||||
|
"""
|
||||||
|
substates = set(cls.get_substates())
|
||||||
|
for substate in cls.get_substates():
|
||||||
|
substates.update(substate.get_all_substate_classes())
|
||||||
|
return substates
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_class_var(cls, path: Sequence[str]) -> Any:
|
def get_class_var(cls, path: Sequence[str]) -> Any:
|
||||||
@ -1393,7 +1406,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
return self
|
return self
|
||||||
path = path[1:]
|
path = path[1:]
|
||||||
if path[0] not in self.substates:
|
if path[0] not in self.substates:
|
||||||
raise ValueError(f"Invalid path: {path}")
|
raise ValueError(
|
||||||
|
f"Invalid path: {path=} {self.get_full_name()=} {self.substates.keys()=}"
|
||||||
|
)
|
||||||
return self.substates[path[0]].get_substate(path[1:])
|
return self.substates[path[0]].get_substate(path[1:])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1455,6 +1470,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
parent_states_with_name.append((parent_state.get_full_name(), parent_state))
|
parent_states_with_name.append((parent_state.get_full_name(), parent_state))
|
||||||
return parent_states_with_name
|
return parent_states_with_name
|
||||||
|
|
||||||
|
def _get_all_loaded_states(self) -> dict[str, BaseState]:
|
||||||
|
"""Get all loaded states in the state tree.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all loaded states in the state tree.
|
||||||
|
"""
|
||||||
|
root_state = self._get_root_state()
|
||||||
|
d = {root_state.get_full_name(): root_state}
|
||||||
|
d.update(root_state._get_loaded_substates())
|
||||||
|
return d
|
||||||
|
|
||||||
|
def _get_loaded_substates(self) -> dict[str, BaseState]:
|
||||||
|
"""Get all loaded substates of this state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of all loaded substates of this state.
|
||||||
|
"""
|
||||||
|
loaded_substates = {}
|
||||||
|
for substate in self.substates.values():
|
||||||
|
loaded_substates[substate.get_full_name()] = substate
|
||||||
|
loaded_substates.update(substate._get_loaded_substates())
|
||||||
|
return loaded_substates
|
||||||
|
|
||||||
def _get_root_state(self) -> BaseState:
|
def _get_root_state(self) -> BaseState:
|
||||||
"""Get the root state of the state tree.
|
"""Get the root state of the state tree.
|
||||||
|
|
||||||
@ -1861,6 +1899,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if include_backend or not self.computed_vars[cvar]._backend
|
if include_backend or not self.computed_vars[cvar]._backend
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: just return full name? cache?
|
||||||
@classmethod
|
@classmethod
|
||||||
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
||||||
"""Determine substates which could be affected by dirty vars in this state.
|
"""Determine substates which could be affected by dirty vars in this state.
|
||||||
@ -1882,6 +1921,22 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
)
|
)
|
||||||
return fetch_substates
|
return fetch_substates
|
||||||
|
|
||||||
|
# TODO: just return full name? cache?
|
||||||
|
# this only needs to be computed once, and only for the root state?
|
||||||
|
@classmethod
|
||||||
|
def _recursive_potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
||||||
|
"""Recursively determine substates which could be affected by dirty vars in this state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of State classes that may need to be fetched to recalc computed vars.
|
||||||
|
"""
|
||||||
|
fetch_substates = cls._potentially_dirty_substates()
|
||||||
|
for substate_cls in cls.get_substates():
|
||||||
|
fetch_substates.update(
|
||||||
|
substate_cls._recursive_potentially_dirty_substates()
|
||||||
|
)
|
||||||
|
return fetch_substates
|
||||||
|
|
||||||
def get_delta(self) -> Delta:
|
def get_delta(self) -> Delta:
|
||||||
"""Get the delta for the state.
|
"""Get the delta for the state.
|
||||||
|
|
||||||
@ -3190,77 +3245,6 @@ class StateManagerRedis(StateManager):
|
|||||||
b"evicted",
|
b"evicted",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _get_parent_state(
|
|
||||||
self, token: str, state: BaseState | None = None
|
|
||||||
) -> BaseState | None:
|
|
||||||
"""Get the parent state for the state requested in the token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: The token to get the state for (_substate_key).
|
|
||||||
state: The state instance to get parent state for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The parent state for the state requested by the token or None if there is no such parent.
|
|
||||||
"""
|
|
||||||
parent_state = None
|
|
||||||
client_token, state_path = _split_substate_key(token)
|
|
||||||
parent_state_name = state_path.rpartition(".")[0]
|
|
||||||
if parent_state_name:
|
|
||||||
cached_substates = None
|
|
||||||
if state is not None:
|
|
||||||
cached_substates = [state]
|
|
||||||
# Retrieve the parent state to populate event handlers onto this substate.
|
|
||||||
parent_state = await self.get_state(
|
|
||||||
token=_substate_key(client_token, parent_state_name),
|
|
||||||
top_level=False,
|
|
||||||
get_substates=False,
|
|
||||||
cached_substates=cached_substates,
|
|
||||||
)
|
|
||||||
return parent_state
|
|
||||||
|
|
||||||
async def _populate_substates(
|
|
||||||
self,
|
|
||||||
token: str,
|
|
||||||
state: BaseState,
|
|
||||||
all_substates: bool = False,
|
|
||||||
):
|
|
||||||
"""Fetch and link substates for the given state instance.
|
|
||||||
|
|
||||||
There is no return value; the side-effect is that `state` will have `substates` populated,
|
|
||||||
and each substate will have its `parent_state` set to `state`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: The token to get the state for.
|
|
||||||
state: The state instance to populate substates for.
|
|
||||||
all_substates: Whether to fetch all substates or just required substates.
|
|
||||||
"""
|
|
||||||
client_token, _ = _split_substate_key(token)
|
|
||||||
|
|
||||||
if all_substates:
|
|
||||||
# All substates are requested.
|
|
||||||
fetch_substates = state.get_substates()
|
|
||||||
else:
|
|
||||||
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
|
||||||
fetch_substates = state._potentially_dirty_substates()
|
|
||||||
|
|
||||||
tasks = {}
|
|
||||||
# Retrieve the necessary substates from redis.
|
|
||||||
for substate_cls in fetch_substates:
|
|
||||||
if substate_cls.get_name() in state.substates:
|
|
||||||
continue
|
|
||||||
substate_name = substate_cls.get_name()
|
|
||||||
tasks[substate_name] = asyncio.create_task(
|
|
||||||
self.get_state(
|
|
||||||
token=_substate_key(client_token, substate_cls),
|
|
||||||
top_level=False,
|
|
||||||
get_substates=all_substates,
|
|
||||||
parent_state=state,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for substate_name, substate_task in tasks.items():
|
|
||||||
state.substates[substate_name] = await substate_task
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def get_state(
|
async def get_state(
|
||||||
self,
|
self,
|
||||||
@ -3268,7 +3252,6 @@ class StateManagerRedis(StateManager):
|
|||||||
top_level: bool = True,
|
top_level: bool = True,
|
||||||
get_substates: bool = True,
|
get_substates: bool = True,
|
||||||
parent_state: BaseState | None = None,
|
parent_state: BaseState | None = None,
|
||||||
cached_substates: list[BaseState] | None = None,
|
|
||||||
) -> BaseState:
|
) -> BaseState:
|
||||||
"""Get the state for a token.
|
"""Get the state for a token.
|
||||||
|
|
||||||
@ -3277,7 +3260,6 @@ class StateManagerRedis(StateManager):
|
|||||||
top_level: If true, return an instance of the top-level state (self.state).
|
top_level: If true, return an instance of the top-level state (self.state).
|
||||||
get_substates: If true, also retrieve substates.
|
get_substates: If true, also retrieve substates.
|
||||||
parent_state: If provided, use this parent_state instead of getting it from redis.
|
parent_state: If provided, use this parent_state instead of getting it from redis.
|
||||||
cached_substates: If provided, attach these substates to the state.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The state for the token.
|
The state for the token.
|
||||||
@ -3285,8 +3267,8 @@ class StateManagerRedis(StateManager):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: when the state_cls is not specified in the token
|
RuntimeError: when the state_cls is not specified in the token
|
||||||
"""
|
"""
|
||||||
# Split the actual token from the fully qualified substate name.
|
# new impl from top to bottomA
|
||||||
_, state_path = _split_substate_key(token)
|
client_token, state_path = _split_substate_key(token)
|
||||||
if state_path:
|
if state_path:
|
||||||
# Get the State class associated with the given path.
|
# Get the State class associated with the given path.
|
||||||
state_cls = self.state.get_class_substate(state_path)
|
state_cls = self.state.get_class_substate(state_path)
|
||||||
@ -3295,44 +3277,92 @@ class StateManagerRedis(StateManager):
|
|||||||
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
|
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# The deserialized or newly created (sub)state instance.
|
state_tokens = {state_path}
|
||||||
state = None
|
|
||||||
|
|
||||||
# Fetch the serialized substate from redis.
|
# walk up the state path
|
||||||
redis_state = await self.redis.get(token)
|
walk_state_path = state_path
|
||||||
|
while "." in walk_state_path:
|
||||||
|
walk_state_path = walk_state_path.rpartition(".")[0]
|
||||||
|
state_tokens.add(walk_state_path)
|
||||||
|
|
||||||
if redis_state is not None:
|
state_tokens.update(
|
||||||
# Deserialize the substate.
|
{
|
||||||
with contextlib.suppress(StateSchemaMismatchError):
|
substate.get_full_name()
|
||||||
state = BaseState._deserialize(data=redis_state)
|
for substate in self.state._recursive_potentially_dirty_substates()
|
||||||
if state is None:
|
}
|
||||||
# Key didn't exist or schema mismatch so create a new instance for this token.
|
)
|
||||||
state = state_cls(
|
if get_substates:
|
||||||
init_substates=False,
|
state_tokens.update(
|
||||||
_reflex_internal_init=True,
|
{
|
||||||
|
substate.get_full_name()
|
||||||
|
for substate in state_cls.get_all_substate_classes()
|
||||||
|
}
|
||||||
)
|
)
|
||||||
# Populate parent state if missing and requested.
|
|
||||||
if parent_state is None:
|
|
||||||
parent_state = await self._get_parent_state(token, state)
|
|
||||||
# Set up Bidirectional linkage between this state and its parent.
|
|
||||||
if parent_state is not None:
|
|
||||||
parent_state.substates[state.get_name()] = state
|
|
||||||
state.parent_state = parent_state
|
|
||||||
# Avoid fetching substates multiple times.
|
|
||||||
if cached_substates:
|
|
||||||
for substate in cached_substates:
|
|
||||||
state.substates[substate.get_name()] = substate
|
|
||||||
if substate.parent_state is None:
|
|
||||||
substate.parent_state = state
|
|
||||||
# Populate substates if requested.
|
|
||||||
await self._populate_substates(token, state, all_substates=get_substates)
|
|
||||||
|
|
||||||
# To retain compatibility with previous implementation, by default, we return
|
loaded_states = {}
|
||||||
# the top-level state by chasing `parent_state` pointers up the tree.
|
if parent_state is not None:
|
||||||
|
loaded_states = parent_state._get_all_loaded_states()
|
||||||
|
# remove all states that are already loaded
|
||||||
|
state_tokens = state_tokens.difference(loaded_states.keys())
|
||||||
|
|
||||||
|
redis_states = await self.hmget(name=client_token, keys=list(state_tokens))
|
||||||
|
redis_states.update(loaded_states)
|
||||||
|
root_state = redis_states[self.state.get_full_name()]
|
||||||
|
self.recursive_link_substates(state=root_state, substates=redis_states)
|
||||||
|
|
||||||
if top_level:
|
if top_level:
|
||||||
return state._get_root_state()
|
return root_state
|
||||||
|
|
||||||
|
state = redis_states[state_path]
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
def recursive_link_substates(
|
||||||
|
self,
|
||||||
|
state: BaseState,
|
||||||
|
substates: dict[str, BaseState],
|
||||||
|
):
|
||||||
|
"""Recursively link substates to a state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The state to link substates to.
|
||||||
|
substates: The substates to link.
|
||||||
|
"""
|
||||||
|
for substate_cls in state.get_substates():
|
||||||
|
if substate_cls.get_full_name() not in substates:
|
||||||
|
continue
|
||||||
|
substate = substates[substate_cls.get_full_name()]
|
||||||
|
state.substates[substate.get_name()] = substate
|
||||||
|
substate.parent_state = state
|
||||||
|
self.recursive_link_substates(
|
||||||
|
state=substate,
|
||||||
|
substates=substates,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def hmget(self, name: str, keys: List[str]) -> dict[str, BaseState]:
|
||||||
|
"""Get multiple values from a hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the hash.
|
||||||
|
keys: The keys to get.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The values.
|
||||||
|
"""
|
||||||
|
d = {}
|
||||||
|
for state in await self.redis.hmget(name=name, keys=keys): # type: ignore
|
||||||
|
key = keys.pop(0)
|
||||||
|
if state is not None:
|
||||||
|
with contextlib.suppress(StateSchemaMismatchError):
|
||||||
|
state = BaseState._deserialize(data=state)
|
||||||
|
if state is None:
|
||||||
|
state_cls = self.state.get_class_substate(key)
|
||||||
|
state = state_cls(
|
||||||
|
init_substates=False,
|
||||||
|
_reflex_internal_init=True,
|
||||||
|
)
|
||||||
|
d[state.get_full_name()] = state
|
||||||
|
return d
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def set_state(
|
async def set_state(
|
||||||
self,
|
self,
|
||||||
@ -3368,31 +3398,25 @@ class StateManagerRedis(StateManager):
|
|||||||
f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
|
f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Recursively set_state on all known substates.
|
redis_hashset = {}
|
||||||
tasks = []
|
|
||||||
for substate in state.substates.values():
|
|
||||||
tasks.append(
|
|
||||||
asyncio.create_task(
|
|
||||||
self.set_state(
|
|
||||||
token=_substate_key(client_token, substate),
|
|
||||||
state=substate,
|
|
||||||
lock_id=lock_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
|
|
||||||
if state._get_was_touched():
|
|
||||||
pickle_state = state._serialize()
|
|
||||||
if pickle_state:
|
|
||||||
await self.redis.set(
|
|
||||||
_substate_key(client_token, state),
|
|
||||||
pickle_state,
|
|
||||||
ex=self.token_expiration,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for substates to be persisted.
|
for state_name, substate in state._get_all_loaded_states().items():
|
||||||
for t in tasks:
|
if not substate._get_was_touched():
|
||||||
await t
|
continue
|
||||||
|
pickle_state = substate._serialize()
|
||||||
|
if not pickle_state:
|
||||||
|
continue
|
||||||
|
redis_hashset[state_name] = pickle_state
|
||||||
|
|
||||||
|
if not redis_hashset:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.redis.hmset(name=client_token, mapping=redis_hashset) # type: ignore
|
||||||
|
await self.redis.hexpire(
|
||||||
|
client_token,
|
||||||
|
self.token_expiration,
|
||||||
|
*redis_hashset.keys(),
|
||||||
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
|
@ -11,7 +11,6 @@ from selenium.webdriver.common.by import By
|
|||||||
from selenium.webdriver.remote.webdriver import WebDriver
|
from selenium.webdriver.remote.webdriver import WebDriver
|
||||||
|
|
||||||
from reflex.state import (
|
from reflex.state import (
|
||||||
State,
|
|
||||||
StateManagerDisk,
|
StateManagerDisk,
|
||||||
StateManagerMemory,
|
StateManagerMemory,
|
||||||
StateManagerRedis,
|
StateManagerRedis,
|
||||||
@ -278,6 +277,7 @@ async def test_client_side_state(
|
|||||||
set_sub_sub_state_button.click()
|
set_sub_sub_state_button.click()
|
||||||
|
|
||||||
token = poll_for_token()
|
token = poll_for_token()
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
# get a reference to all cookie and local storage elements
|
# get a reference to all cookie and local storage elements
|
||||||
c1 = driver.find_element(By.ID, "c1")
|
c1 = driver.find_element(By.ID, "c1")
|
||||||
@ -613,16 +613,7 @@ async def test_client_side_state(
|
|||||||
|
|
||||||
# Simulate state expiration
|
# Simulate state expiration
|
||||||
if isinstance(client_side.state_manager, StateManagerRedis):
|
if isinstance(client_side.state_manager, StateManagerRedis):
|
||||||
await client_side.state_manager.redis.delete(
|
await client_side.state_manager.redis.delete(token)
|
||||||
_substate_key(token, State.get_full_name())
|
|
||||||
)
|
|
||||||
await client_side.state_manager.redis.delete(_substate_key(token, state_name))
|
|
||||||
await client_side.state_manager.redis.delete(
|
|
||||||
_substate_key(token, sub_state_name)
|
|
||||||
)
|
|
||||||
await client_side.state_manager.redis.delete(
|
|
||||||
_substate_key(token, sub_sub_state_name)
|
|
||||||
)
|
|
||||||
elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)):
|
elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||||
del client_side.state_manager.states[token]
|
del client_side.state_manager.states[token]
|
||||||
if isinstance(client_side.state_manager, StateManagerDisk):
|
if isinstance(client_side.state_manager, StateManagerDisk):
|
||||||
@ -679,9 +670,8 @@ async def test_client_side_state(
|
|||||||
|
|
||||||
# Get the backend state and ensure the values are still set
|
# Get the backend state and ensure the values are still set
|
||||||
async def get_sub_state():
|
async def get_sub_state():
|
||||||
root_state = await client_side.get_state(
|
assert token is not None
|
||||||
_substate_key(token or "", sub_state_name)
|
root_state = await client_side.get_state(_substate_key(token, sub_state_name))
|
||||||
)
|
|
||||||
state = root_state.substates[client_side.get_state_name("_client_side_state")]
|
state = root_state.substates[client_side.get_state_name("_client_side_state")]
|
||||||
sub_state = state.substates[
|
sub_state = state.substates[
|
||||||
client_side.get_state_name("_client_side_sub_state")
|
client_side.get_state_name("_client_side_sub_state")
|
||||||
|
@ -354,11 +354,11 @@ async def state_manager_redis(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_get_state_tree(
|
async def test_get_state_tree(
|
||||||
state_manager_redis,
|
state_manager_redis: StateManagerRedis,
|
||||||
token,
|
token: str,
|
||||||
substate_cls,
|
substate_cls: type[BaseState],
|
||||||
exp_root_substates,
|
exp_root_substates: list[str],
|
||||||
exp_root_dict_keys,
|
exp_root_dict_keys: list[str],
|
||||||
):
|
):
|
||||||
"""Test getting state trees and assert on which branches are retrieved.
|
"""Test getting state trees and assert on which branches are retrieved.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user