From 351611ca255af54c45a718e57cb186c87d8bc3f8 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 21 Sep 2023 11:42:11 -0700 Subject: [PATCH] rx.background and StateManager.modify_state provides safe exclusive access to state (#1676) --- .github/workflows/integration_app_harness.yml | 17 + .github/workflows/unit_tests.yml | 20 + .pre-commit-config.yaml | 12 +- integration/test_background_task.py | 214 ++++++ integration/test_client_storage.py | 37 +- integration/test_dynamic_routes.py | 69 +- integration/test_event_chain.py | 132 +++- integration/test_form_submit.py | 50 +- integration/test_input.py | 30 +- integration/test_server_side_event.py | 12 +- integration/test_upload.py | 25 +- integration/test_var_operations.py | 12 +- reflex/__init__.py | 1 + reflex/app.py | 229 +++++-- reflex/app.pyi | 8 + reflex/constants.py | 2 + reflex/event.py | 86 ++- reflex/state.py | 639 +++++++++++++++--- reflex/testing.py | 152 ++++- reflex/utils/exceptions.py | 8 + reflex/utils/prerequisites.py | 8 +- tests/conftest.py | 403 +---------- tests/states/__init__.py | 30 + tests/states/mutation.py | 172 +++++ tests/states/upload.py | 175 +++++ tests/test_app.py | 358 ++++++---- tests/test_state.py | 428 +++++++++++- 27 files changed, 2517 insertions(+), 812 deletions(-) create mode 100644 integration/test_background_task.py create mode 100644 tests/states/__init__.py create mode 100644 tests/states/mutation.py create mode 100644 tests/states/upload.py diff --git a/.github/workflows/integration_app_harness.yml b/.github/workflows/integration_app_harness.yml index 11108a8fb..55a042bb8 100644 --- a/.github/workflows/integration_app_harness.yml +++ b/.github/workflows/integration_app_harness.yml @@ -15,7 +15,23 @@ permissions: jobs: integration-app-harness: + strategy: + matrix: + state_manager: [ "redis", "memory" ] runs-on: ubuntu-latest + services: + # Label used to access the service container + redis: + image: ${{ matrix.state_manager == 'redis' && 'redis' || '' }} + # Set health checks to wait until redis has started + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + # Maps port 6379 on service container to the host + - 6379:6379 steps: - uses: actions/checkout@v4 - uses: ./.github/actions/setup_build_env @@ -27,6 +43,7 @@ jobs: - name: Run app harness tests env: SCREENSHOT_DIR: /tmp/screenshots + REDIS_URL: ${{ matrix.state_manager == 'redis' && 'localhost:6379' || '' }} run: | poetry run pytest integration - uses: actions/upload-artifact@v3 diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index cb852cdf2..9945507fa 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -40,6 +40,20 @@ jobs: - os: windows-latest python-version: "3.8.10" runs-on: ${{ matrix.os }} + # Service containers to run with `runner-job` + services: + # Label used to access the service container + redis: + image: ${{ matrix.os == 'ubuntu-latest' && 'redis' || '' }} + # Set health checks to wait until redis has started + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + # Maps port 6379 on service container to the host + - 6379:6379 steps: - uses: actions/checkout@v4 - uses: ./.github/actions/setup_build_env @@ -51,4 +65,10 @@ jobs: run: | export PYTHONUNBUFFERED=1 poetry run pytest tests --cov --no-cov-on-fail --cov-report= + - name: Run unit tests w/ redis + if: ${{ matrix.os == 'ubuntu-latest' }} + run: | + export PYTHONUNBUFFERED=1 + export REDIS_URL=localhost:6379 + poetry run pytest tests --cov --no-cov-on-fail --cov-report= - run: poetry run coverage html diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 211cb4d5e..a5fcf4b6b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,10 @@ repos: + - repo: https://github.com/psf/black + rev: 22.10.0 + hooks: + - id: black + args: [integration, reflex, tests] + - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.0.244 hooks: @@ -17,9 +23,3 @@ repos: hooks: - id: darglint exclude: '^reflex/reflex.py' - - - repo: https://github.com/psf/black - rev: 22.10.0 - hooks: - - id: black - args: [integration, reflex, tests] diff --git a/integration/test_background_task.py b/integration/test_background_task.py new file mode 100644 index 000000000..349f3b555 --- /dev/null +++ b/integration/test_background_task.py @@ -0,0 +1,214 @@ +"""Test @rx.background task functionality.""" + +from typing import Generator + +import pytest +from selenium.webdriver.common.by import By + +from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver + + +def BackgroundTask(): + """Test that background tasks work as expected.""" + import asyncio + + import reflex as rx + + class State(rx.State): + counter: int = 0 + _task_id: int = 0 + iterations: int = 10 + + @rx.background + async def handle_event(self): + async with self: + self._task_id += 1 + for _ix in range(int(self.iterations)): + async with self: + self.counter += 1 + await asyncio.sleep(0.005) + + @rx.background + async def handle_event_yield_only(self): + async with self: + self._task_id += 1 + for ix in range(int(self.iterations)): + if ix % 2 == 0: + yield State.increment_arbitrary(1) # type: ignore + else: + yield State.increment() # type: ignore + await asyncio.sleep(0.005) + + def increment(self): + self.counter += 1 + + @rx.background + async def increment_arbitrary(self, amount: int): + async with self: + self.counter += int(amount) + + def reset_counter(self): + self.counter = 0 + + async def blocking_pause(self): + await asyncio.sleep(0.02) + + @rx.background + async def non_blocking_pause(self): + await asyncio.sleep(0.02) + + @rx.cached_var + def token(self) -> str: + return self.get_token() + + def index() -> rx.Component: + return rx.vstack( + rx.input(id="token", value=State.token, is_read_only=True), + rx.heading(State.counter, id="counter"), + rx.input( + id="iterations", + placeholder="Iterations", + value=State.iterations.to_string(), # type: ignore + on_change=State.set_iterations, # type: ignore + ), + rx.button( + "Delayed Increment", + on_click=State.handle_event, + id="delayed-increment", + ), + rx.button( + "Yield Increment", + on_click=State.handle_event_yield_only, + id="yield-increment", + ), + rx.button("Increment 1", on_click=State.increment, id="increment"), + rx.button( + "Blocking Pause", + on_click=State.blocking_pause, + id="blocking-pause", + ), + rx.button( + "Non-Blocking Pause", + on_click=State.non_blocking_pause, + id="non-blocking-pause", + ), + rx.button("Reset", on_click=State.reset_counter, id="reset"), + ) + + app = rx.App(state=State) + app.add_page(index) + app.compile() + + +@pytest.fixture(scope="session") +def background_task( + tmp_path_factory, +) -> Generator[AppHarness, None, None]: + """Start BackgroundTask app at tmp_path via AppHarness. + + Args: + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path_factory.mktemp(f"background_task"), + app_source=BackgroundTask, # type: ignore + ) as harness: + yield harness + + +@pytest.fixture +def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]: + """Get an instance of the browser open to the background_task app. + + Args: + background_task: harness for BackgroundTask app + + Yields: + WebDriver instance. + """ + assert background_task.app_instance is not None, "app is not running" + driver = background_task.frontend() + try: + yield driver + finally: + driver.quit() + + +@pytest.fixture() +def token(background_task: AppHarness, driver: WebDriver) -> str: + """Get a function that returns the active token. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + + Returns: + The token for the connected client + """ + assert background_task.app_instance is not None + token_input = driver.find_element(By.ID, "token") + assert token_input + + # wait for the backend connection to send the token + token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2) + assert token is not None + + return token + + +def test_background_task( + background_task: AppHarness, + driver: WebDriver, + token: str, +): + """Test that background tasks work as expected. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert background_task.app_instance is not None + + # get a reference to all buttons + delayed_increment_button = driver.find_element(By.ID, "delayed-increment") + yield_increment_button = driver.find_element(By.ID, "yield-increment") + increment_button = driver.find_element(By.ID, "increment") + blocking_pause_button = driver.find_element(By.ID, "blocking-pause") + non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause") + driver.find_element(By.ID, "reset") + + # get a reference to the counter + counter = driver.find_element(By.ID, "counter") + + # get a reference to the iterations input + iterations_input = driver.find_element(By.ID, "iterations") + + # kick off background tasks + iterations_input.clear() + iterations_input.send_keys("50") + delayed_increment_button.click() + blocking_pause_button.click() + delayed_increment_button.click() + for _ in range(10): + increment_button.click() + blocking_pause_button.click() + delayed_increment_button.click() + delayed_increment_button.click() + yield_increment_button.click() + non_blocking_pause_button.click() + yield_increment_button.click() + blocking_pause_button.click() + yield_increment_button.click() + for _ in range(10): + increment_button.click() + yield_increment_button.click() + blocking_pause_button.click() + assert background_task._poll_for(lambda: counter.text == "420", timeout=40) + # all tasks should have exited and cleaned up + assert background_task._poll_for( + lambda: not background_task.app_instance.background_tasks # type: ignore + ) diff --git a/integration/test_client_storage.py b/integration/test_client_storage.py index 269725ed7..2c9d9dc14 100644 --- a/integration/test_client_storage.py +++ b/integration/test_client_storage.py @@ -133,7 +133,6 @@ def driver(client_side: AppHarness) -> Generator[WebDriver, None, None]: assert client_side.app_instance is not None, "app is not running" driver = client_side.frontend() try: - assert client_side.poll_for_clients() yield driver finally: driver.quit() @@ -168,7 +167,20 @@ def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]: driver.delete_all_cookies() -def test_client_side_state( +def cookie_info_map(driver: WebDriver) -> dict[str, dict[str, str]]: + """Get a map of cookie names to cookie info. + + Args: + driver: WebDriver instance. + + Returns: + A map of cookie names to cookie info. + """ + return {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()} + + +@pytest.mark.asyncio +async def test_client_side_state( client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage ): """Test client side state. @@ -187,8 +199,6 @@ def test_client_side_state( token = client_side.poll_for_value(token_input) assert token is not None - backend_state = client_side.app_instance.state_manager.states[token] - # get a reference to the cookie manipulation form state_var_input = driver.find_element(By.ID, "state_var") input_value_input = driver.find_element(By.ID, "input_value") @@ -274,7 +284,7 @@ def test_client_side_state( input_value_input.send_keys("l1s value") set_sub_sub_state_button.click() - cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()} + cookies = cookie_info_map(driver) assert cookies.pop("client_side_state.client_side_sub_state.c1") == { "domain": "localhost", "httpOnly": False, @@ -338,8 +348,10 @@ def test_client_side_state( state_var_input.send_keys("c3") input_value_input.send_keys("c3 value") set_sub_state_button.click() - cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()} - c3_cookie = cookies["client_side_state.client_side_sub_state.c3"] + AppHarness._poll_for( + lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver) + ) + c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"] assert c3_cookie.pop("expiry") is not None assert c3_cookie == { "domain": "localhost", @@ -351,9 +363,7 @@ def test_client_side_state( "value": "c3%20value", } time.sleep(2) # wait for c3 to expire - assert "client_side_state.client_side_sub_state.c3" not in { - cookie_info["name"] for cookie_info in driver.get_cookies() - } + assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver) local_storage_items = local_storage.items() local_storage_items.pop("chakra-ui-color-mode", None) @@ -426,7 +436,8 @@ def test_client_side_state( assert l1s.text == "l1s value" # reset the backend state to force refresh from client storage - backend_state.reset() + async with client_side.modify_state(token) as state: + state.reset() driver.refresh() # wait for the backend connection to send the token (again) @@ -465,9 +476,7 @@ def test_client_side_state( assert l1s.text == "l1s value" # make sure c5 cookie shows up on the `/foo` route - cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()} - - assert cookies["client_side_state.client_side_sub_state.c5"] == { + assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == { "domain": "localhost", "httpOnly": False, "name": "client_side_state.client_side_sub_state.c5", diff --git a/integration/test_dynamic_routes.py b/integration/test_dynamic_routes.py index 6adb41237..ca94fa3ba 100644 --- a/integration/test_dynamic_routes.py +++ b/integration/test_dynamic_routes.py @@ -1,11 +1,10 @@ """Integration tests for dynamic route page behavior.""" -from typing import Callable, Generator, Type +from typing import Callable, Coroutine, Generator, Type from urllib.parse import urlsplit import pytest from selenium.webdriver.common.by import By -from reflex import State from reflex.testing import AppHarness, AppHarnessProd, WebDriver from .utils import poll_for_navigation @@ -100,22 +99,21 @@ def driver(dynamic_route: AppHarness) -> Generator[WebDriver, None, None]: assert dynamic_route.app_instance is not None, "app is not running" driver = dynamic_route.frontend() try: - assert dynamic_route.poll_for_clients() yield driver finally: driver.quit() @pytest.fixture() -def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State: - """Get the backend state. +def token(dynamic_route: AppHarness, driver: WebDriver) -> str: + """Get the token associated with backend state. Args: dynamic_route: harness for DynamicRoute app. driver: WebDriver instance. Returns: - The backend state associated with the token visible in the driver browser. + The token visible in the driver browser. """ assert dynamic_route.app_instance is not None token_input = driver.find_element(By.ID, "token") @@ -125,43 +123,49 @@ def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State: token = dynamic_route.poll_for_value(token_input) assert token is not None - # look up the backend state from the state manager - return dynamic_route.app_instance.state_manager.states[token] + return token @pytest.fixture() def poll_for_order( - dynamic_route: AppHarness, backend_state: State -) -> Callable[[list[str]], None]: + dynamic_route: AppHarness, token: str +) -> Callable[[list[str]], Coroutine[None, None, None]]: """Poll for the order list to match the expected order. Args: dynamic_route: harness for DynamicRoute app. - backend_state: The backend state associated with the token visible in the driver browser. + token: The token visible in the driver browser. Returns: - A function that polls for the order list to match the expected order. + An async function that polls for the order list to match the expected order. """ - def _poll_for_order(exp_order: list[str]): - dynamic_route._poll_for(lambda: backend_state.order == exp_order) - assert backend_state.order == exp_order + async def _poll_for_order(exp_order: list[str]): + async def _backend_state(): + return await dynamic_route.get_state(token) + + async def _check(): + return (await _backend_state()).order == exp_order + + await AppHarness._poll_for_async(_check) + assert (await _backend_state()).order == exp_order return _poll_for_order -def test_on_load_navigate( +@pytest.mark.asyncio +async def test_on_load_navigate( dynamic_route: AppHarness, driver: WebDriver, - backend_state: State, - poll_for_order: Callable[[list[str]], None], + token: str, + poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], ): """Click links to navigate between dynamic pages with on_load event. Args: dynamic_route: harness for DynamicRoute app. driver: WebDriver instance. - backend_state: The backend state associated with the token visible in the driver browser. + token: The token visible in the driver browser. poll_for_order: function that polls for the order list to match the expected order. """ assert dynamic_route.app_instance is not None @@ -184,7 +188,7 @@ def test_on_load_navigate( assert page_id_input assert dynamic_route.poll_for_value(page_id_input) == str(ix) - poll_for_order(exp_order) + await poll_for_order(exp_order) # manually load the next page to trigger client side routing in prod mode if is_prod: @@ -192,14 +196,14 @@ def test_on_load_navigate( exp_order += ["/page/[page_id]-10"] with poll_for_navigation(driver): driver.get(f"{dynamic_route.frontend_url}/page/10/") - poll_for_order(exp_order) + await poll_for_order(exp_order) # make sure internal nav still hydrates after redirect exp_order += ["/page/[page_id]-11"] link = driver.find_element(By.ID, "link_page_next") with poll_for_navigation(driver): link.click() - poll_for_order(exp_order) + await poll_for_order(exp_order) # load same page with a query param and make sure it passes through if is_prod: @@ -207,14 +211,14 @@ def test_on_load_navigate( exp_order += ["/page/[page_id]-11"] with poll_for_navigation(driver): driver.get(f"{driver.current_url}?foo=bar") - poll_for_order(exp_order) - assert backend_state.get_query_params()["foo"] == "bar" + await poll_for_order(exp_order) + assert (await dynamic_route.get_state(token)).get_query_params()["foo"] == "bar" # hit a 404 and ensure we still hydrate exp_order += ["/404-no page id"] with poll_for_navigation(driver): driver.get(f"{dynamic_route.frontend_url}/missing") - poll_for_order(exp_order) + await poll_for_order(exp_order) # browser nav should still trigger hydration if is_prod: @@ -222,14 +226,14 @@ def test_on_load_navigate( exp_order += ["/page/[page_id]-11"] with poll_for_navigation(driver): driver.back() - poll_for_order(exp_order) + await poll_for_order(exp_order) # next/link to a 404 and ensure we still hydrate exp_order += ["/404-no page id"] link = driver.find_element(By.ID, "link_missing") with poll_for_navigation(driver): link.click() - poll_for_order(exp_order) + await poll_for_order(exp_order) # hit a page that redirects back to dynamic page if is_prod: @@ -237,15 +241,16 @@ def test_on_load_navigate( exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"] with poll_for_navigation(driver): driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar") - poll_for_order(exp_order) + await poll_for_order(exp_order) # should have redirected back to page 0 assert urlsplit(driver.current_url).path == "/page/0/" -def test_on_load_navigate_non_dynamic( +@pytest.mark.asyncio +async def test_on_load_navigate_non_dynamic( dynamic_route: AppHarness, driver: WebDriver, - poll_for_order: Callable[[list[str]], None], + poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], ): """Click links to navigate between static pages with on_load event. @@ -261,7 +266,7 @@ def test_on_load_navigate_non_dynamic( with poll_for_navigation(driver): link.click() assert urlsplit(driver.current_url).path == "/static/x/" - poll_for_order(["/static/x-no page id"]) + await poll_for_order(["/static/x-no page id"]) # go back to the index and navigate back to the static route link = driver.find_element(By.ID, "link_index") @@ -273,4 +278,4 @@ def test_on_load_navigate_non_dynamic( with poll_for_navigation(driver): link.click() assert urlsplit(driver.current_url).path == "/static/x/" - poll_for_order(["/static/x-no page id", "/static/x-no page id"]) + await poll_for_order(["/static/x-no page id", "/static/x-no page id"]) diff --git a/integration/test_event_chain.py b/integration/test_event_chain.py index d562bc340..454cf51a6 100644 --- a/integration/test_event_chain.py +++ b/integration/test_event_chain.py @@ -1,18 +1,20 @@ """Ensure that Event Chains are properly queued and handled between frontend and backend.""" -import time from typing import Generator import pytest from selenium.webdriver.common.by import By -from reflex.testing import AppHarness +from reflex.testing import AppHarness, WebDriver MANY_EVENTS = 50 def EventChain(): """App with chained event handlers.""" + import asyncio + import time + import reflex as rx # repeated here since the outer global isn't exported into the App module @@ -20,6 +22,7 @@ def EventChain(): class State(rx.State): event_order: list[str] = [] + interim_value: str = "" @rx.var def token(self) -> str: @@ -111,12 +114,25 @@ def EventChain(): self.event_order.append("click_return_dict_type") return State.event_arg_repr_type({"a": 1}) # type: ignore + async def click_yield_interim_value_async(self): + self.interim_value = "interim" + yield + await asyncio.sleep(0.5) + self.interim_value = "final" + + def click_yield_interim_value(self): + self.interim_value = "interim" + yield + time.sleep(0.5) + self.interim_value = "final" + app = rx.App(state=State) @app.add_page def index(): return rx.fragment( rx.input(value=State.token, readonly=True, id="token"), + rx.input(value=State.interim_value, readonly=True, id="interim_value"), rx.button( "Return Event", id="return_event", @@ -172,6 +188,16 @@ def EventChain(): id="return_dict_type", on_click=State.click_return_dict_type, ), + rx.button( + "Click Yield Interim Value (Async)", + id="click_yield_interim_value_async", + on_click=State.click_yield_interim_value_async, + ), + rx.button( + "Click Yield Interim Value", + id="click_yield_interim_value", + on_click=State.click_yield_interim_value, + ), ) def on_load_return_chain(): @@ -237,7 +263,7 @@ def event_chain(tmp_path_factory) -> Generator[AppHarness, None, None]: @pytest.fixture -def driver(event_chain: AppHarness): +def driver(event_chain: AppHarness) -> Generator[WebDriver, None, None]: """Get an instance of the browser open to the event_chain app. Args: @@ -249,7 +275,6 @@ def driver(event_chain: AppHarness): assert event_chain.app_instance is not None, "app is not running" driver = event_chain.frontend() try: - assert event_chain.poll_for_clients() yield driver finally: driver.quit() @@ -335,7 +360,13 @@ def driver(event_chain: AppHarness): ), ], ) -def test_event_chain_click(event_chain, driver, button_id, exp_event_order): +@pytest.mark.asyncio +async def test_event_chain_click( + event_chain: AppHarness, + driver: WebDriver, + button_id: str, + exp_event_order: list[str], +): """Click the button, assert that the events are handled in the correct order. Args: @@ -350,17 +381,18 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order): assert btn token = event_chain.poll_for_value(token_input) + assert token is not None btn.click() - if "redirect" in button_id: - # wait a bit longer if we're redirecting - time.sleep(1) - if "many_events" in button_id: - # wait a bit longer if we have loads of events - time.sleep(1) - time.sleep(0.5) - backend_state = event_chain.app_instance.state_manager.states[token] - assert backend_state.event_order == exp_event_order + + async def _has_all_events(): + return len((await event_chain.get_state(token)).event_order) == len( + exp_event_order + ) + + await AppHarness._poll_for_async(_has_all_events) + event_order = (await event_chain.get_state(token)).event_order + assert event_order == exp_event_order @pytest.mark.parametrize( @@ -386,7 +418,13 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order): ), ], ) -def test_event_chain_on_load(event_chain, driver, uri, exp_event_order): +@pytest.mark.asyncio +async def test_event_chain_on_load( + event_chain: AppHarness, + driver: WebDriver, + uri: str, + exp_event_order: list[str], +): """Load the URI, assert that the events are handled in the correct order. Args: @@ -395,16 +433,23 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order): uri: the page to load exp_event_order: the expected events recorded in the State """ + assert event_chain.frontend_url is not None driver.get(event_chain.frontend_url + uri) token_input = driver.find_element(By.ID, "token") assert token_input token = event_chain.poll_for_value(token_input) + assert token is not None - time.sleep(0.5) - backend_state = event_chain.app_instance.state_manager.states[token] - assert backend_state.is_hydrated is True + async def _has_all_events(): + return len((await event_chain.get_state(token)).event_order) == len( + exp_event_order + ) + + await AppHarness._poll_for_async(_has_all_events) + backend_state = await event_chain.get_state(token) assert backend_state.event_order == exp_event_order + assert backend_state.is_hydrated is True @pytest.mark.parametrize( @@ -444,7 +489,13 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order): ), ], ) -def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order): +@pytest.mark.asyncio +async def test_event_chain_on_mount( + event_chain: AppHarness, + driver: WebDriver, + uri: str, + exp_event_order: list[str], +): """Load the URI, assert that the events are handled in the correct order. These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode @@ -458,16 +509,53 @@ def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order): uri: the page to load exp_event_order: the expected events recorded in the State """ + assert event_chain.frontend_url is not None driver.get(event_chain.frontend_url + uri) token_input = driver.find_element(By.ID, "token") assert token_input token = event_chain.poll_for_value(token_input) + assert token is not None unmount_button = driver.find_element(By.ID, "unmount") assert unmount_button unmount_button.click() - time.sleep(1) - backend_state = event_chain.app_instance.state_manager.states[token] - assert backend_state.event_order == exp_event_order + async def _has_all_events(): + return len((await event_chain.get_state(token)).event_order) == len( + exp_event_order + ) + + await AppHarness._poll_for_async(_has_all_events) + event_order = (await event_chain.get_state(token)).event_order + assert event_order == exp_event_order + + +@pytest.mark.parametrize( + ("button_id",), + [ + ("click_yield_interim_value_async",), + ("click_yield_interim_value",), + ], +) +def test_yield_state_update(event_chain: AppHarness, driver: WebDriver, button_id: str): + """Click the button, assert that the interim value is set, then final value is set. + + Args: + event_chain: AppHarness for the event_chain app + driver: selenium WebDriver open to the app + button_id: the ID of the button to click + """ + token_input = driver.find_element(By.ID, "token") + interim_value_input = driver.find_element(By.ID, "interim_value") + assert event_chain.poll_for_value(token_input) + + btn = driver.find_element(By.ID, button_id) + btn.click() + assert ( + event_chain.poll_for_value(interim_value_input, exp_not_equal="") == "interim" + ) + assert ( + event_chain.poll_for_value(interim_value_input, exp_not_equal="interim") + == "final" + ) diff --git a/integration/test_form_submit.py b/integration/test_form_submit.py index 2cddfb32c..74f60a35c 100644 --- a/integration/test_form_submit.py +++ b/integration/test_form_submit.py @@ -19,11 +19,16 @@ def FormSubmit(): def form_submit(self, form_data: dict): self.form_data = form_data + @rx.var + def token(self) -> str: + return self.get_token() + app = rx.App(state=FormState) @app.add_page def index(): return rx.vstack( + rx.input(value=FormState.token, is_read_only=True, id="token"), rx.form( rx.vstack( rx.input(id="name_input"), @@ -82,13 +87,13 @@ def driver(form_submit: AppHarness): """ driver = form_submit.frontend() try: - assert form_submit.poll_for_clients() yield driver finally: driver.quit() -def test_submit(driver, form_submit: AppHarness): +@pytest.mark.asyncio +async def test_submit(driver, form_submit: AppHarness): """Fill a form with various different output, submit it to backend and verify the output. @@ -97,7 +102,14 @@ def test_submit(driver, form_submit: AppHarness): form_submit: harness for FormSubmit app """ assert form_submit.app_instance is not None, "app is not running" - _, backend_state = list(form_submit.app_instance.state_manager.states.items())[0] + + # get a reference to the connected client + token_input = driver.find_element(By.ID, "token") + assert token_input + + # wait for the backend connection to send the token + token = form_submit.poll_for_value(token_input) + assert token name_input = driver.find_element(By.ID, "name_input") name_input.send_keys("foo") @@ -132,19 +144,21 @@ def test_submit(driver, form_submit: AppHarness): submit_input = driver.find_element(By.CLASS_NAME, "chakra-button") submit_input.click() - # wait for the form data to arrive at the backend - AppHarness._poll_for( - lambda: backend_state.form_data != {}, - ) + async def get_form_data(): + return (await form_submit.get_state(token)).form_data - assert backend_state.form_data["name_input"] == "foo" - assert backend_state.form_data["pin_input"] == pin_values - assert backend_state.form_data["number_input"] == "-3" - assert backend_state.form_data["bool_input"] is True - assert backend_state.form_data["bool_input2"] is True - assert backend_state.form_data["slider_input"] == "50" - assert backend_state.form_data["range_input"] == ["25", "75"] - assert backend_state.form_data["radio_input"] == "option2" - assert backend_state.form_data["select_input"] == "option1" - assert backend_state.form_data["text_area_input"] == "Some\nText" - assert backend_state.form_data["debounce_input"] == "bar baz" + # wait for the form data to arrive at the backend + form_data = await AppHarness._poll_for_async(get_form_data) + assert isinstance(form_data, dict) + + assert form_data["name_input"] == "foo" + assert form_data["pin_input"] == pin_values + assert form_data["number_input"] == "-3" + assert form_data["bool_input"] is True + assert form_data["bool_input2"] is True + assert form_data["slider_input"] == "50" + assert form_data["range_input"] == ["25", "75"] + assert form_data["radio_input"] == "option2" + assert form_data["select_input"] == "option1" + assert form_data["text_area_input"] == "Some\nText" + assert form_data["debounce_input"] == "bar baz" diff --git a/integration/test_input.py b/integration/test_input.py index 9085e8bba..ccdc83efa 100644 --- a/integration/test_input.py +++ b/integration/test_input.py @@ -16,11 +16,16 @@ def FullyControlledInput(): class State(rx.State): text: str = "initial" + @rx.var + def token(self) -> str: + return self.get_token() + app = rx.App(state=State) @app.add_page def index(): return rx.fragment( + rx.input(value=State.token, is_read_only=True, id="token"), rx.input( id="debounce_input_input", on_change=State.set_text, # type: ignore @@ -62,10 +67,12 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): driver = fully_controlled_input.frontend() # get a reference to the connected client - assert len(fully_controlled_input.poll_for_clients()) == 1 - token, backend_state = list( - fully_controlled_input.app_instance.state_manager.states.items() - )[0] + token_input = driver.find_element(By.ID, "token") + assert token_input + + # wait for the backend connection to send the token + token = fully_controlled_input.poll_for_value(token_input) + assert token # find the input and wait for it to have the initial state value debounce_input = driver.find_element(By.ID, "debounce_input_input") @@ -80,14 +87,13 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): debounce_input.send_keys("foo") time.sleep(0.5) assert debounce_input.get_attribute("value") == "ifoonitial" - assert backend_state.text == "ifoonitial" + assert (await fully_controlled_input.get_state(token)).text == "ifoonitial" assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial" # clear the input on the backend - backend_state.text = "" - fully_controlled_input.app_instance.state_manager.set_state(token, backend_state) - await fully_controlled_input.emit_state_updates() - assert backend_state.text == "" + async with fully_controlled_input.modify_state(token) as state: + state.text = "" + assert (await fully_controlled_input.get_state(token)).text == "" assert ( fully_controlled_input.poll_for_value( debounce_input, exp_not_equal="ifoonitial" @@ -99,7 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): debounce_input.send_keys("getting testing done") time.sleep(0.5) assert debounce_input.get_attribute("value") == "getting testing done" - assert backend_state.text == "getting testing done" + assert ( + await fully_controlled_input.get_state(token) + ).text == "getting testing done" assert fully_controlled_input.poll_for_value(value_input) == "getting testing done" # type into the on_change input @@ -107,7 +115,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): time.sleep(0.5) assert debounce_input.get_attribute("value") == "overwrite the state" assert on_change_input.get_attribute("value") == "overwrite the state" - assert backend_state.text == "overwrite the state" + assert (await fully_controlled_input.get_state(token)).text == "overwrite the state" assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state" clear_button.click() diff --git a/integration/test_server_side_event.py b/integration/test_server_side_event.py index 87260d299..cbb93e781 100644 --- a/integration/test_server_side_event.py +++ b/integration/test_server_side_event.py @@ -33,11 +33,16 @@ def ServerSideEvent(): def set_value_return_c(self): return rx.set_value("c", "") + @rx.var + def token(self) -> str: + return self.get_token() + app = rx.App(state=SSState) @app.add_page def index(): return rx.fragment( + rx.input(id="token", value=SSState.token, is_read_only=True), rx.input(default_value="a", id="a"), rx.input(default_value="b", id="b"), rx.input(default_value="c", id="c"), @@ -106,7 +111,12 @@ def driver(server_side_event: AppHarness): assert server_side_event.app_instance is not None, "app is not running" driver = server_side_event.frontend() try: - assert server_side_event.poll_for_clients() + token_input = driver.find_element(By.ID, "token") + assert token_input + # wait for the backend connection to send the token + token = server_side_event.poll_for_value(token_input) + assert token is not None + yield driver finally: driver.quit() diff --git a/integration/test_upload.py b/integration/test_upload.py index 0e0917fca..29599afb5 100644 --- a/integration/test_upload.py +++ b/integration/test_upload.py @@ -89,13 +89,13 @@ def driver(upload_file: AppHarness): assert upload_file.app_instance is not None, "app is not running" driver = upload_file.frontend() try: - assert upload_file.poll_for_clients() yield driver finally: driver.quit() -def test_upload_file(tmp_path, upload_file: AppHarness, driver): +@pytest.mark.asyncio +async def test_upload_file(tmp_path, upload_file: AppHarness, driver): """Submit a file upload and check that it arrived on the backend. Args: @@ -124,16 +124,20 @@ def test_upload_file(tmp_path, upload_file: AppHarness, driver): upload_button.click() # look up the backend state and assert on uploaded contents - backend_state = upload_file.app_instance.state_manager.states[token] - time.sleep(0.5) - assert backend_state._file_data[exp_name] == exp_contents + async def get_file_data(): + return (await upload_file.get_state(token))._file_data + + file_data = await AppHarness._poll_for_async(get_file_data) + assert isinstance(file_data, dict) + assert file_data[exp_name] == exp_contents # check that the selected files are displayed selected_files = driver.find_element(By.ID, "selected_files") assert selected_files.text == exp_name -def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): +@pytest.mark.asyncio +async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): """Submit several file uploads and check that they arrived on the backend. Args: @@ -173,10 +177,13 @@ def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): upload_button.click() # look up the backend state and assert on uploaded contents - backend_state = upload_file.app_instance.state_manager.states[token] - time.sleep(0.5) + async def get_file_data(): + return (await upload_file.get_state(token))._file_data + + file_data = await AppHarness._poll_for_async(get_file_data) + assert isinstance(file_data, dict) for exp_name, exp_contents in exp_files.items(): - assert backend_state._file_data[exp_name] == exp_contents + assert file_data[exp_name] == exp_contents def test_clear_files(tmp_path, upload_file: AppHarness, driver): diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py index 8f79927b8..2c8fb08e2 100644 --- a/integration/test_var_operations.py +++ b/integration/test_var_operations.py @@ -26,11 +26,16 @@ def VarOperations(): dict1: dict = {1: 2} dict2: dict = {3: 4} + @rx.var + def token(self) -> str: + return self.get_token() + app = rx.App(state=VarOperationState) @app.add_page def index(): return rx.vstack( + rx.input(id="token", value=VarOperationState.token, is_read_only=True), # INT INT rx.text( VarOperationState.int_var1 + VarOperationState.int_var2, @@ -544,7 +549,12 @@ def driver(var_operations: AppHarness): """ driver = var_operations.frontend() try: - assert var_operations.poll_for_clients() + token_input = driver.find_element(By.ID, "token") + assert token_input + # wait for the backend connection to send the token + token = var_operations.poll_for_value(token_input) + assert token is not None + yield driver finally: driver.quit() diff --git a/reflex/__init__.py b/reflex/__init__.py index bba295a14..f2bb4a964 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -21,6 +21,7 @@ from .constants import Env as Env from .event import EVENT_ARG as EVENT_ARG from .event import EventChain as EventChain from .event import FileUpload as upload_files +from .event import background as background from .event import clear_local_storage as clear_local_storage from .event import console_log as console_log from .event import download as download diff --git a/reflex/app.py b/reflex/app.py index 2018bf541..5c698d1a7 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +import contextlib import inspect import os from multiprocessing.pool import ThreadPool @@ -13,6 +14,7 @@ from typing import ( Dict, List, Optional, + Set, Type, Union, ) @@ -49,7 +51,13 @@ from reflex.route import ( get_route_args, verify_route_validity, ) -from reflex.state import DefaultState, State, StateManager, StateUpdate +from reflex.state import ( + DefaultState, + State, + StateManager, + StateManagerMemory, + StateUpdate, +) from reflex.utils import console, format, prerequisites, types from reflex.vars import ImportVar @@ -89,7 +97,7 @@ class App(Base): state: Type[State] = DefaultState # Class to manage many client states. - state_manager: StateManager = StateManager() + state_manager: StateManager = StateManagerMemory(state=DefaultState) # The styling to apply to each component. style: ComponentStyle = {} @@ -104,13 +112,16 @@ class App(Base): admin_dash: Optional[AdminDash] = None # The async server name space - event_namespace: Optional[AsyncNamespace] = None + event_namespace: Optional[EventNamespace] = None # A component that is present on every page. overlay_component: Optional[ Union[Component, ComponentCallable] ] = default_overlay_component + # Background tasks that are currently running + background_tasks: Set[asyncio.Task] = set() + def __init__(self, *args, **kwargs): """Initialize the app. @@ -154,7 +165,7 @@ class App(Base): self.middleware.append(HydrateMiddleware()) # Set up the state manager. - self.state_manager.setup(state=self.state) + self.state_manager = StateManager.create(state=self.state) # Set up the API. self.api = FastAPI() @@ -646,6 +657,76 @@ class App(Base): thread_pool.close() thread_pool.join() + @contextlib.asynccontextmanager + async def modify_state(self, token: str) -> AsyncIterator[State]: + """Modify the state out of band. + + Args: + token: The token to modify the state for. + + Yields: + The state to modify. + + Raises: + RuntimeError: If the app has not been initialized yet. + """ + if self.event_namespace is None: + raise RuntimeError("App has not been initialized yet.") + # Get exclusive access to the state. + async with self.state_manager.modify_state(token) as state: + # No other event handler can modify the state while in this context. + yield state + delta = state.get_delta() + if delta: + # When the state is modified reset dirty status and emit the delta to the frontend. + state._clean() + await self.event_namespace.emit_update( + update=StateUpdate(delta=delta), + sid=state.get_sid(), + ) + + def _process_background(self, state: State, event: Event) -> asyncio.Task | None: + """Process an event in the background and emit updates as they arrive. + + Args: + state: The state to process the event for. + event: The event to process. + + Returns: + Task if the event was backgroundable, otherwise None + """ + substate, handler = state._get_event_handler(event) + if not handler.is_background: + return None + + async def _coro(): + """Coroutine to process the event and emit updates inside an asyncio.Task. + + Raises: + RuntimeError: If the app has not been initialized yet. + """ + if self.event_namespace is None: + raise RuntimeError("App has not been initialized yet.") + + # Process the event. + async for update in state._process_event( + handler=handler, state=substate, payload=event.payload + ): + # Postprocess the event. + update = await self.postprocess(state, event, update) + + # Send the update to the client. + await self.event_namespace.emit_update( + update=update, + sid=state.get_sid(), + ) + + task = asyncio.create_task(_coro()) + self.background_tasks.add(task) + # Clean up task from background_tasks set when complete. + task.add_done_callback(self.background_tasks.discard) + return task + async def process( app: App, event: Event, sid: str, headers: Dict, client_ip: str @@ -662,9 +743,6 @@ async def process( Yields: The state updates after processing the event. """ - # Get the state for the session. - state = app.state_manager.get_state(event.token) - # Add request data to the state. router_data = event.router_data router_data.update( @@ -676,31 +754,35 @@ async def process( constants.RouteVar.CLIENT_IP: client_ip, } ) - # re-assign only when the value is different - if state.router_data != router_data: - # assignment will recurse into substates and force recalculation of - # dependent ComputedVar (dynamic route variables) - state.router_data = router_data + # Get the state for the session exclusively. + async with app.state_manager.modify_state(event.token) as state: + # re-assign only when the value is different + if state.router_data != router_data: + # assignment will recurse into substates and force recalculation of + # dependent ComputedVar (dynamic route variables) + state.router_data = router_data - # Preprocess the event. - update = await app.preprocess(state, event) + # Preprocess the event. + update = await app.preprocess(state, event) - # If there was an update, yield it. - if update is not None: - yield update - - # Only process the event if there is no update. - else: - # Process the event. - async for update in state._process(event): - # Postprocess the event. - update = await app.postprocess(state, event, update) - - # Yield the update. + # If there was an update, yield it. + if update is not None: yield update - # Set the state for the session. - app.state_manager.set_state(event.token, state) + # Only process the event if there is no update. + else: + if app._process_background(state, event) is not None: + # `final=True` allows the frontend send more events immediately. + yield StateUpdate(final=True) + return + + # Process the event synchronously. + async for update in state._process(event): + # Postprocess the event. + update = await app.postprocess(state, event, update) + + # Yield the update. + yield update async def ping() -> str: @@ -737,47 +819,46 @@ def upload(app: App): assert file.filename is not None file.filename = file.filename.split(":")[-1] # Get the state for the session. - state = app.state_manager.get_state(token) - # get the current session ID - sid = state.get_sid() - # get the current state(parent state/substate) - path = handler.split(".")[:-1] - current_state = state.get_substate(path) - handler_upload_param = () + async with app.state_manager.modify_state(token) as state: + # get the current session ID + sid = state.get_sid() + # get the current state(parent state/substate) + path = handler.split(".")[:-1] + current_state = state.get_substate(path) + handler_upload_param = () - # get handler function - func = getattr(current_state, handler.split(".")[-1]) + # get handler function + func = getattr(current_state, handler.split(".")[-1]) - # check if there exists any handler args with annotation, List[UploadFile] - for k, v in inspect.getfullargspec( - func.fn if isinstance(func, EventHandler) else func - ).annotations.items(): - if types.is_generic_alias(v) and types._issubclass( - v.__args__[0], UploadFile - ): - handler_upload_param = (k, v) - break + # check if there exists any handler args with annotation, List[UploadFile] + for k, v in inspect.getfullargspec( + func.fn if isinstance(func, EventHandler) else func + ).annotations.items(): + if types.is_generic_alias(v) and types._issubclass( + v.__args__[0], UploadFile + ): + handler_upload_param = (k, v) + break - if not handler_upload_param: - raise ValueError( - f"`{handler}` handler should have a parameter annotated as List[" - f"rx.UploadFile]" + if not handler_upload_param: + raise ValueError( + f"`{handler}` handler should have a parameter annotated as List[" + f"rx.UploadFile]" + ) + + event = Event( + token=token, + name=handler, + payload={handler_upload_param[0]: files}, ) - - event = Event( - token=token, - name=handler, - payload={handler_upload_param[0]: files}, - ) - async for update in state._process(event): - # Postprocess the event. - update = await app.postprocess(state, event, update) - # Send update to client - await asyncio.create_task( - app.event_namespace.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore - ) - # Set the state for the session. - app.state_manager.set_state(event.token, state) + async for update in state._process(event): + # Postprocess the event. + update = await app.postprocess(state, event, update) + # Send update to client + await app.event_namespace.emit_update( # type: ignore + update=update, + sid=sid, + ) return upload_file @@ -815,6 +896,18 @@ class EventNamespace(AsyncNamespace): """ pass + async def emit_update(self, update: StateUpdate, sid: str) -> None: + """Emit an update to the client. + + Args: + update: The state update to send. + sid: The Socket.IO session id. + """ + # Creating a task prevents the update from being blocked behind other coroutines. + await asyncio.create_task( + self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) + ) + async def on_event(self, sid, data): """Event for receiving front-end websocket events. @@ -841,10 +934,8 @@ class EventNamespace(AsyncNamespace): # Process the events. async for update in process(self.app, event, sid, headers, client_ip): - # Emit the event. - await asyncio.create_task( - self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) - ) + # Emit the update from processing the event. + await self.emit_update(update=update, sid=sid) async def on_ping(self, sid): """Event for testing the API endpoint. diff --git a/reflex/app.pyi b/reflex/app.pyi index ebb94d0ff..74ea757d2 100644 --- a/reflex/app.pyi +++ b/reflex/app.pyi @@ -1,5 +1,6 @@ """ Generated with stubgen from mypy, then manually edited, do not regen.""" +import asyncio from fastapi import FastAPI from fastapi import UploadFile as UploadFile from reflex import constants as constants @@ -45,12 +46,14 @@ from reflex.utils import ( from socketio import ASGIApp, AsyncNamespace, AsyncServer from typing import ( Any, + AsyncContextManager, AsyncIterator, Callable, Coroutine, Dict, List, Optional, + Set, Type, Union, overload, @@ -75,6 +78,7 @@ class App(Base): admin_dash: Optional[AdminDash] event_namespace: Optional[AsyncNamespace] overlay_component: Optional[Union[Component, ComponentCallable]] + background_tasks: Set[asyncio.Task] = set() def __init__( self, *args, @@ -116,6 +120,10 @@ class App(Base): def setup_admin_dash(self) -> None: ... def get_frontend_packages(self, imports: Dict[str, str]): ... def compile(self) -> None: ... + def modify_state(self, token: str) -> AsyncContextManager[State]: ... + def _process_background( + self, state: State, event: Event + ) -> asyncio.Task | None: ... async def process( app: App, event: Event, sid: str, headers: Dict, client_ip: str diff --git a/reflex/constants.py b/reflex/constants.py index ab4942874..ae78b5428 100644 --- a/reflex/constants.py +++ b/reflex/constants.py @@ -219,6 +219,8 @@ OLD_CONFIG_FILE = f"pcconfig{PY_EXT}" PRODUCTION_BACKEND_URL = "https://{username}-{app_name}.api.pynecone.app" # Token expiration time in seconds. TOKEN_EXPIRATION = 60 * 60 +# Maximum time in milliseconds that a state can be locked for exclusive access. +LOCK_EXPIRATION = 10000 # Testing variables. # Testing os env set by pytest when running a test case. diff --git a/reflex/event.py b/reflex/event.py index 04e37e67c..557b717dd 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -2,7 +2,17 @@ from __future__ import annotations import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) from reflex import constants from reflex.base import Base @@ -10,6 +20,9 @@ from reflex.utils import console, format from reflex.utils.types import ArgsSpec from reflex.vars import BaseVar, Var +if TYPE_CHECKING: + from reflex.state import State + class Event(Base): """An event that describes any state change in the app.""" @@ -27,6 +40,66 @@ class Event(Base): payload: Dict[str, Any] = {} +BACKGROUND_TASK_MARKER = "_reflex_background_task" + + +def background(fn): + """Decorator to mark event handler as running in the background. + + Args: + fn: The function to decorate. + + Returns: + The same function, but with a marker set. + + + Raises: + TypeError: If the function is not a coroutine function or async generator. + """ + if not inspect.iscoroutinefunction(fn) and not inspect.isasyncgenfunction(fn): + raise TypeError("Background task must be async function or generator.") + setattr(fn, BACKGROUND_TASK_MARKER, True) + return fn + + +def _no_chain_background_task( + state_cls: Type["State"], name: str, fn: Callable +) -> Callable: + """Protect against directly chaining a background task from another event handler. + + Args: + state_cls: The state class that the event handler is in. + name: The name of the background task. + fn: The background task coroutine function / generator. + + Returns: + A compatible coroutine function / generator that raises a runtime error. + + Raises: + TypeError: If the background task is not async. + """ + call = f"{state_cls.__name__}.{name}" + message = ( + f"Cannot directly call background task {name!r}, use " + f"`yield {call}` or `return {call}` instead." + ) + if inspect.iscoroutinefunction(fn): + + async def _no_chain_background_task_co(*args, **kwargs): + raise RuntimeError(message) + + return _no_chain_background_task_co + if inspect.isasyncgenfunction(fn): + + async def _no_chain_background_task_gen(*args, **kwargs): + yield + raise RuntimeError(message) + + return _no_chain_background_task_gen + + raise TypeError(f"{fn} is marked as a background task, but is not async.") + + class EventHandler(Base): """An event handler responds to an event to update the state.""" @@ -39,6 +112,15 @@ class EventHandler(Base): # Needed to allow serialization of Callable. frozen = True + @property + def is_background(self) -> bool: + """Whether the event handler is a background task. + + Returns: + True if the event handler is marked as a background task. + """ + return getattr(self.fn, BACKGROUND_TASK_MARKER, False) + def __call__(self, *args: Var) -> EventSpec: """Pass arguments to the handler to get an event spec. @@ -530,7 +612,7 @@ def get_handler_args(event_spec: EventSpec) -> tuple[tuple[Var, Var], ...]: def fix_events( - events: list[EventHandler | EventSpec], + events: list[EventHandler | EventSpec] | None, token: str, router_data: dict[str, Any] | None = None, ) -> list[Event]: diff --git a/reflex/state.py b/reflex/state.py index dfe9803c7..f7ef2577f 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2,13 +2,15 @@ from __future__ import annotations import asyncio +import contextlib import copy import functools import inspect import json import traceback import urllib.parse -from abc import ABC +import uuid +from abc import ABC, abstractmethod from collections import defaultdict from types import FunctionType from typing import ( @@ -27,12 +29,20 @@ from typing import ( import cloudpickle import pydantic import wrapt -from redis import Redis +from redis.asyncio import Redis from reflex import constants from reflex.base import Base -from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert +from reflex.event import ( + Event, + EventHandler, + EventSpec, + _no_chain_background_task, + fix_events, + window_alert, +) from reflex.utils import format, prerequisites, types +from reflex.utils.exceptions import ImmutableStateError, LockExpiredError from reflex.vars import BaseVar, ComputedVar, Var Delta = Dict[str, Any] @@ -152,7 +162,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Convert the event handlers to functions. for name, event_handler in state.event_handlers.items(): - fn = functools.partial(event_handler.fn, self) + if event_handler.is_background: + fn = _no_chain_background_task(type(state), name, event_handler.fn) + else: + fn = functools.partial(event_handler.fn, self) fn.__module__ = event_handler.fn.__module__ # type: ignore fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore setattr(self, name, fn) @@ -711,6 +724,37 @@ class State(Base, ABC, extra=pydantic.Extra.allow): raise ValueError(f"Invalid path: {path}") return self.substates[path[0]].get_substate(path[1:]) + def _get_event_handler( + self, event: Event + ) -> tuple[State | StateProxy, EventHandler]: + """Get the event handler for the given event. + + Args: + event: The event to get the handler for. + + + Returns: + The event handler. + + Raises: + ValueError: If the event handler or substate is not found. + """ + # Get the event handler. + path = event.name.split(".") + path, name = path[:-1], path[-1] + substate = self.get_substate(path) + if not substate: + raise ValueError( + "The value of state cannot be None when processing an event." + ) + handler = substate.event_handlers[name] + + # For background tasks, proxy the state + if handler.is_background: + substate = StateProxy(substate) + + return substate, handler + async def _process(self, event: Event) -> AsyncIterator[StateUpdate]: """Obtain event info and process event. @@ -719,44 +763,17 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Yields: The state update after processing the event. - - Raises: - ValueError: If the state value is None. """ # Get the event handler. - path = event.name.split(".") - path, name = path[:-1], path[-1] - substate = self.get_substate(path) - handler = substate.event_handlers[name] # type: ignore + substate, handler = self._get_event_handler(event) - if not substate: - raise ValueError( - "The value of state cannot be None when processing an event." - ) - - # Get the event generator. - event_iter = self._process_event( + # Run the event generator and yield state updates. + async for update in self._process_event( handler=handler, state=substate, payload=event.payload, - ) - - # Clean the state before processing the event. - self._clean() - - # Run the event generator and return state updates. - async for events, final in event_iter: - # Fix the returned events. - events = fix_events(events, event.token) # type: ignore - - # Get the delta after processing the event. - delta = self.get_delta() - - # Yield the state update. - yield StateUpdate(delta=delta, events=events, final=final) - - # Clean the state to prepare for the next event. - self._clean() + ): + yield update def _check_valid(self, handler: EventHandler, events: Any) -> Any: """Check if the events yielded are valid. They must be EventHandlers or EventSpecs. @@ -787,9 +804,42 @@ class State(Base, ABC, extra=pydantic.Extra.allow): f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)" ) + def _as_state_update( + self, + handler: EventHandler, + events: EventSpec | list[EventSpec] | None, + final: bool, + ) -> StateUpdate: + """Convert the events to a StateUpdate. + + Fixes the events and checks for validity before converting. + + Args: + handler: The handler where the events originated from. + events: The events to queue with the update. + final: Whether the handler is done processing. + + Returns: + The valid StateUpdate containing the events and final flag. + """ + token = self.get_token() + + # Convert valid EventHandler and EventSpec into Event + fixed_events = fix_events(self._check_valid(handler, events), token) + + # Get the delta after processing the event. + delta = self.get_delta() + self._clean() + + return StateUpdate( + delta=delta, + events=fixed_events, + final=final if not handler.is_background else True, + ) + async def _process_event( - self, handler: EventHandler, state: State, payload: Dict - ) -> AsyncIterator[tuple[list[EventSpec] | None, bool]]: + self, handler: EventHandler, state: State | StateProxy, payload: Dict + ) -> AsyncIterator[StateUpdate]: """Process event. Args: @@ -798,13 +848,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow): payload: The event payload. Yields: - Tuple containing: - 0: The state update after processing the event. - 1: Whether the event is the final event. + StateUpdate object """ # Get the function to process the event. fn = functools.partial(handler.fn, state) + # Clean the state before processing the event. + self._clean() + # Wrap the function in a try/except block. try: # Handle async functions. @@ -817,30 +868,34 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Handle async generators. if inspect.isasyncgen(events): async for event in events: - yield self._check_valid(handler, event), False - yield None, True + yield self._as_state_update(handler, event, final=False) + yield self._as_state_update(handler, events=None, final=True) # Handle regular generators. elif inspect.isgenerator(events): try: while True: - yield self._check_valid(handler, next(events)), False + yield self._as_state_update(handler, next(events), final=False) except StopIteration as si: # the "return" value of the generator is not available # in the loop, we must catch StopIteration to access it if si.value is not None: - yield self._check_valid(handler, si.value), False - yield None, True + yield self._as_state_update(handler, si.value, final=False) + yield self._as_state_update(handler, events=None, final=True) # Handle regular event chains. else: - yield self._check_valid(handler, events), True + yield self._as_state_update(handler, events, final=True) # If an error occurs, throw a window alert. except Exception: error = traceback.format_exc() print(error) - yield [window_alert("An error occurred. See logs for details.")], True + yield self._as_state_update( + handler, + window_alert("An error occurred. See logs for details."), + final=True, + ) def _always_dirty_computed_vars(self) -> set[str]: """The set of ComputedVars that always need to be recalculated. @@ -989,6 +1044,160 @@ class State(Base, ABC, extra=pydantic.Extra.allow): variables = {**base_vars, **computed_vars, **substate_vars} return {k: variables[k] for k in sorted(variables)} + async def __aenter__(self) -> State: + """Enter the async context manager protocol. + + This should not be used for the State class, but exists for + type-compatibility with StateProxy. + + Raises: + TypeError: always, because async contextmanager protocol is only supported for background task. + """ + raise TypeError( + "Only background task should use `async with self` to modify state." + ) + + async def __aexit__(self, *exc_info: Any) -> None: + """Exit the async context manager protocol. + + This should not be used for the State class, but exists for + type-compatibility with StateProxy. + + Args: + exc_info: The exception info tuple. + """ + pass + + +class StateProxy(wrapt.ObjectProxy): + """Proxy of a state instance to control mutability of vars for a background task. + + Since a background task runs against a state instance without holding the + state_manager lock for the token, the reference may become stale if the same + state is modified by another event handler. + + The proxy object ensures that writes to the state are blocked unless + explicitly entering a context which refreshes the state from state_manager + and holds the lock for the token until exiting the context. After exiting + the context, a StateUpdate may be emitted to the frontend to notify the + client of the state change. + + A background task will be passed the `StateProxy` as `self`, so mutability + can be safely performed inside an `async with self` block. + + class State(rx.State): + counter: int = 0 + + @rx.background + async def bg_increment(self): + await asyncio.sleep(1) + async with self: + self.counter += 1 + """ + + def __init__(self, state_instance): + """Create a proxy for a state instance. + + Args: + state_instance: The state instance to proxy. + """ + super().__init__(state_instance) + self._self_app = getattr(prerequisites.get_app(), constants.APP_VAR) + self._self_substate_path = state_instance.get_full_name().split(".") + self._self_actx = None + self._self_mutable = False + + async def __aenter__(self) -> StateProxy: + """Enter the async context manager protocol. + + Sets mutability to True and enters the `App.modify_state` async context, + which refreshes the state from state_manager and holds the lock for the + given state token until exiting the context. + + Background tasks should avoid blocking calls while inside the context. + + Returns: + This StateProxy instance in mutable mode. + """ + self._self_actx = self._self_app.modify_state(self.__wrapped__.get_token()) + mutable_state = await self._self_actx.__aenter__() + super().__setattr__( + "__wrapped__", mutable_state.get_substate(self._self_substate_path) + ) + self._self_mutable = True + return self + + async def __aexit__(self, *exc_info: Any) -> None: + """Exit the async context manager protocol. + + Sets proxy mutability to False and persists any state changes. + + Args: + exc_info: The exception info tuple. + """ + if self._self_actx is None: + return + self._self_mutable = False + await self._self_actx.__aexit__(*exc_info) + self._self_actx = None + + def __enter__(self): + """Enter the regular context manager protocol. + + This is not supported for background tasks, and exists only to raise a more useful exception + when the StateProxy is used incorrectly. + + Raises: + TypeError: always, because only async contextmanager protocol is supported. + """ + raise TypeError("Background task must use `async with self` to modify state.") + + def __exit__(self, *exc_info: Any) -> None: + """Exit the regular context manager protocol. + + Args: + exc_info: The exception info tuple. + """ + pass + + def __getattr__(self, name: str) -> Any: + """Get the attribute from the underlying state instance. + + Args: + name: The name of the attribute. + + Returns: + The value of the attribute. + """ + value = super().__getattr__(name) + if not name.startswith("_self_") and isinstance(value, MutableProxy): + # ensure mutations to these containers are blocked unless proxy is _mutable + return ImmutableMutableProxy( + wrapped=value.__wrapped__, + state=self, # type: ignore + field_name=value._self_field_name, + ) + return value + + def __setattr__(self, name: str, value: Any) -> None: + """Set the attribute on the underlying state instance. + + If the attribute is internal, set it on the proxy instance instead. + + Args: + name: The name of the attribute. + value: The value of the attribute. + + Raises: + ImmutableStateError: If the state is not in mutable mode. + """ + if not name.startswith("_self_") and not self._self_mutable: + raise ImmutableStateError( + "Background task StateProxy is immutable outside of a context " + "manager. Use `async with self` to modify state." + ) + super().__setattr__(name, value) + class DefaultState(State): """The default empty state.""" @@ -1009,31 +1218,29 @@ class StateUpdate(Base): final: bool = True -class StateManager(Base): +class StateManager(Base, ABC): """A class to manage many client states.""" # The state class to use. - state: Type[State] = DefaultState + state: Type[State] - # The mapping of client ids to states. - states: Dict[str, State] = {} - - # The token expiration time (s). - token_expiration: int = constants.TOKEN_EXPIRATION - - # The redis client to use. - redis: Optional[Redis] = None - - def setup(self, state: Type[State]): - """Set up the state manager. + @classmethod + def create(cls, state: Type[State] = DefaultState): + """Create a new state manager. Args: state: The state class to use. - """ - self.state = state - self.redis = prerequisites.get_redis() - def get_state(self, token: str) -> State: + Returns: + The state manager (either memory or redis). + """ + redis = prerequisites.get_redis() + if redis is not None: + return StateManagerRedis(state=state, redis=redis) + return StateManagerMemory(state=state) + + @abstractmethod + async def get_state(self, token: str) -> State: """Get the state for a token. Args: @@ -1042,27 +1249,266 @@ class StateManager(Base): Returns: The state for the token. """ - if self.redis is not None: - redis_state = self.redis.get(token) - if redis_state is None: - self.set_state(token, self.state()) - return self.get_state(token) - return cloudpickle.loads(redis_state) + pass - if token not in self.states: - self.states[token] = self.state() - return self.states[token] - - def set_state(self, token: str, state: State): + @abstractmethod + async def set_state(self, token: str, state: State): """Set the state for a token. Args: token: The token to set the state for. state: The state to set. """ - if self.redis is None: - return - self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration) + pass + + @abstractmethod + @contextlib.asynccontextmanager + async def modify_state(self, token: str) -> AsyncIterator[State]: + """Modify the state for a token while holding exclusive lock. + + Args: + token: The token to modify the state for. + + Yields: + The state for the token. + """ + yield self.state() + + +class StateManagerMemory(StateManager): + """A state manager that stores states in memory.""" + + # The mapping of client ids to states. + states: Dict[str, State] = {} + + # The mutex ensures the dict of mutexes is updated exclusively + _state_manager_lock = asyncio.Lock() + + # The dict of mutexes for each client + _states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({}) + + class Config: + """The Pydantic config.""" + + fields = { + "_states_locks": {"exclude": True}, + } + + async def get_state(self, token: str) -> State: + """Get the state for a token. + + Args: + token: The token to get the state for. + + Returns: + The state for the token. + """ + if token not in self.states: + self.states[token] = self.state() + return self.states[token] + + async def set_state(self, token: str, state: State): + """Set the state for a token. + + Args: + token: The token to set the state for. + state: The state to set. + """ + pass + + @contextlib.asynccontextmanager + async def modify_state(self, token: str) -> AsyncIterator[State]: + """Modify the state for a token while holding exclusive lock. + + Args: + token: The token to modify the state for. + + Yields: + The state for the token. + """ + if token not in self._states_locks: + async with self._state_manager_lock: + if token not in self._states_locks: + self._states_locks[token] = asyncio.Lock() + + async with self._states_locks[token]: + state = await self.get_state(token) + yield state + await self.set_state(token, state) + + +class StateManagerRedis(StateManager): + """A state manager that stores states in redis.""" + + # The redis client to use. + redis: Redis + + # The token expiration time (s). + token_expiration: int = constants.TOKEN_EXPIRATION + + # The maximum time to hold a lock (ms). + lock_expiration: int = constants.LOCK_EXPIRATION + + # The keyspace subscription string when redis is waiting for lock to be released + _redis_notify_keyspace_events: str = ( + "K" # Enable keyspace notifications (target a particular key) + "g" # For generic commands (DEL, EXPIRE, etc) + "x" # For expired events + "e" # For evicted events (i.e. maxmemory exceeded) + ) + + # These events indicate that a lock is no longer held + _redis_keyspace_lock_release_events: Set[bytes] = { + b"del", + b"expire", + b"expired", + b"evicted", + } + + async def get_state(self, token: str) -> State: + """Get the state for a token. + + Args: + token: The token to get the state for. + + Returns: + The state for the token. + """ + redis_state = await self.redis.get(token) + if redis_state is None: + await self.set_state(token, self.state()) + return await self.get_state(token) + return cloudpickle.loads(redis_state) + + async def set_state(self, token: str, state: State, lock_id: bytes | None = None): + """Set the state for a token. + + Args: + token: The token to set the state for. + state: The state to set. + lock_id: If provided, the lock_key must be set to this value to set the state. + + Raises: + LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID. + """ + # check that we're holding the lock + if ( + lock_id is not None + and await self.redis.get(self._lock_key(token)) != lock_id + ): + raise LockExpiredError( + f"Lock expired for token {token} while processing. Consider increasing " + f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " + "or use `@rx.background` decorator for long-running tasks." + ) + await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration) + + @contextlib.asynccontextmanager + async def modify_state(self, token: str) -> AsyncIterator[State]: + """Modify the state for a token while holding exclusive lock. + + Args: + token: The token to modify the state for. + + Yields: + The state for the token. + """ + async with self._lock(token) as lock_id: + state = await self.get_state(token) + yield state + await self.set_state(token, state, lock_id) + + @staticmethod + def _lock_key(token: str) -> bytes: + """Get the redis key for a token's lock. + + Args: + token: The token to get the lock key for. + + Returns: + The redis lock key for the token. + """ + return f"{token}_lock".encode() + + async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None: + """Try to get a redis lock for a token. + + Args: + lock_key: The redis key for the lock. + lock_id: The ID of the lock. + + Returns: + True if the lock was obtained. + """ + return await self.redis.set( + lock_key, + lock_id, + px=self.lock_expiration, + nx=True, # only set if it doesn't exist + ) + + async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: + """Wait for a redis lock to be released via pubsub. + + Coroutine will not return until the lock is obtained. + + Args: + lock_key: The redis key for the lock. + lock_id: The ID of the lock. + """ + state_is_locked = False + lock_key_channel = f"__keyspace@0__:{lock_key.decode()}" + # Enable keyspace notifications for the lock key, so we know when it is available. + await self.redis.config_set( + "notify-keyspace-events", self._redis_notify_keyspace_events + ) + async with self.redis.pubsub() as pubsub: + await pubsub.psubscribe(lock_key_channel) + while not state_is_locked: + # wait for the lock to be released + while True: + if not await self.redis.exists(lock_key): + break # key was removed, try to get the lock again + message = await pubsub.get_message( + ignore_subscribe_messages=True, + timeout=self.lock_expiration / 1000.0, + ) + if message is None: + continue + if message["data"] in self._redis_keyspace_lock_release_events: + break + state_is_locked = await self._try_get_lock(lock_key, lock_id) + + @contextlib.asynccontextmanager + async def _lock(self, token: str): + """Obtain a redis lock for a token. + + Args: + token: The token to obtain a lock for. + + Yields: + The ID of the lock (to be passed to set_state). + + Raises: + LockExpiredError: If the lock has expired while processing the event. + """ + lock_key = self._lock_key(token) + lock_id = uuid.uuid4().hex.encode() + + if not await self._try_get_lock(lock_key, lock_id): + # Missed the fast-path to get lock, subscribe for lock delete/expire events + await self._wait_lock(lock_key, lock_id) + state_is_locked = True + + try: + yield lock_id + except LockExpiredError: + state_is_locked = False + raise + finally: + if state_is_locked: + # only delete our lock + await self.redis.delete(lock_key) class ClientStorageBase: @@ -1246,7 +1692,7 @@ class MutableProxy(wrapt.ObjectProxy): value, super().__getattribute__("__mutable_types__") ) and __name not in ("__wrapped__", "_self_state"): # Recursively wrap mutable attribute values retrieved through this proxy. - return MutableProxy( + return type(self)( wrapped=value, state=self._self_state, field_name=self._self_field_name, @@ -1266,7 +1712,7 @@ class MutableProxy(wrapt.ObjectProxy): value = super().__getitem__(key) if isinstance(value, self.__mutable_types__): # Recursively wrap mutable items retrieved through this proxy. - return MutableProxy( + return type(self)( wrapped=value, state=self._self_state, field_name=self._self_field_name, @@ -1332,3 +1778,34 @@ class MutableProxy(wrapt.ObjectProxy): A deepcopy of the wrapped object, unconnected to the proxy. """ return copy.deepcopy(self.__wrapped__, memo=memo) + + +class ImmutableMutableProxy(MutableProxy): + """A proxy for a mutable object that tracks changes. + + This wrapper comes from StateProxy, and will raise an exception if an attempt is made + to modify the wrapped object when the StateProxy is immutable. + """ + + def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None): + """Raise an exception when an attempt is made to modify the object. + + Intended for use with `FunctionWrapper` from the `wrapt` library. + + Args: + wrapped: The wrapped function. + instance: The instance of the wrapped function. + args: The args for the wrapped function. + kwargs: The kwargs for the wrapped function. + + Raises: + ImmutableStateError: if the StateProxy is not mutable. + """ + if not self._self_state._self_mutable: + raise ImmutableStateError( + "Background task StateProxy is immutable outside of a context " + "manager. Use `async with self` to modify state." + ) + super()._mark_dirty( + wrapped=wrapped, instance=instance, args=args, kwargs=kwargs + ) diff --git a/reflex/testing.py b/reflex/testing.py index 953f1ddbe..9486823f9 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -1,6 +1,7 @@ """reflex.testing - tools for testing reflex apps.""" from __future__ import annotations +import asyncio import contextlib import dataclasses import inspect @@ -19,14 +20,13 @@ import types from http.server import SimpleHTTPRequestHandler from typing import ( TYPE_CHECKING, - Any, + AsyncIterator, Callable, Coroutine, Optional, Type, TypeVar, Union, - cast, ) import psutil @@ -38,7 +38,7 @@ import reflex.utils.build import reflex.utils.exec import reflex.utils.prerequisites import reflex.utils.processes -from reflex.app import EventNamespace +from reflex.state import State, StateManagerMemory, StateManagerRedis try: from selenium import webdriver # pyright: ignore [reportMissingImports] @@ -109,6 +109,7 @@ class AppHarness: frontend_url: Optional[str] = None backend_thread: Optional[threading.Thread] = None backend: Optional[uvicorn.Server] = None + state_manager: Optional[StateManagerMemory | StateManagerRedis] = None _frontends: list["WebDriver"] = dataclasses.field(default_factory=list) @classmethod @@ -162,6 +163,27 @@ class AppHarness: reflex.config.get_config(reload=True) self.app_module = reflex.utils.prerequisites.get_app(reload=True) self.app_instance = self.app_module.app + if isinstance(self.app_instance.state_manager, StateManagerRedis): + # Create our own redis connection for testing. + self.state_manager = StateManagerRedis.create(self.app_instance.state) + else: + self.state_manager = self.app_instance.state_manager + + def _get_backend_shutdown_handler(self): + if self.backend is None: + raise RuntimeError("Backend was not initialized.") + + original_shutdown = self.backend.shutdown + + async def _shutdown_redis(*args, **kwargs) -> None: + # ensure redis is closed before event loop + if self.app_instance is not None and isinstance( + self.app_instance.state_manager, StateManagerRedis + ): + await self.app_instance.state_manager.redis.close() + await original_shutdown(*args, **kwargs) + + return _shutdown_redis def _start_backend(self, port=0): if self.app_instance is None: @@ -173,6 +195,7 @@ class AppHarness: port=port, ) ) + self.backend.shutdown = self._get_backend_shutdown_handler() self.backend_thread = threading.Thread(target=self.backend.run) self.backend_thread.start() @@ -296,6 +319,35 @@ class AppHarness: time.sleep(step) return False + @staticmethod + async def _poll_for_async( + target: Callable[[], Coroutine[None, None, T]], + timeout: TimeoutType = None, + step: TimeoutType = None, + ) -> T | bool: + """Generic polling logic for async functions. + + Args: + target: callable that returns truthy if polling condition is met. + timeout: max polling time + step: interval between checking target() + + Returns: + return value of target() if truthy within timeout + False if timeout elapses + """ + if timeout is None: + timeout = DEFAULT_TIMEOUT + if step is None: + step = POLL_INTERVAL + deadline = time.time() + timeout + while time.time() < deadline: + success = await target() + if success: + return success + await asyncio.sleep(step) + return False + def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket: """Poll backend server for listening sockets. @@ -351,39 +403,76 @@ class AppHarness: self._frontends.append(driver) return driver - async def emit_state_updates(self) -> list[Any]: - """Send any backend state deltas to the frontend. + async def get_state(self, token: str) -> State: + """Get the state associated with the given token. + + Args: + token: The state token to look up. Returns: - List of awaited response from each EventNamespace.emit() call. + The state instance associated with the given token Raises: RuntimeError: when the app hasn't started running """ - if self.app_instance is None or self.app_instance.sio is None: + if self.state_manager is None: + raise RuntimeError("state_manager is not set.") + try: + return await self.state_manager.get_state(token) + finally: + if isinstance(self.state_manager, StateManagerRedis): + await self.state_manager.redis.close() + + async def set_state(self, token: str, **kwargs) -> None: + """Set the state associated with the given token. + + Args: + token: The state token to set. + kwargs: Attributes to set on the state. + + Raises: + RuntimeError: when the app hasn't started running + """ + if self.state_manager is None: + raise RuntimeError("state_manager is not set.") + state = await self.get_state(token) + for key, value in kwargs.items(): + setattr(state, key, value) + try: + await self.state_manager.set_state(token, state) + finally: + if isinstance(self.state_manager, StateManagerRedis): + await self.state_manager.redis.close() + + @contextlib.asynccontextmanager + async def modify_state(self, token: str) -> AsyncIterator[State]: + """Modify the state associated with the given token and send update to frontend. + + Args: + token: The state token to modify + + Yields: + The state instance associated with the given token + + Raises: + RuntimeError: when the app hasn't started running + """ + if self.state_manager is None: + raise RuntimeError("state_manager is not set.") + if self.app_instance is None: raise RuntimeError("App is not running.") - event_ns: EventNamespace = cast( - EventNamespace, - self.app_instance.event_namespace, - ) - pending: list[Coroutine[Any, Any, Any]] = [] - for state in self.app_instance.state_manager.states.values(): - delta = state.get_delta() - if delta: - update = reflex.state.StateUpdate(delta=delta, events=[], final=True) - state._clean() - # Emit the event. - pending.append( - event_ns.emit( - str(reflex.constants.SocketEvent.EVENT), - update.json(), - to=state.get_sid(), - ), - ) - responses = [] - for request in pending: - responses.append(await request) - return responses + app_state_manager = self.app_instance.state_manager + if isinstance(self.state_manager, StateManagerRedis): + # Temporarily replace the app's state manager with our own, since + # the redis connection is on the backend_thread event loop + self.app_instance.state_manager = self.state_manager + try: + async with self.app_instance.modify_state(token) as state: + yield state + finally: + if isinstance(self.state_manager, StateManagerRedis): + self.app_instance.state_manager = app_state_manager + await self.state_manager.redis.close() def poll_for_content( self, @@ -457,6 +546,9 @@ class AppHarness: if self.app_instance is None: raise RuntimeError("App is not running.") state_manager = self.app_instance.state_manager + assert isinstance( + state_manager, StateManagerMemory + ), "Only works with memory state manager" if not self._poll_for( target=lambda: state_manager.states, timeout=timeout, @@ -534,7 +626,6 @@ class Subdir404TCPServer(socketserver.TCPServer): request: the requesting socket client_address: (host, port) referring to the client’s address. """ - print(client_address, type(client_address)) self.RequestHandlerClass( request, client_address, @@ -605,6 +696,7 @@ class AppHarnessProd(AppHarness): workers=reflex.utils.processes.get_num_workers(), ), ) + self.backend.shutdown = self._get_backend_shutdown_handler() self.backend_thread = threading.Thread(target=self.backend.run) self.backend_thread.start() diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index 542fe5a54..878f6fb16 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -5,3 +5,11 @@ class InvalidStylePropError(TypeError): """Custom Type Error when style props have invalid values.""" pass + + +class ImmutableStateError(AttributeError): + """Raised when a background task attempts to modify state outside of context.""" + + +class LockExpiredError(Exception): + """Raised when the state lock expires while an event is being processed.""" diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 664a4ed2f..540955673 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -21,7 +21,7 @@ import httpx import typer from alembic.util.exc import CommandError from packaging import version -from redis import Redis +from redis.asyncio import Redis from reflex import constants, model from reflex.compiler import templates @@ -124,9 +124,11 @@ def get_redis() -> Redis | None: The redis client. """ config = get_config() - if config.redis_url is None: + if not config.redis_url: return None - redis_url, redis_port = config.redis_url.split(":") + redis_url, has_port, redis_port = config.redis_url.partition(":") + if not has_port: + redis_port = 6379 console.info(f"Using redis at {config.redis_url}") return Redis(host=redis_url, port=int(redis_port), db=0) diff --git a/tests/conftest.py b/tests/conftest.py index 895954afa..d2dd301f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,8 +2,9 @@ import contextlib import os import platform +import uuid from pathlib import Path -from typing import Dict, Generator, List, Set, Union +from typing import Dict, Generator import pytest @@ -11,6 +12,14 @@ import reflex as rx from reflex.app import App from reflex.event import EventSpec +from .states import ( + DictMutationTestState, + ListMutationTestState, + MutableTestState, + SubUploadState, + UploadState, +) + @pytest.fixture def app() -> App: @@ -39,60 +48,7 @@ def list_mutation_state(): Returns: A state with list mutation features. """ - - class TestState(rx.State): - """The test state.""" - - # plain list - plain_friends = ["Tommy"] - - def make_friend(self): - self.plain_friends.append("another-fd") - - def change_first_friend(self): - self.plain_friends[0] = "Jenny" - - def unfriend_all_friends(self): - self.plain_friends.clear() - - def unfriend_first_friend(self): - del self.plain_friends[0] - - def remove_last_friend(self): - self.plain_friends.pop() - - def make_friends_with_colleagues(self): - colleagues = ["Peter", "Jimmy"] - self.plain_friends.extend(colleagues) - - def remove_tommy(self): - self.plain_friends.remove("Tommy") - - # list in dict - friends_in_dict = {"Tommy": ["Jenny"]} - - def remove_jenny_from_tommy(self): - self.friends_in_dict["Tommy"].remove("Jenny") - - def add_jimmy_to_tommy_friends(self): - self.friends_in_dict["Tommy"].append("Jimmy") - - def tommy_has_no_fds(self): - self.friends_in_dict["Tommy"].clear() - - # nested list - friends_in_nested_list = [["Tommy"], ["Jenny"]] - - def remove_first_group(self): - self.friends_in_nested_list.pop(0) - - def remove_first_person_from_first_group(self): - self.friends_in_nested_list[0].pop(0) - - def add_jimmy_to_second_group(self): - self.friends_in_nested_list[1].append("Jimmy") - - return TestState() + return ListMutationTestState() @pytest.fixture @@ -102,85 +58,7 @@ def dict_mutation_state(): Returns: A state with dict mutation features. """ - - class TestState(rx.State): - """The test state.""" - - # plain dict - details = {"name": "Tommy"} - - def add_age(self): - self.details.update({"age": 20}) # type: ignore - - def change_name(self): - self.details["name"] = "Jenny" - - def remove_last_detail(self): - self.details.popitem() - - def clear_details(self): - self.details.clear() - - def remove_name(self): - del self.details["name"] - - def pop_out_age(self): - self.details.pop("age") - - # dict in list - address = [{"home": "home address"}, {"work": "work address"}] - - def remove_home_address(self): - self.address[0].pop("home") - - def add_street_to_home_address(self): - self.address[0]["street"] = "street address" - - # nested dict - friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}} - - def change_friend_name(self): - self.friend_in_nested_dict["friend"]["name"] = "Tommy" - - def remove_friend(self): - self.friend_in_nested_dict.pop("friend") - - def add_friend_age(self): - self.friend_in_nested_dict["friend"]["age"] = 30 - - return TestState() - - -class UploadState(rx.State): - """The base state for uploading a file.""" - - async def handle_upload1(self, files: List[rx.UploadFile]): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - pass - - -class BaseState(rx.State): - """The test base state.""" - - pass - - -class SubUploadState(BaseState): - """The test substate.""" - - img: str - - async def handle_upload(self, files: List[rx.UploadFile]): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - pass + return DictMutationTestState() @pytest.fixture @@ -203,187 +81,6 @@ def upload_event_spec(): return EventSpec(handler=UploadState.handle_upload1, upload=True) # type: ignore -@pytest.fixture -def upload_state(tmp_path): - """Create upload state. - - Args: - tmp_path: pytest tmp_path - - Returns: - The state - - """ - - class FileUploadState(rx.State): - """The base state for uploading a file.""" - - img_list: List[str] - - async def handle_upload2(self, files): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - for file in files: - upload_data = await file.read() - outfile = f"{tmp_path}/{file.filename}" - - # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) - - # Update the img var. - self.img_list.append(file.filename) - - async def multi_handle_upload(self, files: List[rx.UploadFile]): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - for file in files: - upload_data = await file.read() - outfile = f"{tmp_path}/{file.filename}" - - # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) - - # Update the img var. - assert file.filename is not None - self.img_list.append(file.filename) - - return FileUploadState - - -@pytest.fixture -def upload_sub_state(tmp_path): - """Create upload substate. - - Args: - tmp_path: pytest tmp_path - - Returns: - The state - - """ - - class FileState(rx.State): - """The base state.""" - - pass - - class FileUploadState(FileState): - """The substate for uploading a file.""" - - img_list: List[str] - - async def handle_upload2(self, files): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - for file in files: - upload_data = await file.read() - outfile = f"{tmp_path}/{file.filename}" - - # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) - - # Update the img var. - self.img_list.append(file.filename) - - async def multi_handle_upload(self, files: List[rx.UploadFile]): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - for file in files: - upload_data = await file.read() - outfile = f"{tmp_path}/{file.filename}" - - # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) - - # Update the img var. - assert file.filename is not None - self.img_list.append(file.filename) - - return FileUploadState - - -@pytest.fixture -def upload_grand_sub_state(tmp_path): - """Create upload grand-state. - - Args: - tmp_path: pytest tmp_path - - Returns: - The state - - """ - - class BaseFileState(rx.State): - """The base state.""" - - pass - - class FileSubState(BaseFileState): - """The substate.""" - - pass - - class FileUploadState(FileSubState): - """The grand-substate for uploading a file.""" - - img_list: List[str] - - async def handle_upload2(self, files): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - for file in files: - upload_data = await file.read() - outfile = f"{tmp_path}/{file.filename}" - - # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) - - # Update the img var. - assert file.filename is not None - self.img_list.append(file.filename) - - async def multi_handle_upload(self, files: List[rx.UploadFile]): - """Handle the upload of a file. - - Args: - files: The uploaded files. - """ - for file in files: - upload_data = await file.read() - outfile = f"{tmp_path}/{file.filename}" - - # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) - - # Update the img var. - assert file.filename is not None - self.img_list.append(file.filename) - - return FileUploadState - - @pytest.fixture def base_config_values() -> Dict: """Get base config values. @@ -418,35 +115,6 @@ def sqlite_db_config_values(base_db_config_values) -> Dict: return base_db_config_values -class GenState(rx.State): - """A state with event handlers that generate multiple updates.""" - - value: int - - def go(self, c: int): - """Increment the value c times and update each time. - - Args: - c: The number of times to increment. - - Yields: - After each increment. - """ - for _ in range(c): - self.value += 1 - yield - - -@pytest.fixture -def gen_state() -> GenState: - """A state. - - Returns: - A test state. - """ - return GenState # type: ignore - - @pytest.fixture def router_data_headers() -> Dict[str, str]: """Router data headers. @@ -546,46 +214,19 @@ def mutable_state(): Returns: A state object. """ - - class OtherBase(rx.Base): - bar: str = "" - - class CustomVar(rx.Base): - foo: str = "" - array: List[str] = [] - hashmap: Dict[str, str] = {} - test_set: Set[str] = set() - custom: OtherBase = OtherBase() - - class MutableTestState(rx.State): - """A test state.""" - - array: List[Union[str, List, Dict[str, str]]] = [ - "value", - [1, 2, 3], - {"key": "value"}, - ] - hashmap: Dict[str, Union[List, str, Dict[str, str]]] = { - "key": ["list", "of", "values"], - "another_key": "another_value", - "third_key": {"key": "value"}, - } - test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"} - custom: CustomVar = CustomVar() - _be_custom: CustomVar = CustomVar() - - def reassign_mutables(self): - self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}] - self.hashmap = { - "mod_key": ["list", "of", "values"], - "mod_another_key": "another_value", - "mod_third_key": {"key": "value"}, - } - self.test_set = {1, 2, 3, 4, "five"} - return MutableTestState() +@pytest.fixture(scope="function") +def token() -> str: + """Create a token. + + Returns: + A fresh/unique token string. + """ + return str(uuid.uuid4()) + + @pytest.fixture def duplicate_substate(): """Create a Test state that has duplicate child substates. diff --git a/tests/states/__init__.py b/tests/states/__init__.py new file mode 100644 index 000000000..bcbb337af --- /dev/null +++ b/tests/states/__init__.py @@ -0,0 +1,30 @@ +"""Common rx.State subclasses for use in tests.""" +import reflex as rx + +from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState +from .upload import ( + ChildFileUploadState, + FileUploadState, + GrandChildFileUploadState, + SubUploadState, + UploadState, +) + + +class GenState(rx.State): + """A state with event handlers that generate multiple updates.""" + + value: int + + def go(self, c: int): + """Increment the value c times and update each time. + + Args: + c: The number of times to increment. + + Yields: + After each increment. + """ + for _ in range(c): + self.value += 1 + yield diff --git a/tests/states/mutation.py b/tests/states/mutation.py new file mode 100644 index 000000000..b3d98301f --- /dev/null +++ b/tests/states/mutation.py @@ -0,0 +1,172 @@ +"""Test states for mutable vars.""" + +from typing import Dict, List, Set, Union + +import reflex as rx + + +class DictMutationTestState(rx.State): + """A state for testing ReflexDict mutation.""" + + # plain dict + details = {"name": "Tommy"} + + def add_age(self): + """Add an age to the dict.""" + self.details.update({"age": 20}) # type: ignore + + def change_name(self): + """Change the name in the dict.""" + self.details["name"] = "Jenny" + + def remove_last_detail(self): + """Remove the last item in the dict.""" + self.details.popitem() + + def clear_details(self): + """Clear the dict.""" + self.details.clear() + + def remove_name(self): + """Remove the name from the dict.""" + del self.details["name"] + + def pop_out_age(self): + """Pop out the age from the dict.""" + self.details.pop("age") + + # dict in list + address = [{"home": "home address"}, {"work": "work address"}] + + def remove_home_address(self): + """Remove the home address from dict in the list.""" + self.address[0].pop("home") + + def add_street_to_home_address(self): + """Set street key in the dict in the list.""" + self.address[0]["street"] = "street address" + + # nested dict + friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}} + + def change_friend_name(self): + """Change the friend's name in the nested dict.""" + self.friend_in_nested_dict["friend"]["name"] = "Tommy" + + def remove_friend(self): + """Remove the friend from the nested dict.""" + self.friend_in_nested_dict.pop("friend") + + def add_friend_age(self): + """Add an age to the friend in the nested dict.""" + self.friend_in_nested_dict["friend"]["age"] = 30 + + +class ListMutationTestState(rx.State): + """A state for testing ReflexList mutation.""" + + # plain list + plain_friends = ["Tommy"] + + def make_friend(self): + """Add a friend to the list.""" + self.plain_friends.append("another-fd") + + def change_first_friend(self): + """Change the first friend in the list.""" + self.plain_friends[0] = "Jenny" + + def unfriend_all_friends(self): + """Unfriend all friends in the list.""" + self.plain_friends.clear() + + def unfriend_first_friend(self): + """Unfriend the first friend in the list.""" + del self.plain_friends[0] + + def remove_last_friend(self): + """Remove the last friend in the list.""" + self.plain_friends.pop() + + def make_friends_with_colleagues(self): + """Add list of friends to the list.""" + colleagues = ["Peter", "Jimmy"] + self.plain_friends.extend(colleagues) + + def remove_tommy(self): + """Remove Tommy from the list.""" + self.plain_friends.remove("Tommy") + + # list in dict + friends_in_dict = {"Tommy": ["Jenny"]} + + def remove_jenny_from_tommy(self): + """Remove Jenny from Tommy's friends list.""" + self.friends_in_dict["Tommy"].remove("Jenny") + + def add_jimmy_to_tommy_friends(self): + """Add Jimmy to Tommy's friends list.""" + self.friends_in_dict["Tommy"].append("Jimmy") + + def tommy_has_no_fds(self): + """Clear Tommy's friends list.""" + self.friends_in_dict["Tommy"].clear() + + # nested list + friends_in_nested_list = [["Tommy"], ["Jenny"]] + + def remove_first_group(self): + """Remove the first group of friends from the nested list.""" + self.friends_in_nested_list.pop(0) + + def remove_first_person_from_first_group(self): + """Remove the first person from the first group of friends in the nested list.""" + self.friends_in_nested_list[0].pop(0) + + def add_jimmy_to_second_group(self): + """Add Jimmy to the second group of friends in the nested list.""" + self.friends_in_nested_list[1].append("Jimmy") + + +class OtherBase(rx.Base): + """A Base model with a str field.""" + + bar: str = "" + + +class CustomVar(rx.Base): + """A Base model with multiple fields.""" + + foo: str = "" + array: List[str] = [] + hashmap: Dict[str, str] = {} + test_set: Set[str] = set() + custom: OtherBase = OtherBase() + + +class MutableTestState(rx.State): + """A test state.""" + + array: List[Union[str, List, Dict[str, str]]] = [ + "value", + [1, 2, 3], + {"key": "value"}, + ] + hashmap: Dict[str, Union[List, str, Dict[str, str]]] = { + "key": ["list", "of", "values"], + "another_key": "another_value", + "third_key": {"key": "value"}, + } + test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"} + custom: CustomVar = CustomVar() + _be_custom: CustomVar = CustomVar() + + def reassign_mutables(self): + """Assign mutable fields to different values.""" + self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}] + self.hashmap = { + "mod_key": ["list", "of", "values"], + "mod_another_key": "another_value", + "mod_third_key": {"key": "value"}, + } + self.test_set = {1, 2, 3, 4, "five"} diff --git a/tests/states/upload.py b/tests/states/upload.py new file mode 100644 index 000000000..893947930 --- /dev/null +++ b/tests/states/upload.py @@ -0,0 +1,175 @@ +"""Test states for upload-related tests.""" +from pathlib import Path +from typing import ClassVar, List + +import reflex as rx + + +class UploadState(rx.State): + """The base state for uploading a file.""" + + async def handle_upload1(self, files: List[rx.UploadFile]): + """Handle the upload of a file. + + Args: + files: The uploaded files. + """ + pass + + +class BaseState(rx.State): + """The test base state.""" + + pass + + +class SubUploadState(BaseState): + """The test substate.""" + + img: str + + async def handle_upload(self, files: List[rx.UploadFile]): + """Handle the upload of a file. + + Args: + files: The uploaded files. + """ + pass + + +class FileUploadState(rx.State): + """The base state for uploading a file.""" + + img_list: List[str] + _tmp_path: ClassVar[Path] + + async def handle_upload2(self, files): + """Handle the upload of a file. + + Args: + files: The uploaded files. + """ + for file in files: + upload_data = await file.read() + outfile = f"{self._tmp_path}/{file.filename}" + + # Save the file. + with open(outfile, "wb") as file_object: + file_object.write(upload_data) + + # Update the img var. + self.img_list.append(file.filename) + + async def multi_handle_upload(self, files: List[rx.UploadFile]): + """Handle the upload of a file. + + Args: + files: The uploaded files. + """ + for file in files: + upload_data = await file.read() + outfile = f"{self._tmp_path}/{file.filename}" + + # Save the file. + with open(outfile, "wb") as file_object: + file_object.write(upload_data) + + # Update the img var. + assert file.filename is not None + self.img_list.append(file.filename) + + +class FileStateBase1(rx.State): + """The base state for a child FileUploadState.""" + + pass + + +class ChildFileUploadState(FileStateBase1): + """The child state for uploading a file.""" + + img_list: List[str] + _tmp_path: ClassVar[Path] + + async def handle_upload2(self, files): + """Handle the upload of a file. + + Args: + files: The uploaded files. + """ + for file in files: + upload_data = await file.read() + outfile = f"{self._tmp_path}/{file.filename}" + + # Save the file. + with open(outfile, "wb") as file_object: + file_object.write(upload_data) + + # Update the img var. + self.img_list.append(file.filename) + + async def multi_handle_upload(self, files: List[rx.UploadFile]): + """Handle the upload of a file. + + Args: + files: The uploaded files. + """ + for file in files: + upload_data = await file.read() + outfile = f"{self._tmp_path}/{file.filename}" + + # Save the file. + with open(outfile, "wb") as file_object: + file_object.write(upload_data) + + # Update the img var. + assert file.filename is not None + self.img_list.append(file.filename) + + +class FileStateBase2(FileStateBase1): + """The parent state for a grandchild FileUploadState.""" + + pass + + +class GrandChildFileUploadState(FileStateBase2): + """The child state for uploading a file.""" + + img_list: List[str] + _tmp_path: ClassVar[Path] + + async def handle_upload2(self, files): + """Handle the upload of a file. + + Args: + files: The uploaded files. + """ + for file in files: + upload_data = await file.read() + outfile = f"{self._tmp_path}/{file.filename}" + + # Save the file. + with open(outfile, "wb") as file_object: + file_object.write(upload_data) + + # Update the img var. + self.img_list.append(file.filename) + + async def multi_handle_upload(self, files: List[rx.UploadFile]): + """Handle the upload of a file. + + Args: + files: The uploaded files. + """ + for file in files: + upload_data = await file.read() + outfile = f"{self._tmp_path}/{file.filename}" + + # Save the file. + with open(outfile, "wb") as file_object: + file_object.write(upload_data) + + # Update the img var. + assert file.filename is not None + self.img_list.append(file.filename) diff --git a/tests/test_app.py b/tests/test_app.py index 51f45621e..0044c2c32 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -3,6 +3,7 @@ from __future__ import annotations import io import os.path import sys +import uuid from typing import List, Tuple, Type if sys.version_info.major >= 3 and sys.version_info.minor > 7: @@ -30,11 +31,18 @@ from reflex.components import Box, Component, Cond, Fragment, Text from reflex.event import Event, get_hydrate_event from reflex.middleware import HydrateMiddleware from reflex.model import Model -from reflex.state import State, StateUpdate +from reflex.state import State, StateManagerRedis, StateUpdate from reflex.style import Style from reflex.utils import format from reflex.vars import ComputedVar +from .states import ( + ChildFileUploadState, + FileUploadState, + GenState, + GrandChildFileUploadState, +) + @pytest.fixture def index_page(): @@ -64,6 +72,12 @@ def about_page(): return about +class ATestState(State): + """A simple state for testing.""" + + var: int + + @pytest.fixture() def test_state() -> Type[State]: """A default state. @@ -71,11 +85,7 @@ def test_state() -> Type[State]: Returns: A default state. """ - - class TestState(State): - var: int - - return TestState + return ATestState @pytest.fixture() @@ -313,23 +323,28 @@ def test_initialize_admin_dashboard_with_view_overrides(test_model): assert app.admin_dash.view_overrides[test_model] == TestModelView -def test_initialize_with_state(test_state): +@pytest.mark.asyncio +async def test_initialize_with_state(test_state: Type[ATestState], token: str): """Test setting the state of an app. Args: test_state: The default state. + token: a Token. """ app = App(state=test_state) assert app.state == test_state # Get a state for a given token. - token = "token" - state = app.state_manager.get_state(token) + state = await app.state_manager.get_state(token) assert isinstance(state, test_state) assert state.var == 0 # type: ignore + if isinstance(app.state_manager, StateManagerRedis): + await app.state_manager.redis.close() -def test_set_and_get_state(test_state): + +@pytest.mark.asyncio +async def test_set_and_get_state(test_state): """Test setting and getting the state of an app with different tokens. Args: @@ -338,47 +353,51 @@ def test_set_and_get_state(test_state): app = App(state=test_state) # Create two tokens. - token1 = "token1" - token2 = "token2" + token1 = str(uuid.uuid4()) + token2 = str(uuid.uuid4()) # Get the default state for each token. - state1 = app.state_manager.get_state(token1) - state2 = app.state_manager.get_state(token2) + state1 = await app.state_manager.get_state(token1) + state2 = await app.state_manager.get_state(token2) assert state1.var == 0 # type: ignore assert state2.var == 0 # type: ignore # Set the vars to different values. state1.var = 1 state2.var = 2 - app.state_manager.set_state(token1, state1) - app.state_manager.set_state(token2, state2) + await app.state_manager.set_state(token1, state1) + await app.state_manager.set_state(token2, state2) # Get the states again and check the values. - state1 = app.state_manager.get_state(token1) - state2 = app.state_manager.get_state(token2) + state1 = await app.state_manager.get_state(token1) + state2 = await app.state_manager.get_state(token2) assert state1.var == 1 # type: ignore assert state2.var == 2 # type: ignore + if isinstance(app.state_manager, StateManagerRedis): + await app.state_manager.redis.close() + @pytest.mark.asyncio -async def test_dynamic_var_event(test_state): +async def test_dynamic_var_event(test_state: Type[ATestState], token: str): """Test that the default handler of a dynamic generated var works as expected. Args: test_state: State Fixture. + token: a Token. """ - test_state = test_state() - test_state.add_var("int_val", int, 0) - result = await test_state._process( + state = test_state() # type: ignore + state.add_var("int_val", int, 0) + result = await state._process( Event( - token="fake-token", - name="test_state.set_int_val", + token=token, + name=f"{test_state.get_name()}.set_int_val", router_data={"pathname": "/", "query": {}}, payload={"value": 50}, ) ).__anext__() - assert result.delta == {"test_state": {"int_val": 50}} + assert result.delta == {test_state.get_name(): {"int_val": 50}} @pytest.mark.asyncio @@ -388,12 +407,20 @@ async def test_dynamic_var_event(test_state): pytest.param( [ ( - "test_state.make_friend", - {"test_state": {"plain_friends": ["Tommy", "another-fd"]}}, + "list_mutation_test_state.make_friend", + { + "list_mutation_test_state": { + "plain_friends": ["Tommy", "another-fd"] + } + }, ), ( - "test_state.change_first_friend", - {"test_state": {"plain_friends": ["Jenny", "another-fd"]}}, + "list_mutation_test_state.change_first_friend", + { + "list_mutation_test_state": { + "plain_friends": ["Jenny", "another-fd"] + } + }, ), ], id="append then __setitem__", @@ -401,12 +428,12 @@ async def test_dynamic_var_event(test_state): pytest.param( [ ( - "test_state.unfriend_first_friend", - {"test_state": {"plain_friends": []}}, + "list_mutation_test_state.unfriend_first_friend", + {"list_mutation_test_state": {"plain_friends": []}}, ), ( - "test_state.make_friend", - {"test_state": {"plain_friends": ["another-fd"]}}, + "list_mutation_test_state.make_friend", + {"list_mutation_test_state": {"plain_friends": ["another-fd"]}}, ), ], id="delitem then append", @@ -414,20 +441,24 @@ async def test_dynamic_var_event(test_state): pytest.param( [ ( - "test_state.make_friends_with_colleagues", - {"test_state": {"plain_friends": ["Tommy", "Peter", "Jimmy"]}}, + "list_mutation_test_state.make_friends_with_colleagues", + { + "list_mutation_test_state": { + "plain_friends": ["Tommy", "Peter", "Jimmy"] + } + }, ), ( - "test_state.remove_tommy", - {"test_state": {"plain_friends": ["Peter", "Jimmy"]}}, + "list_mutation_test_state.remove_tommy", + {"list_mutation_test_state": {"plain_friends": ["Peter", "Jimmy"]}}, ), ( - "test_state.remove_last_friend", - {"test_state": {"plain_friends": ["Peter"]}}, + "list_mutation_test_state.remove_last_friend", + {"list_mutation_test_state": {"plain_friends": ["Peter"]}}, ), ( - "test_state.unfriend_all_friends", - {"test_state": {"plain_friends": []}}, + "list_mutation_test_state.unfriend_all_friends", + {"list_mutation_test_state": {"plain_friends": []}}, ), ], id="extend, remove, pop, clear", @@ -435,24 +466,28 @@ async def test_dynamic_var_event(test_state): pytest.param( [ ( - "test_state.add_jimmy_to_second_group", + "list_mutation_test_state.add_jimmy_to_second_group", { - "test_state": { + "list_mutation_test_state": { "friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]] } }, ), ( - "test_state.remove_first_person_from_first_group", + "list_mutation_test_state.remove_first_person_from_first_group", { - "test_state": { + "list_mutation_test_state": { "friends_in_nested_list": [[], ["Jenny", "Jimmy"]] } }, ), ( - "test_state.remove_first_group", - {"test_state": {"friends_in_nested_list": [["Jenny", "Jimmy"]]}}, + "list_mutation_test_state.remove_first_group", + { + "list_mutation_test_state": { + "friends_in_nested_list": [["Jenny", "Jimmy"]] + } + }, ), ], id="nested list", @@ -460,16 +495,24 @@ async def test_dynamic_var_event(test_state): pytest.param( [ ( - "test_state.add_jimmy_to_tommy_friends", - {"test_state": {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}}, + "list_mutation_test_state.add_jimmy_to_tommy_friends", + { + "list_mutation_test_state": { + "friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]} + } + }, ), ( - "test_state.remove_jenny_from_tommy", - {"test_state": {"friends_in_dict": {"Tommy": ["Jimmy"]}}}, + "list_mutation_test_state.remove_jenny_from_tommy", + { + "list_mutation_test_state": { + "friends_in_dict": {"Tommy": ["Jimmy"]} + } + }, ), ( - "test_state.tommy_has_no_fds", - {"test_state": {"friends_in_dict": {"Tommy": []}}}, + "list_mutation_test_state.tommy_has_no_fds", + {"list_mutation_test_state": {"friends_in_dict": {"Tommy": []}}}, ), ], id="list in dict", @@ -477,7 +520,9 @@ async def test_dynamic_var_event(test_state): ], ) async def test_list_mutation_detection__plain_list( - event_tuples: List[Tuple[str, List[str]]], list_mutation_state: State + event_tuples: List[Tuple[str, List[str]]], + list_mutation_state: State, + token: str, ): """Test list mutation detection when reassignment is not explicitly included in the logic. @@ -485,11 +530,12 @@ async def test_list_mutation_detection__plain_list( Args: event_tuples: From parametrization. list_mutation_state: A state with list mutation features. + token: a Token. """ for event_name, expected_delta in event_tuples: result = await list_mutation_state._process( Event( - token="fake-token", + token=token, name=event_name, router_data={"pathname": "/", "query": {}}, payload={}, @@ -506,16 +552,24 @@ async def test_list_mutation_detection__plain_list( pytest.param( [ ( - "test_state.add_age", - {"test_state": {"details": {"name": "Tommy", "age": 20}}}, + "dict_mutation_test_state.add_age", + { + "dict_mutation_test_state": { + "details": {"name": "Tommy", "age": 20} + } + }, ), ( - "test_state.change_name", - {"test_state": {"details": {"name": "Jenny", "age": 20}}}, + "dict_mutation_test_state.change_name", + { + "dict_mutation_test_state": { + "details": {"name": "Jenny", "age": 20} + } + }, ), ( - "test_state.remove_last_detail", - {"test_state": {"details": {"name": "Jenny"}}}, + "dict_mutation_test_state.remove_last_detail", + {"dict_mutation_test_state": {"details": {"name": "Jenny"}}}, ), ], id="update then __setitem__", @@ -523,12 +577,12 @@ async def test_list_mutation_detection__plain_list( pytest.param( [ ( - "test_state.clear_details", - {"test_state": {"details": {}}}, + "dict_mutation_test_state.clear_details", + {"dict_mutation_test_state": {"details": {}}}, ), ( - "test_state.add_age", - {"test_state": {"details": {"age": 20}}}, + "dict_mutation_test_state.add_age", + {"dict_mutation_test_state": {"details": {"age": 20}}}, ), ], id="delitem then update", @@ -536,16 +590,20 @@ async def test_list_mutation_detection__plain_list( pytest.param( [ ( - "test_state.add_age", - {"test_state": {"details": {"name": "Tommy", "age": 20}}}, + "dict_mutation_test_state.add_age", + { + "dict_mutation_test_state": { + "details": {"name": "Tommy", "age": 20} + } + }, ), ( - "test_state.remove_name", - {"test_state": {"details": {"age": 20}}}, + "dict_mutation_test_state.remove_name", + {"dict_mutation_test_state": {"details": {"age": 20}}}, ), ( - "test_state.pop_out_age", - {"test_state": {"details": {}}}, + "dict_mutation_test_state.pop_out_age", + {"dict_mutation_test_state": {"details": {}}}, ), ], id="add, remove, pop", @@ -553,13 +611,17 @@ async def test_list_mutation_detection__plain_list( pytest.param( [ ( - "test_state.remove_home_address", - {"test_state": {"address": [{}, {"work": "work address"}]}}, + "dict_mutation_test_state.remove_home_address", + { + "dict_mutation_test_state": { + "address": [{}, {"work": "work address"}] + } + }, ), ( - "test_state.add_street_to_home_address", + "dict_mutation_test_state.add_street_to_home_address", { - "test_state": { + "dict_mutation_test_state": { "address": [ {"street": "street address"}, {"work": "work address"}, @@ -573,9 +635,9 @@ async def test_list_mutation_detection__plain_list( pytest.param( [ ( - "test_state.change_friend_name", + "dict_mutation_test_state.change_friend_name", { - "test_state": { + "dict_mutation_test_state": { "friend_in_nested_dict": { "name": "Nikhil", "friend": {"name": "Tommy"}, @@ -584,9 +646,9 @@ async def test_list_mutation_detection__plain_list( }, ), ( - "test_state.add_friend_age", + "dict_mutation_test_state.add_friend_age", { - "test_state": { + "dict_mutation_test_state": { "friend_in_nested_dict": { "name": "Nikhil", "friend": {"name": "Tommy", "age": 30}, @@ -595,8 +657,12 @@ async def test_list_mutation_detection__plain_list( }, ), ( - "test_state.remove_friend", - {"test_state": {"friend_in_nested_dict": {"name": "Nikhil"}}}, + "dict_mutation_test_state.remove_friend", + { + "dict_mutation_test_state": { + "friend_in_nested_dict": {"name": "Nikhil"} + } + }, ), ], id="nested dict", @@ -604,7 +670,9 @@ async def test_list_mutation_detection__plain_list( ], ) async def test_dict_mutation_detection__plain_list( - event_tuples: List[Tuple[str, List[str]]], dict_mutation_state: State + event_tuples: List[Tuple[str, List[str]]], + dict_mutation_state: State, + token: str, ): """Test dict mutation detection when reassignment is not explicitly included in the logic. @@ -612,11 +680,12 @@ async def test_dict_mutation_detection__plain_list( Args: event_tuples: From parametrization. dict_mutation_state: A state with dict mutation features. + token: a Token. """ for event_name, expected_delta in event_tuples: result = await dict_mutation_state._process( Event( - token="fake-token", + token=token, name=event_name, router_data={"pathname": "/", "query": {}}, payload={}, @@ -628,41 +697,43 @@ async def test_dict_mutation_detection__plain_list( @pytest.mark.asyncio @pytest.mark.parametrize( - "fixture, delta", + ("state", "delta"), [ ( - "upload_state", + FileUploadState, {"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}}, ), ( - "upload_sub_state", + ChildFileUploadState, { - "file_state.file_upload_state": { + "file_state_base1.child_file_upload_state": { "img_list": ["image1.jpg", "image2.jpg"] } }, ), ( - "upload_grand_sub_state", + GrandChildFileUploadState, { - "base_file_state.file_sub_state.file_upload_state": { + "file_state_base1.file_state_base2.grand_child_file_upload_state": { "img_list": ["image1.jpg", "image2.jpg"] } }, ), ], ) -async def test_upload_file(fixture, request, delta): +async def test_upload_file(tmp_path, state, delta, token: str): """Test that file upload works correctly. Args: - fixture: The state. - request: Fixture request. + tmp_path: Temporary path. + state: The state class. delta: Expected delta + token: a Token. """ - app = App(state=request.getfixturevalue(fixture)) + state._tmp_path = tmp_path + app = App(state=state) app.event_namespace.emit = AsyncMock() # type: ignore - current_state = app.state_manager.get_state("token") + current_state = await app.state_manager.get_state(token) data = b"This is binary data" # Create a binary IO object and write data to it @@ -670,11 +741,11 @@ async def test_upload_file(fixture, request, delta): bio.write(data) file1 = UploadFile( - filename="token:file_upload_state.multi_handle_upload:True:image1.jpg", + filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image1.jpg", file=bio, ) file2 = UploadFile( - filename="token:file_upload_state.multi_handle_upload:True:image2.jpg", + filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image2.jpg", file=bio, ) upload_fn = upload(app) @@ -684,22 +755,27 @@ async def test_upload_file(fixture, request, delta): app.event_namespace.emit.assert_called_with( # type: ignore "event", state_update.json(), to=current_state.get_sid() ) - assert app.state_manager.get_state("token").dict()["img_list"] == [ + assert (await app.state_manager.get_state(token)).dict()["img_list"] == [ "image1.jpg", "image2.jpg", ] + if isinstance(app.state_manager, StateManagerRedis): + await app.state_manager.redis.close() + @pytest.mark.asyncio @pytest.mark.parametrize( - "fixture", ["upload_state", "upload_sub_state", "upload_grand_sub_state"] + "state", + [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) -async def test_upload_file_without_annotation(fixture, request): +async def test_upload_file_without_annotation(state, tmp_path, token): """Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile]. Args: - fixture: The state. - request: Fixture request. + state: The state class. + tmp_path: Temporary path. + token: a Token. """ data = b"This is binary data" @@ -707,14 +783,15 @@ async def test_upload_file_without_annotation(fixture, request): bio = io.BytesIO() bio.write(data) - app = App(state=request.getfixturevalue(fixture)) + state._tmp_path = tmp_path + app = App(state=state) file1 = UploadFile( - filename="token:file_upload_state.handle_upload2:True:image1.jpg", + filename=f"{token}:{state.get_name()}.handle_upload2:True:image1.jpg", file=bio, ) file2 = UploadFile( - filename="token:file_upload_state.handle_upload2:True:image2.jpg", + filename=f"{token}:{state.get_name()}.handle_upload2:True:image2.jpg", file=bio, ) fn = upload(app) @@ -722,9 +799,12 @@ async def test_upload_file_without_annotation(fixture, request): await fn([file1, file2]) assert ( err.value.args[0] - == "`file_upload_state.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]" + == f"`{state.get_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]" ) + if isinstance(app.state_manager, StateManagerRedis): + await app.state_manager.redis.close() + class DynamicState(State): """State class for testing dynamic route var. @@ -768,6 +848,7 @@ class DynamicState(State): async def test_dynamic_route_var_route_change_completed_on_load( index_page, windows_platform: bool, + token: str, ): """Create app with dynamic route var, and simulate navigation. @@ -777,6 +858,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( Args: index_page: The index page. windows_platform: Whether the system is windows. + token: a Token. """ arg_name = "dynamic" route = f"/test/[{arg_name}]" @@ -792,10 +874,9 @@ async def test_dynamic_route_var_route_change_completed_on_load( } assert constants.ROUTER_DATA in app.state().computed_var_dependencies - token = "mock_token" sid = "mock_sid" client_ip = "127.0.0.1" - state = app.state_manager.get_state(token) + state = await app.state_manager.get_state(token) assert state.dynamic == "" exp_vals = ["foo", "foobar", "baz"] @@ -817,6 +898,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( **kwargs, ) + prev_exp_val = "" for exp_index, exp_val in enumerate(exp_vals): hydrate_event = _event(name=get_hydrate_event(state), val=exp_val) exp_router_data = { @@ -826,13 +908,14 @@ async def test_dynamic_route_var_route_change_completed_on_load( "token": token, **hydrate_event.router_data, } - update = await process( + process_coro = process( app, event=hydrate_event, sid=sid, headers={}, client_ip=client_ip, - ).__anext__() # type: ignore + ) + update = await process_coro.__anext__() # type: ignore # route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)] assert update == StateUpdate( @@ -860,14 +943,27 @@ async def test_dynamic_route_var_route_change_completed_on_load( ), ], ) + if isinstance(app.state_manager, StateManagerRedis): + # When redis is used, the state is not updated until the processing is complete + state = await app.state_manager.get_state(token) + assert state.dynamic == prev_exp_val + + # complete the processing + with pytest.raises(StopAsyncIteration): + await process_coro.__anext__() # type: ignore + + # check that router data was written to the state_manager store + state = await app.state_manager.get_state(token) assert state.dynamic == exp_val - on_load_update = await process( + + process_coro = process( app, event=_dynamic_state_event(name="on_load", val=exp_val), sid=sid, headers={}, client_ip=client_ip, - ).__anext__() # type: ignore + ) + on_load_update = await process_coro.__anext__() # type: ignore assert on_load_update == StateUpdate( delta={ state.get_name(): { @@ -879,7 +975,10 @@ async def test_dynamic_route_var_route_change_completed_on_load( }, events=[], ) - on_set_is_hydrated_update = await process( + # complete the processing + with pytest.raises(StopAsyncIteration): + await process_coro.__anext__() # type: ignore + process_coro = process( app, event=_dynamic_state_event( name="set_is_hydrated", payload={"value": True}, val=exp_val @@ -887,7 +986,8 @@ async def test_dynamic_route_var_route_change_completed_on_load( sid=sid, headers={}, client_ip=client_ip, - ).__anext__() # type: ignore + ) + on_set_is_hydrated_update = await process_coro.__anext__() # type: ignore assert on_set_is_hydrated_update == StateUpdate( delta={ state.get_name(): { @@ -899,15 +999,19 @@ async def test_dynamic_route_var_route_change_completed_on_load( }, events=[], ) + # complete the processing + with pytest.raises(StopAsyncIteration): + await process_coro.__anext__() # type: ignore # a simple state update event should NOT trigger on_load or route var side effects - update = await process( + process_coro = process( app, event=_dynamic_state_event(name="on_counter", val=exp_val), sid=sid, headers={}, client_ip=client_ip, - ).__anext__() # type: ignore + ) + update = await process_coro.__anext__() # type: ignore assert update == StateUpdate( delta={ state.get_name(): { @@ -919,42 +1023,54 @@ async def test_dynamic_route_var_route_change_completed_on_load( }, events=[], ) + # complete the processing + with pytest.raises(StopAsyncIteration): + await process_coro.__anext__() # type: ignore + + prev_exp_val = exp_val + state = await app.state_manager.get_state(token) assert state.loaded == len(exp_vals) assert state.counter == len(exp_vals) # print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}") # assert state.side_effect_counter == len(exp_vals) + if isinstance(app.state_manager, StateManagerRedis): + await app.state_manager.redis.close() + @pytest.mark.asyncio -async def test_process_events(gen_state, mocker): +async def test_process_events(mocker, token: str): """Test that an event is processed properly and that it is postprocessed n+1 times. Also check that the processing flag of the last stateupdate is set to False. Args: - gen_state: The state. mocker: mocker object. + token: a Token. """ router_data = { "pathname": "/", "query": {}, - "token": "mock_token", + "token": token, "sid": "mock_sid", "headers": {}, "ip": "127.0.0.1", } - app = App(state=gen_state) + app = App(state=GenState) mocker.patch.object(app, "postprocess", AsyncMock()) event = Event( - token="token", name="gen_state.go", payload={"c": 5}, router_data=router_data + token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data ) async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): # type: ignore pass - assert app.state_manager.get_state("token").value == 5 + assert (await app.state_manager.get_state(token)).value == 5 assert app.postprocess.call_count == 6 + if isinstance(app.state_manager, StateManagerRedis): + await app.state_manager.redis.close() + @pytest.mark.parametrize( ("state", "overlay_component", "exp_page_child"), diff --git a/tests/test_state.py b/tests/test_state.py index cf59e5eb9..e24985bea 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,22 +1,42 @@ from __future__ import annotations +import asyncio import copy import datetime import functools +import json +import os import sys -from typing import Dict, List +from typing import Dict, Generator, List +from unittest.mock import AsyncMock, Mock import pytest from plotly.graph_objects import Figure import reflex as rx from reflex.base import Base -from reflex.constants import IS_HYDRATED, RouteVar +from reflex.constants import APP_VAR, IS_HYDRATED, RouteVar, SocketEvent from reflex.event import Event, EventHandler -from reflex.state import MutableProxy, State -from reflex.utils import format +from reflex.state import ( + ImmutableStateError, + LockExpiredError, + MutableProxy, + State, + StateManager, + StateManagerMemory, + StateManagerRedis, + StateProxy, + StateUpdate, +) +from reflex.utils import format, prerequisites from reflex.vars import BaseVar, ComputedVar +from .states import GenState + +CI = bool(os.environ.get("CI", False)) +LOCK_EXPIRATION = 2000 if CI else 100 +LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2 + class Object(Base): """A test object fixture.""" @@ -704,13 +724,9 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) @pytest.mark.asyncio -async def test_process_event_generator(gen_state): - """Test event handlers that generate multiple updates. - - Args: - gen_state: A state. - """ - gen_state = gen_state() +async def test_process_event_generator(): + """Test event handlers that generate multiple updates.""" + gen_state = GenState() # type: ignore event = Event( token="t", name="go", @@ -1402,6 +1418,396 @@ def test_state_with_invalid_yield(): ) +@pytest.fixture(scope="function", params=["in_process", "redis"]) +def state_manager(request) -> Generator[StateManager, None, None]: + """Instance of state manager parametrized for redis and in-process. + + Args: + request: pytest request object. + + Yields: + A state manager instance + """ + state_manager = StateManager.create(state=TestState) + if request.param == "redis": + if not isinstance(state_manager, StateManagerRedis): + pytest.skip("Test requires redis") + else: + # explicitly NOT using redis + state_manager = StateManagerMemory(state=TestState) + assert not state_manager._states_locks + + yield state_manager + + if isinstance(state_manager, StateManagerRedis): + asyncio.get_event_loop().run_until_complete(state_manager.redis.close()) + + +@pytest.mark.asyncio +async def test_state_manager_modify_state(state_manager: StateManager, token: str): + """Test that the state manager can modify a state exclusively. + + Args: + state_manager: A state manager instance. + token: A token. + """ + async with state_manager.modify_state(token): + if isinstance(state_manager, StateManagerRedis): + assert await state_manager.redis.get(f"{token}_lock") + elif isinstance(state_manager, StateManagerMemory): + assert token in state_manager._states_locks + assert state_manager._states_locks[token].locked() + # lock should be dropped after exiting the context + if isinstance(state_manager, StateManagerRedis): + assert (await state_manager.redis.get(f"{token}_lock")) is None + elif isinstance(state_manager, StateManagerMemory): + assert not state_manager._states_locks[token].locked() + + # separate instances should NOT share locks + sm2 = StateManagerMemory(state=TestState) + assert sm2._state_manager_lock is state_manager._state_manager_lock + assert not sm2._states_locks + if state_manager._states_locks: + assert sm2._states_locks != state_manager._states_locks + + +@pytest.mark.asyncio +async def test_state_manager_contend(state_manager: StateManager, token: str): + """Multiple coroutines attempting to access the same state. + + Args: + state_manager: A state manager instance. + token: A token. + """ + n_coroutines = 10 + exp_num1 = 10 + + async with state_manager.modify_state(token) as state: + state.num1 = 0 + + async def _coro(): + async with state_manager.modify_state(token) as state: + await asyncio.sleep(0.01) + state.num1 += 1 + + tasks = [asyncio.create_task(_coro()) for _ in range(n_coroutines)] + + for f in asyncio.as_completed(tasks): + await f + + assert (await state_manager.get_state(token)).num1 == exp_num1 + + if isinstance(state_manager, StateManagerRedis): + assert (await state_manager.redis.get(f"{token}_lock")) is None + elif isinstance(state_manager, StateManagerMemory): + assert token in state_manager._states_locks + assert not state_manager._states_locks[token].locked() + + +@pytest.fixture(scope="function") +def state_manager_redis() -> Generator[StateManager, None, None]: + """Instance of state manager for redis only. + + Yields: + A state manager instance + """ + state_manager = StateManager.create(TestState) + + if not isinstance(state_manager, StateManagerRedis): + pytest.skip("Test requires redis") + + yield state_manager + + asyncio.get_event_loop().run_until_complete(state_manager.redis.close()) + + +@pytest.mark.asyncio +async def test_state_manager_lock_expire(state_manager_redis: StateManager, token: str): + """Test that the state manager lock expires and raises exception exiting context. + + Args: + state_manager_redis: A state manager instance. + token: A token. + """ + state_manager_redis.lock_expiration = LOCK_EXPIRATION + + async with state_manager_redis.modify_state(token): + await asyncio.sleep(0.01) + + with pytest.raises(LockExpiredError): + async with state_manager_redis.modify_state(token): + await asyncio.sleep(LOCK_EXPIRE_SLEEP) + + +@pytest.mark.asyncio +async def test_state_manager_lock_expire_contend( + state_manager_redis: StateManager, token: str +): + """Test that the state manager lock expires and queued waiters proceed. + + Args: + state_manager_redis: A state manager instance. + token: A token. + """ + exp_num1 = 4252 + unexp_num1 = 666 + + state_manager_redis.lock_expiration = LOCK_EXPIRATION + + order = [] + + async def _coro_blocker(): + async with state_manager_redis.modify_state(token) as state: + order.append("blocker") + await asyncio.sleep(LOCK_EXPIRE_SLEEP) + state.num1 = unexp_num1 + + async def _coro_waiter(): + while "blocker" not in order: + await asyncio.sleep(0.005) + async with state_manager_redis.modify_state(token) as state: + order.append("waiter") + assert state.num1 != unexp_num1 + state.num1 = exp_num1 + + tasks = [ + asyncio.create_task(_coro_blocker()), + asyncio.create_task(_coro_waiter()), + ] + with pytest.raises(LockExpiredError): + await tasks[0] + await tasks[1] + + assert order == ["blocker", "waiter"] + assert (await state_manager_redis.get_state(token)).num1 == exp_num1 + + +@pytest.fixture(scope="function") +def mock_app(monkeypatch, app: rx.App, state_manager: StateManager) -> rx.App: + """Mock app fixture. + + Args: + monkeypatch: Pytest monkeypatch object. + app: An app. + state_manager: A state manager. + + Returns: + The app, after mocking out prerequisites.get_app() + """ + app_module = Mock() + setattr(app_module, APP_VAR, app) + app.state = TestState + app.state_manager = state_manager + assert app.event_namespace is not None + app.event_namespace.emit = AsyncMock() + monkeypatch.setattr(prerequisites, "get_app", lambda: app_module) + return app + + +@pytest.mark.asyncio +async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): + """Test that the state proxy works. + + Args: + grandchild_state: A grandchild state. + mock_app: An app that will be returned by `get_app()` + """ + child_state = grandchild_state.parent_state + assert child_state is not None + parent_state = child_state.parent_state + assert parent_state is not None + if isinstance(mock_app.state_manager, StateManagerMemory): + mock_app.state_manager.states[parent_state.get_token()] = parent_state + + sp = StateProxy(grandchild_state) + assert sp.__wrapped__ == grandchild_state + assert sp._self_substate_path == grandchild_state.get_full_name().split(".") + assert sp._self_app is mock_app + assert not sp._self_mutable + assert sp._self_actx is None + + # cannot use normal contextmanager protocol + with pytest.raises(TypeError), sp: + pass + + with pytest.raises(ImmutableStateError): + # cannot directly modify state proxy outside of async context + sp.value2 = 16 + + async with sp: + assert sp._self_actx is not None + assert sp._self_mutable # proxy is mutable inside context + if isinstance(mock_app.state_manager, StateManagerMemory): + # For in-process store, only one instance of the state exists + assert sp.__wrapped__ is grandchild_state + else: + # When redis is used, a new+updated instance is assigned to the proxy + assert sp.__wrapped__ is not grandchild_state + sp.value2 = 42 + assert not sp._self_mutable # proxy is not mutable after exiting context + assert sp._self_actx is None + assert sp.value2 == 42 + + # Get the state from the state manager directly and check that the value is updated + gotten_state = await mock_app.state_manager.get_state(grandchild_state.get_token()) + if isinstance(mock_app.state_manager, StateManagerMemory): + # For in-process store, only one instance of the state exists + assert gotten_state is parent_state + else: + assert gotten_state is not parent_state + gotten_grandchild_state = gotten_state.get_substate(sp._self_substate_path) + assert gotten_grandchild_state is not None + assert gotten_grandchild_state.value2 == 42 + + # ensure state update was emitted + assert mock_app.event_namespace is not None + mock_app.event_namespace.emit.assert_called_once() + mcall = mock_app.event_namespace.emit.mock_calls[0] + assert mcall.args[0] == str(SocketEvent.EVENT) + assert json.loads(mcall.args[1]) == StateUpdate( + delta={ + parent_state.get_full_name(): { + "upper": "", + "sum": 3.14, + }, + grandchild_state.get_full_name(): { + "value2": 42, + }, + } + ) + assert mcall.kwargs["to"] == grandchild_state.get_sid() + + +class BackgroundTaskState(State): + """A state with a background task.""" + + order: List[str] = [] + dict_list: Dict[str, List[int]] = {"foo": []} + + @rx.background + async def background_task(self): + """A background task that updates the state.""" + async with self: + assert not self.order + self.order.append("background_task:start") + + assert isinstance(self, StateProxy) + with pytest.raises(ImmutableStateError): + self.order.append("bad idea") + + with pytest.raises(ImmutableStateError): + # Even nested access to mutables raises an exception. + self.dict_list["foo"].append(42) + + # wait for some other event to happen + while len(self.order) == 1: + await asyncio.sleep(0.01) + async with self: + pass # update proxy instance + + async with self: + self.order.append("background_task:stop") + + @rx.background + async def background_task_generator(self): + """A background task generator that does nothing. + + Yields: + None + """ + yield + + def other(self): + """Some other event that updates the state.""" + self.order.append("other") + + async def bad_chain1(self): + """Test that a background task cannot be chained.""" + await self.background_task() + + async def bad_chain2(self): + """Test that a background task generator cannot be chained.""" + async for _foo in self.background_task_generator(): + pass + + +@pytest.mark.asyncio +async def test_background_task_no_block(mock_app: rx.App, token: str): + """Test that a background task does not block other events. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + router_data = {"query": {}} + mock_app.state_manager.state = mock_app.state = BackgroundTaskState + async for update in rx.app.process( # type: ignore + mock_app, + Event( + token=token, + name=f"{BackgroundTaskState.get_name()}.background_task", + router_data=router_data, + payload={}, + ), + sid="", + headers={}, + client_ip="", + ): + # background task returns empty update immediately + assert update == StateUpdate() + assert len(mock_app.background_tasks) == 1 + + # wait for the coroutine to start + await asyncio.sleep(0.5 if CI else 0.1) + assert len(mock_app.background_tasks) == 1 + + # Process another normal event + async for update in rx.app.process( # type: ignore + mock_app, + Event( + token=token, + name=f"{BackgroundTaskState.get_name()}.other", + router_data=router_data, + payload={}, + ), + sid="", + headers={}, + client_ip="", + ): + # other task returns delta + assert update == StateUpdate( + delta={ + BackgroundTaskState.get_name(): { + "order": [ + "background_task:start", + "other", + ], + } + } + ) + + # Explicit wait for background tasks + for task in tuple(mock_app.background_tasks): + await task + assert not mock_app.background_tasks + + assert (await mock_app.state_manager.get_state(token)).order == [ + "background_task:start", + "other", + "background_task:stop", + ] + + +@pytest.mark.asyncio +async def test_background_task_no_chain(): + """Test that a background task cannot be chained.""" + bts = BackgroundTaskState() + with pytest.raises(RuntimeError): + await bts.bad_chain1() + with pytest.raises(RuntimeError): + await bts.bad_chain2() + + def test_mutable_list(mutable_state): """Test that mutable lists are tracked correctly.