rx.background and StateManager.modify_state provides safe exclusive access to state (#1676)
This commit is contained in:
parent
211dc15995
commit
351611ca25
17
.github/workflows/integration_app_harness.yml
vendored
17
.github/workflows/integration_app_harness.yml
vendored
@ -15,7 +15,23 @@ permissions:
|
||||
|
||||
jobs:
|
||||
integration-app-harness:
|
||||
strategy:
|
||||
matrix:
|
||||
state_manager: [ "redis", "memory" ]
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
# Label used to access the service container
|
||||
redis:
|
||||
image: ${{ matrix.state_manager == 'redis' && 'redis' || '' }}
|
||||
# Set health checks to wait until redis has started
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
# Maps port 6379 on service container to the host
|
||||
- 6379:6379
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: ./.github/actions/setup_build_env
|
||||
@ -27,6 +43,7 @@ jobs:
|
||||
- name: Run app harness tests
|
||||
env:
|
||||
SCREENSHOT_DIR: /tmp/screenshots
|
||||
REDIS_URL: ${{ matrix.state_manager == 'redis' && 'localhost:6379' || '' }}
|
||||
run: |
|
||||
poetry run pytest integration
|
||||
- uses: actions/upload-artifact@v3
|
||||
|
20
.github/workflows/unit_tests.yml
vendored
20
.github/workflows/unit_tests.yml
vendored
@ -40,6 +40,20 @@ jobs:
|
||||
- os: windows-latest
|
||||
python-version: "3.8.10"
|
||||
runs-on: ${{ matrix.os }}
|
||||
# Service containers to run with `runner-job`
|
||||
services:
|
||||
# Label used to access the service container
|
||||
redis:
|
||||
image: ${{ matrix.os == 'ubuntu-latest' && 'redis' || '' }}
|
||||
# Set health checks to wait until redis has started
|
||||
options: >-
|
||||
--health-cmd "redis-cli ping"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
# Maps port 6379 on service container to the host
|
||||
- 6379:6379
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: ./.github/actions/setup_build_env
|
||||
@ -51,4 +65,10 @@ jobs:
|
||||
run: |
|
||||
export PYTHONUNBUFFERED=1
|
||||
poetry run pytest tests --cov --no-cov-on-fail --cov-report=
|
||||
- name: Run unit tests w/ redis
|
||||
if: ${{ matrix.os == 'ubuntu-latest' }}
|
||||
run: |
|
||||
export PYTHONUNBUFFERED=1
|
||||
export REDIS_URL=localhost:6379
|
||||
poetry run pytest tests --cov --no-cov-on-fail --cov-report=
|
||||
- run: poetry run coverage html
|
||||
|
@ -1,4 +1,10 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.10.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [integration, reflex, tests]
|
||||
|
||||
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
||||
rev: v0.0.244
|
||||
hooks:
|
||||
@ -17,9 +23,3 @@ repos:
|
||||
hooks:
|
||||
- id: darglint
|
||||
exclude: '^reflex/reflex.py'
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.10.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [integration, reflex, tests]
|
||||
|
214
integration/test_background_task.py
Normal file
214
integration/test_background_task.py
Normal file
@ -0,0 +1,214 @@
|
||||
"""Test @rx.background task functionality."""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from selenium.webdriver.common.by import By
|
||||
|
||||
from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
|
||||
|
||||
|
||||
def BackgroundTask():
|
||||
"""Test that background tasks work as expected."""
|
||||
import asyncio
|
||||
|
||||
import reflex as rx
|
||||
|
||||
class State(rx.State):
|
||||
counter: int = 0
|
||||
_task_id: int = 0
|
||||
iterations: int = 10
|
||||
|
||||
@rx.background
|
||||
async def handle_event(self):
|
||||
async with self:
|
||||
self._task_id += 1
|
||||
for _ix in range(int(self.iterations)):
|
||||
async with self:
|
||||
self.counter += 1
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
@rx.background
|
||||
async def handle_event_yield_only(self):
|
||||
async with self:
|
||||
self._task_id += 1
|
||||
for ix in range(int(self.iterations)):
|
||||
if ix % 2 == 0:
|
||||
yield State.increment_arbitrary(1) # type: ignore
|
||||
else:
|
||||
yield State.increment() # type: ignore
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
def increment(self):
|
||||
self.counter += 1
|
||||
|
||||
@rx.background
|
||||
async def increment_arbitrary(self, amount: int):
|
||||
async with self:
|
||||
self.counter += int(amount)
|
||||
|
||||
def reset_counter(self):
|
||||
self.counter = 0
|
||||
|
||||
async def blocking_pause(self):
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
@rx.background
|
||||
async def non_blocking_pause(self):
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
@rx.cached_var
|
||||
def token(self) -> str:
|
||||
return self.get_token()
|
||||
|
||||
def index() -> rx.Component:
|
||||
return rx.vstack(
|
||||
rx.input(id="token", value=State.token, is_read_only=True),
|
||||
rx.heading(State.counter, id="counter"),
|
||||
rx.input(
|
||||
id="iterations",
|
||||
placeholder="Iterations",
|
||||
value=State.iterations.to_string(), # type: ignore
|
||||
on_change=State.set_iterations, # type: ignore
|
||||
),
|
||||
rx.button(
|
||||
"Delayed Increment",
|
||||
on_click=State.handle_event,
|
||||
id="delayed-increment",
|
||||
),
|
||||
rx.button(
|
||||
"Yield Increment",
|
||||
on_click=State.handle_event_yield_only,
|
||||
id="yield-increment",
|
||||
),
|
||||
rx.button("Increment 1", on_click=State.increment, id="increment"),
|
||||
rx.button(
|
||||
"Blocking Pause",
|
||||
on_click=State.blocking_pause,
|
||||
id="blocking-pause",
|
||||
),
|
||||
rx.button(
|
||||
"Non-Blocking Pause",
|
||||
on_click=State.non_blocking_pause,
|
||||
id="non-blocking-pause",
|
||||
),
|
||||
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
||||
)
|
||||
|
||||
app = rx.App(state=State)
|
||||
app.add_page(index)
|
||||
app.compile()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def background_task(
|
||||
tmp_path_factory,
|
||||
) -> Generator[AppHarness, None, None]:
|
||||
"""Start BackgroundTask app at tmp_path via AppHarness.
|
||||
|
||||
Args:
|
||||
tmp_path_factory: pytest tmp_path_factory fixture
|
||||
|
||||
Yields:
|
||||
running AppHarness instance
|
||||
"""
|
||||
with AppHarness.create(
|
||||
root=tmp_path_factory.mktemp(f"background_task"),
|
||||
app_source=BackgroundTask, # type: ignore
|
||||
) as harness:
|
||||
yield harness
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
|
||||
"""Get an instance of the browser open to the background_task app.
|
||||
|
||||
Args:
|
||||
background_task: harness for BackgroundTask app
|
||||
|
||||
Yields:
|
||||
WebDriver instance.
|
||||
"""
|
||||
assert background_task.app_instance is not None, "app is not running"
|
||||
driver = background_task.frontend()
|
||||
try:
|
||||
yield driver
|
||||
finally:
|
||||
driver.quit()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def token(background_task: AppHarness, driver: WebDriver) -> str:
|
||||
"""Get a function that returns the active token.
|
||||
|
||||
Args:
|
||||
background_task: harness for BackgroundTask app.
|
||||
driver: WebDriver instance.
|
||||
|
||||
Returns:
|
||||
The token for the connected client
|
||||
"""
|
||||
assert background_task.app_instance is not None
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
assert token_input
|
||||
|
||||
# wait for the backend connection to send the token
|
||||
token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
|
||||
assert token is not None
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def test_background_task(
|
||||
background_task: AppHarness,
|
||||
driver: WebDriver,
|
||||
token: str,
|
||||
):
|
||||
"""Test that background tasks work as expected.
|
||||
|
||||
Args:
|
||||
background_task: harness for BackgroundTask app.
|
||||
driver: WebDriver instance.
|
||||
token: The token for the connected client.
|
||||
"""
|
||||
assert background_task.app_instance is not None
|
||||
|
||||
# get a reference to all buttons
|
||||
delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
|
||||
yield_increment_button = driver.find_element(By.ID, "yield-increment")
|
||||
increment_button = driver.find_element(By.ID, "increment")
|
||||
blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
|
||||
non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
|
||||
driver.find_element(By.ID, "reset")
|
||||
|
||||
# get a reference to the counter
|
||||
counter = driver.find_element(By.ID, "counter")
|
||||
|
||||
# get a reference to the iterations input
|
||||
iterations_input = driver.find_element(By.ID, "iterations")
|
||||
|
||||
# kick off background tasks
|
||||
iterations_input.clear()
|
||||
iterations_input.send_keys("50")
|
||||
delayed_increment_button.click()
|
||||
blocking_pause_button.click()
|
||||
delayed_increment_button.click()
|
||||
for _ in range(10):
|
||||
increment_button.click()
|
||||
blocking_pause_button.click()
|
||||
delayed_increment_button.click()
|
||||
delayed_increment_button.click()
|
||||
yield_increment_button.click()
|
||||
non_blocking_pause_button.click()
|
||||
yield_increment_button.click()
|
||||
blocking_pause_button.click()
|
||||
yield_increment_button.click()
|
||||
for _ in range(10):
|
||||
increment_button.click()
|
||||
yield_increment_button.click()
|
||||
blocking_pause_button.click()
|
||||
assert background_task._poll_for(lambda: counter.text == "420", timeout=40)
|
||||
# all tasks should have exited and cleaned up
|
||||
assert background_task._poll_for(
|
||||
lambda: not background_task.app_instance.background_tasks # type: ignore
|
||||
)
|
@ -133,7 +133,6 @@ def driver(client_side: AppHarness) -> Generator[WebDriver, None, None]:
|
||||
assert client_side.app_instance is not None, "app is not running"
|
||||
driver = client_side.frontend()
|
||||
try:
|
||||
assert client_side.poll_for_clients()
|
||||
yield driver
|
||||
finally:
|
||||
driver.quit()
|
||||
@ -168,7 +167,20 @@ def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]:
|
||||
driver.delete_all_cookies()
|
||||
|
||||
|
||||
def test_client_side_state(
|
||||
def cookie_info_map(driver: WebDriver) -> dict[str, dict[str, str]]:
|
||||
"""Get a map of cookie names to cookie info.
|
||||
|
||||
Args:
|
||||
driver: WebDriver instance.
|
||||
|
||||
Returns:
|
||||
A map of cookie names to cookie info.
|
||||
"""
|
||||
return {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_side_state(
|
||||
client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage
|
||||
):
|
||||
"""Test client side state.
|
||||
@ -187,8 +199,6 @@ def test_client_side_state(
|
||||
token = client_side.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
|
||||
backend_state = client_side.app_instance.state_manager.states[token]
|
||||
|
||||
# get a reference to the cookie manipulation form
|
||||
state_var_input = driver.find_element(By.ID, "state_var")
|
||||
input_value_input = driver.find_element(By.ID, "input_value")
|
||||
@ -274,7 +284,7 @@ def test_client_side_state(
|
||||
input_value_input.send_keys("l1s value")
|
||||
set_sub_sub_state_button.click()
|
||||
|
||||
cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
|
||||
cookies = cookie_info_map(driver)
|
||||
assert cookies.pop("client_side_state.client_side_sub_state.c1") == {
|
||||
"domain": "localhost",
|
||||
"httpOnly": False,
|
||||
@ -338,8 +348,10 @@ def test_client_side_state(
|
||||
state_var_input.send_keys("c3")
|
||||
input_value_input.send_keys("c3 value")
|
||||
set_sub_state_button.click()
|
||||
cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
|
||||
c3_cookie = cookies["client_side_state.client_side_sub_state.c3"]
|
||||
AppHarness._poll_for(
|
||||
lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver)
|
||||
)
|
||||
c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"]
|
||||
assert c3_cookie.pop("expiry") is not None
|
||||
assert c3_cookie == {
|
||||
"domain": "localhost",
|
||||
@ -351,9 +363,7 @@ def test_client_side_state(
|
||||
"value": "c3%20value",
|
||||
}
|
||||
time.sleep(2) # wait for c3 to expire
|
||||
assert "client_side_state.client_side_sub_state.c3" not in {
|
||||
cookie_info["name"] for cookie_info in driver.get_cookies()
|
||||
}
|
||||
assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver)
|
||||
|
||||
local_storage_items = local_storage.items()
|
||||
local_storage_items.pop("chakra-ui-color-mode", None)
|
||||
@ -426,7 +436,8 @@ def test_client_side_state(
|
||||
assert l1s.text == "l1s value"
|
||||
|
||||
# reset the backend state to force refresh from client storage
|
||||
backend_state.reset()
|
||||
async with client_side.modify_state(token) as state:
|
||||
state.reset()
|
||||
driver.refresh()
|
||||
|
||||
# wait for the backend connection to send the token (again)
|
||||
@ -465,9 +476,7 @@ def test_client_side_state(
|
||||
assert l1s.text == "l1s value"
|
||||
|
||||
# make sure c5 cookie shows up on the `/foo` route
|
||||
cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
|
||||
|
||||
assert cookies["client_side_state.client_side_sub_state.c5"] == {
|
||||
assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
|
||||
"domain": "localhost",
|
||||
"httpOnly": False,
|
||||
"name": "client_side_state.client_side_sub_state.c5",
|
||||
|
@ -1,11 +1,10 @@
|
||||
"""Integration tests for dynamic route page behavior."""
|
||||
from typing import Callable, Generator, Type
|
||||
from typing import Callable, Coroutine, Generator, Type
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import pytest
|
||||
from selenium.webdriver.common.by import By
|
||||
|
||||
from reflex import State
|
||||
from reflex.testing import AppHarness, AppHarnessProd, WebDriver
|
||||
|
||||
from .utils import poll_for_navigation
|
||||
@ -100,22 +99,21 @@ def driver(dynamic_route: AppHarness) -> Generator[WebDriver, None, None]:
|
||||
assert dynamic_route.app_instance is not None, "app is not running"
|
||||
driver = dynamic_route.frontend()
|
||||
try:
|
||||
assert dynamic_route.poll_for_clients()
|
||||
yield driver
|
||||
finally:
|
||||
driver.quit()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
|
||||
"""Get the backend state.
|
||||
def token(dynamic_route: AppHarness, driver: WebDriver) -> str:
|
||||
"""Get the token associated with backend state.
|
||||
|
||||
Args:
|
||||
dynamic_route: harness for DynamicRoute app.
|
||||
driver: WebDriver instance.
|
||||
|
||||
Returns:
|
||||
The backend state associated with the token visible in the driver browser.
|
||||
The token visible in the driver browser.
|
||||
"""
|
||||
assert dynamic_route.app_instance is not None
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
@ -125,43 +123,49 @@ def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
|
||||
token = dynamic_route.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
|
||||
# look up the backend state from the state manager
|
||||
return dynamic_route.app_instance.state_manager.states[token]
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def poll_for_order(
|
||||
dynamic_route: AppHarness, backend_state: State
|
||||
) -> Callable[[list[str]], None]:
|
||||
dynamic_route: AppHarness, token: str
|
||||
) -> Callable[[list[str]], Coroutine[None, None, None]]:
|
||||
"""Poll for the order list to match the expected order.
|
||||
|
||||
Args:
|
||||
dynamic_route: harness for DynamicRoute app.
|
||||
backend_state: The backend state associated with the token visible in the driver browser.
|
||||
token: The token visible in the driver browser.
|
||||
|
||||
Returns:
|
||||
A function that polls for the order list to match the expected order.
|
||||
An async function that polls for the order list to match the expected order.
|
||||
"""
|
||||
|
||||
def _poll_for_order(exp_order: list[str]):
|
||||
dynamic_route._poll_for(lambda: backend_state.order == exp_order)
|
||||
assert backend_state.order == exp_order
|
||||
async def _poll_for_order(exp_order: list[str]):
|
||||
async def _backend_state():
|
||||
return await dynamic_route.get_state(token)
|
||||
|
||||
async def _check():
|
||||
return (await _backend_state()).order == exp_order
|
||||
|
||||
await AppHarness._poll_for_async(_check)
|
||||
assert (await _backend_state()).order == exp_order
|
||||
|
||||
return _poll_for_order
|
||||
|
||||
|
||||
def test_on_load_navigate(
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_load_navigate(
|
||||
dynamic_route: AppHarness,
|
||||
driver: WebDriver,
|
||||
backend_state: State,
|
||||
poll_for_order: Callable[[list[str]], None],
|
||||
token: str,
|
||||
poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
|
||||
):
|
||||
"""Click links to navigate between dynamic pages with on_load event.
|
||||
|
||||
Args:
|
||||
dynamic_route: harness for DynamicRoute app.
|
||||
driver: WebDriver instance.
|
||||
backend_state: The backend state associated with the token visible in the driver browser.
|
||||
token: The token visible in the driver browser.
|
||||
poll_for_order: function that polls for the order list to match the expected order.
|
||||
"""
|
||||
assert dynamic_route.app_instance is not None
|
||||
@ -184,7 +188,7 @@ def test_on_load_navigate(
|
||||
assert page_id_input
|
||||
|
||||
assert dynamic_route.poll_for_value(page_id_input) == str(ix)
|
||||
poll_for_order(exp_order)
|
||||
await poll_for_order(exp_order)
|
||||
|
||||
# manually load the next page to trigger client side routing in prod mode
|
||||
if is_prod:
|
||||
@ -192,14 +196,14 @@ def test_on_load_navigate(
|
||||
exp_order += ["/page/[page_id]-10"]
|
||||
with poll_for_navigation(driver):
|
||||
driver.get(f"{dynamic_route.frontend_url}/page/10/")
|
||||
poll_for_order(exp_order)
|
||||
await poll_for_order(exp_order)
|
||||
|
||||
# make sure internal nav still hydrates after redirect
|
||||
exp_order += ["/page/[page_id]-11"]
|
||||
link = driver.find_element(By.ID, "link_page_next")
|
||||
with poll_for_navigation(driver):
|
||||
link.click()
|
||||
poll_for_order(exp_order)
|
||||
await poll_for_order(exp_order)
|
||||
|
||||
# load same page with a query param and make sure it passes through
|
||||
if is_prod:
|
||||
@ -207,14 +211,14 @@ def test_on_load_navigate(
|
||||
exp_order += ["/page/[page_id]-11"]
|
||||
with poll_for_navigation(driver):
|
||||
driver.get(f"{driver.current_url}?foo=bar")
|
||||
poll_for_order(exp_order)
|
||||
assert backend_state.get_query_params()["foo"] == "bar"
|
||||
await poll_for_order(exp_order)
|
||||
assert (await dynamic_route.get_state(token)).get_query_params()["foo"] == "bar"
|
||||
|
||||
# hit a 404 and ensure we still hydrate
|
||||
exp_order += ["/404-no page id"]
|
||||
with poll_for_navigation(driver):
|
||||
driver.get(f"{dynamic_route.frontend_url}/missing")
|
||||
poll_for_order(exp_order)
|
||||
await poll_for_order(exp_order)
|
||||
|
||||
# browser nav should still trigger hydration
|
||||
if is_prod:
|
||||
@ -222,14 +226,14 @@ def test_on_load_navigate(
|
||||
exp_order += ["/page/[page_id]-11"]
|
||||
with poll_for_navigation(driver):
|
||||
driver.back()
|
||||
poll_for_order(exp_order)
|
||||
await poll_for_order(exp_order)
|
||||
|
||||
# next/link to a 404 and ensure we still hydrate
|
||||
exp_order += ["/404-no page id"]
|
||||
link = driver.find_element(By.ID, "link_missing")
|
||||
with poll_for_navigation(driver):
|
||||
link.click()
|
||||
poll_for_order(exp_order)
|
||||
await poll_for_order(exp_order)
|
||||
|
||||
# hit a page that redirects back to dynamic page
|
||||
if is_prod:
|
||||
@ -237,15 +241,16 @@ def test_on_load_navigate(
|
||||
exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"]
|
||||
with poll_for_navigation(driver):
|
||||
driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar")
|
||||
poll_for_order(exp_order)
|
||||
await poll_for_order(exp_order)
|
||||
# should have redirected back to page 0
|
||||
assert urlsplit(driver.current_url).path == "/page/0/"
|
||||
|
||||
|
||||
def test_on_load_navigate_non_dynamic(
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_load_navigate_non_dynamic(
|
||||
dynamic_route: AppHarness,
|
||||
driver: WebDriver,
|
||||
poll_for_order: Callable[[list[str]], None],
|
||||
poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
|
||||
):
|
||||
"""Click links to navigate between static pages with on_load event.
|
||||
|
||||
@ -261,7 +266,7 @@ def test_on_load_navigate_non_dynamic(
|
||||
with poll_for_navigation(driver):
|
||||
link.click()
|
||||
assert urlsplit(driver.current_url).path == "/static/x/"
|
||||
poll_for_order(["/static/x-no page id"])
|
||||
await poll_for_order(["/static/x-no page id"])
|
||||
|
||||
# go back to the index and navigate back to the static route
|
||||
link = driver.find_element(By.ID, "link_index")
|
||||
@ -273,4 +278,4 @@ def test_on_load_navigate_non_dynamic(
|
||||
with poll_for_navigation(driver):
|
||||
link.click()
|
||||
assert urlsplit(driver.current_url).path == "/static/x/"
|
||||
poll_for_order(["/static/x-no page id", "/static/x-no page id"])
|
||||
await poll_for_order(["/static/x-no page id", "/static/x-no page id"])
|
||||
|
@ -1,18 +1,20 @@
|
||||
"""Ensure that Event Chains are properly queued and handled between frontend and backend."""
|
||||
|
||||
import time
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from selenium.webdriver.common.by import By
|
||||
|
||||
from reflex.testing import AppHarness
|
||||
from reflex.testing import AppHarness, WebDriver
|
||||
|
||||
MANY_EVENTS = 50
|
||||
|
||||
|
||||
def EventChain():
|
||||
"""App with chained event handlers."""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import reflex as rx
|
||||
|
||||
# repeated here since the outer global isn't exported into the App module
|
||||
@ -20,6 +22,7 @@ def EventChain():
|
||||
|
||||
class State(rx.State):
|
||||
event_order: list[str] = []
|
||||
interim_value: str = ""
|
||||
|
||||
@rx.var
|
||||
def token(self) -> str:
|
||||
@ -111,12 +114,25 @@ def EventChain():
|
||||
self.event_order.append("click_return_dict_type")
|
||||
return State.event_arg_repr_type({"a": 1}) # type: ignore
|
||||
|
||||
async def click_yield_interim_value_async(self):
|
||||
self.interim_value = "interim"
|
||||
yield
|
||||
await asyncio.sleep(0.5)
|
||||
self.interim_value = "final"
|
||||
|
||||
def click_yield_interim_value(self):
|
||||
self.interim_value = "interim"
|
||||
yield
|
||||
time.sleep(0.5)
|
||||
self.interim_value = "final"
|
||||
|
||||
app = rx.App(state=State)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
return rx.fragment(
|
||||
rx.input(value=State.token, readonly=True, id="token"),
|
||||
rx.input(value=State.interim_value, readonly=True, id="interim_value"),
|
||||
rx.button(
|
||||
"Return Event",
|
||||
id="return_event",
|
||||
@ -172,6 +188,16 @@ def EventChain():
|
||||
id="return_dict_type",
|
||||
on_click=State.click_return_dict_type,
|
||||
),
|
||||
rx.button(
|
||||
"Click Yield Interim Value (Async)",
|
||||
id="click_yield_interim_value_async",
|
||||
on_click=State.click_yield_interim_value_async,
|
||||
),
|
||||
rx.button(
|
||||
"Click Yield Interim Value",
|
||||
id="click_yield_interim_value",
|
||||
on_click=State.click_yield_interim_value,
|
||||
),
|
||||
)
|
||||
|
||||
def on_load_return_chain():
|
||||
@ -237,7 +263,7 @@ def event_chain(tmp_path_factory) -> Generator[AppHarness, None, None]:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def driver(event_chain: AppHarness):
|
||||
def driver(event_chain: AppHarness) -> Generator[WebDriver, None, None]:
|
||||
"""Get an instance of the browser open to the event_chain app.
|
||||
|
||||
Args:
|
||||
@ -249,7 +275,6 @@ def driver(event_chain: AppHarness):
|
||||
assert event_chain.app_instance is not None, "app is not running"
|
||||
driver = event_chain.frontend()
|
||||
try:
|
||||
assert event_chain.poll_for_clients()
|
||||
yield driver
|
||||
finally:
|
||||
driver.quit()
|
||||
@ -335,7 +360,13 @@ def driver(event_chain: AppHarness):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_chain_click(
|
||||
event_chain: AppHarness,
|
||||
driver: WebDriver,
|
||||
button_id: str,
|
||||
exp_event_order: list[str],
|
||||
):
|
||||
"""Click the button, assert that the events are handled in the correct order.
|
||||
|
||||
Args:
|
||||
@ -350,17 +381,18 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
|
||||
assert btn
|
||||
|
||||
token = event_chain.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
|
||||
btn.click()
|
||||
if "redirect" in button_id:
|
||||
# wait a bit longer if we're redirecting
|
||||
time.sleep(1)
|
||||
if "many_events" in button_id:
|
||||
# wait a bit longer if we have loads of events
|
||||
time.sleep(1)
|
||||
time.sleep(0.5)
|
||||
backend_state = event_chain.app_instance.state_manager.states[token]
|
||||
assert backend_state.event_order == exp_event_order
|
||||
|
||||
async def _has_all_events():
|
||||
return len((await event_chain.get_state(token)).event_order) == len(
|
||||
exp_event_order
|
||||
)
|
||||
|
||||
await AppHarness._poll_for_async(_has_all_events)
|
||||
event_order = (await event_chain.get_state(token)).event_order
|
||||
assert event_order == exp_event_order
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -386,7 +418,13 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_chain_on_load(
|
||||
event_chain: AppHarness,
|
||||
driver: WebDriver,
|
||||
uri: str,
|
||||
exp_event_order: list[str],
|
||||
):
|
||||
"""Load the URI, assert that the events are handled in the correct order.
|
||||
|
||||
Args:
|
||||
@ -395,16 +433,23 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
|
||||
uri: the page to load
|
||||
exp_event_order: the expected events recorded in the State
|
||||
"""
|
||||
assert event_chain.frontend_url is not None
|
||||
driver.get(event_chain.frontend_url + uri)
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
assert token_input
|
||||
|
||||
token = event_chain.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
|
||||
time.sleep(0.5)
|
||||
backend_state = event_chain.app_instance.state_manager.states[token]
|
||||
assert backend_state.is_hydrated is True
|
||||
async def _has_all_events():
|
||||
return len((await event_chain.get_state(token)).event_order) == len(
|
||||
exp_event_order
|
||||
)
|
||||
|
||||
await AppHarness._poll_for_async(_has_all_events)
|
||||
backend_state = await event_chain.get_state(token)
|
||||
assert backend_state.event_order == exp_event_order
|
||||
assert backend_state.is_hydrated is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -444,7 +489,13 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_chain_on_mount(
|
||||
event_chain: AppHarness,
|
||||
driver: WebDriver,
|
||||
uri: str,
|
||||
exp_event_order: list[str],
|
||||
):
|
||||
"""Load the URI, assert that the events are handled in the correct order.
|
||||
|
||||
These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
|
||||
@ -458,16 +509,53 @@ def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
|
||||
uri: the page to load
|
||||
exp_event_order: the expected events recorded in the State
|
||||
"""
|
||||
assert event_chain.frontend_url is not None
|
||||
driver.get(event_chain.frontend_url + uri)
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
assert token_input
|
||||
|
||||
token = event_chain.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
|
||||
unmount_button = driver.find_element(By.ID, "unmount")
|
||||
assert unmount_button
|
||||
unmount_button.click()
|
||||
|
||||
time.sleep(1)
|
||||
backend_state = event_chain.app_instance.state_manager.states[token]
|
||||
assert backend_state.event_order == exp_event_order
|
||||
async def _has_all_events():
|
||||
return len((await event_chain.get_state(token)).event_order) == len(
|
||||
exp_event_order
|
||||
)
|
||||
|
||||
await AppHarness._poll_for_async(_has_all_events)
|
||||
event_order = (await event_chain.get_state(token)).event_order
|
||||
assert event_order == exp_event_order
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("button_id",),
|
||||
[
|
||||
("click_yield_interim_value_async",),
|
||||
("click_yield_interim_value",),
|
||||
],
|
||||
)
|
||||
def test_yield_state_update(event_chain: AppHarness, driver: WebDriver, button_id: str):
|
||||
"""Click the button, assert that the interim value is set, then final value is set.
|
||||
|
||||
Args:
|
||||
event_chain: AppHarness for the event_chain app
|
||||
driver: selenium WebDriver open to the app
|
||||
button_id: the ID of the button to click
|
||||
"""
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
interim_value_input = driver.find_element(By.ID, "interim_value")
|
||||
assert event_chain.poll_for_value(token_input)
|
||||
|
||||
btn = driver.find_element(By.ID, button_id)
|
||||
btn.click()
|
||||
assert (
|
||||
event_chain.poll_for_value(interim_value_input, exp_not_equal="") == "interim"
|
||||
)
|
||||
assert (
|
||||
event_chain.poll_for_value(interim_value_input, exp_not_equal="interim")
|
||||
== "final"
|
||||
)
|
||||
|
@ -19,11 +19,16 @@ def FormSubmit():
|
||||
def form_submit(self, form_data: dict):
|
||||
self.form_data = form_data
|
||||
|
||||
@rx.var
|
||||
def token(self) -> str:
|
||||
return self.get_token()
|
||||
|
||||
app = rx.App(state=FormState)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
return rx.vstack(
|
||||
rx.input(value=FormState.token, is_read_only=True, id="token"),
|
||||
rx.form(
|
||||
rx.vstack(
|
||||
rx.input(id="name_input"),
|
||||
@ -82,13 +87,13 @@ def driver(form_submit: AppHarness):
|
||||
"""
|
||||
driver = form_submit.frontend()
|
||||
try:
|
||||
assert form_submit.poll_for_clients()
|
||||
yield driver
|
||||
finally:
|
||||
driver.quit()
|
||||
|
||||
|
||||
def test_submit(driver, form_submit: AppHarness):
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit(driver, form_submit: AppHarness):
|
||||
"""Fill a form with various different output, submit it to backend and verify
|
||||
the output.
|
||||
|
||||
@ -97,7 +102,14 @@ def test_submit(driver, form_submit: AppHarness):
|
||||
form_submit: harness for FormSubmit app
|
||||
"""
|
||||
assert form_submit.app_instance is not None, "app is not running"
|
||||
_, backend_state = list(form_submit.app_instance.state_manager.states.items())[0]
|
||||
|
||||
# get a reference to the connected client
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
assert token_input
|
||||
|
||||
# wait for the backend connection to send the token
|
||||
token = form_submit.poll_for_value(token_input)
|
||||
assert token
|
||||
|
||||
name_input = driver.find_element(By.ID, "name_input")
|
||||
name_input.send_keys("foo")
|
||||
@ -132,19 +144,21 @@ def test_submit(driver, form_submit: AppHarness):
|
||||
submit_input = driver.find_element(By.CLASS_NAME, "chakra-button")
|
||||
submit_input.click()
|
||||
|
||||
# wait for the form data to arrive at the backend
|
||||
AppHarness._poll_for(
|
||||
lambda: backend_state.form_data != {},
|
||||
)
|
||||
async def get_form_data():
|
||||
return (await form_submit.get_state(token)).form_data
|
||||
|
||||
assert backend_state.form_data["name_input"] == "foo"
|
||||
assert backend_state.form_data["pin_input"] == pin_values
|
||||
assert backend_state.form_data["number_input"] == "-3"
|
||||
assert backend_state.form_data["bool_input"] is True
|
||||
assert backend_state.form_data["bool_input2"] is True
|
||||
assert backend_state.form_data["slider_input"] == "50"
|
||||
assert backend_state.form_data["range_input"] == ["25", "75"]
|
||||
assert backend_state.form_data["radio_input"] == "option2"
|
||||
assert backend_state.form_data["select_input"] == "option1"
|
||||
assert backend_state.form_data["text_area_input"] == "Some\nText"
|
||||
assert backend_state.form_data["debounce_input"] == "bar baz"
|
||||
# wait for the form data to arrive at the backend
|
||||
form_data = await AppHarness._poll_for_async(get_form_data)
|
||||
assert isinstance(form_data, dict)
|
||||
|
||||
assert form_data["name_input"] == "foo"
|
||||
assert form_data["pin_input"] == pin_values
|
||||
assert form_data["number_input"] == "-3"
|
||||
assert form_data["bool_input"] is True
|
||||
assert form_data["bool_input2"] is True
|
||||
assert form_data["slider_input"] == "50"
|
||||
assert form_data["range_input"] == ["25", "75"]
|
||||
assert form_data["radio_input"] == "option2"
|
||||
assert form_data["select_input"] == "option1"
|
||||
assert form_data["text_area_input"] == "Some\nText"
|
||||
assert form_data["debounce_input"] == "bar baz"
|
||||
|
@ -16,11 +16,16 @@ def FullyControlledInput():
|
||||
class State(rx.State):
|
||||
text: str = "initial"
|
||||
|
||||
@rx.var
|
||||
def token(self) -> str:
|
||||
return self.get_token()
|
||||
|
||||
app = rx.App(state=State)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
return rx.fragment(
|
||||
rx.input(value=State.token, is_read_only=True, id="token"),
|
||||
rx.input(
|
||||
id="debounce_input_input",
|
||||
on_change=State.set_text, # type: ignore
|
||||
@ -62,10 +67,12 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
driver = fully_controlled_input.frontend()
|
||||
|
||||
# get a reference to the connected client
|
||||
assert len(fully_controlled_input.poll_for_clients()) == 1
|
||||
token, backend_state = list(
|
||||
fully_controlled_input.app_instance.state_manager.states.items()
|
||||
)[0]
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
assert token_input
|
||||
|
||||
# wait for the backend connection to send the token
|
||||
token = fully_controlled_input.poll_for_value(token_input)
|
||||
assert token
|
||||
|
||||
# find the input and wait for it to have the initial state value
|
||||
debounce_input = driver.find_element(By.ID, "debounce_input_input")
|
||||
@ -80,14 +87,13 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
debounce_input.send_keys("foo")
|
||||
time.sleep(0.5)
|
||||
assert debounce_input.get_attribute("value") == "ifoonitial"
|
||||
assert backend_state.text == "ifoonitial"
|
||||
assert (await fully_controlled_input.get_state(token)).text == "ifoonitial"
|
||||
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
|
||||
|
||||
# clear the input on the backend
|
||||
backend_state.text = ""
|
||||
fully_controlled_input.app_instance.state_manager.set_state(token, backend_state)
|
||||
await fully_controlled_input.emit_state_updates()
|
||||
assert backend_state.text == ""
|
||||
async with fully_controlled_input.modify_state(token) as state:
|
||||
state.text = ""
|
||||
assert (await fully_controlled_input.get_state(token)).text == ""
|
||||
assert (
|
||||
fully_controlled_input.poll_for_value(
|
||||
debounce_input, exp_not_equal="ifoonitial"
|
||||
@ -99,7 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
debounce_input.send_keys("getting testing done")
|
||||
time.sleep(0.5)
|
||||
assert debounce_input.get_attribute("value") == "getting testing done"
|
||||
assert backend_state.text == "getting testing done"
|
||||
assert (
|
||||
await fully_controlled_input.get_state(token)
|
||||
).text == "getting testing done"
|
||||
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
|
||||
|
||||
# type into the on_change input
|
||||
@ -107,7 +115,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
||||
time.sleep(0.5)
|
||||
assert debounce_input.get_attribute("value") == "overwrite the state"
|
||||
assert on_change_input.get_attribute("value") == "overwrite the state"
|
||||
assert backend_state.text == "overwrite the state"
|
||||
assert (await fully_controlled_input.get_state(token)).text == "overwrite the state"
|
||||
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
|
||||
|
||||
clear_button.click()
|
||||
|
@ -33,11 +33,16 @@ def ServerSideEvent():
|
||||
def set_value_return_c(self):
|
||||
return rx.set_value("c", "")
|
||||
|
||||
@rx.var
|
||||
def token(self) -> str:
|
||||
return self.get_token()
|
||||
|
||||
app = rx.App(state=SSState)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
return rx.fragment(
|
||||
rx.input(id="token", value=SSState.token, is_read_only=True),
|
||||
rx.input(default_value="a", id="a"),
|
||||
rx.input(default_value="b", id="b"),
|
||||
rx.input(default_value="c", id="c"),
|
||||
@ -106,7 +111,12 @@ def driver(server_side_event: AppHarness):
|
||||
assert server_side_event.app_instance is not None, "app is not running"
|
||||
driver = server_side_event.frontend()
|
||||
try:
|
||||
assert server_side_event.poll_for_clients()
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
assert token_input
|
||||
# wait for the backend connection to send the token
|
||||
token = server_side_event.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
|
||||
yield driver
|
||||
finally:
|
||||
driver.quit()
|
||||
|
@ -89,13 +89,13 @@ def driver(upload_file: AppHarness):
|
||||
assert upload_file.app_instance is not None, "app is not running"
|
||||
driver = upload_file.frontend()
|
||||
try:
|
||||
assert upload_file.poll_for_clients()
|
||||
yield driver
|
||||
finally:
|
||||
driver.quit()
|
||||
|
||||
|
||||
def test_upload_file(tmp_path, upload_file: AppHarness, driver):
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
|
||||
"""Submit a file upload and check that it arrived on the backend.
|
||||
|
||||
Args:
|
||||
@ -124,16 +124,20 @@ def test_upload_file(tmp_path, upload_file: AppHarness, driver):
|
||||
upload_button.click()
|
||||
|
||||
# look up the backend state and assert on uploaded contents
|
||||
backend_state = upload_file.app_instance.state_manager.states[token]
|
||||
time.sleep(0.5)
|
||||
assert backend_state._file_data[exp_name] == exp_contents
|
||||
async def get_file_data():
|
||||
return (await upload_file.get_state(token))._file_data
|
||||
|
||||
file_data = await AppHarness._poll_for_async(get_file_data)
|
||||
assert isinstance(file_data, dict)
|
||||
assert file_data[exp_name] == exp_contents
|
||||
|
||||
# check that the selected files are displayed
|
||||
selected_files = driver.find_element(By.ID, "selected_files")
|
||||
assert selected_files.text == exp_name
|
||||
|
||||
|
||||
def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
|
||||
"""Submit several file uploads and check that they arrived on the backend.
|
||||
|
||||
Args:
|
||||
@ -173,10 +177,13 @@ def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
|
||||
upload_button.click()
|
||||
|
||||
# look up the backend state and assert on uploaded contents
|
||||
backend_state = upload_file.app_instance.state_manager.states[token]
|
||||
time.sleep(0.5)
|
||||
async def get_file_data():
|
||||
return (await upload_file.get_state(token))._file_data
|
||||
|
||||
file_data = await AppHarness._poll_for_async(get_file_data)
|
||||
assert isinstance(file_data, dict)
|
||||
for exp_name, exp_contents in exp_files.items():
|
||||
assert backend_state._file_data[exp_name] == exp_contents
|
||||
assert file_data[exp_name] == exp_contents
|
||||
|
||||
|
||||
def test_clear_files(tmp_path, upload_file: AppHarness, driver):
|
||||
|
@ -26,11 +26,16 @@ def VarOperations():
|
||||
dict1: dict = {1: 2}
|
||||
dict2: dict = {3: 4}
|
||||
|
||||
@rx.var
|
||||
def token(self) -> str:
|
||||
return self.get_token()
|
||||
|
||||
app = rx.App(state=VarOperationState)
|
||||
|
||||
@app.add_page
|
||||
def index():
|
||||
return rx.vstack(
|
||||
rx.input(id="token", value=VarOperationState.token, is_read_only=True),
|
||||
# INT INT
|
||||
rx.text(
|
||||
VarOperationState.int_var1 + VarOperationState.int_var2,
|
||||
@ -544,7 +549,12 @@ def driver(var_operations: AppHarness):
|
||||
"""
|
||||
driver = var_operations.frontend()
|
||||
try:
|
||||
assert var_operations.poll_for_clients()
|
||||
token_input = driver.find_element(By.ID, "token")
|
||||
assert token_input
|
||||
# wait for the backend connection to send the token
|
||||
token = var_operations.poll_for_value(token_input)
|
||||
assert token is not None
|
||||
|
||||
yield driver
|
||||
finally:
|
||||
driver.quit()
|
||||
|
@ -21,6 +21,7 @@ from .constants import Env as Env
|
||||
from .event import EVENT_ARG as EVENT_ARG
|
||||
from .event import EventChain as EventChain
|
||||
from .event import FileUpload as upload_files
|
||||
from .event import background as background
|
||||
from .event import clear_local_storage as clear_local_storage
|
||||
from .event import console_log as console_log
|
||||
from .event import download as download
|
||||
|
229
reflex/app.py
229
reflex/app.py
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import inspect
|
||||
import os
|
||||
from multiprocessing.pool import ThreadPool
|
||||
@ -13,6 +14,7 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
@ -49,7 +51,13 @@ from reflex.route import (
|
||||
get_route_args,
|
||||
verify_route_validity,
|
||||
)
|
||||
from reflex.state import DefaultState, State, StateManager, StateUpdate
|
||||
from reflex.state import (
|
||||
DefaultState,
|
||||
State,
|
||||
StateManager,
|
||||
StateManagerMemory,
|
||||
StateUpdate,
|
||||
)
|
||||
from reflex.utils import console, format, prerequisites, types
|
||||
from reflex.vars import ImportVar
|
||||
|
||||
@ -89,7 +97,7 @@ class App(Base):
|
||||
state: Type[State] = DefaultState
|
||||
|
||||
# Class to manage many client states.
|
||||
state_manager: StateManager = StateManager()
|
||||
state_manager: StateManager = StateManagerMemory(state=DefaultState)
|
||||
|
||||
# The styling to apply to each component.
|
||||
style: ComponentStyle = {}
|
||||
@ -104,13 +112,16 @@ class App(Base):
|
||||
admin_dash: Optional[AdminDash] = None
|
||||
|
||||
# The async server name space
|
||||
event_namespace: Optional[AsyncNamespace] = None
|
||||
event_namespace: Optional[EventNamespace] = None
|
||||
|
||||
# A component that is present on every page.
|
||||
overlay_component: Optional[
|
||||
Union[Component, ComponentCallable]
|
||||
] = default_overlay_component
|
||||
|
||||
# Background tasks that are currently running
|
||||
background_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the app.
|
||||
|
||||
@ -154,7 +165,7 @@ class App(Base):
|
||||
self.middleware.append(HydrateMiddleware())
|
||||
|
||||
# Set up the state manager.
|
||||
self.state_manager.setup(state=self.state)
|
||||
self.state_manager = StateManager.create(state=self.state)
|
||||
|
||||
# Set up the API.
|
||||
self.api = FastAPI()
|
||||
@ -646,6 +657,76 @@ class App(Base):
|
||||
thread_pool.close()
|
||||
thread_pool.join()
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||
"""Modify the state out of band.
|
||||
|
||||
Args:
|
||||
token: The token to modify the state for.
|
||||
|
||||
Yields:
|
||||
The state to modify.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the app has not been initialized yet.
|
||||
"""
|
||||
if self.event_namespace is None:
|
||||
raise RuntimeError("App has not been initialized yet.")
|
||||
# Get exclusive access to the state.
|
||||
async with self.state_manager.modify_state(token) as state:
|
||||
# No other event handler can modify the state while in this context.
|
||||
yield state
|
||||
delta = state.get_delta()
|
||||
if delta:
|
||||
# When the state is modified reset dirty status and emit the delta to the frontend.
|
||||
state._clean()
|
||||
await self.event_namespace.emit_update(
|
||||
update=StateUpdate(delta=delta),
|
||||
sid=state.get_sid(),
|
||||
)
|
||||
|
||||
def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
|
||||
"""Process an event in the background and emit updates as they arrive.
|
||||
|
||||
Args:
|
||||
state: The state to process the event for.
|
||||
event: The event to process.
|
||||
|
||||
Returns:
|
||||
Task if the event was backgroundable, otherwise None
|
||||
"""
|
||||
substate, handler = state._get_event_handler(event)
|
||||
if not handler.is_background:
|
||||
return None
|
||||
|
||||
async def _coro():
|
||||
"""Coroutine to process the event and emit updates inside an asyncio.Task.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the app has not been initialized yet.
|
||||
"""
|
||||
if self.event_namespace is None:
|
||||
raise RuntimeError("App has not been initialized yet.")
|
||||
|
||||
# Process the event.
|
||||
async for update in state._process_event(
|
||||
handler=handler, state=substate, payload=event.payload
|
||||
):
|
||||
# Postprocess the event.
|
||||
update = await self.postprocess(state, event, update)
|
||||
|
||||
# Send the update to the client.
|
||||
await self.event_namespace.emit_update(
|
||||
update=update,
|
||||
sid=state.get_sid(),
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_coro())
|
||||
self.background_tasks.add(task)
|
||||
# Clean up task from background_tasks set when complete.
|
||||
task.add_done_callback(self.background_tasks.discard)
|
||||
return task
|
||||
|
||||
|
||||
async def process(
|
||||
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
||||
@ -662,9 +743,6 @@ async def process(
|
||||
Yields:
|
||||
The state updates after processing the event.
|
||||
"""
|
||||
# Get the state for the session.
|
||||
state = app.state_manager.get_state(event.token)
|
||||
|
||||
# Add request data to the state.
|
||||
router_data = event.router_data
|
||||
router_data.update(
|
||||
@ -676,31 +754,35 @@ async def process(
|
||||
constants.RouteVar.CLIENT_IP: client_ip,
|
||||
}
|
||||
)
|
||||
# re-assign only when the value is different
|
||||
if state.router_data != router_data:
|
||||
# assignment will recurse into substates and force recalculation of
|
||||
# dependent ComputedVar (dynamic route variables)
|
||||
state.router_data = router_data
|
||||
# Get the state for the session exclusively.
|
||||
async with app.state_manager.modify_state(event.token) as state:
|
||||
# re-assign only when the value is different
|
||||
if state.router_data != router_data:
|
||||
# assignment will recurse into substates and force recalculation of
|
||||
# dependent ComputedVar (dynamic route variables)
|
||||
state.router_data = router_data
|
||||
|
||||
# Preprocess the event.
|
||||
update = await app.preprocess(state, event)
|
||||
# Preprocess the event.
|
||||
update = await app.preprocess(state, event)
|
||||
|
||||
# If there was an update, yield it.
|
||||
if update is not None:
|
||||
yield update
|
||||
|
||||
# Only process the event if there is no update.
|
||||
else:
|
||||
# Process the event.
|
||||
async for update in state._process(event):
|
||||
# Postprocess the event.
|
||||
update = await app.postprocess(state, event, update)
|
||||
|
||||
# Yield the update.
|
||||
# If there was an update, yield it.
|
||||
if update is not None:
|
||||
yield update
|
||||
|
||||
# Set the state for the session.
|
||||
app.state_manager.set_state(event.token, state)
|
||||
# Only process the event if there is no update.
|
||||
else:
|
||||
if app._process_background(state, event) is not None:
|
||||
# `final=True` allows the frontend send more events immediately.
|
||||
yield StateUpdate(final=True)
|
||||
return
|
||||
|
||||
# Process the event synchronously.
|
||||
async for update in state._process(event):
|
||||
# Postprocess the event.
|
||||
update = await app.postprocess(state, event, update)
|
||||
|
||||
# Yield the update.
|
||||
yield update
|
||||
|
||||
|
||||
async def ping() -> str:
|
||||
@ -737,47 +819,46 @@ def upload(app: App):
|
||||
assert file.filename is not None
|
||||
file.filename = file.filename.split(":")[-1]
|
||||
# Get the state for the session.
|
||||
state = app.state_manager.get_state(token)
|
||||
# get the current session ID
|
||||
sid = state.get_sid()
|
||||
# get the current state(parent state/substate)
|
||||
path = handler.split(".")[:-1]
|
||||
current_state = state.get_substate(path)
|
||||
handler_upload_param = ()
|
||||
async with app.state_manager.modify_state(token) as state:
|
||||
# get the current session ID
|
||||
sid = state.get_sid()
|
||||
# get the current state(parent state/substate)
|
||||
path = handler.split(".")[:-1]
|
||||
current_state = state.get_substate(path)
|
||||
handler_upload_param = ()
|
||||
|
||||
# get handler function
|
||||
func = getattr(current_state, handler.split(".")[-1])
|
||||
# get handler function
|
||||
func = getattr(current_state, handler.split(".")[-1])
|
||||
|
||||
# check if there exists any handler args with annotation, List[UploadFile]
|
||||
for k, v in inspect.getfullargspec(
|
||||
func.fn if isinstance(func, EventHandler) else func
|
||||
).annotations.items():
|
||||
if types.is_generic_alias(v) and types._issubclass(
|
||||
v.__args__[0], UploadFile
|
||||
):
|
||||
handler_upload_param = (k, v)
|
||||
break
|
||||
# check if there exists any handler args with annotation, List[UploadFile]
|
||||
for k, v in inspect.getfullargspec(
|
||||
func.fn if isinstance(func, EventHandler) else func
|
||||
).annotations.items():
|
||||
if types.is_generic_alias(v) and types._issubclass(
|
||||
v.__args__[0], UploadFile
|
||||
):
|
||||
handler_upload_param = (k, v)
|
||||
break
|
||||
|
||||
if not handler_upload_param:
|
||||
raise ValueError(
|
||||
f"`{handler}` handler should have a parameter annotated as List["
|
||||
f"rx.UploadFile]"
|
||||
if not handler_upload_param:
|
||||
raise ValueError(
|
||||
f"`{handler}` handler should have a parameter annotated as List["
|
||||
f"rx.UploadFile]"
|
||||
)
|
||||
|
||||
event = Event(
|
||||
token=token,
|
||||
name=handler,
|
||||
payload={handler_upload_param[0]: files},
|
||||
)
|
||||
|
||||
event = Event(
|
||||
token=token,
|
||||
name=handler,
|
||||
payload={handler_upload_param[0]: files},
|
||||
)
|
||||
async for update in state._process(event):
|
||||
# Postprocess the event.
|
||||
update = await app.postprocess(state, event, update)
|
||||
# Send update to client
|
||||
await asyncio.create_task(
|
||||
app.event_namespace.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
|
||||
)
|
||||
# Set the state for the session.
|
||||
app.state_manager.set_state(event.token, state)
|
||||
async for update in state._process(event):
|
||||
# Postprocess the event.
|
||||
update = await app.postprocess(state, event, update)
|
||||
# Send update to client
|
||||
await app.event_namespace.emit_update( # type: ignore
|
||||
update=update,
|
||||
sid=sid,
|
||||
)
|
||||
|
||||
return upload_file
|
||||
|
||||
@ -815,6 +896,18 @@ class EventNamespace(AsyncNamespace):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def emit_update(self, update: StateUpdate, sid: str) -> None:
|
||||
"""Emit an update to the client.
|
||||
|
||||
Args:
|
||||
update: The state update to send.
|
||||
sid: The Socket.IO session id.
|
||||
"""
|
||||
# Creating a task prevents the update from being blocked behind other coroutines.
|
||||
await asyncio.create_task(
|
||||
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
|
||||
)
|
||||
|
||||
async def on_event(self, sid, data):
|
||||
"""Event for receiving front-end websocket events.
|
||||
|
||||
@ -841,10 +934,8 @@ class EventNamespace(AsyncNamespace):
|
||||
|
||||
# Process the events.
|
||||
async for update in process(self.app, event, sid, headers, client_ip):
|
||||
# Emit the event.
|
||||
await asyncio.create_task(
|
||||
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
|
||||
)
|
||||
# Emit the update from processing the event.
|
||||
await self.emit_update(update=update, sid=sid)
|
||||
|
||||
async def on_ping(self, sid):
|
||||
"""Event for testing the API endpoint.
|
||||
|
@ -1,5 +1,6 @@
|
||||
""" Generated with stubgen from mypy, then manually edited, do not regen."""
|
||||
|
||||
import asyncio
|
||||
from fastapi import FastAPI
|
||||
from fastapi import UploadFile as UploadFile
|
||||
from reflex import constants as constants
|
||||
@ -45,12 +46,14 @@ from reflex.utils import (
|
||||
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
overload,
|
||||
@ -75,6 +78,7 @@ class App(Base):
|
||||
admin_dash: Optional[AdminDash]
|
||||
event_namespace: Optional[AsyncNamespace]
|
||||
overlay_component: Optional[Union[Component, ComponentCallable]]
|
||||
background_tasks: Set[asyncio.Task] = set()
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
@ -116,6 +120,10 @@ class App(Base):
|
||||
def setup_admin_dash(self) -> None: ...
|
||||
def get_frontend_packages(self, imports: Dict[str, str]): ...
|
||||
def compile(self) -> None: ...
|
||||
def modify_state(self, token: str) -> AsyncContextManager[State]: ...
|
||||
def _process_background(
|
||||
self, state: State, event: Event
|
||||
) -> asyncio.Task | None: ...
|
||||
|
||||
async def process(
|
||||
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
||||
|
@ -219,6 +219,8 @@ OLD_CONFIG_FILE = f"pcconfig{PY_EXT}"
|
||||
PRODUCTION_BACKEND_URL = "https://{username}-{app_name}.api.pynecone.app"
|
||||
# Token expiration time in seconds.
|
||||
TOKEN_EXPIRATION = 60 * 60
|
||||
# Maximum time in milliseconds that a state can be locked for exclusive access.
|
||||
LOCK_EXPIRATION = 10000
|
||||
|
||||
# Testing variables.
|
||||
# Testing os env set by pytest when running a test case.
|
||||
|
@ -2,7 +2,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
@ -10,6 +20,9 @@ from reflex.utils import console, format
|
||||
from reflex.utils.types import ArgsSpec
|
||||
from reflex.vars import BaseVar, Var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from reflex.state import State
|
||||
|
||||
|
||||
class Event(Base):
|
||||
"""An event that describes any state change in the app."""
|
||||
@ -27,6 +40,66 @@ class Event(Base):
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
|
||||
BACKGROUND_TASK_MARKER = "_reflex_background_task"
|
||||
|
||||
|
||||
def background(fn):
|
||||
"""Decorator to mark event handler as running in the background.
|
||||
|
||||
Args:
|
||||
fn: The function to decorate.
|
||||
|
||||
Returns:
|
||||
The same function, but with a marker set.
|
||||
|
||||
|
||||
Raises:
|
||||
TypeError: If the function is not a coroutine function or async generator.
|
||||
"""
|
||||
if not inspect.iscoroutinefunction(fn) and not inspect.isasyncgenfunction(fn):
|
||||
raise TypeError("Background task must be async function or generator.")
|
||||
setattr(fn, BACKGROUND_TASK_MARKER, True)
|
||||
return fn
|
||||
|
||||
|
||||
def _no_chain_background_task(
|
||||
state_cls: Type["State"], name: str, fn: Callable
|
||||
) -> Callable:
|
||||
"""Protect against directly chaining a background task from another event handler.
|
||||
|
||||
Args:
|
||||
state_cls: The state class that the event handler is in.
|
||||
name: The name of the background task.
|
||||
fn: The background task coroutine function / generator.
|
||||
|
||||
Returns:
|
||||
A compatible coroutine function / generator that raises a runtime error.
|
||||
|
||||
Raises:
|
||||
TypeError: If the background task is not async.
|
||||
"""
|
||||
call = f"{state_cls.__name__}.{name}"
|
||||
message = (
|
||||
f"Cannot directly call background task {name!r}, use "
|
||||
f"`yield {call}` or `return {call}` instead."
|
||||
)
|
||||
if inspect.iscoroutinefunction(fn):
|
||||
|
||||
async def _no_chain_background_task_co(*args, **kwargs):
|
||||
raise RuntimeError(message)
|
||||
|
||||
return _no_chain_background_task_co
|
||||
if inspect.isasyncgenfunction(fn):
|
||||
|
||||
async def _no_chain_background_task_gen(*args, **kwargs):
|
||||
yield
|
||||
raise RuntimeError(message)
|
||||
|
||||
return _no_chain_background_task_gen
|
||||
|
||||
raise TypeError(f"{fn} is marked as a background task, but is not async.")
|
||||
|
||||
|
||||
class EventHandler(Base):
|
||||
"""An event handler responds to an event to update the state."""
|
||||
|
||||
@ -39,6 +112,15 @@ class EventHandler(Base):
|
||||
# Needed to allow serialization of Callable.
|
||||
frozen = True
|
||||
|
||||
@property
|
||||
def is_background(self) -> bool:
|
||||
"""Whether the event handler is a background task.
|
||||
|
||||
Returns:
|
||||
True if the event handler is marked as a background task.
|
||||
"""
|
||||
return getattr(self.fn, BACKGROUND_TASK_MARKER, False)
|
||||
|
||||
def __call__(self, *args: Var) -> EventSpec:
|
||||
"""Pass arguments to the handler to get an event spec.
|
||||
|
||||
@ -530,7 +612,7 @@ def get_handler_args(event_spec: EventSpec) -> tuple[tuple[Var, Var], ...]:
|
||||
|
||||
|
||||
def fix_events(
|
||||
events: list[EventHandler | EventSpec],
|
||||
events: list[EventHandler | EventSpec] | None,
|
||||
token: str,
|
||||
router_data: dict[str, Any] | None = None,
|
||||
) -> list[Event]:
|
||||
|
639
reflex/state.py
639
reflex/state.py
@ -2,13 +2,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import traceback
|
||||
import urllib.parse
|
||||
from abc import ABC
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from types import FunctionType
|
||||
from typing import (
|
||||
@ -27,12 +29,20 @@ from typing import (
|
||||
import cloudpickle
|
||||
import pydantic
|
||||
import wrapt
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert
|
||||
from reflex.event import (
|
||||
Event,
|
||||
EventHandler,
|
||||
EventSpec,
|
||||
_no_chain_background_task,
|
||||
fix_events,
|
||||
window_alert,
|
||||
)
|
||||
from reflex.utils import format, prerequisites, types
|
||||
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
||||
from reflex.vars import BaseVar, ComputedVar, Var
|
||||
|
||||
Delta = Dict[str, Any]
|
||||
@ -152,7 +162,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
# Convert the event handlers to functions.
|
||||
for name, event_handler in state.event_handlers.items():
|
||||
fn = functools.partial(event_handler.fn, self)
|
||||
if event_handler.is_background:
|
||||
fn = _no_chain_background_task(type(state), name, event_handler.fn)
|
||||
else:
|
||||
fn = functools.partial(event_handler.fn, self)
|
||||
fn.__module__ = event_handler.fn.__module__ # type: ignore
|
||||
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
||||
setattr(self, name, fn)
|
||||
@ -711,6 +724,37 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
return self.substates[path[0]].get_substate(path[1:])
|
||||
|
||||
def _get_event_handler(
|
||||
self, event: Event
|
||||
) -> tuple[State | StateProxy, EventHandler]:
|
||||
"""Get the event handler for the given event.
|
||||
|
||||
Args:
|
||||
event: The event to get the handler for.
|
||||
|
||||
|
||||
Returns:
|
||||
The event handler.
|
||||
|
||||
Raises:
|
||||
ValueError: If the event handler or substate is not found.
|
||||
"""
|
||||
# Get the event handler.
|
||||
path = event.name.split(".")
|
||||
path, name = path[:-1], path[-1]
|
||||
substate = self.get_substate(path)
|
||||
if not substate:
|
||||
raise ValueError(
|
||||
"The value of state cannot be None when processing an event."
|
||||
)
|
||||
handler = substate.event_handlers[name]
|
||||
|
||||
# For background tasks, proxy the state
|
||||
if handler.is_background:
|
||||
substate = StateProxy(substate)
|
||||
|
||||
return substate, handler
|
||||
|
||||
async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
|
||||
"""Obtain event info and process event.
|
||||
|
||||
@ -719,44 +763,17 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
|
||||
Yields:
|
||||
The state update after processing the event.
|
||||
|
||||
Raises:
|
||||
ValueError: If the state value is None.
|
||||
"""
|
||||
# Get the event handler.
|
||||
path = event.name.split(".")
|
||||
path, name = path[:-1], path[-1]
|
||||
substate = self.get_substate(path)
|
||||
handler = substate.event_handlers[name] # type: ignore
|
||||
substate, handler = self._get_event_handler(event)
|
||||
|
||||
if not substate:
|
||||
raise ValueError(
|
||||
"The value of state cannot be None when processing an event."
|
||||
)
|
||||
|
||||
# Get the event generator.
|
||||
event_iter = self._process_event(
|
||||
# Run the event generator and yield state updates.
|
||||
async for update in self._process_event(
|
||||
handler=handler,
|
||||
state=substate,
|
||||
payload=event.payload,
|
||||
)
|
||||
|
||||
# Clean the state before processing the event.
|
||||
self._clean()
|
||||
|
||||
# Run the event generator and return state updates.
|
||||
async for events, final in event_iter:
|
||||
# Fix the returned events.
|
||||
events = fix_events(events, event.token) # type: ignore
|
||||
|
||||
# Get the delta after processing the event.
|
||||
delta = self.get_delta()
|
||||
|
||||
# Yield the state update.
|
||||
yield StateUpdate(delta=delta, events=events, final=final)
|
||||
|
||||
# Clean the state to prepare for the next event.
|
||||
self._clean()
|
||||
):
|
||||
yield update
|
||||
|
||||
def _check_valid(self, handler: EventHandler, events: Any) -> Any:
|
||||
"""Check if the events yielded are valid. They must be EventHandlers or EventSpecs.
|
||||
@ -787,9 +804,42 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
||||
)
|
||||
|
||||
def _as_state_update(
|
||||
self,
|
||||
handler: EventHandler,
|
||||
events: EventSpec | list[EventSpec] | None,
|
||||
final: bool,
|
||||
) -> StateUpdate:
|
||||
"""Convert the events to a StateUpdate.
|
||||
|
||||
Fixes the events and checks for validity before converting.
|
||||
|
||||
Args:
|
||||
handler: The handler where the events originated from.
|
||||
events: The events to queue with the update.
|
||||
final: Whether the handler is done processing.
|
||||
|
||||
Returns:
|
||||
The valid StateUpdate containing the events and final flag.
|
||||
"""
|
||||
token = self.get_token()
|
||||
|
||||
# Convert valid EventHandler and EventSpec into Event
|
||||
fixed_events = fix_events(self._check_valid(handler, events), token)
|
||||
|
||||
# Get the delta after processing the event.
|
||||
delta = self.get_delta()
|
||||
self._clean()
|
||||
|
||||
return StateUpdate(
|
||||
delta=delta,
|
||||
events=fixed_events,
|
||||
final=final if not handler.is_background else True,
|
||||
)
|
||||
|
||||
async def _process_event(
|
||||
self, handler: EventHandler, state: State, payload: Dict
|
||||
) -> AsyncIterator[tuple[list[EventSpec] | None, bool]]:
|
||||
self, handler: EventHandler, state: State | StateProxy, payload: Dict
|
||||
) -> AsyncIterator[StateUpdate]:
|
||||
"""Process event.
|
||||
|
||||
Args:
|
||||
@ -798,13 +848,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
payload: The event payload.
|
||||
|
||||
Yields:
|
||||
Tuple containing:
|
||||
0: The state update after processing the event.
|
||||
1: Whether the event is the final event.
|
||||
StateUpdate object
|
||||
"""
|
||||
# Get the function to process the event.
|
||||
fn = functools.partial(handler.fn, state)
|
||||
|
||||
# Clean the state before processing the event.
|
||||
self._clean()
|
||||
|
||||
# Wrap the function in a try/except block.
|
||||
try:
|
||||
# Handle async functions.
|
||||
@ -817,30 +868,34 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
# Handle async generators.
|
||||
if inspect.isasyncgen(events):
|
||||
async for event in events:
|
||||
yield self._check_valid(handler, event), False
|
||||
yield None, True
|
||||
yield self._as_state_update(handler, event, final=False)
|
||||
yield self._as_state_update(handler, events=None, final=True)
|
||||
|
||||
# Handle regular generators.
|
||||
elif inspect.isgenerator(events):
|
||||
try:
|
||||
while True:
|
||||
yield self._check_valid(handler, next(events)), False
|
||||
yield self._as_state_update(handler, next(events), final=False)
|
||||
except StopIteration as si:
|
||||
# the "return" value of the generator is not available
|
||||
# in the loop, we must catch StopIteration to access it
|
||||
if si.value is not None:
|
||||
yield self._check_valid(handler, si.value), False
|
||||
yield None, True
|
||||
yield self._as_state_update(handler, si.value, final=False)
|
||||
yield self._as_state_update(handler, events=None, final=True)
|
||||
|
||||
# Handle regular event chains.
|
||||
else:
|
||||
yield self._check_valid(handler, events), True
|
||||
yield self._as_state_update(handler, events, final=True)
|
||||
|
||||
# If an error occurs, throw a window alert.
|
||||
except Exception:
|
||||
error = traceback.format_exc()
|
||||
print(error)
|
||||
yield [window_alert("An error occurred. See logs for details.")], True
|
||||
yield self._as_state_update(
|
||||
handler,
|
||||
window_alert("An error occurred. See logs for details."),
|
||||
final=True,
|
||||
)
|
||||
|
||||
def _always_dirty_computed_vars(self) -> set[str]:
|
||||
"""The set of ComputedVars that always need to be recalculated.
|
||||
@ -989,6 +1044,160 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
||||
variables = {**base_vars, **computed_vars, **substate_vars}
|
||||
return {k: variables[k] for k in sorted(variables)}
|
||||
|
||||
async def __aenter__(self) -> State:
|
||||
"""Enter the async context manager protocol.
|
||||
|
||||
This should not be used for the State class, but exists for
|
||||
type-compatibility with StateProxy.
|
||||
|
||||
Raises:
|
||||
TypeError: always, because async contextmanager protocol is only supported for background task.
|
||||
"""
|
||||
raise TypeError(
|
||||
"Only background task should use `async with self` to modify state."
|
||||
)
|
||||
|
||||
async def __aexit__(self, *exc_info: Any) -> None:
|
||||
"""Exit the async context manager protocol.
|
||||
|
||||
This should not be used for the State class, but exists for
|
||||
type-compatibility with StateProxy.
|
||||
|
||||
Args:
|
||||
exc_info: The exception info tuple.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StateProxy(wrapt.ObjectProxy):
|
||||
"""Proxy of a state instance to control mutability of vars for a background task.
|
||||
|
||||
Since a background task runs against a state instance without holding the
|
||||
state_manager lock for the token, the reference may become stale if the same
|
||||
state is modified by another event handler.
|
||||
|
||||
The proxy object ensures that writes to the state are blocked unless
|
||||
explicitly entering a context which refreshes the state from state_manager
|
||||
and holds the lock for the token until exiting the context. After exiting
|
||||
the context, a StateUpdate may be emitted to the frontend to notify the
|
||||
client of the state change.
|
||||
|
||||
A background task will be passed the `StateProxy` as `self`, so mutability
|
||||
can be safely performed inside an `async with self` block.
|
||||
|
||||
class State(rx.State):
|
||||
counter: int = 0
|
||||
|
||||
@rx.background
|
||||
async def bg_increment(self):
|
||||
await asyncio.sleep(1)
|
||||
async with self:
|
||||
self.counter += 1
|
||||
"""
|
||||
|
||||
def __init__(self, state_instance):
|
||||
"""Create a proxy for a state instance.
|
||||
|
||||
Args:
|
||||
state_instance: The state instance to proxy.
|
||||
"""
|
||||
super().__init__(state_instance)
|
||||
self._self_app = getattr(prerequisites.get_app(), constants.APP_VAR)
|
||||
self._self_substate_path = state_instance.get_full_name().split(".")
|
||||
self._self_actx = None
|
||||
self._self_mutable = False
|
||||
|
||||
async def __aenter__(self) -> StateProxy:
|
||||
"""Enter the async context manager protocol.
|
||||
|
||||
Sets mutability to True and enters the `App.modify_state` async context,
|
||||
which refreshes the state from state_manager and holds the lock for the
|
||||
given state token until exiting the context.
|
||||
|
||||
Background tasks should avoid blocking calls while inside the context.
|
||||
|
||||
Returns:
|
||||
This StateProxy instance in mutable mode.
|
||||
"""
|
||||
self._self_actx = self._self_app.modify_state(self.__wrapped__.get_token())
|
||||
mutable_state = await self._self_actx.__aenter__()
|
||||
super().__setattr__(
|
||||
"__wrapped__", mutable_state.get_substate(self._self_substate_path)
|
||||
)
|
||||
self._self_mutable = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_info: Any) -> None:
|
||||
"""Exit the async context manager protocol.
|
||||
|
||||
Sets proxy mutability to False and persists any state changes.
|
||||
|
||||
Args:
|
||||
exc_info: The exception info tuple.
|
||||
"""
|
||||
if self._self_actx is None:
|
||||
return
|
||||
self._self_mutable = False
|
||||
await self._self_actx.__aexit__(*exc_info)
|
||||
self._self_actx = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the regular context manager protocol.
|
||||
|
||||
This is not supported for background tasks, and exists only to raise a more useful exception
|
||||
when the StateProxy is used incorrectly.
|
||||
|
||||
Raises:
|
||||
TypeError: always, because only async contextmanager protocol is supported.
|
||||
"""
|
||||
raise TypeError("Background task must use `async with self` to modify state.")
|
||||
|
||||
def __exit__(self, *exc_info: Any) -> None:
|
||||
"""Exit the regular context manager protocol.
|
||||
|
||||
Args:
|
||||
exc_info: The exception info tuple.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Get the attribute from the underlying state instance.
|
||||
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
|
||||
Returns:
|
||||
The value of the attribute.
|
||||
"""
|
||||
value = super().__getattr__(name)
|
||||
if not name.startswith("_self_") and isinstance(value, MutableProxy):
|
||||
# ensure mutations to these containers are blocked unless proxy is _mutable
|
||||
return ImmutableMutableProxy(
|
||||
wrapped=value.__wrapped__,
|
||||
state=self, # type: ignore
|
||||
field_name=value._self_field_name,
|
||||
)
|
||||
return value
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
"""Set the attribute on the underlying state instance.
|
||||
|
||||
If the attribute is internal, set it on the proxy instance instead.
|
||||
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
value: The value of the attribute.
|
||||
|
||||
Raises:
|
||||
ImmutableStateError: If the state is not in mutable mode.
|
||||
"""
|
||||
if not name.startswith("_self_") and not self._self_mutable:
|
||||
raise ImmutableStateError(
|
||||
"Background task StateProxy is immutable outside of a context "
|
||||
"manager. Use `async with self` to modify state."
|
||||
)
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
||||
class DefaultState(State):
|
||||
"""The default empty state."""
|
||||
@ -1009,31 +1218,29 @@ class StateUpdate(Base):
|
||||
final: bool = True
|
||||
|
||||
|
||||
class StateManager(Base):
|
||||
class StateManager(Base, ABC):
|
||||
"""A class to manage many client states."""
|
||||
|
||||
# The state class to use.
|
||||
state: Type[State] = DefaultState
|
||||
state: Type[State]
|
||||
|
||||
# The mapping of client ids to states.
|
||||
states: Dict[str, State] = {}
|
||||
|
||||
# The token expiration time (s).
|
||||
token_expiration: int = constants.TOKEN_EXPIRATION
|
||||
|
||||
# The redis client to use.
|
||||
redis: Optional[Redis] = None
|
||||
|
||||
def setup(self, state: Type[State]):
|
||||
"""Set up the state manager.
|
||||
@classmethod
|
||||
def create(cls, state: Type[State] = DefaultState):
|
||||
"""Create a new state manager.
|
||||
|
||||
Args:
|
||||
state: The state class to use.
|
||||
"""
|
||||
self.state = state
|
||||
self.redis = prerequisites.get_redis()
|
||||
|
||||
def get_state(self, token: str) -> State:
|
||||
Returns:
|
||||
The state manager (either memory or redis).
|
||||
"""
|
||||
redis = prerequisites.get_redis()
|
||||
if redis is not None:
|
||||
return StateManagerRedis(state=state, redis=redis)
|
||||
return StateManagerMemory(state=state)
|
||||
|
||||
@abstractmethod
|
||||
async def get_state(self, token: str) -> State:
|
||||
"""Get the state for a token.
|
||||
|
||||
Args:
|
||||
@ -1042,27 +1249,266 @@ class StateManager(Base):
|
||||
Returns:
|
||||
The state for the token.
|
||||
"""
|
||||
if self.redis is not None:
|
||||
redis_state = self.redis.get(token)
|
||||
if redis_state is None:
|
||||
self.set_state(token, self.state())
|
||||
return self.get_state(token)
|
||||
return cloudpickle.loads(redis_state)
|
||||
pass
|
||||
|
||||
if token not in self.states:
|
||||
self.states[token] = self.state()
|
||||
return self.states[token]
|
||||
|
||||
def set_state(self, token: str, state: State):
|
||||
@abstractmethod
|
||||
async def set_state(self, token: str, state: State):
|
||||
"""Set the state for a token.
|
||||
|
||||
Args:
|
||||
token: The token to set the state for.
|
||||
state: The state to set.
|
||||
"""
|
||||
if self.redis is None:
|
||||
return
|
||||
self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||
"""Modify the state for a token while holding exclusive lock.
|
||||
|
||||
Args:
|
||||
token: The token to modify the state for.
|
||||
|
||||
Yields:
|
||||
The state for the token.
|
||||
"""
|
||||
yield self.state()
|
||||
|
||||
|
||||
class StateManagerMemory(StateManager):
|
||||
"""A state manager that stores states in memory."""
|
||||
|
||||
# The mapping of client ids to states.
|
||||
states: Dict[str, State] = {}
|
||||
|
||||
# The mutex ensures the dict of mutexes is updated exclusively
|
||||
_state_manager_lock = asyncio.Lock()
|
||||
|
||||
# The dict of mutexes for each client
|
||||
_states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
|
||||
|
||||
class Config:
|
||||
"""The Pydantic config."""
|
||||
|
||||
fields = {
|
||||
"_states_locks": {"exclude": True},
|
||||
}
|
||||
|
||||
async def get_state(self, token: str) -> State:
|
||||
"""Get the state for a token.
|
||||
|
||||
Args:
|
||||
token: The token to get the state for.
|
||||
|
||||
Returns:
|
||||
The state for the token.
|
||||
"""
|
||||
if token not in self.states:
|
||||
self.states[token] = self.state()
|
||||
return self.states[token]
|
||||
|
||||
async def set_state(self, token: str, state: State):
|
||||
"""Set the state for a token.
|
||||
|
||||
Args:
|
||||
token: The token to set the state for.
|
||||
state: The state to set.
|
||||
"""
|
||||
pass
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||
"""Modify the state for a token while holding exclusive lock.
|
||||
|
||||
Args:
|
||||
token: The token to modify the state for.
|
||||
|
||||
Yields:
|
||||
The state for the token.
|
||||
"""
|
||||
if token not in self._states_locks:
|
||||
async with self._state_manager_lock:
|
||||
if token not in self._states_locks:
|
||||
self._states_locks[token] = asyncio.Lock()
|
||||
|
||||
async with self._states_locks[token]:
|
||||
state = await self.get_state(token)
|
||||
yield state
|
||||
await self.set_state(token, state)
|
||||
|
||||
|
||||
class StateManagerRedis(StateManager):
|
||||
"""A state manager that stores states in redis."""
|
||||
|
||||
# The redis client to use.
|
||||
redis: Redis
|
||||
|
||||
# The token expiration time (s).
|
||||
token_expiration: int = constants.TOKEN_EXPIRATION
|
||||
|
||||
# The maximum time to hold a lock (ms).
|
||||
lock_expiration: int = constants.LOCK_EXPIRATION
|
||||
|
||||
# The keyspace subscription string when redis is waiting for lock to be released
|
||||
_redis_notify_keyspace_events: str = (
|
||||
"K" # Enable keyspace notifications (target a particular key)
|
||||
"g" # For generic commands (DEL, EXPIRE, etc)
|
||||
"x" # For expired events
|
||||
"e" # For evicted events (i.e. maxmemory exceeded)
|
||||
)
|
||||
|
||||
# These events indicate that a lock is no longer held
|
||||
_redis_keyspace_lock_release_events: Set[bytes] = {
|
||||
b"del",
|
||||
b"expire",
|
||||
b"expired",
|
||||
b"evicted",
|
||||
}
|
||||
|
||||
async def get_state(self, token: str) -> State:
|
||||
"""Get the state for a token.
|
||||
|
||||
Args:
|
||||
token: The token to get the state for.
|
||||
|
||||
Returns:
|
||||
The state for the token.
|
||||
"""
|
||||
redis_state = await self.redis.get(token)
|
||||
if redis_state is None:
|
||||
await self.set_state(token, self.state())
|
||||
return await self.get_state(token)
|
||||
return cloudpickle.loads(redis_state)
|
||||
|
||||
async def set_state(self, token: str, state: State, lock_id: bytes | None = None):
|
||||
"""Set the state for a token.
|
||||
|
||||
Args:
|
||||
token: The token to set the state for.
|
||||
state: The state to set.
|
||||
lock_id: If provided, the lock_key must be set to this value to set the state.
|
||||
|
||||
Raises:
|
||||
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
|
||||
"""
|
||||
# check that we're holding the lock
|
||||
if (
|
||||
lock_id is not None
|
||||
and await self.redis.get(self._lock_key(token)) != lock_id
|
||||
):
|
||||
raise LockExpiredError(
|
||||
f"Lock expired for token {token} while processing. Consider increasing "
|
||||
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
||||
"or use `@rx.background` decorator for long-running tasks."
|
||||
)
|
||||
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||
"""Modify the state for a token while holding exclusive lock.
|
||||
|
||||
Args:
|
||||
token: The token to modify the state for.
|
||||
|
||||
Yields:
|
||||
The state for the token.
|
||||
"""
|
||||
async with self._lock(token) as lock_id:
|
||||
state = await self.get_state(token)
|
||||
yield state
|
||||
await self.set_state(token, state, lock_id)
|
||||
|
||||
@staticmethod
|
||||
def _lock_key(token: str) -> bytes:
|
||||
"""Get the redis key for a token's lock.
|
||||
|
||||
Args:
|
||||
token: The token to get the lock key for.
|
||||
|
||||
Returns:
|
||||
The redis lock key for the token.
|
||||
"""
|
||||
return f"{token}_lock".encode()
|
||||
|
||||
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
|
||||
"""Try to get a redis lock for a token.
|
||||
|
||||
Args:
|
||||
lock_key: The redis key for the lock.
|
||||
lock_id: The ID of the lock.
|
||||
|
||||
Returns:
|
||||
True if the lock was obtained.
|
||||
"""
|
||||
return await self.redis.set(
|
||||
lock_key,
|
||||
lock_id,
|
||||
px=self.lock_expiration,
|
||||
nx=True, # only set if it doesn't exist
|
||||
)
|
||||
|
||||
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
|
||||
"""Wait for a redis lock to be released via pubsub.
|
||||
|
||||
Coroutine will not return until the lock is obtained.
|
||||
|
||||
Args:
|
||||
lock_key: The redis key for the lock.
|
||||
lock_id: The ID of the lock.
|
||||
"""
|
||||
state_is_locked = False
|
||||
lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
|
||||
# Enable keyspace notifications for the lock key, so we know when it is available.
|
||||
await self.redis.config_set(
|
||||
"notify-keyspace-events", self._redis_notify_keyspace_events
|
||||
)
|
||||
async with self.redis.pubsub() as pubsub:
|
||||
await pubsub.psubscribe(lock_key_channel)
|
||||
while not state_is_locked:
|
||||
# wait for the lock to be released
|
||||
while True:
|
||||
if not await self.redis.exists(lock_key):
|
||||
break # key was removed, try to get the lock again
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True,
|
||||
timeout=self.lock_expiration / 1000.0,
|
||||
)
|
||||
if message is None:
|
||||
continue
|
||||
if message["data"] in self._redis_keyspace_lock_release_events:
|
||||
break
|
||||
state_is_locked = await self._try_get_lock(lock_key, lock_id)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _lock(self, token: str):
|
||||
"""Obtain a redis lock for a token.
|
||||
|
||||
Args:
|
||||
token: The token to obtain a lock for.
|
||||
|
||||
Yields:
|
||||
The ID of the lock (to be passed to set_state).
|
||||
|
||||
Raises:
|
||||
LockExpiredError: If the lock has expired while processing the event.
|
||||
"""
|
||||
lock_key = self._lock_key(token)
|
||||
lock_id = uuid.uuid4().hex.encode()
|
||||
|
||||
if not await self._try_get_lock(lock_key, lock_id):
|
||||
# Missed the fast-path to get lock, subscribe for lock delete/expire events
|
||||
await self._wait_lock(lock_key, lock_id)
|
||||
state_is_locked = True
|
||||
|
||||
try:
|
||||
yield lock_id
|
||||
except LockExpiredError:
|
||||
state_is_locked = False
|
||||
raise
|
||||
finally:
|
||||
if state_is_locked:
|
||||
# only delete our lock
|
||||
await self.redis.delete(lock_key)
|
||||
|
||||
|
||||
class ClientStorageBase:
|
||||
@ -1246,7 +1692,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
value, super().__getattribute__("__mutable_types__")
|
||||
) and __name not in ("__wrapped__", "_self_state"):
|
||||
# Recursively wrap mutable attribute values retrieved through this proxy.
|
||||
return MutableProxy(
|
||||
return type(self)(
|
||||
wrapped=value,
|
||||
state=self._self_state,
|
||||
field_name=self._self_field_name,
|
||||
@ -1266,7 +1712,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
value = super().__getitem__(key)
|
||||
if isinstance(value, self.__mutable_types__):
|
||||
# Recursively wrap mutable items retrieved through this proxy.
|
||||
return MutableProxy(
|
||||
return type(self)(
|
||||
wrapped=value,
|
||||
state=self._self_state,
|
||||
field_name=self._self_field_name,
|
||||
@ -1332,3 +1778,34 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
A deepcopy of the wrapped object, unconnected to the proxy.
|
||||
"""
|
||||
return copy.deepcopy(self.__wrapped__, memo=memo)
|
||||
|
||||
|
||||
class ImmutableMutableProxy(MutableProxy):
|
||||
"""A proxy for a mutable object that tracks changes.
|
||||
|
||||
This wrapper comes from StateProxy, and will raise an exception if an attempt is made
|
||||
to modify the wrapped object when the StateProxy is immutable.
|
||||
"""
|
||||
|
||||
def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None):
|
||||
"""Raise an exception when an attempt is made to modify the object.
|
||||
|
||||
Intended for use with `FunctionWrapper` from the `wrapt` library.
|
||||
|
||||
Args:
|
||||
wrapped: The wrapped function.
|
||||
instance: The instance of the wrapped function.
|
||||
args: The args for the wrapped function.
|
||||
kwargs: The kwargs for the wrapped function.
|
||||
|
||||
Raises:
|
||||
ImmutableStateError: if the StateProxy is not mutable.
|
||||
"""
|
||||
if not self._self_state._self_mutable:
|
||||
raise ImmutableStateError(
|
||||
"Background task StateProxy is immutable outside of a context "
|
||||
"manager. Use `async with self` to modify state."
|
||||
)
|
||||
super()._mark_dirty(
|
||||
wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""reflex.testing - tools for testing reflex apps."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import inspect
|
||||
@ -19,14 +20,13 @@ import types
|
||||
from http.server import SimpleHTTPRequestHandler
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import psutil
|
||||
@ -38,7 +38,7 @@ import reflex.utils.build
|
||||
import reflex.utils.exec
|
||||
import reflex.utils.prerequisites
|
||||
import reflex.utils.processes
|
||||
from reflex.app import EventNamespace
|
||||
from reflex.state import State, StateManagerMemory, StateManagerRedis
|
||||
|
||||
try:
|
||||
from selenium import webdriver # pyright: ignore [reportMissingImports]
|
||||
@ -109,6 +109,7 @@ class AppHarness:
|
||||
frontend_url: Optional[str] = None
|
||||
backend_thread: Optional[threading.Thread] = None
|
||||
backend: Optional[uvicorn.Server] = None
|
||||
state_manager: Optional[StateManagerMemory | StateManagerRedis] = None
|
||||
_frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
@ -162,6 +163,27 @@ class AppHarness:
|
||||
reflex.config.get_config(reload=True)
|
||||
self.app_module = reflex.utils.prerequisites.get_app(reload=True)
|
||||
self.app_instance = self.app_module.app
|
||||
if isinstance(self.app_instance.state_manager, StateManagerRedis):
|
||||
# Create our own redis connection for testing.
|
||||
self.state_manager = StateManagerRedis.create(self.app_instance.state)
|
||||
else:
|
||||
self.state_manager = self.app_instance.state_manager
|
||||
|
||||
def _get_backend_shutdown_handler(self):
|
||||
if self.backend is None:
|
||||
raise RuntimeError("Backend was not initialized.")
|
||||
|
||||
original_shutdown = self.backend.shutdown
|
||||
|
||||
async def _shutdown_redis(*args, **kwargs) -> None:
|
||||
# ensure redis is closed before event loop
|
||||
if self.app_instance is not None and isinstance(
|
||||
self.app_instance.state_manager, StateManagerRedis
|
||||
):
|
||||
await self.app_instance.state_manager.redis.close()
|
||||
await original_shutdown(*args, **kwargs)
|
||||
|
||||
return _shutdown_redis
|
||||
|
||||
def _start_backend(self, port=0):
|
||||
if self.app_instance is None:
|
||||
@ -173,6 +195,7 @@ class AppHarness:
|
||||
port=port,
|
||||
)
|
||||
)
|
||||
self.backend.shutdown = self._get_backend_shutdown_handler()
|
||||
self.backend_thread = threading.Thread(target=self.backend.run)
|
||||
self.backend_thread.start()
|
||||
|
||||
@ -296,6 +319,35 @@ class AppHarness:
|
||||
time.sleep(step)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _poll_for_async(
|
||||
target: Callable[[], Coroutine[None, None, T]],
|
||||
timeout: TimeoutType = None,
|
||||
step: TimeoutType = None,
|
||||
) -> T | bool:
|
||||
"""Generic polling logic for async functions.
|
||||
|
||||
Args:
|
||||
target: callable that returns truthy if polling condition is met.
|
||||
timeout: max polling time
|
||||
step: interval between checking target()
|
||||
|
||||
Returns:
|
||||
return value of target() if truthy within timeout
|
||||
False if timeout elapses
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_TIMEOUT
|
||||
if step is None:
|
||||
step = POLL_INTERVAL
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
success = await target()
|
||||
if success:
|
||||
return success
|
||||
await asyncio.sleep(step)
|
||||
return False
|
||||
|
||||
def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
|
||||
"""Poll backend server for listening sockets.
|
||||
|
||||
@ -351,39 +403,76 @@ class AppHarness:
|
||||
self._frontends.append(driver)
|
||||
return driver
|
||||
|
||||
async def emit_state_updates(self) -> list[Any]:
|
||||
"""Send any backend state deltas to the frontend.
|
||||
async def get_state(self, token: str) -> State:
|
||||
"""Get the state associated with the given token.
|
||||
|
||||
Args:
|
||||
token: The state token to look up.
|
||||
|
||||
Returns:
|
||||
List of awaited response from each EventNamespace.emit() call.
|
||||
The state instance associated with the given token
|
||||
|
||||
Raises:
|
||||
RuntimeError: when the app hasn't started running
|
||||
"""
|
||||
if self.app_instance is None or self.app_instance.sio is None:
|
||||
if self.state_manager is None:
|
||||
raise RuntimeError("state_manager is not set.")
|
||||
try:
|
||||
return await self.state_manager.get_state(token)
|
||||
finally:
|
||||
if isinstance(self.state_manager, StateManagerRedis):
|
||||
await self.state_manager.redis.close()
|
||||
|
||||
async def set_state(self, token: str, **kwargs) -> None:
|
||||
"""Set the state associated with the given token.
|
||||
|
||||
Args:
|
||||
token: The state token to set.
|
||||
kwargs: Attributes to set on the state.
|
||||
|
||||
Raises:
|
||||
RuntimeError: when the app hasn't started running
|
||||
"""
|
||||
if self.state_manager is None:
|
||||
raise RuntimeError("state_manager is not set.")
|
||||
state = await self.get_state(token)
|
||||
for key, value in kwargs.items():
|
||||
setattr(state, key, value)
|
||||
try:
|
||||
await self.state_manager.set_state(token, state)
|
||||
finally:
|
||||
if isinstance(self.state_manager, StateManagerRedis):
|
||||
await self.state_manager.redis.close()
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||
"""Modify the state associated with the given token and send update to frontend.
|
||||
|
||||
Args:
|
||||
token: The state token to modify
|
||||
|
||||
Yields:
|
||||
The state instance associated with the given token
|
||||
|
||||
Raises:
|
||||
RuntimeError: when the app hasn't started running
|
||||
"""
|
||||
if self.state_manager is None:
|
||||
raise RuntimeError("state_manager is not set.")
|
||||
if self.app_instance is None:
|
||||
raise RuntimeError("App is not running.")
|
||||
event_ns: EventNamespace = cast(
|
||||
EventNamespace,
|
||||
self.app_instance.event_namespace,
|
||||
)
|
||||
pending: list[Coroutine[Any, Any, Any]] = []
|
||||
for state in self.app_instance.state_manager.states.values():
|
||||
delta = state.get_delta()
|
||||
if delta:
|
||||
update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
|
||||
state._clean()
|
||||
# Emit the event.
|
||||
pending.append(
|
||||
event_ns.emit(
|
||||
str(reflex.constants.SocketEvent.EVENT),
|
||||
update.json(),
|
||||
to=state.get_sid(),
|
||||
),
|
||||
)
|
||||
responses = []
|
||||
for request in pending:
|
||||
responses.append(await request)
|
||||
return responses
|
||||
app_state_manager = self.app_instance.state_manager
|
||||
if isinstance(self.state_manager, StateManagerRedis):
|
||||
# Temporarily replace the app's state manager with our own, since
|
||||
# the redis connection is on the backend_thread event loop
|
||||
self.app_instance.state_manager = self.state_manager
|
||||
try:
|
||||
async with self.app_instance.modify_state(token) as state:
|
||||
yield state
|
||||
finally:
|
||||
if isinstance(self.state_manager, StateManagerRedis):
|
||||
self.app_instance.state_manager = app_state_manager
|
||||
await self.state_manager.redis.close()
|
||||
|
||||
def poll_for_content(
|
||||
self,
|
||||
@ -457,6 +546,9 @@ class AppHarness:
|
||||
if self.app_instance is None:
|
||||
raise RuntimeError("App is not running.")
|
||||
state_manager = self.app_instance.state_manager
|
||||
assert isinstance(
|
||||
state_manager, StateManagerMemory
|
||||
), "Only works with memory state manager"
|
||||
if not self._poll_for(
|
||||
target=lambda: state_manager.states,
|
||||
timeout=timeout,
|
||||
@ -534,7 +626,6 @@ class Subdir404TCPServer(socketserver.TCPServer):
|
||||
request: the requesting socket
|
||||
client_address: (host, port) referring to the client’s address.
|
||||
"""
|
||||
print(client_address, type(client_address))
|
||||
self.RequestHandlerClass(
|
||||
request,
|
||||
client_address,
|
||||
@ -605,6 +696,7 @@ class AppHarnessProd(AppHarness):
|
||||
workers=reflex.utils.processes.get_num_workers(),
|
||||
),
|
||||
)
|
||||
self.backend.shutdown = self._get_backend_shutdown_handler()
|
||||
self.backend_thread = threading.Thread(target=self.backend.run)
|
||||
self.backend_thread.start()
|
||||
|
||||
|
@ -5,3 +5,11 @@ class InvalidStylePropError(TypeError):
|
||||
"""Custom Type Error when style props have invalid values."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ImmutableStateError(AttributeError):
|
||||
"""Raised when a background task attempts to modify state outside of context."""
|
||||
|
||||
|
||||
class LockExpiredError(Exception):
|
||||
"""Raised when the state lock expires while an event is being processed."""
|
||||
|
@ -21,7 +21,7 @@ import httpx
|
||||
import typer
|
||||
from alembic.util.exc import CommandError
|
||||
from packaging import version
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from reflex import constants, model
|
||||
from reflex.compiler import templates
|
||||
@ -124,9 +124,11 @@ def get_redis() -> Redis | None:
|
||||
The redis client.
|
||||
"""
|
||||
config = get_config()
|
||||
if config.redis_url is None:
|
||||
if not config.redis_url:
|
||||
return None
|
||||
redis_url, redis_port = config.redis_url.split(":")
|
||||
redis_url, has_port, redis_port = config.redis_url.partition(":")
|
||||
if not has_port:
|
||||
redis_port = 6379
|
||||
console.info(f"Using redis at {config.redis_url}")
|
||||
return Redis(host=redis_url, port=int(redis_port), db=0)
|
||||
|
||||
|
@ -2,8 +2,9 @@
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator, List, Set, Union
|
||||
from typing import Dict, Generator
|
||||
|
||||
import pytest
|
||||
|
||||
@ -11,6 +12,14 @@ import reflex as rx
|
||||
from reflex.app import App
|
||||
from reflex.event import EventSpec
|
||||
|
||||
from .states import (
|
||||
DictMutationTestState,
|
||||
ListMutationTestState,
|
||||
MutableTestState,
|
||||
SubUploadState,
|
||||
UploadState,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> App:
|
||||
@ -39,60 +48,7 @@ def list_mutation_state():
|
||||
Returns:
|
||||
A state with list mutation features.
|
||||
"""
|
||||
|
||||
class TestState(rx.State):
|
||||
"""The test state."""
|
||||
|
||||
# plain list
|
||||
plain_friends = ["Tommy"]
|
||||
|
||||
def make_friend(self):
|
||||
self.plain_friends.append("another-fd")
|
||||
|
||||
def change_first_friend(self):
|
||||
self.plain_friends[0] = "Jenny"
|
||||
|
||||
def unfriend_all_friends(self):
|
||||
self.plain_friends.clear()
|
||||
|
||||
def unfriend_first_friend(self):
|
||||
del self.plain_friends[0]
|
||||
|
||||
def remove_last_friend(self):
|
||||
self.plain_friends.pop()
|
||||
|
||||
def make_friends_with_colleagues(self):
|
||||
colleagues = ["Peter", "Jimmy"]
|
||||
self.plain_friends.extend(colleagues)
|
||||
|
||||
def remove_tommy(self):
|
||||
self.plain_friends.remove("Tommy")
|
||||
|
||||
# list in dict
|
||||
friends_in_dict = {"Tommy": ["Jenny"]}
|
||||
|
||||
def remove_jenny_from_tommy(self):
|
||||
self.friends_in_dict["Tommy"].remove("Jenny")
|
||||
|
||||
def add_jimmy_to_tommy_friends(self):
|
||||
self.friends_in_dict["Tommy"].append("Jimmy")
|
||||
|
||||
def tommy_has_no_fds(self):
|
||||
self.friends_in_dict["Tommy"].clear()
|
||||
|
||||
# nested list
|
||||
friends_in_nested_list = [["Tommy"], ["Jenny"]]
|
||||
|
||||
def remove_first_group(self):
|
||||
self.friends_in_nested_list.pop(0)
|
||||
|
||||
def remove_first_person_from_first_group(self):
|
||||
self.friends_in_nested_list[0].pop(0)
|
||||
|
||||
def add_jimmy_to_second_group(self):
|
||||
self.friends_in_nested_list[1].append("Jimmy")
|
||||
|
||||
return TestState()
|
||||
return ListMutationTestState()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -102,85 +58,7 @@ def dict_mutation_state():
|
||||
Returns:
|
||||
A state with dict mutation features.
|
||||
"""
|
||||
|
||||
class TestState(rx.State):
|
||||
"""The test state."""
|
||||
|
||||
# plain dict
|
||||
details = {"name": "Tommy"}
|
||||
|
||||
def add_age(self):
|
||||
self.details.update({"age": 20}) # type: ignore
|
||||
|
||||
def change_name(self):
|
||||
self.details["name"] = "Jenny"
|
||||
|
||||
def remove_last_detail(self):
|
||||
self.details.popitem()
|
||||
|
||||
def clear_details(self):
|
||||
self.details.clear()
|
||||
|
||||
def remove_name(self):
|
||||
del self.details["name"]
|
||||
|
||||
def pop_out_age(self):
|
||||
self.details.pop("age")
|
||||
|
||||
# dict in list
|
||||
address = [{"home": "home address"}, {"work": "work address"}]
|
||||
|
||||
def remove_home_address(self):
|
||||
self.address[0].pop("home")
|
||||
|
||||
def add_street_to_home_address(self):
|
||||
self.address[0]["street"] = "street address"
|
||||
|
||||
# nested dict
|
||||
friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
|
||||
|
||||
def change_friend_name(self):
|
||||
self.friend_in_nested_dict["friend"]["name"] = "Tommy"
|
||||
|
||||
def remove_friend(self):
|
||||
self.friend_in_nested_dict.pop("friend")
|
||||
|
||||
def add_friend_age(self):
|
||||
self.friend_in_nested_dict["friend"]["age"] = 30
|
||||
|
||||
return TestState()
|
||||
|
||||
|
||||
class UploadState(rx.State):
|
||||
"""The base state for uploading a file."""
|
||||
|
||||
async def handle_upload1(self, files: List[rx.UploadFile]):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseState(rx.State):
|
||||
"""The test base state."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubUploadState(BaseState):
|
||||
"""The test substate."""
|
||||
|
||||
img: str
|
||||
|
||||
async def handle_upload(self, files: List[rx.UploadFile]):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
pass
|
||||
return DictMutationTestState()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -203,187 +81,6 @@ def upload_event_spec():
|
||||
return EventSpec(handler=UploadState.handle_upload1, upload=True) # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def upload_state(tmp_path):
|
||||
"""Create upload state.
|
||||
|
||||
Args:
|
||||
tmp_path: pytest tmp_path
|
||||
|
||||
Returns:
|
||||
The state
|
||||
|
||||
"""
|
||||
|
||||
class FileUploadState(rx.State):
|
||||
"""The base state for uploading a file."""
|
||||
|
||||
img_list: List[str]
|
||||
|
||||
async def handle_upload2(self, files):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
assert file.filename is not None
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
return FileUploadState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def upload_sub_state(tmp_path):
|
||||
"""Create upload substate.
|
||||
|
||||
Args:
|
||||
tmp_path: pytest tmp_path
|
||||
|
||||
Returns:
|
||||
The state
|
||||
|
||||
"""
|
||||
|
||||
class FileState(rx.State):
|
||||
"""The base state."""
|
||||
|
||||
pass
|
||||
|
||||
class FileUploadState(FileState):
|
||||
"""The substate for uploading a file."""
|
||||
|
||||
img_list: List[str]
|
||||
|
||||
async def handle_upload2(self, files):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
assert file.filename is not None
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
return FileUploadState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def upload_grand_sub_state(tmp_path):
|
||||
"""Create upload grand-state.
|
||||
|
||||
Args:
|
||||
tmp_path: pytest tmp_path
|
||||
|
||||
Returns:
|
||||
The state
|
||||
|
||||
"""
|
||||
|
||||
class BaseFileState(rx.State):
|
||||
"""The base state."""
|
||||
|
||||
pass
|
||||
|
||||
class FileSubState(BaseFileState):
|
||||
"""The substate."""
|
||||
|
||||
pass
|
||||
|
||||
class FileUploadState(FileSubState):
|
||||
"""The grand-substate for uploading a file."""
|
||||
|
||||
img_list: List[str]
|
||||
|
||||
async def handle_upload2(self, files):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
assert file.filename is not None
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
assert file.filename is not None
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
return FileUploadState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config_values() -> Dict:
|
||||
"""Get base config values.
|
||||
@ -418,35 +115,6 @@ def sqlite_db_config_values(base_db_config_values) -> Dict:
|
||||
return base_db_config_values
|
||||
|
||||
|
||||
class GenState(rx.State):
|
||||
"""A state with event handlers that generate multiple updates."""
|
||||
|
||||
value: int
|
||||
|
||||
def go(self, c: int):
|
||||
"""Increment the value c times and update each time.
|
||||
|
||||
Args:
|
||||
c: The number of times to increment.
|
||||
|
||||
Yields:
|
||||
After each increment.
|
||||
"""
|
||||
for _ in range(c):
|
||||
self.value += 1
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gen_state() -> GenState:
|
||||
"""A state.
|
||||
|
||||
Returns:
|
||||
A test state.
|
||||
"""
|
||||
return GenState # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router_data_headers() -> Dict[str, str]:
|
||||
"""Router data headers.
|
||||
@ -546,46 +214,19 @@ def mutable_state():
|
||||
Returns:
|
||||
A state object.
|
||||
"""
|
||||
|
||||
class OtherBase(rx.Base):
|
||||
bar: str = ""
|
||||
|
||||
class CustomVar(rx.Base):
|
||||
foo: str = ""
|
||||
array: List[str] = []
|
||||
hashmap: Dict[str, str] = {}
|
||||
test_set: Set[str] = set()
|
||||
custom: OtherBase = OtherBase()
|
||||
|
||||
class MutableTestState(rx.State):
|
||||
"""A test state."""
|
||||
|
||||
array: List[Union[str, List, Dict[str, str]]] = [
|
||||
"value",
|
||||
[1, 2, 3],
|
||||
{"key": "value"},
|
||||
]
|
||||
hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
|
||||
"key": ["list", "of", "values"],
|
||||
"another_key": "another_value",
|
||||
"third_key": {"key": "value"},
|
||||
}
|
||||
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
|
||||
custom: CustomVar = CustomVar()
|
||||
_be_custom: CustomVar = CustomVar()
|
||||
|
||||
def reassign_mutables(self):
|
||||
self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
|
||||
self.hashmap = {
|
||||
"mod_key": ["list", "of", "values"],
|
||||
"mod_another_key": "another_value",
|
||||
"mod_third_key": {"key": "value"},
|
||||
}
|
||||
self.test_set = {1, 2, 3, 4, "five"}
|
||||
|
||||
return MutableTestState()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def token() -> str:
|
||||
"""Create a token.
|
||||
|
||||
Returns:
|
||||
A fresh/unique token string.
|
||||
"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def duplicate_substate():
|
||||
"""Create a Test state that has duplicate child substates.
|
||||
|
30
tests/states/__init__.py
Normal file
30
tests/states/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Common rx.State subclasses for use in tests."""
|
||||
import reflex as rx
|
||||
|
||||
from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
|
||||
from .upload import (
|
||||
ChildFileUploadState,
|
||||
FileUploadState,
|
||||
GrandChildFileUploadState,
|
||||
SubUploadState,
|
||||
UploadState,
|
||||
)
|
||||
|
||||
|
||||
class GenState(rx.State):
|
||||
"""A state with event handlers that generate multiple updates."""
|
||||
|
||||
value: int
|
||||
|
||||
def go(self, c: int):
|
||||
"""Increment the value c times and update each time.
|
||||
|
||||
Args:
|
||||
c: The number of times to increment.
|
||||
|
||||
Yields:
|
||||
After each increment.
|
||||
"""
|
||||
for _ in range(c):
|
||||
self.value += 1
|
||||
yield
|
172
tests/states/mutation.py
Normal file
172
tests/states/mutation.py
Normal file
@ -0,0 +1,172 @@
|
||||
"""Test states for mutable vars."""
|
||||
|
||||
from typing import Dict, List, Set, Union
|
||||
|
||||
import reflex as rx
|
||||
|
||||
|
||||
class DictMutationTestState(rx.State):
|
||||
"""A state for testing ReflexDict mutation."""
|
||||
|
||||
# plain dict
|
||||
details = {"name": "Tommy"}
|
||||
|
||||
def add_age(self):
|
||||
"""Add an age to the dict."""
|
||||
self.details.update({"age": 20}) # type: ignore
|
||||
|
||||
def change_name(self):
|
||||
"""Change the name in the dict."""
|
||||
self.details["name"] = "Jenny"
|
||||
|
||||
def remove_last_detail(self):
|
||||
"""Remove the last item in the dict."""
|
||||
self.details.popitem()
|
||||
|
||||
def clear_details(self):
|
||||
"""Clear the dict."""
|
||||
self.details.clear()
|
||||
|
||||
def remove_name(self):
|
||||
"""Remove the name from the dict."""
|
||||
del self.details["name"]
|
||||
|
||||
def pop_out_age(self):
|
||||
"""Pop out the age from the dict."""
|
||||
self.details.pop("age")
|
||||
|
||||
# dict in list
|
||||
address = [{"home": "home address"}, {"work": "work address"}]
|
||||
|
||||
def remove_home_address(self):
|
||||
"""Remove the home address from dict in the list."""
|
||||
self.address[0].pop("home")
|
||||
|
||||
def add_street_to_home_address(self):
|
||||
"""Set street key in the dict in the list."""
|
||||
self.address[0]["street"] = "street address"
|
||||
|
||||
# nested dict
|
||||
friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
|
||||
|
||||
def change_friend_name(self):
|
||||
"""Change the friend's name in the nested dict."""
|
||||
self.friend_in_nested_dict["friend"]["name"] = "Tommy"
|
||||
|
||||
def remove_friend(self):
|
||||
"""Remove the friend from the nested dict."""
|
||||
self.friend_in_nested_dict.pop("friend")
|
||||
|
||||
def add_friend_age(self):
|
||||
"""Add an age to the friend in the nested dict."""
|
||||
self.friend_in_nested_dict["friend"]["age"] = 30
|
||||
|
||||
|
||||
class ListMutationTestState(rx.State):
|
||||
"""A state for testing ReflexList mutation."""
|
||||
|
||||
# plain list
|
||||
plain_friends = ["Tommy"]
|
||||
|
||||
def make_friend(self):
|
||||
"""Add a friend to the list."""
|
||||
self.plain_friends.append("another-fd")
|
||||
|
||||
def change_first_friend(self):
|
||||
"""Change the first friend in the list."""
|
||||
self.plain_friends[0] = "Jenny"
|
||||
|
||||
def unfriend_all_friends(self):
|
||||
"""Unfriend all friends in the list."""
|
||||
self.plain_friends.clear()
|
||||
|
||||
def unfriend_first_friend(self):
|
||||
"""Unfriend the first friend in the list."""
|
||||
del self.plain_friends[0]
|
||||
|
||||
def remove_last_friend(self):
|
||||
"""Remove the last friend in the list."""
|
||||
self.plain_friends.pop()
|
||||
|
||||
def make_friends_with_colleagues(self):
|
||||
"""Add list of friends to the list."""
|
||||
colleagues = ["Peter", "Jimmy"]
|
||||
self.plain_friends.extend(colleagues)
|
||||
|
||||
def remove_tommy(self):
|
||||
"""Remove Tommy from the list."""
|
||||
self.plain_friends.remove("Tommy")
|
||||
|
||||
# list in dict
|
||||
friends_in_dict = {"Tommy": ["Jenny"]}
|
||||
|
||||
def remove_jenny_from_tommy(self):
|
||||
"""Remove Jenny from Tommy's friends list."""
|
||||
self.friends_in_dict["Tommy"].remove("Jenny")
|
||||
|
||||
def add_jimmy_to_tommy_friends(self):
|
||||
"""Add Jimmy to Tommy's friends list."""
|
||||
self.friends_in_dict["Tommy"].append("Jimmy")
|
||||
|
||||
def tommy_has_no_fds(self):
|
||||
"""Clear Tommy's friends list."""
|
||||
self.friends_in_dict["Tommy"].clear()
|
||||
|
||||
# nested list
|
||||
friends_in_nested_list = [["Tommy"], ["Jenny"]]
|
||||
|
||||
def remove_first_group(self):
|
||||
"""Remove the first group of friends from the nested list."""
|
||||
self.friends_in_nested_list.pop(0)
|
||||
|
||||
def remove_first_person_from_first_group(self):
|
||||
"""Remove the first person from the first group of friends in the nested list."""
|
||||
self.friends_in_nested_list[0].pop(0)
|
||||
|
||||
def add_jimmy_to_second_group(self):
|
||||
"""Add Jimmy to the second group of friends in the nested list."""
|
||||
self.friends_in_nested_list[1].append("Jimmy")
|
||||
|
||||
|
||||
class OtherBase(rx.Base):
|
||||
"""A Base model with a str field."""
|
||||
|
||||
bar: str = ""
|
||||
|
||||
|
||||
class CustomVar(rx.Base):
|
||||
"""A Base model with multiple fields."""
|
||||
|
||||
foo: str = ""
|
||||
array: List[str] = []
|
||||
hashmap: Dict[str, str] = {}
|
||||
test_set: Set[str] = set()
|
||||
custom: OtherBase = OtherBase()
|
||||
|
||||
|
||||
class MutableTestState(rx.State):
|
||||
"""A test state."""
|
||||
|
||||
array: List[Union[str, List, Dict[str, str]]] = [
|
||||
"value",
|
||||
[1, 2, 3],
|
||||
{"key": "value"},
|
||||
]
|
||||
hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
|
||||
"key": ["list", "of", "values"],
|
||||
"another_key": "another_value",
|
||||
"third_key": {"key": "value"},
|
||||
}
|
||||
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
|
||||
custom: CustomVar = CustomVar()
|
||||
_be_custom: CustomVar = CustomVar()
|
||||
|
||||
def reassign_mutables(self):
|
||||
"""Assign mutable fields to different values."""
|
||||
self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
|
||||
self.hashmap = {
|
||||
"mod_key": ["list", "of", "values"],
|
||||
"mod_another_key": "another_value",
|
||||
"mod_third_key": {"key": "value"},
|
||||
}
|
||||
self.test_set = {1, 2, 3, 4, "five"}
|
175
tests/states/upload.py
Normal file
175
tests/states/upload.py
Normal file
@ -0,0 +1,175 @@
|
||||
"""Test states for upload-related tests."""
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, List
|
||||
|
||||
import reflex as rx
|
||||
|
||||
|
||||
class UploadState(rx.State):
|
||||
"""The base state for uploading a file."""
|
||||
|
||||
async def handle_upload1(self, files: List[rx.UploadFile]):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseState(rx.State):
|
||||
"""The test base state."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SubUploadState(BaseState):
|
||||
"""The test substate."""
|
||||
|
||||
img: str
|
||||
|
||||
async def handle_upload(self, files: List[rx.UploadFile]):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class FileUploadState(rx.State):
|
||||
"""The base state for uploading a file."""
|
||||
|
||||
img_list: List[str]
|
||||
_tmp_path: ClassVar[Path]
|
||||
|
||||
async def handle_upload2(self, files):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{self._tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{self._tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
assert file.filename is not None
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
|
||||
class FileStateBase1(rx.State):
|
||||
"""The base state for a child FileUploadState."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChildFileUploadState(FileStateBase1):
|
||||
"""The child state for uploading a file."""
|
||||
|
||||
img_list: List[str]
|
||||
_tmp_path: ClassVar[Path]
|
||||
|
||||
async def handle_upload2(self, files):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{self._tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{self._tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
assert file.filename is not None
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
|
||||
class FileStateBase2(FileStateBase1):
|
||||
"""The parent state for a grandchild FileUploadState."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GrandChildFileUploadState(FileStateBase2):
|
||||
"""The child state for uploading a file."""
|
||||
|
||||
img_list: List[str]
|
||||
_tmp_path: ClassVar[Path]
|
||||
|
||||
async def handle_upload2(self, files):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{self._tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
self.img_list.append(file.filename)
|
||||
|
||||
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
||||
"""Handle the upload of a file.
|
||||
|
||||
Args:
|
||||
files: The uploaded files.
|
||||
"""
|
||||
for file in files:
|
||||
upload_data = await file.read()
|
||||
outfile = f"{self._tmp_path}/{file.filename}"
|
||||
|
||||
# Save the file.
|
||||
with open(outfile, "wb") as file_object:
|
||||
file_object.write(upload_data)
|
||||
|
||||
# Update the img var.
|
||||
assert file.filename is not None
|
||||
self.img_list.append(file.filename)
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import io
|
||||
import os.path
|
||||
import sys
|
||||
import uuid
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
if sys.version_info.major >= 3 and sys.version_info.minor > 7:
|
||||
@ -30,11 +31,18 @@ from reflex.components import Box, Component, Cond, Fragment, Text
|
||||
from reflex.event import Event, get_hydrate_event
|
||||
from reflex.middleware import HydrateMiddleware
|
||||
from reflex.model import Model
|
||||
from reflex.state import State, StateUpdate
|
||||
from reflex.state import State, StateManagerRedis, StateUpdate
|
||||
from reflex.style import Style
|
||||
from reflex.utils import format
|
||||
from reflex.vars import ComputedVar
|
||||
|
||||
from .states import (
|
||||
ChildFileUploadState,
|
||||
FileUploadState,
|
||||
GenState,
|
||||
GrandChildFileUploadState,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def index_page():
|
||||
@ -64,6 +72,12 @@ def about_page():
|
||||
return about
|
||||
|
||||
|
||||
class ATestState(State):
|
||||
"""A simple state for testing."""
|
||||
|
||||
var: int
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_state() -> Type[State]:
|
||||
"""A default state.
|
||||
@ -71,11 +85,7 @@ def test_state() -> Type[State]:
|
||||
Returns:
|
||||
A default state.
|
||||
"""
|
||||
|
||||
class TestState(State):
|
||||
var: int
|
||||
|
||||
return TestState
|
||||
return ATestState
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -313,23 +323,28 @@ def test_initialize_admin_dashboard_with_view_overrides(test_model):
|
||||
assert app.admin_dash.view_overrides[test_model] == TestModelView
|
||||
|
||||
|
||||
def test_initialize_with_state(test_state):
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_with_state(test_state: Type[ATestState], token: str):
|
||||
"""Test setting the state of an app.
|
||||
|
||||
Args:
|
||||
test_state: The default state.
|
||||
token: a Token.
|
||||
"""
|
||||
app = App(state=test_state)
|
||||
assert app.state == test_state
|
||||
|
||||
# Get a state for a given token.
|
||||
token = "token"
|
||||
state = app.state_manager.get_state(token)
|
||||
state = await app.state_manager.get_state(token)
|
||||
assert isinstance(state, test_state)
|
||||
assert state.var == 0 # type: ignore
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
await app.state_manager.redis.close()
|
||||
|
||||
def test_set_and_get_state(test_state):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_state(test_state):
|
||||
"""Test setting and getting the state of an app with different tokens.
|
||||
|
||||
Args:
|
||||
@ -338,47 +353,51 @@ def test_set_and_get_state(test_state):
|
||||
app = App(state=test_state)
|
||||
|
||||
# Create two tokens.
|
||||
token1 = "token1"
|
||||
token2 = "token2"
|
||||
token1 = str(uuid.uuid4())
|
||||
token2 = str(uuid.uuid4())
|
||||
|
||||
# Get the default state for each token.
|
||||
state1 = app.state_manager.get_state(token1)
|
||||
state2 = app.state_manager.get_state(token2)
|
||||
state1 = await app.state_manager.get_state(token1)
|
||||
state2 = await app.state_manager.get_state(token2)
|
||||
assert state1.var == 0 # type: ignore
|
||||
assert state2.var == 0 # type: ignore
|
||||
|
||||
# Set the vars to different values.
|
||||
state1.var = 1
|
||||
state2.var = 2
|
||||
app.state_manager.set_state(token1, state1)
|
||||
app.state_manager.set_state(token2, state2)
|
||||
await app.state_manager.set_state(token1, state1)
|
||||
await app.state_manager.set_state(token2, state2)
|
||||
|
||||
# Get the states again and check the values.
|
||||
state1 = app.state_manager.get_state(token1)
|
||||
state2 = app.state_manager.get_state(token2)
|
||||
state1 = await app.state_manager.get_state(token1)
|
||||
state2 = await app.state_manager.get_state(token2)
|
||||
assert state1.var == 1 # type: ignore
|
||||
assert state2.var == 2 # type: ignore
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
await app.state_manager.redis.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_var_event(test_state):
|
||||
async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
|
||||
"""Test that the default handler of a dynamic generated var
|
||||
works as expected.
|
||||
|
||||
Args:
|
||||
test_state: State Fixture.
|
||||
token: a Token.
|
||||
"""
|
||||
test_state = test_state()
|
||||
test_state.add_var("int_val", int, 0)
|
||||
result = await test_state._process(
|
||||
state = test_state() # type: ignore
|
||||
state.add_var("int_val", int, 0)
|
||||
result = await state._process(
|
||||
Event(
|
||||
token="fake-token",
|
||||
name="test_state.set_int_val",
|
||||
token=token,
|
||||
name=f"{test_state.get_name()}.set_int_val",
|
||||
router_data={"pathname": "/", "query": {}},
|
||||
payload={"value": 50},
|
||||
)
|
||||
).__anext__()
|
||||
assert result.delta == {"test_state": {"int_val": 50}}
|
||||
assert result.delta == {test_state.get_name(): {"int_val": 50}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -388,12 +407,20 @@ async def test_dynamic_var_event(test_state):
|
||||
pytest.param(
|
||||
[
|
||||
(
|
||||
"test_state.make_friend",
|
||||
{"test_state": {"plain_friends": ["Tommy", "another-fd"]}},
|
||||
"list_mutation_test_state.make_friend",
|
||||
{
|
||||
"list_mutation_test_state": {
|
||||
"plain_friends": ["Tommy", "another-fd"]
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.change_first_friend",
|
||||
{"test_state": {"plain_friends": ["Jenny", "another-fd"]}},
|
||||
"list_mutation_test_state.change_first_friend",
|
||||
{
|
||||
"list_mutation_test_state": {
|
||||
"plain_friends": ["Jenny", "another-fd"]
|
||||
}
|
||||
},
|
||||
),
|
||||
],
|
||||
id="append then __setitem__",
|
||||
@ -401,12 +428,12 @@ async def test_dynamic_var_event(test_state):
|
||||
pytest.param(
|
||||
[
|
||||
(
|
||||
"test_state.unfriend_first_friend",
|
||||
{"test_state": {"plain_friends": []}},
|
||||
"list_mutation_test_state.unfriend_first_friend",
|
||||
{"list_mutation_test_state": {"plain_friends": []}},
|
||||
),
|
||||
(
|
||||
"test_state.make_friend",
|
||||
{"test_state": {"plain_friends": ["another-fd"]}},
|
||||
"list_mutation_test_state.make_friend",
|
||||
{"list_mutation_test_state": {"plain_friends": ["another-fd"]}},
|
||||
),
|
||||
],
|
||||
id="delitem then append",
|
||||
@ -414,20 +441,24 @@ async def test_dynamic_var_event(test_state):
|
||||
pytest.param(
|
||||
[
|
||||
(
|
||||
"test_state.make_friends_with_colleagues",
|
||||
{"test_state": {"plain_friends": ["Tommy", "Peter", "Jimmy"]}},
|
||||
"list_mutation_test_state.make_friends_with_colleagues",
|
||||
{
|
||||
"list_mutation_test_state": {
|
||||
"plain_friends": ["Tommy", "Peter", "Jimmy"]
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.remove_tommy",
|
||||
{"test_state": {"plain_friends": ["Peter", "Jimmy"]}},
|
||||
"list_mutation_test_state.remove_tommy",
|
||||
{"list_mutation_test_state": {"plain_friends": ["Peter", "Jimmy"]}},
|
||||
),
|
||||
(
|
||||
"test_state.remove_last_friend",
|
||||
{"test_state": {"plain_friends": ["Peter"]}},
|
||||
"list_mutation_test_state.remove_last_friend",
|
||||
{"list_mutation_test_state": {"plain_friends": ["Peter"]}},
|
||||
),
|
||||
(
|
||||
"test_state.unfriend_all_friends",
|
||||
{"test_state": {"plain_friends": []}},
|
||||
"list_mutation_test_state.unfriend_all_friends",
|
||||
{"list_mutation_test_state": {"plain_friends": []}},
|
||||
),
|
||||
],
|
||||
id="extend, remove, pop, clear",
|
||||
@ -435,24 +466,28 @@ async def test_dynamic_var_event(test_state):
|
||||
pytest.param(
|
||||
[
|
||||
(
|
||||
"test_state.add_jimmy_to_second_group",
|
||||
"list_mutation_test_state.add_jimmy_to_second_group",
|
||||
{
|
||||
"test_state": {
|
||||
"list_mutation_test_state": {
|
||||
"friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.remove_first_person_from_first_group",
|
||||
"list_mutation_test_state.remove_first_person_from_first_group",
|
||||
{
|
||||
"test_state": {
|
||||
"list_mutation_test_state": {
|
||||
"friends_in_nested_list": [[], ["Jenny", "Jimmy"]]
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.remove_first_group",
|
||||
{"test_state": {"friends_in_nested_list": [["Jenny", "Jimmy"]]}},
|
||||
"list_mutation_test_state.remove_first_group",
|
||||
{
|
||||
"list_mutation_test_state": {
|
||||
"friends_in_nested_list": [["Jenny", "Jimmy"]]
|
||||
}
|
||||
},
|
||||
),
|
||||
],
|
||||
id="nested list",
|
||||
@ -460,16 +495,24 @@ async def test_dynamic_var_event(test_state):
|
||||
pytest.param(
|
||||
[
|
||||
(
|
||||
"test_state.add_jimmy_to_tommy_friends",
|
||||
{"test_state": {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}},
|
||||
"list_mutation_test_state.add_jimmy_to_tommy_friends",
|
||||
{
|
||||
"list_mutation_test_state": {
|
||||
"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.remove_jenny_from_tommy",
|
||||
{"test_state": {"friends_in_dict": {"Tommy": ["Jimmy"]}}},
|
||||
"list_mutation_test_state.remove_jenny_from_tommy",
|
||||
{
|
||||
"list_mutation_test_state": {
|
||||
"friends_in_dict": {"Tommy": ["Jimmy"]}
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.tommy_has_no_fds",
|
||||
{"test_state": {"friends_in_dict": {"Tommy": []}}},
|
||||
"list_mutation_test_state.tommy_has_no_fds",
|
||||
{"list_mutation_test_state": {"friends_in_dict": {"Tommy": []}}},
|
||||
),
|
||||
],
|
||||
id="list in dict",
|
||||
@ -477,7 +520,9 @@ async def test_dynamic_var_event(test_state):
|
||||
],
|
||||
)
|
||||
async def test_list_mutation_detection__plain_list(
|
||||
event_tuples: List[Tuple[str, List[str]]], list_mutation_state: State
|
||||
event_tuples: List[Tuple[str, List[str]]],
|
||||
list_mutation_state: State,
|
||||
token: str,
|
||||
):
|
||||
"""Test list mutation detection
|
||||
when reassignment is not explicitly included in the logic.
|
||||
@ -485,11 +530,12 @@ async def test_list_mutation_detection__plain_list(
|
||||
Args:
|
||||
event_tuples: From parametrization.
|
||||
list_mutation_state: A state with list mutation features.
|
||||
token: a Token.
|
||||
"""
|
||||
for event_name, expected_delta in event_tuples:
|
||||
result = await list_mutation_state._process(
|
||||
Event(
|
||||
token="fake-token",
|
||||
token=token,
|
||||
name=event_name,
|
||||
router_data={"pathname": "/", "query": {}},
|
||||
payload={},
|
||||
@ -506,16 +552,24 @@ async def test_list_mutation_detection__plain_list(
|
||||
pytest.param(
|
||||
[
|
||||
(
|
||||
"test_state.add_age",
|
||||
{"test_state": {"details": {"name": "Tommy", "age": 20}}},
|
||||
"dict_mutation_test_state.add_age",
|
||||
{
|
||||
"dict_mutation_test_state": {
|
||||
"details": {"name": "Tommy", "age": 20}
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.change_name",
|
||||
{"test_state": {"details": {"name": "Jenny", "age": 20}}},
|
||||
"dict_mutation_test_state.change_name",
|
||||
{
|
||||
"dict_mutation_test_state": {
|
||||
"details": {"name": "Jenny", "age": 20}
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.remove_last_detail",
|
||||
{"test_state": {"details": {"name": "Jenny"}}},
|
||||
"dict_mutation_test_state.remove_last_detail",
|
||||
{"dict_mutation_test_state": {"details": {"name": "Jenny"}}},
|
||||
),
|
||||
],
|
||||
id="update then __setitem__",
|
||||
@ -523,12 +577,12 @@ async def test_list_mutation_detection__plain_list(
|
||||
pytest.param(
|
||||
[
|
||||
(
|
||||
"test_state.clear_details",
|
||||
{"test_state": {"details": {}}},
|
||||
"dict_mutation_test_state.clear_details",
|
||||
{"dict_mutation_test_state": {"details": {}}},
|
||||
),
|
||||
(
|
||||
"test_state.add_age",
|
||||
{"test_state": {"details": {"age": 20}}},
|
||||
"dict_mutation_test_state.add_age",
|
||||
{"dict_mutation_test_state": {"details": {"age": 20}}},
|
||||
),
|
||||
],
|
||||
id="delitem then update",
|
||||
@ -536,16 +590,20 @@ async def test_list_mutation_detection__plain_list(
|
||||
pytest.param(
|
||||
[
|
||||
(
|
||||
"test_state.add_age",
|
||||
{"test_state": {"details": {"name": "Tommy", "age": 20}}},
|
||||
"dict_mutation_test_state.add_age",
|
||||
{
|
||||
"dict_mutation_test_state": {
|
||||
"details": {"name": "Tommy", "age": 20}
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.remove_name",
|
||||
{"test_state": {"details": {"age": 20}}},
|
||||
"dict_mutation_test_state.remove_name",
|
||||
{"dict_mutation_test_state": {"details": {"age": 20}}},
|
||||
),
|
||||
(
|
||||
"test_state.pop_out_age",
|
||||
{"test_state": {"details": {}}},
|
||||
"dict_mutation_test_state.pop_out_age",
|
||||
{"dict_mutation_test_state": {"details": {}}},
|
||||
),
|
||||
],
|
||||
id="add, remove, pop",
|
||||
@ -553,13 +611,17 @@ async def test_list_mutation_detection__plain_list(
|
||||
pytest.param(
|
||||
[
|
||||
(
|
||||
"test_state.remove_home_address",
|
||||
{"test_state": {"address": [{}, {"work": "work address"}]}},
|
||||
"dict_mutation_test_state.remove_home_address",
|
||||
{
|
||||
"dict_mutation_test_state": {
|
||||
"address": [{}, {"work": "work address"}]
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.add_street_to_home_address",
|
||||
"dict_mutation_test_state.add_street_to_home_address",
|
||||
{
|
||||
"test_state": {
|
||||
"dict_mutation_test_state": {
|
||||
"address": [
|
||||
{"street": "street address"},
|
||||
{"work": "work address"},
|
||||
@ -573,9 +635,9 @@ async def test_list_mutation_detection__plain_list(
|
||||
pytest.param(
|
||||
[
|
||||
(
|
||||
"test_state.change_friend_name",
|
||||
"dict_mutation_test_state.change_friend_name",
|
||||
{
|
||||
"test_state": {
|
||||
"dict_mutation_test_state": {
|
||||
"friend_in_nested_dict": {
|
||||
"name": "Nikhil",
|
||||
"friend": {"name": "Tommy"},
|
||||
@ -584,9 +646,9 @@ async def test_list_mutation_detection__plain_list(
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.add_friend_age",
|
||||
"dict_mutation_test_state.add_friend_age",
|
||||
{
|
||||
"test_state": {
|
||||
"dict_mutation_test_state": {
|
||||
"friend_in_nested_dict": {
|
||||
"name": "Nikhil",
|
||||
"friend": {"name": "Tommy", "age": 30},
|
||||
@ -595,8 +657,12 @@ async def test_list_mutation_detection__plain_list(
|
||||
},
|
||||
),
|
||||
(
|
||||
"test_state.remove_friend",
|
||||
{"test_state": {"friend_in_nested_dict": {"name": "Nikhil"}}},
|
||||
"dict_mutation_test_state.remove_friend",
|
||||
{
|
||||
"dict_mutation_test_state": {
|
||||
"friend_in_nested_dict": {"name": "Nikhil"}
|
||||
}
|
||||
},
|
||||
),
|
||||
],
|
||||
id="nested dict",
|
||||
@ -604,7 +670,9 @@ async def test_list_mutation_detection__plain_list(
|
||||
],
|
||||
)
|
||||
async def test_dict_mutation_detection__plain_list(
|
||||
event_tuples: List[Tuple[str, List[str]]], dict_mutation_state: State
|
||||
event_tuples: List[Tuple[str, List[str]]],
|
||||
dict_mutation_state: State,
|
||||
token: str,
|
||||
):
|
||||
"""Test dict mutation detection
|
||||
when reassignment is not explicitly included in the logic.
|
||||
@ -612,11 +680,12 @@ async def test_dict_mutation_detection__plain_list(
|
||||
Args:
|
||||
event_tuples: From parametrization.
|
||||
dict_mutation_state: A state with dict mutation features.
|
||||
token: a Token.
|
||||
"""
|
||||
for event_name, expected_delta in event_tuples:
|
||||
result = await dict_mutation_state._process(
|
||||
Event(
|
||||
token="fake-token",
|
||||
token=token,
|
||||
name=event_name,
|
||||
router_data={"pathname": "/", "query": {}},
|
||||
payload={},
|
||||
@ -628,41 +697,43 @@ async def test_dict_mutation_detection__plain_list(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"fixture, delta",
|
||||
("state", "delta"),
|
||||
[
|
||||
(
|
||||
"upload_state",
|
||||
FileUploadState,
|
||||
{"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
|
||||
),
|
||||
(
|
||||
"upload_sub_state",
|
||||
ChildFileUploadState,
|
||||
{
|
||||
"file_state.file_upload_state": {
|
||||
"file_state_base1.child_file_upload_state": {
|
||||
"img_list": ["image1.jpg", "image2.jpg"]
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
"upload_grand_sub_state",
|
||||
GrandChildFileUploadState,
|
||||
{
|
||||
"base_file_state.file_sub_state.file_upload_state": {
|
||||
"file_state_base1.file_state_base2.grand_child_file_upload_state": {
|
||||
"img_list": ["image1.jpg", "image2.jpg"]
|
||||
}
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_upload_file(fixture, request, delta):
|
||||
async def test_upload_file(tmp_path, state, delta, token: str):
|
||||
"""Test that file upload works correctly.
|
||||
|
||||
Args:
|
||||
fixture: The state.
|
||||
request: Fixture request.
|
||||
tmp_path: Temporary path.
|
||||
state: The state class.
|
||||
delta: Expected delta
|
||||
token: a Token.
|
||||
"""
|
||||
app = App(state=request.getfixturevalue(fixture))
|
||||
state._tmp_path = tmp_path
|
||||
app = App(state=state)
|
||||
app.event_namespace.emit = AsyncMock() # type: ignore
|
||||
current_state = app.state_manager.get_state("token")
|
||||
current_state = await app.state_manager.get_state(token)
|
||||
data = b"This is binary data"
|
||||
|
||||
# Create a binary IO object and write data to it
|
||||
@ -670,11 +741,11 @@ async def test_upload_file(fixture, request, delta):
|
||||
bio.write(data)
|
||||
|
||||
file1 = UploadFile(
|
||||
filename="token:file_upload_state.multi_handle_upload:True:image1.jpg",
|
||||
filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image1.jpg",
|
||||
file=bio,
|
||||
)
|
||||
file2 = UploadFile(
|
||||
filename="token:file_upload_state.multi_handle_upload:True:image2.jpg",
|
||||
filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image2.jpg",
|
||||
file=bio,
|
||||
)
|
||||
upload_fn = upload(app)
|
||||
@ -684,22 +755,27 @@ async def test_upload_file(fixture, request, delta):
|
||||
app.event_namespace.emit.assert_called_with( # type: ignore
|
||||
"event", state_update.json(), to=current_state.get_sid()
|
||||
)
|
||||
assert app.state_manager.get_state("token").dict()["img_list"] == [
|
||||
assert (await app.state_manager.get_state(token)).dict()["img_list"] == [
|
||||
"image1.jpg",
|
||||
"image2.jpg",
|
||||
]
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
await app.state_manager.redis.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"fixture", ["upload_state", "upload_sub_state", "upload_grand_sub_state"]
|
||||
"state",
|
||||
[FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
|
||||
)
|
||||
async def test_upload_file_without_annotation(fixture, request):
|
||||
async def test_upload_file_without_annotation(state, tmp_path, token):
|
||||
"""Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile].
|
||||
|
||||
Args:
|
||||
fixture: The state.
|
||||
request: Fixture request.
|
||||
state: The state class.
|
||||
tmp_path: Temporary path.
|
||||
token: a Token.
|
||||
"""
|
||||
data = b"This is binary data"
|
||||
|
||||
@ -707,14 +783,15 @@ async def test_upload_file_without_annotation(fixture, request):
|
||||
bio = io.BytesIO()
|
||||
bio.write(data)
|
||||
|
||||
app = App(state=request.getfixturevalue(fixture))
|
||||
state._tmp_path = tmp_path
|
||||
app = App(state=state)
|
||||
|
||||
file1 = UploadFile(
|
||||
filename="token:file_upload_state.handle_upload2:True:image1.jpg",
|
||||
filename=f"{token}:{state.get_name()}.handle_upload2:True:image1.jpg",
|
||||
file=bio,
|
||||
)
|
||||
file2 = UploadFile(
|
||||
filename="token:file_upload_state.handle_upload2:True:image2.jpg",
|
||||
filename=f"{token}:{state.get_name()}.handle_upload2:True:image2.jpg",
|
||||
file=bio,
|
||||
)
|
||||
fn = upload(app)
|
||||
@ -722,9 +799,12 @@ async def test_upload_file_without_annotation(fixture, request):
|
||||
await fn([file1, file2])
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== "`file_upload_state.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
|
||||
== f"`{state.get_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
|
||||
)
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
await app.state_manager.redis.close()
|
||||
|
||||
|
||||
class DynamicState(State):
|
||||
"""State class for testing dynamic route var.
|
||||
@ -768,6 +848,7 @@ class DynamicState(State):
|
||||
async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
index_page,
|
||||
windows_platform: bool,
|
||||
token: str,
|
||||
):
|
||||
"""Create app with dynamic route var, and simulate navigation.
|
||||
|
||||
@ -777,6 +858,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
Args:
|
||||
index_page: The index page.
|
||||
windows_platform: Whether the system is windows.
|
||||
token: a Token.
|
||||
"""
|
||||
arg_name = "dynamic"
|
||||
route = f"/test/[{arg_name}]"
|
||||
@ -792,10 +874,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
}
|
||||
assert constants.ROUTER_DATA in app.state().computed_var_dependencies
|
||||
|
||||
token = "mock_token"
|
||||
sid = "mock_sid"
|
||||
client_ip = "127.0.0.1"
|
||||
state = app.state_manager.get_state(token)
|
||||
state = await app.state_manager.get_state(token)
|
||||
assert state.dynamic == ""
|
||||
exp_vals = ["foo", "foobar", "baz"]
|
||||
|
||||
@ -817,6 +898,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prev_exp_val = ""
|
||||
for exp_index, exp_val in enumerate(exp_vals):
|
||||
hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
|
||||
exp_router_data = {
|
||||
@ -826,13 +908,14 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
"token": token,
|
||||
**hydrate_event.router_data,
|
||||
}
|
||||
update = await process(
|
||||
process_coro = process(
|
||||
app,
|
||||
event=hydrate_event,
|
||||
sid=sid,
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
).__anext__() # type: ignore
|
||||
)
|
||||
update = await process_coro.__anext__() # type: ignore
|
||||
|
||||
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
|
||||
assert update == StateUpdate(
|
||||
@ -860,14 +943,27 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
),
|
||||
],
|
||||
)
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
# When redis is used, the state is not updated until the processing is complete
|
||||
state = await app.state_manager.get_state(token)
|
||||
assert state.dynamic == prev_exp_val
|
||||
|
||||
# complete the processing
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await process_coro.__anext__() # type: ignore
|
||||
|
||||
# check that router data was written to the state_manager store
|
||||
state = await app.state_manager.get_state(token)
|
||||
assert state.dynamic == exp_val
|
||||
on_load_update = await process(
|
||||
|
||||
process_coro = process(
|
||||
app,
|
||||
event=_dynamic_state_event(name="on_load", val=exp_val),
|
||||
sid=sid,
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
).__anext__() # type: ignore
|
||||
)
|
||||
on_load_update = await process_coro.__anext__() # type: ignore
|
||||
assert on_load_update == StateUpdate(
|
||||
delta={
|
||||
state.get_name(): {
|
||||
@ -879,7 +975,10 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
},
|
||||
events=[],
|
||||
)
|
||||
on_set_is_hydrated_update = await process(
|
||||
# complete the processing
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await process_coro.__anext__() # type: ignore
|
||||
process_coro = process(
|
||||
app,
|
||||
event=_dynamic_state_event(
|
||||
name="set_is_hydrated", payload={"value": True}, val=exp_val
|
||||
@ -887,7 +986,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
sid=sid,
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
).__anext__() # type: ignore
|
||||
)
|
||||
on_set_is_hydrated_update = await process_coro.__anext__() # type: ignore
|
||||
assert on_set_is_hydrated_update == StateUpdate(
|
||||
delta={
|
||||
state.get_name(): {
|
||||
@ -899,15 +999,19 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
},
|
||||
events=[],
|
||||
)
|
||||
# complete the processing
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await process_coro.__anext__() # type: ignore
|
||||
|
||||
# a simple state update event should NOT trigger on_load or route var side effects
|
||||
update = await process(
|
||||
process_coro = process(
|
||||
app,
|
||||
event=_dynamic_state_event(name="on_counter", val=exp_val),
|
||||
sid=sid,
|
||||
headers={},
|
||||
client_ip=client_ip,
|
||||
).__anext__() # type: ignore
|
||||
)
|
||||
update = await process_coro.__anext__() # type: ignore
|
||||
assert update == StateUpdate(
|
||||
delta={
|
||||
state.get_name(): {
|
||||
@ -919,42 +1023,54 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
||||
},
|
||||
events=[],
|
||||
)
|
||||
# complete the processing
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await process_coro.__anext__() # type: ignore
|
||||
|
||||
prev_exp_val = exp_val
|
||||
state = await app.state_manager.get_state(token)
|
||||
assert state.loaded == len(exp_vals)
|
||||
assert state.counter == len(exp_vals)
|
||||
# print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
|
||||
# assert state.side_effect_counter == len(exp_vals)
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
await app.state_manager.redis.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_events(gen_state, mocker):
|
||||
async def test_process_events(mocker, token: str):
|
||||
"""Test that an event is processed properly and that it is postprocessed
|
||||
n+1 times. Also check that the processing flag of the last stateupdate is set to
|
||||
False.
|
||||
|
||||
Args:
|
||||
gen_state: The state.
|
||||
mocker: mocker object.
|
||||
token: a Token.
|
||||
"""
|
||||
router_data = {
|
||||
"pathname": "/",
|
||||
"query": {},
|
||||
"token": "mock_token",
|
||||
"token": token,
|
||||
"sid": "mock_sid",
|
||||
"headers": {},
|
||||
"ip": "127.0.0.1",
|
||||
}
|
||||
app = App(state=gen_state)
|
||||
app = App(state=GenState)
|
||||
mocker.patch.object(app, "postprocess", AsyncMock())
|
||||
event = Event(
|
||||
token="token", name="gen_state.go", payload={"c": 5}, router_data=router_data
|
||||
token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
|
||||
)
|
||||
|
||||
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): # type: ignore
|
||||
pass
|
||||
|
||||
assert app.state_manager.get_state("token").value == 5
|
||||
assert (await app.state_manager.get_state(token)).value == 5
|
||||
assert app.postprocess.call_count == 6
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
await app.state_manager.redis.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("state", "overlay_component", "exp_page_child"),
|
||||
|
@ -1,22 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import datetime
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
from typing import Dict, Generator, List
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from plotly.graph_objects import Figure
|
||||
|
||||
import reflex as rx
|
||||
from reflex.base import Base
|
||||
from reflex.constants import IS_HYDRATED, RouteVar
|
||||
from reflex.constants import APP_VAR, IS_HYDRATED, RouteVar, SocketEvent
|
||||
from reflex.event import Event, EventHandler
|
||||
from reflex.state import MutableProxy, State
|
||||
from reflex.utils import format
|
||||
from reflex.state import (
|
||||
ImmutableStateError,
|
||||
LockExpiredError,
|
||||
MutableProxy,
|
||||
State,
|
||||
StateManager,
|
||||
StateManagerMemory,
|
||||
StateManagerRedis,
|
||||
StateProxy,
|
||||
StateUpdate,
|
||||
)
|
||||
from reflex.utils import format, prerequisites
|
||||
from reflex.vars import BaseVar, ComputedVar
|
||||
|
||||
from .states import GenState
|
||||
|
||||
CI = bool(os.environ.get("CI", False))
|
||||
LOCK_EXPIRATION = 2000 if CI else 100
|
||||
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
|
||||
|
||||
|
||||
class Object(Base):
|
||||
"""A test object fixture."""
|
||||
@ -704,13 +724,9 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_event_generator(gen_state):
|
||||
"""Test event handlers that generate multiple updates.
|
||||
|
||||
Args:
|
||||
gen_state: A state.
|
||||
"""
|
||||
gen_state = gen_state()
|
||||
async def test_process_event_generator():
|
||||
"""Test event handlers that generate multiple updates."""
|
||||
gen_state = GenState() # type: ignore
|
||||
event = Event(
|
||||
token="t",
|
||||
name="go",
|
||||
@ -1402,6 +1418,396 @@ def test_state_with_invalid_yield():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", params=["in_process", "redis"])
|
||||
def state_manager(request) -> Generator[StateManager, None, None]:
|
||||
"""Instance of state manager parametrized for redis and in-process.
|
||||
|
||||
Args:
|
||||
request: pytest request object.
|
||||
|
||||
Yields:
|
||||
A state manager instance
|
||||
"""
|
||||
state_manager = StateManager.create(state=TestState)
|
||||
if request.param == "redis":
|
||||
if not isinstance(state_manager, StateManagerRedis):
|
||||
pytest.skip("Test requires redis")
|
||||
else:
|
||||
# explicitly NOT using redis
|
||||
state_manager = StateManagerMemory(state=TestState)
|
||||
assert not state_manager._states_locks
|
||||
|
||||
yield state_manager
|
||||
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
asyncio.get_event_loop().run_until_complete(state_manager.redis.close())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_modify_state(state_manager: StateManager, token: str):
|
||||
"""Test that the state manager can modify a state exclusively.
|
||||
|
||||
Args:
|
||||
state_manager: A state manager instance.
|
||||
token: A token.
|
||||
"""
|
||||
async with state_manager.modify_state(token):
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert await state_manager.redis.get(f"{token}_lock")
|
||||
elif isinstance(state_manager, StateManagerMemory):
|
||||
assert token in state_manager._states_locks
|
||||
assert state_manager._states_locks[token].locked()
|
||||
# lock should be dropped after exiting the context
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||
elif isinstance(state_manager, StateManagerMemory):
|
||||
assert not state_manager._states_locks[token].locked()
|
||||
|
||||
# separate instances should NOT share locks
|
||||
sm2 = StateManagerMemory(state=TestState)
|
||||
assert sm2._state_manager_lock is state_manager._state_manager_lock
|
||||
assert not sm2._states_locks
|
||||
if state_manager._states_locks:
|
||||
assert sm2._states_locks != state_manager._states_locks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_contend(state_manager: StateManager, token: str):
|
||||
"""Multiple coroutines attempting to access the same state.
|
||||
|
||||
Args:
|
||||
state_manager: A state manager instance.
|
||||
token: A token.
|
||||
"""
|
||||
n_coroutines = 10
|
||||
exp_num1 = 10
|
||||
|
||||
async with state_manager.modify_state(token) as state:
|
||||
state.num1 = 0
|
||||
|
||||
async def _coro():
|
||||
async with state_manager.modify_state(token) as state:
|
||||
await asyncio.sleep(0.01)
|
||||
state.num1 += 1
|
||||
|
||||
tasks = [asyncio.create_task(_coro()) for _ in range(n_coroutines)]
|
||||
|
||||
for f in asyncio.as_completed(tasks):
|
||||
await f
|
||||
|
||||
assert (await state_manager.get_state(token)).num1 == exp_num1
|
||||
|
||||
if isinstance(state_manager, StateManagerRedis):
|
||||
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||
elif isinstance(state_manager, StateManagerMemory):
|
||||
assert token in state_manager._states_locks
|
||||
assert not state_manager._states_locks[token].locked()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def state_manager_redis() -> Generator[StateManager, None, None]:
|
||||
"""Instance of state manager for redis only.
|
||||
|
||||
Yields:
|
||||
A state manager instance
|
||||
"""
|
||||
state_manager = StateManager.create(TestState)
|
||||
|
||||
if not isinstance(state_manager, StateManagerRedis):
|
||||
pytest.skip("Test requires redis")
|
||||
|
||||
yield state_manager
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(state_manager.redis.close())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_lock_expire(state_manager_redis: StateManager, token: str):
|
||||
"""Test that the state manager lock expires and raises exception exiting context.
|
||||
|
||||
Args:
|
||||
state_manager_redis: A state manager instance.
|
||||
token: A token.
|
||||
"""
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
|
||||
async with state_manager_redis.modify_state(token):
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
with pytest.raises(LockExpiredError):
|
||||
async with state_manager_redis.modify_state(token):
|
||||
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_manager_lock_expire_contend(
|
||||
state_manager_redis: StateManager, token: str
|
||||
):
|
||||
"""Test that the state manager lock expires and queued waiters proceed.
|
||||
|
||||
Args:
|
||||
state_manager_redis: A state manager instance.
|
||||
token: A token.
|
||||
"""
|
||||
exp_num1 = 4252
|
||||
unexp_num1 = 666
|
||||
|
||||
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||
|
||||
order = []
|
||||
|
||||
async def _coro_blocker():
|
||||
async with state_manager_redis.modify_state(token) as state:
|
||||
order.append("blocker")
|
||||
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
|
||||
state.num1 = unexp_num1
|
||||
|
||||
async def _coro_waiter():
|
||||
while "blocker" not in order:
|
||||
await asyncio.sleep(0.005)
|
||||
async with state_manager_redis.modify_state(token) as state:
|
||||
order.append("waiter")
|
||||
assert state.num1 != unexp_num1
|
||||
state.num1 = exp_num1
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(_coro_blocker()),
|
||||
asyncio.create_task(_coro_waiter()),
|
||||
]
|
||||
with pytest.raises(LockExpiredError):
|
||||
await tasks[0]
|
||||
await tasks[1]
|
||||
|
||||
assert order == ["blocker", "waiter"]
|
||||
assert (await state_manager_redis.get_state(token)).num1 == exp_num1
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_app(monkeypatch, app: rx.App, state_manager: StateManager) -> rx.App:
|
||||
"""Mock app fixture.
|
||||
|
||||
Args:
|
||||
monkeypatch: Pytest monkeypatch object.
|
||||
app: An app.
|
||||
state_manager: A state manager.
|
||||
|
||||
Returns:
|
||||
The app, after mocking out prerequisites.get_app()
|
||||
"""
|
||||
app_module = Mock()
|
||||
setattr(app_module, APP_VAR, app)
|
||||
app.state = TestState
|
||||
app.state_manager = state_manager
|
||||
assert app.event_namespace is not None
|
||||
app.event_namespace.emit = AsyncMock()
|
||||
monkeypatch.setattr(prerequisites, "get_app", lambda: app_module)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
"""Test that the state proxy works.
|
||||
|
||||
Args:
|
||||
grandchild_state: A grandchild state.
|
||||
mock_app: An app that will be returned by `get_app()`
|
||||
"""
|
||||
child_state = grandchild_state.parent_state
|
||||
assert child_state is not None
|
||||
parent_state = child_state.parent_state
|
||||
assert parent_state is not None
|
||||
if isinstance(mock_app.state_manager, StateManagerMemory):
|
||||
mock_app.state_manager.states[parent_state.get_token()] = parent_state
|
||||
|
||||
sp = StateProxy(grandchild_state)
|
||||
assert sp.__wrapped__ == grandchild_state
|
||||
assert sp._self_substate_path == grandchild_state.get_full_name().split(".")
|
||||
assert sp._self_app is mock_app
|
||||
assert not sp._self_mutable
|
||||
assert sp._self_actx is None
|
||||
|
||||
# cannot use normal contextmanager protocol
|
||||
with pytest.raises(TypeError), sp:
|
||||
pass
|
||||
|
||||
with pytest.raises(ImmutableStateError):
|
||||
# cannot directly modify state proxy outside of async context
|
||||
sp.value2 = 16
|
||||
|
||||
async with sp:
|
||||
assert sp._self_actx is not None
|
||||
assert sp._self_mutable # proxy is mutable inside context
|
||||
if isinstance(mock_app.state_manager, StateManagerMemory):
|
||||
# For in-process store, only one instance of the state exists
|
||||
assert sp.__wrapped__ is grandchild_state
|
||||
else:
|
||||
# When redis is used, a new+updated instance is assigned to the proxy
|
||||
assert sp.__wrapped__ is not grandchild_state
|
||||
sp.value2 = 42
|
||||
assert not sp._self_mutable # proxy is not mutable after exiting context
|
||||
assert sp._self_actx is None
|
||||
assert sp.value2 == 42
|
||||
|
||||
# Get the state from the state manager directly and check that the value is updated
|
||||
gotten_state = await mock_app.state_manager.get_state(grandchild_state.get_token())
|
||||
if isinstance(mock_app.state_manager, StateManagerMemory):
|
||||
# For in-process store, only one instance of the state exists
|
||||
assert gotten_state is parent_state
|
||||
else:
|
||||
assert gotten_state is not parent_state
|
||||
gotten_grandchild_state = gotten_state.get_substate(sp._self_substate_path)
|
||||
assert gotten_grandchild_state is not None
|
||||
assert gotten_grandchild_state.value2 == 42
|
||||
|
||||
# ensure state update was emitted
|
||||
assert mock_app.event_namespace is not None
|
||||
mock_app.event_namespace.emit.assert_called_once()
|
||||
mcall = mock_app.event_namespace.emit.mock_calls[0]
|
||||
assert mcall.args[0] == str(SocketEvent.EVENT)
|
||||
assert json.loads(mcall.args[1]) == StateUpdate(
|
||||
delta={
|
||||
parent_state.get_full_name(): {
|
||||
"upper": "",
|
||||
"sum": 3.14,
|
||||
},
|
||||
grandchild_state.get_full_name(): {
|
||||
"value2": 42,
|
||||
},
|
||||
}
|
||||
)
|
||||
assert mcall.kwargs["to"] == grandchild_state.get_sid()
|
||||
|
||||
|
||||
class BackgroundTaskState(State):
|
||||
"""A state with a background task."""
|
||||
|
||||
order: List[str] = []
|
||||
dict_list: Dict[str, List[int]] = {"foo": []}
|
||||
|
||||
@rx.background
|
||||
async def background_task(self):
|
||||
"""A background task that updates the state."""
|
||||
async with self:
|
||||
assert not self.order
|
||||
self.order.append("background_task:start")
|
||||
|
||||
assert isinstance(self, StateProxy)
|
||||
with pytest.raises(ImmutableStateError):
|
||||
self.order.append("bad idea")
|
||||
|
||||
with pytest.raises(ImmutableStateError):
|
||||
# Even nested access to mutables raises an exception.
|
||||
self.dict_list["foo"].append(42)
|
||||
|
||||
# wait for some other event to happen
|
||||
while len(self.order) == 1:
|
||||
await asyncio.sleep(0.01)
|
||||
async with self:
|
||||
pass # update proxy instance
|
||||
|
||||
async with self:
|
||||
self.order.append("background_task:stop")
|
||||
|
||||
@rx.background
|
||||
async def background_task_generator(self):
|
||||
"""A background task generator that does nothing.
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
yield
|
||||
|
||||
def other(self):
|
||||
"""Some other event that updates the state."""
|
||||
self.order.append("other")
|
||||
|
||||
async def bad_chain1(self):
|
||||
"""Test that a background task cannot be chained."""
|
||||
await self.background_task()
|
||||
|
||||
async def bad_chain2(self):
|
||||
"""Test that a background task generator cannot be chained."""
|
||||
async for _foo in self.background_task_generator():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
"""Test that a background task does not block other events.
|
||||
|
||||
Args:
|
||||
mock_app: An app that will be returned by `get_app()`
|
||||
token: A token.
|
||||
"""
|
||||
router_data = {"query": {}}
|
||||
mock_app.state_manager.state = mock_app.state = BackgroundTaskState
|
||||
async for update in rx.app.process( # type: ignore
|
||||
mock_app,
|
||||
Event(
|
||||
token=token,
|
||||
name=f"{BackgroundTaskState.get_name()}.background_task",
|
||||
router_data=router_data,
|
||||
payload={},
|
||||
),
|
||||
sid="",
|
||||
headers={},
|
||||
client_ip="",
|
||||
):
|
||||
# background task returns empty update immediately
|
||||
assert update == StateUpdate()
|
||||
assert len(mock_app.background_tasks) == 1
|
||||
|
||||
# wait for the coroutine to start
|
||||
await asyncio.sleep(0.5 if CI else 0.1)
|
||||
assert len(mock_app.background_tasks) == 1
|
||||
|
||||
# Process another normal event
|
||||
async for update in rx.app.process( # type: ignore
|
||||
mock_app,
|
||||
Event(
|
||||
token=token,
|
||||
name=f"{BackgroundTaskState.get_name()}.other",
|
||||
router_data=router_data,
|
||||
payload={},
|
||||
),
|
||||
sid="",
|
||||
headers={},
|
||||
client_ip="",
|
||||
):
|
||||
# other task returns delta
|
||||
assert update == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_name(): {
|
||||
"order": [
|
||||
"background_task:start",
|
||||
"other",
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Explicit wait for background tasks
|
||||
for task in tuple(mock_app.background_tasks):
|
||||
await task
|
||||
assert not mock_app.background_tasks
|
||||
|
||||
assert (await mock_app.state_manager.get_state(token)).order == [
|
||||
"background_task:start",
|
||||
"other",
|
||||
"background_task:stop",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_task_no_chain():
|
||||
"""Test that a background task cannot be chained."""
|
||||
bts = BackgroundTaskState()
|
||||
with pytest.raises(RuntimeError):
|
||||
await bts.bad_chain1()
|
||||
with pytest.raises(RuntimeError):
|
||||
await bts.bad_chain2()
|
||||
|
||||
|
||||
def test_mutable_list(mutable_state):
|
||||
"""Test that mutable lists are tracked correctly.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user