This commit is contained in:
benedikt-bartscher 2024-12-21 14:45:13 -08:00 committed by GitHub
commit 8de2da2ce5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 560 additions and 170 deletions

View File

@ -0,0 +1,309 @@
"""Benchmark for the state manager redis."""
import asyncio
from uuid import uuid4
import pytest
from pytest_benchmark.fixture import BenchmarkFixture
from reflex.state import State, StateManagerRedis
from reflex.utils.prerequisites import get_redis
from reflex.vars.base import computed_var
class RootState(State):
"""Root state class for testing."""
counter: int = 0
int_dict: dict[str, int] = {}
class ChildState(RootState):
"""Child state class for testing."""
child_counter: int = 0
@computed_var
def str_dict(self):
"""Convert the int dict to a string dict.
Returns:
A dictionary with string keys and integer values.
"""
return {str(k): v for k, v in self.int_dict.items()}
class ChildState2(RootState):
"""Child state 2 class for testing."""
child2_counter: int = 0
class GrandChildState(ChildState):
"""Grandchild state class for testing."""
grand_child_counter: int = 0
float_dict: dict[str, float] = {}
@computed_var
def double_counter(self):
"""Double the counter.
Returns:
The counter value multiplied by 2.
"""
return self.counter * 2
@pytest.fixture
def state_manager() -> StateManagerRedis:
"""Fixture for the redis state manager.
Returns:
An instance of StateManagerRedis.
"""
redis = get_redis()
if redis is None:
pytest.skip("Redis is not available")
return StateManagerRedis(redis=redis, state=State)
@pytest.fixture
def token() -> str:
"""Fixture for the token.
Returns:
A unique token string.
"""
return str(uuid4())
@pytest.fixture
def grand_child_state_token(token: str) -> str:
"""Fixture for the grand child state token.
Args:
token: The token fixture.
Returns:
A string combining the token and the grandchild state name.
"""
return f"{token}_{GrandChildState.get_full_name()}"
@pytest.fixture
def state_token(token: str) -> str:
"""Fixture for the base state token.
Args:
token: The token fixture.
Returns:
A string combining the token and the base state name.
"""
return f"{token}_{State.get_full_name()}"
@pytest.fixture
def grand_child_state() -> GrandChildState:
"""Fixture for the grand child state.
Returns:
An instance of GrandChildState.
"""
state = State()
root = RootState()
root.parent_state = state
state.substates[root.get_name()] = root
child = ChildState()
child.parent_state = root
root.substates[child.get_name()] = child
child2 = ChildState2()
child2.parent_state = root
root.substates[child2.get_name()] = child2
gcs = GrandChildState()
gcs.parent_state = child
child.substates[gcs.get_name()] = gcs
return gcs
@pytest.fixture
def grand_child_state_big(grand_child_state: GrandChildState) -> GrandChildState:
"""Fixture for the grand child state with big data.
Args:
grand_child_state: The grand child state fixture.
Returns:
An instance of GrandChildState with large data.
"""
grand_child_state.counter = 100
grand_child_state.child_counter = 200
grand_child_state.grand_child_counter = 300
grand_child_state.int_dict = {str(i): i for i in range(10000)}
grand_child_state.float_dict = {str(i): i + 0.5 for i in range(10000)}
return grand_child_state
def test_set_state(
benchmark: BenchmarkFixture,
state_manager: StateManagerRedis,
event_loop: asyncio.AbstractEventLoop,
token: str,
) -> None:
"""Benchmark setting state with minimal data.
Args:
benchmark: The benchmark fixture.
state_manager: The state manager fixture.
event_loop: The event loop fixture.
token: The token fixture.
"""
state = State()
def func():
event_loop.run_until_complete(state_manager.set_state(token=token, state=state))
benchmark(func)
def test_get_state(
benchmark: BenchmarkFixture,
state_manager: StateManagerRedis,
event_loop: asyncio.AbstractEventLoop,
state_token: str,
) -> None:
"""Benchmark getting state with minimal data.
Args:
benchmark: The benchmark fixture.
state_manager: The state manager fixture.
event_loop: The event loop fixture.
state_token: The base state token fixture.
"""
state = State()
event_loop.run_until_complete(
state_manager.set_state(token=state_token, state=state)
)
def func():
_ = event_loop.run_until_complete(state_manager.get_state(token=state_token))
benchmark(func)
def test_set_state_tree_minimal(
benchmark: BenchmarkFixture,
state_manager: StateManagerRedis,
event_loop: asyncio.AbstractEventLoop,
grand_child_state_token: str,
grand_child_state: GrandChildState,
) -> None:
"""Benchmark setting state with minimal data.
Args:
benchmark: The benchmark fixture.
state_manager: The state manager fixture.
event_loop: The event loop fixture.
grand_child_state_token: The grand child state token fixture.
grand_child_state: The grand child state fixture.
"""
def func():
event_loop.run_until_complete(
state_manager.set_state(
token=grand_child_state_token, state=grand_child_state
)
)
benchmark(func)
def test_get_state_tree_minimal(
benchmark: BenchmarkFixture,
state_manager: StateManagerRedis,
event_loop: asyncio.AbstractEventLoop,
grand_child_state_token: str,
grand_child_state: GrandChildState,
) -> None:
"""Benchmark getting state with minimal data.
Args:
benchmark: The benchmark fixture.
state_manager: The state manager fixture.
event_loop: The event loop fixture.
grand_child_state_token: The grand child state token fixture.
grand_child_state: The grand child state fixture.
"""
event_loop.run_until_complete(
state_manager.set_state(token=grand_child_state_token, state=grand_child_state)
)
def func():
_ = event_loop.run_until_complete(
state_manager.get_state(token=grand_child_state_token)
)
benchmark(func)
def test_set_state_tree_big(
benchmark: BenchmarkFixture,
state_manager: StateManagerRedis,
event_loop: asyncio.AbstractEventLoop,
grand_child_state_token: str,
grand_child_state_big: GrandChildState,
) -> None:
"""Benchmark setting state with minimal data.
Args:
benchmark: The benchmark fixture.
state_manager: The state manager fixture.
event_loop: The event loop fixture.
grand_child_state_token: The grand child state token fixture.
grand_child_state_big: The grand child state fixture.
"""
def func():
event_loop.run_until_complete(
state_manager.set_state(
token=grand_child_state_token, state=grand_child_state_big
)
)
benchmark(func)
def test_get_state_tree_big(
benchmark: BenchmarkFixture,
state_manager: StateManagerRedis,
event_loop: asyncio.AbstractEventLoop,
grand_child_state_token: str,
grand_child_state_big: GrandChildState,
) -> None:
"""Benchmark getting state with minimal data.
Args:
benchmark: The benchmark fixture.
state_manager: The state manager fixture.
event_loop: The event loop fixture.
grand_child_state_token: The grand child state token fixture.
grand_child_state_big: The grand child state fixture.
"""
event_loop.run_until_complete(
state_manager.set_state(
token=grand_child_state_token, state=grand_child_state_big
)
)
def func():
_ = event_loop.run_until_complete(
state_manager.get_state(token=grand_child_state_token)
)
benchmark(func)

