Protect StateProxy with an asyncio.Lock (#3508)

* test_background_task: dispatch multiple async tasks

Use asyncio.gather to dispatch multiple tasks from a single background task
that all compete over the `async with self` lock. Even though the state itself
has a lock, each StateProxy instance should only allow a single `async with
self` context to run at a time.

* Protect StateProxy with an asyncio.Lock

Allow multiple tasks to reference the same StateProxy without stomping on each
other when entering an `async with self` context to acquire the state lock and
ultimately modify the state.
This commit is contained in:
Masen Furer 2024-06-18 09:48:12 -07:00 committed by GitHub
parent eb397dacc4
commit af3c9be97c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 2 deletions

View File

@ -57,6 +57,20 @@ def BackgroundTask():
async def non_blocking_pause(self):
await asyncio.sleep(0.02)
async def racy_task(self):
async with self:
self._task_id += 1
for _ix in range(int(self.iterations)):
async with self:
self.counter += 1
await asyncio.sleep(0.005)
@rx.background
async def handle_racy_event(self):
await asyncio.gather(
self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
)
def index() -> rx.Component:
return rx.vstack(
rx.chakra.input(
@ -90,6 +104,11 @@ def BackgroundTask():
on_click=State.non_blocking_pause,
id="non-blocking-pause",
),
rx.button(
"Racy Increment (x4)",
on_click=State.handle_racy_event,
id="racy-increment",
),
rx.button("Reset", on_click=State.reset_counter, id="reset"),
)
@ -176,6 +195,7 @@ def test_background_task(
increment_button = driver.find_element(By.ID, "increment")
blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
racy_increment_button = driver.find_element(By.ID, "racy-increment")
driver.find_element(By.ID, "reset")
# get a reference to the counter
@ -196,6 +216,7 @@ def test_background_task(
delayed_increment_button.click()
delayed_increment_button.click()
yield_increment_button.click()
racy_increment_button.click()
non_blocking_pause_button.click()
yield_increment_button.click()
blocking_pause_button.click()
@ -204,7 +225,7 @@ def test_background_task(
increment_button.click()
yield_increment_button.click()
blocking_pause_button.click()
assert background_task._poll_for(lambda: counter.text == "420", timeout=40)
assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
# all tasks should have exited and cleaned up
assert background_task._poll_for(
lambda: not background_task.app_instance.background_tasks # type: ignore

View File

@ -1988,6 +1988,7 @@ class StateProxy(wrapt.ObjectProxy):
self._self_substate_path = state_instance.get_full_name().split(".")
self._self_actx = None
self._self_mutable = False
self._self_actx_lock = asyncio.Lock()
async def __aenter__(self) -> StateProxy:
"""Enter the async context manager protocol.
@ -2001,6 +2002,7 @@ class StateProxy(wrapt.ObjectProxy):
Returns:
This StateProxy instance in mutable mode.
"""
await self._self_actx_lock.acquire()
self._self_actx = self._self_app.modify_state(
token=_substate_key(
self.__wrapped__.router.session.client_token,
@ -2025,7 +2027,10 @@ class StateProxy(wrapt.ObjectProxy):
if self._self_actx is None:
return
self._self_mutable = False
await self._self_actx.__aexit__(*exc_info)
try:
await self._self_actx.__aexit__(*exc_info)
finally:
self._self_actx_lock.release()
self._self_actx = None
def __enter__(self):