Merge ffd99ec00e
into 848b87070c
This commit is contained in:
commit
8de2da2ce5
309
benchmarks/benchmark_state_manager_redis.py
Normal file
309
benchmarks/benchmark_state_manager_redis.py
Normal file
@ -0,0 +1,309 @@
|
||||
"""Benchmark for the state manager redis."""
|
||||
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pytest_benchmark.fixture import BenchmarkFixture
|
||||
|
||||
from reflex.state import State, StateManagerRedis
|
||||
from reflex.utils.prerequisites import get_redis
|
||||
from reflex.vars.base import computed_var
|
||||
|
||||
|
||||
class RootState(State):
|
||||
"""Root state class for testing."""
|
||||
|
||||
counter: int = 0
|
||||
int_dict: dict[str, int] = {}
|
||||
|
||||
|
||||
class ChildState(RootState):
|
||||
"""Child state class for testing."""
|
||||
|
||||
child_counter: int = 0
|
||||
|
||||
@computed_var
|
||||
def str_dict(self):
|
||||
"""Convert the int dict to a string dict.
|
||||
|
||||
Returns:
|
||||
A dictionary with string keys and integer values.
|
||||
"""
|
||||
return {str(k): v for k, v in self.int_dict.items()}
|
||||
|
||||
|
||||
class ChildState2(RootState):
|
||||
"""Child state 2 class for testing."""
|
||||
|
||||
child2_counter: int = 0
|
||||
|
||||
|
||||
class GrandChildState(ChildState):
|
||||
"""Grandchild state class for testing."""
|
||||
|
||||
grand_child_counter: int = 0
|
||||
float_dict: dict[str, float] = {}
|
||||
|
||||
@computed_var
|
||||
def double_counter(self):
|
||||
"""Double the counter.
|
||||
|
||||
Returns:
|
||||
The counter value multiplied by 2.
|
||||
"""
|
||||
return self.counter * 2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def state_manager() -> StateManagerRedis:
|
||||
"""Fixture for the redis state manager.
|
||||
|
||||
Returns:
|
||||
An instance of StateManagerRedis.
|
||||
"""
|
||||
redis = get_redis()
|
||||
if redis is None:
|
||||
pytest.skip("Redis is not available")
|
||||
return StateManagerRedis(redis=redis, state=State)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token() -> str:
|
||||
"""Fixture for the token.
|
||||
|
||||
Returns:
|
||||
A unique token string.
|
||||
"""
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def grand_child_state_token(token: str) -> str:
|
||||
"""Fixture for the grand child state token.
|
||||
|
||||
Args:
|
||||
token: The token fixture.
|
||||
|
||||
Returns:
|
||||
A string combining the token and the grandchild state name.
|
||||
"""
|
||||
return f"{token}_{GrandChildState.get_full_name()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def state_token(token: str) -> str:
|
||||
"""Fixture for the base state token.
|
||||
|
||||
Args:
|
||||
token: The token fixture.
|
||||
|
||||
Returns:
|
||||
A string combining the token and the base state name.
|
||||
"""
|
||||
return f"{token}_{State.get_full_name()}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def grand_child_state() -> GrandChildState:
|
||||
"""Fixture for the grand child state.
|
||||
|
||||
Returns:
|
||||
An instance of GrandChildState.
|
||||
"""
|
||||
state = State()
|
||||
|
||||
root = RootState()
|
||||
root.parent_state = state
|
||||
state.substates[root.get_name()] = root
|
||||
|
||||
child = ChildState()
|
||||
child.parent_state = root
|
||||
root.substates[child.get_name()] = child
|
||||
|
||||
child2 = ChildState2()
|
||||
child2.parent_state = root
|
||||
root.substates[child2.get_name()] = child2
|
||||
|
||||
gcs = GrandChildState()
|
||||
gcs.parent_state = child
|
||||
child.substates[gcs.get_name()] = gcs
|
||||
|
||||
return gcs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def grand_child_state_big(grand_child_state: GrandChildState) -> GrandChildState:
|
||||
"""Fixture for the grand child state with big data.
|
||||
|
||||
Args:
|
||||
grand_child_state: The grand child state fixture.
|
||||
|
||||
Returns:
|
||||
An instance of GrandChildState with large data.
|
||||
"""
|
||||
grand_child_state.counter = 100
|
||||
grand_child_state.child_counter = 200
|
||||
grand_child_state.grand_child_counter = 300
|
||||
grand_child_state.int_dict = {str(i): i for i in range(10000)}
|
||||
grand_child_state.float_dict = {str(i): i + 0.5 for i in range(10000)}
|
||||
return grand_child_state
|
||||
|
||||
|
||||
def test_set_state(
|
||||
benchmark: BenchmarkFixture,
|
||||
state_manager: StateManagerRedis,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
token: str,
|
||||
) -> None:
|
||||
"""Benchmark setting state with minimal data.
|
||||
|
||||
Args:
|
||||
benchmark: The benchmark fixture.
|
||||
state_manager: The state manager fixture.
|
||||
event_loop: The event loop fixture.
|
||||
token: The token fixture.
|
||||
"""
|
||||
state = State()
|
||||
|
||||
def func():
|
||||
event_loop.run_until_complete(state_manager.set_state(token=token, state=state))
|
||||
|
||||
benchmark(func)
|
||||
|
||||
|
||||
def test_get_state(
|
||||
benchmark: BenchmarkFixture,
|
||||
state_manager: StateManagerRedis,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
state_token: str,
|
||||
) -> None:
|
||||
"""Benchmark getting state with minimal data.
|
||||
|
||||
Args:
|
||||
benchmark: The benchmark fixture.
|
||||
state_manager: The state manager fixture.
|
||||
event_loop: The event loop fixture.
|
||||
state_token: The base state token fixture.
|
||||
"""
|
||||
state = State()
|
||||
event_loop.run_until_complete(
|
||||
state_manager.set_state(token=state_token, state=state)
|
||||
)
|
||||
|
||||
def func():
|
||||
_ = event_loop.run_until_complete(state_manager.get_state(token=state_token))
|
||||
|
||||
benchmark(func)
|
||||
|
||||
|
||||
def test_set_state_tree_minimal(
|
||||
benchmark: BenchmarkFixture,
|
||||
state_manager: StateManagerRedis,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
grand_child_state_token: str,
|
||||
grand_child_state: GrandChildState,
|
||||
) -> None:
|
||||
"""Benchmark setting state with minimal data.
|
||||
|
||||
Args:
|
||||
benchmark: The benchmark fixture.
|
||||
state_manager: The state manager fixture.
|
||||
event_loop: The event loop fixture.
|
||||
grand_child_state_token: The grand child state token fixture.
|
||||
grand_child_state: The grand child state fixture.
|
||||
"""
|
||||
|
||||
def func():
|
||||
event_loop.run_until_complete(
|
||||
state_manager.set_state(
|
||||
token=grand_child_state_token, state=grand_child_state
|
||||
)
|
||||
)
|
||||
|
||||
benchmark(func)
|
||||
|
||||
|
||||
def test_get_state_tree_minimal(
|
||||
benchmark: BenchmarkFixture,
|
||||
state_manager: StateManagerRedis,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
grand_child_state_token: str,
|
||||
grand_child_state: GrandChildState,
|
||||
) -> None:
|
||||
"""Benchmark getting state with minimal data.
|
||||
|
||||
Args:
|
||||
benchmark: The benchmark fixture.
|
||||
state_manager: The state manager fixture.
|
||||
event_loop: The event loop fixture.
|
||||
grand_child_state_token: The grand child state token fixture.
|
||||
grand_child_state: The grand child state fixture.
|
||||
"""
|
||||
event_loop.run_until_complete(
|
||||
state_manager.set_state(token=grand_child_state_token, state=grand_child_state)
|
||||
)
|
||||
|
||||
def func():
|
||||
_ = event_loop.run_until_complete(
|
||||
state_manager.get_state(token=grand_child_state_token)
|
||||
)
|
||||
|
||||
benchmark(func)
|
||||
|
||||
|
||||
def test_set_state_tree_big(
|
||||
benchmark: BenchmarkFixture,
|
||||
state_manager: StateManagerRedis,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
grand_child_state_token: str,
|
||||
grand_child_state_big: GrandChildState,
|
||||
) -> None:
|
||||
"""Benchmark setting state with minimal data.
|
||||
|
||||
Args:
|
||||
benchmark: The benchmark fixture.
|
||||
state_manager: The state manager fixture.
|
||||
event_loop: The event loop fixture.
|
||||
grand_child_state_token: The grand child state token fixture.
|
||||
grand_child_state_big: The grand child state fixture.
|
||||
"""
|
||||
|
||||
def func():
|
||||
event_loop.run_until_complete(
|
||||
state_manager.set_state(
|
||||
token=grand_child_state_token, state=grand_child_state_big
|
||||
)
|
||||
)
|
||||
|
||||
benchmark(func)
|
||||
|
||||
|
||||
def test_get_state_tree_big(
|
||||
benchmark: BenchmarkFixture,
|
||||
state_manager: StateManagerRedis,
|
||||
event_loop: asyncio.AbstractEventLoop,
|
||||
grand_child_state_token: str,
|
||||
grand_child_state_big: GrandChildState,
|
||||
) -> None:
|
||||
"""Benchmark getting state with minimal data.
|
||||
|
||||
Args:
|
||||
benchmark: The benchmark fixture.
|
||||
state_manager: The state manager fixture.
|
||||
event_loop: The event loop fixture.
|
||||
grand_child_state_token: The grand child state token fixture.
|
||||
grand_child_state_big: The grand child state fixture.
|
||||
"""
|
||||
event_loop.run_until_complete(
|
||||
state_manager.set_state(
|
||||
token=grand_child_state_token, state=grand_child_state_big
|
||||
)
|
||||
)
|
||||
|
||||
def func():
|
||||
_ = event_loop.run_until_complete(
|
||||
state_manager.get_state(token=grand_child_state_token)
|
||||
)
|
||||
|
||||
benchmark(func)
|
3
poetry.lock
generated
3
poetry.lock
generated
@ -1147,6 +1147,7 @@ files = [
|
||||
{file = "nh3-0.2.19-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:00810cd5275f5c3f44b9eb0e521d1a841ee2f8023622de39ffc7d88bd533d8e0"},
|
||||
{file = "nh3-0.2.19-cp38-abi3-win32.whl", hash = "sha256:7e98621856b0a911c21faa5eef8f8ea3e691526c2433f9afc2be713cb6fbdb48"},
|
||||
{file = "nh3-0.2.19-cp38-abi3-win_amd64.whl", hash = "sha256:75c7cafb840f24430b009f7368945cb5ca88b2b54bb384ebfba495f16bc9c121"},
|
||||
{file = "nh3-0.2.19.tar.gz", hash = "sha256:790056b54c068ff8dceb443eaefb696b84beff58cca6c07afd754d17692a4804"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3076,4 +3077,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "d62cd1897d8f73e9aad9e907beb82be509dc5e33d8f37b36ebf26ad1f3075a9f"
|
||||
content-hash = "4ef559dcc4b3fd0d88c908cb4df4d7a14e3d021498d3034ad1b9481131abe686"
|
||||
|
@ -27,7 +27,7 @@ psutil = ">=5.9.4,<7.0"
|
||||
pydantic = ">=1.10.2,<3.0"
|
||||
python-multipart = ">=0.0.5,<0.1"
|
||||
python-socketio = ">=5.7.0,<6.0"
|
||||
redis = ">=4.3.5,<6.0"
|
||||
redis = ">=5.1.0,<6.0"
|
||||
rich = ">=13.0.0,<14.0"
|
||||
sqlmodel = ">=0.0.14,<0.1"
|
||||
typer = ">=0.4.2,<1.0"
|
||||
|
352
reflex/state.py
352
reflex/state.py
@ -956,7 +956,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
for substate in cls.get_substates():
|
||||
if path[0] == substate.get_name():
|
||||
return substate.get_class_substate(path[1:])
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
raise ValueError(f"Invalid path: {cls.get_full_name()=} {path=}")
|
||||
|
||||
@classmethod
|
||||
def get_all_substate_classes(cls) -> set[Type[BaseState]]:
|
||||
"""Get all substate classes of the state.
|
||||
|
||||
Returns:
|
||||
The set of all substate classes.
|
||||
"""
|
||||
substates = set(cls.get_substates())
|
||||
for substate in cls.get_substates():
|
||||
substates.update(substate.get_all_substate_classes())
|
||||
return substates
|
||||
|
||||
@classmethod
|
||||
def get_class_var(cls, path: Sequence[str]) -> Any:
|
||||
@ -1414,7 +1426,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
return self
|
||||
path = path[1:]
|
||||
if path[0] not in self.substates:
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
raise ValueError(
|
||||
f"Invalid path: {path=} {self.get_full_name()=} {self.substates.keys()=}"
|
||||
)
|
||||
return self.substates[path[0]].get_substate(path[1:])
|
||||
|
||||
@classmethod
|
||||
@ -1476,6 +1490,63 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
parent_states_with_name.append((parent_state.get_full_name(), parent_state))
|
||||
return parent_states_with_name
|
||||
|
||||
def _get_loaded_states(self) -> dict[str, BaseState]:
|
||||
"""Get all loaded states in the state tree.
|
||||
|
||||
Returns:
|
||||
A list of all loaded states in the state tree.
|
||||
"""
|
||||
root_state = self._get_root_state()
|
||||
d = {root_state.get_full_name(): root_state}
|
||||
root_state._get_loaded_substates(d)
|
||||
return d
|
||||
|
||||
def _get_loaded_substates(
|
||||
self,
|
||||
loaded_substates: dict[str, BaseState],
|
||||
) -> None:
|
||||
"""Get all loaded substates of this state.
|
||||
|
||||
Args:
|
||||
loaded_substates: A dictionary of loaded substates which will be updated with the substates of this state.
|
||||
"""
|
||||
for substate in self.substates.values():
|
||||
loaded_substates[substate.get_full_name()] = substate
|
||||
substate._get_loaded_substates(loaded_substates)
|
||||
|
||||
def _serialize_touched_states(self) -> dict[str, bytes]:
|
||||
"""Serialize all touched states in the state tree.
|
||||
|
||||
Returns:
|
||||
The serialized states.
|
||||
"""
|
||||
root_state = self._get_root_state()
|
||||
d = {}
|
||||
if root_state._get_was_touched():
|
||||
serialized = root_state._serialize()
|
||||
if serialized:
|
||||
d[root_state.get_full_name()] = serialized
|
||||
root_state._serialize_touched_substates(d)
|
||||
return d
|
||||
|
||||
def _serialize_touched_substates(
|
||||
self,
|
||||
touched_substates: dict[str, bytes],
|
||||
) -> None:
|
||||
"""Serialize all touched substates of this state.
|
||||
|
||||
Args:
|
||||
touched_substates: A dictionary of touched substates which will be updated with the substates of this state.
|
||||
"""
|
||||
for substate in self.substates.values():
|
||||
substate._serialize_touched_substates(touched_substates)
|
||||
if not substate._get_was_touched():
|
||||
continue
|
||||
serialized = substate._serialize()
|
||||
if not serialized:
|
||||
continue
|
||||
touched_substates[substate.get_full_name()] = serialized
|
||||
|
||||
def _get_root_state(self) -> BaseState:
|
||||
"""Get the root state of the state tree.
|
||||
|
||||
@ -1883,26 +1954,48 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]:
|
||||
def _potentially_dirty_substates(cls) -> set[str]:
|
||||
"""Determine substates which could be affected by dirty vars in this state.
|
||||
|
||||
Returns:
|
||||
Set of State classes that may need to be fetched to recalc computed vars.
|
||||
Set of State full names that may need to be fetched to recalc computed vars.
|
||||
"""
|
||||
# _always_dirty_substates need to be fetched to recalc computed vars.
|
||||
fetch_substates = {
|
||||
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
||||
f"{cls.get_full_name()}.{substate_name}"
|
||||
for substate_name in cls._always_dirty_substates
|
||||
}
|
||||
for dependent_substates in cls._substate_var_dependencies.values():
|
||||
fetch_substates.update(
|
||||
{
|
||||
cls.get_class_substate((cls.get_name(), *substate_name.split(".")))
|
||||
f"{cls.get_full_name()}.{substate_name}"
|
||||
for substate_name in dependent_substates
|
||||
}
|
||||
)
|
||||
return fetch_substates
|
||||
|
||||
@classmethod
|
||||
def _recursive_potentially_dirty_substates(
|
||||
cls,
|
||||
already_selected: Type[BaseState] | None = None,
|
||||
) -> set[str]:
|
||||
"""Recursively determine substates which could be affected by dirty vars in this state.
|
||||
|
||||
Args:
|
||||
already_selected: The class of the state that has already been selected and needs no further processing.
|
||||
|
||||
Returns:
|
||||
Set of full state names that may need to be fetched to recalc computed vars.
|
||||
"""
|
||||
if already_selected is not None and already_selected == cls:
|
||||
return set()
|
||||
fetch_substates = cls._potentially_dirty_substates()
|
||||
for substate_cls in cls.get_substates():
|
||||
fetch_substates.update(
|
||||
substate_cls._recursive_potentially_dirty_substates(already_selected)
|
||||
)
|
||||
return fetch_substates
|
||||
|
||||
def get_delta(self) -> Delta:
|
||||
"""Get the delta for the state.
|
||||
|
||||
@ -3231,6 +3324,9 @@ class StateManagerRedis(StateManager):
|
||||
default_factory=_default_lock_warning_threshold
|
||||
)
|
||||
|
||||
# If HEXPIRE is not supported, use EXPIRE instead.
|
||||
_hexpire_not_supported: Optional[bool] = pydantic.PrivateAttr(None)
|
||||
|
||||
# The keyspace subscription string when redis is waiting for lock to be released
|
||||
_redis_notify_keyspace_events: str = (
|
||||
"K" # Enable keyspace notifications (target a particular key)
|
||||
@ -3247,77 +3343,6 @@ class StateManagerRedis(StateManager):
|
||||
b"evicted",
|
||||
}
|
||||
|
||||
async def _get_parent_state(
|
||||
self, token: str, state: BaseState | None = None
|
||||
) -> BaseState | None:
|
||||
"""Get the parent state for the state requested in the token.
|
||||
|
||||
Args:
|
||||
token: The token to get the state for (_substate_key).
|
||||
state: The state instance to get parent state for.
|
||||
|
||||
Returns:
|
||||
The parent state for the state requested by the token or None if there is no such parent.
|
||||
"""
|
||||
parent_state = None
|
||||
client_token, state_path = _split_substate_key(token)
|
||||
parent_state_name = state_path.rpartition(".")[0]
|
||||
if parent_state_name:
|
||||
cached_substates = None
|
||||
if state is not None:
|
||||
cached_substates = [state]
|
||||
# Retrieve the parent state to populate event handlers onto this substate.
|
||||
parent_state = await self.get_state(
|
||||
token=_substate_key(client_token, parent_state_name),
|
||||
top_level=False,
|
||||
get_substates=False,
|
||||
cached_substates=cached_substates,
|
||||
)
|
||||
return parent_state
|
||||
|
||||
async def _populate_substates(
|
||||
self,
|
||||
token: str,
|
||||
state: BaseState,
|
||||
all_substates: bool = False,
|
||||
):
|
||||
"""Fetch and link substates for the given state instance.
|
||||
|
||||
There is no return value; the side-effect is that `state` will have `substates` populated,
|
||||
and each substate will have its `parent_state` set to `state`.
|
||||
|
||||
Args:
|
||||
token: The token to get the state for.
|
||||
state: The state instance to populate substates for.
|
||||
all_substates: Whether to fetch all substates or just required substates.
|
||||
"""
|
||||
client_token, _ = _split_substate_key(token)
|
||||
|
||||
if all_substates:
|
||||
# All substates are requested.
|
||||
fetch_substates = state.get_substates()
|
||||
else:
|
||||
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
|
||||
fetch_substates = state._potentially_dirty_substates()
|
||||
|
||||
tasks = {}
|
||||
# Retrieve the necessary substates from redis.
|
||||
for substate_cls in fetch_substates:
|
||||
if substate_cls.get_name() in state.substates:
|
||||
continue
|
||||
substate_name = substate_cls.get_name()
|
||||
tasks[substate_name] = asyncio.create_task(
|
||||
self.get_state(
|
||||
token=_substate_key(client_token, substate_cls),
|
||||
top_level=False,
|
||||
get_substates=all_substates,
|
||||
parent_state=state,
|
||||
)
|
||||
)
|
||||
|
||||
for substate_name, substate_task in tasks.items():
|
||||
state.substates[substate_name] = await substate_task
|
||||
|
||||
@override
|
||||
async def get_state(
|
||||
self,
|
||||
@ -3325,7 +3350,6 @@ class StateManagerRedis(StateManager):
|
||||
top_level: bool = True,
|
||||
get_substates: bool = True,
|
||||
parent_state: BaseState | None = None,
|
||||
cached_substates: list[BaseState] | None = None,
|
||||
) -> BaseState:
|
||||
"""Get the state for a token.
|
||||
|
||||
@ -3334,7 +3358,6 @@ class StateManagerRedis(StateManager):
|
||||
top_level: If true, return an instance of the top-level state (self.state).
|
||||
get_substates: If true, also retrieve substates.
|
||||
parent_state: If provided, use this parent_state instead of getting it from redis.
|
||||
cached_substates: If provided, attach these substates to the state.
|
||||
|
||||
Returns:
|
||||
The state for the token.
|
||||
@ -3342,8 +3365,8 @@ class StateManagerRedis(StateManager):
|
||||
Raises:
|
||||
RuntimeError: when the state_cls is not specified in the token
|
||||
"""
|
||||
# Split the actual token from the fully qualified substate name.
|
||||
_, state_path = _split_substate_key(token)
|
||||
# new impl from top to bottomA
|
||||
client_token, state_path = _split_substate_key(token)
|
||||
if state_path:
|
||||
# Get the State class associated with the given path.
|
||||
state_cls = self.state.get_class_substate(state_path)
|
||||
@ -3352,44 +3375,94 @@ class StateManagerRedis(StateManager):
|
||||
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
|
||||
)
|
||||
|
||||
# The deserialized or newly created (sub)state instance.
|
||||
state = None
|
||||
state_tokens = {state_path}
|
||||
|
||||
# Fetch the serialized substate from redis.
|
||||
redis_state = await self.redis.get(token)
|
||||
# walk up the state path
|
||||
walk_state_path = state_path
|
||||
while "." in walk_state_path:
|
||||
walk_state_path = walk_state_path.rpartition(".")[0]
|
||||
state_tokens.add(walk_state_path)
|
||||
|
||||
if redis_state is not None:
|
||||
# Deserialize the substate.
|
||||
with contextlib.suppress(StateSchemaMismatchError):
|
||||
state = BaseState._deserialize(data=redis_state)
|
||||
if state is None:
|
||||
# Key didn't exist or schema mismatch so create a new instance for this token.
|
||||
state = state_cls(
|
||||
init_substates=False,
|
||||
_reflex_internal_init=True,
|
||||
if get_substates:
|
||||
state_tokens.update(
|
||||
{
|
||||
substate.get_full_name()
|
||||
for substate in state_cls.get_all_substate_classes()
|
||||
}
|
||||
)
|
||||
# Populate parent state if missing and requested.
|
||||
if parent_state is None:
|
||||
parent_state = await self._get_parent_state(token, state)
|
||||
# Set up Bidirectional linkage between this state and its parent.
|
||||
if parent_state is not None:
|
||||
parent_state.substates[state.get_name()] = state
|
||||
state.parent_state = parent_state
|
||||
# Avoid fetching substates multiple times.
|
||||
if cached_substates:
|
||||
for substate in cached_substates:
|
||||
state.substates[substate.get_name()] = substate
|
||||
if substate.parent_state is None:
|
||||
substate.parent_state = state
|
||||
# Populate substates if requested.
|
||||
await self._populate_substates(token, state, all_substates=get_substates)
|
||||
state_tokens.update(
|
||||
self.state._recursive_potentially_dirty_substates(
|
||||
already_selected=state_cls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
state_tokens.update(self.state._recursive_potentially_dirty_substates())
|
||||
|
||||
loaded_states = {}
|
||||
if parent_state is not None:
|
||||
loaded_states = parent_state._get_loaded_states()
|
||||
# remove all states that are already loaded
|
||||
state_tokens = state_tokens.difference(loaded_states.keys())
|
||||
|
||||
redis_states = await self.hmget(name=client_token, keys=list(state_tokens))
|
||||
redis_states.update(loaded_states)
|
||||
root_state = redis_states[self.state.get_full_name()]
|
||||
self.recursive_link_substates(state=root_state, substates=redis_states)
|
||||
|
||||
# To retain compatibility with previous implementation, by default, we return
|
||||
# the top-level state by chasing `parent_state` pointers up the tree.
|
||||
if top_level:
|
||||
return state._get_root_state()
|
||||
return root_state
|
||||
|
||||
state = redis_states[state_path]
|
||||
return state
|
||||
|
||||
def recursive_link_substates(
|
||||
self,
|
||||
state: BaseState,
|
||||
substates: dict[str, BaseState],
|
||||
):
|
||||
"""Recursively link substates to a state.
|
||||
|
||||
Args:
|
||||
state: The state to link substates to.
|
||||
substates: The substates to link.
|
||||
"""
|
||||
for substate_cls in state.get_substates():
|
||||
if substate_cls.get_full_name() not in substates:
|
||||
continue
|
||||
substate = substates[substate_cls.get_full_name()]
|
||||
state.substates[substate.get_name()] = substate
|
||||
substate.parent_state = state
|
||||
self.recursive_link_substates(
|
||||
state=substate,
|
||||
substates=substates,
|
||||
)
|
||||
|
||||
async def hmget(self, name: str, keys: List[str]) -> dict[str, BaseState]:
|
||||
"""Get multiple values from a hash.
|
||||
|
||||
Args:
|
||||
name: The name of the hash.
|
||||
keys: The keys to get.
|
||||
|
||||
Returns:
|
||||
The values.
|
||||
"""
|
||||
d = {}
|
||||
for redis_state in await self.redis.hmget(name=name, keys=keys): # type: ignore
|
||||
key = keys.pop(0)
|
||||
state = None
|
||||
if redis_state is not None:
|
||||
with contextlib.suppress(StateSchemaMismatchError):
|
||||
state = BaseState._deserialize(data=redis_state)
|
||||
if state is None:
|
||||
state_cls = self.state.get_class_substate(key)
|
||||
state = state_cls(
|
||||
init_substates=False,
|
||||
_reflex_internal_init=True,
|
||||
)
|
||||
d[state.get_full_name()] = state
|
||||
return d
|
||||
|
||||
@override
|
||||
async def set_state(
|
||||
self,
|
||||
@ -3407,6 +3480,7 @@ class StateManagerRedis(StateManager):
|
||||
Raises:
|
||||
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
|
||||
RuntimeError: If the state instance doesn't match the state name in the token.
|
||||
ResponseError: If the redis command fails.
|
||||
"""
|
||||
# Check that we're holding the lock.
|
||||
if (
|
||||
@ -3436,30 +3510,38 @@ class StateManagerRedis(StateManager):
|
||||
f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
|
||||
)
|
||||
|
||||
# Recursively set_state on all known substates.
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
self.set_state(
|
||||
_substate_key(client_token, substate),
|
||||
substate,
|
||||
lock_id,
|
||||
)
|
||||
)
|
||||
for substate in state.substates.values()
|
||||
]
|
||||
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
|
||||
if state._get_was_touched():
|
||||
pickle_state = state._serialize()
|
||||
if pickle_state:
|
||||
await self.redis.set(
|
||||
_substate_key(client_token, state),
|
||||
pickle_state,
|
||||
ex=self.token_expiration,
|
||||
)
|
||||
redis_hashset = state._serialize_touched_states()
|
||||
|
||||
# Wait for substates to be persisted.
|
||||
for t in tasks:
|
||||
await t
|
||||
if not redis_hashset:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._hset_pipeline(client_token, redis_hashset)
|
||||
except ResponseError as re:
|
||||
if "unknown command 'HEXPIRE'" not in str(re):
|
||||
raise
|
||||
# HEXPIRE not supported, try again with fallback expire.
|
||||
self._hexpire_not_supported = True
|
||||
await self._hset_pipeline(client_token, redis_hashset)
|
||||
|
||||
async def _hset_pipeline(self, client_token: str, redis_hashset: dict[str, bytes]):
|
||||
"""Set multiple fields in a hash with expiration.
|
||||
|
||||
Args:
|
||||
client_token: The name of the hash.
|
||||
redis_hashset: The keys and values to set.
|
||||
"""
|
||||
pipe = self.redis.pipeline(transaction=False)
|
||||
pipe.hset(name=client_token, mapping=redis_hashset)
|
||||
if self._hexpire_not_supported:
|
||||
pipe.expire(client_token, self.token_expiration)
|
||||
else:
|
||||
pipe.hexpire(
|
||||
client_token,
|
||||
self.token_expiration,
|
||||
*redis_hashset.keys(),
|
||||
)
|
||||
await pipe.execute()
|
||||
|
||||
@override
|
||||
@contextlib.asynccontextmanager
|
||||
|
@ -333,10 +333,9 @@ def get_redis() -> Redis | None:
|
||||
Returns:
|
||||
The asynchronous redis client.
|
||||
"""
|
||||
if isinstance((redis_url_or_options := parse_redis_url()), str):
|
||||
return Redis.from_url(redis_url_or_options)
|
||||
elif isinstance(redis_url_or_options, dict):
|
||||
return Redis(**redis_url_or_options)
|
||||
redis_url = parse_redis_url()
|
||||
if redis_url is not None:
|
||||
return Redis.from_url(redis_url)
|
||||
return None
|
||||
|
||||
|
||||
@ -346,14 +345,13 @@ def get_redis_sync() -> RedisSync | None:
|
||||
Returns:
|
||||
The synchronous redis client.
|
||||
"""
|
||||
if isinstance((redis_url_or_options := parse_redis_url()), str):
|
||||
return RedisSync.from_url(redis_url_or_options)
|
||||
elif isinstance(redis_url_or_options, dict):
|
||||
return RedisSync(**redis_url_or_options)
|
||||
redis_url = parse_redis_url()
|
||||
if redis_url is not None:
|
||||
return RedisSync.from_url(redis_url)
|
||||
return None
|
||||
|
||||
|
||||
def parse_redis_url() -> str | dict | None:
|
||||
def parse_redis_url() -> str | None:
|
||||
"""Parse the REDIS_URL in config if applicable.
|
||||
|
||||
Returns:
|
||||
|
@ -11,7 +11,6 @@ from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
|
||||
from reflex.state import (
|
||||
State,
|
||||
StateManagerDisk,
|
||||
StateManagerMemory,
|
||||
StateManagerRedis,
|
||||
@ -278,6 +277,7 @@ async def test_client_side_state(
|
||||
set_sub_sub_state_button.click()
|
||||
|
||||
token = poll_for_token()
|
||||
assert token is not None
|
||||
|
||||
# get a reference to all cookie and local storage elements
|
||||
c1 = driver.find_element(By.ID, "c1")
|
||||
@ -613,16 +613,7 @@ async def test_client_side_state(
|
||||
|
||||
# Simulate state expiration
|
||||
if isinstance(client_side.state_manager, StateManagerRedis):
|
||||
await client_side.state_manager.redis.delete(
|
||||
_substate_key(token, State.get_full_name())
|
||||
)
|
||||
await client_side.state_manager.redis.delete(_substate_key(token, state_name))
|
||||
await client_side.state_manager.redis.delete(
|
||||
_substate_key(token, sub_state_name)
|
||||
)
|
||||
await client_side.state_manager.redis.delete(
|
||||
_substate_key(token, sub_sub_state_name)
|
||||
)
|
||||
await client_side.state_manager.redis.delete(token)
|
||||
elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)):
|
||||
del client_side.state_manager.states[token]
|
||||
if isinstance(client_side.state_manager, StateManagerDisk):
|
||||
@ -678,9 +669,8 @@ async def test_client_side_state(
|
||||
|
||||
# Get the backend state and ensure the values are still set
|
||||
async def get_sub_state():
|
||||
root_state = await client_side.get_state(
|
||||
_substate_key(token or "", sub_state_name)
|
||||
)
|
||||
assert token is not None
|
||||
root_state = await client_side.get_state(_substate_key(token, sub_state_name))
|
||||
state = root_state.substates[client_side.get_state_name("_client_side_state")]
|
||||
sub_state = state.substates[
|
||||
client_side.get_state_name("_client_side_sub_state")
|
||||
|
@ -31,6 +31,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
|
||||
|
||||
import reflex as rx
|
||||
import reflex.config
|
||||
import reflex.utils.console
|
||||
from reflex import constants
|
||||
from reflex.app import App
|
||||
from reflex.base import Base
|
||||
@ -1857,7 +1858,7 @@ async def test_state_manager_lock_warning_threshold_contend(
|
||||
substate_token_redis: A token + substate name for looking up in state manager.
|
||||
mocker: Pytest mocker object.
|
||||
"""
|
||||
console_warn = mocker.patch("reflex.utils.console.warn")
|
||||
console_warn = mocker.spy(reflex.utils.console, "warn")
|
||||
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
|
||||
@ -1875,7 +1876,7 @@ async def test_state_manager_lock_warning_threshold_contend(
|
||||
|
||||
await tasks[0]
|
||||
console_warn.assert_called()
|
||||
assert console_warn.call_count == 7
|
||||
assert console_warn.call_count == 1
|
||||
|
||||
|
||||
class CopyingAsyncMock(AsyncMock):
|
||||
@ -3192,10 +3193,17 @@ def test_potentially_dirty_substates():
|
||||
def bar(self) -> str:
|
||||
return ""
|
||||
|
||||
assert RxState._potentially_dirty_substates() == {State}
|
||||
assert State._potentially_dirty_substates() == {C1}
|
||||
assert RxState._potentially_dirty_substates() == {State.get_full_name()}
|
||||
assert State._potentially_dirty_substates() == {C1.get_full_name()}
|
||||
assert C1._potentially_dirty_substates() == set()
|
||||
|
||||
assert RxState._recursive_potentially_dirty_substates() == {
|
||||
State.get_full_name(),
|
||||
C1.get_full_name(),
|
||||
}
|
||||
assert State._recursive_potentially_dirty_substates() == {C1.get_full_name()}
|
||||
assert C1._recursive_potentially_dirty_substates() == set()
|
||||
|
||||
|
||||
def test_router_var_dep() -> None:
|
||||
"""Test that router var dependencies are correctly tracked."""
|
||||
@ -3216,7 +3224,9 @@ def test_router_var_dep() -> None:
|
||||
State._init_var_dependency_dicts()
|
||||
|
||||
assert foo._deps(objclass=RouterVarDepState) == {"router"}
|
||||
assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState}
|
||||
assert RouterVarParentState._potentially_dirty_substates() == {
|
||||
RouterVarDepState.get_full_name()
|
||||
}
|
||||
assert RouterVarParentState._substate_var_dependencies == {
|
||||
"router": {RouterVarDepState.get_name()}
|
||||
}
|
||||
|
@ -354,11 +354,11 @@ async def state_manager_redis(
|
||||
],
|
||||
)
|
||||
async def test_get_state_tree(
|
||||
state_manager_redis,
|
||||
token,
|
||||
substate_cls,
|
||||
exp_root_substates,
|
||||
exp_root_dict_keys,
|
||||
state_manager_redis: StateManagerRedis,
|
||||
token: str,
|
||||
substate_cls: type[BaseState],
|
||||
exp_root_substates: list[str],
|
||||
exp_root_dict_keys: list[str],
|
||||
):
|
||||
"""Test getting state trees and assert on which branches are retrieved.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user