From 756bf9b0f4a0f38acc2c631bce3ebf227aafc882 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 21 Feb 2024 01:50:25 -0800 Subject: [PATCH] [REF-1885] Shard Substates when serializing to Redis (#2574) * Move sharding internal to StateManager Avoid leaking sharding implementation details all over the State class and breaking the API * WiP StateManager based sharding * Copy the state __dict__ when serializing to avoid breaking the instance * State tests need to pass the correct substate token for redis * state: when getting parent_state, set top_level=False ensure that we don't end up with a broken tree * test_app: get tests passing with redis by passing the correct token refactor upload tests to suck less * test_client_storage: look up substate key * state.py: pass static checks * test_dynamic_routes: working with redis state shard * Update the remaining AppHarness tests to pass {token}_{state.get_full_name()} * test_app: pass all tokens with state suffix * StateManagerRedis: clean up commentary --- integration/test_client_storage.py | 2 +- integration/test_dynamic_routes.py | 11 +- integration/test_event_actions.py | 2 +- integration/test_event_chain.py | 2 +- integration/test_form_submit.py | 6 +- integration/test_input.py | 20 ++-- integration/test_upload.py | 19 +++- reflex/app.py | 7 +- reflex/event.py | 10 ++ reflex/state.py | 170 ++++++++++++++++++++++++++--- reflex/testing.py | 1 + tests/test_app.py | 41 ++++--- tests/test_state.py | 77 ++++++++++--- 13 files changed, 290 insertions(+), 78 deletions(-) diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index c381d3a3e..3f7ff33f8 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -449,7 +449,7 @@ async def test_client_side_state( assert l1s.text == "l1s value" # reset the backend state to force refresh from client storage - async with client_side.modify_state(token) as state: + async with client_side.modify_state(f"{token}_state.client_side_state") as state: state.reset() driver.refresh() diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index 9fa597d2b..97a2a7070 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -85,6 +85,7 @@ def dynamic_route( """ with app_harness_env.create( root=tmp_path_factory.mktemp(f"dynamic_route"), + app_name=f"dynamicroute_{app_harness_env.__name__.lower()}", app_source=DynamicRoute, # type: ignore ) as harness: yield harness @@ -146,7 +147,7 @@ def poll_for_order( async def _poll_for_order(exp_order: list[str]): async def _backend_state(): - return await dynamic_route.get_state(token) + return await dynamic_route.get_state(f"{token}_state.dynamic_state") async def _check(): return (await _backend_state()).substates[ @@ -194,7 +195,9 @@ async def test_on_load_navigate( assert link assert page_id_input - assert dynamic_route.poll_for_value(page_id_input) == str(ix) + assert dynamic_route.poll_for_value( + page_id_input, exp_not_equal=str(ix - 1) + ) == str(ix) assert dynamic_route.poll_for_value(raw_path_input) == f"/page/{ix}/" await poll_for_order(exp_order) @@ -220,7 +223,9 @@ async def test_on_load_navigate( with poll_for_navigation(driver): driver.get(f"{driver.current_url}?foo=bar") await poll_for_order(exp_order) - assert (await dynamic_route.get_state(token)).router.page.params["foo"] == "bar" + assert ( + await dynamic_route.get_state(f"{token}_state.dynamic_state") + ).router.page.params["foo"] == "bar" # hit a 404 and ensure we still hydrate exp_order += ["/404-no page id"] diff --git a/integration/test_event_actions.py b/integration/test_event_actions.py index 05e333f3c..67605639a 100644 --- a/integration/test_event_actions.py +++ b/integration/test_event_actions.py @@ -207,7 +207,7 @@ def poll_for_order( async def _poll_for_order(exp_order: list[str]): async def _backend_state(): - return await event_action.get_state(token) + return await event_action.get_state(f"{token}_state.event_action_state") async def _check(): return (await _backend_state()).substates[ diff --git a/integration/test_event_chain.py b/integration/test_event_chain.py index 587e0f63d..da9a197b3 100644 --- a/integration/test_event_chain.py +++ b/integration/test_event_chain.py @@ -298,7 +298,7 @@ def assert_token(event_chain: AppHarness, driver: WebDriver) -> str: token = event_chain.poll_for_value(token_input) assert token is not None - return token + return f"{token}_state.state" @pytest.mark.parametrize( diff --git a/integration/test_form_submit.py b/integration/test_form_submit.py index 655630478..ef5631cd7 100644 --- a/integration/test_form_submit.py +++ b/integration/test_form_submit.py @@ -221,7 +221,11 @@ async def test_submit(driver, form_submit: AppHarness): submit_input.click() async def get_form_data(): - return (await form_submit.get_state(token)).substates["form_state"].form_data + return ( + (await form_submit.get_state(f"{token}_state.form_state")) + .substates["form_state"] + .form_data + ) # wait for the form data to arrive at the backend form_data = await AppHarness._poll_for_async(get_form_data) diff --git a/integration/test_input.py b/integration/test_input.py index cbb12de24..dc1e7475a 100644 --- a/integration/test_input.py +++ b/integration/test_input.py @@ -76,6 +76,10 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): token = fully_controlled_input.poll_for_value(token_input) assert token + async def get_state_text(): + state = await fully_controlled_input.get_state(f"{token}_state.state") + return state.substates["state"].text + # find the input and wait for it to have the initial state value debounce_input = driver.find_element(By.ID, "debounce_input_input") value_input = driver.find_element(By.ID, "value_input") @@ -95,16 +99,14 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): debounce_input.send_keys("foo") time.sleep(0.5) assert debounce_input.get_attribute("value") == "ifoonitial" - assert (await fully_controlled_input.get_state(token)).substates[ - "state" - ].text == "ifoonitial" + assert await get_state_text() == "ifoonitial" assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial" assert fully_controlled_input.poll_for_value(plain_value_input) == "ifoonitial" # clear the input on the backend - async with fully_controlled_input.modify_state(token) as state: + async with fully_controlled_input.modify_state(f"{token}_state.state") as state: state.substates["state"].text = "" - assert (await fully_controlled_input.get_state(token)).substates["state"].text == "" + assert await get_state_text() == "" assert ( fully_controlled_input.poll_for_value( debounce_input, exp_not_equal="ifoonitial" @@ -116,9 +118,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): debounce_input.send_keys("getting testing done") time.sleep(0.5) assert debounce_input.get_attribute("value") == "getting testing done" - assert (await fully_controlled_input.get_state(token)).substates[ - "state" - ].text == "getting testing done" + assert await get_state_text() == "getting testing done" assert fully_controlled_input.poll_for_value(value_input) == "getting testing done" assert ( fully_controlled_input.poll_for_value(plain_value_input) @@ -130,9 +130,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): time.sleep(0.5) assert debounce_input.get_attribute("value") == "overwrite the state" assert on_change_input.get_attribute("value") == "overwrite the state" - assert (await fully_controlled_input.get_state(token)).substates[ - "state" - ].text == "overwrite the state" + assert await get_state_text() == "overwrite the state" assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state" assert ( fully_controlled_input.poll_for_value(plain_value_input) diff --git a/integration/test_upload.py b/integration/test_upload.py index f498c6dd4..c703a6747 100644 --- a/integration/test_upload.py +++ b/integration/test_upload.py @@ -171,6 +171,7 @@ async def test_upload_file( # wait for the backend connection to send the token token = upload_file.poll_for_value(token_input) assert token is not None + substate_token = f"{token}_state.upload_state" suffix = "_secondary" if secondary else "" @@ -191,7 +192,11 @@ async def test_upload_file( # look up the backend state and assert on uploaded contents async def get_file_data(): - return (await upload_file.get_state(token)).substates["upload_state"]._file_data + return ( + (await upload_file.get_state(substate_token)) + .substates["upload_state"] + ._file_data + ) file_data = await AppHarness._poll_for_async(get_file_data) assert isinstance(file_data, dict) @@ -201,7 +206,7 @@ async def test_upload_file( selected_files = driver.find_element(By.ID, f"selected_files{suffix}") assert selected_files.text == exp_name - state = await upload_file.get_state(token) + state = await upload_file.get_state(substate_token) if secondary: # only the secondary form tracks progress and chain events assert state.substates["upload_state"].event_order.count("upload_progress") == 1 @@ -223,6 +228,7 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): # wait for the backend connection to send the token token = upload_file.poll_for_value(token_input) assert token is not None + substate_token = f"{token}_state.upload_state" upload_box = driver.find_element(By.XPATH, "//input[@type='file']") assert upload_box @@ -250,7 +256,11 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): # look up the backend state and assert on uploaded contents async def get_file_data(): - return (await upload_file.get_state(token)).substates["upload_state"]._file_data + return ( + (await upload_file.get_state(substate_token)) + .substates["upload_state"] + ._file_data + ) file_data = await AppHarness._poll_for_async(get_file_data) assert isinstance(file_data, dict) @@ -330,6 +340,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive # wait for the backend connection to send the token token = upload_file.poll_for_value(token_input) assert token is not None + substate_token = f"{token}_state.upload_state" upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1] upload_button = driver.find_element(By.ID, f"upload_button_secondary") @@ -347,7 +358,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive cancel_button.click() # look up the backend state and assert on progress - state = await upload_file.get_state(token) + state = await upload_file.get_state(substate_token) assert state.substates["upload_state"].progress_dicts assert exp_name not in state.substates["upload_state"]._file_data diff --git a/reflex/app.py b/reflex/app.py index 7ab4b5abe..245a52214 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -926,7 +926,7 @@ async def process( } ) # Get the state for the session exclusively. - async with app.state_manager.modify_state(event.token) as state: + async with app.state_manager.modify_state(event.substate_token) as state: # re-assign only when the value is different if state.router_data != router_data: # assignment will recurse into substates and force recalculation of @@ -1002,7 +1002,8 @@ def upload(app: App): ) # Get the state for the session. - state = await app.state_manager.get_state(token) + substate_token = token + "_" + handler.rpartition(".")[0] + state = await app.state_manager.get_state(substate_token) # get the current session ID # get the current state(parent state/substate) @@ -1049,7 +1050,7 @@ def upload(app: App): Each state update as JSON followed by a new line. """ # Process the event. - async with app.state_manager.modify_state(token) as state: + async with app.state_manager.modify_state(event.substate_token) as state: async for update in state._process(event): # Postprocess the event. update = await app.postprocess(state, event, update) diff --git a/reflex/event.py b/reflex/event.py index 499b4877a..d81e257f9 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -41,6 +41,16 @@ class Event(Base): # The event payload. payload: Dict[str, Any] = {} + @property + def substate_token(self) -> str: + """Get the substate token for the event. + + Returns: + The substate token. + """ + substate = self.name.rpartition(".")[0] + return f"{self.token}_{substate}" + BACKGROUND_TASK_MARKER = "_reflex_background_task" diff --git a/reflex/state.py b/reflex/state.py index e15df6daf..a42f85a64 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -213,21 +213,29 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # The router data for the current page router: RouterData = RouterData() - def __init__(self, *args, parent_state: BaseState | None = None, **kwargs): + def __init__( + self, + *args, + parent_state: BaseState | None = None, + init_substates: bool = True, + **kwargs, + ): """Initialize the state. Args: *args: The args to pass to the Pydantic init method. parent_state: The parent state. + init_substates: Whether to initialize the substates in this instance. **kwargs: The kwargs to pass to the Pydantic init method. """ kwargs["parent_state"] = parent_state super().__init__(*args, **kwargs) - # Setup the substates. - for substate in self.get_substates(): - self.substates[substate.get_name()] = substate(parent_state=self) + # Setup the substates (for memory state manager only). + if init_substates: + for substate in self.get_substates(): + self.substates[substate.get_name()] = substate(parent_state=self) # Convert the event handlers to functions. self._init_event_handlers() @@ -1005,7 +1013,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for substate in self.substates.values(): substate._reset_client_storage() - def get_substate(self, path: Sequence[str]) -> BaseState | None: + def get_substate(self, path: Sequence[str]) -> BaseState: """Get the substate. Args: @@ -1260,6 +1268,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Recursively find the substate deltas. substates = self.substates for substate in self.dirty_substates.union(self._always_dirty_substates): + if substate not in substates: + continue # substate not loaded at this time, no delta delta.update(substates[substate].get_delta()) # Format the delta. @@ -1287,6 +1297,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for var in self.dirty_vars: for substate_name in self._substate_var_dependencies[var]: self.dirty_substates.add(substate_name) + if substate_name not in substates: + continue substate = substates[substate_name] substate.dirty_vars.add(var) substate._mark_dirty() @@ -1295,6 +1307,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): """Reset the dirty vars.""" # Recursively clean the substates. for substate in self.dirty_substates: + if substate not in self.substates: + continue self.substates[substate]._clean() # Clean this state. @@ -1380,6 +1394,24 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): """ pass + def __getstate__(self): + """Get the state for redis serialization. + + This method is called by cloudpickle to serialize the object. + + It explicitly removes parent_state and substates because those are serialized separately + by the StateManagerRedis to allow for better horizontal scaling as state size increases. + + Returns: + The state dict for serialization. + """ + state = super().__getstate__() + # Never serialize parent_state or substates + state["__dict__"] = state["__dict__"].copy() + state["__dict__"]["parent_state"] = None + state["__dict__"]["substates"] = {} + return state + class State(BaseState): """The app Base State.""" @@ -1479,6 +1511,8 @@ class StateProxy(wrapt.ObjectProxy): """ self._self_actx = self._self_app.modify_state( self.__wrapped__.router.session.client_token + + "_" + + ".".join(self._self_substate_path) ) mutable_state = await self._self_actx.__aenter__() super().__setattr__( @@ -1675,6 +1709,8 @@ class StateManagerMemory(StateManager): Returns: The state for the token. """ + # Memory state manager ignores the substate suffix and always returns the top-level state. + token = token.partition("_")[0] if token not in self.states: self.states[token] = self.state() return self.states[token] @@ -1698,6 +1734,8 @@ class StateManagerMemory(StateManager): Yields: The state for the token. """ + # Memory state manager ignores the substate suffix and always returns the top-level state. + token = token.partition("_")[0] if token not in self._states_locks: async with self._state_manager_lock: if token not in self._states_locks: @@ -1737,23 +1775,104 @@ class StateManagerRedis(StateManager): b"evicted", } - async def get_state(self, token: str) -> BaseState: + async def get_state( + self, + token: str, + top_level: bool = True, + get_substates: bool = True, + parent_state: BaseState | None = None, + ) -> BaseState: """Get the state for a token. Args: token: The token to get the state for. + top_level: If true, return an instance of the top-level state. + get_substates: If true, also retrieve substates + parent_state: If provided, use this parent_state instead of getting it from redis. Returns: The state for the token. + + Raises: + RuntimeError: when the state_cls is not specified in the token """ + # Split the actual token from the fully qualified substate name. + client_token, _, state_path = token.partition("_") + if state_path: + # Get the State class associated with the given path. + state_cls = self.state.get_class_substate(tuple(state_path.split("."))) + else: + raise RuntimeError( + "StateManagerRedis requires token to be specified in the form of {token}_{state_full_name}" + ) + + # Fetch the serialized substate from redis. redis_state = await self.redis.get(token) - if redis_state is None: - await self.set_state(token, self.state()) - return await self.get_state(token) - return cloudpickle.loads(redis_state) + + if redis_state is not None: + # Deserialize the substate. + state = cloudpickle.loads(redis_state) + + # Populate parent and substates if requested. + if parent_state is None: + # Retrieve the parent state from redis. + parent_state_name = state_path.rpartition(".")[0] + if parent_state_name: + parent_state_key = token.rpartition(".")[0] + parent_state = await self.get_state( + parent_state_key, top_level=False, get_substates=False + ) + # Set up Bidirectional linkage between this state and its parent. + if parent_state is not None: + parent_state.substates[state.get_name()] = state + state.parent_state = parent_state + if get_substates: + # Retrieve all substates from redis. + for substate_cls in state_cls.get_substates(): + substate_name = substate_cls.get_name() + substate_key = token + "." + substate_name + state.substates[substate_name] = await self.get_state( + substate_key, top_level=False, parent_state=state + ) + # To retain compatibility with previous implementation, by default, we return + # the top-level state by chasing `parent_state` pointers up the tree. + if top_level: + while type(state) != self.state and state.parent_state is not None: + state = state.parent_state + return state + + # Key didn't exist so we have to create a new entry for this token. + if parent_state is None: + parent_state_name = state_path.rpartition(".")[0] + if parent_state_name: + # Retrieve the parent state to populate event handlers onto this substate. + parent_state_key = client_token + "_" + parent_state_name + parent_state = await self.get_state( + parent_state_key, top_level=False, get_substates=False + ) + # Persist the new state class to redis. + await self.set_state( + token, + state_cls( + parent_state=parent_state, + init_substates=False, + ), + ) + # After creating the state key, recursively call `get_state` to populate substates. + return await self.get_state( + token, + top_level=top_level, + get_substates=get_substates, + parent_state=parent_state, + ) async def set_state( - self, token: str, state: BaseState, lock_id: bytes | None = None + self, + token: str, + state: BaseState, + lock_id: bytes | None = None, + set_substates: bool = True, + set_parent_state: bool = True, ): """Set the state for a token. @@ -1761,11 +1880,13 @@ class StateManagerRedis(StateManager): token: The token to set the state for. state: The state to set. lock_id: If provided, the lock_key must be set to this value to set the state. + set_substates: If True, write substates to redis + set_parent_state: If True, write parent state to redis Raises: LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID. """ - # check that we're holding the lock + # Check that we're holding the lock. if ( lock_id is not None and await self.redis.get(self._lock_key(token)) != lock_id @@ -1775,6 +1896,27 @@ class StateManagerRedis(StateManager): f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " "or use `@rx.background` decorator for long-running tasks." ) + # Find the substate associated with the token. + state_path = token.partition("_")[2] + if state_path and state.get_full_name() != state_path: + state = state.get_substate(tuple(state_path.split("."))) + # Persist the parent state separately, if requested. + if state.parent_state is not None and set_parent_state: + parent_state_key = token.rpartition(".")[0] + await self.set_state( + parent_state_key, + state.parent_state, + lock_id=lock_id, + set_substates=False, + ) + # Persist the substates separately, if requested. + if set_substates: + for substate_name, substate in state.substates.items(): + substate_key = token + "." + substate_name + await self.set_state( + substate_key, substate, lock_id=lock_id, set_parent_state=False + ) + # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration) @contextlib.asynccontextmanager @@ -1802,7 +1944,9 @@ class StateManagerRedis(StateManager): Returns: The redis lock key for the token. """ - return f"{token}_lock".encode() + # All substates share the same lock domain, so ignore any substate path suffix. + client_token = token.partition("_")[0] + return f"{client_token}_lock".encode() async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None: """Try to get a redis lock for a token. diff --git a/reflex/testing.py b/reflex/testing.py index 9b9ea1c06..78d220847 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -220,6 +220,7 @@ class AppHarness: reflex.config.get_config(reload=True) # reset rx.State subclasses State.class_subclasses.clear() + State.get_class_substate.cache_clear() # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) # self.app_module.app. diff --git a/tests/test_app.py b/tests/test_app.py index 08b24b2ea..de7e5cdd9 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -340,7 +340,7 @@ async def test_initialize_with_state(test_state: Type[ATestState], token: str): assert app.state == test_state # Get a state for a given token. - state = await app.state_manager.get_state(token) + state = await app.state_manager.get_state(f"{token}_{test_state.get_full_name()}") assert isinstance(state, test_state) assert state.var == 0 # type: ignore @@ -358,8 +358,8 @@ async def test_set_and_get_state(test_state): app = App(state=test_state) # Create two tokens. - token1 = str(uuid.uuid4()) - token2 = str(uuid.uuid4()) + token1 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}" + token2 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}" # Get the default state for each token. state1 = await app.state_manager.get_state(token1) @@ -744,18 +744,18 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): # The App state must be the "root" of the state tree app = App(state=State) app.event_namespace.emit = AsyncMock() # type: ignore - current_state = await app.state_manager.get_state(token) + substate_token = f"{token}_{state.get_full_name()}" + current_state = await app.state_manager.get_state(substate_token) data = b"This is binary data" # Create a binary IO object and write data to it bio = io.BytesIO() bio.write(data) - state_name = state.get_full_name().partition(".")[2] or state.get_name() request_mock = unittest.mock.Mock() request_mock.headers = { "reflex-client-token": token, - "reflex-event-handler": f"state.{state_name}.multi_handle_upload", + "reflex-event-handler": f"{state.get_full_name()}.multi_handle_upload", } file1 = UploadFile( @@ -774,7 +774,7 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker): == StateUpdate(delta=delta, events=[], final=True).json() + "\n" ) - current_state = await app.state_manager.get_state(token) + current_state = await app.state_manager.get_state(substate_token) state_dict = current_state.dict()[state.get_full_name()] assert state_dict["img_list"] == [ "image1.jpg", @@ -799,14 +799,12 @@ async def test_upload_file_without_annotation(state, tmp_path, token): token: a Token. """ state._tmp_path = tmp_path - # The App state must be the "root" of the state tree - app = App(state=state if state is FileUploadState else FileStateBase1) + app = App(state=State) - state_name = state.get_full_name().partition(".")[2] or state.get_name() request_mock = unittest.mock.Mock() request_mock.headers = { "reflex-client-token": token, - "reflex-event-handler": f"{state_name}.handle_upload2", + "reflex-event-handler": f"{state.get_full_name()}.handle_upload2", } file_mock = unittest.mock.Mock(filename="image1.jpg") fn = upload(app) @@ -814,7 +812,7 @@ async def test_upload_file_without_annotation(state, tmp_path, token): await fn(request_mock, [file_mock]) assert ( err.value.args[0] - == f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]" + == f"`{state.get_full_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]" ) if isinstance(app.state_manager, StateManagerRedis): @@ -835,14 +833,12 @@ async def test_upload_file_background(state, tmp_path, token): token: a Token. """ state._tmp_path = tmp_path - # The App state must be the "root" of the state tree - app = App(state=state if state is FileUploadState else FileStateBase1) + app = App(state=State) - state_name = state.get_full_name().partition(".")[2] or state.get_name() request_mock = unittest.mock.Mock() request_mock.headers = { "reflex-client-token": token, - "reflex-event-handler": f"{state_name}.bg_upload", + "reflex-event-handler": f"{state.get_full_name()}.bg_upload", } file_mock = unittest.mock.Mock(filename="image1.jpg") fn = upload(app) @@ -850,7 +846,7 @@ async def test_upload_file_background(state, tmp_path, token): await fn(request_mock, [file_mock]) assert ( err.value.args[0] - == f"@rx.background is not supported for upload handler `{state_name}.bg_upload`." + == f"@rx.background is not supported for upload handler `{state.get_full_name()}.bg_upload`." ) if isinstance(app.state_manager, StateManagerRedis): @@ -932,9 +928,10 @@ async def test_dynamic_route_var_route_change_completed_on_load( } assert constants.ROUTER in app.state()._computed_var_dependencies + substate_token = f"{token}_{DynamicState.get_full_name()}" sid = "mock_sid" client_ip = "127.0.0.1" - state = await app.state_manager.get_state(token) + state = await app.state_manager.get_state(substate_token) assert state.dynamic == "" exp_vals = ["foo", "foobar", "baz"] @@ -1004,7 +1001,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( ) if isinstance(app.state_manager, StateManagerRedis): # When redis is used, the state is not updated until the processing is complete - state = await app.state_manager.get_state(token) + state = await app.state_manager.get_state(substate_token) assert state.dynamic == prev_exp_val # complete the processing @@ -1012,7 +1009,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( await process_coro.__anext__() # check that router data was written to the state_manager store - state = await app.state_manager.get_state(token) + state = await app.state_manager.get_state(substate_token) assert state.dynamic == exp_val process_coro = process( @@ -1087,7 +1084,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( await process_coro.__anext__() prev_exp_val = exp_val - state = await app.state_manager.get_state(token) + state = await app.state_manager.get_state(substate_token) assert state.loaded == len(exp_vals) assert state.counter == len(exp_vals) # print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}") @@ -1124,7 +1121,7 @@ async def test_process_events(mocker, token: str): async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): pass - assert (await app.state_manager.get_state(token)).value == 5 + assert (await app.state_manager.get_state(event.substate_token)).value == 5 assert app.postprocess.call_count == 6 if isinstance(app.state_manager, StateManagerRedis): diff --git a/tests/test_state.py b/tests/test_state.py index 423886655..dbcc1afce 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -38,8 +38,8 @@ from reflex.vars import BaseVar, ComputedVar from .states import GenState CI = bool(os.environ.get("CI", False)) -LOCK_EXPIRATION = 2000 if CI else 100 -LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2 +LOCK_EXPIRATION = 2000 if CI else 300 +LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.4 formatted_router = { @@ -1432,15 +1432,32 @@ def state_manager(request) -> Generator[StateManager, None, None]: asyncio.get_event_loop().run_until_complete(state_manager.close()) +@pytest.fixture() +def substate_token(state_manager, token): + """A token + substate name for looking up in state manager. + + Args: + state_manager: A state manager instance. + token: A token. + + Returns: + Token concatenated with the state_manager's state full_name. + """ + return f"{token}_{state_manager.state.get_full_name()}" + + @pytest.mark.asyncio -async def test_state_manager_modify_state(state_manager: StateManager, token: str): +async def test_state_manager_modify_state( + state_manager: StateManager, token: str, substate_token: str +): """Test that the state manager can modify a state exclusively. Args: state_manager: A state manager instance. token: A token. + substate_token: A token + substate name for looking up in state manager. """ - async with state_manager.modify_state(token): + async with state_manager.modify_state(substate_token): if isinstance(state_manager, StateManagerRedis): assert await state_manager.redis.get(f"{token}_lock") elif isinstance(state_manager, StateManagerMemory): @@ -1461,21 +1478,24 @@ async def test_state_manager_modify_state(state_manager: StateManager, token: st @pytest.mark.asyncio -async def test_state_manager_contend(state_manager: StateManager, token: str): +async def test_state_manager_contend( + state_manager: StateManager, token: str, substate_token: str +): """Multiple coroutines attempting to access the same state. Args: state_manager: A state manager instance. token: A token. + substate_token: A token + substate name for looking up in state manager. """ n_coroutines = 10 exp_num1 = 10 - async with state_manager.modify_state(token) as state: + async with state_manager.modify_state(substate_token) as state: state.num1 = 0 async def _coro(): - async with state_manager.modify_state(token) as state: + async with state_manager.modify_state(substate_token) as state: await asyncio.sleep(0.01) state.num1 += 1 @@ -1484,7 +1504,7 @@ async def test_state_manager_contend(state_manager: StateManager, token: str): for f in asyncio.as_completed(tasks): await f - assert (await state_manager.get_state(token)).num1 == exp_num1 + assert (await state_manager.get_state(substate_token)).num1 == exp_num1 if isinstance(state_manager, StateManagerRedis): assert (await state_manager.redis.get(f"{token}_lock")) is None @@ -1510,33 +1530,51 @@ def state_manager_redis() -> Generator[StateManager, None, None]: asyncio.get_event_loop().run_until_complete(state_manager.close()) +@pytest.fixture() +def substate_token_redis(state_manager_redis, token): + """A token + substate name for looking up in state manager. + + Args: + state_manager_redis: A state manager instance. + token: A token. + + Returns: + Token concatenated with the state_manager's state full_name. + """ + return f"{token}_{state_manager_redis.state.get_full_name()}" + + @pytest.mark.asyncio -async def test_state_manager_lock_expire(state_manager_redis: StateManager, token: str): +async def test_state_manager_lock_expire( + state_manager_redis: StateManager, token: str, substate_token_redis: str +): """Test that the state manager lock expires and raises exception exiting context. Args: state_manager_redis: A state manager instance. token: A token. + substate_token_redis: A token + substate name for looking up in state manager. """ state_manager_redis.lock_expiration = LOCK_EXPIRATION - async with state_manager_redis.modify_state(token): + async with state_manager_redis.modify_state(substate_token_redis): await asyncio.sleep(0.01) with pytest.raises(LockExpiredError): - async with state_manager_redis.modify_state(token): + async with state_manager_redis.modify_state(substate_token_redis): await asyncio.sleep(LOCK_EXPIRE_SLEEP) @pytest.mark.asyncio async def test_state_manager_lock_expire_contend( - state_manager_redis: StateManager, token: str + state_manager_redis: StateManager, token: str, substate_token_redis: str ): """Test that the state manager lock expires and queued waiters proceed. Args: state_manager_redis: A state manager instance. token: A token. + substate_token_redis: A token + substate name for looking up in state manager. """ exp_num1 = 4252 unexp_num1 = 666 @@ -1546,7 +1584,7 @@ async def test_state_manager_lock_expire_contend( order = [] async def _coro_blocker(): - async with state_manager_redis.modify_state(token) as state: + async with state_manager_redis.modify_state(substate_token_redis) as state: order.append("blocker") await asyncio.sleep(LOCK_EXPIRE_SLEEP) state.num1 = unexp_num1 @@ -1554,7 +1592,7 @@ async def test_state_manager_lock_expire_contend( async def _coro_waiter(): while "blocker" not in order: await asyncio.sleep(0.005) - async with state_manager_redis.modify_state(token) as state: + async with state_manager_redis.modify_state(substate_token_redis) as state: order.append("waiter") assert state.num1 != unexp_num1 state.num1 = exp_num1 @@ -1568,7 +1606,7 @@ async def test_state_manager_lock_expire_contend( await tasks[1] assert order == ["blocker", "waiter"] - assert (await state_manager_redis.get_state(token)).num1 == exp_num1 + assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 @pytest.fixture(scope="function") @@ -1643,7 +1681,8 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): assert sp.value2 == 42 # Get the state from the state manager directly and check that the value is updated - gotten_state = await mock_app.state_manager.get_state(grandchild_state.get_token()) + gc_token = f"{grandchild_state.get_token()}_{grandchild_state.get_full_name()}" + gotten_state = await mock_app.state_manager.get_state(gc_token) if isinstance(mock_app.state_manager, StateManagerMemory): # For in-process store, only one instance of the state exists assert gotten_state is parent_state @@ -1836,7 +1875,8 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): "private", ] - assert (await mock_app.state_manager.get_state(token)).order == exp_order + substate_token = f"{token}_{BackgroundTaskState.get_name()}" + assert (await mock_app.state_manager.get_state(substate_token)).order == exp_order assert mock_app.event_namespace is not None emit_mock = mock_app.event_namespace.emit @@ -1913,7 +1953,8 @@ async def test_background_task_reset(mock_app: rx.App, token: str): await task assert not mock_app.background_tasks - assert (await mock_app.state_manager.get_state(token)).order == [ + substate_token = f"{token}_{BackgroundTaskState.get_name()}" + assert (await mock_app.state_manager.get_state(substate_token)).order == [ "reset", ]