Protect StateProxy with an asyncio.Lock (#3508)
* test_background_task: dispatch multiple async tasks Use asyncio.gather to dispatch multiple tasks from a single background task that all compete over the `async with self` lock. Even though the state itself has a lock, each StateProxy instance should only allow a single `async with self` context to run at a time. * Protect StateProxy with an asyncio.Lock Allow multiple tasks to reference the same StateProxy without stomping on each other when entering an `async with self` context to acquire the state lock and ultimately modify the state.
This commit is contained in:
parent
eb397dacc4
commit
af3c9be97c
@ -57,6 +57,20 @@ def BackgroundTask():
|
||||
async def non_blocking_pause(self):
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
async def racy_task(self):
|
||||
async with self:
|
||||
self._task_id += 1
|
||||
for _ix in range(int(self.iterations)):
|
||||
async with self:
|
||||
self.counter += 1
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
@rx.background
|
||||
async def handle_racy_event(self):
|
||||
await asyncio.gather(
|
||||
self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
|
||||
)
|
||||
|
||||
def index() -> rx.Component:
|
||||
return rx.vstack(
|
||||
rx.chakra.input(
|
||||
@ -90,6 +104,11 @@ def BackgroundTask():
|
||||
on_click=State.non_blocking_pause,
|
||||
id="non-blocking-pause",
|
||||
),
|
||||
rx.button(
|
||||
"Racy Increment (x4)",
|
||||
on_click=State.handle_racy_event,
|
||||
id="racy-increment",
|
||||
),
|
||||
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
||||
)
|
||||
|
||||
@ -176,6 +195,7 @@ def test_background_task(
|
||||
increment_button = driver.find_element(By.ID, "increment")
|
||||
blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
|
||||
non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
|
||||
racy_increment_button = driver.find_element(By.ID, "racy-increment")
|
||||
driver.find_element(By.ID, "reset")
|
||||
|
||||
# get a reference to the counter
|
||||
@ -196,6 +216,7 @@ def test_background_task(
|
||||
delayed_increment_button.click()
|
||||
delayed_increment_button.click()
|
||||
yield_increment_button.click()
|
||||
racy_increment_button.click()
|
||||
non_blocking_pause_button.click()
|
||||
yield_increment_button.click()
|
||||
blocking_pause_button.click()
|
||||
@ -204,7 +225,7 @@ def test_background_task(
|
||||
increment_button.click()
|
||||
yield_increment_button.click()
|
||||
blocking_pause_button.click()
|
||||
assert background_task._poll_for(lambda: counter.text == "420", timeout=40)
|
||||
assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
|
||||
# all tasks should have exited and cleaned up
|
||||
assert background_task._poll_for(
|
||||
lambda: not background_task.app_instance.background_tasks # type: ignore
|
||||
|
@ -1988,6 +1988,7 @@ class StateProxy(wrapt.ObjectProxy):
|
||||
self._self_substate_path = state_instance.get_full_name().split(".")
|
||||
self._self_actx = None
|
||||
self._self_mutable = False
|
||||
self._self_actx_lock = asyncio.Lock()
|
||||
|
||||
async def __aenter__(self) -> StateProxy:
|
||||
"""Enter the async context manager protocol.
|
||||
@ -2001,6 +2002,7 @@ class StateProxy(wrapt.ObjectProxy):
|
||||
Returns:
|
||||
This StateProxy instance in mutable mode.
|
||||
"""
|
||||
await self._self_actx_lock.acquire()
|
||||
self._self_actx = self._self_app.modify_state(
|
||||
token=_substate_key(
|
||||
self.__wrapped__.router.session.client_token,
|
||||
@ -2025,7 +2027,10 @@ class StateProxy(wrapt.ObjectProxy):
|
||||
if self._self_actx is None:
|
||||
return
|
||||
self._self_mutable = False
|
||||
await self._self_actx.__aexit__(*exc_info)
|
||||
try:
|
||||
await self._self_actx.__aexit__(*exc_info)
|
||||
finally:
|
||||
self._self_actx_lock.release()
|
||||
self._self_actx = None
|
||||
|
||||
def __enter__(self):
|
||||
|
Loading…
Reference in New Issue
Block a user