rx.background and StateManager.modify_state provides safe exclusive access to state (#1676)
This commit is contained in:
parent
211dc15995
commit
351611ca25
17
.github/workflows/integration_app_harness.yml
vendored
17
.github/workflows/integration_app_harness.yml
vendored
@ -15,7 +15,23 @@ permissions:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
integration-app-harness:
|
integration-app-harness:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
state_manager: [ "redis", "memory" ]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
services:
|
||||||
|
# Label used to access the service container
|
||||||
|
redis:
|
||||||
|
image: ${{ matrix.state_manager == 'redis' && 'redis' || '' }}
|
||||||
|
# Set health checks to wait until redis has started
|
||||||
|
options: >-
|
||||||
|
--health-cmd "redis-cli ping"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
ports:
|
||||||
|
# Maps port 6379 on service container to the host
|
||||||
|
- 6379:6379
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: ./.github/actions/setup_build_env
|
- uses: ./.github/actions/setup_build_env
|
||||||
@ -27,6 +43,7 @@ jobs:
|
|||||||
- name: Run app harness tests
|
- name: Run app harness tests
|
||||||
env:
|
env:
|
||||||
SCREENSHOT_DIR: /tmp/screenshots
|
SCREENSHOT_DIR: /tmp/screenshots
|
||||||
|
REDIS_URL: ${{ matrix.state_manager == 'redis' && 'localhost:6379' || '' }}
|
||||||
run: |
|
run: |
|
||||||
poetry run pytest integration
|
poetry run pytest integration
|
||||||
- uses: actions/upload-artifact@v3
|
- uses: actions/upload-artifact@v3
|
||||||
|
20
.github/workflows/unit_tests.yml
vendored
20
.github/workflows/unit_tests.yml
vendored
@ -40,6 +40,20 @@ jobs:
|
|||||||
- os: windows-latest
|
- os: windows-latest
|
||||||
python-version: "3.8.10"
|
python-version: "3.8.10"
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
# Service containers to run with `runner-job`
|
||||||
|
services:
|
||||||
|
# Label used to access the service container
|
||||||
|
redis:
|
||||||
|
image: ${{ matrix.os == 'ubuntu-latest' && 'redis' || '' }}
|
||||||
|
# Set health checks to wait until redis has started
|
||||||
|
options: >-
|
||||||
|
--health-cmd "redis-cli ping"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
ports:
|
||||||
|
# Maps port 6379 on service container to the host
|
||||||
|
- 6379:6379
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: ./.github/actions/setup_build_env
|
- uses: ./.github/actions/setup_build_env
|
||||||
@ -51,4 +65,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
export PYTHONUNBUFFERED=1
|
export PYTHONUNBUFFERED=1
|
||||||
poetry run pytest tests --cov --no-cov-on-fail --cov-report=
|
poetry run pytest tests --cov --no-cov-on-fail --cov-report=
|
||||||
|
- name: Run unit tests w/ redis
|
||||||
|
if: ${{ matrix.os == 'ubuntu-latest' }}
|
||||||
|
run: |
|
||||||
|
export PYTHONUNBUFFERED=1
|
||||||
|
export REDIS_URL=localhost:6379
|
||||||
|
poetry run pytest tests --cov --no-cov-on-fail --cov-report=
|
||||||
- run: poetry run coverage html
|
- run: poetry run coverage html
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
repos:
|
repos:
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 22.10.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
args: [integration, reflex, tests]
|
||||||
|
|
||||||
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
||||||
rev: v0.0.244
|
rev: v0.0.244
|
||||||
hooks:
|
hooks:
|
||||||
@ -17,9 +23,3 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: darglint
|
- id: darglint
|
||||||
exclude: '^reflex/reflex.py'
|
exclude: '^reflex/reflex.py'
|
||||||
|
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 22.10.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
args: [integration, reflex, tests]
|
|
||||||
|
214
integration/test_background_task.py
Normal file
214
integration/test_background_task.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
"""Test @rx.background task functionality."""
|
||||||
|
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from selenium.webdriver.common.by import By
|
||||||
|
|
||||||
|
from reflex.testing import DEFAULT_TIMEOUT, AppHarness, WebDriver
|
||||||
|
|
||||||
|
|
||||||
|
def BackgroundTask():
|
||||||
|
"""Test that background tasks work as expected."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import reflex as rx
|
||||||
|
|
||||||
|
class State(rx.State):
|
||||||
|
counter: int = 0
|
||||||
|
_task_id: int = 0
|
||||||
|
iterations: int = 10
|
||||||
|
|
||||||
|
@rx.background
|
||||||
|
async def handle_event(self):
|
||||||
|
async with self:
|
||||||
|
self._task_id += 1
|
||||||
|
for _ix in range(int(self.iterations)):
|
||||||
|
async with self:
|
||||||
|
self.counter += 1
|
||||||
|
await asyncio.sleep(0.005)
|
||||||
|
|
||||||
|
@rx.background
|
||||||
|
async def handle_event_yield_only(self):
|
||||||
|
async with self:
|
||||||
|
self._task_id += 1
|
||||||
|
for ix in range(int(self.iterations)):
|
||||||
|
if ix % 2 == 0:
|
||||||
|
yield State.increment_arbitrary(1) # type: ignore
|
||||||
|
else:
|
||||||
|
yield State.increment() # type: ignore
|
||||||
|
await asyncio.sleep(0.005)
|
||||||
|
|
||||||
|
def increment(self):
|
||||||
|
self.counter += 1
|
||||||
|
|
||||||
|
@rx.background
|
||||||
|
async def increment_arbitrary(self, amount: int):
|
||||||
|
async with self:
|
||||||
|
self.counter += int(amount)
|
||||||
|
|
||||||
|
def reset_counter(self):
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
async def blocking_pause(self):
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
|
||||||
|
@rx.background
|
||||||
|
async def non_blocking_pause(self):
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
|
||||||
|
@rx.cached_var
|
||||||
|
def token(self) -> str:
|
||||||
|
return self.get_token()
|
||||||
|
|
||||||
|
def index() -> rx.Component:
|
||||||
|
return rx.vstack(
|
||||||
|
rx.input(id="token", value=State.token, is_read_only=True),
|
||||||
|
rx.heading(State.counter, id="counter"),
|
||||||
|
rx.input(
|
||||||
|
id="iterations",
|
||||||
|
placeholder="Iterations",
|
||||||
|
value=State.iterations.to_string(), # type: ignore
|
||||||
|
on_change=State.set_iterations, # type: ignore
|
||||||
|
),
|
||||||
|
rx.button(
|
||||||
|
"Delayed Increment",
|
||||||
|
on_click=State.handle_event,
|
||||||
|
id="delayed-increment",
|
||||||
|
),
|
||||||
|
rx.button(
|
||||||
|
"Yield Increment",
|
||||||
|
on_click=State.handle_event_yield_only,
|
||||||
|
id="yield-increment",
|
||||||
|
),
|
||||||
|
rx.button("Increment 1", on_click=State.increment, id="increment"),
|
||||||
|
rx.button(
|
||||||
|
"Blocking Pause",
|
||||||
|
on_click=State.blocking_pause,
|
||||||
|
id="blocking-pause",
|
||||||
|
),
|
||||||
|
rx.button(
|
||||||
|
"Non-Blocking Pause",
|
||||||
|
on_click=State.non_blocking_pause,
|
||||||
|
id="non-blocking-pause",
|
||||||
|
),
|
||||||
|
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
||||||
|
)
|
||||||
|
|
||||||
|
app = rx.App(state=State)
|
||||||
|
app.add_page(index)
|
||||||
|
app.compile()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def background_task(
|
||||||
|
tmp_path_factory,
|
||||||
|
) -> Generator[AppHarness, None, None]:
|
||||||
|
"""Start BackgroundTask app at tmp_path via AppHarness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmp_path_factory: pytest tmp_path_factory fixture
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
running AppHarness instance
|
||||||
|
"""
|
||||||
|
with AppHarness.create(
|
||||||
|
root=tmp_path_factory.mktemp(f"background_task"),
|
||||||
|
app_source=BackgroundTask, # type: ignore
|
||||||
|
) as harness:
|
||||||
|
yield harness
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def driver(background_task: AppHarness) -> Generator[WebDriver, None, None]:
|
||||||
|
"""Get an instance of the browser open to the background_task app.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
background_task: harness for BackgroundTask app
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
WebDriver instance.
|
||||||
|
"""
|
||||||
|
assert background_task.app_instance is not None, "app is not running"
|
||||||
|
driver = background_task.frontend()
|
||||||
|
try:
|
||||||
|
yield driver
|
||||||
|
finally:
|
||||||
|
driver.quit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def token(background_task: AppHarness, driver: WebDriver) -> str:
|
||||||
|
"""Get a function that returns the active token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
background_task: harness for BackgroundTask app.
|
||||||
|
driver: WebDriver instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The token for the connected client
|
||||||
|
"""
|
||||||
|
assert background_task.app_instance is not None
|
||||||
|
token_input = driver.find_element(By.ID, "token")
|
||||||
|
assert token_input
|
||||||
|
|
||||||
|
# wait for the backend connection to send the token
|
||||||
|
token = background_task.poll_for_value(token_input, timeout=DEFAULT_TIMEOUT * 2)
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def test_background_task(
|
||||||
|
background_task: AppHarness,
|
||||||
|
driver: WebDriver,
|
||||||
|
token: str,
|
||||||
|
):
|
||||||
|
"""Test that background tasks work as expected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
background_task: harness for BackgroundTask app.
|
||||||
|
driver: WebDriver instance.
|
||||||
|
token: The token for the connected client.
|
||||||
|
"""
|
||||||
|
assert background_task.app_instance is not None
|
||||||
|
|
||||||
|
# get a reference to all buttons
|
||||||
|
delayed_increment_button = driver.find_element(By.ID, "delayed-increment")
|
||||||
|
yield_increment_button = driver.find_element(By.ID, "yield-increment")
|
||||||
|
increment_button = driver.find_element(By.ID, "increment")
|
||||||
|
blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
|
||||||
|
non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
|
||||||
|
driver.find_element(By.ID, "reset")
|
||||||
|
|
||||||
|
# get a reference to the counter
|
||||||
|
counter = driver.find_element(By.ID, "counter")
|
||||||
|
|
||||||
|
# get a reference to the iterations input
|
||||||
|
iterations_input = driver.find_element(By.ID, "iterations")
|
||||||
|
|
||||||
|
# kick off background tasks
|
||||||
|
iterations_input.clear()
|
||||||
|
iterations_input.send_keys("50")
|
||||||
|
delayed_increment_button.click()
|
||||||
|
blocking_pause_button.click()
|
||||||
|
delayed_increment_button.click()
|
||||||
|
for _ in range(10):
|
||||||
|
increment_button.click()
|
||||||
|
blocking_pause_button.click()
|
||||||
|
delayed_increment_button.click()
|
||||||
|
delayed_increment_button.click()
|
||||||
|
yield_increment_button.click()
|
||||||
|
non_blocking_pause_button.click()
|
||||||
|
yield_increment_button.click()
|
||||||
|
blocking_pause_button.click()
|
||||||
|
yield_increment_button.click()
|
||||||
|
for _ in range(10):
|
||||||
|
increment_button.click()
|
||||||
|
yield_increment_button.click()
|
||||||
|
blocking_pause_button.click()
|
||||||
|
assert background_task._poll_for(lambda: counter.text == "420", timeout=40)
|
||||||
|
# all tasks should have exited and cleaned up
|
||||||
|
assert background_task._poll_for(
|
||||||
|
lambda: not background_task.app_instance.background_tasks # type: ignore
|
||||||
|
)
|
@ -133,7 +133,6 @@ def driver(client_side: AppHarness) -> Generator[WebDriver, None, None]:
|
|||||||
assert client_side.app_instance is not None, "app is not running"
|
assert client_side.app_instance is not None, "app is not running"
|
||||||
driver = client_side.frontend()
|
driver = client_side.frontend()
|
||||||
try:
|
try:
|
||||||
assert client_side.poll_for_clients()
|
|
||||||
yield driver
|
yield driver
|
||||||
finally:
|
finally:
|
||||||
driver.quit()
|
driver.quit()
|
||||||
@ -168,7 +167,20 @@ def delete_all_cookies(driver: WebDriver) -> Generator[None, None, None]:
|
|||||||
driver.delete_all_cookies()
|
driver.delete_all_cookies()
|
||||||
|
|
||||||
|
|
||||||
def test_client_side_state(
|
def cookie_info_map(driver: WebDriver) -> dict[str, dict[str, str]]:
|
||||||
|
"""Get a map of cookie names to cookie info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
driver: WebDriver instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A map of cookie names to cookie info.
|
||||||
|
"""
|
||||||
|
return {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_side_state(
|
||||||
client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage
|
client_side: AppHarness, driver: WebDriver, local_storage: utils.LocalStorage
|
||||||
):
|
):
|
||||||
"""Test client side state.
|
"""Test client side state.
|
||||||
@ -187,8 +199,6 @@ def test_client_side_state(
|
|||||||
token = client_side.poll_for_value(token_input)
|
token = client_side.poll_for_value(token_input)
|
||||||
assert token is not None
|
assert token is not None
|
||||||
|
|
||||||
backend_state = client_side.app_instance.state_manager.states[token]
|
|
||||||
|
|
||||||
# get a reference to the cookie manipulation form
|
# get a reference to the cookie manipulation form
|
||||||
state_var_input = driver.find_element(By.ID, "state_var")
|
state_var_input = driver.find_element(By.ID, "state_var")
|
||||||
input_value_input = driver.find_element(By.ID, "input_value")
|
input_value_input = driver.find_element(By.ID, "input_value")
|
||||||
@ -274,7 +284,7 @@ def test_client_side_state(
|
|||||||
input_value_input.send_keys("l1s value")
|
input_value_input.send_keys("l1s value")
|
||||||
set_sub_sub_state_button.click()
|
set_sub_sub_state_button.click()
|
||||||
|
|
||||||
cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
|
cookies = cookie_info_map(driver)
|
||||||
assert cookies.pop("client_side_state.client_side_sub_state.c1") == {
|
assert cookies.pop("client_side_state.client_side_sub_state.c1") == {
|
||||||
"domain": "localhost",
|
"domain": "localhost",
|
||||||
"httpOnly": False,
|
"httpOnly": False,
|
||||||
@ -338,8 +348,10 @@ def test_client_side_state(
|
|||||||
state_var_input.send_keys("c3")
|
state_var_input.send_keys("c3")
|
||||||
input_value_input.send_keys("c3 value")
|
input_value_input.send_keys("c3 value")
|
||||||
set_sub_state_button.click()
|
set_sub_state_button.click()
|
||||||
cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
|
AppHarness._poll_for(
|
||||||
c3_cookie = cookies["client_side_state.client_side_sub_state.c3"]
|
lambda: "client_side_state.client_side_sub_state.c3" in cookie_info_map(driver)
|
||||||
|
)
|
||||||
|
c3_cookie = cookie_info_map(driver)["client_side_state.client_side_sub_state.c3"]
|
||||||
assert c3_cookie.pop("expiry") is not None
|
assert c3_cookie.pop("expiry") is not None
|
||||||
assert c3_cookie == {
|
assert c3_cookie == {
|
||||||
"domain": "localhost",
|
"domain": "localhost",
|
||||||
@ -351,9 +363,7 @@ def test_client_side_state(
|
|||||||
"value": "c3%20value",
|
"value": "c3%20value",
|
||||||
}
|
}
|
||||||
time.sleep(2) # wait for c3 to expire
|
time.sleep(2) # wait for c3 to expire
|
||||||
assert "client_side_state.client_side_sub_state.c3" not in {
|
assert "client_side_state.client_side_sub_state.c3" not in cookie_info_map(driver)
|
||||||
cookie_info["name"] for cookie_info in driver.get_cookies()
|
|
||||||
}
|
|
||||||
|
|
||||||
local_storage_items = local_storage.items()
|
local_storage_items = local_storage.items()
|
||||||
local_storage_items.pop("chakra-ui-color-mode", None)
|
local_storage_items.pop("chakra-ui-color-mode", None)
|
||||||
@ -426,7 +436,8 @@ def test_client_side_state(
|
|||||||
assert l1s.text == "l1s value"
|
assert l1s.text == "l1s value"
|
||||||
|
|
||||||
# reset the backend state to force refresh from client storage
|
# reset the backend state to force refresh from client storage
|
||||||
backend_state.reset()
|
async with client_side.modify_state(token) as state:
|
||||||
|
state.reset()
|
||||||
driver.refresh()
|
driver.refresh()
|
||||||
|
|
||||||
# wait for the backend connection to send the token (again)
|
# wait for the backend connection to send the token (again)
|
||||||
@ -465,9 +476,7 @@ def test_client_side_state(
|
|||||||
assert l1s.text == "l1s value"
|
assert l1s.text == "l1s value"
|
||||||
|
|
||||||
# make sure c5 cookie shows up on the `/foo` route
|
# make sure c5 cookie shows up on the `/foo` route
|
||||||
cookies = {cookie_info["name"]: cookie_info for cookie_info in driver.get_cookies()}
|
assert cookie_info_map(driver)["client_side_state.client_side_sub_state.c5"] == {
|
||||||
|
|
||||||
assert cookies["client_side_state.client_side_sub_state.c5"] == {
|
|
||||||
"domain": "localhost",
|
"domain": "localhost",
|
||||||
"httpOnly": False,
|
"httpOnly": False,
|
||||||
"name": "client_side_state.client_side_sub_state.c5",
|
"name": "client_side_state.client_side_sub_state.c5",
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
"""Integration tests for dynamic route page behavior."""
|
"""Integration tests for dynamic route page behavior."""
|
||||||
from typing import Callable, Generator, Type
|
from typing import Callable, Coroutine, Generator, Type
|
||||||
from urllib.parse import urlsplit
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
|
|
||||||
from reflex import State
|
|
||||||
from reflex.testing import AppHarness, AppHarnessProd, WebDriver
|
from reflex.testing import AppHarness, AppHarnessProd, WebDriver
|
||||||
|
|
||||||
from .utils import poll_for_navigation
|
from .utils import poll_for_navigation
|
||||||
@ -100,22 +99,21 @@ def driver(dynamic_route: AppHarness) -> Generator[WebDriver, None, None]:
|
|||||||
assert dynamic_route.app_instance is not None, "app is not running"
|
assert dynamic_route.app_instance is not None, "app is not running"
|
||||||
driver = dynamic_route.frontend()
|
driver = dynamic_route.frontend()
|
||||||
try:
|
try:
|
||||||
assert dynamic_route.poll_for_clients()
|
|
||||||
yield driver
|
yield driver
|
||||||
finally:
|
finally:
|
||||||
driver.quit()
|
driver.quit()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
|
def token(dynamic_route: AppHarness, driver: WebDriver) -> str:
|
||||||
"""Get the backend state.
|
"""Get the token associated with backend state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dynamic_route: harness for DynamicRoute app.
|
dynamic_route: harness for DynamicRoute app.
|
||||||
driver: WebDriver instance.
|
driver: WebDriver instance.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The backend state associated with the token visible in the driver browser.
|
The token visible in the driver browser.
|
||||||
"""
|
"""
|
||||||
assert dynamic_route.app_instance is not None
|
assert dynamic_route.app_instance is not None
|
||||||
token_input = driver.find_element(By.ID, "token")
|
token_input = driver.find_element(By.ID, "token")
|
||||||
@ -125,43 +123,49 @@ def backend_state(dynamic_route: AppHarness, driver: WebDriver) -> State:
|
|||||||
token = dynamic_route.poll_for_value(token_input)
|
token = dynamic_route.poll_for_value(token_input)
|
||||||
assert token is not None
|
assert token is not None
|
||||||
|
|
||||||
# look up the backend state from the state manager
|
return token
|
||||||
return dynamic_route.app_instance.state_manager.states[token]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def poll_for_order(
|
def poll_for_order(
|
||||||
dynamic_route: AppHarness, backend_state: State
|
dynamic_route: AppHarness, token: str
|
||||||
) -> Callable[[list[str]], None]:
|
) -> Callable[[list[str]], Coroutine[None, None, None]]:
|
||||||
"""Poll for the order list to match the expected order.
|
"""Poll for the order list to match the expected order.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dynamic_route: harness for DynamicRoute app.
|
dynamic_route: harness for DynamicRoute app.
|
||||||
backend_state: The backend state associated with the token visible in the driver browser.
|
token: The token visible in the driver browser.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A function that polls for the order list to match the expected order.
|
An async function that polls for the order list to match the expected order.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _poll_for_order(exp_order: list[str]):
|
async def _poll_for_order(exp_order: list[str]):
|
||||||
dynamic_route._poll_for(lambda: backend_state.order == exp_order)
|
async def _backend_state():
|
||||||
assert backend_state.order == exp_order
|
return await dynamic_route.get_state(token)
|
||||||
|
|
||||||
|
async def _check():
|
||||||
|
return (await _backend_state()).order == exp_order
|
||||||
|
|
||||||
|
await AppHarness._poll_for_async(_check)
|
||||||
|
assert (await _backend_state()).order == exp_order
|
||||||
|
|
||||||
return _poll_for_order
|
return _poll_for_order
|
||||||
|
|
||||||
|
|
||||||
def test_on_load_navigate(
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_load_navigate(
|
||||||
dynamic_route: AppHarness,
|
dynamic_route: AppHarness,
|
||||||
driver: WebDriver,
|
driver: WebDriver,
|
||||||
backend_state: State,
|
token: str,
|
||||||
poll_for_order: Callable[[list[str]], None],
|
poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
|
||||||
):
|
):
|
||||||
"""Click links to navigate between dynamic pages with on_load event.
|
"""Click links to navigate between dynamic pages with on_load event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dynamic_route: harness for DynamicRoute app.
|
dynamic_route: harness for DynamicRoute app.
|
||||||
driver: WebDriver instance.
|
driver: WebDriver instance.
|
||||||
backend_state: The backend state associated with the token visible in the driver browser.
|
token: The token visible in the driver browser.
|
||||||
poll_for_order: function that polls for the order list to match the expected order.
|
poll_for_order: function that polls for the order list to match the expected order.
|
||||||
"""
|
"""
|
||||||
assert dynamic_route.app_instance is not None
|
assert dynamic_route.app_instance is not None
|
||||||
@ -184,7 +188,7 @@ def test_on_load_navigate(
|
|||||||
assert page_id_input
|
assert page_id_input
|
||||||
|
|
||||||
assert dynamic_route.poll_for_value(page_id_input) == str(ix)
|
assert dynamic_route.poll_for_value(page_id_input) == str(ix)
|
||||||
poll_for_order(exp_order)
|
await poll_for_order(exp_order)
|
||||||
|
|
||||||
# manually load the next page to trigger client side routing in prod mode
|
# manually load the next page to trigger client side routing in prod mode
|
||||||
if is_prod:
|
if is_prod:
|
||||||
@ -192,14 +196,14 @@ def test_on_load_navigate(
|
|||||||
exp_order += ["/page/[page_id]-10"]
|
exp_order += ["/page/[page_id]-10"]
|
||||||
with poll_for_navigation(driver):
|
with poll_for_navigation(driver):
|
||||||
driver.get(f"{dynamic_route.frontend_url}/page/10/")
|
driver.get(f"{dynamic_route.frontend_url}/page/10/")
|
||||||
poll_for_order(exp_order)
|
await poll_for_order(exp_order)
|
||||||
|
|
||||||
# make sure internal nav still hydrates after redirect
|
# make sure internal nav still hydrates after redirect
|
||||||
exp_order += ["/page/[page_id]-11"]
|
exp_order += ["/page/[page_id]-11"]
|
||||||
link = driver.find_element(By.ID, "link_page_next")
|
link = driver.find_element(By.ID, "link_page_next")
|
||||||
with poll_for_navigation(driver):
|
with poll_for_navigation(driver):
|
||||||
link.click()
|
link.click()
|
||||||
poll_for_order(exp_order)
|
await poll_for_order(exp_order)
|
||||||
|
|
||||||
# load same page with a query param and make sure it passes through
|
# load same page with a query param and make sure it passes through
|
||||||
if is_prod:
|
if is_prod:
|
||||||
@ -207,14 +211,14 @@ def test_on_load_navigate(
|
|||||||
exp_order += ["/page/[page_id]-11"]
|
exp_order += ["/page/[page_id]-11"]
|
||||||
with poll_for_navigation(driver):
|
with poll_for_navigation(driver):
|
||||||
driver.get(f"{driver.current_url}?foo=bar")
|
driver.get(f"{driver.current_url}?foo=bar")
|
||||||
poll_for_order(exp_order)
|
await poll_for_order(exp_order)
|
||||||
assert backend_state.get_query_params()["foo"] == "bar"
|
assert (await dynamic_route.get_state(token)).get_query_params()["foo"] == "bar"
|
||||||
|
|
||||||
# hit a 404 and ensure we still hydrate
|
# hit a 404 and ensure we still hydrate
|
||||||
exp_order += ["/404-no page id"]
|
exp_order += ["/404-no page id"]
|
||||||
with poll_for_navigation(driver):
|
with poll_for_navigation(driver):
|
||||||
driver.get(f"{dynamic_route.frontend_url}/missing")
|
driver.get(f"{dynamic_route.frontend_url}/missing")
|
||||||
poll_for_order(exp_order)
|
await poll_for_order(exp_order)
|
||||||
|
|
||||||
# browser nav should still trigger hydration
|
# browser nav should still trigger hydration
|
||||||
if is_prod:
|
if is_prod:
|
||||||
@ -222,14 +226,14 @@ def test_on_load_navigate(
|
|||||||
exp_order += ["/page/[page_id]-11"]
|
exp_order += ["/page/[page_id]-11"]
|
||||||
with poll_for_navigation(driver):
|
with poll_for_navigation(driver):
|
||||||
driver.back()
|
driver.back()
|
||||||
poll_for_order(exp_order)
|
await poll_for_order(exp_order)
|
||||||
|
|
||||||
# next/link to a 404 and ensure we still hydrate
|
# next/link to a 404 and ensure we still hydrate
|
||||||
exp_order += ["/404-no page id"]
|
exp_order += ["/404-no page id"]
|
||||||
link = driver.find_element(By.ID, "link_missing")
|
link = driver.find_element(By.ID, "link_missing")
|
||||||
with poll_for_navigation(driver):
|
with poll_for_navigation(driver):
|
||||||
link.click()
|
link.click()
|
||||||
poll_for_order(exp_order)
|
await poll_for_order(exp_order)
|
||||||
|
|
||||||
# hit a page that redirects back to dynamic page
|
# hit a page that redirects back to dynamic page
|
||||||
if is_prod:
|
if is_prod:
|
||||||
@ -237,15 +241,16 @@ def test_on_load_navigate(
|
|||||||
exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"]
|
exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"]
|
||||||
with poll_for_navigation(driver):
|
with poll_for_navigation(driver):
|
||||||
driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar")
|
driver.get(f"{dynamic_route.frontend_url}/redirect-page/0/?foo=bar")
|
||||||
poll_for_order(exp_order)
|
await poll_for_order(exp_order)
|
||||||
# should have redirected back to page 0
|
# should have redirected back to page 0
|
||||||
assert urlsplit(driver.current_url).path == "/page/0/"
|
assert urlsplit(driver.current_url).path == "/page/0/"
|
||||||
|
|
||||||
|
|
||||||
def test_on_load_navigate_non_dynamic(
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_load_navigate_non_dynamic(
|
||||||
dynamic_route: AppHarness,
|
dynamic_route: AppHarness,
|
||||||
driver: WebDriver,
|
driver: WebDriver,
|
||||||
poll_for_order: Callable[[list[str]], None],
|
poll_for_order: Callable[[list[str]], Coroutine[None, None, None]],
|
||||||
):
|
):
|
||||||
"""Click links to navigate between static pages with on_load event.
|
"""Click links to navigate between static pages with on_load event.
|
||||||
|
|
||||||
@ -261,7 +266,7 @@ def test_on_load_navigate_non_dynamic(
|
|||||||
with poll_for_navigation(driver):
|
with poll_for_navigation(driver):
|
||||||
link.click()
|
link.click()
|
||||||
assert urlsplit(driver.current_url).path == "/static/x/"
|
assert urlsplit(driver.current_url).path == "/static/x/"
|
||||||
poll_for_order(["/static/x-no page id"])
|
await poll_for_order(["/static/x-no page id"])
|
||||||
|
|
||||||
# go back to the index and navigate back to the static route
|
# go back to the index and navigate back to the static route
|
||||||
link = driver.find_element(By.ID, "link_index")
|
link = driver.find_element(By.ID, "link_index")
|
||||||
@ -273,4 +278,4 @@ def test_on_load_navigate_non_dynamic(
|
|||||||
with poll_for_navigation(driver):
|
with poll_for_navigation(driver):
|
||||||
link.click()
|
link.click()
|
||||||
assert urlsplit(driver.current_url).path == "/static/x/"
|
assert urlsplit(driver.current_url).path == "/static/x/"
|
||||||
poll_for_order(["/static/x-no page id", "/static/x-no page id"])
|
await poll_for_order(["/static/x-no page id", "/static/x-no page id"])
|
||||||
|
@ -1,18 +1,20 @@
|
|||||||
"""Ensure that Event Chains are properly queued and handled between frontend and backend."""
|
"""Ensure that Event Chains are properly queued and handled between frontend and backend."""
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from selenium.webdriver.common.by import By
|
from selenium.webdriver.common.by import By
|
||||||
|
|
||||||
from reflex.testing import AppHarness
|
from reflex.testing import AppHarness, WebDriver
|
||||||
|
|
||||||
MANY_EVENTS = 50
|
MANY_EVENTS = 50
|
||||||
|
|
||||||
|
|
||||||
def EventChain():
|
def EventChain():
|
||||||
"""App with chained event handlers."""
|
"""App with chained event handlers."""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
import reflex as rx
|
import reflex as rx
|
||||||
|
|
||||||
# repeated here since the outer global isn't exported into the App module
|
# repeated here since the outer global isn't exported into the App module
|
||||||
@ -20,6 +22,7 @@ def EventChain():
|
|||||||
|
|
||||||
class State(rx.State):
|
class State(rx.State):
|
||||||
event_order: list[str] = []
|
event_order: list[str] = []
|
||||||
|
interim_value: str = ""
|
||||||
|
|
||||||
@rx.var
|
@rx.var
|
||||||
def token(self) -> str:
|
def token(self) -> str:
|
||||||
@ -111,12 +114,25 @@ def EventChain():
|
|||||||
self.event_order.append("click_return_dict_type")
|
self.event_order.append("click_return_dict_type")
|
||||||
return State.event_arg_repr_type({"a": 1}) # type: ignore
|
return State.event_arg_repr_type({"a": 1}) # type: ignore
|
||||||
|
|
||||||
|
async def click_yield_interim_value_async(self):
|
||||||
|
self.interim_value = "interim"
|
||||||
|
yield
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
self.interim_value = "final"
|
||||||
|
|
||||||
|
def click_yield_interim_value(self):
|
||||||
|
self.interim_value = "interim"
|
||||||
|
yield
|
||||||
|
time.sleep(0.5)
|
||||||
|
self.interim_value = "final"
|
||||||
|
|
||||||
app = rx.App(state=State)
|
app = rx.App(state=State)
|
||||||
|
|
||||||
@app.add_page
|
@app.add_page
|
||||||
def index():
|
def index():
|
||||||
return rx.fragment(
|
return rx.fragment(
|
||||||
rx.input(value=State.token, readonly=True, id="token"),
|
rx.input(value=State.token, readonly=True, id="token"),
|
||||||
|
rx.input(value=State.interim_value, readonly=True, id="interim_value"),
|
||||||
rx.button(
|
rx.button(
|
||||||
"Return Event",
|
"Return Event",
|
||||||
id="return_event",
|
id="return_event",
|
||||||
@ -172,6 +188,16 @@ def EventChain():
|
|||||||
id="return_dict_type",
|
id="return_dict_type",
|
||||||
on_click=State.click_return_dict_type,
|
on_click=State.click_return_dict_type,
|
||||||
),
|
),
|
||||||
|
rx.button(
|
||||||
|
"Click Yield Interim Value (Async)",
|
||||||
|
id="click_yield_interim_value_async",
|
||||||
|
on_click=State.click_yield_interim_value_async,
|
||||||
|
),
|
||||||
|
rx.button(
|
||||||
|
"Click Yield Interim Value",
|
||||||
|
id="click_yield_interim_value",
|
||||||
|
on_click=State.click_yield_interim_value,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_load_return_chain():
|
def on_load_return_chain():
|
||||||
@ -237,7 +263,7 @@ def event_chain(tmp_path_factory) -> Generator[AppHarness, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def driver(event_chain: AppHarness):
|
def driver(event_chain: AppHarness) -> Generator[WebDriver, None, None]:
|
||||||
"""Get an instance of the browser open to the event_chain app.
|
"""Get an instance of the browser open to the event_chain app.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -249,7 +275,6 @@ def driver(event_chain: AppHarness):
|
|||||||
assert event_chain.app_instance is not None, "app is not running"
|
assert event_chain.app_instance is not None, "app is not running"
|
||||||
driver = event_chain.frontend()
|
driver = event_chain.frontend()
|
||||||
try:
|
try:
|
||||||
assert event_chain.poll_for_clients()
|
|
||||||
yield driver
|
yield driver
|
||||||
finally:
|
finally:
|
||||||
driver.quit()
|
driver.quit()
|
||||||
@ -335,7 +360,13 @@ def driver(event_chain: AppHarness):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_chain_click(
|
||||||
|
event_chain: AppHarness,
|
||||||
|
driver: WebDriver,
|
||||||
|
button_id: str,
|
||||||
|
exp_event_order: list[str],
|
||||||
|
):
|
||||||
"""Click the button, assert that the events are handled in the correct order.
|
"""Click the button, assert that the events are handled in the correct order.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -350,17 +381,18 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
|
|||||||
assert btn
|
assert btn
|
||||||
|
|
||||||
token = event_chain.poll_for_value(token_input)
|
token = event_chain.poll_for_value(token_input)
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
btn.click()
|
btn.click()
|
||||||
if "redirect" in button_id:
|
|
||||||
# wait a bit longer if we're redirecting
|
async def _has_all_events():
|
||||||
time.sleep(1)
|
return len((await event_chain.get_state(token)).event_order) == len(
|
||||||
if "many_events" in button_id:
|
exp_event_order
|
||||||
# wait a bit longer if we have loads of events
|
)
|
||||||
time.sleep(1)
|
|
||||||
time.sleep(0.5)
|
await AppHarness._poll_for_async(_has_all_events)
|
||||||
backend_state = event_chain.app_instance.state_manager.states[token]
|
event_order = (await event_chain.get_state(token)).event_order
|
||||||
assert backend_state.event_order == exp_event_order
|
assert event_order == exp_event_order
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -386,7 +418,13 @@ def test_event_chain_click(event_chain, driver, button_id, exp_event_order):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_chain_on_load(
|
||||||
|
event_chain: AppHarness,
|
||||||
|
driver: WebDriver,
|
||||||
|
uri: str,
|
||||||
|
exp_event_order: list[str],
|
||||||
|
):
|
||||||
"""Load the URI, assert that the events are handled in the correct order.
|
"""Load the URI, assert that the events are handled in the correct order.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -395,16 +433,23 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
|
|||||||
uri: the page to load
|
uri: the page to load
|
||||||
exp_event_order: the expected events recorded in the State
|
exp_event_order: the expected events recorded in the State
|
||||||
"""
|
"""
|
||||||
|
assert event_chain.frontend_url is not None
|
||||||
driver.get(event_chain.frontend_url + uri)
|
driver.get(event_chain.frontend_url + uri)
|
||||||
token_input = driver.find_element(By.ID, "token")
|
token_input = driver.find_element(By.ID, "token")
|
||||||
assert token_input
|
assert token_input
|
||||||
|
|
||||||
token = event_chain.poll_for_value(token_input)
|
token = event_chain.poll_for_value(token_input)
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
time.sleep(0.5)
|
async def _has_all_events():
|
||||||
backend_state = event_chain.app_instance.state_manager.states[token]
|
return len((await event_chain.get_state(token)).event_order) == len(
|
||||||
assert backend_state.is_hydrated is True
|
exp_event_order
|
||||||
|
)
|
||||||
|
|
||||||
|
await AppHarness._poll_for_async(_has_all_events)
|
||||||
|
backend_state = await event_chain.get_state(token)
|
||||||
assert backend_state.event_order == exp_event_order
|
assert backend_state.event_order == exp_event_order
|
||||||
|
assert backend_state.is_hydrated is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -444,7 +489,13 @@ def test_event_chain_on_load(event_chain, driver, uri, exp_event_order):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
|
@pytest.mark.asyncio
|
||||||
|
async def test_event_chain_on_mount(
|
||||||
|
event_chain: AppHarness,
|
||||||
|
driver: WebDriver,
|
||||||
|
uri: str,
|
||||||
|
exp_event_order: list[str],
|
||||||
|
):
|
||||||
"""Load the URI, assert that the events are handled in the correct order.
|
"""Load the URI, assert that the events are handled in the correct order.
|
||||||
|
|
||||||
These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
|
These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
|
||||||
@ -458,16 +509,53 @@ def test_event_chain_on_mount(event_chain, driver, uri, exp_event_order):
|
|||||||
uri: the page to load
|
uri: the page to load
|
||||||
exp_event_order: the expected events recorded in the State
|
exp_event_order: the expected events recorded in the State
|
||||||
"""
|
"""
|
||||||
|
assert event_chain.frontend_url is not None
|
||||||
driver.get(event_chain.frontend_url + uri)
|
driver.get(event_chain.frontend_url + uri)
|
||||||
token_input = driver.find_element(By.ID, "token")
|
token_input = driver.find_element(By.ID, "token")
|
||||||
assert token_input
|
assert token_input
|
||||||
|
|
||||||
token = event_chain.poll_for_value(token_input)
|
token = event_chain.poll_for_value(token_input)
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
unmount_button = driver.find_element(By.ID, "unmount")
|
unmount_button = driver.find_element(By.ID, "unmount")
|
||||||
assert unmount_button
|
assert unmount_button
|
||||||
unmount_button.click()
|
unmount_button.click()
|
||||||
|
|
||||||
time.sleep(1)
|
async def _has_all_events():
|
||||||
backend_state = event_chain.app_instance.state_manager.states[token]
|
return len((await event_chain.get_state(token)).event_order) == len(
|
||||||
assert backend_state.event_order == exp_event_order
|
exp_event_order
|
||||||
|
)
|
||||||
|
|
||||||
|
await AppHarness._poll_for_async(_has_all_events)
|
||||||
|
event_order = (await event_chain.get_state(token)).event_order
|
||||||
|
assert event_order == exp_event_order
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("button_id",),
|
||||||
|
[
|
||||||
|
("click_yield_interim_value_async",),
|
||||||
|
("click_yield_interim_value",),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_yield_state_update(event_chain: AppHarness, driver: WebDriver, button_id: str):
|
||||||
|
"""Click the button, assert that the interim value is set, then final value is set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_chain: AppHarness for the event_chain app
|
||||||
|
driver: selenium WebDriver open to the app
|
||||||
|
button_id: the ID of the button to click
|
||||||
|
"""
|
||||||
|
token_input = driver.find_element(By.ID, "token")
|
||||||
|
interim_value_input = driver.find_element(By.ID, "interim_value")
|
||||||
|
assert event_chain.poll_for_value(token_input)
|
||||||
|
|
||||||
|
btn = driver.find_element(By.ID, button_id)
|
||||||
|
btn.click()
|
||||||
|
assert (
|
||||||
|
event_chain.poll_for_value(interim_value_input, exp_not_equal="") == "interim"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
event_chain.poll_for_value(interim_value_input, exp_not_equal="interim")
|
||||||
|
== "final"
|
||||||
|
)
|
||||||
|
@ -19,11 +19,16 @@ def FormSubmit():
|
|||||||
def form_submit(self, form_data: dict):
|
def form_submit(self, form_data: dict):
|
||||||
self.form_data = form_data
|
self.form_data = form_data
|
||||||
|
|
||||||
|
@rx.var
|
||||||
|
def token(self) -> str:
|
||||||
|
return self.get_token()
|
||||||
|
|
||||||
app = rx.App(state=FormState)
|
app = rx.App(state=FormState)
|
||||||
|
|
||||||
@app.add_page
|
@app.add_page
|
||||||
def index():
|
def index():
|
||||||
return rx.vstack(
|
return rx.vstack(
|
||||||
|
rx.input(value=FormState.token, is_read_only=True, id="token"),
|
||||||
rx.form(
|
rx.form(
|
||||||
rx.vstack(
|
rx.vstack(
|
||||||
rx.input(id="name_input"),
|
rx.input(id="name_input"),
|
||||||
@ -82,13 +87,13 @@ def driver(form_submit: AppHarness):
|
|||||||
"""
|
"""
|
||||||
driver = form_submit.frontend()
|
driver = form_submit.frontend()
|
||||||
try:
|
try:
|
||||||
assert form_submit.poll_for_clients()
|
|
||||||
yield driver
|
yield driver
|
||||||
finally:
|
finally:
|
||||||
driver.quit()
|
driver.quit()
|
||||||
|
|
||||||
|
|
||||||
def test_submit(driver, form_submit: AppHarness):
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit(driver, form_submit: AppHarness):
|
||||||
"""Fill a form with various different output, submit it to backend and verify
|
"""Fill a form with various different output, submit it to backend and verify
|
||||||
the output.
|
the output.
|
||||||
|
|
||||||
@ -97,7 +102,14 @@ def test_submit(driver, form_submit: AppHarness):
|
|||||||
form_submit: harness for FormSubmit app
|
form_submit: harness for FormSubmit app
|
||||||
"""
|
"""
|
||||||
assert form_submit.app_instance is not None, "app is not running"
|
assert form_submit.app_instance is not None, "app is not running"
|
||||||
_, backend_state = list(form_submit.app_instance.state_manager.states.items())[0]
|
|
||||||
|
# get a reference to the connected client
|
||||||
|
token_input = driver.find_element(By.ID, "token")
|
||||||
|
assert token_input
|
||||||
|
|
||||||
|
# wait for the backend connection to send the token
|
||||||
|
token = form_submit.poll_for_value(token_input)
|
||||||
|
assert token
|
||||||
|
|
||||||
name_input = driver.find_element(By.ID, "name_input")
|
name_input = driver.find_element(By.ID, "name_input")
|
||||||
name_input.send_keys("foo")
|
name_input.send_keys("foo")
|
||||||
@ -132,19 +144,21 @@ def test_submit(driver, form_submit: AppHarness):
|
|||||||
submit_input = driver.find_element(By.CLASS_NAME, "chakra-button")
|
submit_input = driver.find_element(By.CLASS_NAME, "chakra-button")
|
||||||
submit_input.click()
|
submit_input.click()
|
||||||
|
|
||||||
# wait for the form data to arrive at the backend
|
async def get_form_data():
|
||||||
AppHarness._poll_for(
|
return (await form_submit.get_state(token)).form_data
|
||||||
lambda: backend_state.form_data != {},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert backend_state.form_data["name_input"] == "foo"
|
# wait for the form data to arrive at the backend
|
||||||
assert backend_state.form_data["pin_input"] == pin_values
|
form_data = await AppHarness._poll_for_async(get_form_data)
|
||||||
assert backend_state.form_data["number_input"] == "-3"
|
assert isinstance(form_data, dict)
|
||||||
assert backend_state.form_data["bool_input"] is True
|
|
||||||
assert backend_state.form_data["bool_input2"] is True
|
assert form_data["name_input"] == "foo"
|
||||||
assert backend_state.form_data["slider_input"] == "50"
|
assert form_data["pin_input"] == pin_values
|
||||||
assert backend_state.form_data["range_input"] == ["25", "75"]
|
assert form_data["number_input"] == "-3"
|
||||||
assert backend_state.form_data["radio_input"] == "option2"
|
assert form_data["bool_input"] is True
|
||||||
assert backend_state.form_data["select_input"] == "option1"
|
assert form_data["bool_input2"] is True
|
||||||
assert backend_state.form_data["text_area_input"] == "Some\nText"
|
assert form_data["slider_input"] == "50"
|
||||||
assert backend_state.form_data["debounce_input"] == "bar baz"
|
assert form_data["range_input"] == ["25", "75"]
|
||||||
|
assert form_data["radio_input"] == "option2"
|
||||||
|
assert form_data["select_input"] == "option1"
|
||||||
|
assert form_data["text_area_input"] == "Some\nText"
|
||||||
|
assert form_data["debounce_input"] == "bar baz"
|
||||||
|
@ -16,11 +16,16 @@ def FullyControlledInput():
|
|||||||
class State(rx.State):
|
class State(rx.State):
|
||||||
text: str = "initial"
|
text: str = "initial"
|
||||||
|
|
||||||
|
@rx.var
|
||||||
|
def token(self) -> str:
|
||||||
|
return self.get_token()
|
||||||
|
|
||||||
app = rx.App(state=State)
|
app = rx.App(state=State)
|
||||||
|
|
||||||
@app.add_page
|
@app.add_page
|
||||||
def index():
|
def index():
|
||||||
return rx.fragment(
|
return rx.fragment(
|
||||||
|
rx.input(value=State.token, is_read_only=True, id="token"),
|
||||||
rx.input(
|
rx.input(
|
||||||
id="debounce_input_input",
|
id="debounce_input_input",
|
||||||
on_change=State.set_text, # type: ignore
|
on_change=State.set_text, # type: ignore
|
||||||
@ -62,10 +67,12 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
|||||||
driver = fully_controlled_input.frontend()
|
driver = fully_controlled_input.frontend()
|
||||||
|
|
||||||
# get a reference to the connected client
|
# get a reference to the connected client
|
||||||
assert len(fully_controlled_input.poll_for_clients()) == 1
|
token_input = driver.find_element(By.ID, "token")
|
||||||
token, backend_state = list(
|
assert token_input
|
||||||
fully_controlled_input.app_instance.state_manager.states.items()
|
|
||||||
)[0]
|
# wait for the backend connection to send the token
|
||||||
|
token = fully_controlled_input.poll_for_value(token_input)
|
||||||
|
assert token
|
||||||
|
|
||||||
# find the input and wait for it to have the initial state value
|
# find the input and wait for it to have the initial state value
|
||||||
debounce_input = driver.find_element(By.ID, "debounce_input_input")
|
debounce_input = driver.find_element(By.ID, "debounce_input_input")
|
||||||
@ -80,14 +87,13 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
|||||||
debounce_input.send_keys("foo")
|
debounce_input.send_keys("foo")
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
assert debounce_input.get_attribute("value") == "ifoonitial"
|
assert debounce_input.get_attribute("value") == "ifoonitial"
|
||||||
assert backend_state.text == "ifoonitial"
|
assert (await fully_controlled_input.get_state(token)).text == "ifoonitial"
|
||||||
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
|
assert fully_controlled_input.poll_for_value(value_input) == "ifoonitial"
|
||||||
|
|
||||||
# clear the input on the backend
|
# clear the input on the backend
|
||||||
backend_state.text = ""
|
async with fully_controlled_input.modify_state(token) as state:
|
||||||
fully_controlled_input.app_instance.state_manager.set_state(token, backend_state)
|
state.text = ""
|
||||||
await fully_controlled_input.emit_state_updates()
|
assert (await fully_controlled_input.get_state(token)).text == ""
|
||||||
assert backend_state.text == ""
|
|
||||||
assert (
|
assert (
|
||||||
fully_controlled_input.poll_for_value(
|
fully_controlled_input.poll_for_value(
|
||||||
debounce_input, exp_not_equal="ifoonitial"
|
debounce_input, exp_not_equal="ifoonitial"
|
||||||
@ -99,7 +105,9 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
|||||||
debounce_input.send_keys("getting testing done")
|
debounce_input.send_keys("getting testing done")
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
assert debounce_input.get_attribute("value") == "getting testing done"
|
assert debounce_input.get_attribute("value") == "getting testing done"
|
||||||
assert backend_state.text == "getting testing done"
|
assert (
|
||||||
|
await fully_controlled_input.get_state(token)
|
||||||
|
).text == "getting testing done"
|
||||||
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
|
assert fully_controlled_input.poll_for_value(value_input) == "getting testing done"
|
||||||
|
|
||||||
# type into the on_change input
|
# type into the on_change input
|
||||||
@ -107,7 +115,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
assert debounce_input.get_attribute("value") == "overwrite the state"
|
assert debounce_input.get_attribute("value") == "overwrite the state"
|
||||||
assert on_change_input.get_attribute("value") == "overwrite the state"
|
assert on_change_input.get_attribute("value") == "overwrite the state"
|
||||||
assert backend_state.text == "overwrite the state"
|
assert (await fully_controlled_input.get_state(token)).text == "overwrite the state"
|
||||||
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
|
assert fully_controlled_input.poll_for_value(value_input) == "overwrite the state"
|
||||||
|
|
||||||
clear_button.click()
|
clear_button.click()
|
||||||
|
@ -33,11 +33,16 @@ def ServerSideEvent():
|
|||||||
def set_value_return_c(self):
|
def set_value_return_c(self):
|
||||||
return rx.set_value("c", "")
|
return rx.set_value("c", "")
|
||||||
|
|
||||||
|
@rx.var
|
||||||
|
def token(self) -> str:
|
||||||
|
return self.get_token()
|
||||||
|
|
||||||
app = rx.App(state=SSState)
|
app = rx.App(state=SSState)
|
||||||
|
|
||||||
@app.add_page
|
@app.add_page
|
||||||
def index():
|
def index():
|
||||||
return rx.fragment(
|
return rx.fragment(
|
||||||
|
rx.input(id="token", value=SSState.token, is_read_only=True),
|
||||||
rx.input(default_value="a", id="a"),
|
rx.input(default_value="a", id="a"),
|
||||||
rx.input(default_value="b", id="b"),
|
rx.input(default_value="b", id="b"),
|
||||||
rx.input(default_value="c", id="c"),
|
rx.input(default_value="c", id="c"),
|
||||||
@ -106,7 +111,12 @@ def driver(server_side_event: AppHarness):
|
|||||||
assert server_side_event.app_instance is not None, "app is not running"
|
assert server_side_event.app_instance is not None, "app is not running"
|
||||||
driver = server_side_event.frontend()
|
driver = server_side_event.frontend()
|
||||||
try:
|
try:
|
||||||
assert server_side_event.poll_for_clients()
|
token_input = driver.find_element(By.ID, "token")
|
||||||
|
assert token_input
|
||||||
|
# wait for the backend connection to send the token
|
||||||
|
token = server_side_event.poll_for_value(token_input)
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
yield driver
|
yield driver
|
||||||
finally:
|
finally:
|
||||||
driver.quit()
|
driver.quit()
|
||||||
|
@ -89,13 +89,13 @@ def driver(upload_file: AppHarness):
|
|||||||
assert upload_file.app_instance is not None, "app is not running"
|
assert upload_file.app_instance is not None, "app is not running"
|
||||||
driver = upload_file.frontend()
|
driver = upload_file.frontend()
|
||||||
try:
|
try:
|
||||||
assert upload_file.poll_for_clients()
|
|
||||||
yield driver
|
yield driver
|
||||||
finally:
|
finally:
|
||||||
driver.quit()
|
driver.quit()
|
||||||
|
|
||||||
|
|
||||||
def test_upload_file(tmp_path, upload_file: AppHarness, driver):
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
|
||||||
"""Submit a file upload and check that it arrived on the backend.
|
"""Submit a file upload and check that it arrived on the backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -124,16 +124,20 @@ def test_upload_file(tmp_path, upload_file: AppHarness, driver):
|
|||||||
upload_button.click()
|
upload_button.click()
|
||||||
|
|
||||||
# look up the backend state and assert on uploaded contents
|
# look up the backend state and assert on uploaded contents
|
||||||
backend_state = upload_file.app_instance.state_manager.states[token]
|
async def get_file_data():
|
||||||
time.sleep(0.5)
|
return (await upload_file.get_state(token))._file_data
|
||||||
assert backend_state._file_data[exp_name] == exp_contents
|
|
||||||
|
file_data = await AppHarness._poll_for_async(get_file_data)
|
||||||
|
assert isinstance(file_data, dict)
|
||||||
|
assert file_data[exp_name] == exp_contents
|
||||||
|
|
||||||
# check that the selected files are displayed
|
# check that the selected files are displayed
|
||||||
selected_files = driver.find_element(By.ID, "selected_files")
|
selected_files = driver.find_element(By.ID, "selected_files")
|
||||||
assert selected_files.text == exp_name
|
assert selected_files.text == exp_name
|
||||||
|
|
||||||
|
|
||||||
def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
|
||||||
"""Submit several file uploads and check that they arrived on the backend.
|
"""Submit several file uploads and check that they arrived on the backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -173,10 +177,13 @@ def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
|
|||||||
upload_button.click()
|
upload_button.click()
|
||||||
|
|
||||||
# look up the backend state and assert on uploaded contents
|
# look up the backend state and assert on uploaded contents
|
||||||
backend_state = upload_file.app_instance.state_manager.states[token]
|
async def get_file_data():
|
||||||
time.sleep(0.5)
|
return (await upload_file.get_state(token))._file_data
|
||||||
|
|
||||||
|
file_data = await AppHarness._poll_for_async(get_file_data)
|
||||||
|
assert isinstance(file_data, dict)
|
||||||
for exp_name, exp_contents in exp_files.items():
|
for exp_name, exp_contents in exp_files.items():
|
||||||
assert backend_state._file_data[exp_name] == exp_contents
|
assert file_data[exp_name] == exp_contents
|
||||||
|
|
||||||
|
|
||||||
def test_clear_files(tmp_path, upload_file: AppHarness, driver):
|
def test_clear_files(tmp_path, upload_file: AppHarness, driver):
|
||||||
|
@ -26,11 +26,16 @@ def VarOperations():
|
|||||||
dict1: dict = {1: 2}
|
dict1: dict = {1: 2}
|
||||||
dict2: dict = {3: 4}
|
dict2: dict = {3: 4}
|
||||||
|
|
||||||
|
@rx.var
|
||||||
|
def token(self) -> str:
|
||||||
|
return self.get_token()
|
||||||
|
|
||||||
app = rx.App(state=VarOperationState)
|
app = rx.App(state=VarOperationState)
|
||||||
|
|
||||||
@app.add_page
|
@app.add_page
|
||||||
def index():
|
def index():
|
||||||
return rx.vstack(
|
return rx.vstack(
|
||||||
|
rx.input(id="token", value=VarOperationState.token, is_read_only=True),
|
||||||
# INT INT
|
# INT INT
|
||||||
rx.text(
|
rx.text(
|
||||||
VarOperationState.int_var1 + VarOperationState.int_var2,
|
VarOperationState.int_var1 + VarOperationState.int_var2,
|
||||||
@ -544,7 +549,12 @@ def driver(var_operations: AppHarness):
|
|||||||
"""
|
"""
|
||||||
driver = var_operations.frontend()
|
driver = var_operations.frontend()
|
||||||
try:
|
try:
|
||||||
assert var_operations.poll_for_clients()
|
token_input = driver.find_element(By.ID, "token")
|
||||||
|
assert token_input
|
||||||
|
# wait for the backend connection to send the token
|
||||||
|
token = var_operations.poll_for_value(token_input)
|
||||||
|
assert token is not None
|
||||||
|
|
||||||
yield driver
|
yield driver
|
||||||
finally:
|
finally:
|
||||||
driver.quit()
|
driver.quit()
|
||||||
|
@ -21,6 +21,7 @@ from .constants import Env as Env
|
|||||||
from .event import EVENT_ARG as EVENT_ARG
|
from .event import EVENT_ARG as EVENT_ARG
|
||||||
from .event import EventChain as EventChain
|
from .event import EventChain as EventChain
|
||||||
from .event import FileUpload as upload_files
|
from .event import FileUpload as upload_files
|
||||||
|
from .event import background as background
|
||||||
from .event import clear_local_storage as clear_local_storage
|
from .event import clear_local_storage as clear_local_storage
|
||||||
from .event import console_log as console_log
|
from .event import console_log as console_log
|
||||||
from .event import download as download
|
from .event import download as download
|
||||||
|
229
reflex/app.py
229
reflex/app.py
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
@ -13,6 +14,7 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Set,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -49,7 +51,13 @@ from reflex.route import (
|
|||||||
get_route_args,
|
get_route_args,
|
||||||
verify_route_validity,
|
verify_route_validity,
|
||||||
)
|
)
|
||||||
from reflex.state import DefaultState, State, StateManager, StateUpdate
|
from reflex.state import (
|
||||||
|
DefaultState,
|
||||||
|
State,
|
||||||
|
StateManager,
|
||||||
|
StateManagerMemory,
|
||||||
|
StateUpdate,
|
||||||
|
)
|
||||||
from reflex.utils import console, format, prerequisites, types
|
from reflex.utils import console, format, prerequisites, types
|
||||||
from reflex.vars import ImportVar
|
from reflex.vars import ImportVar
|
||||||
|
|
||||||
@ -89,7 +97,7 @@ class App(Base):
|
|||||||
state: Type[State] = DefaultState
|
state: Type[State] = DefaultState
|
||||||
|
|
||||||
# Class to manage many client states.
|
# Class to manage many client states.
|
||||||
state_manager: StateManager = StateManager()
|
state_manager: StateManager = StateManagerMemory(state=DefaultState)
|
||||||
|
|
||||||
# The styling to apply to each component.
|
# The styling to apply to each component.
|
||||||
style: ComponentStyle = {}
|
style: ComponentStyle = {}
|
||||||
@ -104,13 +112,16 @@ class App(Base):
|
|||||||
admin_dash: Optional[AdminDash] = None
|
admin_dash: Optional[AdminDash] = None
|
||||||
|
|
||||||
# The async server name space
|
# The async server name space
|
||||||
event_namespace: Optional[AsyncNamespace] = None
|
event_namespace: Optional[EventNamespace] = None
|
||||||
|
|
||||||
# A component that is present on every page.
|
# A component that is present on every page.
|
||||||
overlay_component: Optional[
|
overlay_component: Optional[
|
||||||
Union[Component, ComponentCallable]
|
Union[Component, ComponentCallable]
|
||||||
] = default_overlay_component
|
] = default_overlay_component
|
||||||
|
|
||||||
|
# Background tasks that are currently running
|
||||||
|
background_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
"""Initialize the app.
|
"""Initialize the app.
|
||||||
|
|
||||||
@ -154,7 +165,7 @@ class App(Base):
|
|||||||
self.middleware.append(HydrateMiddleware())
|
self.middleware.append(HydrateMiddleware())
|
||||||
|
|
||||||
# Set up the state manager.
|
# Set up the state manager.
|
||||||
self.state_manager.setup(state=self.state)
|
self.state_manager = StateManager.create(state=self.state)
|
||||||
|
|
||||||
# Set up the API.
|
# Set up the API.
|
||||||
self.api = FastAPI()
|
self.api = FastAPI()
|
||||||
@ -646,6 +657,76 @@ class App(Base):
|
|||||||
thread_pool.close()
|
thread_pool.close()
|
||||||
thread_pool.join()
|
thread_pool.join()
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||||
|
"""Modify the state out of band.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to modify the state for.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The state to modify.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the app has not been initialized yet.
|
||||||
|
"""
|
||||||
|
if self.event_namespace is None:
|
||||||
|
raise RuntimeError("App has not been initialized yet.")
|
||||||
|
# Get exclusive access to the state.
|
||||||
|
async with self.state_manager.modify_state(token) as state:
|
||||||
|
# No other event handler can modify the state while in this context.
|
||||||
|
yield state
|
||||||
|
delta = state.get_delta()
|
||||||
|
if delta:
|
||||||
|
# When the state is modified reset dirty status and emit the delta to the frontend.
|
||||||
|
state._clean()
|
||||||
|
await self.event_namespace.emit_update(
|
||||||
|
update=StateUpdate(delta=delta),
|
||||||
|
sid=state.get_sid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_background(self, state: State, event: Event) -> asyncio.Task | None:
|
||||||
|
"""Process an event in the background and emit updates as they arrive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The state to process the event for.
|
||||||
|
event: The event to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task if the event was backgroundable, otherwise None
|
||||||
|
"""
|
||||||
|
substate, handler = state._get_event_handler(event)
|
||||||
|
if not handler.is_background:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _coro():
|
||||||
|
"""Coroutine to process the event and emit updates inside an asyncio.Task.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the app has not been initialized yet.
|
||||||
|
"""
|
||||||
|
if self.event_namespace is None:
|
||||||
|
raise RuntimeError("App has not been initialized yet.")
|
||||||
|
|
||||||
|
# Process the event.
|
||||||
|
async for update in state._process_event(
|
||||||
|
handler=handler, state=substate, payload=event.payload
|
||||||
|
):
|
||||||
|
# Postprocess the event.
|
||||||
|
update = await self.postprocess(state, event, update)
|
||||||
|
|
||||||
|
# Send the update to the client.
|
||||||
|
await self.event_namespace.emit_update(
|
||||||
|
update=update,
|
||||||
|
sid=state.get_sid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
task = asyncio.create_task(_coro())
|
||||||
|
self.background_tasks.add(task)
|
||||||
|
# Clean up task from background_tasks set when complete.
|
||||||
|
task.add_done_callback(self.background_tasks.discard)
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
||||||
@ -662,9 +743,6 @@ async def process(
|
|||||||
Yields:
|
Yields:
|
||||||
The state updates after processing the event.
|
The state updates after processing the event.
|
||||||
"""
|
"""
|
||||||
# Get the state for the session.
|
|
||||||
state = app.state_manager.get_state(event.token)
|
|
||||||
|
|
||||||
# Add request data to the state.
|
# Add request data to the state.
|
||||||
router_data = event.router_data
|
router_data = event.router_data
|
||||||
router_data.update(
|
router_data.update(
|
||||||
@ -676,31 +754,35 @@ async def process(
|
|||||||
constants.RouteVar.CLIENT_IP: client_ip,
|
constants.RouteVar.CLIENT_IP: client_ip,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# re-assign only when the value is different
|
# Get the state for the session exclusively.
|
||||||
if state.router_data != router_data:
|
async with app.state_manager.modify_state(event.token) as state:
|
||||||
# assignment will recurse into substates and force recalculation of
|
# re-assign only when the value is different
|
||||||
# dependent ComputedVar (dynamic route variables)
|
if state.router_data != router_data:
|
||||||
state.router_data = router_data
|
# assignment will recurse into substates and force recalculation of
|
||||||
|
# dependent ComputedVar (dynamic route variables)
|
||||||
|
state.router_data = router_data
|
||||||
|
|
||||||
# Preprocess the event.
|
# Preprocess the event.
|
||||||
update = await app.preprocess(state, event)
|
update = await app.preprocess(state, event)
|
||||||
|
|
||||||
# If there was an update, yield it.
|
# If there was an update, yield it.
|
||||||
if update is not None:
|
if update is not None:
|
||||||
yield update
|
|
||||||
|
|
||||||
# Only process the event if there is no update.
|
|
||||||
else:
|
|
||||||
# Process the event.
|
|
||||||
async for update in state._process(event):
|
|
||||||
# Postprocess the event.
|
|
||||||
update = await app.postprocess(state, event, update)
|
|
||||||
|
|
||||||
# Yield the update.
|
|
||||||
yield update
|
yield update
|
||||||
|
|
||||||
# Set the state for the session.
|
# Only process the event if there is no update.
|
||||||
app.state_manager.set_state(event.token, state)
|
else:
|
||||||
|
if app._process_background(state, event) is not None:
|
||||||
|
# `final=True` allows the frontend send more events immediately.
|
||||||
|
yield StateUpdate(final=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process the event synchronously.
|
||||||
|
async for update in state._process(event):
|
||||||
|
# Postprocess the event.
|
||||||
|
update = await app.postprocess(state, event, update)
|
||||||
|
|
||||||
|
# Yield the update.
|
||||||
|
yield update
|
||||||
|
|
||||||
|
|
||||||
async def ping() -> str:
|
async def ping() -> str:
|
||||||
@ -737,47 +819,46 @@ def upload(app: App):
|
|||||||
assert file.filename is not None
|
assert file.filename is not None
|
||||||
file.filename = file.filename.split(":")[-1]
|
file.filename = file.filename.split(":")[-1]
|
||||||
# Get the state for the session.
|
# Get the state for the session.
|
||||||
state = app.state_manager.get_state(token)
|
async with app.state_manager.modify_state(token) as state:
|
||||||
# get the current session ID
|
# get the current session ID
|
||||||
sid = state.get_sid()
|
sid = state.get_sid()
|
||||||
# get the current state(parent state/substate)
|
# get the current state(parent state/substate)
|
||||||
path = handler.split(".")[:-1]
|
path = handler.split(".")[:-1]
|
||||||
current_state = state.get_substate(path)
|
current_state = state.get_substate(path)
|
||||||
handler_upload_param = ()
|
handler_upload_param = ()
|
||||||
|
|
||||||
# get handler function
|
# get handler function
|
||||||
func = getattr(current_state, handler.split(".")[-1])
|
func = getattr(current_state, handler.split(".")[-1])
|
||||||
|
|
||||||
# check if there exists any handler args with annotation, List[UploadFile]
|
# check if there exists any handler args with annotation, List[UploadFile]
|
||||||
for k, v in inspect.getfullargspec(
|
for k, v in inspect.getfullargspec(
|
||||||
func.fn if isinstance(func, EventHandler) else func
|
func.fn if isinstance(func, EventHandler) else func
|
||||||
).annotations.items():
|
).annotations.items():
|
||||||
if types.is_generic_alias(v) and types._issubclass(
|
if types.is_generic_alias(v) and types._issubclass(
|
||||||
v.__args__[0], UploadFile
|
v.__args__[0], UploadFile
|
||||||
):
|
):
|
||||||
handler_upload_param = (k, v)
|
handler_upload_param = (k, v)
|
||||||
break
|
break
|
||||||
|
|
||||||
if not handler_upload_param:
|
if not handler_upload_param:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`{handler}` handler should have a parameter annotated as List["
|
f"`{handler}` handler should have a parameter annotated as List["
|
||||||
f"rx.UploadFile]"
|
f"rx.UploadFile]"
|
||||||
|
)
|
||||||
|
|
||||||
|
event = Event(
|
||||||
|
token=token,
|
||||||
|
name=handler,
|
||||||
|
payload={handler_upload_param[0]: files},
|
||||||
)
|
)
|
||||||
|
async for update in state._process(event):
|
||||||
event = Event(
|
# Postprocess the event.
|
||||||
token=token,
|
update = await app.postprocess(state, event, update)
|
||||||
name=handler,
|
# Send update to client
|
||||||
payload={handler_upload_param[0]: files},
|
await app.event_namespace.emit_update( # type: ignore
|
||||||
)
|
update=update,
|
||||||
async for update in state._process(event):
|
sid=sid,
|
||||||
# Postprocess the event.
|
)
|
||||||
update = await app.postprocess(state, event, update)
|
|
||||||
# Send update to client
|
|
||||||
await asyncio.create_task(
|
|
||||||
app.event_namespace.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) # type: ignore
|
|
||||||
)
|
|
||||||
# Set the state for the session.
|
|
||||||
app.state_manager.set_state(event.token, state)
|
|
||||||
|
|
||||||
return upload_file
|
return upload_file
|
||||||
|
|
||||||
@ -815,6 +896,18 @@ class EventNamespace(AsyncNamespace):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def emit_update(self, update: StateUpdate, sid: str) -> None:
|
||||||
|
"""Emit an update to the client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update: The state update to send.
|
||||||
|
sid: The Socket.IO session id.
|
||||||
|
"""
|
||||||
|
# Creating a task prevents the update from being blocked behind other coroutines.
|
||||||
|
await asyncio.create_task(
|
||||||
|
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
|
||||||
|
)
|
||||||
|
|
||||||
async def on_event(self, sid, data):
|
async def on_event(self, sid, data):
|
||||||
"""Event for receiving front-end websocket events.
|
"""Event for receiving front-end websocket events.
|
||||||
|
|
||||||
@ -841,10 +934,8 @@ class EventNamespace(AsyncNamespace):
|
|||||||
|
|
||||||
# Process the events.
|
# Process the events.
|
||||||
async for update in process(self.app, event, sid, headers, client_ip):
|
async for update in process(self.app, event, sid, headers, client_ip):
|
||||||
# Emit the event.
|
# Emit the update from processing the event.
|
||||||
await asyncio.create_task(
|
await self.emit_update(update=update, sid=sid)
|
||||||
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_ping(self, sid):
|
async def on_ping(self, sid):
|
||||||
"""Event for testing the API endpoint.
|
"""Event for testing the API endpoint.
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
""" Generated with stubgen from mypy, then manually edited, do not regen."""
|
""" Generated with stubgen from mypy, then manually edited, do not regen."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi import UploadFile as UploadFile
|
from fastapi import UploadFile as UploadFile
|
||||||
from reflex import constants as constants
|
from reflex import constants as constants
|
||||||
@ -45,12 +46,14 @@ from reflex.utils import (
|
|||||||
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
from socketio import ASGIApp, AsyncNamespace, AsyncServer
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncContextManager,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Set,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
overload,
|
overload,
|
||||||
@ -75,6 +78,7 @@ class App(Base):
|
|||||||
admin_dash: Optional[AdminDash]
|
admin_dash: Optional[AdminDash]
|
||||||
event_namespace: Optional[AsyncNamespace]
|
event_namespace: Optional[AsyncNamespace]
|
||||||
overlay_component: Optional[Union[Component, ComponentCallable]]
|
overlay_component: Optional[Union[Component, ComponentCallable]]
|
||||||
|
background_tasks: Set[asyncio.Task] = set()
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
@ -116,6 +120,10 @@ class App(Base):
|
|||||||
def setup_admin_dash(self) -> None: ...
|
def setup_admin_dash(self) -> None: ...
|
||||||
def get_frontend_packages(self, imports: Dict[str, str]): ...
|
def get_frontend_packages(self, imports: Dict[str, str]): ...
|
||||||
def compile(self) -> None: ...
|
def compile(self) -> None: ...
|
||||||
|
def modify_state(self, token: str) -> AsyncContextManager[State]: ...
|
||||||
|
def _process_background(
|
||||||
|
self, state: State, event: Event
|
||||||
|
) -> asyncio.Task | None: ...
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
app: App, event: Event, sid: str, headers: Dict, client_ip: str
|
||||||
|
@ -219,6 +219,8 @@ OLD_CONFIG_FILE = f"pcconfig{PY_EXT}"
|
|||||||
PRODUCTION_BACKEND_URL = "https://{username}-{app_name}.api.pynecone.app"
|
PRODUCTION_BACKEND_URL = "https://{username}-{app_name}.api.pynecone.app"
|
||||||
# Token expiration time in seconds.
|
# Token expiration time in seconds.
|
||||||
TOKEN_EXPIRATION = 60 * 60
|
TOKEN_EXPIRATION = 60 * 60
|
||||||
|
# Maximum time in milliseconds that a state can be locked for exclusive access.
|
||||||
|
LOCK_EXPIRATION = 10000
|
||||||
|
|
||||||
# Testing variables.
|
# Testing variables.
|
||||||
# Testing os env set by pytest when running a test case.
|
# Testing os env set by pytest when running a test case.
|
||||||
|
@ -2,7 +2,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
@ -10,6 +20,9 @@ from reflex.utils import console, format
|
|||||||
from reflex.utils.types import ArgsSpec
|
from reflex.utils.types import ArgsSpec
|
||||||
from reflex.vars import BaseVar, Var
|
from reflex.vars import BaseVar, Var
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from reflex.state import State
|
||||||
|
|
||||||
|
|
||||||
class Event(Base):
|
class Event(Base):
|
||||||
"""An event that describes any state change in the app."""
|
"""An event that describes any state change in the app."""
|
||||||
@ -27,6 +40,66 @@ class Event(Base):
|
|||||||
payload: Dict[str, Any] = {}
|
payload: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
BACKGROUND_TASK_MARKER = "_reflex_background_task"
|
||||||
|
|
||||||
|
|
||||||
|
def background(fn):
|
||||||
|
"""Decorator to mark event handler as running in the background.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The function to decorate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same function, but with a marker set.
|
||||||
|
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the function is not a coroutine function or async generator.
|
||||||
|
"""
|
||||||
|
if not inspect.iscoroutinefunction(fn) and not inspect.isasyncgenfunction(fn):
|
||||||
|
raise TypeError("Background task must be async function or generator.")
|
||||||
|
setattr(fn, BACKGROUND_TASK_MARKER, True)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def _no_chain_background_task(
|
||||||
|
state_cls: Type["State"], name: str, fn: Callable
|
||||||
|
) -> Callable:
|
||||||
|
"""Protect against directly chaining a background task from another event handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_cls: The state class that the event handler is in.
|
||||||
|
name: The name of the background task.
|
||||||
|
fn: The background task coroutine function / generator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A compatible coroutine function / generator that raises a runtime error.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the background task is not async.
|
||||||
|
"""
|
||||||
|
call = f"{state_cls.__name__}.{name}"
|
||||||
|
message = (
|
||||||
|
f"Cannot directly call background task {name!r}, use "
|
||||||
|
f"`yield {call}` or `return {call}` instead."
|
||||||
|
)
|
||||||
|
if inspect.iscoroutinefunction(fn):
|
||||||
|
|
||||||
|
async def _no_chain_background_task_co(*args, **kwargs):
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
return _no_chain_background_task_co
|
||||||
|
if inspect.isasyncgenfunction(fn):
|
||||||
|
|
||||||
|
async def _no_chain_background_task_gen(*args, **kwargs):
|
||||||
|
yield
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
return _no_chain_background_task_gen
|
||||||
|
|
||||||
|
raise TypeError(f"{fn} is marked as a background task, but is not async.")
|
||||||
|
|
||||||
|
|
||||||
class EventHandler(Base):
|
class EventHandler(Base):
|
||||||
"""An event handler responds to an event to update the state."""
|
"""An event handler responds to an event to update the state."""
|
||||||
|
|
||||||
@ -39,6 +112,15 @@ class EventHandler(Base):
|
|||||||
# Needed to allow serialization of Callable.
|
# Needed to allow serialization of Callable.
|
||||||
frozen = True
|
frozen = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_background(self) -> bool:
|
||||||
|
"""Whether the event handler is a background task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the event handler is marked as a background task.
|
||||||
|
"""
|
||||||
|
return getattr(self.fn, BACKGROUND_TASK_MARKER, False)
|
||||||
|
|
||||||
def __call__(self, *args: Var) -> EventSpec:
|
def __call__(self, *args: Var) -> EventSpec:
|
||||||
"""Pass arguments to the handler to get an event spec.
|
"""Pass arguments to the handler to get an event spec.
|
||||||
|
|
||||||
@ -530,7 +612,7 @@ def get_handler_args(event_spec: EventSpec) -> tuple[tuple[Var, Var], ...]:
|
|||||||
|
|
||||||
|
|
||||||
def fix_events(
|
def fix_events(
|
||||||
events: list[EventHandler | EventSpec],
|
events: list[EventHandler | EventSpec] | None,
|
||||||
token: str,
|
token: str,
|
||||||
router_data: dict[str, Any] | None = None,
|
router_data: dict[str, Any] | None = None,
|
||||||
) -> list[Event]:
|
) -> list[Event]:
|
||||||
|
639
reflex/state.py
639
reflex/state.py
@ -2,13 +2,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from abc import ABC
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -27,12 +29,20 @@ from typing import (
|
|||||||
import cloudpickle
|
import cloudpickle
|
||||||
import pydantic
|
import pydantic
|
||||||
import wrapt
|
import wrapt
|
||||||
from redis import Redis
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.event import Event, EventHandler, EventSpec, fix_events, window_alert
|
from reflex.event import (
|
||||||
|
Event,
|
||||||
|
EventHandler,
|
||||||
|
EventSpec,
|
||||||
|
_no_chain_background_task,
|
||||||
|
fix_events,
|
||||||
|
window_alert,
|
||||||
|
)
|
||||||
from reflex.utils import format, prerequisites, types
|
from reflex.utils import format, prerequisites, types
|
||||||
|
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
|
||||||
from reflex.vars import BaseVar, ComputedVar, Var
|
from reflex.vars import BaseVar, ComputedVar, Var
|
||||||
|
|
||||||
Delta = Dict[str, Any]
|
Delta = Dict[str, Any]
|
||||||
@ -152,7 +162,10 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
# Convert the event handlers to functions.
|
# Convert the event handlers to functions.
|
||||||
for name, event_handler in state.event_handlers.items():
|
for name, event_handler in state.event_handlers.items():
|
||||||
fn = functools.partial(event_handler.fn, self)
|
if event_handler.is_background:
|
||||||
|
fn = _no_chain_background_task(type(state), name, event_handler.fn)
|
||||||
|
else:
|
||||||
|
fn = functools.partial(event_handler.fn, self)
|
||||||
fn.__module__ = event_handler.fn.__module__ # type: ignore
|
fn.__module__ = event_handler.fn.__module__ # type: ignore
|
||||||
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
fn.__qualname__ = event_handler.fn.__qualname__ # type: ignore
|
||||||
setattr(self, name, fn)
|
setattr(self, name, fn)
|
||||||
@ -711,6 +724,37 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
raise ValueError(f"Invalid path: {path}")
|
raise ValueError(f"Invalid path: {path}")
|
||||||
return self.substates[path[0]].get_substate(path[1:])
|
return self.substates[path[0]].get_substate(path[1:])
|
||||||
|
|
||||||
|
def _get_event_handler(
|
||||||
|
self, event: Event
|
||||||
|
) -> tuple[State | StateProxy, EventHandler]:
|
||||||
|
"""Get the event handler for the given event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The event to get the handler for.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The event handler.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the event handler or substate is not found.
|
||||||
|
"""
|
||||||
|
# Get the event handler.
|
||||||
|
path = event.name.split(".")
|
||||||
|
path, name = path[:-1], path[-1]
|
||||||
|
substate = self.get_substate(path)
|
||||||
|
if not substate:
|
||||||
|
raise ValueError(
|
||||||
|
"The value of state cannot be None when processing an event."
|
||||||
|
)
|
||||||
|
handler = substate.event_handlers[name]
|
||||||
|
|
||||||
|
# For background tasks, proxy the state
|
||||||
|
if handler.is_background:
|
||||||
|
substate = StateProxy(substate)
|
||||||
|
|
||||||
|
return substate, handler
|
||||||
|
|
||||||
async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
|
async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
|
||||||
"""Obtain event info and process event.
|
"""Obtain event info and process event.
|
||||||
|
|
||||||
@ -719,44 +763,17 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The state update after processing the event.
|
The state update after processing the event.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the state value is None.
|
|
||||||
"""
|
"""
|
||||||
# Get the event handler.
|
# Get the event handler.
|
||||||
path = event.name.split(".")
|
substate, handler = self._get_event_handler(event)
|
||||||
path, name = path[:-1], path[-1]
|
|
||||||
substate = self.get_substate(path)
|
|
||||||
handler = substate.event_handlers[name] # type: ignore
|
|
||||||
|
|
||||||
if not substate:
|
# Run the event generator and yield state updates.
|
||||||
raise ValueError(
|
async for update in self._process_event(
|
||||||
"The value of state cannot be None when processing an event."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the event generator.
|
|
||||||
event_iter = self._process_event(
|
|
||||||
handler=handler,
|
handler=handler,
|
||||||
state=substate,
|
state=substate,
|
||||||
payload=event.payload,
|
payload=event.payload,
|
||||||
)
|
):
|
||||||
|
yield update
|
||||||
# Clean the state before processing the event.
|
|
||||||
self._clean()
|
|
||||||
|
|
||||||
# Run the event generator and return state updates.
|
|
||||||
async for events, final in event_iter:
|
|
||||||
# Fix the returned events.
|
|
||||||
events = fix_events(events, event.token) # type: ignore
|
|
||||||
|
|
||||||
# Get the delta after processing the event.
|
|
||||||
delta = self.get_delta()
|
|
||||||
|
|
||||||
# Yield the state update.
|
|
||||||
yield StateUpdate(delta=delta, events=events, final=final)
|
|
||||||
|
|
||||||
# Clean the state to prepare for the next event.
|
|
||||||
self._clean()
|
|
||||||
|
|
||||||
def _check_valid(self, handler: EventHandler, events: Any) -> Any:
|
def _check_valid(self, handler: EventHandler, events: Any) -> Any:
|
||||||
"""Check if the events yielded are valid. They must be EventHandlers or EventSpecs.
|
"""Check if the events yielded are valid. They must be EventHandlers or EventSpecs.
|
||||||
@ -787,9 +804,42 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _as_state_update(
|
||||||
|
self,
|
||||||
|
handler: EventHandler,
|
||||||
|
events: EventSpec | list[EventSpec] | None,
|
||||||
|
final: bool,
|
||||||
|
) -> StateUpdate:
|
||||||
|
"""Convert the events to a StateUpdate.
|
||||||
|
|
||||||
|
Fixes the events and checks for validity before converting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handler: The handler where the events originated from.
|
||||||
|
events: The events to queue with the update.
|
||||||
|
final: Whether the handler is done processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The valid StateUpdate containing the events and final flag.
|
||||||
|
"""
|
||||||
|
token = self.get_token()
|
||||||
|
|
||||||
|
# Convert valid EventHandler and EventSpec into Event
|
||||||
|
fixed_events = fix_events(self._check_valid(handler, events), token)
|
||||||
|
|
||||||
|
# Get the delta after processing the event.
|
||||||
|
delta = self.get_delta()
|
||||||
|
self._clean()
|
||||||
|
|
||||||
|
return StateUpdate(
|
||||||
|
delta=delta,
|
||||||
|
events=fixed_events,
|
||||||
|
final=final if not handler.is_background else True,
|
||||||
|
)
|
||||||
|
|
||||||
async def _process_event(
|
async def _process_event(
|
||||||
self, handler: EventHandler, state: State, payload: Dict
|
self, handler: EventHandler, state: State | StateProxy, payload: Dict
|
||||||
) -> AsyncIterator[tuple[list[EventSpec] | None, bool]]:
|
) -> AsyncIterator[StateUpdate]:
|
||||||
"""Process event.
|
"""Process event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -798,13 +848,14 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
payload: The event payload.
|
payload: The event payload.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuple containing:
|
StateUpdate object
|
||||||
0: The state update after processing the event.
|
|
||||||
1: Whether the event is the final event.
|
|
||||||
"""
|
"""
|
||||||
# Get the function to process the event.
|
# Get the function to process the event.
|
||||||
fn = functools.partial(handler.fn, state)
|
fn = functools.partial(handler.fn, state)
|
||||||
|
|
||||||
|
# Clean the state before processing the event.
|
||||||
|
self._clean()
|
||||||
|
|
||||||
# Wrap the function in a try/except block.
|
# Wrap the function in a try/except block.
|
||||||
try:
|
try:
|
||||||
# Handle async functions.
|
# Handle async functions.
|
||||||
@ -817,30 +868,34 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
# Handle async generators.
|
# Handle async generators.
|
||||||
if inspect.isasyncgen(events):
|
if inspect.isasyncgen(events):
|
||||||
async for event in events:
|
async for event in events:
|
||||||
yield self._check_valid(handler, event), False
|
yield self._as_state_update(handler, event, final=False)
|
||||||
yield None, True
|
yield self._as_state_update(handler, events=None, final=True)
|
||||||
|
|
||||||
# Handle regular generators.
|
# Handle regular generators.
|
||||||
elif inspect.isgenerator(events):
|
elif inspect.isgenerator(events):
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
yield self._check_valid(handler, next(events)), False
|
yield self._as_state_update(handler, next(events), final=False)
|
||||||
except StopIteration as si:
|
except StopIteration as si:
|
||||||
# the "return" value of the generator is not available
|
# the "return" value of the generator is not available
|
||||||
# in the loop, we must catch StopIteration to access it
|
# in the loop, we must catch StopIteration to access it
|
||||||
if si.value is not None:
|
if si.value is not None:
|
||||||
yield self._check_valid(handler, si.value), False
|
yield self._as_state_update(handler, si.value, final=False)
|
||||||
yield None, True
|
yield self._as_state_update(handler, events=None, final=True)
|
||||||
|
|
||||||
# Handle regular event chains.
|
# Handle regular event chains.
|
||||||
else:
|
else:
|
||||||
yield self._check_valid(handler, events), True
|
yield self._as_state_update(handler, events, final=True)
|
||||||
|
|
||||||
# If an error occurs, throw a window alert.
|
# If an error occurs, throw a window alert.
|
||||||
except Exception:
|
except Exception:
|
||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
print(error)
|
print(error)
|
||||||
yield [window_alert("An error occurred. See logs for details.")], True
|
yield self._as_state_update(
|
||||||
|
handler,
|
||||||
|
window_alert("An error occurred. See logs for details."),
|
||||||
|
final=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _always_dirty_computed_vars(self) -> set[str]:
|
def _always_dirty_computed_vars(self) -> set[str]:
|
||||||
"""The set of ComputedVars that always need to be recalculated.
|
"""The set of ComputedVars that always need to be recalculated.
|
||||||
@ -989,6 +1044,160 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
variables = {**base_vars, **computed_vars, **substate_vars}
|
variables = {**base_vars, **computed_vars, **substate_vars}
|
||||||
return {k: variables[k] for k in sorted(variables)}
|
return {k: variables[k] for k in sorted(variables)}
|
||||||
|
|
||||||
|
async def __aenter__(self) -> State:
|
||||||
|
"""Enter the async context manager protocol.
|
||||||
|
|
||||||
|
This should not be used for the State class, but exists for
|
||||||
|
type-compatibility with StateProxy.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: always, because async contextmanager protocol is only supported for background task.
|
||||||
|
"""
|
||||||
|
raise TypeError(
|
||||||
|
"Only background task should use `async with self` to modify state."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc_info: Any) -> None:
|
||||||
|
"""Exit the async context manager protocol.
|
||||||
|
|
||||||
|
This should not be used for the State class, but exists for
|
||||||
|
type-compatibility with StateProxy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_info: The exception info tuple.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StateProxy(wrapt.ObjectProxy):
|
||||||
|
"""Proxy of a state instance to control mutability of vars for a background task.
|
||||||
|
|
||||||
|
Since a background task runs against a state instance without holding the
|
||||||
|
state_manager lock for the token, the reference may become stale if the same
|
||||||
|
state is modified by another event handler.
|
||||||
|
|
||||||
|
The proxy object ensures that writes to the state are blocked unless
|
||||||
|
explicitly entering a context which refreshes the state from state_manager
|
||||||
|
and holds the lock for the token until exiting the context. After exiting
|
||||||
|
the context, a StateUpdate may be emitted to the frontend to notify the
|
||||||
|
client of the state change.
|
||||||
|
|
||||||
|
A background task will be passed the `StateProxy` as `self`, so mutability
|
||||||
|
can be safely performed inside an `async with self` block.
|
||||||
|
|
||||||
|
class State(rx.State):
|
||||||
|
counter: int = 0
|
||||||
|
|
||||||
|
@rx.background
|
||||||
|
async def bg_increment(self):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
async with self:
|
||||||
|
self.counter += 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state_instance):
|
||||||
|
"""Create a proxy for a state instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_instance: The state instance to proxy.
|
||||||
|
"""
|
||||||
|
super().__init__(state_instance)
|
||||||
|
self._self_app = getattr(prerequisites.get_app(), constants.APP_VAR)
|
||||||
|
self._self_substate_path = state_instance.get_full_name().split(".")
|
||||||
|
self._self_actx = None
|
||||||
|
self._self_mutable = False
|
||||||
|
|
||||||
|
async def __aenter__(self) -> StateProxy:
|
||||||
|
"""Enter the async context manager protocol.
|
||||||
|
|
||||||
|
Sets mutability to True and enters the `App.modify_state` async context,
|
||||||
|
which refreshes the state from state_manager and holds the lock for the
|
||||||
|
given state token until exiting the context.
|
||||||
|
|
||||||
|
Background tasks should avoid blocking calls while inside the context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
This StateProxy instance in mutable mode.
|
||||||
|
"""
|
||||||
|
self._self_actx = self._self_app.modify_state(self.__wrapped__.get_token())
|
||||||
|
mutable_state = await self._self_actx.__aenter__()
|
||||||
|
super().__setattr__(
|
||||||
|
"__wrapped__", mutable_state.get_substate(self._self_substate_path)
|
||||||
|
)
|
||||||
|
self._self_mutable = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc_info: Any) -> None:
|
||||||
|
"""Exit the async context manager protocol.
|
||||||
|
|
||||||
|
Sets proxy mutability to False and persists any state changes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_info: The exception info tuple.
|
||||||
|
"""
|
||||||
|
if self._self_actx is None:
|
||||||
|
return
|
||||||
|
self._self_mutable = False
|
||||||
|
await self._self_actx.__aexit__(*exc_info)
|
||||||
|
self._self_actx = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Enter the regular context manager protocol.
|
||||||
|
|
||||||
|
This is not supported for background tasks, and exists only to raise a more useful exception
|
||||||
|
when the StateProxy is used incorrectly.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: always, because only async contextmanager protocol is supported.
|
||||||
|
"""
|
||||||
|
raise TypeError("Background task must use `async with self` to modify state.")
|
||||||
|
|
||||||
|
def __exit__(self, *exc_info: Any) -> None:
|
||||||
|
"""Exit the regular context manager protocol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_info: The exception info tuple.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
"""Get the attribute from the underlying state instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the attribute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value of the attribute.
|
||||||
|
"""
|
||||||
|
value = super().__getattr__(name)
|
||||||
|
if not name.startswith("_self_") and isinstance(value, MutableProxy):
|
||||||
|
# ensure mutations to these containers are blocked unless proxy is _mutable
|
||||||
|
return ImmutableMutableProxy(
|
||||||
|
wrapped=value.__wrapped__,
|
||||||
|
state=self, # type: ignore
|
||||||
|
field_name=value._self_field_name,
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
"""Set the attribute on the underlying state instance.
|
||||||
|
|
||||||
|
If the attribute is internal, set it on the proxy instance instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the attribute.
|
||||||
|
value: The value of the attribute.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImmutableStateError: If the state is not in mutable mode.
|
||||||
|
"""
|
||||||
|
if not name.startswith("_self_") and not self._self_mutable:
|
||||||
|
raise ImmutableStateError(
|
||||||
|
"Background task StateProxy is immutable outside of a context "
|
||||||
|
"manager. Use `async with self` to modify state."
|
||||||
|
)
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
|
|
||||||
class DefaultState(State):
|
class DefaultState(State):
|
||||||
"""The default empty state."""
|
"""The default empty state."""
|
||||||
@ -1009,31 +1218,29 @@ class StateUpdate(Base):
|
|||||||
final: bool = True
|
final: bool = True
|
||||||
|
|
||||||
|
|
||||||
class StateManager(Base):
|
class StateManager(Base, ABC):
|
||||||
"""A class to manage many client states."""
|
"""A class to manage many client states."""
|
||||||
|
|
||||||
# The state class to use.
|
# The state class to use.
|
||||||
state: Type[State] = DefaultState
|
state: Type[State]
|
||||||
|
|
||||||
# The mapping of client ids to states.
|
@classmethod
|
||||||
states: Dict[str, State] = {}
|
def create(cls, state: Type[State] = DefaultState):
|
||||||
|
"""Create a new state manager.
|
||||||
# The token expiration time (s).
|
|
||||||
token_expiration: int = constants.TOKEN_EXPIRATION
|
|
||||||
|
|
||||||
# The redis client to use.
|
|
||||||
redis: Optional[Redis] = None
|
|
||||||
|
|
||||||
def setup(self, state: Type[State]):
|
|
||||||
"""Set up the state manager.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: The state class to use.
|
state: The state class to use.
|
||||||
"""
|
|
||||||
self.state = state
|
|
||||||
self.redis = prerequisites.get_redis()
|
|
||||||
|
|
||||||
def get_state(self, token: str) -> State:
|
Returns:
|
||||||
|
The state manager (either memory or redis).
|
||||||
|
"""
|
||||||
|
redis = prerequisites.get_redis()
|
||||||
|
if redis is not None:
|
||||||
|
return StateManagerRedis(state=state, redis=redis)
|
||||||
|
return StateManagerMemory(state=state)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_state(self, token: str) -> State:
|
||||||
"""Get the state for a token.
|
"""Get the state for a token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1042,27 +1249,266 @@ class StateManager(Base):
|
|||||||
Returns:
|
Returns:
|
||||||
The state for the token.
|
The state for the token.
|
||||||
"""
|
"""
|
||||||
if self.redis is not None:
|
pass
|
||||||
redis_state = self.redis.get(token)
|
|
||||||
if redis_state is None:
|
|
||||||
self.set_state(token, self.state())
|
|
||||||
return self.get_state(token)
|
|
||||||
return cloudpickle.loads(redis_state)
|
|
||||||
|
|
||||||
if token not in self.states:
|
@abstractmethod
|
||||||
self.states[token] = self.state()
|
async def set_state(self, token: str, state: State):
|
||||||
return self.states[token]
|
|
||||||
|
|
||||||
def set_state(self, token: str, state: State):
|
|
||||||
"""Set the state for a token.
|
"""Set the state for a token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: The token to set the state for.
|
token: The token to set the state for.
|
||||||
state: The state to set.
|
state: The state to set.
|
||||||
"""
|
"""
|
||||||
if self.redis is None:
|
pass
|
||||||
return
|
|
||||||
self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
|
@abstractmethod
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||||
|
"""Modify the state for a token while holding exclusive lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to modify the state for.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The state for the token.
|
||||||
|
"""
|
||||||
|
yield self.state()
|
||||||
|
|
||||||
|
|
||||||
|
class StateManagerMemory(StateManager):
|
||||||
|
"""A state manager that stores states in memory."""
|
||||||
|
|
||||||
|
# The mapping of client ids to states.
|
||||||
|
states: Dict[str, State] = {}
|
||||||
|
|
||||||
|
# The mutex ensures the dict of mutexes is updated exclusively
|
||||||
|
_state_manager_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# The dict of mutexes for each client
|
||||||
|
_states_locks: Dict[str, asyncio.Lock] = pydantic.PrivateAttr({})
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""The Pydantic config."""
|
||||||
|
|
||||||
|
fields = {
|
||||||
|
"_states_locks": {"exclude": True},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_state(self, token: str) -> State:
|
||||||
|
"""Get the state for a token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to get the state for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The state for the token.
|
||||||
|
"""
|
||||||
|
if token not in self.states:
|
||||||
|
self.states[token] = self.state()
|
||||||
|
return self.states[token]
|
||||||
|
|
||||||
|
async def set_state(self, token: str, state: State):
|
||||||
|
"""Set the state for a token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to set the state for.
|
||||||
|
state: The state to set.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||||
|
"""Modify the state for a token while holding exclusive lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to modify the state for.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The state for the token.
|
||||||
|
"""
|
||||||
|
if token not in self._states_locks:
|
||||||
|
async with self._state_manager_lock:
|
||||||
|
if token not in self._states_locks:
|
||||||
|
self._states_locks[token] = asyncio.Lock()
|
||||||
|
|
||||||
|
async with self._states_locks[token]:
|
||||||
|
state = await self.get_state(token)
|
||||||
|
yield state
|
||||||
|
await self.set_state(token, state)
|
||||||
|
|
||||||
|
|
||||||
|
class StateManagerRedis(StateManager):
|
||||||
|
"""A state manager that stores states in redis."""
|
||||||
|
|
||||||
|
# The redis client to use.
|
||||||
|
redis: Redis
|
||||||
|
|
||||||
|
# The token expiration time (s).
|
||||||
|
token_expiration: int = constants.TOKEN_EXPIRATION
|
||||||
|
|
||||||
|
# The maximum time to hold a lock (ms).
|
||||||
|
lock_expiration: int = constants.LOCK_EXPIRATION
|
||||||
|
|
||||||
|
# The keyspace subscription string when redis is waiting for lock to be released
|
||||||
|
_redis_notify_keyspace_events: str = (
|
||||||
|
"K" # Enable keyspace notifications (target a particular key)
|
||||||
|
"g" # For generic commands (DEL, EXPIRE, etc)
|
||||||
|
"x" # For expired events
|
||||||
|
"e" # For evicted events (i.e. maxmemory exceeded)
|
||||||
|
)
|
||||||
|
|
||||||
|
# These events indicate that a lock is no longer held
|
||||||
|
_redis_keyspace_lock_release_events: Set[bytes] = {
|
||||||
|
b"del",
|
||||||
|
b"expire",
|
||||||
|
b"expired",
|
||||||
|
b"evicted",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_state(self, token: str) -> State:
|
||||||
|
"""Get the state for a token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to get the state for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The state for the token.
|
||||||
|
"""
|
||||||
|
redis_state = await self.redis.get(token)
|
||||||
|
if redis_state is None:
|
||||||
|
await self.set_state(token, self.state())
|
||||||
|
return await self.get_state(token)
|
||||||
|
return cloudpickle.loads(redis_state)
|
||||||
|
|
||||||
|
async def set_state(self, token: str, state: State, lock_id: bytes | None = None):
|
||||||
|
"""Set the state for a token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to set the state for.
|
||||||
|
state: The state to set.
|
||||||
|
lock_id: If provided, the lock_key must be set to this value to set the state.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID.
|
||||||
|
"""
|
||||||
|
# check that we're holding the lock
|
||||||
|
if (
|
||||||
|
lock_id is not None
|
||||||
|
and await self.redis.get(self._lock_key(token)) != lock_id
|
||||||
|
):
|
||||||
|
raise LockExpiredError(
|
||||||
|
f"Lock expired for token {token} while processing. Consider increasing "
|
||||||
|
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
|
||||||
|
"or use `@rx.background` decorator for long-running tasks."
|
||||||
|
)
|
||||||
|
await self.redis.set(token, cloudpickle.dumps(state), ex=self.token_expiration)
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||||
|
"""Modify the state for a token while holding exclusive lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to modify the state for.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The state for the token.
|
||||||
|
"""
|
||||||
|
async with self._lock(token) as lock_id:
|
||||||
|
state = await self.get_state(token)
|
||||||
|
yield state
|
||||||
|
await self.set_state(token, state, lock_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _lock_key(token: str) -> bytes:
|
||||||
|
"""Get the redis key for a token's lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to get the lock key for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The redis lock key for the token.
|
||||||
|
"""
|
||||||
|
return f"{token}_lock".encode()
|
||||||
|
|
||||||
|
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
|
||||||
|
"""Try to get a redis lock for a token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lock_key: The redis key for the lock.
|
||||||
|
lock_id: The ID of the lock.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the lock was obtained.
|
||||||
|
"""
|
||||||
|
return await self.redis.set(
|
||||||
|
lock_key,
|
||||||
|
lock_id,
|
||||||
|
px=self.lock_expiration,
|
||||||
|
nx=True, # only set if it doesn't exist
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
|
||||||
|
"""Wait for a redis lock to be released via pubsub.
|
||||||
|
|
||||||
|
Coroutine will not return until the lock is obtained.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lock_key: The redis key for the lock.
|
||||||
|
lock_id: The ID of the lock.
|
||||||
|
"""
|
||||||
|
state_is_locked = False
|
||||||
|
lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
|
||||||
|
# Enable keyspace notifications for the lock key, so we know when it is available.
|
||||||
|
await self.redis.config_set(
|
||||||
|
"notify-keyspace-events", self._redis_notify_keyspace_events
|
||||||
|
)
|
||||||
|
async with self.redis.pubsub() as pubsub:
|
||||||
|
await pubsub.psubscribe(lock_key_channel)
|
||||||
|
while not state_is_locked:
|
||||||
|
# wait for the lock to be released
|
||||||
|
while True:
|
||||||
|
if not await self.redis.exists(lock_key):
|
||||||
|
break # key was removed, try to get the lock again
|
||||||
|
message = await pubsub.get_message(
|
||||||
|
ignore_subscribe_messages=True,
|
||||||
|
timeout=self.lock_expiration / 1000.0,
|
||||||
|
)
|
||||||
|
if message is None:
|
||||||
|
continue
|
||||||
|
if message["data"] in self._redis_keyspace_lock_release_events:
|
||||||
|
break
|
||||||
|
state_is_locked = await self._try_get_lock(lock_key, lock_id)
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _lock(self, token: str):
|
||||||
|
"""Obtain a redis lock for a token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to obtain a lock for.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The ID of the lock (to be passed to set_state).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LockExpiredError: If the lock has expired while processing the event.
|
||||||
|
"""
|
||||||
|
lock_key = self._lock_key(token)
|
||||||
|
lock_id = uuid.uuid4().hex.encode()
|
||||||
|
|
||||||
|
if not await self._try_get_lock(lock_key, lock_id):
|
||||||
|
# Missed the fast-path to get lock, subscribe for lock delete/expire events
|
||||||
|
await self._wait_lock(lock_key, lock_id)
|
||||||
|
state_is_locked = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield lock_id
|
||||||
|
except LockExpiredError:
|
||||||
|
state_is_locked = False
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if state_is_locked:
|
||||||
|
# only delete our lock
|
||||||
|
await self.redis.delete(lock_key)
|
||||||
|
|
||||||
|
|
||||||
class ClientStorageBase:
|
class ClientStorageBase:
|
||||||
@ -1246,7 +1692,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
value, super().__getattribute__("__mutable_types__")
|
value, super().__getattribute__("__mutable_types__")
|
||||||
) and __name not in ("__wrapped__", "_self_state"):
|
) and __name not in ("__wrapped__", "_self_state"):
|
||||||
# Recursively wrap mutable attribute values retrieved through this proxy.
|
# Recursively wrap mutable attribute values retrieved through this proxy.
|
||||||
return MutableProxy(
|
return type(self)(
|
||||||
wrapped=value,
|
wrapped=value,
|
||||||
state=self._self_state,
|
state=self._self_state,
|
||||||
field_name=self._self_field_name,
|
field_name=self._self_field_name,
|
||||||
@ -1266,7 +1712,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
value = super().__getitem__(key)
|
value = super().__getitem__(key)
|
||||||
if isinstance(value, self.__mutable_types__):
|
if isinstance(value, self.__mutable_types__):
|
||||||
# Recursively wrap mutable items retrieved through this proxy.
|
# Recursively wrap mutable items retrieved through this proxy.
|
||||||
return MutableProxy(
|
return type(self)(
|
||||||
wrapped=value,
|
wrapped=value,
|
||||||
state=self._self_state,
|
state=self._self_state,
|
||||||
field_name=self._self_field_name,
|
field_name=self._self_field_name,
|
||||||
@ -1332,3 +1778,34 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
A deepcopy of the wrapped object, unconnected to the proxy.
|
A deepcopy of the wrapped object, unconnected to the proxy.
|
||||||
"""
|
"""
|
||||||
return copy.deepcopy(self.__wrapped__, memo=memo)
|
return copy.deepcopy(self.__wrapped__, memo=memo)
|
||||||
|
|
||||||
|
|
||||||
|
class ImmutableMutableProxy(MutableProxy):
|
||||||
|
"""A proxy for a mutable object that tracks changes.
|
||||||
|
|
||||||
|
This wrapper comes from StateProxy, and will raise an exception if an attempt is made
|
||||||
|
to modify the wrapped object when the StateProxy is immutable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _mark_dirty(self, wrapped=None, instance=None, args=tuple(), kwargs=None):
|
||||||
|
"""Raise an exception when an attempt is made to modify the object.
|
||||||
|
|
||||||
|
Intended for use with `FunctionWrapper` from the `wrapt` library.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wrapped: The wrapped function.
|
||||||
|
instance: The instance of the wrapped function.
|
||||||
|
args: The args for the wrapped function.
|
||||||
|
kwargs: The kwargs for the wrapped function.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImmutableStateError: if the StateProxy is not mutable.
|
||||||
|
"""
|
||||||
|
if not self._self_state._self_mutable:
|
||||||
|
raise ImmutableStateError(
|
||||||
|
"Background task StateProxy is immutable outside of a context "
|
||||||
|
"manager. Use `async with self` to modify state."
|
||||||
|
)
|
||||||
|
super()._mark_dirty(
|
||||||
|
wrapped=wrapped, instance=instance, args=args, kwargs=kwargs
|
||||||
|
)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""reflex.testing - tools for testing reflex apps."""
|
"""reflex.testing - tools for testing reflex apps."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
@ -19,14 +20,13 @@ import types
|
|||||||
from http.server import SimpleHTTPRequestHandler
|
from http.server import SimpleHTTPRequestHandler
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Optional,
|
Optional,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
@ -38,7 +38,7 @@ import reflex.utils.build
|
|||||||
import reflex.utils.exec
|
import reflex.utils.exec
|
||||||
import reflex.utils.prerequisites
|
import reflex.utils.prerequisites
|
||||||
import reflex.utils.processes
|
import reflex.utils.processes
|
||||||
from reflex.app import EventNamespace
|
from reflex.state import State, StateManagerMemory, StateManagerRedis
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from selenium import webdriver # pyright: ignore [reportMissingImports]
|
from selenium import webdriver # pyright: ignore [reportMissingImports]
|
||||||
@ -109,6 +109,7 @@ class AppHarness:
|
|||||||
frontend_url: Optional[str] = None
|
frontend_url: Optional[str] = None
|
||||||
backend_thread: Optional[threading.Thread] = None
|
backend_thread: Optional[threading.Thread] = None
|
||||||
backend: Optional[uvicorn.Server] = None
|
backend: Optional[uvicorn.Server] = None
|
||||||
|
state_manager: Optional[StateManagerMemory | StateManagerRedis] = None
|
||||||
_frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
|
_frontends: list["WebDriver"] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -162,6 +163,27 @@ class AppHarness:
|
|||||||
reflex.config.get_config(reload=True)
|
reflex.config.get_config(reload=True)
|
||||||
self.app_module = reflex.utils.prerequisites.get_app(reload=True)
|
self.app_module = reflex.utils.prerequisites.get_app(reload=True)
|
||||||
self.app_instance = self.app_module.app
|
self.app_instance = self.app_module.app
|
||||||
|
if isinstance(self.app_instance.state_manager, StateManagerRedis):
|
||||||
|
# Create our own redis connection for testing.
|
||||||
|
self.state_manager = StateManagerRedis.create(self.app_instance.state)
|
||||||
|
else:
|
||||||
|
self.state_manager = self.app_instance.state_manager
|
||||||
|
|
||||||
|
def _get_backend_shutdown_handler(self):
|
||||||
|
if self.backend is None:
|
||||||
|
raise RuntimeError("Backend was not initialized.")
|
||||||
|
|
||||||
|
original_shutdown = self.backend.shutdown
|
||||||
|
|
||||||
|
async def _shutdown_redis(*args, **kwargs) -> None:
|
||||||
|
# ensure redis is closed before event loop
|
||||||
|
if self.app_instance is not None and isinstance(
|
||||||
|
self.app_instance.state_manager, StateManagerRedis
|
||||||
|
):
|
||||||
|
await self.app_instance.state_manager.redis.close()
|
||||||
|
await original_shutdown(*args, **kwargs)
|
||||||
|
|
||||||
|
return _shutdown_redis
|
||||||
|
|
||||||
def _start_backend(self, port=0):
|
def _start_backend(self, port=0):
|
||||||
if self.app_instance is None:
|
if self.app_instance is None:
|
||||||
@ -173,6 +195,7 @@ class AppHarness:
|
|||||||
port=port,
|
port=port,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.backend.shutdown = self._get_backend_shutdown_handler()
|
||||||
self.backend_thread = threading.Thread(target=self.backend.run)
|
self.backend_thread = threading.Thread(target=self.backend.run)
|
||||||
self.backend_thread.start()
|
self.backend_thread.start()
|
||||||
|
|
||||||
@ -296,6 +319,35 @@ class AppHarness:
|
|||||||
time.sleep(step)
|
time.sleep(step)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _poll_for_async(
|
||||||
|
target: Callable[[], Coroutine[None, None, T]],
|
||||||
|
timeout: TimeoutType = None,
|
||||||
|
step: TimeoutType = None,
|
||||||
|
) -> T | bool:
|
||||||
|
"""Generic polling logic for async functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: callable that returns truthy if polling condition is met.
|
||||||
|
timeout: max polling time
|
||||||
|
step: interval between checking target()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return value of target() if truthy within timeout
|
||||||
|
False if timeout elapses
|
||||||
|
"""
|
||||||
|
if timeout is None:
|
||||||
|
timeout = DEFAULT_TIMEOUT
|
||||||
|
if step is None:
|
||||||
|
step = POLL_INTERVAL
|
||||||
|
deadline = time.time() + timeout
|
||||||
|
while time.time() < deadline:
|
||||||
|
success = await target()
|
||||||
|
if success:
|
||||||
|
return success
|
||||||
|
await asyncio.sleep(step)
|
||||||
|
return False
|
||||||
|
|
||||||
def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
|
def _poll_for_servers(self, timeout: TimeoutType = None) -> socket.socket:
|
||||||
"""Poll backend server for listening sockets.
|
"""Poll backend server for listening sockets.
|
||||||
|
|
||||||
@ -351,39 +403,76 @@ class AppHarness:
|
|||||||
self._frontends.append(driver)
|
self._frontends.append(driver)
|
||||||
return driver
|
return driver
|
||||||
|
|
||||||
async def emit_state_updates(self) -> list[Any]:
|
async def get_state(self, token: str) -> State:
|
||||||
"""Send any backend state deltas to the frontend.
|
"""Get the state associated with the given token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The state token to look up.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of awaited response from each EventNamespace.emit() call.
|
The state instance associated with the given token
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: when the app hasn't started running
|
RuntimeError: when the app hasn't started running
|
||||||
"""
|
"""
|
||||||
if self.app_instance is None or self.app_instance.sio is None:
|
if self.state_manager is None:
|
||||||
|
raise RuntimeError("state_manager is not set.")
|
||||||
|
try:
|
||||||
|
return await self.state_manager.get_state(token)
|
||||||
|
finally:
|
||||||
|
if isinstance(self.state_manager, StateManagerRedis):
|
||||||
|
await self.state_manager.redis.close()
|
||||||
|
|
||||||
|
async def set_state(self, token: str, **kwargs) -> None:
|
||||||
|
"""Set the state associated with the given token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The state token to set.
|
||||||
|
kwargs: Attributes to set on the state.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: when the app hasn't started running
|
||||||
|
"""
|
||||||
|
if self.state_manager is None:
|
||||||
|
raise RuntimeError("state_manager is not set.")
|
||||||
|
state = await self.get_state(token)
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(state, key, value)
|
||||||
|
try:
|
||||||
|
await self.state_manager.set_state(token, state)
|
||||||
|
finally:
|
||||||
|
if isinstance(self.state_manager, StateManagerRedis):
|
||||||
|
await self.state_manager.redis.close()
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def modify_state(self, token: str) -> AsyncIterator[State]:
|
||||||
|
"""Modify the state associated with the given token and send update to frontend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The state token to modify
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The state instance associated with the given token
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: when the app hasn't started running
|
||||||
|
"""
|
||||||
|
if self.state_manager is None:
|
||||||
|
raise RuntimeError("state_manager is not set.")
|
||||||
|
if self.app_instance is None:
|
||||||
raise RuntimeError("App is not running.")
|
raise RuntimeError("App is not running.")
|
||||||
event_ns: EventNamespace = cast(
|
app_state_manager = self.app_instance.state_manager
|
||||||
EventNamespace,
|
if isinstance(self.state_manager, StateManagerRedis):
|
||||||
self.app_instance.event_namespace,
|
# Temporarily replace the app's state manager with our own, since
|
||||||
)
|
# the redis connection is on the backend_thread event loop
|
||||||
pending: list[Coroutine[Any, Any, Any]] = []
|
self.app_instance.state_manager = self.state_manager
|
||||||
for state in self.app_instance.state_manager.states.values():
|
try:
|
||||||
delta = state.get_delta()
|
async with self.app_instance.modify_state(token) as state:
|
||||||
if delta:
|
yield state
|
||||||
update = reflex.state.StateUpdate(delta=delta, events=[], final=True)
|
finally:
|
||||||
state._clean()
|
if isinstance(self.state_manager, StateManagerRedis):
|
||||||
# Emit the event.
|
self.app_instance.state_manager = app_state_manager
|
||||||
pending.append(
|
await self.state_manager.redis.close()
|
||||||
event_ns.emit(
|
|
||||||
str(reflex.constants.SocketEvent.EVENT),
|
|
||||||
update.json(),
|
|
||||||
to=state.get_sid(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
responses = []
|
|
||||||
for request in pending:
|
|
||||||
responses.append(await request)
|
|
||||||
return responses
|
|
||||||
|
|
||||||
def poll_for_content(
|
def poll_for_content(
|
||||||
self,
|
self,
|
||||||
@ -457,6 +546,9 @@ class AppHarness:
|
|||||||
if self.app_instance is None:
|
if self.app_instance is None:
|
||||||
raise RuntimeError("App is not running.")
|
raise RuntimeError("App is not running.")
|
||||||
state_manager = self.app_instance.state_manager
|
state_manager = self.app_instance.state_manager
|
||||||
|
assert isinstance(
|
||||||
|
state_manager, StateManagerMemory
|
||||||
|
), "Only works with memory state manager"
|
||||||
if not self._poll_for(
|
if not self._poll_for(
|
||||||
target=lambda: state_manager.states,
|
target=lambda: state_manager.states,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
@ -534,7 +626,6 @@ class Subdir404TCPServer(socketserver.TCPServer):
|
|||||||
request: the requesting socket
|
request: the requesting socket
|
||||||
client_address: (host, port) referring to the client’s address.
|
client_address: (host, port) referring to the client’s address.
|
||||||
"""
|
"""
|
||||||
print(client_address, type(client_address))
|
|
||||||
self.RequestHandlerClass(
|
self.RequestHandlerClass(
|
||||||
request,
|
request,
|
||||||
client_address,
|
client_address,
|
||||||
@ -605,6 +696,7 @@ class AppHarnessProd(AppHarness):
|
|||||||
workers=reflex.utils.processes.get_num_workers(),
|
workers=reflex.utils.processes.get_num_workers(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
self.backend.shutdown = self._get_backend_shutdown_handler()
|
||||||
self.backend_thread = threading.Thread(target=self.backend.run)
|
self.backend_thread = threading.Thread(target=self.backend.run)
|
||||||
self.backend_thread.start()
|
self.backend_thread.start()
|
||||||
|
|
||||||
|
@ -5,3 +5,11 @@ class InvalidStylePropError(TypeError):
|
|||||||
"""Custom Type Error when style props have invalid values."""
|
"""Custom Type Error when style props have invalid values."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ImmutableStateError(AttributeError):
|
||||||
|
"""Raised when a background task attempts to modify state outside of context."""
|
||||||
|
|
||||||
|
|
||||||
|
class LockExpiredError(Exception):
|
||||||
|
"""Raised when the state lock expires while an event is being processed."""
|
||||||
|
@ -21,7 +21,7 @@ import httpx
|
|||||||
import typer
|
import typer
|
||||||
from alembic.util.exc import CommandError
|
from alembic.util.exc import CommandError
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from redis import Redis
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
from reflex import constants, model
|
from reflex import constants, model
|
||||||
from reflex.compiler import templates
|
from reflex.compiler import templates
|
||||||
@ -124,9 +124,11 @@ def get_redis() -> Redis | None:
|
|||||||
The redis client.
|
The redis client.
|
||||||
"""
|
"""
|
||||||
config = get_config()
|
config = get_config()
|
||||||
if config.redis_url is None:
|
if not config.redis_url:
|
||||||
return None
|
return None
|
||||||
redis_url, redis_port = config.redis_url.split(":")
|
redis_url, has_port, redis_port = config.redis_url.partition(":")
|
||||||
|
if not has_port:
|
||||||
|
redis_port = 6379
|
||||||
console.info(f"Using redis at {config.redis_url}")
|
console.info(f"Using redis at {config.redis_url}")
|
||||||
return Redis(host=redis_url, port=int(redis_port), db=0)
|
return Redis(host=redis_url, port=int(redis_port), db=0)
|
||||||
|
|
||||||
|
@ -2,8 +2,9 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Generator, List, Set, Union
|
from typing import Dict, Generator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -11,6 +12,14 @@ import reflex as rx
|
|||||||
from reflex.app import App
|
from reflex.app import App
|
||||||
from reflex.event import EventSpec
|
from reflex.event import EventSpec
|
||||||
|
|
||||||
|
from .states import (
|
||||||
|
DictMutationTestState,
|
||||||
|
ListMutationTestState,
|
||||||
|
MutableTestState,
|
||||||
|
SubUploadState,
|
||||||
|
UploadState,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def app() -> App:
|
def app() -> App:
|
||||||
@ -39,60 +48,7 @@ def list_mutation_state():
|
|||||||
Returns:
|
Returns:
|
||||||
A state with list mutation features.
|
A state with list mutation features.
|
||||||
"""
|
"""
|
||||||
|
return ListMutationTestState()
|
||||||
class TestState(rx.State):
|
|
||||||
"""The test state."""
|
|
||||||
|
|
||||||
# plain list
|
|
||||||
plain_friends = ["Tommy"]
|
|
||||||
|
|
||||||
def make_friend(self):
|
|
||||||
self.plain_friends.append("another-fd")
|
|
||||||
|
|
||||||
def change_first_friend(self):
|
|
||||||
self.plain_friends[0] = "Jenny"
|
|
||||||
|
|
||||||
def unfriend_all_friends(self):
|
|
||||||
self.plain_friends.clear()
|
|
||||||
|
|
||||||
def unfriend_first_friend(self):
|
|
||||||
del self.plain_friends[0]
|
|
||||||
|
|
||||||
def remove_last_friend(self):
|
|
||||||
self.plain_friends.pop()
|
|
||||||
|
|
||||||
def make_friends_with_colleagues(self):
|
|
||||||
colleagues = ["Peter", "Jimmy"]
|
|
||||||
self.plain_friends.extend(colleagues)
|
|
||||||
|
|
||||||
def remove_tommy(self):
|
|
||||||
self.plain_friends.remove("Tommy")
|
|
||||||
|
|
||||||
# list in dict
|
|
||||||
friends_in_dict = {"Tommy": ["Jenny"]}
|
|
||||||
|
|
||||||
def remove_jenny_from_tommy(self):
|
|
||||||
self.friends_in_dict["Tommy"].remove("Jenny")
|
|
||||||
|
|
||||||
def add_jimmy_to_tommy_friends(self):
|
|
||||||
self.friends_in_dict["Tommy"].append("Jimmy")
|
|
||||||
|
|
||||||
def tommy_has_no_fds(self):
|
|
||||||
self.friends_in_dict["Tommy"].clear()
|
|
||||||
|
|
||||||
# nested list
|
|
||||||
friends_in_nested_list = [["Tommy"], ["Jenny"]]
|
|
||||||
|
|
||||||
def remove_first_group(self):
|
|
||||||
self.friends_in_nested_list.pop(0)
|
|
||||||
|
|
||||||
def remove_first_person_from_first_group(self):
|
|
||||||
self.friends_in_nested_list[0].pop(0)
|
|
||||||
|
|
||||||
def add_jimmy_to_second_group(self):
|
|
||||||
self.friends_in_nested_list[1].append("Jimmy")
|
|
||||||
|
|
||||||
return TestState()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -102,85 +58,7 @@ def dict_mutation_state():
|
|||||||
Returns:
|
Returns:
|
||||||
A state with dict mutation features.
|
A state with dict mutation features.
|
||||||
"""
|
"""
|
||||||
|
return DictMutationTestState()
|
||||||
class TestState(rx.State):
|
|
||||||
"""The test state."""
|
|
||||||
|
|
||||||
# plain dict
|
|
||||||
details = {"name": "Tommy"}
|
|
||||||
|
|
||||||
def add_age(self):
|
|
||||||
self.details.update({"age": 20}) # type: ignore
|
|
||||||
|
|
||||||
def change_name(self):
|
|
||||||
self.details["name"] = "Jenny"
|
|
||||||
|
|
||||||
def remove_last_detail(self):
|
|
||||||
self.details.popitem()
|
|
||||||
|
|
||||||
def clear_details(self):
|
|
||||||
self.details.clear()
|
|
||||||
|
|
||||||
def remove_name(self):
|
|
||||||
del self.details["name"]
|
|
||||||
|
|
||||||
def pop_out_age(self):
|
|
||||||
self.details.pop("age")
|
|
||||||
|
|
||||||
# dict in list
|
|
||||||
address = [{"home": "home address"}, {"work": "work address"}]
|
|
||||||
|
|
||||||
def remove_home_address(self):
|
|
||||||
self.address[0].pop("home")
|
|
||||||
|
|
||||||
def add_street_to_home_address(self):
|
|
||||||
self.address[0]["street"] = "street address"
|
|
||||||
|
|
||||||
# nested dict
|
|
||||||
friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
|
|
||||||
|
|
||||||
def change_friend_name(self):
|
|
||||||
self.friend_in_nested_dict["friend"]["name"] = "Tommy"
|
|
||||||
|
|
||||||
def remove_friend(self):
|
|
||||||
self.friend_in_nested_dict.pop("friend")
|
|
||||||
|
|
||||||
def add_friend_age(self):
|
|
||||||
self.friend_in_nested_dict["friend"]["age"] = 30
|
|
||||||
|
|
||||||
return TestState()
|
|
||||||
|
|
||||||
|
|
||||||
class UploadState(rx.State):
|
|
||||||
"""The base state for uploading a file."""
|
|
||||||
|
|
||||||
async def handle_upload1(self, files: List[rx.UploadFile]):
|
|
||||||
"""Handle the upload of a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: The uploaded files.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BaseState(rx.State):
|
|
||||||
"""The test base state."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SubUploadState(BaseState):
|
|
||||||
"""The test substate."""
|
|
||||||
|
|
||||||
img: str
|
|
||||||
|
|
||||||
async def handle_upload(self, files: List[rx.UploadFile]):
|
|
||||||
"""Handle the upload of a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: The uploaded files.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -203,187 +81,6 @@ def upload_event_spec():
|
|||||||
return EventSpec(handler=UploadState.handle_upload1, upload=True) # type: ignore
|
return EventSpec(handler=UploadState.handle_upload1, upload=True) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def upload_state(tmp_path):
|
|
||||||
"""Create upload state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tmp_path: pytest tmp_path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The state
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
class FileUploadState(rx.State):
|
|
||||||
"""The base state for uploading a file."""
|
|
||||||
|
|
||||||
img_list: List[str]
|
|
||||||
|
|
||||||
async def handle_upload2(self, files):
|
|
||||||
"""Handle the upload of a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: The uploaded files.
|
|
||||||
"""
|
|
||||||
for file in files:
|
|
||||||
upload_data = await file.read()
|
|
||||||
outfile = f"{tmp_path}/{file.filename}"
|
|
||||||
|
|
||||||
# Save the file.
|
|
||||||
with open(outfile, "wb") as file_object:
|
|
||||||
file_object.write(upload_data)
|
|
||||||
|
|
||||||
# Update the img var.
|
|
||||||
self.img_list.append(file.filename)
|
|
||||||
|
|
||||||
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
|
||||||
"""Handle the upload of a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: The uploaded files.
|
|
||||||
"""
|
|
||||||
for file in files:
|
|
||||||
upload_data = await file.read()
|
|
||||||
outfile = f"{tmp_path}/{file.filename}"
|
|
||||||
|
|
||||||
# Save the file.
|
|
||||||
with open(outfile, "wb") as file_object:
|
|
||||||
file_object.write(upload_data)
|
|
||||||
|
|
||||||
# Update the img var.
|
|
||||||
assert file.filename is not None
|
|
||||||
self.img_list.append(file.filename)
|
|
||||||
|
|
||||||
return FileUploadState
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def upload_sub_state(tmp_path):
|
|
||||||
"""Create upload substate.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tmp_path: pytest tmp_path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The state
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
class FileState(rx.State):
|
|
||||||
"""The base state."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
class FileUploadState(FileState):
|
|
||||||
"""The substate for uploading a file."""
|
|
||||||
|
|
||||||
img_list: List[str]
|
|
||||||
|
|
||||||
async def handle_upload2(self, files):
|
|
||||||
"""Handle the upload of a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: The uploaded files.
|
|
||||||
"""
|
|
||||||
for file in files:
|
|
||||||
upload_data = await file.read()
|
|
||||||
outfile = f"{tmp_path}/{file.filename}"
|
|
||||||
|
|
||||||
# Save the file.
|
|
||||||
with open(outfile, "wb") as file_object:
|
|
||||||
file_object.write(upload_data)
|
|
||||||
|
|
||||||
# Update the img var.
|
|
||||||
self.img_list.append(file.filename)
|
|
||||||
|
|
||||||
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
|
||||||
"""Handle the upload of a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: The uploaded files.
|
|
||||||
"""
|
|
||||||
for file in files:
|
|
||||||
upload_data = await file.read()
|
|
||||||
outfile = f"{tmp_path}/{file.filename}"
|
|
||||||
|
|
||||||
# Save the file.
|
|
||||||
with open(outfile, "wb") as file_object:
|
|
||||||
file_object.write(upload_data)
|
|
||||||
|
|
||||||
# Update the img var.
|
|
||||||
assert file.filename is not None
|
|
||||||
self.img_list.append(file.filename)
|
|
||||||
|
|
||||||
return FileUploadState
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def upload_grand_sub_state(tmp_path):
|
|
||||||
"""Create upload grand-state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tmp_path: pytest tmp_path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The state
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
class BaseFileState(rx.State):
|
|
||||||
"""The base state."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
class FileSubState(BaseFileState):
|
|
||||||
"""The substate."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
class FileUploadState(FileSubState):
|
|
||||||
"""The grand-substate for uploading a file."""
|
|
||||||
|
|
||||||
img_list: List[str]
|
|
||||||
|
|
||||||
async def handle_upload2(self, files):
|
|
||||||
"""Handle the upload of a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: The uploaded files.
|
|
||||||
"""
|
|
||||||
for file in files:
|
|
||||||
upload_data = await file.read()
|
|
||||||
outfile = f"{tmp_path}/{file.filename}"
|
|
||||||
|
|
||||||
# Save the file.
|
|
||||||
with open(outfile, "wb") as file_object:
|
|
||||||
file_object.write(upload_data)
|
|
||||||
|
|
||||||
# Update the img var.
|
|
||||||
assert file.filename is not None
|
|
||||||
self.img_list.append(file.filename)
|
|
||||||
|
|
||||||
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
|
||||||
"""Handle the upload of a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: The uploaded files.
|
|
||||||
"""
|
|
||||||
for file in files:
|
|
||||||
upload_data = await file.read()
|
|
||||||
outfile = f"{tmp_path}/{file.filename}"
|
|
||||||
|
|
||||||
# Save the file.
|
|
||||||
with open(outfile, "wb") as file_object:
|
|
||||||
file_object.write(upload_data)
|
|
||||||
|
|
||||||
# Update the img var.
|
|
||||||
assert file.filename is not None
|
|
||||||
self.img_list.append(file.filename)
|
|
||||||
|
|
||||||
return FileUploadState
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def base_config_values() -> Dict:
|
def base_config_values() -> Dict:
|
||||||
"""Get base config values.
|
"""Get base config values.
|
||||||
@ -418,35 +115,6 @@ def sqlite_db_config_values(base_db_config_values) -> Dict:
|
|||||||
return base_db_config_values
|
return base_db_config_values
|
||||||
|
|
||||||
|
|
||||||
class GenState(rx.State):
|
|
||||||
"""A state with event handlers that generate multiple updates."""
|
|
||||||
|
|
||||||
value: int
|
|
||||||
|
|
||||||
def go(self, c: int):
|
|
||||||
"""Increment the value c times and update each time.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
c: The number of times to increment.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
After each increment.
|
|
||||||
"""
|
|
||||||
for _ in range(c):
|
|
||||||
self.value += 1
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def gen_state() -> GenState:
|
|
||||||
"""A state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A test state.
|
|
||||||
"""
|
|
||||||
return GenState # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def router_data_headers() -> Dict[str, str]:
|
def router_data_headers() -> Dict[str, str]:
|
||||||
"""Router data headers.
|
"""Router data headers.
|
||||||
@ -546,46 +214,19 @@ def mutable_state():
|
|||||||
Returns:
|
Returns:
|
||||||
A state object.
|
A state object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class OtherBase(rx.Base):
|
|
||||||
bar: str = ""
|
|
||||||
|
|
||||||
class CustomVar(rx.Base):
|
|
||||||
foo: str = ""
|
|
||||||
array: List[str] = []
|
|
||||||
hashmap: Dict[str, str] = {}
|
|
||||||
test_set: Set[str] = set()
|
|
||||||
custom: OtherBase = OtherBase()
|
|
||||||
|
|
||||||
class MutableTestState(rx.State):
|
|
||||||
"""A test state."""
|
|
||||||
|
|
||||||
array: List[Union[str, List, Dict[str, str]]] = [
|
|
||||||
"value",
|
|
||||||
[1, 2, 3],
|
|
||||||
{"key": "value"},
|
|
||||||
]
|
|
||||||
hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
|
|
||||||
"key": ["list", "of", "values"],
|
|
||||||
"another_key": "another_value",
|
|
||||||
"third_key": {"key": "value"},
|
|
||||||
}
|
|
||||||
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
|
|
||||||
custom: CustomVar = CustomVar()
|
|
||||||
_be_custom: CustomVar = CustomVar()
|
|
||||||
|
|
||||||
def reassign_mutables(self):
|
|
||||||
self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
|
|
||||||
self.hashmap = {
|
|
||||||
"mod_key": ["list", "of", "values"],
|
|
||||||
"mod_another_key": "another_value",
|
|
||||||
"mod_third_key": {"key": "value"},
|
|
||||||
}
|
|
||||||
self.test_set = {1, 2, 3, 4, "five"}
|
|
||||||
|
|
||||||
return MutableTestState()
|
return MutableTestState()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def token() -> str:
|
||||||
|
"""Create a token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A fresh/unique token string.
|
||||||
|
"""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def duplicate_substate():
|
def duplicate_substate():
|
||||||
"""Create a Test state that has duplicate child substates.
|
"""Create a Test state that has duplicate child substates.
|
||||||
|
30
tests/states/__init__.py
Normal file
30
tests/states/__init__.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Common rx.State subclasses for use in tests."""
|
||||||
|
import reflex as rx
|
||||||
|
|
||||||
|
from .mutation import DictMutationTestState, ListMutationTestState, MutableTestState
|
||||||
|
from .upload import (
|
||||||
|
ChildFileUploadState,
|
||||||
|
FileUploadState,
|
||||||
|
GrandChildFileUploadState,
|
||||||
|
SubUploadState,
|
||||||
|
UploadState,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GenState(rx.State):
|
||||||
|
"""A state with event handlers that generate multiple updates."""
|
||||||
|
|
||||||
|
value: int
|
||||||
|
|
||||||
|
def go(self, c: int):
|
||||||
|
"""Increment the value c times and update each time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c: The number of times to increment.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
After each increment.
|
||||||
|
"""
|
||||||
|
for _ in range(c):
|
||||||
|
self.value += 1
|
||||||
|
yield
|
172
tests/states/mutation.py
Normal file
172
tests/states/mutation.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
"""Test states for mutable vars."""
|
||||||
|
|
||||||
|
from typing import Dict, List, Set, Union
|
||||||
|
|
||||||
|
import reflex as rx
|
||||||
|
|
||||||
|
|
||||||
|
class DictMutationTestState(rx.State):
|
||||||
|
"""A state for testing ReflexDict mutation."""
|
||||||
|
|
||||||
|
# plain dict
|
||||||
|
details = {"name": "Tommy"}
|
||||||
|
|
||||||
|
def add_age(self):
|
||||||
|
"""Add an age to the dict."""
|
||||||
|
self.details.update({"age": 20}) # type: ignore
|
||||||
|
|
||||||
|
def change_name(self):
|
||||||
|
"""Change the name in the dict."""
|
||||||
|
self.details["name"] = "Jenny"
|
||||||
|
|
||||||
|
def remove_last_detail(self):
|
||||||
|
"""Remove the last item in the dict."""
|
||||||
|
self.details.popitem()
|
||||||
|
|
||||||
|
def clear_details(self):
|
||||||
|
"""Clear the dict."""
|
||||||
|
self.details.clear()
|
||||||
|
|
||||||
|
def remove_name(self):
|
||||||
|
"""Remove the name from the dict."""
|
||||||
|
del self.details["name"]
|
||||||
|
|
||||||
|
def pop_out_age(self):
|
||||||
|
"""Pop out the age from the dict."""
|
||||||
|
self.details.pop("age")
|
||||||
|
|
||||||
|
# dict in list
|
||||||
|
address = [{"home": "home address"}, {"work": "work address"}]
|
||||||
|
|
||||||
|
def remove_home_address(self):
|
||||||
|
"""Remove the home address from dict in the list."""
|
||||||
|
self.address[0].pop("home")
|
||||||
|
|
||||||
|
def add_street_to_home_address(self):
|
||||||
|
"""Set street key in the dict in the list."""
|
||||||
|
self.address[0]["street"] = "street address"
|
||||||
|
|
||||||
|
# nested dict
|
||||||
|
friend_in_nested_dict = {"name": "Nikhil", "friend": {"name": "Alek"}}
|
||||||
|
|
||||||
|
def change_friend_name(self):
|
||||||
|
"""Change the friend's name in the nested dict."""
|
||||||
|
self.friend_in_nested_dict["friend"]["name"] = "Tommy"
|
||||||
|
|
||||||
|
def remove_friend(self):
|
||||||
|
"""Remove the friend from the nested dict."""
|
||||||
|
self.friend_in_nested_dict.pop("friend")
|
||||||
|
|
||||||
|
def add_friend_age(self):
|
||||||
|
"""Add an age to the friend in the nested dict."""
|
||||||
|
self.friend_in_nested_dict["friend"]["age"] = 30
|
||||||
|
|
||||||
|
|
||||||
|
class ListMutationTestState(rx.State):
|
||||||
|
"""A state for testing ReflexList mutation."""
|
||||||
|
|
||||||
|
# plain list
|
||||||
|
plain_friends = ["Tommy"]
|
||||||
|
|
||||||
|
def make_friend(self):
|
||||||
|
"""Add a friend to the list."""
|
||||||
|
self.plain_friends.append("another-fd")
|
||||||
|
|
||||||
|
def change_first_friend(self):
|
||||||
|
"""Change the first friend in the list."""
|
||||||
|
self.plain_friends[0] = "Jenny"
|
||||||
|
|
||||||
|
def unfriend_all_friends(self):
|
||||||
|
"""Unfriend all friends in the list."""
|
||||||
|
self.plain_friends.clear()
|
||||||
|
|
||||||
|
def unfriend_first_friend(self):
|
||||||
|
"""Unfriend the first friend in the list."""
|
||||||
|
del self.plain_friends[0]
|
||||||
|
|
||||||
|
def remove_last_friend(self):
|
||||||
|
"""Remove the last friend in the list."""
|
||||||
|
self.plain_friends.pop()
|
||||||
|
|
||||||
|
def make_friends_with_colleagues(self):
|
||||||
|
"""Add list of friends to the list."""
|
||||||
|
colleagues = ["Peter", "Jimmy"]
|
||||||
|
self.plain_friends.extend(colleagues)
|
||||||
|
|
||||||
|
def remove_tommy(self):
|
||||||
|
"""Remove Tommy from the list."""
|
||||||
|
self.plain_friends.remove("Tommy")
|
||||||
|
|
||||||
|
# list in dict
|
||||||
|
friends_in_dict = {"Tommy": ["Jenny"]}
|
||||||
|
|
||||||
|
def remove_jenny_from_tommy(self):
|
||||||
|
"""Remove Jenny from Tommy's friends list."""
|
||||||
|
self.friends_in_dict["Tommy"].remove("Jenny")
|
||||||
|
|
||||||
|
def add_jimmy_to_tommy_friends(self):
|
||||||
|
"""Add Jimmy to Tommy's friends list."""
|
||||||
|
self.friends_in_dict["Tommy"].append("Jimmy")
|
||||||
|
|
||||||
|
def tommy_has_no_fds(self):
|
||||||
|
"""Clear Tommy's friends list."""
|
||||||
|
self.friends_in_dict["Tommy"].clear()
|
||||||
|
|
||||||
|
# nested list
|
||||||
|
friends_in_nested_list = [["Tommy"], ["Jenny"]]
|
||||||
|
|
||||||
|
def remove_first_group(self):
|
||||||
|
"""Remove the first group of friends from the nested list."""
|
||||||
|
self.friends_in_nested_list.pop(0)
|
||||||
|
|
||||||
|
def remove_first_person_from_first_group(self):
|
||||||
|
"""Remove the first person from the first group of friends in the nested list."""
|
||||||
|
self.friends_in_nested_list[0].pop(0)
|
||||||
|
|
||||||
|
def add_jimmy_to_second_group(self):
|
||||||
|
"""Add Jimmy to the second group of friends in the nested list."""
|
||||||
|
self.friends_in_nested_list[1].append("Jimmy")
|
||||||
|
|
||||||
|
|
||||||
|
class OtherBase(rx.Base):
|
||||||
|
"""A Base model with a str field."""
|
||||||
|
|
||||||
|
bar: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class CustomVar(rx.Base):
|
||||||
|
"""A Base model with multiple fields."""
|
||||||
|
|
||||||
|
foo: str = ""
|
||||||
|
array: List[str] = []
|
||||||
|
hashmap: Dict[str, str] = {}
|
||||||
|
test_set: Set[str] = set()
|
||||||
|
custom: OtherBase = OtherBase()
|
||||||
|
|
||||||
|
|
||||||
|
class MutableTestState(rx.State):
|
||||||
|
"""A test state."""
|
||||||
|
|
||||||
|
array: List[Union[str, List, Dict[str, str]]] = [
|
||||||
|
"value",
|
||||||
|
[1, 2, 3],
|
||||||
|
{"key": "value"},
|
||||||
|
]
|
||||||
|
hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
|
||||||
|
"key": ["list", "of", "values"],
|
||||||
|
"another_key": "another_value",
|
||||||
|
"third_key": {"key": "value"},
|
||||||
|
}
|
||||||
|
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
|
||||||
|
custom: CustomVar = CustomVar()
|
||||||
|
_be_custom: CustomVar = CustomVar()
|
||||||
|
|
||||||
|
def reassign_mutables(self):
|
||||||
|
"""Assign mutable fields to different values."""
|
||||||
|
self.array = ["modified_value", [1, 2, 3], {"mod_key": "mod_value"}]
|
||||||
|
self.hashmap = {
|
||||||
|
"mod_key": ["list", "of", "values"],
|
||||||
|
"mod_another_key": "another_value",
|
||||||
|
"mod_third_key": {"key": "value"},
|
||||||
|
}
|
||||||
|
self.test_set = {1, 2, 3, 4, "five"}
|
175
tests/states/upload.py
Normal file
175
tests/states/upload.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
"""Test states for upload-related tests."""
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import ClassVar, List
|
||||||
|
|
||||||
|
import reflex as rx
|
||||||
|
|
||||||
|
|
||||||
|
class UploadState(rx.State):
|
||||||
|
"""The base state for uploading a file."""
|
||||||
|
|
||||||
|
async def handle_upload1(self, files: List[rx.UploadFile]):
|
||||||
|
"""Handle the upload of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: The uploaded files.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseState(rx.State):
|
||||||
|
"""The test base state."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SubUploadState(BaseState):
|
||||||
|
"""The test substate."""
|
||||||
|
|
||||||
|
img: str
|
||||||
|
|
||||||
|
async def handle_upload(self, files: List[rx.UploadFile]):
|
||||||
|
"""Handle the upload of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: The uploaded files.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FileUploadState(rx.State):
|
||||||
|
"""The base state for uploading a file."""
|
||||||
|
|
||||||
|
img_list: List[str]
|
||||||
|
_tmp_path: ClassVar[Path]
|
||||||
|
|
||||||
|
async def handle_upload2(self, files):
|
||||||
|
"""Handle the upload of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: The uploaded files.
|
||||||
|
"""
|
||||||
|
for file in files:
|
||||||
|
upload_data = await file.read()
|
||||||
|
outfile = f"{self._tmp_path}/{file.filename}"
|
||||||
|
|
||||||
|
# Save the file.
|
||||||
|
with open(outfile, "wb") as file_object:
|
||||||
|
file_object.write(upload_data)
|
||||||
|
|
||||||
|
# Update the img var.
|
||||||
|
self.img_list.append(file.filename)
|
||||||
|
|
||||||
|
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
||||||
|
"""Handle the upload of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: The uploaded files.
|
||||||
|
"""
|
||||||
|
for file in files:
|
||||||
|
upload_data = await file.read()
|
||||||
|
outfile = f"{self._tmp_path}/{file.filename}"
|
||||||
|
|
||||||
|
# Save the file.
|
||||||
|
with open(outfile, "wb") as file_object:
|
||||||
|
file_object.write(upload_data)
|
||||||
|
|
||||||
|
# Update the img var.
|
||||||
|
assert file.filename is not None
|
||||||
|
self.img_list.append(file.filename)
|
||||||
|
|
||||||
|
|
||||||
|
class FileStateBase1(rx.State):
|
||||||
|
"""The base state for a child FileUploadState."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChildFileUploadState(FileStateBase1):
|
||||||
|
"""The child state for uploading a file."""
|
||||||
|
|
||||||
|
img_list: List[str]
|
||||||
|
_tmp_path: ClassVar[Path]
|
||||||
|
|
||||||
|
async def handle_upload2(self, files):
|
||||||
|
"""Handle the upload of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: The uploaded files.
|
||||||
|
"""
|
||||||
|
for file in files:
|
||||||
|
upload_data = await file.read()
|
||||||
|
outfile = f"{self._tmp_path}/{file.filename}"
|
||||||
|
|
||||||
|
# Save the file.
|
||||||
|
with open(outfile, "wb") as file_object:
|
||||||
|
file_object.write(upload_data)
|
||||||
|
|
||||||
|
# Update the img var.
|
||||||
|
self.img_list.append(file.filename)
|
||||||
|
|
||||||
|
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
||||||
|
"""Handle the upload of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: The uploaded files.
|
||||||
|
"""
|
||||||
|
for file in files:
|
||||||
|
upload_data = await file.read()
|
||||||
|
outfile = f"{self._tmp_path}/{file.filename}"
|
||||||
|
|
||||||
|
# Save the file.
|
||||||
|
with open(outfile, "wb") as file_object:
|
||||||
|
file_object.write(upload_data)
|
||||||
|
|
||||||
|
# Update the img var.
|
||||||
|
assert file.filename is not None
|
||||||
|
self.img_list.append(file.filename)
|
||||||
|
|
||||||
|
|
||||||
|
class FileStateBase2(FileStateBase1):
|
||||||
|
"""The parent state for a grandchild FileUploadState."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GrandChildFileUploadState(FileStateBase2):
|
||||||
|
"""The child state for uploading a file."""
|
||||||
|
|
||||||
|
img_list: List[str]
|
||||||
|
_tmp_path: ClassVar[Path]
|
||||||
|
|
||||||
|
async def handle_upload2(self, files):
|
||||||
|
"""Handle the upload of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: The uploaded files.
|
||||||
|
"""
|
||||||
|
for file in files:
|
||||||
|
upload_data = await file.read()
|
||||||
|
outfile = f"{self._tmp_path}/{file.filename}"
|
||||||
|
|
||||||
|
# Save the file.
|
||||||
|
with open(outfile, "wb") as file_object:
|
||||||
|
file_object.write(upload_data)
|
||||||
|
|
||||||
|
# Update the img var.
|
||||||
|
self.img_list.append(file.filename)
|
||||||
|
|
||||||
|
async def multi_handle_upload(self, files: List[rx.UploadFile]):
|
||||||
|
"""Handle the upload of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: The uploaded files.
|
||||||
|
"""
|
||||||
|
for file in files:
|
||||||
|
upload_data = await file.read()
|
||||||
|
outfile = f"{self._tmp_path}/{file.filename}"
|
||||||
|
|
||||||
|
# Save the file.
|
||||||
|
with open(outfile, "wb") as file_object:
|
||||||
|
file_object.write(upload_data)
|
||||||
|
|
||||||
|
# Update the img var.
|
||||||
|
assert file.filename is not None
|
||||||
|
self.img_list.append(file.filename)
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import io
|
import io
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
|
import uuid
|
||||||
from typing import List, Tuple, Type
|
from typing import List, Tuple, Type
|
||||||
|
|
||||||
if sys.version_info.major >= 3 and sys.version_info.minor > 7:
|
if sys.version_info.major >= 3 and sys.version_info.minor > 7:
|
||||||
@ -30,11 +31,18 @@ from reflex.components import Box, Component, Cond, Fragment, Text
|
|||||||
from reflex.event import Event, get_hydrate_event
|
from reflex.event import Event, get_hydrate_event
|
||||||
from reflex.middleware import HydrateMiddleware
|
from reflex.middleware import HydrateMiddleware
|
||||||
from reflex.model import Model
|
from reflex.model import Model
|
||||||
from reflex.state import State, StateUpdate
|
from reflex.state import State, StateManagerRedis, StateUpdate
|
||||||
from reflex.style import Style
|
from reflex.style import Style
|
||||||
from reflex.utils import format
|
from reflex.utils import format
|
||||||
from reflex.vars import ComputedVar
|
from reflex.vars import ComputedVar
|
||||||
|
|
||||||
|
from .states import (
|
||||||
|
ChildFileUploadState,
|
||||||
|
FileUploadState,
|
||||||
|
GenState,
|
||||||
|
GrandChildFileUploadState,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def index_page():
|
def index_page():
|
||||||
@ -64,6 +72,12 @@ def about_page():
|
|||||||
return about
|
return about
|
||||||
|
|
||||||
|
|
||||||
|
class ATestState(State):
|
||||||
|
"""A simple state for testing."""
|
||||||
|
|
||||||
|
var: int
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def test_state() -> Type[State]:
|
def test_state() -> Type[State]:
|
||||||
"""A default state.
|
"""A default state.
|
||||||
@ -71,11 +85,7 @@ def test_state() -> Type[State]:
|
|||||||
Returns:
|
Returns:
|
||||||
A default state.
|
A default state.
|
||||||
"""
|
"""
|
||||||
|
return ATestState
|
||||||
class TestState(State):
|
|
||||||
var: int
|
|
||||||
|
|
||||||
return TestState
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@ -313,23 +323,28 @@ def test_initialize_admin_dashboard_with_view_overrides(test_model):
|
|||||||
assert app.admin_dash.view_overrides[test_model] == TestModelView
|
assert app.admin_dash.view_overrides[test_model] == TestModelView
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_with_state(test_state):
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_with_state(test_state: Type[ATestState], token: str):
|
||||||
"""Test setting the state of an app.
|
"""Test setting the state of an app.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
test_state: The default state.
|
test_state: The default state.
|
||||||
|
token: a Token.
|
||||||
"""
|
"""
|
||||||
app = App(state=test_state)
|
app = App(state=test_state)
|
||||||
assert app.state == test_state
|
assert app.state == test_state
|
||||||
|
|
||||||
# Get a state for a given token.
|
# Get a state for a given token.
|
||||||
token = "token"
|
state = await app.state_manager.get_state(token)
|
||||||
state = app.state_manager.get_state(token)
|
|
||||||
assert isinstance(state, test_state)
|
assert isinstance(state, test_state)
|
||||||
assert state.var == 0 # type: ignore
|
assert state.var == 0 # type: ignore
|
||||||
|
|
||||||
|
if isinstance(app.state_manager, StateManagerRedis):
|
||||||
|
await app.state_manager.redis.close()
|
||||||
|
|
||||||
def test_set_and_get_state(test_state):
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_and_get_state(test_state):
|
||||||
"""Test setting and getting the state of an app with different tokens.
|
"""Test setting and getting the state of an app with different tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -338,47 +353,51 @@ def test_set_and_get_state(test_state):
|
|||||||
app = App(state=test_state)
|
app = App(state=test_state)
|
||||||
|
|
||||||
# Create two tokens.
|
# Create two tokens.
|
||||||
token1 = "token1"
|
token1 = str(uuid.uuid4())
|
||||||
token2 = "token2"
|
token2 = str(uuid.uuid4())
|
||||||
|
|
||||||
# Get the default state for each token.
|
# Get the default state for each token.
|
||||||
state1 = app.state_manager.get_state(token1)
|
state1 = await app.state_manager.get_state(token1)
|
||||||
state2 = app.state_manager.get_state(token2)
|
state2 = await app.state_manager.get_state(token2)
|
||||||
assert state1.var == 0 # type: ignore
|
assert state1.var == 0 # type: ignore
|
||||||
assert state2.var == 0 # type: ignore
|
assert state2.var == 0 # type: ignore
|
||||||
|
|
||||||
# Set the vars to different values.
|
# Set the vars to different values.
|
||||||
state1.var = 1
|
state1.var = 1
|
||||||
state2.var = 2
|
state2.var = 2
|
||||||
app.state_manager.set_state(token1, state1)
|
await app.state_manager.set_state(token1, state1)
|
||||||
app.state_manager.set_state(token2, state2)
|
await app.state_manager.set_state(token2, state2)
|
||||||
|
|
||||||
# Get the states again and check the values.
|
# Get the states again and check the values.
|
||||||
state1 = app.state_manager.get_state(token1)
|
state1 = await app.state_manager.get_state(token1)
|
||||||
state2 = app.state_manager.get_state(token2)
|
state2 = await app.state_manager.get_state(token2)
|
||||||
assert state1.var == 1 # type: ignore
|
assert state1.var == 1 # type: ignore
|
||||||
assert state2.var == 2 # type: ignore
|
assert state2.var == 2 # type: ignore
|
||||||
|
|
||||||
|
if isinstance(app.state_manager, StateManagerRedis):
|
||||||
|
await app.state_manager.redis.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_dynamic_var_event(test_state):
|
async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
|
||||||
"""Test that the default handler of a dynamic generated var
|
"""Test that the default handler of a dynamic generated var
|
||||||
works as expected.
|
works as expected.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
test_state: State Fixture.
|
test_state: State Fixture.
|
||||||
|
token: a Token.
|
||||||
"""
|
"""
|
||||||
test_state = test_state()
|
state = test_state() # type: ignore
|
||||||
test_state.add_var("int_val", int, 0)
|
state.add_var("int_val", int, 0)
|
||||||
result = await test_state._process(
|
result = await state._process(
|
||||||
Event(
|
Event(
|
||||||
token="fake-token",
|
token=token,
|
||||||
name="test_state.set_int_val",
|
name=f"{test_state.get_name()}.set_int_val",
|
||||||
router_data={"pathname": "/", "query": {}},
|
router_data={"pathname": "/", "query": {}},
|
||||||
payload={"value": 50},
|
payload={"value": 50},
|
||||||
)
|
)
|
||||||
).__anext__()
|
).__anext__()
|
||||||
assert result.delta == {"test_state": {"int_val": 50}}
|
assert result.delta == {test_state.get_name(): {"int_val": 50}}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -388,12 +407,20 @@ async def test_dynamic_var_event(test_state):
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"test_state.make_friend",
|
"list_mutation_test_state.make_friend",
|
||||||
{"test_state": {"plain_friends": ["Tommy", "another-fd"]}},
|
{
|
||||||
|
"list_mutation_test_state": {
|
||||||
|
"plain_friends": ["Tommy", "another-fd"]
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.change_first_friend",
|
"list_mutation_test_state.change_first_friend",
|
||||||
{"test_state": {"plain_friends": ["Jenny", "another-fd"]}},
|
{
|
||||||
|
"list_mutation_test_state": {
|
||||||
|
"plain_friends": ["Jenny", "another-fd"]
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
id="append then __setitem__",
|
id="append then __setitem__",
|
||||||
@ -401,12 +428,12 @@ async def test_dynamic_var_event(test_state):
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"test_state.unfriend_first_friend",
|
"list_mutation_test_state.unfriend_first_friend",
|
||||||
{"test_state": {"plain_friends": []}},
|
{"list_mutation_test_state": {"plain_friends": []}},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.make_friend",
|
"list_mutation_test_state.make_friend",
|
||||||
{"test_state": {"plain_friends": ["another-fd"]}},
|
{"list_mutation_test_state": {"plain_friends": ["another-fd"]}},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
id="delitem then append",
|
id="delitem then append",
|
||||||
@ -414,20 +441,24 @@ async def test_dynamic_var_event(test_state):
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"test_state.make_friends_with_colleagues",
|
"list_mutation_test_state.make_friends_with_colleagues",
|
||||||
{"test_state": {"plain_friends": ["Tommy", "Peter", "Jimmy"]}},
|
{
|
||||||
|
"list_mutation_test_state": {
|
||||||
|
"plain_friends": ["Tommy", "Peter", "Jimmy"]
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.remove_tommy",
|
"list_mutation_test_state.remove_tommy",
|
||||||
{"test_state": {"plain_friends": ["Peter", "Jimmy"]}},
|
{"list_mutation_test_state": {"plain_friends": ["Peter", "Jimmy"]}},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.remove_last_friend",
|
"list_mutation_test_state.remove_last_friend",
|
||||||
{"test_state": {"plain_friends": ["Peter"]}},
|
{"list_mutation_test_state": {"plain_friends": ["Peter"]}},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.unfriend_all_friends",
|
"list_mutation_test_state.unfriend_all_friends",
|
||||||
{"test_state": {"plain_friends": []}},
|
{"list_mutation_test_state": {"plain_friends": []}},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
id="extend, remove, pop, clear",
|
id="extend, remove, pop, clear",
|
||||||
@ -435,24 +466,28 @@ async def test_dynamic_var_event(test_state):
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"test_state.add_jimmy_to_second_group",
|
"list_mutation_test_state.add_jimmy_to_second_group",
|
||||||
{
|
{
|
||||||
"test_state": {
|
"list_mutation_test_state": {
|
||||||
"friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]
|
"friends_in_nested_list": [["Tommy"], ["Jenny", "Jimmy"]]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.remove_first_person_from_first_group",
|
"list_mutation_test_state.remove_first_person_from_first_group",
|
||||||
{
|
{
|
||||||
"test_state": {
|
"list_mutation_test_state": {
|
||||||
"friends_in_nested_list": [[], ["Jenny", "Jimmy"]]
|
"friends_in_nested_list": [[], ["Jenny", "Jimmy"]]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.remove_first_group",
|
"list_mutation_test_state.remove_first_group",
|
||||||
{"test_state": {"friends_in_nested_list": [["Jenny", "Jimmy"]]}},
|
{
|
||||||
|
"list_mutation_test_state": {
|
||||||
|
"friends_in_nested_list": [["Jenny", "Jimmy"]]
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
id="nested list",
|
id="nested list",
|
||||||
@ -460,16 +495,24 @@ async def test_dynamic_var_event(test_state):
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"test_state.add_jimmy_to_tommy_friends",
|
"list_mutation_test_state.add_jimmy_to_tommy_friends",
|
||||||
{"test_state": {"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}}},
|
{
|
||||||
|
"list_mutation_test_state": {
|
||||||
|
"friends_in_dict": {"Tommy": ["Jenny", "Jimmy"]}
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.remove_jenny_from_tommy",
|
"list_mutation_test_state.remove_jenny_from_tommy",
|
||||||
{"test_state": {"friends_in_dict": {"Tommy": ["Jimmy"]}}},
|
{
|
||||||
|
"list_mutation_test_state": {
|
||||||
|
"friends_in_dict": {"Tommy": ["Jimmy"]}
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.tommy_has_no_fds",
|
"list_mutation_test_state.tommy_has_no_fds",
|
||||||
{"test_state": {"friends_in_dict": {"Tommy": []}}},
|
{"list_mutation_test_state": {"friends_in_dict": {"Tommy": []}}},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
id="list in dict",
|
id="list in dict",
|
||||||
@ -477,7 +520,9 @@ async def test_dynamic_var_event(test_state):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_list_mutation_detection__plain_list(
|
async def test_list_mutation_detection__plain_list(
|
||||||
event_tuples: List[Tuple[str, List[str]]], list_mutation_state: State
|
event_tuples: List[Tuple[str, List[str]]],
|
||||||
|
list_mutation_state: State,
|
||||||
|
token: str,
|
||||||
):
|
):
|
||||||
"""Test list mutation detection
|
"""Test list mutation detection
|
||||||
when reassignment is not explicitly included in the logic.
|
when reassignment is not explicitly included in the logic.
|
||||||
@ -485,11 +530,12 @@ async def test_list_mutation_detection__plain_list(
|
|||||||
Args:
|
Args:
|
||||||
event_tuples: From parametrization.
|
event_tuples: From parametrization.
|
||||||
list_mutation_state: A state with list mutation features.
|
list_mutation_state: A state with list mutation features.
|
||||||
|
token: a Token.
|
||||||
"""
|
"""
|
||||||
for event_name, expected_delta in event_tuples:
|
for event_name, expected_delta in event_tuples:
|
||||||
result = await list_mutation_state._process(
|
result = await list_mutation_state._process(
|
||||||
Event(
|
Event(
|
||||||
token="fake-token",
|
token=token,
|
||||||
name=event_name,
|
name=event_name,
|
||||||
router_data={"pathname": "/", "query": {}},
|
router_data={"pathname": "/", "query": {}},
|
||||||
payload={},
|
payload={},
|
||||||
@ -506,16 +552,24 @@ async def test_list_mutation_detection__plain_list(
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"test_state.add_age",
|
"dict_mutation_test_state.add_age",
|
||||||
{"test_state": {"details": {"name": "Tommy", "age": 20}}},
|
{
|
||||||
|
"dict_mutation_test_state": {
|
||||||
|
"details": {"name": "Tommy", "age": 20}
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.change_name",
|
"dict_mutation_test_state.change_name",
|
||||||
{"test_state": {"details": {"name": "Jenny", "age": 20}}},
|
{
|
||||||
|
"dict_mutation_test_state": {
|
||||||
|
"details": {"name": "Jenny", "age": 20}
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.remove_last_detail",
|
"dict_mutation_test_state.remove_last_detail",
|
||||||
{"test_state": {"details": {"name": "Jenny"}}},
|
{"dict_mutation_test_state": {"details": {"name": "Jenny"}}},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
id="update then __setitem__",
|
id="update then __setitem__",
|
||||||
@ -523,12 +577,12 @@ async def test_list_mutation_detection__plain_list(
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"test_state.clear_details",
|
"dict_mutation_test_state.clear_details",
|
||||||
{"test_state": {"details": {}}},
|
{"dict_mutation_test_state": {"details": {}}},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.add_age",
|
"dict_mutation_test_state.add_age",
|
||||||
{"test_state": {"details": {"age": 20}}},
|
{"dict_mutation_test_state": {"details": {"age": 20}}},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
id="delitem then update",
|
id="delitem then update",
|
||||||
@ -536,16 +590,20 @@ async def test_list_mutation_detection__plain_list(
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"test_state.add_age",
|
"dict_mutation_test_state.add_age",
|
||||||
{"test_state": {"details": {"name": "Tommy", "age": 20}}},
|
{
|
||||||
|
"dict_mutation_test_state": {
|
||||||
|
"details": {"name": "Tommy", "age": 20}
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.remove_name",
|
"dict_mutation_test_state.remove_name",
|
||||||
{"test_state": {"details": {"age": 20}}},
|
{"dict_mutation_test_state": {"details": {"age": 20}}},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.pop_out_age",
|
"dict_mutation_test_state.pop_out_age",
|
||||||
{"test_state": {"details": {}}},
|
{"dict_mutation_test_state": {"details": {}}},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
id="add, remove, pop",
|
id="add, remove, pop",
|
||||||
@ -553,13 +611,17 @@ async def test_list_mutation_detection__plain_list(
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"test_state.remove_home_address",
|
"dict_mutation_test_state.remove_home_address",
|
||||||
{"test_state": {"address": [{}, {"work": "work address"}]}},
|
{
|
||||||
|
"dict_mutation_test_state": {
|
||||||
|
"address": [{}, {"work": "work address"}]
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.add_street_to_home_address",
|
"dict_mutation_test_state.add_street_to_home_address",
|
||||||
{
|
{
|
||||||
"test_state": {
|
"dict_mutation_test_state": {
|
||||||
"address": [
|
"address": [
|
||||||
{"street": "street address"},
|
{"street": "street address"},
|
||||||
{"work": "work address"},
|
{"work": "work address"},
|
||||||
@ -573,9 +635,9 @@ async def test_list_mutation_detection__plain_list(
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"test_state.change_friend_name",
|
"dict_mutation_test_state.change_friend_name",
|
||||||
{
|
{
|
||||||
"test_state": {
|
"dict_mutation_test_state": {
|
||||||
"friend_in_nested_dict": {
|
"friend_in_nested_dict": {
|
||||||
"name": "Nikhil",
|
"name": "Nikhil",
|
||||||
"friend": {"name": "Tommy"},
|
"friend": {"name": "Tommy"},
|
||||||
@ -584,9 +646,9 @@ async def test_list_mutation_detection__plain_list(
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.add_friend_age",
|
"dict_mutation_test_state.add_friend_age",
|
||||||
{
|
{
|
||||||
"test_state": {
|
"dict_mutation_test_state": {
|
||||||
"friend_in_nested_dict": {
|
"friend_in_nested_dict": {
|
||||||
"name": "Nikhil",
|
"name": "Nikhil",
|
||||||
"friend": {"name": "Tommy", "age": 30},
|
"friend": {"name": "Tommy", "age": 30},
|
||||||
@ -595,8 +657,12 @@ async def test_list_mutation_detection__plain_list(
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"test_state.remove_friend",
|
"dict_mutation_test_state.remove_friend",
|
||||||
{"test_state": {"friend_in_nested_dict": {"name": "Nikhil"}}},
|
{
|
||||||
|
"dict_mutation_test_state": {
|
||||||
|
"friend_in_nested_dict": {"name": "Nikhil"}
|
||||||
|
}
|
||||||
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
id="nested dict",
|
id="nested dict",
|
||||||
@ -604,7 +670,9 @@ async def test_list_mutation_detection__plain_list(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_dict_mutation_detection__plain_list(
|
async def test_dict_mutation_detection__plain_list(
|
||||||
event_tuples: List[Tuple[str, List[str]]], dict_mutation_state: State
|
event_tuples: List[Tuple[str, List[str]]],
|
||||||
|
dict_mutation_state: State,
|
||||||
|
token: str,
|
||||||
):
|
):
|
||||||
"""Test dict mutation detection
|
"""Test dict mutation detection
|
||||||
when reassignment is not explicitly included in the logic.
|
when reassignment is not explicitly included in the logic.
|
||||||
@ -612,11 +680,12 @@ async def test_dict_mutation_detection__plain_list(
|
|||||||
Args:
|
Args:
|
||||||
event_tuples: From parametrization.
|
event_tuples: From parametrization.
|
||||||
dict_mutation_state: A state with dict mutation features.
|
dict_mutation_state: A state with dict mutation features.
|
||||||
|
token: a Token.
|
||||||
"""
|
"""
|
||||||
for event_name, expected_delta in event_tuples:
|
for event_name, expected_delta in event_tuples:
|
||||||
result = await dict_mutation_state._process(
|
result = await dict_mutation_state._process(
|
||||||
Event(
|
Event(
|
||||||
token="fake-token",
|
token=token,
|
||||||
name=event_name,
|
name=event_name,
|
||||||
router_data={"pathname": "/", "query": {}},
|
router_data={"pathname": "/", "query": {}},
|
||||||
payload={},
|
payload={},
|
||||||
@ -628,41 +697,43 @@ async def test_dict_mutation_detection__plain_list(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"fixture, delta",
|
("state", "delta"),
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"upload_state",
|
FileUploadState,
|
||||||
{"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
|
{"file_upload_state": {"img_list": ["image1.jpg", "image2.jpg"]}},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"upload_sub_state",
|
ChildFileUploadState,
|
||||||
{
|
{
|
||||||
"file_state.file_upload_state": {
|
"file_state_base1.child_file_upload_state": {
|
||||||
"img_list": ["image1.jpg", "image2.jpg"]
|
"img_list": ["image1.jpg", "image2.jpg"]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"upload_grand_sub_state",
|
GrandChildFileUploadState,
|
||||||
{
|
{
|
||||||
"base_file_state.file_sub_state.file_upload_state": {
|
"file_state_base1.file_state_base2.grand_child_file_upload_state": {
|
||||||
"img_list": ["image1.jpg", "image2.jpg"]
|
"img_list": ["image1.jpg", "image2.jpg"]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_upload_file(fixture, request, delta):
|
async def test_upload_file(tmp_path, state, delta, token: str):
|
||||||
"""Test that file upload works correctly.
|
"""Test that file upload works correctly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fixture: The state.
|
tmp_path: Temporary path.
|
||||||
request: Fixture request.
|
state: The state class.
|
||||||
delta: Expected delta
|
delta: Expected delta
|
||||||
|
token: a Token.
|
||||||
"""
|
"""
|
||||||
app = App(state=request.getfixturevalue(fixture))
|
state._tmp_path = tmp_path
|
||||||
|
app = App(state=state)
|
||||||
app.event_namespace.emit = AsyncMock() # type: ignore
|
app.event_namespace.emit = AsyncMock() # type: ignore
|
||||||
current_state = app.state_manager.get_state("token")
|
current_state = await app.state_manager.get_state(token)
|
||||||
data = b"This is binary data"
|
data = b"This is binary data"
|
||||||
|
|
||||||
# Create a binary IO object and write data to it
|
# Create a binary IO object and write data to it
|
||||||
@ -670,11 +741,11 @@ async def test_upload_file(fixture, request, delta):
|
|||||||
bio.write(data)
|
bio.write(data)
|
||||||
|
|
||||||
file1 = UploadFile(
|
file1 = UploadFile(
|
||||||
filename="token:file_upload_state.multi_handle_upload:True:image1.jpg",
|
filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image1.jpg",
|
||||||
file=bio,
|
file=bio,
|
||||||
)
|
)
|
||||||
file2 = UploadFile(
|
file2 = UploadFile(
|
||||||
filename="token:file_upload_state.multi_handle_upload:True:image2.jpg",
|
filename=f"{token}:{state.get_name()}.multi_handle_upload:True:image2.jpg",
|
||||||
file=bio,
|
file=bio,
|
||||||
)
|
)
|
||||||
upload_fn = upload(app)
|
upload_fn = upload(app)
|
||||||
@ -684,22 +755,27 @@ async def test_upload_file(fixture, request, delta):
|
|||||||
app.event_namespace.emit.assert_called_with( # type: ignore
|
app.event_namespace.emit.assert_called_with( # type: ignore
|
||||||
"event", state_update.json(), to=current_state.get_sid()
|
"event", state_update.json(), to=current_state.get_sid()
|
||||||
)
|
)
|
||||||
assert app.state_manager.get_state("token").dict()["img_list"] == [
|
assert (await app.state_manager.get_state(token)).dict()["img_list"] == [
|
||||||
"image1.jpg",
|
"image1.jpg",
|
||||||
"image2.jpg",
|
"image2.jpg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if isinstance(app.state_manager, StateManagerRedis):
|
||||||
|
await app.state_manager.redis.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"fixture", ["upload_state", "upload_sub_state", "upload_grand_sub_state"]
|
"state",
|
||||||
|
[FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
|
||||||
)
|
)
|
||||||
async def test_upload_file_without_annotation(fixture, request):
|
async def test_upload_file_without_annotation(state, tmp_path, token):
|
||||||
"""Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile].
|
"""Test that an error is thrown when there's no param annotated with rx.UploadFile or List[UploadFile].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fixture: The state.
|
state: The state class.
|
||||||
request: Fixture request.
|
tmp_path: Temporary path.
|
||||||
|
token: a Token.
|
||||||
"""
|
"""
|
||||||
data = b"This is binary data"
|
data = b"This is binary data"
|
||||||
|
|
||||||
@ -707,14 +783,15 @@ async def test_upload_file_without_annotation(fixture, request):
|
|||||||
bio = io.BytesIO()
|
bio = io.BytesIO()
|
||||||
bio.write(data)
|
bio.write(data)
|
||||||
|
|
||||||
app = App(state=request.getfixturevalue(fixture))
|
state._tmp_path = tmp_path
|
||||||
|
app = App(state=state)
|
||||||
|
|
||||||
file1 = UploadFile(
|
file1 = UploadFile(
|
||||||
filename="token:file_upload_state.handle_upload2:True:image1.jpg",
|
filename=f"{token}:{state.get_name()}.handle_upload2:True:image1.jpg",
|
||||||
file=bio,
|
file=bio,
|
||||||
)
|
)
|
||||||
file2 = UploadFile(
|
file2 = UploadFile(
|
||||||
filename="token:file_upload_state.handle_upload2:True:image2.jpg",
|
filename=f"{token}:{state.get_name()}.handle_upload2:True:image2.jpg",
|
||||||
file=bio,
|
file=bio,
|
||||||
)
|
)
|
||||||
fn = upload(app)
|
fn = upload(app)
|
||||||
@ -722,9 +799,12 @@ async def test_upload_file_without_annotation(fixture, request):
|
|||||||
await fn([file1, file2])
|
await fn([file1, file2])
|
||||||
assert (
|
assert (
|
||||||
err.value.args[0]
|
err.value.args[0]
|
||||||
== "`file_upload_state.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
|
== f"`{state.get_name()}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(app.state_manager, StateManagerRedis):
|
||||||
|
await app.state_manager.redis.close()
|
||||||
|
|
||||||
|
|
||||||
class DynamicState(State):
|
class DynamicState(State):
|
||||||
"""State class for testing dynamic route var.
|
"""State class for testing dynamic route var.
|
||||||
@ -768,6 +848,7 @@ class DynamicState(State):
|
|||||||
async def test_dynamic_route_var_route_change_completed_on_load(
|
async def test_dynamic_route_var_route_change_completed_on_load(
|
||||||
index_page,
|
index_page,
|
||||||
windows_platform: bool,
|
windows_platform: bool,
|
||||||
|
token: str,
|
||||||
):
|
):
|
||||||
"""Create app with dynamic route var, and simulate navigation.
|
"""Create app with dynamic route var, and simulate navigation.
|
||||||
|
|
||||||
@ -777,6 +858,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
Args:
|
Args:
|
||||||
index_page: The index page.
|
index_page: The index page.
|
||||||
windows_platform: Whether the system is windows.
|
windows_platform: Whether the system is windows.
|
||||||
|
token: a Token.
|
||||||
"""
|
"""
|
||||||
arg_name = "dynamic"
|
arg_name = "dynamic"
|
||||||
route = f"/test/[{arg_name}]"
|
route = f"/test/[{arg_name}]"
|
||||||
@ -792,10 +874,9 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
}
|
}
|
||||||
assert constants.ROUTER_DATA in app.state().computed_var_dependencies
|
assert constants.ROUTER_DATA in app.state().computed_var_dependencies
|
||||||
|
|
||||||
token = "mock_token"
|
|
||||||
sid = "mock_sid"
|
sid = "mock_sid"
|
||||||
client_ip = "127.0.0.1"
|
client_ip = "127.0.0.1"
|
||||||
state = app.state_manager.get_state(token)
|
state = await app.state_manager.get_state(token)
|
||||||
assert state.dynamic == ""
|
assert state.dynamic == ""
|
||||||
exp_vals = ["foo", "foobar", "baz"]
|
exp_vals = ["foo", "foobar", "baz"]
|
||||||
|
|
||||||
@ -817,6 +898,7 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prev_exp_val = ""
|
||||||
for exp_index, exp_val in enumerate(exp_vals):
|
for exp_index, exp_val in enumerate(exp_vals):
|
||||||
hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
|
hydrate_event = _event(name=get_hydrate_event(state), val=exp_val)
|
||||||
exp_router_data = {
|
exp_router_data = {
|
||||||
@ -826,13 +908,14 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
"token": token,
|
"token": token,
|
||||||
**hydrate_event.router_data,
|
**hydrate_event.router_data,
|
||||||
}
|
}
|
||||||
update = await process(
|
process_coro = process(
|
||||||
app,
|
app,
|
||||||
event=hydrate_event,
|
event=hydrate_event,
|
||||||
sid=sid,
|
sid=sid,
|
||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
).__anext__() # type: ignore
|
)
|
||||||
|
update = await process_coro.__anext__() # type: ignore
|
||||||
|
|
||||||
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
|
# route change triggers: [full state dict, call on_load events, call set_is_hydrated(True)]
|
||||||
assert update == StateUpdate(
|
assert update == StateUpdate(
|
||||||
@ -860,14 +943,27 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
if isinstance(app.state_manager, StateManagerRedis):
|
||||||
|
# When redis is used, the state is not updated until the processing is complete
|
||||||
|
state = await app.state_manager.get_state(token)
|
||||||
|
assert state.dynamic == prev_exp_val
|
||||||
|
|
||||||
|
# complete the processing
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await process_coro.__anext__() # type: ignore
|
||||||
|
|
||||||
|
# check that router data was written to the state_manager store
|
||||||
|
state = await app.state_manager.get_state(token)
|
||||||
assert state.dynamic == exp_val
|
assert state.dynamic == exp_val
|
||||||
on_load_update = await process(
|
|
||||||
|
process_coro = process(
|
||||||
app,
|
app,
|
||||||
event=_dynamic_state_event(name="on_load", val=exp_val),
|
event=_dynamic_state_event(name="on_load", val=exp_val),
|
||||||
sid=sid,
|
sid=sid,
|
||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
).__anext__() # type: ignore
|
)
|
||||||
|
on_load_update = await process_coro.__anext__() # type: ignore
|
||||||
assert on_load_update == StateUpdate(
|
assert on_load_update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
state.get_name(): {
|
state.get_name(): {
|
||||||
@ -879,7 +975,10 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
},
|
},
|
||||||
events=[],
|
events=[],
|
||||||
)
|
)
|
||||||
on_set_is_hydrated_update = await process(
|
# complete the processing
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await process_coro.__anext__() # type: ignore
|
||||||
|
process_coro = process(
|
||||||
app,
|
app,
|
||||||
event=_dynamic_state_event(
|
event=_dynamic_state_event(
|
||||||
name="set_is_hydrated", payload={"value": True}, val=exp_val
|
name="set_is_hydrated", payload={"value": True}, val=exp_val
|
||||||
@ -887,7 +986,8 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
sid=sid,
|
sid=sid,
|
||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
).__anext__() # type: ignore
|
)
|
||||||
|
on_set_is_hydrated_update = await process_coro.__anext__() # type: ignore
|
||||||
assert on_set_is_hydrated_update == StateUpdate(
|
assert on_set_is_hydrated_update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
state.get_name(): {
|
state.get_name(): {
|
||||||
@ -899,15 +999,19 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
},
|
},
|
||||||
events=[],
|
events=[],
|
||||||
)
|
)
|
||||||
|
# complete the processing
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await process_coro.__anext__() # type: ignore
|
||||||
|
|
||||||
# a simple state update event should NOT trigger on_load or route var side effects
|
# a simple state update event should NOT trigger on_load or route var side effects
|
||||||
update = await process(
|
process_coro = process(
|
||||||
app,
|
app,
|
||||||
event=_dynamic_state_event(name="on_counter", val=exp_val),
|
event=_dynamic_state_event(name="on_counter", val=exp_val),
|
||||||
sid=sid,
|
sid=sid,
|
||||||
headers={},
|
headers={},
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
).__anext__() # type: ignore
|
)
|
||||||
|
update = await process_coro.__anext__() # type: ignore
|
||||||
assert update == StateUpdate(
|
assert update == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
state.get_name(): {
|
state.get_name(): {
|
||||||
@ -919,42 +1023,54 @@ async def test_dynamic_route_var_route_change_completed_on_load(
|
|||||||
},
|
},
|
||||||
events=[],
|
events=[],
|
||||||
)
|
)
|
||||||
|
# complete the processing
|
||||||
|
with pytest.raises(StopAsyncIteration):
|
||||||
|
await process_coro.__anext__() # type: ignore
|
||||||
|
|
||||||
|
prev_exp_val = exp_val
|
||||||
|
state = await app.state_manager.get_state(token)
|
||||||
assert state.loaded == len(exp_vals)
|
assert state.loaded == len(exp_vals)
|
||||||
assert state.counter == len(exp_vals)
|
assert state.counter == len(exp_vals)
|
||||||
# print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
|
# print(f"Expected {exp_vals} rendering side effects, got {state.side_effect_counter}")
|
||||||
# assert state.side_effect_counter == len(exp_vals)
|
# assert state.side_effect_counter == len(exp_vals)
|
||||||
|
|
||||||
|
if isinstance(app.state_manager, StateManagerRedis):
|
||||||
|
await app.state_manager.redis.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_events(gen_state, mocker):
|
async def test_process_events(mocker, token: str):
|
||||||
"""Test that an event is processed properly and that it is postprocessed
|
"""Test that an event is processed properly and that it is postprocessed
|
||||||
n+1 times. Also check that the processing flag of the last stateupdate is set to
|
n+1 times. Also check that the processing flag of the last stateupdate is set to
|
||||||
False.
|
False.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gen_state: The state.
|
|
||||||
mocker: mocker object.
|
mocker: mocker object.
|
||||||
|
token: a Token.
|
||||||
"""
|
"""
|
||||||
router_data = {
|
router_data = {
|
||||||
"pathname": "/",
|
"pathname": "/",
|
||||||
"query": {},
|
"query": {},
|
||||||
"token": "mock_token",
|
"token": token,
|
||||||
"sid": "mock_sid",
|
"sid": "mock_sid",
|
||||||
"headers": {},
|
"headers": {},
|
||||||
"ip": "127.0.0.1",
|
"ip": "127.0.0.1",
|
||||||
}
|
}
|
||||||
app = App(state=gen_state)
|
app = App(state=GenState)
|
||||||
mocker.patch.object(app, "postprocess", AsyncMock())
|
mocker.patch.object(app, "postprocess", AsyncMock())
|
||||||
event = Event(
|
event = Event(
|
||||||
token="token", name="gen_state.go", payload={"c": 5}, router_data=router_data
|
token=token, name="gen_state.go", payload={"c": 5}, router_data=router_data
|
||||||
)
|
)
|
||||||
|
|
||||||
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): # type: ignore
|
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): # type: ignore
|
||||||
pass
|
pass
|
||||||
|
|
||||||
assert app.state_manager.get_state("token").value == 5
|
assert (await app.state_manager.get_state(token)).value == 5
|
||||||
assert app.postprocess.call_count == 6
|
assert app.postprocess.call_count == 6
|
||||||
|
|
||||||
|
if isinstance(app.state_manager, StateManagerRedis):
|
||||||
|
await app.state_manager.redis.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("state", "overlay_component", "exp_page_child"),
|
("state", "overlay_component", "exp_page_child"),
|
||||||
|
@ -1,22 +1,42 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict, List
|
from typing import Dict, Generator, List
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from plotly.graph_objects import Figure
|
from plotly.graph_objects import Figure
|
||||||
|
|
||||||
import reflex as rx
|
import reflex as rx
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.constants import IS_HYDRATED, RouteVar
|
from reflex.constants import APP_VAR, IS_HYDRATED, RouteVar, SocketEvent
|
||||||
from reflex.event import Event, EventHandler
|
from reflex.event import Event, EventHandler
|
||||||
from reflex.state import MutableProxy, State
|
from reflex.state import (
|
||||||
from reflex.utils import format
|
ImmutableStateError,
|
||||||
|
LockExpiredError,
|
||||||
|
MutableProxy,
|
||||||
|
State,
|
||||||
|
StateManager,
|
||||||
|
StateManagerMemory,
|
||||||
|
StateManagerRedis,
|
||||||
|
StateProxy,
|
||||||
|
StateUpdate,
|
||||||
|
)
|
||||||
|
from reflex.utils import format, prerequisites
|
||||||
from reflex.vars import BaseVar, ComputedVar
|
from reflex.vars import BaseVar, ComputedVar
|
||||||
|
|
||||||
|
from .states import GenState
|
||||||
|
|
||||||
|
CI = bool(os.environ.get("CI", False))
|
||||||
|
LOCK_EXPIRATION = 2000 if CI else 100
|
||||||
|
LOCK_EXPIRE_SLEEP = 2.5 if CI else 0.2
|
||||||
|
|
||||||
|
|
||||||
class Object(Base):
|
class Object(Base):
|
||||||
"""A test object fixture."""
|
"""A test object fixture."""
|
||||||
@ -704,13 +724,9 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_event_generator(gen_state):
|
async def test_process_event_generator():
|
||||||
"""Test event handlers that generate multiple updates.
|
"""Test event handlers that generate multiple updates."""
|
||||||
|
gen_state = GenState() # type: ignore
|
||||||
Args:
|
|
||||||
gen_state: A state.
|
|
||||||
"""
|
|
||||||
gen_state = gen_state()
|
|
||||||
event = Event(
|
event = Event(
|
||||||
token="t",
|
token="t",
|
||||||
name="go",
|
name="go",
|
||||||
@ -1402,6 +1418,396 @@ def test_state_with_invalid_yield():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", params=["in_process", "redis"])
|
||||||
|
def state_manager(request) -> Generator[StateManager, None, None]:
|
||||||
|
"""Instance of state manager parametrized for redis and in-process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: pytest request object.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A state manager instance
|
||||||
|
"""
|
||||||
|
state_manager = StateManager.create(state=TestState)
|
||||||
|
if request.param == "redis":
|
||||||
|
if not isinstance(state_manager, StateManagerRedis):
|
||||||
|
pytest.skip("Test requires redis")
|
||||||
|
else:
|
||||||
|
# explicitly NOT using redis
|
||||||
|
state_manager = StateManagerMemory(state=TestState)
|
||||||
|
assert not state_manager._states_locks
|
||||||
|
|
||||||
|
yield state_manager
|
||||||
|
|
||||||
|
if isinstance(state_manager, StateManagerRedis):
|
||||||
|
asyncio.get_event_loop().run_until_complete(state_manager.redis.close())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_manager_modify_state(state_manager: StateManager, token: str):
|
||||||
|
"""Test that the state manager can modify a state exclusively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_manager: A state manager instance.
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
async with state_manager.modify_state(token):
|
||||||
|
if isinstance(state_manager, StateManagerRedis):
|
||||||
|
assert await state_manager.redis.get(f"{token}_lock")
|
||||||
|
elif isinstance(state_manager, StateManagerMemory):
|
||||||
|
assert token in state_manager._states_locks
|
||||||
|
assert state_manager._states_locks[token].locked()
|
||||||
|
# lock should be dropped after exiting the context
|
||||||
|
if isinstance(state_manager, StateManagerRedis):
|
||||||
|
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||||
|
elif isinstance(state_manager, StateManagerMemory):
|
||||||
|
assert not state_manager._states_locks[token].locked()
|
||||||
|
|
||||||
|
# separate instances should NOT share locks
|
||||||
|
sm2 = StateManagerMemory(state=TestState)
|
||||||
|
assert sm2._state_manager_lock is state_manager._state_manager_lock
|
||||||
|
assert not sm2._states_locks
|
||||||
|
if state_manager._states_locks:
|
||||||
|
assert sm2._states_locks != state_manager._states_locks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_manager_contend(state_manager: StateManager, token: str):
|
||||||
|
"""Multiple coroutines attempting to access the same state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_manager: A state manager instance.
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
n_coroutines = 10
|
||||||
|
exp_num1 = 10
|
||||||
|
|
||||||
|
async with state_manager.modify_state(token) as state:
|
||||||
|
state.num1 = 0
|
||||||
|
|
||||||
|
async def _coro():
|
||||||
|
async with state_manager.modify_state(token) as state:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
state.num1 += 1
|
||||||
|
|
||||||
|
tasks = [asyncio.create_task(_coro()) for _ in range(n_coroutines)]
|
||||||
|
|
||||||
|
for f in asyncio.as_completed(tasks):
|
||||||
|
await f
|
||||||
|
|
||||||
|
assert (await state_manager.get_state(token)).num1 == exp_num1
|
||||||
|
|
||||||
|
if isinstance(state_manager, StateManagerRedis):
|
||||||
|
assert (await state_manager.redis.get(f"{token}_lock")) is None
|
||||||
|
elif isinstance(state_manager, StateManagerMemory):
|
||||||
|
assert token in state_manager._states_locks
|
||||||
|
assert not state_manager._states_locks[token].locked()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def state_manager_redis() -> Generator[StateManager, None, None]:
|
||||||
|
"""Instance of state manager for redis only.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A state manager instance
|
||||||
|
"""
|
||||||
|
state_manager = StateManager.create(TestState)
|
||||||
|
|
||||||
|
if not isinstance(state_manager, StateManagerRedis):
|
||||||
|
pytest.skip("Test requires redis")
|
||||||
|
|
||||||
|
yield state_manager
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(state_manager.redis.close())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_manager_lock_expire(state_manager_redis: StateManager, token: str):
|
||||||
|
"""Test that the state manager lock expires and raises exception exiting context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_manager_redis: A state manager instance.
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||||
|
|
||||||
|
async with state_manager_redis.modify_state(token):
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
with pytest.raises(LockExpiredError):
|
||||||
|
async with state_manager_redis.modify_state(token):
|
||||||
|
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_manager_lock_expire_contend(
|
||||||
|
state_manager_redis: StateManager, token: str
|
||||||
|
):
|
||||||
|
"""Test that the state manager lock expires and queued waiters proceed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_manager_redis: A state manager instance.
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
exp_num1 = 4252
|
||||||
|
unexp_num1 = 666
|
||||||
|
|
||||||
|
state_manager_redis.lock_expiration = LOCK_EXPIRATION
|
||||||
|
|
||||||
|
order = []
|
||||||
|
|
||||||
|
async def _coro_blocker():
|
||||||
|
async with state_manager_redis.modify_state(token) as state:
|
||||||
|
order.append("blocker")
|
||||||
|
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
|
||||||
|
state.num1 = unexp_num1
|
||||||
|
|
||||||
|
async def _coro_waiter():
|
||||||
|
while "blocker" not in order:
|
||||||
|
await asyncio.sleep(0.005)
|
||||||
|
async with state_manager_redis.modify_state(token) as state:
|
||||||
|
order.append("waiter")
|
||||||
|
assert state.num1 != unexp_num1
|
||||||
|
state.num1 = exp_num1
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(_coro_blocker()),
|
||||||
|
asyncio.create_task(_coro_waiter()),
|
||||||
|
]
|
||||||
|
with pytest.raises(LockExpiredError):
|
||||||
|
await tasks[0]
|
||||||
|
await tasks[1]
|
||||||
|
|
||||||
|
assert order == ["blocker", "waiter"]
|
||||||
|
assert (await state_manager_redis.get_state(token)).num1 == exp_num1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def mock_app(monkeypatch, app: rx.App, state_manager: StateManager) -> rx.App:
|
||||||
|
"""Mock app fixture.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
monkeypatch: Pytest monkeypatch object.
|
||||||
|
app: An app.
|
||||||
|
state_manager: A state manager.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The app, after mocking out prerequisites.get_app()
|
||||||
|
"""
|
||||||
|
app_module = Mock()
|
||||||
|
setattr(app_module, APP_VAR, app)
|
||||||
|
app.state = TestState
|
||||||
|
app.state_manager = state_manager
|
||||||
|
assert app.event_namespace is not None
|
||||||
|
app.event_namespace.emit = AsyncMock()
|
||||||
|
monkeypatch.setattr(prerequisites, "get_app", lambda: app_module)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||||
|
"""Test that the state proxy works.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grandchild_state: A grandchild state.
|
||||||
|
mock_app: An app that will be returned by `get_app()`
|
||||||
|
"""
|
||||||
|
child_state = grandchild_state.parent_state
|
||||||
|
assert child_state is not None
|
||||||
|
parent_state = child_state.parent_state
|
||||||
|
assert parent_state is not None
|
||||||
|
if isinstance(mock_app.state_manager, StateManagerMemory):
|
||||||
|
mock_app.state_manager.states[parent_state.get_token()] = parent_state
|
||||||
|
|
||||||
|
sp = StateProxy(grandchild_state)
|
||||||
|
assert sp.__wrapped__ == grandchild_state
|
||||||
|
assert sp._self_substate_path == grandchild_state.get_full_name().split(".")
|
||||||
|
assert sp._self_app is mock_app
|
||||||
|
assert not sp._self_mutable
|
||||||
|
assert sp._self_actx is None
|
||||||
|
|
||||||
|
# cannot use normal contextmanager protocol
|
||||||
|
with pytest.raises(TypeError), sp:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
# cannot directly modify state proxy outside of async context
|
||||||
|
sp.value2 = 16
|
||||||
|
|
||||||
|
async with sp:
|
||||||
|
assert sp._self_actx is not None
|
||||||
|
assert sp._self_mutable # proxy is mutable inside context
|
||||||
|
if isinstance(mock_app.state_manager, StateManagerMemory):
|
||||||
|
# For in-process store, only one instance of the state exists
|
||||||
|
assert sp.__wrapped__ is grandchild_state
|
||||||
|
else:
|
||||||
|
# When redis is used, a new+updated instance is assigned to the proxy
|
||||||
|
assert sp.__wrapped__ is not grandchild_state
|
||||||
|
sp.value2 = 42
|
||||||
|
assert not sp._self_mutable # proxy is not mutable after exiting context
|
||||||
|
assert sp._self_actx is None
|
||||||
|
assert sp.value2 == 42
|
||||||
|
|
||||||
|
# Get the state from the state manager directly and check that the value is updated
|
||||||
|
gotten_state = await mock_app.state_manager.get_state(grandchild_state.get_token())
|
||||||
|
if isinstance(mock_app.state_manager, StateManagerMemory):
|
||||||
|
# For in-process store, only one instance of the state exists
|
||||||
|
assert gotten_state is parent_state
|
||||||
|
else:
|
||||||
|
assert gotten_state is not parent_state
|
||||||
|
gotten_grandchild_state = gotten_state.get_substate(sp._self_substate_path)
|
||||||
|
assert gotten_grandchild_state is not None
|
||||||
|
assert gotten_grandchild_state.value2 == 42
|
||||||
|
|
||||||
|
# ensure state update was emitted
|
||||||
|
assert mock_app.event_namespace is not None
|
||||||
|
mock_app.event_namespace.emit.assert_called_once()
|
||||||
|
mcall = mock_app.event_namespace.emit.mock_calls[0]
|
||||||
|
assert mcall.args[0] == str(SocketEvent.EVENT)
|
||||||
|
assert json.loads(mcall.args[1]) == StateUpdate(
|
||||||
|
delta={
|
||||||
|
parent_state.get_full_name(): {
|
||||||
|
"upper": "",
|
||||||
|
"sum": 3.14,
|
||||||
|
},
|
||||||
|
grandchild_state.get_full_name(): {
|
||||||
|
"value2": 42,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert mcall.kwargs["to"] == grandchild_state.get_sid()
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundTaskState(State):
|
||||||
|
"""A state with a background task."""
|
||||||
|
|
||||||
|
order: List[str] = []
|
||||||
|
dict_list: Dict[str, List[int]] = {"foo": []}
|
||||||
|
|
||||||
|
@rx.background
|
||||||
|
async def background_task(self):
|
||||||
|
"""A background task that updates the state."""
|
||||||
|
async with self:
|
||||||
|
assert not self.order
|
||||||
|
self.order.append("background_task:start")
|
||||||
|
|
||||||
|
assert isinstance(self, StateProxy)
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
self.order.append("bad idea")
|
||||||
|
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
# Even nested access to mutables raises an exception.
|
||||||
|
self.dict_list["foo"].append(42)
|
||||||
|
|
||||||
|
# wait for some other event to happen
|
||||||
|
while len(self.order) == 1:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
async with self:
|
||||||
|
pass # update proxy instance
|
||||||
|
|
||||||
|
async with self:
|
||||||
|
self.order.append("background_task:stop")
|
||||||
|
|
||||||
|
@rx.background
|
||||||
|
async def background_task_generator(self):
|
||||||
|
"""A background task generator that does nothing.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
yield
|
||||||
|
|
||||||
|
def other(self):
|
||||||
|
"""Some other event that updates the state."""
|
||||||
|
self.order.append("other")
|
||||||
|
|
||||||
|
async def bad_chain1(self):
|
||||||
|
"""Test that a background task cannot be chained."""
|
||||||
|
await self.background_task()
|
||||||
|
|
||||||
|
async def bad_chain2(self):
|
||||||
|
"""Test that a background task generator cannot be chained."""
|
||||||
|
async for _foo in self.background_task_generator():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||||
|
"""Test that a background task does not block other events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mock_app: An app that will be returned by `get_app()`
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
router_data = {"query": {}}
|
||||||
|
mock_app.state_manager.state = mock_app.state = BackgroundTaskState
|
||||||
|
async for update in rx.app.process( # type: ignore
|
||||||
|
mock_app,
|
||||||
|
Event(
|
||||||
|
token=token,
|
||||||
|
name=f"{BackgroundTaskState.get_name()}.background_task",
|
||||||
|
router_data=router_data,
|
||||||
|
payload={},
|
||||||
|
),
|
||||||
|
sid="",
|
||||||
|
headers={},
|
||||||
|
client_ip="",
|
||||||
|
):
|
||||||
|
# background task returns empty update immediately
|
||||||
|
assert update == StateUpdate()
|
||||||
|
assert len(mock_app.background_tasks) == 1
|
||||||
|
|
||||||
|
# wait for the coroutine to start
|
||||||
|
await asyncio.sleep(0.5 if CI else 0.1)
|
||||||
|
assert len(mock_app.background_tasks) == 1
|
||||||
|
|
||||||
|
# Process another normal event
|
||||||
|
async for update in rx.app.process( # type: ignore
|
||||||
|
mock_app,
|
||||||
|
Event(
|
||||||
|
token=token,
|
||||||
|
name=f"{BackgroundTaskState.get_name()}.other",
|
||||||
|
router_data=router_data,
|
||||||
|
payload={},
|
||||||
|
),
|
||||||
|
sid="",
|
||||||
|
headers={},
|
||||||
|
client_ip="",
|
||||||
|
):
|
||||||
|
# other task returns delta
|
||||||
|
assert update == StateUpdate(
|
||||||
|
delta={
|
||||||
|
BackgroundTaskState.get_name(): {
|
||||||
|
"order": [
|
||||||
|
"background_task:start",
|
||||||
|
"other",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicit wait for background tasks
|
||||||
|
for task in tuple(mock_app.background_tasks):
|
||||||
|
await task
|
||||||
|
assert not mock_app.background_tasks
|
||||||
|
|
||||||
|
assert (await mock_app.state_manager.get_state(token)).order == [
|
||||||
|
"background_task:start",
|
||||||
|
"other",
|
||||||
|
"background_task:stop",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_background_task_no_chain():
|
||||||
|
"""Test that a background task cannot be chained."""
|
||||||
|
bts = BackgroundTaskState()
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await bts.bad_chain1()
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await bts.bad_chain2()
|
||||||
|
|
||||||
|
|
||||||
def test_mutable_list(mutable_state):
|
def test_mutable_list(mutable_state):
|
||||||
"""Test that mutable lists are tracked correctly.
|
"""Test that mutable lists are tracked correctly.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user