[REF-1885] Shard Substates when serializing to Redis (#2574)
* Move sharding internal to StateManager Avoid leaking sharding implementation details all over the State class and breaking the API * WiP StateManager based sharding * Copy the state __dict__ when serializing to avoid breaking the instance * State tests need to pass the correct substate token for redis * state: when getting parent_state, set top_level=False ensure that we don't end up with a broken tree * test_app: get tests passing with redis by passing the correct token refactor upload tests to suck less * test_client_storage: look up substate key * state.py: pass static checks * test_dynamic_routes: working with redis state shard * Update the remaining AppHarness tests to pass {token}_{state.get_full_name()} * test_app: pass all tokens with state suffix * StateManagerRedis: clean up commentary
This commit is contained in:
parent
f9d219407f
commit
756bf9b0f4
@ -449,7 +449,7 @@ async def test_client_side_state(
|
||||
assert l1s.text == "l1s value"
|
||||
|
||||
# reset the backend state to force refresh from client storage
|
||||
async with client_side.modify_state(token) as state:
|
||||
async with client_side.modify_state(f"{token}_state.client_side_state") as state:
|
||||
state.reset()
|
||||
driver.refresh()
|
||||
|
||||
|
@ -85,6 +85,7 @@ def dynamic_route(
|
||||
"""
|
||||
with app_harness_env.create(
|
||||
root=tmp_path_factory.mktemp(f"dynamic_route"),
|
||||
app_name=f"dynamicroute_{app_harness_env.__name__.lower()}",
|
||||
app_source=DynamicRoute, # type: ignore
|
||||
) as harness:
|
||||
yield harness
|
||||
@ -146,7 +147,7 @@ def poll_for_order(
|
||||
|
||||
async def _poll_for_order(exp_order: list[str]):
|
||||
async def _backend_state():
|
||||
return await dynamic_route.get_state(token)
|
||||
return await dynamic_route.get_state(f"{token}_state.dynamic_state")
|
||||
|
||||
async def _check():
|
||||
return (await _backend_state()).substates[
|
||||
@ -194,7 +195,9 @@ async def test_on_load_navigate(
|
||||
assert link
|
||||
assert page_id_input
|
||||
|
||||
assert dynamic_route.poll_for_value(page_id_input) == str(ix)
|
||||
assert dynamic_route.poll_for_value(
|
||||
page_id_input, exp_not_equal=str(ix - 1)
|
||||
) == str(ix)
|
||||
assert dynamic_route.poll_for_value(raw_path_input) == f"/page/{ix}/"
|
||||
await poll_for_order(exp_order)
|
||||
|
||||
@ -220,7 +223,9 @@ async def test_on_load_navigate(
|
||||
with poll_for_navigation(driver):
|
||||
driver.get(f"{driver.current_url}?foo=bar")
|
||||
await poll_for_order(exp_order)
|
||||
assert (await dynamic_route.get_state(token)).router.page.params["foo"] == "bar"
|
||||
assert (
|
||||
await dynamic_route.get_state(f"{token}_state.dynamic_state")
|
||||
).router.page.params["foo"] == "bar"
|
||||
|
||||
# hit a 404 and ensure we still hydrate
|
||||
exp_order += ["/404-no page id"]
|
||||
|
@ -207,7 +207,7 @@ def poll_for_order(
|
||||
|
||||
async def _poll_for_order(exp_order: list[str]):
|
||||
async def _backend_state():
|
||||
return await event_action.get_state(token)
|
||||
return await event_action.get_state(f"{token}_state.event_action_state")
|
||||
|
||||
async def _check():
|
||||
return (await _backend_state()).substates[
|
||||
|
@ -298,7 +298,7 @@ def assert_token(event_chain: AppHarness, driver: WebDriver) -> str:
|
||||
token = event_chain.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
|
||||
return token
|
||||
return f"{token}_state.state"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -221,7 +221,11 @@ async def test_submit(driver, form_submit: AppHarness):
|
||||
submit_input.click()
|
||||
|
||||
async def get_form_data():
|
||||
return (await form_submit.get_state(token)).substates["form_state"].form_data
|
||||
return (
|
||||
(await form_submit.get_state(f"{token}_state.form_state"))
|
||||
.substates["form_state"]
|
||||
.form_data
|
||||
)
|
||||
|
||||
# wait for the form data to arrive at the backend
|
||||
form_data = await AppHarness._poll_for_async(get_form_data)
|
||||
|
@ -76,6 +76,10 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
token = fully_controlled_input.poll_for_value(token_input)
|
||||
assert token
|
||||
|
||||
async def get_state_text():
|
||||
state = await fully_controlled_input.get_state(f"{token}_state.state")
|
||||
return state.substates["state"].text
|
||||
|
||||
# find the input and wait for it to have the initial state value
|
||||
debounce_input = driver.find_element(By.ID, "debounce_input_input")
|
||||
value_input = driver.find_element(By.ID, "value_input")
|
||||
@ -95,16 +99,14 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
debounce_input.send_keys("foo")
|
||||
time.sleep(0.5)
|
||||
assert debounce_input.get_attribute("value") == "ifoonitial"
|
||||
assert (await fully_controlled_input.get_state(token)).substates[
|
||||
"state"
|
||||
].text == "ifoonitial"
|
||||
assert await get_state_text() == "ifoonitial"
|
||||
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
|
||||
assert fully_controlled_input.poll_for_value(plain_value_input) == "ifoonitial"
|
||||
|
||||
# clear the input on the backend
|
||||
async with fully_controlled_input.modify_state(token) as state:
|
||||
async with fully_controlled_input.modify_state(f"{token}_state.state") as state:
|
||||
state.substates["state"].text = ""
|
||||
assert (await fully_controlled_input.get_state(token)).substates["state"].text == ""
|
||||
assert await get_state_text() == ""
|
||||
assert (
|
||||
fully_controlled_input.poll_for_value(
|
||||
debounce_input, exp_not_equal="ifoonitial"
|
||||
@ -116,9 +118,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
debounce_input.send_keys("getting testing done")
|
||||
time.sleep(0.5)
|
||||
assert debounce_input.get_attribute("value") == "getting testing done"
|
||||
assert (await fully_controlled_input.get_state(token)).substates[
|
||||
"state"
|
||||
].text == "getting testing done"
|
||||
assert await get_state_text() == "getting testing done"
|
||||
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
|
||||
assert (
|
||||
fully_controlled_input.poll_for_value(plain_value_input)
|
||||
@ -130,9 +130,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
time.sleep(0.5)
|
||||
assert debounce_input.get_attribute("value") == "overwrite the state"
|
||||
assert on_change_input.get_attribute("value") == "overwrite the state"
|
||||
assert (await fully_controlled_input.get_state(token)).substates[
|
||||
"state"
|
||||
].text == "overwrite the state"
|
||||
assert await get_state_text() == "overwrite the state"
|
||||
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
|
||||
assert (
|
||||
fully_controlled_input.poll_for_value(plain_value_input)
|
||||
|
@ -171,6 +171,7 @@ async def test_upload_file(
|
||||
# wait for the backend connection to send the token
|
||||
token = upload_file.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
substate_token = f"{token}_state.upload_state"
|
||||
|
||||
suffix = "_secondary" if secondary else ""
|
||||
|
||||
@ -191,7 +192,11 @@ async def test_upload_file(
|
||||
|
||||
# look up the backend state and assert on uploaded contents
|
||||
async def get_file_data():
|
||||
return (await upload_file.get_state(token)).substates["upload_state"]._file_data
|
||||
return (
|
||||
(await upload_file.get_state(substate_token))
|
||||
.substates["upload_state"]
|
||||
._file_data
|
||||
)
|
||||
|
||||
file_data = await AppHarness._poll_for_async(get_file_data)
|
||||
assert isinstance(file_data, dict)
|
||||
@ -201,7 +206,7 @@ async def test_upload_file(
|
||||
selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
|
||||
assert selected_files.text == exp_name
|
||||
|
||||
state = await upload_file.get_state(token)
|
||||
state = await upload_file.get_state(substate_token)
|
||||
if secondary:
|
||||
# only the secondary form tracks progress and chain events
|
||||
assert state.substates["upload_state"].event_order.count("upload_progress") == 1
|
||||
@ -223,6 +228,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
|
||||
# wait for the backend connection to send the token
|
||||
token = upload_file.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
substate_token = f"{token}_state.upload_state"
|
||||
|
||||
upload_box = driver.find_element(By.XPATH, "//input[@type='file']")
|
||||
assert upload_box
|
||||
@ -250,7 +256,11 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
|
||||
|
||||
# look up the backend state and assert on uploaded contents
|
||||
async def get_file_data():
|
||||
return (await upload_file.get_state(token)).substates["upload_state"]._file_data
|
||||
return (
|
||||
(await upload_file.get_state(substate_token))
|
||||
.substates["upload_state"]
|
||||
._file_data
|
||||
)
|
||||
|
||||
file_data = await AppHarness._poll_for_async(get_file_data)
|
||||
assert isinstance(file_data, dict)
|
||||
@ -330,6 +340,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
|
||||
# wait for the backend connection to send the token
|
||||
token = upload_file.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
substate_token = f"{token}_state.upload_state"
|
||||
|
||||
upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1]
|
||||
upload_button = driver.find_element(By.ID, f"upload_button_secondary")
|
||||
@ -347,7 +358,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
|
||||
cancel_button.click()
|
||||
|
||||
# look up the backend state and assert on progress
|
||||
state = await upload_file.get_state(token)
|
||||
state = await upload_file.get_state(substate_token)
|
||||
assert state.substates["upload_state"].progress_dicts
|
||||
assert exp_name not in state.substates["upload_state"]._file_data
|
||||
|
||||
|
@ -926,7 +926,7 @@ async def process(
|
||||
}
|
||||
)
|
||||
# Get the state for the session exclusively.
|
||||
async with app.state_manager.modify_state(event.token) as state:
|
||||
async with app.state_manager.modify_state(event.substate_token) as state:
|
||||
# re-assign only when the value is different
|
||||
if state.router_data != router_data:
|
||||
# assignment will recurse into substates and force recalculation of
|
||||
@ -1002,7 +1002,8 @@ def upload(app: App):
|
||||
)
|
||||
|
||||
# Get the state for the session.
|
||||
state = await app.state_manager.get_state(token)
|
||||
substate_token = token + "_" + handler.rpartition(".")[0]
|
||||
state = await app.state_manager.get_state(substate_token)
|
||||
|
||||
# get the current session ID
|
||||
# get the current state(parent state/substate)
|
||||
@ -1049,7 +1050,7 @@ def upload(app: App):
|
||||
Each state update as JSON followed by a new line.
|
||||
"""
|
||||
# Process the event.
|
||||
async with app.state_manager.modify_state(token) as state:
|
||||
async with app.state_manager.modify_state(event.substate_token) as state:
|
||||
async for update in state._process(event):
|
||||
# Postprocess the event.
|
||||
update = await app.postprocess(state, event, update)
|
||||
|
@ -41,6 +41,16 @@ class Event(Base):
|
||||
# The event payload.
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def substate_token(self) -> str:
|
||||
"""Get the substate token for the event.
|
||||
|
||||
Returns:
|
||||
The substate token.
|
||||
"""
|
||||
substate = self.name.rpartition(".")[0]
|
||||
return f"{self.token}_{substate}"
|
||||
|
||||
|
||||
BACKGROUND_TASK_MARKER = "_reflex_background_task"
|
||||
|
||||
|
166
reflex/state.py
166
reflex/state.py
@ -213,19 +213,27 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# The router data for the current page
|
||||
router: RouterData = RouterData()
|
||||
|
||||
def __init__(self, *args, parent_state: BaseState | None = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
parent_state: BaseState | None = None,
|
||||
init_substates: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the state.
|
||||
|
||||
Args:
|
||||
*args: The args to pass to the Pydantic init method.
|
||||
parent_state: The parent state.
|
||||
init_substates: Whether to initialize the substates in this instance.
|
||||
**kwargs: The kwargs to pass to the Pydantic init method.
|
||||
|
||||
"""
|
||||
kwargs["parent_state"] = parent_state
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Setup the substates.
|
||||
# Setup the substates (for memory state manager only).
|
||||
if init_substates:
|
||||
for substate in self.get_substates():
|
||||
self.substates[substate.get_name()] = substate(parent_state=self)
|
||||
# Convert the event handlers to functions.
|
||||
@ -1005,7 +1013,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
for substate in self.substates.values():
|
||||
substate._reset_client_storage()
|
||||
|
||||
def get_substate(self, path: Sequence[str]) -> BaseState | None:
|
||||
def get_substate(self, path: Sequence[str]) -> BaseState:
|
||||
"""Get the substate.
|
||||
|
||||
Args:
|
||||
@ -1260,6 +1268,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Recursively find the substate deltas.
|
||||
substates = self.substates
|
||||
for substate in self.dirty_substates.union(self._always_dirty_substates):
|
||||
if substate not in substates:
|
||||
continue # substate not loaded at this time, no delta
|
||||
delta.update(substates[substate].get_delta())
|
||||
|
||||
# Format the delta.
|
||||
@ -1287,6 +1297,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
for var in self.dirty_vars:
|
||||
for substate_name in self._substate_var_dependencies[var]:
|
||||
self.dirty_substates.add(substate_name)
|
||||
if substate_name not in substates:
|
||||
continue
|
||||
substate = substates[substate_name]
|
||||
substate.dirty_vars.add(var)
|
||||
substate._mark_dirty()
|
||||
@ -1295,6 +1307,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""Reset the dirty vars."""
|
||||
# Recursively clean the substates.
|
||||
for substate in self.dirty_substates:
|
||||
if substate not in self.substates:
|
||||
continue
|
||||
self.substates[substate]._clean()
|
||||
|
||||
# Clean this state.
|
||||
@ -1380,6 +1394,24 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
"""
|
||||
pass
|
||||
|
||||
def __getstate__(self):
|
||||
"""Get the state for redis serialization.
|
||||
|
||||
This method is called by cloudpickle to serialize the object.
|
||||
|
||||
It explicitly removes parent_state and substates because those are serialized separately
|
||||
by the StateManagerRedis to allow for better horizontal scaling as state size increases.
|
||||
|
||||
Returns:
|
||||
The state dict for serialization.
|
||||
"""
|
||||
state = super().__getstate__()
|
||||
# Never serialize parent_state or substates
|
||||
state["__dict__"] = state["__dict__"].copy()
|
||||
state["__dict__"]["parent_state"] = None
|
||||
state["__dict__"]["substates"] = {}
|
||||
return state
|
||||
|
||||
|
||||
class State(BaseState):
|
||||
"""The app Base State."""
|
||||
@ -1479,6 +1511,8 @@ class StateProxy(wrapt.ObjectProxy):
|
||||
"""
|
||||
self._self_actx = self._self_app.modify_state(
|
||||
self.__wrapped__.router.session.client_token
|
||||
+ "_"
|
||||
+ ".".join(self._self_substate_path)
|
||||
)
|
||||
mutable_state = await self._self_actx.__aenter__()
|
||||
super().__setattr__(
|
||||
@ -1675,6 +1709,8 @@ class StateManagerMemory(StateManager):
|
||||
Returns:
|
||||
The state for the token.
|
||||
"""
|
||||
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
||||
token = token.partition("_")[0]
|
||||
if token not in self.states:
|
||||
self.states[token] = self.state()
|
||||
return self.states[token]
|
||||
@ -1698,6 +1734,8 @@ class StateManagerMemory(StateManager):
|
||||
Yields:
|
||||
The state for the token.
|
||||
"""
|
||||
# Memory state manager ignores the substate suffix and always returns the top-level state.
|
||||
token = token.partition("_")[0]
|
||||
if token not in self._states_locks:
|
||||
async with self._state_manager_lock:
|
||||
if token not in self._states_locks:
|
||||
@ -1737,23 +1775,104 @@ class StateManagerRedis(StateManager):
|
||||
b"evicted",
|
||||
}
|
||||
|
||||
async def get_state(self, token: str) -> BaseState:
|
||||
async def get_state(
|
||||
self,
|
||||
token: str,
|
||||
top_level: bool = True,
|
||||
get_substates: bool = True,
|
||||
parent_state: BaseState | None = None,
|
||||
) -> BaseState:
|
||||
"""Get the state for a token.
|
||||
|
||||
Args:
|
||||
token: The token to get the state for.
|
||||
top_level: If true, return an instance of the top-level state.
|
||||
get_substates: If true, also retrieve substates
|
||||
parent_state: If provided, use this parent_state instead of getting it from redis.
|
||||
|
||||
Returns:
|
||||
The state for the token.
|
||||
|
||||
Raises:
|
||||
RuntimeError: when the state_cls is not specified in the token
|
||||
"""
|
||||
# Split the actual token from the fully qualified substate name.
|
||||
client_token, _, state_path = token.partition("_")
|
||||
if state_path:
|
||||
# Get the State class associated with the given path.
|
||||
state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
|
||||
)
|
||||
|
||||
# Fetch the serialized substate from redis.
|
||||
redis_state = await self.redis.get(token)
|
||||
if redis_state is None:
|
||||
await self.set_state(token, self.state())
|
||||
return await self.get_state(token)
|
||||
return cloudpickle.loads(redis_state)
|
||||
|
||||
if redis_state is not None:
|
||||
# Deserialize the substate.
|
||||
state = cloudpickle.loads(redis_state)
|
||||
|
||||
# Populate parent and substates if requested.
|
||||
if parent_state is None:
|
||||
# Retrieve the parent state from redis.
|
||||
parent_state_name = state_path.rpartition(".")[0]
|
||||
if parent_state_name:
|
||||
parent_state_key = token.rpartition(".")[0]
|
||||
parent_state = await self.get_state(
|
||||
parent_state_key, top_level=False, get_substates=False
|
||||
)
|
||||
# Set up Bidirectional linkage between this state and its parent.
|
||||
if parent_state is not None:
|
||||
parent_state.substates[state.get_name()] = state
|
||||
state.parent_state = parent_state
|
||||
if get_substates:
|
||||
# Retrieve all substates from redis.
|
||||
for substate_cls in state_cls.get_substates():
|
||||
substate_name = substate_cls.get_name()
|
||||
substate_key = token + "." + substate_name
|
||||
state.substates[substate_name] = await self.get_state(
|
||||
substate_key, top_level=False, parent_state=state
|
||||
)
|
||||
# To retain compatibility with previous implementation, by default, we return
|
||||
# the top-level state by chasing `parent_state` pointers up the tree.
|
||||
if top_level:
|
||||
while type(state) != self.state and state.parent_state is not None:
|
||||
state = state.parent_state
|
||||
return state
|
||||
|
||||
# Key didn't exist so we have to create a new entry for this token.
|
||||
if parent_state is None:
|
||||
parent_state_name = state_path.rpartition(".")[0]
|
||||
if parent_state_name:
|
||||
# Retrieve the parent state to populate event handlers onto this substate.
|
||||
parent_state_key = client_token + "_" + parent_state_name
|
||||
parent_state = await self.get_state(
|
||||
parent_state_key, top_level=False, get_substates=False
|
||||
)
|
||||
# Persist the new state class to redis.
|
||||
await self.set_state(
|
||||
token,
|
||||
state_cls(
|
||||
parent_state=parent_state,
|
||||
init_substates=False,
|
||||
),
|
||||
)
|
||||
# After creating the state key, recursively call `get_state` to populate substates.
|
||||
return await self.get_state(
|
||||
token,
|
||||
top_level=top_level,
|
||||
get_substates=get_substates,
|
||||
parent_state=parent_state,
|
||||
)
|
||||
|
||||
async def set_state(
|
||||
self, token: str, state: BaseState, lock_id: bytes | None = None
|
||||
self,
|
||||
token: str,
|
||||
state: BaseState,
|
||||
lock_id: bytes | None = None,
|
||||
set_substates: bool = True,
|
||||
set_parent_state: bool = True,
|
||||
):
|
||||
"""Set the state for a token.
|
||||
|
||||
@ -1761,11 +1880,13 @@ class StateManagerRedis(StateManager):
|
||||
token: The token to set the state for.
|
||||
state: The state to set.
|
||||
lock_id: If provided, the lock_key must be set to this value to set the state.
|
||||
set_substates: If True, write substates to redis
|
||||
set_parent_state: If True, write parent state to redis
|
||||
|
||||
Raises:
|
||||
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
|
||||
"""
|
||||
# check that we're holding the lock
|
||||
# Check that we're holding the lock.
|
||||
if (
|
||||
lock_id is not None
|
||||
and await self.redis.get(self._lock_key(token)) != lock_id
|
||||
@ -1775,6 +1896,27 @@ class StateManagerRedis(StateManager):
|
||||
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
||||
"or use `@rx.background` decorator for long-running tasks."
|
||||
)
|
||||
# Find the substate associated with the token.
|
||||
state_path = token.partition("_")[2]
|
||||
if state_path and state.get_full_name() != state_path:
|
||||
state = state.get_substate(tuple(state_path.split(".")))
|
||||
# Persist the parent state separately, if requested.
|
||||
if state.parent_state is not None and set_parent_state:
|
||||
parent_state_key = token.rpartition(".")[0]
|
||||
await self.set_state(
|
||||
parent_state_key,
|
||||
state.parent_state,
|
||||
lock_id=lock_id,
|
||||
set_substates=False,
|
||||
)
|
||||
# Persist the substates separately, if requested.
|
||||
if set_substates:
|
||||
for substate_name, substate in state.substates.items():
|
||||
substate_key = token + "." + substate_name
|
||||
await self.set_state(
|
||||
substate_key, substate, lock_id=lock_id, set_parent_state=False
|
||||
)
|
||||
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
|
||||
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
@ -1802,7 +1944,9 @@ class StateManagerRedis(StateManager):
|
||||
Returns:
|
||||
The redis lock key for the token.
|
||||
"""
|
||||
return f"{token}_lock".encode()
|
||||
# All substates share the same lock domain, so ignore any substate path suffix.
|
||||
client_token = token.partition("_")[0]
|
||||
return f"{client_token}_lock".encode()
|
||||
|
||||
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
|
||||
"""Try to get a redis lock for a token.
|
||||
|
@ -220,6 +220,7 @@ class AppHarness:
|
||||
reflex.config.get_config(reload=True)
|
||||
# reset rx.State subclasses
|
||||
State.class_subclasses.clear()
|
||||
State.get_class_substate.cache_clear()
|
||||
# Ensure the AppHarness test does not skip State assignment due to running via pytest
|
||||
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
|
||||
# self.app_module.app.
|
||||
|
@ -340,7 +340,7 @@ async def test_initialize_with_state(test_state: Type[ATestState], token: str):
|
||||
assert app.state == test_state
|
||||
|
||||
# Get a state for a given token.
|
||||
state = await app.state_manager.get_state(token)
|
||||
state = await app.state_manager.get_state(f"{token}_{test_state.get_full_name()}")
|
||||
assert isinstance(state, test_state)
|
||||
assert state.var == 0 # type: ignore
|
||||
|
||||
@ -358,8 +358,8 @@ async def test_set_and_get_state(test_state):
|
||||
app = App(state=test_state)
|
||||
|
||||
# Create two tokens.
|
||||
token1 = str(uuid.uuid4())
|
||||
token2 = str(uuid.uuid4())
|
||||
token1 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}"
|
||||
token2 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}"
|
||||
|
||||
# Get the default state for each token.
|
||||
state1 = await app.state_manager.get_state(token1)
|
||||
@ -744,18 +744,18 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
|
||||
# The App state must be the "root" of the state tree
|
||||
app = App(state=State)
|
||||
app.event_namespace.emit = AsyncMock() # type: ignore
|
||||
current_state = await app.state_manager.get_state(token)
|
||||
substate_token = f"{token}_{state.get_full_name()}"
|
||||
current_state = await app.state_manager.get_state(substate_token)
|
||||
data = b"This is binary data"
|
||||
|
||||
# Create a binary IO object and write data to it
|
||||
bio = io.BytesIO()
|
||||
bio.write(data)
|
||||
|
||||
state_name = state.get_full_name().partition(".")[2] or state.get_name()
|
||||
request_mock = unittest.mock.Mock()
|
||||
request_mock.headers = {
|
||||
"reflex-client-token": token,
|
||||
"reflex-event-handler": f"state.{state_name}.multi_handle_upload",
|
||||
"reflex-event-handler": f"{state.get_full_name()}.multi_handle_upload",
|
||||
}
|
||||
|
||||
file1 = UploadFile(
|
||||
@ -774,7 +774,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
|
||||
== StateUpdate(delta=delta, events=[], final=True).json() + "\n"
|
||||
)
|
||||
|
||||
current_state = await app.state_manager.get_state(token)
|
||||
current_state = await app.state_manager.get_state(substate_token)
|
||||
state_dict = current_state.dict()[state.get_full_name()]
|
||||
assert state_dict["img_list"] == [
|
||||
"image1.jpg",
|
||||
@ -799,14 +799,12 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
|
||||
token: a Token.
|
||||
"""
|
||||
state._tmp_path = tmp_path
|
||||
# The App state must be the "root" of the state tree
|
||||
app = App(state=state if state is FileUploadState else FileStateBase1)
|
||||
app = App(state=State)
|
||||
|
||||
state_name = state.get_full_name().partition(".")[2] or state.get_name()
|
||||
request_mock = unittest.mock.Mock()
|
||||
request_mock.headers = {
|
||||
"reflex-client-token": token,
|
||||
"reflex-event-handler": f"{state_name}.handle_upload2",
|
||||
"reflex-event-handler": f"{state.get_full_name()}.handle_upload2",
|
||||
}
|
||||
file_mock = unittest.mock.Mock(filename="image1.jpg")
|
||||
fn = upload(app)
|
||||
@ -814,7 +812,7 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
|
||||
await fn(request_mock, [file_mock])
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
|
||||
== f"`{state.get_full_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
|
||||
)
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
@ -835,14 +833,12 @@ async def test_upload_file_background(state, tmp_path, token):
|
||||
token: a Token.
|
||||
"""
|
||||
state._tmp_path = tmp_path
|
||||
# The App state must be the "root" of the state tree
|
||||
app = App(state=state if state is FileUploadState else FileStateBase1)
|
||||
app = App(state=State)
|
||||
|
||||
state_name = state.get_full_name().partition(".")[2] or state.get_name()
|
||||
request_mock = unittest.mock.Mock()
|
||||
request_mock.headers = {
|
||||
"reflex-client-token": token,
|
||||
"reflex-event-handler": f"{state_name}.bg_upload",
|
||||
"reflex-event-handler": f"{state.get_full_name()}.bg_upload",
|
||||
}
|
||||
file_mock = unittest.mock.Mock(filename="image1.jpg")
|
||||
fn = upload(app)
|
||||
@ -850,7 +846,7 @@ async def test_upload_file_background(state, tmp_path, token):
|
||||
await fn(request_mock, [file_mock])
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== f"@rx.background is not supported for upload handler `{state_name}.bg_upload`."
|
||||
== f"@rx.background is not supported for upload handler `{state.get_full_name()}.bg_upload`."
|
||||
)
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
@ -932,9 +928,10 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
}
|
||||
assert constants.ROUTER in app.state()._computed_var_dependencies
|
||||
|
||||
substate_token = f"{token}_{DynamicState.get_full_name()}"
|
||||
sid = "mock_sid"
|
||||
client_ip = "127.0.0.1"
|
||||
state = await app.state_manager.get_state(token)
|
||||
state = await app.state_manager.get_state(substate_token)
|
||||
assert state.dynamic == ""
|
||||
exp_vals = ["foo", "foobar", "baz"]
|
||||
|
||||
@ -1004,7 +1001,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
)
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
# When redis is used, the state is not updated until the processing is complete
|
||||
state = await app.state_manager.get_state(token)
|
||||
state = await app.state_manager.get_state(substate_token)
|
||||
assert state.dynamic == prev_exp_val
|
||||
|
||||
# complete the processing
|
||||
@ -1012,7 +1009,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
await process_coro.__anext__()
|
||||
|
||||
# check that router data was written to the state_manager store
|
||||
state = await app.state_manager.get_state(token)
|
||||
state = await app.state_manager.get_state(substate_token)
|
||||
assert state.dynamic == exp_val
|
||||
|
||||
process_coro = process(
|
||||
@ -1087,7 +1084,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
await process_coro.__anext__()
|
||||
|
||||
prev_exp_val = exp_val
|
||||
state = await app.state_manager.get_state(token)
|
||||
state = await app.state_manager.get_state(substate_token)
|
||||
assert state.loaded == len(exp_vals)
|
||||
assert state.counter == len(exp_vals)
|
||||
# print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
|
||||
@ -1124,7 +1121,7 @@ async def test_process_events(mocker, token: str):
|
||||
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
|
||||
pass
|
||||
|
||||
assert (await app.state_manager.get_state(token)).value == 5
|
||||
assert (await app.state_manager.get_state(event.substate_token)).value == 5
|
||||
assert app.postprocess.call_count == 6
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
|
@ -38,8 +38,8 @@ from reflex.vars import BaseVar, ComputedVar
|
||||
from .states import GenState
|
||||
|
||||
CI = bool(os.environ.get("CI", False))
|
||||
LOCK_EXPIRATION = 2000 if CI else 100
|
||||
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
|
||||
LOCK_EXPIRATION = 2000 if CI else 300
|
||||
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
|
||||
|
||||
|
||||
formatted_router = {
|
||||
@ -1432,15 +1432,32 @@ def state_manager(request) -> Generator[StateManager, None, None]:
|
||||
asyncio.get_event_loop().run_until_complete(state_manager.close())
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def substate_token(state_manager, token):
|
||||
"""A token + substate name for looking up in state manager.
|
||||
|
||||
Args:
|
||||
state_manager: A state manager instance.
|
||||
token: A token.
|
||||
|
||||
Returns:
|
||||
Token concatenated with the state_manager's state full_name.
|
||||
"""
|
||||
return f"{token}_{state_manager.state.get_full_name()}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_modify_state(state_manager: StateManager, token: str):
|
||||
async def test_state_manager_modify_state(
|
||||
state_manager: StateManager, token: str, substate_token: str
|
||||
):
|
||||
"""Test that the state manager can modify a state exclusively.
|
||||
|
||||
Args:
|
||||
state_manager: A state manager instance.
|
||||
token: A token.
|
||||
substate_token: A token + substate name for looking up in state manager.
|
||||
"""
|
||||
async with state_manager.modify_state(token):
|
||||
async with state_manager.modify_state(substate_token):
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert await state_manager.redis.get(f"{token}_lock")
|
||||
elif isinstance(state_manager, StateManagerMemory):
|
||||
@ -1461,21 +1478,24 @@ async def test_state_manager_modify_state(state_manager: StateManager, token: st
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_contend(state_manager: StateManager, token: str):
|
||||
async def test_state_manager_contend(
|
||||
state_manager: StateManager, token: str, substate_token: str
|
||||
):
|
||||
"""Multiple coroutines attempting to access the same state.
|
||||
|
||||
Args:
|
||||
state_manager: A state manager instance.
|
||||
token: A token.
|
||||
substate_token: A token + substate name for looking up in state manager.
|
||||
"""
|
||||
n_coroutines = 10
|
||||
exp_num1 = 10
|
||||
|
||||
async with state_manager.modify_state(token) as state:
|
||||
async with state_manager.modify_state(substate_token) as state:
|
||||
state.num1 = 0
|
||||
|
||||
async def _coro():
|
||||
async with state_manager.modify_state(token) as state:
|
||||
async with state_manager.modify_state(substate_token) as state:
|
||||
await asyncio.sleep(0.01)
|
||||
state.num1 += 1
|
||||
|
||||
@ -1484,7 +1504,7 @@ async def test_state_manager_contend(state_manager: StateManager, token: str):
|
||||
for f in asyncio.as_completed(tasks):
|
||||
await f
|
||||
|
||||
assert (await state_manager.get_state(token)).num1 == exp_num1
|
||||
assert (await state_manager.get_state(substate_token)).num1 == exp_num1
|
||||
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||
@ -1510,33 +1530,51 @@ def state_manager_redis() -> Generator[StateManager, None, None]:
|
||||
asyncio.get_event_loop().run_until_complete(state_manager.close())
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def substate_token_redis(state_manager_redis, token):
|
||||
"""A token + substate name for looking up in state manager.
|
||||
|
||||
Args:
|
||||
state_manager_redis: A state manager instance.
|
||||
token: A token.
|
||||
|
||||
Returns:
|
||||
Token concatenated with the state_manager's state full_name.
|
||||
"""
|
||||
return f"{token}_{state_manager_redis.state.get_full_name()}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_lock_expire(state_manager_redis: StateManager, token: str):
|
||||
async def test_state_manager_lock_expire(
|
||||
state_manager_redis: StateManager, token: str, substate_token_redis: str
|
||||
):
|
||||
"""Test that the state manager lock expires and raises exception exiting context.
|
||||
|
||||
Args:
|
||||
state_manager_redis: A state manager instance.
|
||||
token: A token.
|
||||
substate_token_redis: A token + substate name for looking up in state manager.
|
||||
"""
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
|
||||
async with state_manager_redis.modify_state(token):
|
||||
async with state_manager_redis.modify_state(substate_token_redis):
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
with pytest.raises(LockExpiredError):
|
||||
async with state_manager_redis.modify_state(token):
|
||||
async with state_manager_redis.modify_state(substate_token_redis):
|
||||
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_lock_expire_contend(
|
||||
state_manager_redis: StateManager, token: str
|
||||
state_manager_redis: StateManager, token: str, substate_token_redis: str
|
||||
):
|
||||
"""Test that the state manager lock expires and queued waiters proceed.
|
||||
|
||||
Args:
|
||||
state_manager_redis: A state manager instance.
|
||||
token: A token.
|
||||
substate_token_redis: A token + substate name for looking up in state manager.
|
||||
"""
|
||||
exp_num1 = 4252
|
||||
unexp_num1 = 666
|
||||
@ -1546,7 +1584,7 @@ async def test_state_manager_lock_expire_contend(
|
||||
order = []
|
||||
|
||||
async def _coro_blocker():
|
||||
async with state_manager_redis.modify_state(token) as state:
|
||||
async with state_manager_redis.modify_state(substate_token_redis) as state:
|
||||
order.append("blocker")
|
||||
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
|
||||
state.num1 = unexp_num1
|
||||
@ -1554,7 +1592,7 @@ async def test_state_manager_lock_expire_contend(
|
||||
async def _coro_waiter():
|
||||
while "blocker" not in order:
|
||||
await asyncio.sleep(0.005)
|
||||
async with state_manager_redis.modify_state(token) as state:
|
||||
async with state_manager_redis.modify_state(substate_token_redis) as state:
|
||||
order.append("waiter")
|
||||
assert state.num1 != unexp_num1
|
||||
state.num1 = exp_num1
|
||||
@ -1568,7 +1606,7 @@ async def test_state_manager_lock_expire_contend(
|
||||
await tasks[1]
|
||||
|
||||
assert order == ["blocker", "waiter"]
|
||||
assert (await state_manager_redis.get_state(token)).num1 == exp_num1
|
||||
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@ -1643,7 +1681,8 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
assert sp.value2 == 42
|
||||
|
||||
# Get the state from the state manager directly and check that the value is updated
|
||||
gotten_state = await mock_app.state_manager.get_state(grandchild_state.get_token())
|
||||
gc_token = f"{grandchild_state.get_token()}_{grandchild_state.get_full_name()}"
|
||||
gotten_state = await mock_app.state_manager.get_state(gc_token)
|
||||
if isinstance(mock_app.state_manager, StateManagerMemory):
|
||||
# For in-process store, only one instance of the state exists
|
||||
assert gotten_state is parent_state
|
||||
@ -1836,7 +1875,8 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
"private",
|
||||
]
|
||||
|
||||
assert (await mock_app.state_manager.get_state(token)).order == exp_order
|
||||
substate_token = f"{token}_{BackgroundTaskState.get_name()}"
|
||||
assert (await mock_app.state_manager.get_state(substate_token)).order == exp_order
|
||||
|
||||
assert mock_app.event_namespace is not None
|
||||
emit_mock = mock_app.event_namespace.emit
|
||||
@ -1913,7 +1953,8 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
|
||||
await task
|
||||
assert not mock_app.background_tasks
|
||||
|
||||
assert (await mock_app.state_manager.get_state(token)).order == [
|
||||
substate_token = f"{token}_{BackgroundTaskState.get_name()}"
|
||||
assert (await mock_app.state_manager.get_state(substate_token)).order == [
|
||||
"reset",
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user