rx.background and StateManager.modify_state provides safe exclusive access to state (#1676)

This commit is contained in:
Masen Furer 2023-09-21 11:42:11 -07:00 committed by GitHub
parent 211dc15995
commit 351611ca25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 2517 additions and 812 deletions

View File

@ -15,7 +15,23 @@ permissions:
jobs:
integration-app-harness:
strategy:
matrix:
state_manager: [ "redis", "memory" ]
runs-on: ubuntu-latest
services:
# Label used to access the service container
redis:
image: ${{ matrix.state_manager == 'redis' && 'redis' || '' }}
# Set health checks to wait until redis has started
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
# Maps port 6379 on service container to the host
- 6379:6379
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup_build_env
@ -27,6 +43,7 @@ jobs:
- name: Run app harness tests
env:
SCREENSHOT_DIR: /tmp/screenshots
REDIS_URL: ${{ matrix.state_manager == 'redis' && 'localhost:6379' || '' }}
run: |
poetry run pytest integration
- uses: actions/upload-artifact@v3

View File

@ -40,6 +40,20 @@ jobs:
- os: windows-latest
python-version: "3.8.10"
runs-on: ${{ matrix.os }}
# Service containers to run with `runner-job`
services:
# Label used to access the service container
redis:
image: ${{ matrix.os == 'ubuntu-latest' && 'redis' || '' }}
# Set health checks to wait until redis has started
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
# Maps port 6379 on service container to the host
- 6379:6379
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/setup_build_env
@ -51,4 +65,10 @@ jobs:
run: |
export PYTHONUNBUFFERED=1
poetry run pytest tests --cov --no-cov-on-fail --cov-report=
- name: Run unit tests w/ redis
if: ${{ matrix.os == 'ubuntu-latest' }}
run: |
export PYTHONUNBUFFERED=1
export REDIS_URL=localhost:6379
poetry run pytest tests --cov --no-cov-on-fail --cov-report=
- run: poetry run coverage html

View File

@ -1,4 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
args: [integration, reflex, tests]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.244
hooks:
@ -17,9 +23,3 @@ repos:
hooks:
- id: darglint
exclude: '^reflex/reflex.py'
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
- id: black
args: [integration, reflex, tests]

View File

@ -0,0 +1,214 @@
"""Test @rx.background task functionality."""
from typing import Generator
import pytest
from selenium.webdriver.common.by import By
from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
def BackgroundTask():
"""Test that background tasks work as expected."""
import asyncio
import reflex as rx
class State(rx.State):
counter: int = 0
_task_id: int = 0
iterations: int = 10
@rx.background
async def handle_event(self):
async with self:
self._task_id += 1
for _ix in range(int(self.iterations)):
async with self:
self.counter += 1
await asyncio.sleep(0.005)
@rx.background
async def handle_event_yield_only(self):
async with self:
self._task_id += 1
for ix in range(int(self.iterations)):
if ix % 2 == 0:
yield State.increment_arbitrary(1) # type: ignore
else:
yield State.increment() # type: ignore
await asyncio.sleep(0.005)
def increment(self):
self.counter += 1
@rx.background
async def increment_arbitrary(self, amount: int):
async with self:
self.counter += int(amount)
def reset_counter(self):
self.counter = 0
async def blocking_pause(self):
await asyncio.sleep(0.02)
@rx.background
async def non_blocking_pause(self):
await asyncio.sleep(0.02)
@rx.cached_var
def token(self) -> str:
return self.get_token()
def index() -> rx.Component:
return rx.vstack(
rx.input(id="token", value=State.token, is_read_only=True),
rx.heading(State.counter, id="counter"),
rx.input(
id="iterations",
placeholder="Iterations",
value=State.iterations.to_string(), # type: ignore
on_change=State.set_iterations, # type: ignore
),
rx.button(
"Delayed Increment",
on_click=State.handle_event,
id="delayed-increment",
),
rx.button(
"Yield Increment",
on_click=State.handle_event_yield_only,
id="yield-increment",
),
rx.button("Increment 1", on_click=State.increment, id="increment"),
rx.button(
"Blocking Pause",
on_click=State.blocking_pause,
id="blocking-pause",
),
rx.button(
"Non-Blocking Pause",
on_click=State.non_blocking_pause,
id="non-blocking-pause",
),
rx.button("Reset", on_click=State.reset_counter, id="reset"),
)
app = rx.App(state=State)
app.add_page(index)
app.compile()
@pytest.fixture(scope="session")
def background_task(
tmp_path_factory,
) -> Generator[AppHarness, None, None]:
"""Start BackgroundTask app at tmp_path via AppHarness.
Args:
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path_factory.mktemp(f"background_task"),
app_source=BackgroundTask, # type: ignore
) as harness:
yield harness
@pytest.fixture
def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
"""Get an instance of the browser open to the background_task app.
Args:
background_task: harness for BackgroundTask app
Yields:
WebDriver instance.
"""
assert background_task.app_instance is not None, "app is not running"
driver = background_task.frontend()
try:
yield driver
finally:
driver.quit()
@pytest.fixture()
def token(background_task: AppHarness, driver: WebDriver) -> str:
"""Get a function that returns the active token.
Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
Returns:
The token for the connected client
"""
assert background_task.app_instance is not None
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
assert token is not None
return token
def test_background_task(
background_task: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that background tasks work as expected.
Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert background_task.app_instance is not None
# get a reference to all buttons
delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
yield_increment_button = driver.find_element(By.ID, "yield-increment")
increment_button = driver.find_element(By.ID, "increment")
blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
driver.find_element(By.ID, "reset")
# get a reference to the counter
counter = driver.find_element(By.ID, "counter")
# get a reference to the iterations input
iterations_input = driver.find_element(By.ID, "iterations")
# kick off background tasks
iterations_input.clear()
iterations_input.send_keys("50")
delayed_increment_button.click()
blocking_pause_button.click()
delayed_increment_button.click()
for _ in range(10):
increment_button.click()
blocking_pause_button.click()
delayed_increment_button.click()
delayed_increment_button.click()
yield_increment_button.click()
non_blocking_pause_button.click()
yield_increment_button.click()
blocking_pause_button.click()
yield_increment_button.click()
for _ in range(10):
increment_button.click()
yield_increment_button.click()
blocking_pause_button.click()
assert background_task._poll_for(lambda: counter.text == "420", timeout=40)
# all tasks should have exited and cleaned up
assert background_task._poll_for(
lambda: not background_task.app_instance.background_tasks # type: ignore
)

View File

@ -133,7 +133,6 @@ def driver(client_side: AppHarness) -> Generator[WebDriver, None, None]:
assert client_side.app_instance is not None, "app is not running"
driver = client_side.frontend()
try:
assert client_side.poll_for_clients()
yield driver
finally:
driver.quit()
@ -168,7 +167,20 @@ def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]:
driver.delete_all_cookies()
def test_client_side_state(
def cookie_info_map(driver: WebDriver) -> dict[str, dict[str, str]]:
"""Get a map of cookie names to cookie info.
Args:
driver: WebDriver instance.
Returns:
A map of cookie names to cookie info.
"""
return {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
@pytest.mark.asyncio
async def test_client_side_state(
client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage
):
"""Test client side state.
@ -187,8 +199,6 @@ def test_client_side_state(
token = client_side.poll_for_value(token_input)
assert token is not None
backend_state = client_side.app_instance.state_manager.states[token]
# get a reference to the cookie manipulation form
state_var_input = driver.find_element(By.ID, "state_var")
input_value_input = driver.find_element(By.ID, "input_value")
@ -274,7 +284,7 @@ def test_client_side_state(
input_value_input.send_keys("l1s value")
set_sub_sub_state_button.click()
cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
cookies = cookie_info_map(driver)
assert cookies.pop("client_side_state.client_side_sub_state.c1") == {
"domain": "localhost",
"httpOnly": False,
@ -338,8 +348,10 @@ def test_client_side_state(
state_var_input.send_keys("c3")
input_value_input.send_keys("c3 value")
set_sub_state_button.click()
cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
c3_cookie = cookies["client_side_state.client_side_sub_state.c3"]
AppHarness._poll_for(
lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver)
)
c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"]
assert c3_cookie.pop("expiry") is not None
assert c3_cookie == {
"domain": "localhost",
@ -351,9 +363,7 @@ def test_client_side_state(
"value": "c3%20value",
}
time.sleep(2) # wait for c3 to expire
assert "client_side_state.client_side_sub_state.c3" not in {
cookie_info["name"] for cookie_info in driver.get_cookies()
}
assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver)
local_storage_items = local_storage.items()
local_storage_items.pop("chakra-ui-color-mode", None)
@ -426,7 +436,8 @@ def test_client_side_state(
assert l1s.text == "l1s value"
# reset the backend state to force refresh from client storage
backend_state.reset()
async with client_side.modify_state(token) as state:
state.reset()
driver.refresh()
# wait for the backend connection to send the token (again)
@ -465,9 +476,7 @@ def test_client_side_state(
assert l1s.text == "l1s value"
# make sure c5 cookie shows up on the `/foo` route
cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
assert cookies["client_side_state.client_side_sub_state.c5"] == {
assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
"domain": "localhost",
"httpOnly": False,
"name": "client_side_state.client_side_sub_state.c5",

View File

@ -1,11 +1,10 @@
"""Integration tests for dynamic route page behavior."""
from typing import Callable, Generator, Type
from typing import Callable, Coroutine, Generator, Type
from urllib.parse import urlsplit
import pytest
from selenium.webdriver.common.by import By
from reflex import State
from reflex.testing import AppHarness, AppHarnessProd, WebDriver
from .utils import poll_for_navigation
@ -100,22 +99,21 @@ def driver(dynamic_route: AppHarness) -> Generator[WebDriver, None, None]:
assert dynamic_route.app_instance is not None, "app is not running"
driver = dynamic_route.frontend()
try:
assert dynamic_route.poll_for_clients()
yield driver
finally:
driver.quit()
@pytest.fixture()
def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
"""Get the backend state.
def token(dynamic_route: AppHarness, driver: WebDriver) -> str:
"""Get the token associated with backend state.
Args:
dynamic_route: harness for DynamicRoute app.
driver: WebDriver instance.
Returns:
The backend state associated with the token visible in the driver browser.
The token visible in the driver browser.
"""
assert dynamic_route.app_instance is not None
token_input = driver.find_element(By.ID, "token")
@ -125,43 +123,49 @@ def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
token = dynamic_route.poll_for_value(token_input)
assert token is not None
# look up the backend state from the state manager
return dynamic_route.app_instance.state_manager.states[token]
return token
@pytest.fixture()
def poll_for_order(
dynamic_route: AppHarness, backend_state: State
) -> Callable[[list[str]], None]:
dynamic_route: AppHarness, token: str
) -> Callable[[list[str]], Coroutine[None, None, None]]:
"""Poll for the order list to match the expected order.
Args:
dynamic_route: harness for DynamicRoute app.
backend_state: The backend state associated with the token visible in the driver browser.
token: The token visible in the driver browser.
Returns:
A function that polls for the order list to match the expected order.
An async function that polls for the order list to match the expected order.
"""
def _poll_for_order(exp_order: list[str]):
dynamic_route._poll_for(lambda: backend_state.order == exp_order)
assert backend_state.order == exp_order
async def _poll_for_order(exp_order: list[str]):
async def _backend_state():
return await dynamic_route.get_state(token)
async def _check():
return (await _backend_state()).order == exp_order
await AppHarness._poll_for_async(_check)
assert (await _backend_state()).order == exp_order
return _poll_for_order
def test_on_load_navigate(
@pytest.mark.asyncio
async def test_on_load_navigate(
dynamic_route: AppHarness,
driver: WebDriver,
backend_state: State,
poll_for_order: Callable[[list[str]], None],
token: str,
poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
):
"""Click links to navigate between dynamic pages with on_load event.
Args:
dynamic_route: harness for DynamicRoute app.
driver: WebDriver instance.
backend_state: The backend state associated with the token visible in the driver browser.
token: The token visible in the driver browser.
poll_for_order: function that polls for the order list to match the expected order.
"""
assert dynamic_route.app_instance is not None
@ -184,7 +188,7 @@ def test_on_load_navigate(
assert page_id_input
assert dynamic_route.poll_for_value(page_id_input) == str(ix)
poll_for_order(exp_order)
await poll_for_order(exp_order)
# manually load the next page to trigger client side routing in prod mode
if is_prod:
@ -192,14 +196,14 @@ def test_on_load_navigate(
exp_order += ["/page/[page_id]-10"]
with poll_for_navigation(driver):
driver.get(f"{dynamic_route.frontend_url}/page/10/")
poll_for_order(exp_order)
await poll_for_order(exp_order)
# make sure internal nav still hydrates after redirect
exp_order += ["/page/[page_id]-11"]
link = driver.find_element(By.ID, "link_page_next")
with poll_for_navigation(driver):
link.click()
poll_for_order(exp_order)
await poll_for_order(exp_order)
# load same page with a query param and make sure it passes through
if is_prod:
@ -207,14 +211,14 @@ def test_on_load_navigate(
exp_order += ["/page/[page_id]-11"]
with poll_for_navigation(driver):
driver.get(f"{driver.current_url}?foo=bar")
poll_for_order(exp_order)
assert backend_state.get_query_params()["foo"] == "bar"
await poll_for_order(exp_order)
assert (await dynamic_route.get_state(token)).get_query_params()["foo"] == "bar"
# hit a 404 and ensure we still hydrate
exp_order += ["/404-no page id"]
with poll_for_navigation(driver):
driver.get(f"{dynamic_route.frontend_url}/missing")
poll_for_order(exp_order)
await poll_for_order(exp_order)
# browser nav should still trigger hydration
if is_prod:
@ -222,14 +226,14 @@ def test_on_load_navigate(
exp_order += ["/page/[page_id]-11"]
with poll_for_navigation(driver):
driver.back()
poll_for_order(exp_order)
await poll_for_order(exp_order)
# next/link to a 404 and ensure we still hydrate
exp_order += ["/404-no page id"]
link = driver.find_element(By.ID, "link_missing")
with poll_for_navigation(driver):
link.click()
poll_for_order(exp_order)
await poll_for_order(exp_order)
# hit a page that redirects back to dynamic page
if is_prod:
@ -237,15 +241,16 @@ def test_on_load_navigate(
exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"]
with poll_for_navigation(driver):
driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar")
poll_for_order(exp_order)
await poll_for_order(exp_order)
# should have redirected back to page 0
assert urlsplit(driver.current_url).path == "/page/0/"
def test_on_load_navigate_non_dynamic(
@pytest.mark.asyncio
async def test_on_load_navigate_non_dynamic(
dynamic_route: AppHarness,
driver: WebDriver,
poll_for_order: Callable[[list[str]], None],
poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
):
"""Click links to navigate between static pages with on_load event.
@ -261,7 +266,7 @@ def test_on_load_navigate_non_dynamic(
with poll_for_navigation(driver):
link.click()
assert urlsplit(driver.current_url).path == "/static/x/"
poll_for_order(["/static/x-no page id"])
await poll_for_order(["/static/x-no page id"])
# go back to the index and navigate back to the static route
link = driver.find_element(By.ID, "link_index")
@ -273,4 +278,4 @@ def test_on_load_navigate_non_dynamic(
with poll_for_navigation(driver):
link.click()
assert urlsplit(driver.current_url).path == "/static/x/"
poll_for_order(["/static/x-no page id", "/static/x-no page id"])
await poll_for_order(["/static/x-no page id", "/static/x-no page id"])

View File

@ -1,18 +1,20 @@
"""Ensure that Event Chains are properly queued and handled between frontend and backend."""
import time
from typing import Generator
import pytest
from selenium.webdriver.common.by import By
from reflex.testing import AppHarness
from reflex.testing import AppHarness, WebDriver
MANY_EVENTS = 50
def EventChain():
"""App with chained event handlers."""
import asyncio
import time
import reflex as rx
# repeated here since the outer global isn't exported into the App module
@ -20,6 +22,7 @@ def EventChain():
class State(rx.State):
event_order: list[str] = []
interim_value: str = ""
@rx.var
def token(self) -> str:
@ -111,12 +114,25 @@ def EventChain():
self.event_order.append("click_return_dict_type")
return State.event_arg_repr_type({"a": 1}) # type: ignore
async def click_yield_interim_value_async(self):
self.interim_value = "interim"
yield
await asyncio.sleep(0.5)
self.interim_value = "final"
def click_yield_interim_value(self):
self.interim_value = "interim"
yield
time.sleep(0.5)
self.interim_value = "final"
app = rx.App(state=State)
@app.add_page
def index():
return rx.fragment(
rx.input(value=State.token, readonly=True, id="token"),
rx.input(value=State.interim_value, readonly=True, id="interim_value"),
rx.button(
"Return Event",
id="return_event",
@ -172,6 +188,16 @@ def EventChain():
id="return_dict_type",
on_click=State.click_return_dict_type,
),
rx.button(
"Click Yield Interim Value (Async)",
id="click_yield_interim_value_async",
on_click=State.click_yield_interim_value_async,
),
rx.button(
"Click Yield Interim Value",
id="click_yield_interim_value",
on_click=State.click_yield_interim_value,
),
)
def on_load_return_chain():
@ -237,7 +263,7 @@ def event_chain(tmp_path_factory) -> Generator[AppHarness, None, None]:
@pytest.fixture
def driver(event_chain: AppHarness):
def driver(event_chain: AppHarness) -> Generator[WebDriver, None, None]:
"""Get an instance of the browser open to the event_chain app.
Args:
@ -249,7 +275,6 @@ def driver(event_chain: AppHarness):
assert event_chain.app_instance is not None, "app is not running"
driver = event_chain.frontend()
try:
assert event_chain.poll_for_clients()
yield driver
finally:
driver.quit()
@ -335,7 +360,13 @@ def driver(event_chain: AppHarness):
),
],
)
def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
@pytest.mark.asyncio
async def test_event_chain_click(
event_chain: AppHarness,
driver: WebDriver,
button_id: str,
exp_event_order: list[str],
):
"""Click the button, assert that the events are handled in the correct order.
Args:
@ -350,17 +381,18 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
assert btn
token = event_chain.poll_for_value(token_input)
assert token is not None
btn.click()
if "redirect" in button_id:
# wait a bit longer if we're redirecting
time.sleep(1)
if "many_events" in button_id:
# wait a bit longer if we have loads of events
time.sleep(1)
time.sleep(0.5)
backend_state = event_chain.app_instance.state_manager.states[token]
assert backend_state.event_order == exp_event_order
async def _has_all_events():
return len((await event_chain.get_state(token)).event_order) == len(
exp_event_order
)
await AppHarness._poll_for_async(_has_all_events)
event_order = (await event_chain.get_state(token)).event_order
assert event_order == exp_event_order
@pytest.mark.parametrize(
@ -386,7 +418,13 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
),
],
)
def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
@pytest.mark.asyncio
async def test_event_chain_on_load(
event_chain: AppHarness,
driver: WebDriver,
uri: str,
exp_event_order: list[str],
):
"""Load the URI, assert that the events are handled in the correct order.
Args:
@ -395,16 +433,23 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
uri: the page to load
exp_event_order: the expected events recorded in the State
"""
assert event_chain.frontend_url is not None
driver.get(event_chain.frontend_url + uri)
token_input = driver.find_element(By.ID, "token")
assert token_input
token = event_chain.poll_for_value(token_input)
assert token is not None
time.sleep(0.5)
backend_state = event_chain.app_instance.state_manager.states[token]
assert backend_state.is_hydrated is True
async def _has_all_events():
return len((await event_chain.get_state(token)).event_order) == len(
exp_event_order
)
await AppHarness._poll_for_async(_has_all_events)
backend_state = await event_chain.get_state(token)
assert backend_state.event_order == exp_event_order
assert backend_state.is_hydrated is True
@pytest.mark.parametrize(
@ -444,7 +489,13 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
),
],
)
def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
@pytest.mark.asyncio
async def test_event_chain_on_mount(
event_chain: AppHarness,
driver: WebDriver,
uri: str,
exp_event_order: list[str],
):
"""Load the URI, assert that the events are handled in the correct order.
These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
@ -458,16 +509,53 @@ def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
uri: the page to load
exp_event_order: the expected events recorded in the State
"""
assert event_chain.frontend_url is not None
driver.get(event_chain.frontend_url + uri)
token_input = driver.find_element(By.ID, "token")
assert token_input
token = event_chain.poll_for_value(token_input)
assert token is not None
unmount_button = driver.find_element(By.ID, "unmount")
assert unmount_button
unmount_button.click()
time.sleep(1)
backend_state = event_chain.app_instance.state_manager.states[token]
assert backend_state.event_order == exp_event_order
async def _has_all_events():
return len((await event_chain.get_state(token)).event_order) == len(
exp_event_order
)
await AppHarness._poll_for_async(_has_all_events)
event_order = (await event_chain.get_state(token)).event_order
assert event_order == exp_event_order
@pytest.mark.parametrize(
("button_id",),
[
("click_yield_interim_value_async",),
("click_yield_interim_value",),
],
)
def test_yield_state_update(event_chain: AppHarness, driver: WebDriver, button_id: str):
"""Click the button, assert that the interim value is set, then final value is set.
Args:
event_chain: AppHarness for the event_chain app
driver: selenium WebDriver open to the app
button_id: the ID of the button to click
"""
token_input = driver.find_element(By.ID, "token")
interim_value_input = driver.find_element(By.ID, "interim_value")
assert event_chain.poll_for_value(token_input)
btn = driver.find_element(By.ID, button_id)
btn.click()
assert (
event_chain.poll_for_value(interim_value_input, exp_not_equal="") == "interim"
)
assert (
event_chain.poll_for_value(interim_value_input, exp_not_equal="interim")
== "final"
)

View File

@ -19,11 +19,16 @@ def FormSubmit():
def form_submit(self, form_data: dict):
self.form_data = form_data
@rx.var
def token(self) -> str:
return self.get_token()
app = rx.App(state=FormState)
@app.add_page
def index():
return rx.vstack(
rx.input(value=FormState.token, is_read_only=True, id="token"),
rx.form(
rx.vstack(
rx.input(id="name_input"),
@ -82,13 +87,13 @@ def driver(form_submit: AppHarness):
"""
driver = form_submit.frontend()
try:
assert form_submit.poll_for_clients()
yield driver
finally:
driver.quit()
def test_submit(driver, form_submit: AppHarness):
@pytest.mark.asyncio
async def test_submit(driver, form_submit: AppHarness):
"""Fill a form with various different output, submit it to backend and verify
the output.
@ -97,7 +102,14 @@ def test_submit(driver, form_submit: AppHarness):
form_submit: harness for FormSubmit app
"""
assert form_submit.app_instance is not None, "app is not running"
_, backend_state = list(form_submit.app_instance.state_manager.states.items())[0]
# get a reference to the connected client
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = form_submit.poll_for_value(token_input)
assert token
name_input = driver.find_element(By.ID, "name_input")
name_input.send_keys("foo")
@ -132,19 +144,21 @@ def test_submit(driver, form_submit: AppHarness):
submit_input = driver.find_element(By.CLASS_NAME, "chakra-button")
submit_input.click()
# wait for the form data to arrive at the backend
AppHarness._poll_for(
lambda: backend_state.form_data != {},
)
async def get_form_data():
return (await form_submit.get_state(token)).form_data
assert backend_state.form_data["name_input"] == "foo"
assert backend_state.form_data["pin_input"] == pin_values
assert backend_state.form_data["number_input"] == "-3"
assert backend_state.form_data["bool_input"] is True
assert backend_state.form_data["bool_input2"] is True
assert backend_state.form_data["slider_input"] == "50"
assert backend_state.form_data["range_input"] == ["25", "75"]
assert backend_state.form_data["radio_input"] == "option2"
assert backend_state.form_data["select_input"] == "option1"
assert backend_state.form_data["text_area_input"] == "Some\nText"
assert backend_state.form_data["debounce_input"] == "bar baz"
# wait for the form data to arrive at the backend
form_data = await AppHarness._poll_for_async(get_form_data)
assert isinstance(form_data, dict)
assert form_data["name_input"] == "foo"
assert form_data["pin_input"] == pin_values
assert form_data["number_input"] == "-3"
assert form_data["bool_input"] is True
assert form_data["bool_input2"] is True
assert form_data["slider_input"] == "50"
assert form_data["range_input"] == ["25", "75"]
assert form_data["radio_input"] == "option2"
assert form_data["select_input"] == "option1"
assert form_data["text_area_input"] == "Some\nText"
assert form_data["debounce_input"] == "bar baz"

View File

@ -16,11 +16,16 @@ def FullyControlledInput():
class State(rx.State):
text: str = "initial"
@rx.var
def token(self) -> str:
return self.get_token()
app = rx.App(state=State)
@app.add_page
def index():
return rx.fragment(
rx.input(value=State.token, is_read_only=True, id="token"),
rx.input(
id="debounce_input_input",
on_change=State.set_text, # type: ignore
@ -62,10 +67,12 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
driver = fully_controlled_input.frontend()
# get a reference to the connected client
assert len(fully_controlled_input.poll_for_clients()) == 1
token, backend_state = list(
fully_controlled_input.app_instance.state_manager.states.items()
)[0]
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = fully_controlled_input.poll_for_value(token_input)
assert token
# find the input and wait for it to have the initial state value
debounce_input = driver.find_element(By.ID, "debounce_input_input")
@ -80,14 +87,13 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("foo")
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "ifoonitial"
assert backend_state.text == "ifoonitial"
assert (await fully_controlled_input.get_state(token)).text == "ifoonitial"
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
# clear the input on the backend
backend_state.text = ""
fully_controlled_input.app_instance.state_manager.set_state(token, backend_state)
await fully_controlled_input.emit_state_updates()
assert backend_state.text == ""
async with fully_controlled_input.modify_state(token) as state:
state.text = ""
assert (await fully_controlled_input.get_state(token)).text == ""
assert (
fully_controlled_input.poll_for_value(
debounce_input, exp_not_equal="ifoonitial"
@ -99,7 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
debounce_input.send_keys("getting testing done")
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "getting testing done"
assert backend_state.text == "getting testing done"
assert (
await fully_controlled_input.get_state(token)
).text == "getting testing done"
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
# type into the on_change input
@ -107,7 +115,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
time.sleep(0.5)
assert debounce_input.get_attribute("value") == "overwrite the state"
assert on_change_input.get_attribute("value") == "overwrite the state"
assert backend_state.text == "overwrite the state"
assert (await fully_controlled_input.get_state(token)).text == "overwrite the state"
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
clear_button.click()

View File

@ -33,11 +33,16 @@ def ServerSideEvent():
def set_value_return_c(self):
return rx.set_value("c", "")
@rx.var
def token(self) -> str:
return self.get_token()
app = rx.App(state=SSState)
@app.add_page
def index():
return rx.fragment(
rx.input(id="token", value=SSState.token, is_read_only=True),
rx.input(default_value="a", id="a"),
rx.input(default_value="b", id="b"),
rx.input(default_value="c", id="c"),
@ -106,7 +111,12 @@ def driver(server_side_event: AppHarness):
assert server_side_event.app_instance is not None, "app is not running"
driver = server_side_event.frontend()
try:
assert server_side_event.poll_for_clients()
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = server_side_event.poll_for_value(token_input)
assert token is not None
yield driver
finally:
driver.quit()

View File

@ -89,13 +89,13 @@ def driver(upload_file: AppHarness):
assert upload_file.app_instance is not None, "app is not running"
driver = upload_file.frontend()
try:
assert upload_file.poll_for_clients()
yield driver
finally:
driver.quit()
def test_upload_file(tmp_path, upload_file: AppHarness, driver):
@pytest.mark.asyncio
async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
"""Submit a file upload and check that it arrived on the backend.
Args:
@ -124,16 +124,20 @@ def test_upload_file(tmp_path, upload_file: AppHarness, driver):
upload_button.click()
# look up the backend state and assert on uploaded contents
backend_state = upload_file.app_instance.state_manager.states[token]
time.sleep(0.5)
assert backend_state._file_data[exp_name] == exp_contents
async def get_file_data():
return (await upload_file.get_state(token))._file_data
file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict)
assert file_data[exp_name] == exp_contents
# check that the selected files are displayed
selected_files = driver.find_element(By.ID, "selected_files")
assert selected_files.text == exp_name
def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
@pytest.mark.asyncio
async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
"""Submit several file uploads and check that they arrived on the backend.
Args:
@ -173,10 +177,13 @@ def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
upload_button.click()
# look up the backend state and assert on uploaded contents
backend_state = upload_file.app_instance.state_manager.states[token]
time.sleep(0.5)
async def get_file_data():
return (await upload_file.get_state(token))._file_data
file_data = await AppHarness._poll_for_async(get_file_data)
assert isinstance(file_data, dict)
for exp_name, exp_contents in exp_files.items():
assert backend_state._file_data[exp_name] == exp_contents
assert file_data[exp_name] == exp_contents
def test_clear_files(tmp_path, upload_file: AppHarness, driver):

View File

@ -26,11 +26,16 @@ def VarOperations():
dict1: dict = {1: 2}
dict2: dict = {3: 4}
@rx.var
def token(self) -> str:
return self.get_token()
app = rx.App(state=VarOperationState)
@app.add_page
def index():
return rx.vstack(
rx.input(id="token", value=VarOperationState.token, is_read_only=True),
# INT INT
rx.text(
VarOperationState.int_var1 + VarOperationState.int_var2,
@ -544,7 +549,12 @@ def driver(var_operations: AppHarness):
"""
driver = var_operations.frontend()
try:
assert var_operations.poll_for_clients()
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = var_operations.poll_for_value(token_input)
assert token is not None
yield driver
finally:
driver.quit()

View File

@ -21,6 +21,7 @@ from .constants import Env as Env
from .event import EVENT_ARG as EVENT_ARG
from .event import EventChain as EventChain
from .event import FileUpload as upload_files
from .event import background as background
from .event import clear_local_storage as clear_local_storage
from .event import console_log as console_log
from .event import download as download

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import contextlib
import inspect
import os
from multiprocessing.pool import ThreadPool
@ -13,6 +14,7 @@ from typing import (
Dict,
List,
Optional,
Set,
Type,
Union,
)
@ -49,7 +51,13 @@ from reflex.route import (
get_route_args,
verify_route_validity,
)
from reflex.state import DefaultState, State, StateManager, StateUpdate
from reflex.state import (
DefaultState,
State,
StateManager,
StateManagerMemory,
StateUpdate,
)
from reflex.utils import console, format, prerequisites, types
from reflex.vars import ImportVar
@ -89,7 +97,7 @@ class App(Base):
state: Type[State] = DefaultState
# Class to manage many client states.
state_manager: StateManager = StateManager()
state_manager: StateManager = StateManagerMemory(state=DefaultState)
# The styling to apply to each component.
style: ComponentStyle = {}
@ -104,13 +112,16 @@ class App(Base):
admin_dash: Optional[AdminDash] = None
# The async server name space
event_namespace: Optional[AsyncNamespace] = None
event_namespace: Optional[EventNamespace] = None
# A component that is present on every page.
overlay_component: Optional[
Union[Component, ComponentCallable]
] = default_overlay_component
# Background tasks that are currently running
background_tasks: Set[asyncio.Task] = set()
def __init__(self, *args, **kwargs):
"""Initialize the app.
@ -154,7 +165,7 @@ class App(Base):
self.middleware.append(HydrateMiddleware())
# Set up the state manager.
self.state_manager.setup(state=self.state)
self.state_manager = StateManager.create(state=self.state)
# Set up the API.
self.api = FastAPI()
@ -646,6 +657,76 @@ class App(Base):
thread_pool.close()
thread_pool.join()
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]:
"""Modify the state out of band.
Args:
token: The token to modify the state for.
Yields:
The state to modify.
Raises:
RuntimeError: If the app has not been initialized yet.
"""
if self.event_namespace is None:
raise RuntimeError("App has not been initialized yet.")
# Get exclusive access to the state.
async with self.state_manager.modify_state(token) as state:
# No other event handler can modify the state while in this context.
yield state
delta = state.get_delta()
if delta:
# When the state is modified reset dirty status and emit the delta to the frontend.
state._clean()
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
sid=state.get_sid(),
)
def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
"""Process an event in the background and emit updates as they arrive.
Args:
state: The state to process the event for.
event: The event to process.
Returns:
Task if the event was backgroundable, otherwise None
"""
substate, handler = state._get_event_handler(event)
if not handler.is_background:
return None
async def _coro():
"""Coroutine to process the event and emit updates inside an asyncio.Task.
Raises:
RuntimeError: If the app has not been initialized yet.
"""
if self.event_namespace is None:
raise RuntimeError("App has not been initialized yet.")
# Process the event.
async for update in state._process_event(
handler=handler, state=substate, payload=event.payload
):
# Postprocess the event.
update = await self.postprocess(state, event, update)
# Send the update to the client.
await self.event_namespace.emit_update(
update=update,
sid=state.get_sid(),
)
task = asyncio.create_task(_coro())
self.background_tasks.add(task)
# Clean up task from background_tasks set when complete.
task.add_done_callback(self.background_tasks.discard)
return task
async def process(
app: App, event: Event, sid: str, headers: Dict, client_ip: str
@ -662,9 +743,6 @@ async def process(
Yields:
The state updates after processing the event.
"""
# Get the state for the session.
state = app.state_manager.get_state(event.token)
# Add request data to the state.
router_data = event.router_data
router_data.update(
@ -676,31 +754,35 @@ async def process(
constants.RouteVar.CLIENT_IP: client_ip,
}
)
# re-assign only when the value is different
if state.router_data != router_data:
# assignment will recurse into substates and force recalculation of
# dependent ComputedVar (dynamic route variables)
state.router_data = router_data
# Get the state for the session exclusively.
async with app.state_manager.modify_state(event.token) as state:
# re-assign only when the value is different
if state.router_data != router_data:
# assignment will recurse into substates and force recalculation of
# dependent ComputedVar (dynamic route variables)
state.router_data = router_data
# Preprocess the event.
update = await app.preprocess(state, event)
# Preprocess the event.
update = await app.preprocess(state, event)
# If there was an update, yield it.
if update is not None:
yield update
# Only process the event if there is no update.
else:
# Process the event.
async for update in state._process(event):
# Postprocess the event.
update = await app.postprocess(state, event, update)
# Yield the update.
# If there was an update, yield it.
if update is not None:
yield update
# Set the state for the session.
app.state_manager.set_state(event.token, state)
# Only process the event if there is no update.
else:
if app._process_background(state, event) is not None:
# `final=True` allows the frontend send more events immediately.
yield StateUpdate(final=True)
return
# Process the event synchronously.
async for update in state._process(event):
# Postprocess the event.
update = await app.postprocess(state, event, update)
# Yield the update.
yield update
async def ping() -> str:
@ -737,47 +819,46 @@ def upload(app: App):
assert file.filename is not None
file.filename = file.filename.split(":")[-1]
# Get the state for the session.
state = app.state_manager.get_state(token)
# get the current session ID
sid = state.get_sid()
# get the current state(parent state/substate)
path = handler.split(".")[:-1]
current_state = state.get_substate(path)
handler_upload_param = ()
async with app.state_manager.modify_state(token) as state:
# get the current session ID
sid = state.get_sid()
# get the current state(parent state/substate)
path = handler.split(".")[:-1]
current_state = state.get_substate(path)
handler_upload_param = ()
# get handler function
func = getattr(current_state, handler.split(".")[-1])
# get handler function
func = getattr(current_state, handler.split(".")[-1])
# check if there exists any handler args with annotation, List[UploadFile]
for k, v in inspect.getfullargspec(
func.fn if isinstance(func, EventHandler) else func
).annotations.items():
if types.is_generic_alias(v) and types._issubclass(
v.__args__[0], UploadFile
):
handler_upload_param = (k, v)
break
# check if there exists any handler args with annotation, List[UploadFile]
for k, v in inspect.getfullargspec(
func.fn if isinstance(func, EventHandler) else func
).annotations.items():
if types.is_generic_alias(v) and types._issubclass(
v.__args__[0], UploadFile
):
handler_upload_param = (k, v)
break
if not handler_upload_param:
raise ValueError(
f"`{handler}` handler should have a parameter annotated as List["
f"rx.UploadFile]"
if not handler_upload_param:
raise ValueError(
f"`{handler}` handler should have a parameter annotated as List["
f"rx.UploadFile]"
)
event = Event(
token=token,
name=handler,
payload={handler_upload_param[0]: files},
)
event = Event(
token=token,
name=handler,
payload={handler_upload_param[0]: files},
)
async for update in state._process(event):
# Postprocess the event.
update = await app.postprocess(state, event, update)
# Send update to client
await asyncio.create_task(
app.event_namespace.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
)
# Set the state for the session.
app.state_manager.set_state(event.token, state)
async for update in state._process(event):
# Postprocess the event.
update = await app.postprocess(state, event, update)
# Send update to client
await app.event_namespace.emit_update( # type: ignore
update=update,
sid=sid,
)
return upload_file
@ -815,6 +896,18 @@ class EventNamespace(AsyncNamespace):
"""
pass
async def emit_update(self, update: StateUpdate, sid: str) -> None:
"""Emit an update to the client.
Args:
update: The state update to send.
sid: The Socket.IO session id.
"""
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
)
async def on_event(self, sid, data):
"""Event for receiving front-end websocket events.
@ -841,10 +934,8 @@ class EventNamespace(AsyncNamespace):
# Process the events.
async for update in process(self.app, event, sid, headers, client_ip):
# Emit the event.
await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
)
# Emit the update from processing the event.
await self.emit_update(update=update, sid=sid)
async def on_ping(self, sid):
"""Event for testing the API endpoint.

View File

@ -1,5 +1,6 @@
""" Generated with stubgen from mypy, then manually edited, do not regen."""
import asyncio
from fastapi import FastAPI
from fastapi import UploadFile as UploadFile
from reflex import constants as constants
@ -45,12 +46,14 @@ from reflex.utils import (
from socketio import ASGIApp, AsyncNamespace, AsyncServer
from typing import (
Any,
AsyncContextManager,
AsyncIterator,
Callable,
Coroutine,
Dict,
List,
Optional,
Set,
Type,
Union,
overload,
@ -75,6 +78,7 @@ class App(Base):
admin_dash: Optional[AdminDash]
event_namespace: Optional[AsyncNamespace]
overlay_component: Optional[Union[Component, ComponentCallable]]
background_tasks: Set[asyncio.Task] = set()
def __init__(
self,
*args,
@ -116,6 +120,10 @@ class App(Base):
def setup_admin_dash(self) -> None: ...
def get_frontend_packages(self, imports: Dict[str, str]): ...
def compile(self) -> None: ...
def modify_state(self, token: str) -> AsyncContextManager[State]: ...
def _process_background(
self, state: State, event: Event
) -> asyncio.Task | None: ...
async def process(
app: App, event: Event, sid: str, headers: Dict, client_ip: str

View File

@ -219,6 +219,8 @@ OLD_CONFIG_FILE = f"pcconfig{PY_EXT}"
PRODUCTION_BACKEND_URL = "https://{username}-{app_name}.api.pynecone.app"
# Token expiration time in seconds.
TOKEN_EXPIRATION = 60 * 60
# Maximum time in milliseconds that a state can be locked for exclusive access.
LOCK_EXPIRATION = 10000
# Testing variables.
# Testing os env set by pytest when running a test case.

View File

@ -2,7 +2,17 @@
from __future__ import annotations
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
from reflex import constants
from reflex.base import Base
@ -10,6 +20,9 @@ from reflex.utils import console, format
from reflex.utils.types import ArgsSpec
from reflex.vars import BaseVar, Var
if TYPE_CHECKING:
from reflex.state import State
class Event(Base):
"""An event that describes any state change in the app."""
@ -27,6 +40,66 @@ class Event(Base):
payload: Dict[str, Any] = {}
BACKGROUND_TASK_MARKER = "_reflex_background_task"
def background(fn):
"""Decorator to mark event handler as running in the background.
Args:
fn: The function to decorate.
Returns:
The same function, but with a marker set.
Raises:
TypeError: If the function is not a coroutine function or async generator.
"""
if not inspect.iscoroutinefunction(fn) and not inspect.isasyncgenfunction(fn):
raise TypeError("Background task must be async function or generator.")
setattr(fn, BACKGROUND_TASK_MARKER, True)
return fn
def _no_chain_background_task(
state_cls: Type["State"], name: str, fn: Callable
) -> Callable:
"""Protect against directly chaining a background task from another event handler.
Args:
state_cls: The state class that the event handler is in.
name: The name of the background task.
fn: The background task coroutine function / generator.
Returns:
A compatible coroutine function / generator that raises a runtime error.
Raises:
TypeError: If the background task is not async.
"""
call = f"{state_cls.__name__}.{name}"
message = (
f"Cannot directly call background task {name!r}, use "
f"`yield {call}` or `return {call}` instead."
)
if inspect.iscoroutinefunction(fn):
async def _no_chain_background_task_co(*args, **kwargs):
raise RuntimeError(message)
return _no_chain_background_task_co
if inspect.isasyncgenfunction(fn):
async def _no_chain_background_task_gen(*args, **kwargs):
yield
raise RuntimeError(message)
return _no_chain_background_task_gen
raise TypeError(f"{fn} is marked as a background task, but is not async.")
class EventHandler(Base):
"""An event handler responds to an event to update the state."""
@ -39,6 +112,15 @@ class EventHandler(Base):
# Needed to allow serialization of Callable.
frozen = True
@property
def is_background(self) -> bool:
"""Whether the event handler is a background task.
Returns:
True if the event handler is marked as a background task.
"""
return getattr(self.fn, BACKGROUND_TASK_MARKER, False)
def __call__(self, *args: Var) -> EventSpec:
"""Pass arguments to the handler to get an event spec.
@ -530,7 +612,7 @@ def get_handler_args(event_spec: EventSpec) -> tuple[tuple[Var, Var], ...]:
def fix_events(
events: list[EventHandler | EventSpec],
events: list[EventHandler | EventSpec] | None,
token: str,
router_data: dict[str, Any] | None = None,
) -> list[Event]:

View File

@ -2,13 +2,15 @@
from __future__ import annotations
import asyncio
import contextlib
import copy
import functools
import inspect
import json
import traceback
import urllib.parse
from abc import ABC
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from types import FunctionType
from typing import (
@ -27,12 +29,20 @@ from typing import (
import cloudpickle
import pydantic
import wrapt
from redis import Redis
from redis.asyncio import Redis
from reflex import constants
from reflex.base import Base
from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert
from reflex.event import (
Event,
EventHandler,
EventSpec,
_no_chain_background_task,
fix_events,
window_alert,
)
from reflex.utils import format, prerequisites, types
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
from reflex.vars import BaseVar, ComputedVar, Var
Delta = Dict[str, Any]
@ -152,7 +162,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Convert the event handlers to functions.
for name, event_handler in state.event_handlers.items():
fn = functools.partial(event_handler.fn, self)
if event_handler.is_background:
fn = _no_chain_background_task(type(state), name, event_handler.fn)
else:
fn = functools.partial(event_handler.fn, self)
fn.__module__ = event_handler.fn.__module__ # type: ignore
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
setattr(self, name, fn)
@ -711,6 +724,37 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
raise ValueError(f"Invalid path: {path}")
return self.substates[path[0]].get_substate(path[1:])
def _get_event_handler(
self, event: Event
) -> tuple[State | StateProxy, EventHandler]:
"""Get the event handler for the given event.
Args:
event: The event to get the handler for.
Returns:
The event handler.
Raises:
ValueError: If the event handler or substate is not found.
"""
# Get the event handler.
path = event.name.split(".")
path, name = path[:-1], path[-1]
substate = self.get_substate(path)
if not substate:
raise ValueError(
"The value of state cannot be None when processing an event."
)
handler = substate.event_handlers[name]
# For background tasks, proxy the state
if handler.is_background:
substate = StateProxy(substate)
return substate, handler
async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
"""Obtain event info and process event.
@ -719,44 +763,17 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Yields:
The state update after processing the event.
Raises:
ValueError: If the state value is None.
"""
# Get the event handler.
path = event.name.split(".")
path, name = path[:-1], path[-1]
substate = self.get_substate(path)
handler = substate.event_handlers[name] # type: ignore
substate, handler = self._get_event_handler(event)
if not substate:
raise ValueError(
"The value of state cannot be None when processing an event."
)
# Get the event generator.
event_iter = self._process_event(
# Run the event generator and yield state updates.
async for update in self._process_event(
handler=handler,
state=substate,
payload=event.payload,
)
# Clean the state before processing the event.
self._clean()
# Run the event generator and return state updates.
async for events, final in event_iter:
# Fix the returned events.
events = fix_events(events, event.token) # type: ignore
# Get the delta after processing the event.
delta = self.get_delta()
# Yield the state update.
yield StateUpdate(delta=delta, events=events, final=final)
# Clean the state to prepare for the next event.
self._clean()
):
yield update
def _check_valid(self, handler: EventHandler, events: Any) -> Any:
"""Check if the events yielded are valid. They must be EventHandlers or EventSpecs.
@ -787,9 +804,42 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
)
def _as_state_update(
self,
handler: EventHandler,
events: EventSpec | list[EventSpec] | None,
final: bool,
) -> StateUpdate:
"""Convert the events to a StateUpdate.
Fixes the events and checks for validity before converting.
Args:
handler: The handler where the events originated from.
events: The events to queue with the update.
final: Whether the handler is done processing.
Returns:
The valid StateUpdate containing the events and final flag.
"""
token = self.get_token()
# Convert valid EventHandler and EventSpec into Event
fixed_events = fix_events(self._check_valid(handler, events), token)
# Get the delta after processing the event.
delta = self.get_delta()
self._clean()
return StateUpdate(
delta=delta,
events=fixed_events,
final=final if not handler.is_background else True,
)
async def _process_event(
self, handler: EventHandler, state: State, payload: Dict
) -> AsyncIterator[tuple[list[EventSpec] | None, bool]]:
self, handler: EventHandler, state: State | StateProxy, payload: Dict
) -> AsyncIterator[StateUpdate]:
"""Process event.
Args:
@ -798,13 +848,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
payload: The event payload.
Yields:
Tuple containing:
0: The state update after processing the event.
1: Whether the event is the final event.
StateUpdate object
"""
# Get the function to process the event.
fn = functools.partial(handler.fn, state)
# Clean the state before processing the event.
self._clean()
# Wrap the function in a try/except block.
try:
# Handle async functions.
@ -817,30 +868,34 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Handle async generators.
if inspect.isasyncgen(events):
async for event in events:
yield self._check_valid(handler, event), False
yield None, True
yield self._as_state_update(handler, event, final=False)
yield self._as_state_update(handler, events=None, final=True)
# Handle regular generators.
elif inspect.isgenerator(events):
try:
while True:
yield self._check_valid(handler, next(events)), False
yield self._as_state_update(handler, next(events), final=False)
except StopIteration as si:
# the "return" value of the generator is not available
# in the loop, we must catch StopIteration to access it
if si.value is not None:
yield self._check_valid(handler, si.value), False
yield None, True
yield self._as_state_update(handler, si.value, final=False)
yield self._as_state_update(handler, events=None, final=True)
# Handle regular event chains.
else:
yield self._check_valid(handler, events), True
yield self._as_state_update(handler, events, final=True)
# If an error occurs, throw a window alert.
except Exception:
error = traceback.format_exc()
print(error)
yield [window_alert("An error occurred. See logs for details.")], True
yield self._as_state_update(
handler,
window_alert("An error occurred. See logs for details."),
final=True,
)
def _always_dirty_computed_vars(self) -> set[str]:
"""The set of ComputedVars that always need to be recalculated.
@ -989,6 +1044,160 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
variables = {**base_vars, **computed_vars, **substate_vars}
return {k: variables[k] for k in sorted(variables)}
async def __aenter__(self) -> State:
"""Enter the async context manager protocol.
This should not be used for the State class, but exists for
type-compatibility with StateProxy.
Raises:
TypeError: always, because async contextmanager protocol is only supported for background task.
"""
raise TypeError(
"Only background task should use `async with self` to modify state."
)
async def __aexit__(self, *exc_info: Any) -> None:
"""Exit the async context manager protocol.
This should not be used for the State class, but exists for
type-compatibility with StateProxy.
Args:
exc_info: The exception info tuple.
"""
pass
class StateProxy(wrapt.ObjectProxy):
"""Proxy of a state instance to control mutability of vars for a background task.
Since a background task runs against a state instance without holding the
state_manager lock for the token, the reference may become stale if the same
state is modified by another event handler.
The proxy object ensures that writes to the state are blocked unless
explicitly entering a context which refreshes the state from state_manager
and holds the lock for the token until exiting the context. After exiting
the context, a StateUpdate may be emitted to the frontend to notify the
client of the state change.
A background task will be passed the `StateProxy` as `self`, so mutability
can be safely performed inside an `async with self` block.
class State(rx.State):
counter: int = 0
@rx.background
async def bg_increment(self):
await asyncio.sleep(1)
async with self:
self.counter += 1
"""
def __init__(self, state_instance):
"""Create a proxy for a state instance.
Args:
state_instance: The state instance to proxy.
"""
super().__init__(state_instance)
self._self_app = getattr(prerequisites.get_app(), constants.APP_VAR)
self._self_substate_path = state_instance.get_full_name().split(".")
self._self_actx = None
self._self_mutable = False
async def __aenter__(self) -> StateProxy:
"""Enter the async context manager protocol.
Sets mutability to True and enters the `App.modify_state` async context,
which refreshes the state from state_manager and holds the lock for the
given state token until exiting the context.
Background tasks should avoid blocking calls while inside the context.
Returns:
This StateProxy instance in mutable mode.
"""
self._self_actx = self._self_app.modify_state(self.__wrapped__.get_token())
mutable_state = await self._self_actx.__aenter__()
super().__setattr__(
"__wrapped__", mutable_state.get_substate(self._self_substate_path)
)
self._self_mutable = True
return self
async def __aexit__(self, *exc_info: Any) -> None:
"""Exit the async context manager protocol.
Sets proxy mutability to False and persists any state changes.
Args:
exc_info: The exception info tuple.
"""
if self._self_actx is None:
return
self._self_mutable = False
await self._self_actx.__aexit__(*exc_info)
self._self_actx = None
def __enter__(self):
"""Enter the regular context manager protocol.
This is not supported for background tasks, and exists only to raise a more useful exception
when the StateProxy is used incorrectly.
Raises:
TypeError: always, because only async contextmanager protocol is supported.
"""
raise TypeError("Background task must use `async with self` to modify state.")
def __exit__(self, *exc_info: Any) -> None:
"""Exit the regular context manager protocol.
Args:
exc_info: The exception info tuple.
"""
pass
def __getattr__(self, name: str) -> Any:
"""Get the attribute from the underlying state instance.
Args:
name: The name of the attribute.
Returns:
The value of the attribute.
"""
value = super().__getattr__(name)
if not name.startswith("_self_") and isinstance(value, MutableProxy):
# ensure mutations to these containers are blocked unless proxy is _mutable
return ImmutableMutableProxy(
wrapped=value.__wrapped__,
state=self, # type: ignore
field_name=value._self_field_name,
)
return value
def __setattr__(self, name: str, value: Any) -> None:
"""Set the attribute on the underlying state instance.
If the attribute is internal, set it on the proxy instance instead.
Args:
name: The name of the attribute.
value: The value of the attribute.
Raises:
ImmutableStateError: If the state is not in mutable mode.
"""
if not name.startswith("_self_") and not self._self_mutable:
raise ImmutableStateError(
"Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state."
)
super().__setattr__(name, value)
class DefaultState(State):
"""The default empty state."""
@ -1009,31 +1218,29 @@ class StateUpdate(Base):
final: bool = True
class StateManager(Base):
class StateManager(Base, ABC):
"""A class to manage many client states."""
# The state class to use.
state: Type[State] = DefaultState
state: Type[State]
# The mapping of client ids to states.
states: Dict[str, State] = {}
# The token expiration time (s).
token_expiration: int = constants.TOKEN_EXPIRATION
# The redis client to use.
redis: Optional[Redis] = None
def setup(self, state: Type[State]):
"""Set up the state manager.
@classmethod
def create(cls, state: Type[State] = DefaultState):
"""Create a new state manager.
Args:
state: The state class to use.
"""
self.state = state
self.redis = prerequisites.get_redis()
def get_state(self, token: str) -> State:
Returns:
The state manager (either memory or redis).
"""
redis = prerequisites.get_redis()
if redis is not None:
return StateManagerRedis(state=state, redis=redis)
return StateManagerMemory(state=state)
@abstractmethod
async def get_state(self, token: str) -> State:
"""Get the state for a token.
Args:
@ -1042,27 +1249,266 @@ class StateManager(Base):
Returns:
The state for the token.
"""
if self.redis is not None:
redis_state = self.redis.get(token)
if redis_state is None:
self.set_state(token, self.state())
return self.get_state(token)
return cloudpickle.loads(redis_state)
pass
if token not in self.states:
self.states[token] = self.state()
return self.states[token]
def set_state(self, token: str, state: State):
@abstractmethod
async def set_state(self, token: str, state: State):
"""Set the state for a token.
Args:
token: The token to set the state for.
state: The state to set.
"""
if self.redis is None:
return
self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
pass
@abstractmethod
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]:
"""Modify the state for a token while holding exclusive lock.
Args:
token: The token to modify the state for.
Yields:
The state for the token.
"""
yield self.state()
class StateManagerMemory(StateManager):
"""A state manager that stores states in memory."""
# The mapping of client ids to states.
states: Dict[str, State] = {}
# The mutex ensures the dict of mutexes is updated exclusively
_state_manager_lock = asyncio.Lock()
# The dict of mutexes for each client
_states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
class Config:
"""The Pydantic config."""
fields = {
"_states_locks": {"exclude": True},
}
async def get_state(self, token: str) -> State:
"""Get the state for a token.
Args:
token: The token to get the state for.
Returns:
The state for the token.
"""
if token not in self.states:
self.states[token] = self.state()
return self.states[token]
async def set_state(self, token: str, state: State):
"""Set the state for a token.
Args:
token: The token to set the state for.
state: The state to set.
"""
pass
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]:
"""Modify the state for a token while holding exclusive lock.
Args:
token: The token to modify the state for.
Yields:
The state for the token.
"""
if token not in self._states_locks:
async with self._state_manager_lock:
if token not in self._states_locks:
self._states_locks[token] = asyncio.Lock()
async with self._states_locks[token]:
state = await self.get_state(token)
yield state
await self.set_state(token, state)
class StateManagerRedis(StateManager):
"""A state manager that stores states in redis."""
# The redis client to use.
redis: Redis
# The token expiration time (s).
token_expiration: int = constants.TOKEN_EXPIRATION
# The maximum time to hold a lock (ms).
lock_expiration: int = constants.LOCK_EXPIRATION
# The keyspace subscription string when redis is waiting for lock to be released
_redis_notify_keyspace_events: str = (
"K" # Enable keyspace notifications (target a particular key)
"g" # For generic commands (DEL, EXPIRE, etc)
"x" # For expired events
"e" # For evicted events (i.e. maxmemory exceeded)
)
# These events indicate that a lock is no longer held
_redis_keyspace_lock_release_events: Set[bytes] = {
b"del",
b"expire",
b"expired",
b"evicted",
}
async def get_state(self, token: str) -> State:
"""Get the state for a token.
Args:
token: The token to get the state for.
Returns:
The state for the token.
"""
redis_state = await self.redis.get(token)
if redis_state is None:
await self.set_state(token, self.state())
return await self.get_state(token)
return cloudpickle.loads(redis_state)
async def set_state(self, token: str, state: State, lock_id: bytes | None = None):
"""Set the state for a token.
Args:
token: The token to set the state for.
state: The state to set.
lock_id: If provided, the lock_key must be set to this value to set the state.
Raises:
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
"""
# check that we're holding the lock
if (
lock_id is not None
and await self.redis.get(self._lock_key(token)) != lock_id
):
raise LockExpiredError(
f"Lock expired for token {token} while processing. Consider increasing "
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
"or use `@rx.background` decorator for long-running tasks."
)
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]:
"""Modify the state for a token while holding exclusive lock.
Args:
token: The token to modify the state for.
Yields:
The state for the token.
"""
async with self._lock(token) as lock_id:
state = await self.get_state(token)
yield state
await self.set_state(token, state, lock_id)
@staticmethod
def _lock_key(token: str) -> bytes:
"""Get the redis key for a token's lock.
Args:
token: The token to get the lock key for.
Returns:
The redis lock key for the token.
"""
return f"{token}_lock".encode()
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
"""Try to get a redis lock for a token.
Args:
lock_key: The redis key for the lock.
lock_id: The ID of the lock.
Returns:
True if the lock was obtained.
"""
return await self.redis.set(
lock_key,
lock_id,
px=self.lock_expiration,
nx=True, # only set if it doesn't exist
)
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
"""Wait for a redis lock to be released via pubsub.
Coroutine will not return until the lock is obtained.
Args:
lock_key: The redis key for the lock.
lock_id: The ID of the lock.
"""
state_is_locked = False
lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
# Enable keyspace notifications for the lock key, so we know when it is available.
await self.redis.config_set(
"notify-keyspace-events", self._redis_notify_keyspace_events
)
async with self.redis.pubsub() as pubsub:
await pubsub.psubscribe(lock_key_channel)
while not state_is_locked:
# wait for the lock to be released
while True:
if not await self.redis.exists(lock_key):
break # key was removed, try to get the lock again
message = await pubsub.get_message(
ignore_subscribe_messages=True,
timeout=self.lock_expiration / 1000.0,
)
if message is None:
continue
if message["data"] in self._redis_keyspace_lock_release_events:
break
state_is_locked = await self._try_get_lock(lock_key, lock_id)
@contextlib.asynccontextmanager
async def _lock(self, token: str):
"""Obtain a redis lock for a token.
Args:
token: The token to obtain a lock for.
Yields:
The ID of the lock (to be passed to set_state).
Raises:
LockExpiredError: If the lock has expired while processing the event.
"""
lock_key = self._lock_key(token)
lock_id = uuid.uuid4().hex.encode()
if not await self._try_get_lock(lock_key, lock_id):
# Missed the fast-path to get lock, subscribe for lock delete/expire events
await self._wait_lock(lock_key, lock_id)
state_is_locked = True
try:
yield lock_id
except LockExpiredError:
state_is_locked = False
raise
finally:
if state_is_locked:
# only delete our lock
await self.redis.delete(lock_key)
class ClientStorageBase:
@ -1246,7 +1692,7 @@ class MutableProxy(wrapt.ObjectProxy):
value, super().__getattribute__("__mutable_types__")
) and __name not in ("__wrapped__", "_self_state"):
# Recursively wrap mutable attribute values retrieved through this proxy.
return MutableProxy(
return type(self)(
wrapped=value,
state=self._self_state,
field_name=self._self_field_name,
@ -1266,7 +1712,7 @@ class MutableProxy(wrapt.ObjectProxy):
value = super().__getitem__(key)
if isinstance(value, self.__mutable_types__):
# Recursively wrap mutable items retrieved through this proxy.
return MutableProxy(
return type(self)(
wrapped=value,
state=self._self_state,
field_name=self._self_field_name,
@ -1332,3 +1778,34 @@ class MutableProxy(wrapt.ObjectProxy):
A deepcopy of the wrapped object, unconnected to the proxy.
"""
return copy.deepcopy(self.__wrapped__, memo=memo)
class ImmutableMutableProxy(MutableProxy):
"""A proxy for a mutable object that tracks changes.
This wrapper comes from StateProxy, and will raise an exception if an attempt is made
to modify the wrapped object when the StateProxy is immutable.
"""
def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None):
"""Raise an exception when an attempt is made to modify the object.
Intended for use with `FunctionWrapper` from the `wrapt` library.
Args:
wrapped: The wrapped function.
instance: The instance of the wrapped function.
args: The args for the wrapped function.
kwargs: The kwargs for the wrapped function.
Raises:
ImmutableStateError: if the StateProxy is not mutable.
"""
if not self._self_state._self_mutable:
raise ImmutableStateError(
"Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state."
)
super()._mark_dirty(
wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
)

View File

@ -1,6 +1,7 @@
"""reflex.testing - tools for testing reflex apps."""
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import inspect
@ -19,14 +20,13 @@ import types
from http.server import SimpleHTTPRequestHandler
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Coroutine,
Optional,
Type,
TypeVar,
Union,
cast,
)
import psutil
@ -38,7 +38,7 @@ import reflex.utils.build
import reflex.utils.exec
import reflex.utils.prerequisites
import reflex.utils.processes
from reflex.app import EventNamespace
from reflex.state import State, StateManagerMemory, StateManagerRedis
try:
from selenium import webdriver # pyright: ignore [reportMissingImports]
@ -109,6 +109,7 @@ class AppHarness:
frontend_url: Optional[str] = None
backend_thread: Optional[threading.Thread] = None
backend: Optional[uvicorn.Server] = None
state_manager: Optional[StateManagerMemory | StateManagerRedis] = None
_frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
@classmethod
@ -162,6 +163,27 @@ class AppHarness:
reflex.config.get_config(reload=True)
self.app_module = reflex.utils.prerequisites.get_app(reload=True)
self.app_instance = self.app_module.app
if isinstance(self.app_instance.state_manager, StateManagerRedis):
# Create our own redis connection for testing.
self.state_manager = StateManagerRedis.create(self.app_instance.state)
else:
self.state_manager = self.app_instance.state_manager
def _get_backend_shutdown_handler(self):
if self.backend is None:
raise RuntimeError("Backend was not initialized.")
original_shutdown = self.backend.shutdown
async def _shutdown_redis(*args, **kwargs) -> None:
# ensure redis is closed before event loop
if self.app_instance is not None and isinstance(
self.app_instance.state_manager, StateManagerRedis
):
await self.app_instance.state_manager.redis.close()
await original_shutdown(*args, **kwargs)
return _shutdown_redis
def _start_backend(self, port=0):
if self.app_instance is None:
@ -173,6 +195,7 @@ class AppHarness:
port=port,
)
)
self.backend.shutdown = self._get_backend_shutdown_handler()
self.backend_thread = threading.Thread(target=self.backend.run)
self.backend_thread.start()
@ -296,6 +319,35 @@ class AppHarness:
time.sleep(step)
return False
@staticmethod
async def _poll_for_async(
target: Callable[[], Coroutine[None, None, T]],
timeout: TimeoutType = None,
step: TimeoutType = None,
) -> T | bool:
"""Generic polling logic for async functions.
Args:
target: callable that returns truthy if polling condition is met.
timeout: max polling time
step: interval between checking target()
Returns:
return value of target() if truthy within timeout
False if timeout elapses
"""
if timeout is None:
timeout = DEFAULT_TIMEOUT
if step is None:
step = POLL_INTERVAL
deadline = time.time() + timeout
while time.time() < deadline:
success = await target()
if success:
return success
await asyncio.sleep(step)
return False
def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
"""Poll backend server for listening sockets.
@ -351,39 +403,76 @@ class AppHarness:
self._frontends.append(driver)
return driver
async def emit_state_updates(self) -> list[Any]:
"""Send any backend state deltas to the frontend.
async def get_state(self, token: str) -> State:
"""Get the state associated with the given token.
Args:
token: The state token to look up.
Returns:
List of awaited response from each EventNamespace.emit() call.
The state instance associated with the given token
Raises:
RuntimeError: when the app hasn't started running
"""
if self.app_instance is None or self.app_instance.sio is None:
if self.state_manager is None:
raise RuntimeError("state_manager is not set.")
try:
return await self.state_manager.get_state(token)
finally:
if isinstance(self.state_manager, StateManagerRedis):
await self.state_manager.redis.close()
async def set_state(self, token: str, **kwargs) -> None:
"""Set the state associated with the given token.
Args:
token: The state token to set.
kwargs: Attributes to set on the state.
Raises:
RuntimeError: when the app hasn't started running
"""
if self.state_manager is None:
raise RuntimeError("state_manager is not set.")
state = await self.get_state(token)
for key, value in kwargs.items():
setattr(state, key, value)
try:
await self.state_manager.set_state(token, state)
finally:
if isinstance(self.state_manager, StateManagerRedis):
await self.state_manager.redis.close()
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[State]:
"""Modify the state associated with the given token and send update to frontend.
Args:
token: The state token to modify
Yields:
The state instance associated with the given token
Raises:
RuntimeError: when the app hasn't started running
"""
if self.state_manager is None:
raise RuntimeError("state_manager is not set.")
if self.app_instance is None:
raise RuntimeError("App is not running.")
event_ns: EventNamespace = cast(
EventNamespace,
self.app_instance.event_namespace,
)
pending: list[Coroutine[Any, Any, Any]] = []
for state in self.app_instance.state_manager.states.values():
delta = state.get_delta()
if delta:
update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
state._clean()
# Emit the event.
pending.append(
event_ns.emit(
str(reflex.constants.SocketEvent.EVENT),
update.json(),
to=state.get_sid(),
),
)
responses = []
for request in pending:
responses.append(await request)
return responses
app_state_manager = self.app_instance.state_manager
if isinstance(self.state_manager, StateManagerRedis):
# Temporarily replace the app's state manager with our own, since
# the redis connection is on the backend_thread event loop
self.app_instance.state_manager = self.state_manager
try:
async with self.app_instance.modify_state(token) as state:
yield state
finally:
if isinstance(self.state_manager, StateManagerRedis):
self.app_instance.state_manager = app_state_manager
await self.state_manager.redis.close()
def poll_for_content(
self,
@ -457,6 +546,9 @@ class AppHarness:
if self.app_instance is None:
raise RuntimeError("App is not running.")
state_manager = self.app_instance.state_manager
assert isinstance(
state_manager, StateManagerMemory
), "Only works with memory state manager"
if not self._poll_for(
target=lambda: state_manager.states,
timeout=timeout,
@ -534,7 +626,6 @@ class Subdir404TCPServer(socketserver.TCPServer):
request: the requesting socket
client_address: (host, port) referring to the clients address.
"""
print(client_address, type(client_address))
self.RequestHandlerClass(
request,
client_address,
@ -605,6 +696,7 @@ class AppHarnessProd(AppHarness):
workers=reflex.utils.processes.get_num_workers(),
),
)
self.backend.shutdown = self._get_backend_shutdown_handler()
self.backend_thread = threading.Thread(target=self.backend.run)
self.backend_thread.start()

View File

@ -5,3 +5,11 @@ class InvalidStylePropError(TypeError):
"""Custom Type Error when style props have invalid values."""
pass
class ImmutableStateError(AttributeError):
"""Raised when a background task attempts to modify state outside of context."""
class LockExpiredError(Exception):
"""Raised when the state lock expires while an event is being processed."""

View File

@ -21,7 +21,7 @@ import httpx
import typer
from alembic.util.exc import CommandError
from packaging import version
from redis import Redis
from redis.asyncio import Redis
from reflex import constants, model
from reflex.compiler import templates
@ -124,9 +124,11 @@ def get_redis() -> Redis | None:
The redis client.
"""
config = get_config()
if config.redis_url is None:
if not config.redis_url:
return None
redis_url, redis_port = config.redis_url.split(":")
redis_url, has_port, redis_port = config.redis_url.partition(":")
if not has_port:
redis_port = 6379
console.info(f"Using redis at {config.redis_url}")
return Redis(host=redis_url, port=int(redis_port), db=0)

View File

@ -2,8 +2,9 @@
import contextlib
import os
import platform
import uuid
from pathlib import Path
from typing import Dict, Generator, List, Set, Union
from typing import Dict, Generator
import pytest
@ -11,6 +12,14 @@ import reflex as rx
from reflex.app import App
from reflex.event import EventSpec
from .states import (
DictMutationTestState,
ListMutationTestState,
MutableTestState,
SubUploadState,
UploadState,
)
@pytest.fixture
def app() -> App:
@ -39,60 +48,7 @@ def list_mutation_state():
Returns:
A state with list mutation features.
"""
class TestState(rx.State):
"""The test state."""
# plain list
plain_friends = ["Tommy"]
def make_friend(self):
self.plain_friends.append("another-fd")
def change_first_friend(self):
self.plain_friends[0] = "Jenny"
def unfriend_all_friends(self):
self.plain_friends.clear()
def unfriend_first_friend(self):
del self.plain_friends[0]
def remove_last_friend(self):
self.plain_friends.pop()
def make_friends_with_colleagues(self):
colleagues = ["Peter", "Jimmy"]
self.plain_friends.extend(colleagues)
def remove_tommy(self):
self.plain_friends.remove("Tommy")
# list in dict
friends_in_dict = {"Tommy": ["Jenny"]}
def remove_jenny_from_tommy(self):
self.friends_in_dict["Tommy"].remove("Jenny")
def add_jimmy_to_tommy_friends(self):
self.friends_in_dict["Tommy"].append("Jimmy")
def tommy_has_no_fds(self):
self.friends_in_dict["Tommy"].clear()
# nested list
friends_in_nested_list = [["Tommy"], ["Jenny"]]
def remove_first_group(self):
self.friends_in_nested_list.pop(0)
def remove_first_person_from_first_group(self):
self.friends_in_nested_list[0].pop(0)
def add_jimmy_to_second_group(self):
self.friends_in_nested_list[1].append("Jimmy")
return TestState()
return ListMutationTestState()
@pytest.fixture
@ -102,85 +58,7 @@ def dict_mutation_state():
Returns:
A state with dict mutation features.
"""
class TestState(rx.State):
"""The test state."""
# plain dict
details = {"name": "Tommy"}
def add_age(self):
self.details.update({"age": 20}) # type: ignore
def change_name(self):
self.details["name"] = "Jenny"
def remove_last_detail(self):
self.details.popitem()
def clear_details(self):
self.details.clear()
def remove_name(self):
del self.details["name"]
def pop_out_age(self):
self.details.pop("age")
# dict in list
address = [{"home": "home address"}, {"work": "work address"}]
def remove_home_address(self):
self.address[0].pop("home")
def add_street_to_home_address(self):
self.address[0]["street"] = "street address"
# nested dict
friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
def change_friend_name(self):
self.friend_in_nested_dict["friend"]["name"] = "Tommy"
def remove_friend(self):
self.friend_in_nested_dict.pop("friend")
def add_friend_age(self):
self.friend_in_nested_dict["friend"]["age"] = 30
return TestState()
class UploadState(rx.State):
"""The base state for uploading a file."""
async def handle_upload1(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
pass
class BaseState(rx.State):
"""The test base state."""
pass
class SubUploadState(BaseState):
"""The test substate."""
img: str
async def handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
pass
return DictMutationTestState()
@pytest.fixture
@ -203,187 +81,6 @@ def upload_event_spec():
return EventSpec(handler=UploadState.handle_upload1, upload=True) # type: ignore
@pytest.fixture
def upload_state(tmp_path):
"""Create upload state.
Args:
tmp_path: pytest tmp_path
Returns:
The state
"""
class FileUploadState(rx.State):
"""The base state for uploading a file."""
img_list: List[str]
async def handle_upload2(self, files):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
return FileUploadState
@pytest.fixture
def upload_sub_state(tmp_path):
"""Create upload substate.
Args:
tmp_path: pytest tmp_path
Returns:
The state
"""
class FileState(rx.State):
"""The base state."""
pass
class FileUploadState(FileState):
"""The substate for uploading a file."""
img_list: List[str]
async def handle_upload2(self, files):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
return FileUploadState
@pytest.fixture
def upload_grand_sub_state(tmp_path):
"""Create upload grand-state.
Args:
tmp_path: pytest tmp_path
Returns:
The state
"""
class BaseFileState(rx.State):
"""The base state."""
pass
class FileSubState(BaseFileState):
"""The substate."""
pass
class FileUploadState(FileSubState):
"""The grand-substate for uploading a file."""
img_list: List[str]
async def handle_upload2(self, files):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
return FileUploadState
@pytest.fixture
def base_config_values() -> Dict:
"""Get base config values.
@ -418,35 +115,6 @@ def sqlite_db_config_values(base_db_config_values) -> Dict:
return base_db_config_values
class GenState(rx.State):
"""A state with event handlers that generate multiple updates."""
value: int
def go(self, c: int):
"""Increment the value c times and update each time.
Args:
c: The number of times to increment.
Yields:
After each increment.
"""
for _ in range(c):
self.value += 1
yield
@pytest.fixture
def gen_state() -> GenState:
"""A state.
Returns:
A test state.
"""
return GenState # type: ignore
@pytest.fixture
def router_data_headers() -> Dict[str, str]:
"""Router data headers.
@ -546,46 +214,19 @@ def mutable_state():
Returns:
A state object.
"""
class OtherBase(rx.Base):
bar: str = ""
class CustomVar(rx.Base):
foo: str = ""
array: List[str] = []
hashmap: Dict[str, str] = {}
test_set: Set[str] = set()
custom: OtherBase = OtherBase()
class MutableTestState(rx.State):
"""A test state."""
array: List[Union[str, List, Dict[str, str]]] = [
"value",
[1, 2, 3],
{"key": "value"},
]
hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
"key": ["list", "of", "values"],
"another_key": "another_value",
"third_key": {"key": "value"},
}
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
custom: CustomVar = CustomVar()
_be_custom: CustomVar = CustomVar()
def reassign_mutables(self):
self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
self.hashmap = {
"mod_key": ["list", "of", "values"],
"mod_another_key": "another_value",
"mod_third_key": {"key": "value"},
}
self.test_set = {1, 2, 3, 4, "five"}
return MutableTestState()
@pytest.fixture(scope="function")
def token() -> str:
"""Create a token.
Returns:
A fresh/unique token string.
"""
return str(uuid.uuid4())
@pytest.fixture
def duplicate_substate():
"""Create a Test state that has duplicate child substates.

30
tests/states/__init__.py Normal file
View File

@ -0,0 +1,30 @@
"""Common rx.State subclasses for use in tests."""
import reflex as rx
from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
from .upload import (
ChildFileUploadState,
FileUploadState,
GrandChildFileUploadState,
SubUploadState,
UploadState,
)
class GenState(rx.State):
"""A state with event handlers that generate multiple updates."""
value: int
def go(self, c: int):
"""Increment the value c times and update each time.
Args:
c: The number of times to increment.
Yields:
After each increment.
"""
for _ in range(c):
self.value += 1
yield

172
tests/states/mutation.py Normal file
View File

@ -0,0 +1,172 @@
"""Test states for mutable vars."""
from typing import Dict, List, Set, Union
import reflex as rx
class DictMutationTestState(rx.State):
"""A state for testing ReflexDict mutation."""
# plain dict
details = {"name": "Tommy"}
def add_age(self):
"""Add an age to the dict."""
self.details.update({"age": 20}) # type: ignore
def change_name(self):
"""Change the name in the dict."""
self.details["name"] = "Jenny"
def remove_last_detail(self):
"""Remove the last item in the dict."""
self.details.popitem()
def clear_details(self):
"""Clear the dict."""
self.details.clear()
def remove_name(self):
"""Remove the name from the dict."""
del self.details["name"]
def pop_out_age(self):
"""Pop out the age from the dict."""
self.details.pop("age")
# dict in list
address = [{"home": "home address"}, {"work": "work address"}]
def remove_home_address(self):
"""Remove the home address from dict in the list."""
self.address[0].pop("home")
def add_street_to_home_address(self):
"""Set street key in the dict in the list."""
self.address[0]["street"] = "street address"
# nested dict
friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
def change_friend_name(self):
"""Change the friend's name in the nested dict."""
self.friend_in_nested_dict["friend"]["name"] = "Tommy"
def remove_friend(self):
"""Remove the friend from the nested dict."""
self.friend_in_nested_dict.pop("friend")
def add_friend_age(self):
"""Add an age to the friend in the nested dict."""
self.friend_in_nested_dict["friend"]["age"] = 30
class ListMutationTestState(rx.State):
"""A state for testing ReflexList mutation."""
# plain list
plain_friends = ["Tommy"]
def make_friend(self):
"""Add a friend to the list."""
self.plain_friends.append("another-fd")
def change_first_friend(self):
"""Change the first friend in the list."""
self.plain_friends[0] = "Jenny"
def unfriend_all_friends(self):
"""Unfriend all friends in the list."""
self.plain_friends.clear()
def unfriend_first_friend(self):
"""Unfriend the first friend in the list."""
del self.plain_friends[0]
def remove_last_friend(self):
"""Remove the last friend in the list."""
self.plain_friends.pop()
def make_friends_with_colleagues(self):
"""Add list of friends to the list."""
colleagues = ["Peter", "Jimmy"]
self.plain_friends.extend(colleagues)
def remove_tommy(self):
"""Remove Tommy from the list."""
self.plain_friends.remove("Tommy")
# list in dict
friends_in_dict = {"Tommy": ["Jenny"]}
def remove_jenny_from_tommy(self):
"""Remove Jenny from Tommy's friends list."""
self.friends_in_dict["Tommy"].remove("Jenny")
def add_jimmy_to_tommy_friends(self):
"""Add Jimmy to Tommy's friends list."""
self.friends_in_dict["Tommy"].append("Jimmy")
def tommy_has_no_fds(self):
"""Clear Tommy's friends list."""
self.friends_in_dict["Tommy"].clear()
# nested list
friends_in_nested_list = [["Tommy"], ["Jenny"]]
def remove_first_group(self):
"""Remove the first group of friends from the nested list."""
self.friends_in_nested_list.pop(0)
def remove_first_person_from_first_group(self):
"""Remove the first person from the first group of friends in the nested list."""
self.friends_in_nested_list[0].pop(0)
def add_jimmy_to_second_group(self):
"""Add Jimmy to the second group of friends in the nested list."""
self.friends_in_nested_list[1].append("Jimmy")
class OtherBase(rx.Base):
"""A Base model with a str field."""
bar: str = ""
class CustomVar(rx.Base):
"""A Base model with multiple fields."""
foo: str = ""
array: List[str] = []
hashmap: Dict[str, str] = {}
test_set: Set[str] = set()
custom: OtherBase = OtherBase()
class MutableTestState(rx.State):
"""A test state."""
array: List[Union[str, List, Dict[str, str]]] = [
"value",
[1, 2, 3],
{"key": "value"},
]
hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
"key": ["list", "of", "values"],
"another_key": "another_value",
"third_key": {"key": "value"},
}
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
custom: CustomVar = CustomVar()
_be_custom: CustomVar = CustomVar()
def reassign_mutables(self):
"""Assign mutable fields to different values."""
self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
self.hashmap = {
"mod_key": ["list", "of", "values"],
"mod_another_key": "another_value",
"mod_third_key": {"key": "value"},
}
self.test_set = {1, 2, 3, 4, "five"}

175
tests/states/upload.py Normal file
View File

@ -0,0 +1,175 @@
"""Test states for upload-related tests."""
from pathlib import Path
from typing import ClassVar, List
import reflex as rx
class UploadState(rx.State):
"""The base state for uploading a file."""
async def handle_upload1(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
pass
class BaseState(rx.State):
"""The test base state."""
pass
class SubUploadState(BaseState):
"""The test substate."""
img: str
async def handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
pass
class FileUploadState(rx.State):
"""The base state for uploading a file."""
img_list: List[str]
_tmp_path: ClassVar[Path]
async def handle_upload2(self, files):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
class FileStateBase1(rx.State):
"""The base state for a child FileUploadState."""
pass
class ChildFileUploadState(FileStateBase1):
"""The child state for uploading a file."""
img_list: List[str]
_tmp_path: ClassVar[Path]
async def handle_upload2(self, files):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
class FileStateBase2(FileStateBase1):
"""The parent state for a grandchild FileUploadState."""
pass
class GrandChildFileUploadState(FileStateBase2):
"""The child state for uploading a file."""
img_list: List[str]
_tmp_path: ClassVar[Path]
async def handle_upload2(self, files):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import io
import os.path
import sys
import uuid
from typing import List, Tuple, Type
if sys.version_info.major >= 3 and sys.version_info.minor > 7:
@ -30,11 +31,18 @@ from reflex.components import Box, Component, Cond, Fragment, Text
from reflex.event import Event, get_hydrate_event
from reflex.middleware import HydrateMiddleware
from reflex.model import Model
from reflex.state import State, StateUpdate
from reflex.state import State, StateManagerRedis, StateUpdate
from reflex.style import Style
from reflex.utils import format
from reflex.vars import ComputedVar
from .states import (
ChildFileUploadState,
FileUploadState,
GenState,
GrandChildFileUploadState,
)
@pytest.fixture
def index_page():
@ -64,6 +72,12 @@ def about_page():
return about
class ATestState(State):
"""A simple state for testing."""
var: int
@pytest.fixture()
def test_state() -> Type[State]:
"""A default state.
@ -71,11 +85,7 @@ def test_state() -> Type[State]:
Returns:
A default state.
"""
class TestState(State):
var: int
return TestState
return ATestState
@pytest.fixture()
@ -313,23 +323,28 @@ def test_initialize_admin_dashboard_with_view_overrides(test_model):
assert app.admin_dash.view_overrides[test_model] == TestModelView
def test_initialize_with_state(test_state):
@pytest.mark.asyncio
async def test_initialize_with_state(test_state: Type[ATestState], token: str):
"""Test setting the state of an app.
Args:
test_state: The default state.
token: a Token.
"""
app = App(state=test_state)
assert app.state == test_state
# Get a state for a given token.
token = "token"
state = app.state_manager.get_state(token)
state = await app.state_manager.get_state(token)
assert isinstance(state, test_state)
assert state.var == 0 # type: ignore
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.redis.close()
def test_set_and_get_state(test_state):
@pytest.mark.asyncio
async def test_set_and_get_state(test_state):
"""Test setting and getting the state of an app with different tokens.
Args:
@ -338,47 +353,51 @@ def test_set_and_get_state(test_state):
app = App(state=test_state)
# Create two tokens.
token1 = "token1"
token2 = "token2"
token1 = str(uuid.uuid4())
token2 = str(uuid.uuid4())
# Get the default state for each token.
state1 = app.state_manager.get_state(token1)
state2 = app.state_manager.get_state(token2)
state1 = await app.state_manager.get_state(token1)
state2 = await app.state_manager.get_state(token2)
assert state1.var == 0 # type: ignore
assert state2.var == 0 # type: ignore
# Set the vars to different values.
state1.var = 1
state2.var = 2
app.state_manager.set_state(token1, state1)
app.state_manager.set_state(token2, state2)
await app.state_manager.set_state(token1, state1)
await app.state_manager.set_state(token2, state2)
# Get the states again and check the values.
state1 = app.state_manager.get_state(token1)
state2 = app.state_manager.get_state(token2)
state1 = await app.state_manager.get_state(token1)
state2 = await app.state_manager.get_state(token2)
assert state1.var == 1 # type: ignore
assert state2.var == 2 # type: ignore
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.redis.close()
@pytest.mark.asyncio
async def test_dynamic_var_event(test_state):
async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
"""Test that the default handler of a dynamic generated var
works as expected.
Args:
test_state: State Fixture.
token: a Token.
"""
test_state = test_state()
test_state.add_var("int_val", int, 0)
result = await test_state._process(
state = test_state() # type: ignore
state.add_var("int_val", int, 0)
result = await state._process(
Event(
token="fake-token",
name="test_state.set_int_val",
token=token,
name=f"{test_state.get_name()}.set_int_val",
router_data={"pathname": "/", "query": {}},
payload={"value": 50},
)
).__anext__()
assert result.delta == {"test_state": {"int_val": 50}}
assert result.delta == {test_state.get_name(): {"int_val": 50}}
@pytest.mark.asyncio
@ -388,12 +407,20 @@ async def test_dynamic_var_event(test_state):
pytest.param(
[
(
"test_state.make_friend",
{"test_state": {"plain_friends": ["Tommy", "another-fd"]}},
"list_mutation_test_state.make_friend",
{
"list_mutation_test_state": {
"plain_friends": ["Tommy", "another-fd"]
}
},
),
(
"test_state.change_first_friend",
{"test_state": {"plain_friends": ["Jenny", "another-fd"]}},
"list_mutation_test_state.change_first_friend",
{
"list_mutation_test_state": {
"plain_friends": ["Jenny", "another-fd"]
}
},
),
],
id="append then __setitem__",
@ -401,12 +428,12 @@ async def test_dynamic_var_event(test_state):
pytest.param(
[
(
"test_state.unfriend_first_friend",
{"test_state": {"plain_friends": []}},
"list_mutation_test_state.unfriend_first_friend",
{"list_mutation_test_state": {"plain_friends": []}},
),
(
"test_state.make_friend",
{"test_state": {"plain_friends": ["another-fd"]}},
"list_mutation_test_state.make_friend",
{"list_mutation_test_state": {"plain_friends": ["another-fd"]}},
),
],
id="delitem then append",
@ -414,20 +441,24 @@ async def test_dynamic_var_event(test_state):
pytest.param(
[
(
"test_state.make_friends_with_colleagues",
{"test_state": {"plain_friends": ["Tommy", "Peter", "Jimmy"]}},
"list_mutation_test_state.make_friends_with_colleagues",
{
"list_mutation_test_state": {
"plain_friends": ["Tommy", "Peter", "Jimmy"]
}
},
),
(
"test_state.remove_tommy",
{"test_state": {"plain_friends": ["Peter", "Jimmy"]}},
"list_mutation_test_state.remove_tommy",
{"list_mutation_test_state": {"plain_friends": ["Peter", "Jimmy"]}},
),
(
"test_state.remove_last_friend",
{"test_state": {"plain_friends": ["Peter"]}},
"list_mutation_test_state.remove_last_friend",
{"list_mutation_test_state": {"plain_friends": ["Peter"]}},
),
(
"test_state.unfriend_all_friends",
{"test_state": {"plain_friends": []}},
"list_mutation_test_state.unfriend_all_friends",
{"list_mutation_test_state": {"plain_friends": []}},
),
],
id="extend, remove, pop, clear",
@ -435,24 +466,28 @@ async def test_dynamic_var_event(test_state):
pytest.param(
[
(
"test_state.add_jimmy_to_second_group",
"list_mutation_test_state.add_jimmy_to_second_group",
{
"test_state": {
"list_mutation_test_state": {
"friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]
}
},
),
(
"test_state.remove_first_person_from_first_group",
"list_mutation_test_state.remove_first_person_from_first_group",
{
"test_state": {
"list_mutation_test_state": {
"friends_in_nested_list": [[], ["Jenny", "Jimmy"]]
}
},
),
(
"test_state.remove_first_group",
{"test_state": {"friends_in_nested_list": [["Jenny", "Jimmy"]]}},
"list_mutation_test_state.remove_first_group",
{
"list_mutation_test_state": {
"friends_in_nested_list": [["Jenny", "Jimmy"]]
}
},
),
],
id="nested list",
@ -460,16 +495,24 @@ async def test_dynamic_var_event(test_state):
pytest.param(
[
(
"test_state.add_jimmy_to_tommy_friends",
{"test_state": {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}},
"list_mutation_test_state.add_jimmy_to_tommy_friends",
{
"list_mutation_test_state": {
"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}
}
},
),
(
"test_state.remove_jenny_from_tommy",
{"test_state": {"friends_in_dict": {"Tommy": ["Jimmy"]}}},
"list_mutation_test_state.remove_jenny_from_tommy",
{
"list_mutation_test_state": {
"friends_in_dict": {"Tommy": ["Jimmy"]}
}
},
),
(
"test_state.tommy_has_no_fds",
{"test_state": {"friends_in_dict": {"Tommy": []}}},
"list_mutation_test_state.tommy_has_no_fds",
{"list_mutation_test_state": {"friends_in_dict": {"Tommy": []}}},
),
],
id="list in dict",
@ -477,7 +520,9 @@ async def test_dynamic_var_event(test_state):
],
)
async def test_list_mutation_detection__plain_list(
event_tuples: List[Tuple[str, List[str]]], list_mutation_state: State
event_tuples: List[Tuple[str, List[str]]],
list_mutation_state: State,
token: str,
):
"""Test list mutation detection
when reassignment is not explicitly included in the logic.
@ -485,11 +530,12 @@ async def test_list_mutation_detection__plain_list(
Args:
event_tuples: From parametrization.
list_mutation_state: A state with list mutation features.
token: a Token.
"""
for event_name, expected_delta in event_tuples:
result = await list_mutation_state._process(
Event(
token="fake-token",
token=token,
name=event_name,
router_data={"pathname": "/", "query": {}},
payload={},
@ -506,16 +552,24 @@ async def test_list_mutation_detection__plain_list(
pytest.param(
[
(
"test_state.add_age",
{"test_state": {"details": {"name": "Tommy", "age": 20}}},
"dict_mutation_test_state.add_age",
{
"dict_mutation_test_state": {
"details": {"name": "Tommy", "age": 20}
}
},
),
(
"test_state.change_name",
{"test_state": {"details": {"name": "Jenny", "age": 20}}},
"dict_mutation_test_state.change_name",
{
"dict_mutation_test_state": {
"details": {"name": "Jenny", "age": 20}
}
},
),
(
"test_state.remove_last_detail",
{"test_state": {"details": {"name": "Jenny"}}},
"dict_mutation_test_state.remove_last_detail",
{"dict_mutation_test_state": {"details": {"name": "Jenny"}}},
),
],
id="update then __setitem__",
@ -523,12 +577,12 @@ async def test_list_mutation_detection__plain_list(
pytest.param(
[
(
"test_state.clear_details",
{"test_state": {"details": {}}},
"dict_mutation_test_state.clear_details",
{"dict_mutation_test_state": {"details": {}}},
),
(
"test_state.add_age",
{"test_state": {"details": {"age": 20}}},
"dict_mutation_test_state.add_age",
{"dict_mutation_test_state": {"details": {"age": 20}}},
),
],
id="delitem then update",
@ -536,16 +590,20 @@ async def test_list_mutation_detection__plain_list(
pytest.param(
[
(
"test_state.add_age",
{"test_state": {"details": {"name": "Tommy", "age": 20}}},
"dict_mutation_test_state.add_age",
{
"dict_mutation_test_state": {
"details": {"name": "Tommy", "age": 20}
}
},
),
(
"test_state.remove_name",
{"test_state": {"details": {"age": 20}}},
"dict_mutation_test_state.remove_name",
{"dict_mutation_test_state": {"details": {"age": 20}}},
),
(
"test_state.pop_out_age",
{"test_state": {"details": {}}},
"dict_mutation_test_state.pop_out_age",
{"dict_mutation_test_state": {"details": {}}},
),
],
id="add, remove, pop",
@ -553,13 +611,17 @@ async def test_list_mutation_detection__plain_list(
pytest.param(
[
(
"test_state.remove_home_address",
{"test_state": {"address": [{}, {"work": "work address"}]}},
"dict_mutation_test_state.remove_home_address",
{
"dict_mutation_test_state": {
"address": [{}, {"work": "work address"}]
}
},
),
(
"test_state.add_street_to_home_address",
"dict_mutation_test_state.add_street_to_home_address",
{
"test_state": {
"dict_mutation_test_state": {
"address": [
{"street": "street address"},
{"work": "work address"},
@ -573,9 +635,9 @@ async def test_list_mutation_detection__plain_list(
pytest.param(
[
(
"test_state.change_friend_name",
"dict_mutation_test_state.change_friend_name",
{
"test_state": {
"dict_mutation_test_state": {
"friend_in_nested_dict": {
"name": "Nikhil",
"friend": {"name": "Tommy"},
@ -584,9 +646,9 @@ async def test_list_mutation_detection__plain_list(
},
),
(
"test_state.add_friend_age",
"dict_mutation_test_state.add_friend_age",
{
"test_state": {
"dict_mutation_test_state": {
"friend_in_nested_dict": {
"name": "Nikhil",
"friend": {"name": "Tommy", "age": 30},
@ -595,8 +657,12 @@ async def test_list_mutation_detection__plain_list(
},
),
(
"test_state.remove_friend",
{"test_state": {"friend_in_nested_dict": {"name": "Nikhil"}}},
"dict_mutation_test_state.remove_friend",
{
"dict_mutation_test_state": {
"friend_in_nested_dict": {"name": "Nikhil"}
}
},
),
],
id="nested dict",
@ -604,7 +670,9 @@ async def test_list_mutation_detection__plain_list(
],
)
async def test_dict_mutation_detection__plain_list(
event_tuples: List[Tuple[str, List[str]]], dict_mutation_state: State
event_tuples: List[Tuple[str, List[str]]],
dict_mutation_state: State,
token: str,
):
"""Test dict mutation detection
when reassignment is not explicitly included in the logic.
@ -612,11 +680,12 @@ async def test_dict_mutation_detection__plain_list(
Args:
event_tuples: From parametrization.
dict_mutation_state: A state with dict mutation features.
token: a Token.
"""
for event_name, expected_delta in event_tuples:
result = await dict_mutation_state._process(
Event(
token="fake-token",
token=token,
name=event_name,
router_data={"pathname": "/", "query": {}},
payload={},
@ -628,41 +697,43 @@ async def test_dict_mutation_detection__plain_list(
@pytest.mark.asyncio
@pytest.mark.parametrize(
"fixture, delta",
("state", "delta"),
[
(
"upload_state",
FileUploadState,
{"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
),
(
"upload_sub_state",
ChildFileUploadState,
{
"file_state.file_upload_state": {
"file_state_base1.child_file_upload_state": {
"img_list": ["image1.jpg", "image2.jpg"]
}
},
),
(
"upload_grand_sub_state",
GrandChildFileUploadState,
{
"base_file_state.file_sub_state.file_upload_state": {
"file_state_base1.file_state_base2.grand_child_file_upload_state": {
"img_list": ["image1.jpg", "image2.jpg"]
}
},
),
],
)
async def test_upload_file(fixture, request, delta):
async def test_upload_file(tmp_path, state, delta, token: str):
"""Test that file upload works correctly.
Args:
fixture: The state.
request: Fixture request.
tmp_path: Temporary path.
state: The state class.
delta: Expected delta
token: a Token.
"""
app = App(state=request.getfixturevalue(fixture))
state._tmp_path = tmp_path
app = App(state=state)
app.event_namespace.emit = AsyncMock() # type: ignore
current_state = app.state_manager.get_state("token")
current_state = await app.state_manager.get_state(token)
data = b"This is binary data"
# Create a binary IO object and write data to it
@ -670,11 +741,11 @@ async def test_upload_file(fixture, request, delta):
bio.write(data)
file1 = UploadFile(
filename="token:file_upload_state.multi_handle_upload:True:image1.jpg",
filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image1.jpg",
file=bio,
)
file2 = UploadFile(
filename="token:file_upload_state.multi_handle_upload:True:image2.jpg",
filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image2.jpg",
file=bio,
)
upload_fn = upload(app)
@ -684,22 +755,27 @@ async def test_upload_file(fixture, request, delta):
app.event_namespace.emit.assert_called_with( # type: ignore
"event", state_update.json(), to=current_state.get_sid()
)
assert app.state_manager.get_state("token").dict()["img_list"] == [
assert (await app.state_manager.get_state(token)).dict()["img_list"] == [
"image1.jpg",
"image2.jpg",
]
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.redis.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"fixture", ["upload_state", "upload_sub_state", "upload_grand_sub_state"]
"state",
[FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
)
async def test_upload_file_without_annotation(fixture, request):
async def test_upload_file_without_annotation(state, tmp_path, token):
"""Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile].
Args:
fixture: The state.
request: Fixture request.
state: The state class.
tmp_path: Temporary path.
token: a Token.
"""
data = b"This is binary data"
@ -707,14 +783,15 @@ async def test_upload_file_without_annotation(fixture, request):
bio = io.BytesIO()
bio.write(data)
app = App(state=request.getfixturevalue(fixture))
state._tmp_path = tmp_path
app = App(state=state)
file1 = UploadFile(
filename="token:file_upload_state.handle_upload2:True:image1.jpg",
filename=f"{token}:{state.get_name()}.handle_upload2:True:image1.jpg",
file=bio,
)
file2 = UploadFile(
filename="token:file_upload_state.handle_upload2:True:image2.jpg",
filename=f"{token}:{state.get_name()}.handle_upload2:True:image2.jpg",
file=bio,
)
fn = upload(app)
@ -722,9 +799,12 @@ async def test_upload_file_without_annotation(fixture, request):
await fn([file1, file2])
assert (
err.value.args[0]
== "`file_upload_state.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
== f"`{state.get_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.redis.close()
class DynamicState(State):
"""State class for testing dynamic route var.
@ -768,6 +848,7 @@ class DynamicState(State):
async def test_dynamic_route_var_route_change_completed_on_load(
index_page,
windows_platform: bool,
token: str,
):
"""Create app with dynamic route var, and simulate navigation.
@ -777,6 +858,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
Args:
index_page: The index page.
windows_platform: Whether the system is windows.
token: a Token.
"""
arg_name = "dynamic"
route = f"/test/[{arg_name}]"
@ -792,10 +874,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
}
assert constants.ROUTER_DATA in app.state().computed_var_dependencies
token = "mock_token"
sid = "mock_sid"
client_ip = "127.0.0.1"
state = app.state_manager.get_state(token)
state = await app.state_manager.get_state(token)
assert state.dynamic == ""
exp_vals = ["foo", "foobar", "baz"]
@ -817,6 +898,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
**kwargs,
)
prev_exp_val = ""
for exp_index, exp_val in enumerate(exp_vals):
hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
exp_router_data = {
@ -826,13 +908,14 @@ async def test_dynamic_route_var_route_change_completed_on_load(
"token": token,
**hydrate_event.router_data,
}
update = await process(
process_coro = process(
app,
event=hydrate_event,
sid=sid,
headers={},
client_ip=client_ip,
).__anext__() # type: ignore
)
update = await process_coro.__anext__() # type: ignore
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
assert update == StateUpdate(
@ -860,14 +943,27 @@ async def test_dynamic_route_var_route_change_completed_on_load(
),
],
)
if isinstance(app.state_manager, StateManagerRedis):
# When redis is used, the state is not updated until the processing is complete
state = await app.state_manager.get_state(token)
assert state.dynamic == prev_exp_val
# complete the processing
with pytest.raises(StopAsyncIteration):
await process_coro.__anext__() # type: ignore
# check that router data was written to the state_manager store
state = await app.state_manager.get_state(token)
assert state.dynamic == exp_val
on_load_update = await process(
process_coro = process(
app,
event=_dynamic_state_event(name="on_load", val=exp_val),
sid=sid,
headers={},
client_ip=client_ip,
).__anext__() # type: ignore
)
on_load_update = await process_coro.__anext__() # type: ignore
assert on_load_update == StateUpdate(
delta={
state.get_name(): {
@ -879,7 +975,10 @@ async def test_dynamic_route_var_route_change_completed_on_load(
},
events=[],
)
on_set_is_hydrated_update = await process(
# complete the processing
with pytest.raises(StopAsyncIteration):
await process_coro.__anext__() # type: ignore
process_coro = process(
app,
event=_dynamic_state_event(
name="set_is_hydrated", payload={"value": True}, val=exp_val
@ -887,7 +986,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
sid=sid,
headers={},
client_ip=client_ip,
).__anext__() # type: ignore
)
on_set_is_hydrated_update = await process_coro.__anext__() # type: ignore
assert on_set_is_hydrated_update == StateUpdate(
delta={
state.get_name(): {
@ -899,15 +999,19 @@ async def test_dynamic_route_var_route_change_completed_on_load(
},
events=[],
)
# complete the processing
with pytest.raises(StopAsyncIteration):
await process_coro.__anext__() # type: ignore
# a simple state update event should NOT trigger on_load or route var side effects
update = await process(
process_coro = process(
app,
event=_dynamic_state_event(name="on_counter", val=exp_val),
sid=sid,
headers={},
client_ip=client_ip,
).__anext__() # type: ignore
)
update = await process_coro.__anext__() # type: ignore
assert update == StateUpdate(
delta={
state.get_name(): {
@ -919,42 +1023,54 @@ async def test_dynamic_route_var_route_change_completed_on_load(
},
events=[],
)
# complete the processing
with pytest.raises(StopAsyncIteration):
await process_coro.__anext__() # type: ignore
prev_exp_val = exp_val
state = await app.state_manager.get_state(token)
assert state.loaded == len(exp_vals)
assert state.counter == len(exp_vals)
# print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
# assert state.side_effect_counter == len(exp_vals)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.redis.close()
@pytest.mark.asyncio
async def test_process_events(gen_state, mocker):
async def test_process_events(mocker, token: str):
"""Test that an event is processed properly and that it is postprocessed
n+1 times. Also check that the processing flag of the last stateupdate is set to
False.
Args:
gen_state: The state.
mocker: mocker object.
token: a Token.
"""
router_data = {
"pathname": "/",
"query": {},
"token": "mock_token",
"token": token,
"sid": "mock_sid",
"headers": {},
"ip": "127.0.0.1",
}
app = App(state=gen_state)
app = App(state=GenState)
mocker.patch.object(app, "postprocess", AsyncMock())
event = Event(
token="token", name="gen_state.go", payload={"c": 5}, router_data=router_data
token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
)
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): # type: ignore
pass
assert app.state_manager.get_state("token").value == 5
assert (await app.state_manager.get_state(token)).value == 5
assert app.postprocess.call_count == 6
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.redis.close()
@pytest.mark.parametrize(
("state", "overlay_component", "exp_page_child"),

View File

@ -1,22 +1,42 @@
from __future__ import annotations
import asyncio
import copy
import datetime
import functools
import json
import os
import sys
from typing import Dict, List
from typing import Dict, Generator, List
from unittest.mock import AsyncMock, Mock
import pytest
from plotly.graph_objects import Figure
import reflex as rx
from reflex.base import Base
from reflex.constants import IS_HYDRATED, RouteVar
from reflex.constants import APP_VAR, IS_HYDRATED, RouteVar, SocketEvent
from reflex.event import Event, EventHandler
from reflex.state import MutableProxy, State
from reflex.utils import format
from reflex.state import (
ImmutableStateError,
LockExpiredError,
MutableProxy,
State,
StateManager,
StateManagerMemory,
StateManagerRedis,
StateProxy,
StateUpdate,
)
from reflex.utils import format, prerequisites
from reflex.vars import BaseVar, ComputedVar
from .states import GenState
CI = bool(os.environ.get("CI", False))
LOCK_EXPIRATION = 2000 if CI else 100
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
class Object(Base):
"""A test object fixture."""
@ -704,13 +724,9 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
@pytest.mark.asyncio
async def test_process_event_generator(gen_state):
"""Test event handlers that generate multiple updates.
Args:
gen_state: A state.
"""
gen_state = gen_state()
async def test_process_event_generator():
"""Test event handlers that generate multiple updates."""
gen_state = GenState() # type: ignore
event = Event(
token="t",
name="go",
@ -1402,6 +1418,396 @@ def test_state_with_invalid_yield():
)
@pytest.fixture(scope="function", params=["in_process", "redis"])
def state_manager(request) -> Generator[StateManager, None, None]:
"""Instance of state manager parametrized for redis and in-process.
Args:
request: pytest request object.
Yields:
A state manager instance
"""
state_manager = StateManager.create(state=TestState)
if request.param == "redis":
if not isinstance(state_manager, StateManagerRedis):
pytest.skip("Test requires redis")
else:
# explicitly NOT using redis
state_manager = StateManagerMemory(state=TestState)
assert not state_manager._states_locks
yield state_manager
if isinstance(state_manager, StateManagerRedis):
asyncio.get_event_loop().run_until_complete(state_manager.redis.close())
@pytest.mark.asyncio
async def test_state_manager_modify_state(state_manager: StateManager, token: str):
"""Test that the state manager can modify a state exclusively.
Args:
state_manager: A state manager instance.
token: A token.
"""
async with state_manager.modify_state(token):
if isinstance(state_manager, StateManagerRedis):
assert await state_manager.redis.get(f"{token}_lock")
elif isinstance(state_manager, StateManagerMemory):
assert token in state_manager._states_locks
assert state_manager._states_locks[token].locked()
# lock should be dropped after exiting the context
if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
elif isinstance(state_manager, StateManagerMemory):
assert not state_manager._states_locks[token].locked()
# separate instances should NOT share locks
sm2 = StateManagerMemory(state=TestState)
assert sm2._state_manager_lock is state_manager._state_manager_lock
assert not sm2._states_locks
if state_manager._states_locks:
assert sm2._states_locks != state_manager._states_locks
@pytest.mark.asyncio
async def test_state_manager_contend(state_manager: StateManager, token: str):
"""Multiple coroutines attempting to access the same state.
Args:
state_manager: A state manager instance.
token: A token.
"""
n_coroutines = 10
exp_num1 = 10
async with state_manager.modify_state(token) as state:
state.num1 = 0
async def _coro():
async with state_manager.modify_state(token) as state:
await asyncio.sleep(0.01)
state.num1 += 1
tasks = [asyncio.create_task(_coro()) for _ in range(n_coroutines)]
for f in asyncio.as_completed(tasks):
await f
assert (await state_manager.get_state(token)).num1 == exp_num1
if isinstance(state_manager, StateManagerRedis):
assert (await state_manager.redis.get(f"{token}_lock")) is None
elif isinstance(state_manager, StateManagerMemory):
assert token in state_manager._states_locks
assert not state_manager._states_locks[token].locked()
@pytest.fixture(scope="function")
def state_manager_redis() -> Generator[StateManager, None, None]:
"""Instance of state manager for redis only.
Yields:
A state manager instance
"""
state_manager = StateManager.create(TestState)
if not isinstance(state_manager, StateManagerRedis):
pytest.skip("Test requires redis")
yield state_manager
asyncio.get_event_loop().run_until_complete(state_manager.redis.close())
@pytest.mark.asyncio
async def test_state_manager_lock_expire(state_manager_redis: StateManager, token: str):
"""Test that the state manager lock expires and raises exception exiting context.
Args:
state_manager_redis: A state manager instance.
token: A token.
"""
state_manager_redis.lock_expiration = LOCK_EXPIRATION
async with state_manager_redis.modify_state(token):
await asyncio.sleep(0.01)
with pytest.raises(LockExpiredError):
async with state_manager_redis.modify_state(token):
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
@pytest.mark.asyncio
async def test_state_manager_lock_expire_contend(
state_manager_redis: StateManager, token: str
):
"""Test that the state manager lock expires and queued waiters proceed.
Args:
state_manager_redis: A state manager instance.
token: A token.
"""
exp_num1 = 4252
unexp_num1 = 666
state_manager_redis.lock_expiration = LOCK_EXPIRATION
order = []
async def _coro_blocker():
async with state_manager_redis.modify_state(token) as state:
order.append("blocker")
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
state.num1 = unexp_num1
async def _coro_waiter():
while "blocker" not in order:
await asyncio.sleep(0.005)
async with state_manager_redis.modify_state(token) as state:
order.append("waiter")
assert state.num1 != unexp_num1
state.num1 = exp_num1
tasks = [
asyncio.create_task(_coro_blocker()),
asyncio.create_task(_coro_waiter()),
]
with pytest.raises(LockExpiredError):
await tasks[0]
await tasks[1]
assert order == ["blocker", "waiter"]
assert (await state_manager_redis.get_state(token)).num1 == exp_num1
@pytest.fixture(scope="function")
def mock_app(monkeypatch, app: rx.App, state_manager: StateManager) -> rx.App:
"""Mock app fixture.
Args:
monkeypatch: Pytest monkeypatch object.
app: An app.
state_manager: A state manager.
Returns:
The app, after mocking out prerequisites.get_app()
"""
app_module = Mock()
setattr(app_module, APP_VAR, app)
app.state = TestState
app.state_manager = state_manager
assert app.event_namespace is not None
app.event_namespace.emit = AsyncMock()
monkeypatch.setattr(prerequisites, "get_app", lambda: app_module)
return app
@pytest.mark.asyncio
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
"""Test that the state proxy works.
Args:
grandchild_state: A grandchild state.
mock_app: An app that will be returned by `get_app()`
"""
child_state = grandchild_state.parent_state
assert child_state is not None
parent_state = child_state.parent_state
assert parent_state is not None
if isinstance(mock_app.state_manager, StateManagerMemory):
mock_app.state_manager.states[parent_state.get_token()] = parent_state
sp = StateProxy(grandchild_state)
assert sp.__wrapped__ == grandchild_state
assert sp._self_substate_path == grandchild_state.get_full_name().split(".")
assert sp._self_app is mock_app
assert not sp._self_mutable
assert sp._self_actx is None
# cannot use normal contextmanager protocol
with pytest.raises(TypeError), sp:
pass
with pytest.raises(ImmutableStateError):
# cannot directly modify state proxy outside of async context
sp.value2 = 16
async with sp:
assert sp._self_actx is not None
assert sp._self_mutable # proxy is mutable inside context
if isinstance(mock_app.state_manager, StateManagerMemory):
# For in-process store, only one instance of the state exists
assert sp.__wrapped__ is grandchild_state
else:
# When redis is used, a new+updated instance is assigned to the proxy
assert sp.__wrapped__ is not grandchild_state
sp.value2 = 42
assert not sp._self_mutable # proxy is not mutable after exiting context
assert sp._self_actx is None
assert sp.value2 == 42
# Get the state from the state manager directly and check that the value is updated
gotten_state = await mock_app.state_manager.get_state(grandchild_state.get_token())
if isinstance(mock_app.state_manager, StateManagerMemory):
# For in-process store, only one instance of the state exists
assert gotten_state is parent_state
else:
assert gotten_state is not parent_state
gotten_grandchild_state = gotten_state.get_substate(sp._self_substate_path)
assert gotten_grandchild_state is not None
assert gotten_grandchild_state.value2 == 42
# ensure state update was emitted
assert mock_app.event_namespace is not None
mock_app.event_namespace.emit.assert_called_once()
mcall = mock_app.event_namespace.emit.mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT)
assert json.loads(mcall.args[1]) == StateUpdate(
delta={
parent_state.get_full_name(): {
"upper": "",
"sum": 3.14,
},
grandchild_state.get_full_name(): {
"value2": 42,
},
}
)
assert mcall.kwargs["to"] == grandchild_state.get_sid()
class BackgroundTaskState(State):
"""A state with a background task."""
order: List[str] = []
dict_list: Dict[str, List[int]] = {"foo": []}
@rx.background
async def background_task(self):
"""A background task that updates the state."""
async with self:
assert not self.order
self.order.append("background_task:start")
assert isinstance(self, StateProxy)
with pytest.raises(ImmutableStateError):
self.order.append("bad idea")
with pytest.raises(ImmutableStateError):
# Even nested access to mutables raises an exception.
self.dict_list["foo"].append(42)
# wait for some other event to happen
while len(self.order) == 1:
await asyncio.sleep(0.01)
async with self:
pass # update proxy instance
async with self:
self.order.append("background_task:stop")
@rx.background
async def background_task_generator(self):
"""A background task generator that does nothing.
Yields:
None
"""
yield
def other(self):
"""Some other event that updates the state."""
self.order.append("other")
async def bad_chain1(self):
"""Test that a background task cannot be chained."""
await self.background_task()
async def bad_chain2(self):
"""Test that a background task generator cannot be chained."""
async for _foo in self.background_task_generator():
pass
@pytest.mark.asyncio
async def test_background_task_no_block(mock_app: rx.App, token: str):
"""Test that a background task does not block other events.
Args:
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
router_data = {"query": {}}
mock_app.state_manager.state = mock_app.state = BackgroundTaskState
async for update in rx.app.process( # type: ignore
mock_app,
Event(
token=token,
name=f"{BackgroundTaskState.get_name()}.background_task",
router_data=router_data,
payload={},
),
sid="",
headers={},
client_ip="",
):
# background task returns empty update immediately
assert update == StateUpdate()
assert len(mock_app.background_tasks) == 1
# wait for the coroutine to start
await asyncio.sleep(0.5 if CI else 0.1)
assert len(mock_app.background_tasks) == 1
# Process another normal event
async for update in rx.app.process( # type: ignore
mock_app,
Event(
token=token,
name=f"{BackgroundTaskState.get_name()}.other",
router_data=router_data,
payload={},
),
sid="",
headers={},
client_ip="",
):
# other task returns delta
assert update == StateUpdate(
delta={
BackgroundTaskState.get_name(): {
"order": [
"background_task:start",
"other",
],
}
}
)
# Explicit wait for background tasks
for task in tuple(mock_app.background_tasks):
await task
assert not mock_app.background_tasks
assert (await mock_app.state_manager.get_state(token)).order == [
"background_task:start",
"other",
"background_task:stop",
]
@pytest.mark.asyncio
async def test_background_task_no_chain():
"""Test that a background task cannot be chained."""
bts = BackgroundTaskState()
with pytest.raises(RuntimeError):
await bts.bad_chain1()
with pytest.raises(RuntimeError):
await bts.bad_chain2()
def test_mutable_list(mutable_state):
"""Test that mutable lists are tracked correctly.