[REF-1885] Shard Substates when serializing to Redis ()

* Move sharding internal to StateManager

Avoid leaking sharding implementation details all over the State class and
breaking the API

* WiP StateManager based sharding

* Copy the state __dict__ when serializing to avoid breaking the instance

* State tests need to pass the correct substate token for redis

* state: when getting parent_state, set top_level=False

ensure that we don't end up with a broken tree

* test_app: get tests passing with redis by passing the correct token

refactor upload tests to suck less

* test_client_storage: look up substate key

* state.py: pass static checks

* test_dynamic_routes: working with redis state shard

* Update the remaining AppHarness tests to pass {token}_{state.get_full_name()}

* test_app: pass all tokens with state suffix

* StateManagerRedis: clean up commentary
This commit is contained in:
Masen Furer 2024-02-21 01:50:25 -08:00 committed by GitHub
parent f9d219407f
commit 756bf9b0f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 290 additions and 78 deletions

View File

@ -449,7 +449,7 @@ async def test_client_side_state(
assert l1s.text == "l1s value"
# reset the backend state to force refresh from client storage
async with client_side.modify_state(token) as state:
async with client_side.modify_state(f"{token}_state.client_side_state") as state:
state.reset()
driver.refresh()

View File

@ -85,6 +85,7 @@ def dynamic_route(
"""
with app_harness_env.create(
root=tmp_path_factory.mktemp(f"dynamic_route"),
app_name=f"dynamicroute_{app_harness_env.__name__.lower()}",
app_source=DynamicRoute, # type: ignore
) as harness:
yield harness
@ -146,7 +147,7 @@ def poll_for_order(
async def _poll_for_order(exp_order: list[str]):
async def _backend_state():
return await dynamic_route.get_state(token)
return await dynamic_route.get_state(f"{token}_state.dynamic_state")
async def _check():
return (await _backend_state()).substates[
@ -194,7 +195,9 @@ async def test_on_load_navigate(
assert link
assert page_id_input
assert dynamic_route.poll_for_value(page_id_input) == str(ix)
assert dynamic_route.poll_for_value(
page_id_input, exp_not_equal=str(ix - 1)
) == str(ix)
assert dynamic_route.poll_for_value(raw_path_input) == f"/page/{ix}/"
await poll_for_order(exp_order)
@ -220,7 +223,9 @@ async def test_on_load_navigate(
with poll_for_navigation(driver):
driver.get(f"{driver.current_url}?foo=bar")
await poll_for_order(exp_order)
assert (await dynamic_route.get_state(token)).router.page.params["foo"] == "bar"
assert (
await dynamic_route.get_state(f"{token}_state.dynamic_state")
).router.page.params["foo"] == "bar"
# hit a 404 and ensure we still hydrate
exp_order += ["/404-no page id"]

View File

@ -207,7 +207,7 @@ def poll_for_order(
async def _poll_for_order(exp_order: list[str]):
async def _backend_state():
return await event_action.get_state(token)
return await event_action.get_state(f"{token}_state.event_action_state")
async def _check():
return (await _backend_state()).substates[

View File

@ -298,7 +298,7 @@ def assert_token(event_chain: AppHarness, driver: WebDriver) -> str:
token = event_chain.poll_for_value(token_input)
assert token is not None
return token
return f"{token}_state.state"
@pytest.mark.parametrize(

View File

@ -221,7 +221,11 @@ async def test_submit(driver, form_submit: AppHarness):
submit_input.click()
async def get_form_data():
return (await form_submit.get_state(token)).substates["form_state"].form_data
return (
(await form_submit.get_state(f"{token}_state.form_state"))
.substates["form_state"]
.form_data
)
# wait for the form data to arrive at the backend
form_data = await AppHarness._poll_for_async(get_form_data)

View File

@ -76,6 +76,10 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
token = fully_controlled_input.poll_for_value(token_input)
assert token
async def get_state_text():
state = await fully_controlled_input.get_state(f"{token}_state.state")
return state.substates["state"].text
# find the input and wait for it to have the initial state value
debounce_input = driver.find_element(By.ID, "debounce_input_input")
value_input = driver.find_element(By.ID, "value_input")
@ -95,16 +99,14 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("foo")
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "ifoonitial"
assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "ifoonitial"
assert await get_state_text() == "ifoonitial"
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
assert fully_controlled_input.poll_for_value(plain_value_input) == "ifoonitial"
# clear the input on the backend
async with fully_controlled_input.modify_state(token) as state:
async with fully_controlled_input.modify_state(f"{token}_state.state") as state:
state.substates["state"].text = ""
assert (await fully_controlled_input.get_state(token)).substates["state"].text == ""
assert await get_state_text() == ""
assert (
fully_controlled_input.poll_for_value(
debounce_input, exp_not_equal="ifoonitial"
@ -116,9 +118,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("getting testing done")
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "getting testing done"
assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "getting testing done"
assert await get_state_text() == "getting testing done"
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
assert (
fully_controlled_input.poll_for_value(plain_value_input)
@ -130,9 +130,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "overwrite the state"
assert on_change_input.get_attribute("value") == "overwrite the state"
assert (await fully_controlled_input.get_state(token)).substates[
"state"
].text == "overwrite the state"
assert await get_state_text() == "overwrite the state"
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
assert (
fully_controlled_input.poll_for_value(plain_value_input)

View File

@ -171,6 +171,7 @@ async def test_upload_file(
# wait for the backend connection to send the token
token = upload_file.poll_for_value(token_input)
assert token is not None
substate_token = f"{token}_state.upload_state"
suffix = "_secondary" if secondary else ""
@ -191,7 +192,11 @@ async def test_upload_file(
# look up the backend state and assert on uploaded contents
async def get_file_data():
return (await upload_file.get_state(token)).substates["upload_state"]._file_data
return (
(await upload_file.get_state(substate_token))
.substates["upload_state"]
._file_data
)
file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict)
@ -201,7 +206,7 @@ async def test_upload_file(
selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
assert selected_files.text == exp_name
state = await upload_file.get_state(token)
state = await upload_file.get_state(substate_token)
if secondary:
# only the secondary form tracks progress and chain events
assert state.substates["upload_state"].event_order.count("upload_progress") == 1
@ -223,6 +228,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
# wait for the backend connection to send the token
token = upload_file.poll_for_value(token_input)
assert token is not None
substate_token = f"{token}_state.upload_state"
upload_box = driver.find_element(By.XPATH, "//input[@type='file']")
assert upload_box
@ -250,7 +256,11 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
# look up the backend state and assert on uploaded contents
async def get_file_data():
return (await upload_file.get_state(token)).substates["upload_state"]._file_data
return (
(await upload_file.get_state(substate_token))
.substates["upload_state"]
._file_data
)
file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict)
@ -330,6 +340,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
# wait for the backend connection to send the token
token = upload_file.poll_for_value(token_input)
assert token is not None
substate_token = f"{token}_state.upload_state"
upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1]
upload_button = driver.find_element(By.ID, f"upload_button_secondary")
@ -347,7 +358,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive
cancel_button.click()
# look up the backend state and assert on progress
state = await upload_file.get_state(token)
state = await upload_file.get_state(substate_token)
assert state.substates["upload_state"].progress_dicts
assert exp_name not in state.substates["upload_state"]._file_data

View File

@ -926,7 +926,7 @@ async def process(
}
)
# Get the state for the session exclusively.
async with app.state_manager.modify_state(event.token) as state:
async with app.state_manager.modify_state(event.substate_token) as state:
# re-assign only when the value is different
if state.router_data != router_data:
# assignment will recurse into substates and force recalculation of
@ -1002,7 +1002,8 @@ def upload(app: App):
)
# Get the state for the session.
state = await app.state_manager.get_state(token)
substate_token = token + "_" + handler.rpartition(".")[0]
state = await app.state_manager.get_state(substate_token)
# get the current session ID
# get the current state(parent state/substate)
@ -1049,7 +1050,7 @@ def upload(app: App):
Each state update as JSON followed by a new line.
"""
# Process the event.
async with app.state_manager.modify_state(token) as state:
async with app.state_manager.modify_state(event.substate_token) as state:
async for update in state._process(event):
# Postprocess the event.
update = await app.postprocess(state, event, update)

View File

@ -41,6 +41,16 @@ class Event(Base):
# The event payload.
payload: Dict[str, Any] = {}
@property
def substate_token(self) -> str:
"""Get the substate token for the event.
Returns:
The substate token.
"""
substate = self.name.rpartition(".")[0]
return f"{self.token}_{substate}"
BACKGROUND_TASK_MARKER = "_reflex_background_task"

View File

@ -213,21 +213,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# The router data for the current page
router: RouterData = RouterData()
def __init__(self, *args, parent_state: BaseState | None = None, **kwargs):
def __init__(
self,
*args,
parent_state: BaseState | None = None,
init_substates: bool = True,
**kwargs,
):
"""Initialize the state.
Args:
*args: The args to pass to the Pydantic init method.
parent_state: The parent state.
init_substates: Whether to initialize the substates in this instance.
**kwargs: The kwargs to pass to the Pydantic init method.
"""
kwargs["parent_state"] = parent_state
super().__init__(*args, **kwargs)
# Setup the substates.
for substate in self.get_substates():
self.substates[substate.get_name()] = substate(parent_state=self)
# Setup the substates (for memory state manager only).
if init_substates:
for substate in self.get_substates():
self.substates[substate.get_name()] = substate(parent_state=self)
# Convert the event handlers to functions.
self._init_event_handlers()
@ -1005,7 +1013,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for substate in self.substates.values():
substate._reset_client_storage()
def get_substate(self, path: Sequence[str]) -> BaseState | None:
def get_substate(self, path: Sequence[str]) -> BaseState:
"""Get the substate.
Args:
@ -1260,6 +1268,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Recursively find the substate deltas.
substates = self.substates
for substate in self.dirty_substates.union(self._always_dirty_substates):
if substate not in substates:
continue # substate not loaded at this time, no delta
delta.update(substates[substate].get_delta())
# Format the delta.
@ -1287,6 +1297,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
for var in self.dirty_vars:
for substate_name in self._substate_var_dependencies[var]:
self.dirty_substates.add(substate_name)
if substate_name not in substates:
continue
substate = substates[substate_name]
substate.dirty_vars.add(var)
substate._mark_dirty()
@ -1295,6 +1307,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""Reset the dirty vars."""
# Recursively clean the substates.
for substate in self.dirty_substates:
if substate not in self.substates:
continue
self.substates[substate]._clean()
# Clean this state.
@ -1380,6 +1394,24 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
"""
pass
def __getstate__(self):
"""Get the state for redis serialization.
This method is called by cloudpickle to serialize the object.
It explicitly removes parent_state and substates because those are serialized separately
by the StateManagerRedis to allow for better horizontal scaling as state size increases.
Returns:
The state dict for serialization.
"""
state = super().__getstate__()
# Never serialize parent_state or substates
state["__dict__"] = state["__dict__"].copy()
state["__dict__"]["parent_state"] = None
state["__dict__"]["substates"] = {}
return state
class State(BaseState):
"""The app Base State."""
@ -1479,6 +1511,8 @@ class StateProxy(wrapt.ObjectProxy):
"""
self._self_actx = self._self_app.modify_state(
self.__wrapped__.router.session.client_token
+ "_"
+ ".".join(self._self_substate_path)
)
mutable_state = await self._self_actx.__aenter__()
super().__setattr__(
@ -1675,6 +1709,8 @@ class StateManagerMemory(StateManager):
Returns:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = token.partition("_")[0]
if token not in self.states:
self.states[token] = self.state()
return self.states[token]
@ -1698,6 +1734,8 @@ class StateManagerMemory(StateManager):
Yields:
The state for the token.
"""
# Memory state manager ignores the substate suffix and always returns the top-level state.
token = token.partition("_")[0]
if token not in self._states_locks:
async with self._state_manager_lock:
if token not in self._states_locks:
@ -1737,23 +1775,104 @@ class StateManagerRedis(StateManager):
b"evicted",
}
async def get_state(self, token: str) -> BaseState:
async def get_state(
self,
token: str,
top_level: bool = True,
get_substates: bool = True,
parent_state: BaseState | None = None,
) -> BaseState:
"""Get the state for a token.
Args:
token: The token to get the state for.
top_level: If true, return an instance of the top-level state.
get_substates: If true, also retrieve substates
parent_state: If provided, use this parent_state instead of getting it from redis.
Returns:
The state for the token.
Raises:
RuntimeError: when the state_cls is not specified in the token
"""
# Split the actual token from the fully qualified substate name.
client_token, _, state_path = token.partition("_")
if state_path:
# Get the State class associated with the given path.
state_cls = self.state.get_class_substate(tuple(state_path.split(".")))
else:
raise RuntimeError(
"StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}"
)
# Fetch the serialized substate from redis.
redis_state = await self.redis.get(token)
if redis_state is None:
await self.set_state(token, self.state())
return await self.get_state(token)
return cloudpickle.loads(redis_state)
if redis_state is not None:
# Deserialize the substate.
state = cloudpickle.loads(redis_state)
# Populate parent and substates if requested.
if parent_state is None:
# Retrieve the parent state from redis.
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
parent_state_key = token.rpartition(".")[0]
parent_state = await self.get_state(
parent_state_key, top_level=False, get_substates=False
)
# Set up Bidirectional linkage between this state and its parent.
if parent_state is not None:
parent_state.substates[state.get_name()] = state
state.parent_state = parent_state
if get_substates:
# Retrieve all substates from redis.
for substate_cls in state_cls.get_substates():
substate_name = substate_cls.get_name()
substate_key = token + "." + substate_name
state.substates[substate_name] = await self.get_state(
substate_key, top_level=False, parent_state=state
)
# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
while type(state) != self.state and state.parent_state is not None:
state = state.parent_state
return state
# Key didn't exist so we have to create a new entry for this token.
if parent_state is None:
parent_state_name = state_path.rpartition(".")[0]
if parent_state_name:
# Retrieve the parent state to populate event handlers onto this substate.
parent_state_key = client_token + "_" + parent_state_name
parent_state = await self.get_state(
parent_state_key, top_level=False, get_substates=False
)
# Persist the new state class to redis.
await self.set_state(
token,
state_cls(
parent_state=parent_state,
init_substates=False,
),
)
# After creating the state key, recursively call `get_state` to populate substates.
return await self.get_state(
token,
top_level=top_level,
get_substates=get_substates,
parent_state=parent_state,
)
async def set_state(
self, token: str, state: BaseState, lock_id: bytes | None = None
self,
token: str,
state: BaseState,
lock_id: bytes | None = None,
set_substates: bool = True,
set_parent_state: bool = True,
):
"""Set the state for a token.
@ -1761,11 +1880,13 @@ class StateManagerRedis(StateManager):
token: The token to set the state for.
state: The state to set.
lock_id: If provided, the lock_key must be set to this value to set the state.
set_substates: If True, write substates to redis
set_parent_state: If True, write parent state to redis
Raises:
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
"""
# check that we're holding the lock
# Check that we're holding the lock.
if (
lock_id is not None
and await self.redis.get(self._lock_key(token)) != lock_id
@ -1775,6 +1896,27 @@ class StateManagerRedis(StateManager):
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
"or use `@rx.background` decorator for long-running tasks."
)
# Find the substate associated with the token.
state_path = token.partition("_")[2]
if state_path and state.get_full_name() != state_path:
state = state.get_substate(tuple(state_path.split(".")))
# Persist the parent state separately, if requested.
if state.parent_state is not None and set_parent_state:
parent_state_key = token.rpartition(".")[0]
await self.set_state(
parent_state_key,
state.parent_state,
lock_id=lock_id,
set_substates=False,
)
# Persist the substates separately, if requested.
if set_substates:
for substate_name, substate in state.substates.items():
substate_key = token + "." + substate_name
await self.set_state(
substate_key, substate, lock_id=lock_id, set_parent_state=False
)
# Persist only the given state (parents or substates are excluded by BaseState.__getstate__).
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
@contextlib.asynccontextmanager
@ -1802,7 +1944,9 @@ class StateManagerRedis(StateManager):
Returns:
The redis lock key for the token.
"""
return f"{token}_lock".encode()
# All substates share the same lock domain, so ignore any substate path suffix.
client_token = token.partition("_")[0]
return f"{client_token}_lock".encode()
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
"""Try to get a redis lock for a token.

View File

@ -220,6 +220,7 @@ class AppHarness:
reflex.config.get_config(reload=True)
# reset rx.State subclasses
State.class_subclasses.clear()
State.get_class_substate.cache_clear()
# Ensure the AppHarness test does not skip State assignment due to running via pytest
os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None)
# self.app_module.app.

View File

@ -340,7 +340,7 @@ async def test_initialize_with_state(test_state: Type[ATestState], token: str):
assert app.state == test_state
# Get a state for a given token.
state = await app.state_manager.get_state(token)
state = await app.state_manager.get_state(f"{token}_{test_state.get_full_name()}")
assert isinstance(state, test_state)
assert state.var == 0 # type: ignore
@ -358,8 +358,8 @@ async def test_set_and_get_state(test_state):
app = App(state=test_state)
# Create two tokens.
token1 = str(uuid.uuid4())
token2 = str(uuid.uuid4())
token1 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}"
token2 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}"
# Get the default state for each token.
state1 = await app.state_manager.get_state(token1)
@ -744,18 +744,18 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
# The App state must be the "root" of the state tree
app = App(state=State)
app.event_namespace.emit = AsyncMock() # type: ignore
current_state = await app.state_manager.get_state(token)
substate_token = f"{token}_{state.get_full_name()}"
current_state = await app.state_manager.get_state(substate_token)
data = b"This is binary data"
# Create a binary IO object and write data to it
bio = io.BytesIO()
bio.write(data)
state_name = state.get_full_name().partition(".")[2] or state.get_name()
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"state.{state_name}.multi_handle_upload",
"reflex-event-handler": f"{state.get_full_name()}.multi_handle_upload",
}
file1 = UploadFile(
@ -774,7 +774,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker):
== StateUpdate(delta=delta, events=[], final=True).json() + "\n"
)
current_state = await app.state_manager.get_state(token)
current_state = await app.state_manager.get_state(substate_token)
state_dict = current_state.dict()[state.get_full_name()]
assert state_dict["img_list"] == [
"image1.jpg",
@ -799,14 +799,12 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
token: a Token.
"""
state._tmp_path = tmp_path
# The App state must be the "root" of the state tree
app = App(state=state if state is FileUploadState else FileStateBase1)
app = App(state=State)
state_name = state.get_full_name().partition(".")[2] or state.get_name()
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state_name}.handle_upload2",
"reflex-event-handler": f"{state.get_full_name()}.handle_upload2",
}
file_mock = unittest.mock.Mock(filename="image1.jpg")
fn = upload(app)
@ -814,7 +812,7 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
await fn(request_mock, [file_mock])
assert (
err.value.args[0]
== f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
== f"`{state.get_full_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
)
if isinstance(app.state_manager, StateManagerRedis):
@ -835,14 +833,12 @@ async def test_upload_file_background(state, tmp_path, token):
token: a Token.
"""
state._tmp_path = tmp_path
# The App state must be the "root" of the state tree
app = App(state=state if state is FileUploadState else FileStateBase1)
app = App(state=State)
state_name = state.get_full_name().partition(".")[2] or state.get_name()
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state_name}.bg_upload",
"reflex-event-handler": f"{state.get_full_name()}.bg_upload",
}
file_mock = unittest.mock.Mock(filename="image1.jpg")
fn = upload(app)
@ -850,7 +846,7 @@ async def test_upload_file_background(state, tmp_path, token):
await fn(request_mock, [file_mock])
assert (
err.value.args[0]
== f"@rx.background is not supported for upload handler `{state_name}.bg_upload`."
== f"@rx.background is not supported for upload handler `{state.get_full_name()}.bg_upload`."
)
if isinstance(app.state_manager, StateManagerRedis):
@ -932,9 +928,10 @@ async def test_dynamic_route_var_route_change_completed_on_load(
}
assert constants.ROUTER in app.state()._computed_var_dependencies
substate_token = f"{token}_{DynamicState.get_full_name()}"
sid = "mock_sid"
client_ip = "127.0.0.1"
state = await app.state_manager.get_state(token)
state = await app.state_manager.get_state(substate_token)
assert state.dynamic == ""
exp_vals = ["foo", "foobar", "baz"]
@ -1004,7 +1001,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
)
if isinstance(app.state_manager, StateManagerRedis):
# When redis is used, the state is not updated until the processing is complete
state = await app.state_manager.get_state(token)
state = await app.state_manager.get_state(substate_token)
assert state.dynamic == prev_exp_val
# complete the processing
@ -1012,7 +1009,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
await process_coro.__anext__()
# check that router data was written to the state_manager store
state = await app.state_manager.get_state(token)
state = await app.state_manager.get_state(substate_token)
assert state.dynamic == exp_val
process_coro = process(
@ -1087,7 +1084,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
await process_coro.__anext__()
prev_exp_val = exp_val
state = await app.state_manager.get_state(token)
state = await app.state_manager.get_state(substate_token)
assert state.loaded == len(exp_vals)
assert state.counter == len(exp_vals)
# print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
@ -1124,7 +1121,7 @@ async def test_process_events(mocker, token: str):
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
pass
assert (await app.state_manager.get_state(token)).value == 5
assert (await app.state_manager.get_state(event.substate_token)).value == 5
assert app.postprocess.call_count == 6
if isinstance(app.state_manager, StateManagerRedis):

View File

@ -38,8 +38,8 @@ from reflex.vars import BaseVar, ComputedVar
from .states import GenState
CI = bool(os.environ.get("CI", False))
LOCK_EXPIRATION = 2000 if CI else 100
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
LOCK_EXPIRATION = 2000 if CI else 300
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4
formatted_router = {
@ -1432,15 +1432,32 @@ def state_manager(request) -> Generator[StateManager, None, None]:
asyncio.get_event_loop().run_until_complete(state_manager.close())
@pytest.fixture()
def substate_token(state_manager, token):
"""A token + substate name for looking up in state manager.
Args:
state_manager: A state manager instance.
token: A token.
Returns:
Token concatenated with the state_manager's state full_name.
"""
return f"{token}_{state_manager.state.get_full_name()}"
@pytest.mark.asyncio
async def test_state_manager_modify_state(state_manager: StateManager, token: str):
async def test_state_manager_modify_state(
state_manager: StateManager, token: str, substate_token: str
):
"""Test that the state manager can modify a state exclusively.
Args:
state_manager: A state manager instance.
token: A token.
substate_token: A token + substate name for looking up in state manager.
"""
async with state_manager.modify_state(token):
async with state_manager.modify_state(substate_token):
if isinstance(state_manager, StateManagerRedis):
assert await state_manager.redis.get(f"{token}_lock")
elif isinstance(state_manager, StateManagerMemory):
@ -1461,21 +1478,24 @@ async def test_state_manager_modify_state(state_manager: StateManager, token: st
@pytest.mark.asyncio
async def test_state_manager_contend(state_manager: StateManager, token: str):
async def test_state_manager_contend(
state_manager: StateManager, token: str, substate_token: str
):
"""Multiple coroutines attempting to access the same state.
Args:
state_manager: A state manager instance.
token: A token.
substate_token: A token + substate name for looking up in state manager.
"""
n_coroutines = 10
exp_num1 = 10
async with state_manager.modify_state(token) as state:
async with state_manager.modify_state(substate_token) as state:
state.num1 = 0
async def _coro():
async with state_manager.modify_state(token) as state:
async with state_manager.modify_state(substate_token) as state:
await asyncio.sleep(0.01)
state.num1 += 1
@ -1484,7 +1504,7 @@ async def test_state_manager_contend(state_manager: StateManager, token: str):
for f in asyncio.as_completed(tasks):
await f
assert (await state_manager.get_state(token)).num1 == exp_num1
assert (await state_manager.get_state(substate_token)).num1 == exp_num1
if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
@ -1510,33 +1530,51 @@ def state_manager_redis() -> Generator[StateManager, None, None]:
asyncio.get_event_loop().run_until_complete(state_manager.close())
@pytest.fixture()
def substate_token_redis(state_manager_redis, token):
"""A token + substate name for looking up in state manager.
Args:
state_manager_redis: A state manager instance.
token: A token.
Returns:
Token concatenated with the state_manager's state full_name.
"""
return f"{token}_{state_manager_redis.state.get_full_name()}"
@pytest.mark.asyncio
async def test_state_manager_lock_expire(state_manager_redis: StateManager, token: str):
async def test_state_manager_lock_expire(
state_manager_redis: StateManager, token: str, substate_token_redis: str
):
"""Test that the state manager lock expires and raises exception exiting context.
Args:
state_manager_redis: A state manager instance.
token: A token.
substate_token_redis: A token + substate name for looking up in state manager.
"""
state_manager_redis.lock_expiration = LOCK_EXPIRATION
async with state_manager_redis.modify_state(token):
async with state_manager_redis.modify_state(substate_token_redis):
await asyncio.sleep(0.01)
with pytest.raises(LockExpiredError):
async with state_manager_redis.modify_state(token):
async with state_manager_redis.modify_state(substate_token_redis):
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
@pytest.mark.asyncio
async def test_state_manager_lock_expire_contend(
state_manager_redis: StateManager, token: str
state_manager_redis: StateManager, token: str, substate_token_redis: str
):
"""Test that the state manager lock expires and queued waiters proceed.
Args:
state_manager_redis: A state manager instance.
token: A token.
substate_token_redis: A token + substate name for looking up in state manager.
"""
exp_num1 = 4252
unexp_num1 = 666
@ -1546,7 +1584,7 @@ async def test_state_manager_lock_expire_contend(
order = []
async def _coro_blocker():
async with state_manager_redis.modify_state(token) as state:
async with state_manager_redis.modify_state(substate_token_redis) as state:
order.append("blocker")
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
state.num1 = unexp_num1
@ -1554,7 +1592,7 @@ async def test_state_manager_lock_expire_contend(
async def _coro_waiter():
while "blocker" not in order:
await asyncio.sleep(0.005)
async with state_manager_redis.modify_state(token) as state:
async with state_manager_redis.modify_state(substate_token_redis) as state:
order.append("waiter")
assert state.num1 != unexp_num1
state.num1 = exp_num1
@ -1568,7 +1606,7 @@ async def test_state_manager_lock_expire_contend(
await tasks[1]
assert order == ["blocker", "waiter"]
assert (await state_manager_redis.get_state(token)).num1 == exp_num1
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
@pytest.fixture(scope="function")
@ -1643,7 +1681,8 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
assert sp.value2 == 42
# Get the state from the state manager directly and check that the value is updated
gotten_state = await mock_app.state_manager.get_state(grandchild_state.get_token())
gc_token = f"{grandchild_state.get_token()}_{grandchild_state.get_full_name()}"
gotten_state = await mock_app.state_manager.get_state(gc_token)
if isinstance(mock_app.state_manager, StateManagerMemory):
# For in-process store, only one instance of the state exists
assert gotten_state is parent_state
@ -1836,7 +1875,8 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
"private",
]
assert (await mock_app.state_manager.get_state(token)).order == exp_order
substate_token = f"{token}_{BackgroundTaskState.get_name()}"
assert (await mock_app.state_manager.get_state(substate_token)).order == exp_order
assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit
@ -1913,7 +1953,8 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
await task
assert not mock_app.background_tasks
assert (await mock_app.state_manager.get_state(token)).order == [
substate_token = f"{token}_{BackgroundTaskState.get_name()}"
assert (await mock_app.state_manager.get_state(substate_token)).order == [
"reset",
]