reflex/tests/integration/test_background_task.py
Thomas Brandého 42e6dfa40d
enable PGH, bump pyright and fix all #type: ignore (#4699)
* enable PGH, bump pyright and fix  all #type: ignore

* relock poetry file

* ignore incompatible override

* fix varop tests

* ignore missing imports

* fix

* fix stuff

* fix tests

* rechange tests

* relock with poetry 2.0
2025-01-28 13:11:05 -08:00

378 lines
12 KiB
Python

"""Test @rx.event(background=True) task functionality."""
from typing import Generator
import pytest
from selenium.webdriver.common.by import By
from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
def BackgroundTask():
"""Test that background tasks work as expected."""
import asyncio
import pytest
import reflex as rx
from reflex.state import ImmutableStateError
class State(rx.State):
counter: int = 0
_task_id: int = 0
iterations: int = 10
@rx.event(background=True)
async def handle_event(self):
async with self:
self._task_id += 1
for _ix in range(int(self.iterations)):
async with self:
self.counter += 1
await asyncio.sleep(0.005)
@rx.event(background=True)
async def handle_event_yield_only(self):
async with self:
self._task_id += 1
for ix in range(int(self.iterations)):
if ix % 2 == 0:
yield State.increment_arbitrary(1)
else:
yield State.increment()
await asyncio.sleep(0.005)
@rx.event
def increment(self):
self.counter += 1
@rx.event(background=True)
async def increment_arbitrary(self, amount: int):
async with self:
self.counter += int(amount)
@rx.event
def reset_counter(self):
self.counter = 0
@rx.event
async def blocking_pause(self):
await asyncio.sleep(0.02)
@rx.event(background=True)
async def non_blocking_pause(self):
await asyncio.sleep(0.02)
async def racy_task(self):
async with self:
self._task_id += 1
for _ix in range(int(self.iterations)):
async with self:
self.counter += 1
await asyncio.sleep(0.005)
@rx.event(background=True)
async def handle_racy_event(self):
await asyncio.gather(
self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
)
@rx.event(background=True)
async def nested_async_with_self(self):
async with self:
self.counter += 1
with pytest.raises(ImmutableStateError):
async with self:
self.counter += 1
async def triple_count(self):
third_state = await self.get_state(ThirdState)
await third_state._triple_count()
@rx.event(background=True)
async def yield_in_async_with_self(self):
async with self:
self.counter += 1
yield
self.counter += 1
class OtherState(rx.State):
@rx.event(background=True)
async def get_other_state(self):
async with self:
state = await self.get_state(State)
state.counter += 1
await state.triple_count()
with pytest.raises(ImmutableStateError):
await state.triple_count()
with pytest.raises(ImmutableStateError):
state.counter += 1
async with state:
state.counter += 1
await state.triple_count()
class ThirdState(rx.State):
async def _triple_count(self):
state = await self.get_state(State)
state.counter *= 3
def index() -> rx.Component:
return rx.vstack(
rx.input(
id="token", value=State.router.session.client_token, is_read_only=True
),
rx.heading(State.counter, id="counter"),
rx.input(
id="iterations",
placeholder="Iterations",
value=State.iterations.to_string(), # pyright: ignore [reportAttributeAccessIssue]
on_change=State.set_iterations, # pyright: ignore [reportAttributeAccessIssue]
),
rx.button(
"Delayed Increment",
on_click=State.handle_event,
id="delayed-increment",
),
rx.button(
"Yield Increment",
on_click=State.handle_event_yield_only,
id="yield-increment",
),
rx.button("Increment 1", on_click=State.increment, id="increment"),
rx.button(
"Blocking Pause",
on_click=State.blocking_pause,
id="blocking-pause",
),
rx.button(
"Non-Blocking Pause",
on_click=State.non_blocking_pause,
id="non-blocking-pause",
),
rx.button(
"Racy Increment (x4)",
on_click=State.handle_racy_event,
id="racy-increment",
),
rx.button(
"Nested Async with Self",
on_click=State.nested_async_with_self,
id="nested-async-with-self",
),
rx.button(
"Increment from OtherState",
on_click=OtherState.get_other_state,
id="increment-from-other-state",
),
rx.button(
"Yield in Async with Self",
on_click=State.yield_in_async_with_self,
id="yield-in-async-with-self",
),
rx.button("Reset", on_click=State.reset_counter, id="reset"),
)
app = rx.App(_state=rx.State)
app.add_page(index)
@pytest.fixture(scope="module")
def background_task(
tmp_path_factory,
) -> Generator[AppHarness, None, None]:
"""Start BackgroundTask app at tmp_path via AppHarness.
Args:
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
running AppHarness instance
"""
with AppHarness.create(
root=tmp_path_factory.mktemp("background_task"),
app_source=BackgroundTask,
) as harness:
yield harness
@pytest.fixture
def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
"""Get an instance of the browser open to the background_task app.
Args:
background_task: harness for BackgroundTask app
Yields:
WebDriver instance.
"""
assert background_task.app_instance is not None, "app is not running"
driver = background_task.frontend()
try:
yield driver
finally:
driver.quit()
@pytest.fixture()
def token(background_task: AppHarness, driver: WebDriver) -> str:
"""Get a function that returns the active token.
Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
Returns:
The token for the connected client
"""
assert background_task.app_instance is not None
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
assert token is not None
return token
def test_background_task(
background_task: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that background tasks work as expected.
Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert background_task.app_instance is not None
# get a reference to all buttons
delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
yield_increment_button = driver.find_element(By.ID, "yield-increment")
increment_button = driver.find_element(By.ID, "increment")
blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
racy_increment_button = driver.find_element(By.ID, "racy-increment")
driver.find_element(By.ID, "reset")
# get a reference to the counter
counter = driver.find_element(By.ID, "counter")
# get a reference to the iterations input
iterations_input = driver.find_element(By.ID, "iterations")
# kick off background tasks
iterations_input.clear()
iterations_input.send_keys("50")
delayed_increment_button.click()
blocking_pause_button.click()
delayed_increment_button.click()
for _ in range(10):
increment_button.click()
blocking_pause_button.click()
delayed_increment_button.click()
delayed_increment_button.click()
yield_increment_button.click()
racy_increment_button.click()
non_blocking_pause_button.click()
yield_increment_button.click()
blocking_pause_button.click()
yield_increment_button.click()
for _ in range(10):
increment_button.click()
yield_increment_button.click()
blocking_pause_button.click()
assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
# all tasks should have exited and cleaned up
assert background_task._poll_for(
lambda: not background_task.app_instance._background_tasks # pyright: ignore [reportOptionalMemberAccess]
)
def test_nested_async_with_self(
background_task: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that nested async with self in the same coroutine raises Exception.
Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert background_task.app_instance is not None
# get a reference to all buttons
nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
increment_button = driver.find_element(By.ID, "increment")
# get a reference to the counter
counter = driver.find_element(By.ID, "counter")
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
nested_async_with_self_button.click()
assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
increment_button.click()
assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
def test_get_state(
background_task: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that get_state returns a state bound to the correct StateProxy.
Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert background_task.app_instance is not None
# get a reference to all buttons
other_state_button = driver.find_element(By.ID, "increment-from-other-state")
increment_button = driver.find_element(By.ID, "increment")
# get a reference to the counter
counter = driver.find_element(By.ID, "counter")
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
other_state_button.click()
assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
increment_button.click()
assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
def test_yield_in_async_with_self(
background_task: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that yielding inside async with self does not disable mutability.
Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert background_task.app_instance is not None
# get a reference to all buttons
yield_in_async_with_self_button = driver.find_element(
By.ID, "yield-in-async-with-self"
)
# get a reference to the counter
counter = driver.find_element(By.ID, "counter")
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
yield_in_async_with_self_button.click()
assert background_task._poll_for(lambda: counter.text == "2", timeout=5)