3
poetry.lock generated
View File

@ -1147,6 +1147,7 @@ files = [
{file = "nh3-0.2.19-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:00810cd5275f5c3f44b9eb0e521d1a841ee2f8023622de39ffc7d88bd533d8e0"}, {file = "nh3-0.2.19-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:00810cd5275f5c3f44b9eb0e521d1a841ee2f8023622de39ffc7d88bd533d8e0"},
{file = "nh3-0.2.19-cp38-abi3-win32.whl", hash = "sha256:7e98621856b0a911c21faa5eef8f8ea3e691526c2433f9afc2be713cb6fbdb48"}, {file = "nh3-0.2.19-cp38-abi3-win32.whl", hash = "sha256:7e98621856b0a911c21faa5eef8f8ea3e691526c2433f9afc2be713cb6fbdb48"},
{file = "nh3-0.2.19-cp38-abi3-win_amd64.whl", hash = "sha256:75c7cafb840f24430b009f7368945cb5ca88b2b54bb384ebfba495f16bc9c121"}, {file = "nh3-0.2.19-cp38-abi3-win_amd64.whl", hash = "sha256:75c7cafb840f24430b009f7368945cb5ca88b2b54bb384ebfba495f16bc9c121"},
{file = "nh3-0.2.19.tar.gz", hash = "sha256:790056b54c068ff8dceb443eaefb696b84beff58cca6c07afd754d17692a4804"},
] ]
[[package]] [[package]]
@ -3076,4 +3077,4 @@ type = ["pytest-mypy"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "d62cd1897d8f73e9aad9e907beb82be509dc5e33d8f37b36ebf26ad1f3075a9f" content-hash = "4ef559dcc4b3fd0d88c908cb4df4d7a14e3d021498d3034ad1b9481131abe686"

View File

@ -27,7 +27,7 @@ psutil = ">=5.9.4,<7.0"
pydantic = ">=1.10.2,<3.0" pydantic = ">=1.10.2,<3.0"
python-multipart = ">=0.0.5,<0.1" python-multipart = ">=0.0.5,<0.1"
python-socketio = ">=5.7.0,<6.0" python-socketio = ">=5.7.0,<6.0"
redis = ">=4.3.5,<6.0" redis = ">=5.1.0,<6.0"
rich = ">=13.0.0,<14.0" rich = ">=13.0.0,<14.0"
sqlmodel = ">=0.0.14,<0.1" sqlmodel = ">=0.0.14,<0.1"
typer = ">=0.4.2,<1.0" typer = ">=0.4.2,<1.0"

View File

@ -956,7 +956,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for substate in cls.get_substates(): for substate in cls.get_substates():
if path[0] == substate.get_name(): if path[0] == substate.get_name():
return substate.get_class_substate(path[1:]) return substate.get_class_substate(path[1:])
raise ValueError(f"Invalid path: {path}") raise ValueError(f"Invalid path: {cls.get_full_name()=} {path=}")
@classmethod
def get_all_substate_classes(cls) -> set[Type[BaseState]]:
"""Get all substate classes of the state.
Returns:
The set of all substate classes.
"""
substates = set(cls.get_substates())
for substate in cls.get_substates():
substates.update(substate.get_all_substate_classes())
return substates
@classmethod @classmethod
def get_class_var(cls, path: Sequence[str]) -> Any: def get_class_var(cls, path: Sequence[str]) -> Any:
@ -1414,7 +1426,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
return self return self
path = path[1:] path = path[1:]
if path[0] not in self.substates: if path[0] not in self.substates:
raise ValueError(f"Invalid path: {path}") raise ValueError(
f"Invalid path: {path=} {self.get_full_name()=} {self.substates.keys()=}"
)
return self.substates[path[0]].get_substate(path[1:]) return self.substates[path[0]].get_substate(path[1:])
@classmethod @classmethod
@ -1476,6 +1490,63 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
parent_states_with_name.append((parent_state.get_full_name(), parent_state)) parent_states_with_name.append((parent_state.get_full_name(), parent_state))
return parent_states_with_name return parent_states_with_name
def _get_loaded_states(self) -> dict[str, BaseState]:
"""Get all loaded states in the state tree.
Returns:
A list of all loaded states in the state tree.
"""
root_state = self._get_root_state()
d = {root_state.get_full_name(): root_state}
root_state._get_loaded_substates(d)
return d
def _get_loaded_substates(
self,
loaded_substates: dict[str, BaseState],
) -> None:
"""Get all loaded substates of this state.
Args:
loaded_substates: A dictionary of loaded substates which will be updated with the substates of this state.
"""
for substate in self.substates.values():
loaded_substates[substate.get_full_name()] = substate
substate._get_loaded_substates(loaded_substates)
def _serialize_touched_states(self) -> dict[str, bytes]:
"""Serialize all touched states in the state tree.
Returns:
The serialized states.
"""
root_state = self._get_root_state()
d = {}
if root_state._get_was_touched():
serialized = root_state._serialize()
if serialized:
d[root_state.get_full_name()] = serialized
root_state._serialize_touched_substates(d)
return d
def _serialize_touched_substates(
self,
touched_substates: dict[str, bytes],
) -> None:
"""Serialize all touched substates of this state.
Args:
touched_substates: A dictionary of touched substates which will be updated with the substates of this state.
"""
for substate in self.substates.values():
substate._serialize_touched_substates(touched_substates)
if not substate._get_was_touched():
continue
serialized = substate._serialize()
if not serialized:
continue
touched_substates[substate.get_full_name()] = serialized
def _get_root_state(self) -> BaseState: def _get_root_state(self) -> BaseState:
"""Get the root state of the state tree. """Get the root state of the state tree.
@ -1883,26 +1954,48 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
} }
@classmethod @classmethod
def _potentially_dirty_substates(cls) -> set[Type[BaseState]]: def _potentially_dirty_substates(cls) -> set[str]:
"""Determine substates which could be affected by dirty vars in this state. """Determine substates which could be affected by dirty vars in this state.
Returns: Returns:
Set of State classes that may need to be fetched to recalc computed vars. Set of State full names that may need to be fetched to recalc computed vars.
""" """
# _always_dirty_substates need to be fetched to recalc computed vars. # _always_dirty_substates need to be fetched to recalc computed vars.
fetch_substates = { fetch_substates = {
cls.get_class_substate((cls.get_name(), *substate_name.split("."))) f"{cls.get_full_name()}.{substate_name}"
for substate_name in cls._always_dirty_substates for substate_name in cls._always_dirty_substates
} }
for dependent_substates in cls._substate_var_dependencies.values(): for dependent_substates in cls._substate_var_dependencies.values():
fetch_substates.update( fetch_substates.update(
{ {
cls.get_class_substate((cls.get_name(), *substate_name.split("."))) f"{cls.get_full_name()}.{substate_name}"
for substate_name in dependent_substates for substate_name in dependent_substates
} }
) )
return fetch_substates return fetch_substates
@classmethod
def _recursive_potentially_dirty_substates(
cls,
already_selected: Type[BaseState] | None = None,
) -> set[str]:
"""Recursively determine substates which could be affected by dirty vars in this state.
Args:
already_selected: The class of the state that has already been selected and needs no further processing.
Returns:
Set of full state names that may need to be fetched to recalc computed vars.
"""
if already_selected is not None and already_selected == cls:
return set()
fetch_substates = cls._potentially_dirty_substates()
for substate_cls in cls.get_substates():
fetch_substates.update(
substate_cls._recursive_potentially_dirty_substates(already_selected)
)
return fetch_substates
def get_delta(self) -> Delta: def get_delta(self) -> Delta:
"""Get the delta for the state. """Get the delta for the state.
@ -3231,6 +3324,9 @@ class StateManagerRedis(StateManager):
default_factory=_default_lock_warning_threshold default_factory=_default_lock_warning_threshold
) )
# If HEXPIRE is not supported, use EXPIRE instead.
_hexpire_not_supported: Optional[bool] = pydantic.PrivateAttr(None)
# The keyspace subscription string when redis is waiting for lock to be released # The keyspace subscription string when redis is waiting for lock to be released
_redis_notify_keyspace_events: str = ( _redis_notify_keyspace_events: str = (
"K" # Enable keyspace notifications (target a particular key) "K" # Enable keyspace notifications (target a particular key)
@ -3247,77 +3343,6 @@ class StateManagerRedis(StateManager):
b"evicted", b"evicted",
} }
async def _get_parent_state(
self, token: str, state: BaseState | None = None
) -> BaseState | None:
"""Get the parent state for the state requested in the token.
Args:
token: The token to get the state for (_substate_key).
state: The state instance to get parent state for.
Returns:
The parent state for the state requested by the token or None if there is no such parent.
"""
parent_state = None
client_token, state_path = _split_substate_key(token)
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
cached_substates = None
if state is not None:
cached_substates = [state]
# Retrieve the parent state to populate event handlers onto this substate.
parent_state = await self.get_state(
token=_substate_key(client_token, parent_state_name),
top_level=False,
get_substates=False,
cached_substates=cached_substates,
)
return parent_state
async def _populate_substates(
self,
token: str,
state: BaseState,
all_substates: bool = False,
):
"""Fetch and link substates for the given state instance.
There is no return value; the side-effect is that `state` will have `substates` populated,
and each substate will have its `parent_state` set to `state`.
Args:
token: The token to get the state for.
state: The state instance to populate substates for.
all_substates: Whether to fetch all substates or just required substates.
"""
client_token, _ = _split_substate_key(token)
if all_substates:
# All substates are requested.
fetch_substates = state.get_substates()
else:
# Only _potentially_dirty_substates need to be fetched to recalc computed vars.
fetch_substates = state._potentially_dirty_substates()
tasks = {}
# Retrieve the necessary substates from redis.
for substate_cls in fetch_substates:
if substate_cls.get_name() in state.substates:
continue
substate_name = substate_cls.get_name()
tasks[substate_name] = asyncio.create_task(
self.get_state(
token=_substate_key(client_token, substate_cls),
top_level=False,
get_substates=all_substates,
parent_state=state,
)
)
for substate_name, substate_task in tasks.items():
state.substates[substate_name] = await substate_task
@override @override
async def get_state( async def get_state(
self, self,
@ -3325,7 +3350,6 @@ class StateManagerRedis(StateManager):
top_level: bool = True, top_level: bool = True,
get_substates: bool = True, get_substates: bool = True,
parent_state: BaseState | None = None, parent_state: BaseState | None = None,
cached_substates: list[BaseState] | None = None,
) -> BaseState: ) -> BaseState:
"""Get the state for a token. """Get the state for a token.
@ -3334,7 +3358,6 @@ class StateManagerRedis(StateManager):
top_level: If true, return an instance of the top-level state (self.state). top_level: If true, return an instance of the top-level state (self.state).
get_substates: If true, also retrieve substates. get_substates: If true, also retrieve substates.
parent_state: If provided, use this parent_state instead of getting it from redis. parent_state: If provided, use this parent_state instead of getting it from redis.
cached_substates: If provided, attach these substates to the state.
Returns: Returns:
The state for the token. The state for the token.
@ -3342,8 +3365,8 @@ class StateManagerRedis(StateManager):
Raises: Raises:
RuntimeError: when the state_cls is not specified in the token RuntimeError: when the state_cls is not specified in the token
""" """
# Split the actual token from the fully qualified substate name. # new impl from top to bottomA
_, state_path = _split_substate_key(token) client_token, state_path = _split_substate_key(token)
if state_path: if state_path:
# Get the State class associated with the given path. # Get the State class associated with the given path.
state_cls = self.state.get_class_substate(state_path) state_cls = self.state.get_class_substate(state_path)
@ -3352,44 +3375,94 @@ class StateManagerRedis(StateManager):
f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}"
) )
# The deserialized or newly created (sub)state instance. state_tokens = {state_path}
state = None
# Fetch the serialized substate from redis. # walk up the state path
redis_state = await self.redis.get(token) walk_state_path = state_path
while "." in walk_state_path:
walk_state_path = walk_state_path.rpartition(".")[0]
state_tokens.add(walk_state_path)
if redis_state is not None: if get_substates:
# Deserialize the substate. state_tokens.update(
with contextlib.suppress(StateSchemaMismatchError): {
state = BaseState._deserialize(data=redis_state) substate.get_full_name()
if state is None: for substate in state_cls.get_all_substate_classes()
# Key didn't exist or schema mismatch so create a new instance for this token. }
state = state_cls(
init_substates=False,
_reflex_internal_init=True,
) )
# Populate parent state if missing and requested. state_tokens.update(
if parent_state is None: self.state._recursive_potentially_dirty_substates(
parent_state = await self._get_parent_state(token, state) already_selected=state_cls,
# Set up Bidirectional linkage between this state and its parent. )
if parent_state is not None: )
parent_state.substates[state.get_name()] = state else:
state.parent_state = parent_state state_tokens.update(self.state._recursive_potentially_dirty_substates())
# Avoid fetching substates multiple times.
if cached_substates: loaded_states = {}
for substate in cached_substates: if parent_state is not None:
state.substates[substate.get_name()] = substate loaded_states = parent_state._get_loaded_states()
if substate.parent_state is None: # remove all states that are already loaded
substate.parent_state = state state_tokens = state_tokens.difference(loaded_states.keys())
# Populate substates if requested.
await self._populate_substates(token, state, all_substates=get_substates) redis_states = await self.hmget(name=client_token, keys=list(state_tokens))
redis_states.update(loaded_states)
root_state = redis_states[self.state.get_full_name()]
self.recursive_link_substates(state=root_state, substates=redis_states)
# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level: if top_level:
return state._get_root_state() return root_state
state = redis_states[state_path]
return state return state
def recursive_link_substates(
self,
state: BaseState,
substates: dict[str, BaseState],
):
"""Recursively link substates to a state.
Args:
state: The state to link substates to.
substates: The substates to link.
"""
for substate_cls in state.get_substates():
if substate_cls.get_full_name() not in substates:
continue
substate = substates[substate_cls.get_full_name()]
state.substates[substate.get_name()] = substate
substate.parent_state = state
self.recursive_link_substates(
state=substate,
substates=substates,
)
async def hmget(self, name: str, keys: List[str]) -> dict[str, BaseState]:
"""Get multiple values from a hash.
Args:
name: The name of the hash.
keys: The keys to get.
Returns:
The values.
"""
d = {}
for redis_state in await self.redis.hmget(name=name, keys=keys): # type: ignore
key = keys.pop(0)
state = None
if redis_state is not None:
with contextlib.suppress(StateSchemaMismatchError):
state = BaseState._deserialize(data=redis_state)
if state is None:
state_cls = self.state.get_class_substate(key)
state = state_cls(
init_substates=False,
_reflex_internal_init=True,
)
d[state.get_full_name()] = state
return d
@override @override
async def set_state( async def set_state(
self, self,
@ -3407,6 +3480,7 @@ class StateManagerRedis(StateManager):
Raises: Raises:
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID. LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
RuntimeError: If the state instance doesn't match the state name in the token. RuntimeError: If the state instance doesn't match the state name in the token.
ResponseError: If the redis command fails.
""" """
# Check that we're holding the lock. # Check that we're holding the lock.
if ( if (
@ -3436,30 +3510,38 @@ class StateManagerRedis(StateManager):
f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}." f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}."
) )
# Recursively set_state on all known substates. redis_hashset = state._serialize_touched_states()
tasks = [
asyncio.create_task(
self.set_state(
_substate_key(client_token, substate),
substate,
lock_id,
)
)
for substate in state.substates.values()
]
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
if state._get_was_touched():
pickle_state = state._serialize()
if pickle_state:
await self.redis.set(
_substate_key(client_token, state),
pickle_state,
ex=self.token_expiration,
)
# Wait for substates to be persisted. if not redis_hashset:
for t in tasks: return
await t
try:
await self._hset_pipeline(client_token, redis_hashset)
except ResponseError as re:
if "unknown command 'HEXPIRE'" not in str(re):
raise
# HEXPIRE not supported, try again with fallback expire.
self._hexpire_not_supported = True
await self._hset_pipeline(client_token, redis_hashset)
async def _hset_pipeline(self, client_token: str, redis_hashset: dict[str, bytes]):
"""Set multiple fields in a hash with expiration.
Args:
client_token: The name of the hash.
redis_hashset: The keys and values to set.
"""
pipe = self.redis.pipeline(transaction=False)
pipe.hset(name=client_token, mapping=redis_hashset)
if self._hexpire_not_supported:
pipe.expire(client_token, self.token_expiration)
else:
pipe.hexpire(
client_token,
self.token_expiration,
*redis_hashset.keys(),
)
await pipe.execute()
@override @override
@contextlib.asynccontextmanager @contextlib.asynccontextmanager

View File

@ -333,10 +333,9 @@ def get_redis() -> Redis | None:
Returns: Returns:
The asynchronous redis client. The asynchronous redis client.
""" """
if isinstance((redis_url_or_options := parse_redis_url()), str): redis_url = parse_redis_url()
return Redis.from_url(redis_url_or_options) if redis_url is not None:
elif isinstance(redis_url_or_options, dict): return Redis.from_url(redis_url)
return Redis(**redis_url_or_options)
return None return None
@ -346,14 +345,13 @@ def get_redis_sync() -> RedisSync | None:
Returns: Returns:
The synchronous redis client. The synchronous redis client.
""" """
if isinstance((redis_url_or_options := parse_redis_url()), str): redis_url = parse_redis_url()
return RedisSync.from_url(redis_url_or_options) if redis_url is not None:
elif isinstance(redis_url_or_options, dict): return RedisSync.from_url(redis_url)
return RedisSync(**redis_url_or_options)
return None return None
def parse_redis_url() -> str | dict | None: def parse_redis_url() -> str | None:
"""Parse the REDIS_URL in config if applicable. """Parse the REDIS_URL in config if applicable.
Returns: Returns:

View File

@ -11,7 +11,6 @@ from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webdriver import WebDriver from selenium.webdriver.remote.webdriver import WebDriver
from reflex.state import ( from reflex.state import (
State,
StateManagerDisk, StateManagerDisk,
StateManagerMemory, StateManagerMemory,
StateManagerRedis, StateManagerRedis,
@ -278,6 +277,7 @@ async def test_client_side_state(
set_sub_sub_state_button.click() set_sub_sub_state_button.click()
token = poll_for_token() token = poll_for_token()
assert token is not None
# get a reference to all cookie and local storage elements # get a reference to all cookie and local storage elements
c1 = driver.find_element(By.ID, "c1") c1 = driver.find_element(By.ID, "c1")
@ -613,16 +613,7 @@ async def test_client_side_state(
# Simulate state expiration # Simulate state expiration
if isinstance(client_side.state_manager, StateManagerRedis): if isinstance(client_side.state_manager, StateManagerRedis):
await client_side.state_manager.redis.delete( await client_side.state_manager.redis.delete(token)
_substate_key(token, State.get_full_name())
)
await client_side.state_manager.redis.delete(_substate_key(token, state_name))
await client_side.state_manager.redis.delete(
_substate_key(token, sub_state_name)
)
await client_side.state_manager.redis.delete(
_substate_key(token, sub_sub_state_name)
)
elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)): elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)):
del client_side.state_manager.states[token] del client_side.state_manager.states[token]
if isinstance(client_side.state_manager, StateManagerDisk): if isinstance(client_side.state_manager, StateManagerDisk):
@ -678,9 +669,8 @@ async def test_client_side_state(
# Get the backend state and ensure the values are still set # Get the backend state and ensure the values are still set
async def get_sub_state(): async def get_sub_state():
root_state = await client_side.get_state( assert token is not None
_substate_key(token or "", sub_state_name) root_state = await client_side.get_state(_substate_key(token, sub_state_name))
)
state = root_state.substates[client_side.get_state_name("_client_side_state")] state = root_state.substates[client_side.get_state_name("_client_side_state")]
sub_state = state.substates[ sub_state = state.substates[
client_side.get_state_name("_client_side_sub_state") client_side.get_state_name("_client_side_sub_state")

View File

@ -31,6 +31,7 @@ from pydantic.v1 import BaseModel as BaseModelV1
import reflex as rx import reflex as rx
import reflex.config import reflex.config
import reflex.utils.console
from reflex import constants from reflex import constants
from reflex.app import App from reflex.app import App
from reflex.base import Base from reflex.base import Base
@ -1857,7 +1858,7 @@ async def test_state_manager_lock_warning_threshold_contend(
substate_token_redis: A token + substate name for looking up in state manager. substate_token_redis: A token + substate name for looking up in state manager.
mocker: Pytest mocker object. mocker: Pytest mocker object.
""" """
console_warn = mocker.patch("reflex.utils.console.warn") console_warn = mocker.spy(reflex.utils.console, "warn")
state_manager_redis.lock_expiration = LOCK_EXPIRATION state_manager_redis.lock_expiration = LOCK_EXPIRATION
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
@ -1875,7 +1876,7 @@ async def test_state_manager_lock_warning_threshold_contend(
await tasks[0] await tasks[0]
console_warn.assert_called() console_warn.assert_called()
assert console_warn.call_count == 7 assert console_warn.call_count == 1
class CopyingAsyncMock(AsyncMock): class CopyingAsyncMock(AsyncMock):
@ -3192,10 +3193,17 @@ def test_potentially_dirty_substates():
def bar(self) -> str: def bar(self) -> str:
return "" return ""
assert RxState._potentially_dirty_substates() == {State} assert RxState._potentially_dirty_substates() == {State.get_full_name()}
assert State._potentially_dirty_substates() == {C1} assert State._potentially_dirty_substates() == {C1.get_full_name()}
assert C1._potentially_dirty_substates() == set() assert C1._potentially_dirty_substates() == set()
assert RxState._recursive_potentially_dirty_substates() == {
State.get_full_name(),
C1.get_full_name(),
}
assert State._recursive_potentially_dirty_substates() == {C1.get_full_name()}
assert C1._recursive_potentially_dirty_substates() == set()
def test_router_var_dep() -> None: def test_router_var_dep() -> None:
"""Test that router var dependencies are correctly tracked.""" """Test that router var dependencies are correctly tracked."""
@ -3216,7 +3224,9 @@ def test_router_var_dep() -> None:
State._init_var_dependency_dicts() State._init_var_dependency_dicts()
assert foo._deps(objclass=RouterVarDepState) == {"router"} assert foo._deps(objclass=RouterVarDepState) == {"router"}
assert RouterVarParentState._potentially_dirty_substates() == {RouterVarDepState} assert RouterVarParentState._potentially_dirty_substates() == {
RouterVarDepState.get_full_name()
}
assert RouterVarParentState._substate_var_dependencies == { assert RouterVarParentState._substate_var_dependencies == {
"router": {RouterVarDepState.get_name()} "router": {RouterVarDepState.get_name()}
} }

View File

@ -354,11 +354,11 @@ async def state_manager_redis(
], ],
) )
async def test_get_state_tree( async def test_get_state_tree(
state_manager_redis, state_manager_redis: StateManagerRedis,
token, token: str,
substate_cls, substate_cls: type[BaseState],
exp_root_substates, exp_root_substates: list[str],
exp_root_dict_keys, exp_root_dict_keys: list[str],
): ):
"""Test getting state trees and assert on which branches are retrieved. """Test getting state trees and assert on which branches are retrieved.