Protect StateProxy with an asyncio.Lock (#3508)
* test_background_task: dispatch multiple async tasks Use asyncio.gather to dispatch multiple tasks from a single background task that all compete over the `async with self` lock. Even though the state itself has a lock, each StateProxy instance should only allow a single `async with self` context to run at a time. * Protect StateProxy with an asyncio.Lock Allow multiple tasks to reference the same StateProxy without stomping on each other when entering an `async with self` context to acquire the state lock and ultimately modify the state.
This commit is contained in:
parent
eb397dacc4
commit
af3c9be97c
@ -57,6 +57,20 @@ def BackgroundTask():
|
|||||||
async def non_blocking_pause(self):
|
async def non_blocking_pause(self):
|
||||||
await asyncio.sleep(0.02)
|
await asyncio.sleep(0.02)
|
||||||
|
|
||||||
|
async def racy_task(self):
|
||||||
|
async with self:
|
||||||
|
self._task_id += 1
|
||||||
|
for _ix in range(int(self.iterations)):
|
||||||
|
async with self:
|
||||||
|
self.counter += 1
|
||||||
|
await asyncio.sleep(0.005)
|
||||||
|
|
||||||
|
@rx.background
|
||||||
|
async def handle_racy_event(self):
|
||||||
|
await asyncio.gather(
|
||||||
|
self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
|
||||||
|
)
|
||||||
|
|
||||||
def index() -> rx.Component:
|
def index() -> rx.Component:
|
||||||
return rx.vstack(
|
return rx.vstack(
|
||||||
rx.chakra.input(
|
rx.chakra.input(
|
||||||
@ -90,6 +104,11 @@ def BackgroundTask():
|
|||||||
on_click=State.non_blocking_pause,
|
on_click=State.non_blocking_pause,
|
||||||
id="non-blocking-pause",
|
id="non-blocking-pause",
|
||||||
),
|
),
|
||||||
|
rx.button(
|
||||||
|
"Racy Increment (x4)",
|
||||||
|
on_click=State.handle_racy_event,
|
||||||
|
id="racy-increment",
|
||||||
|
),
|
||||||
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -176,6 +195,7 @@ def test_background_task(
|
|||||||
increment_button = driver.find_element(By.ID, "increment")
|
increment_button = driver.find_element(By.ID, "increment")
|
||||||
blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
|
blocking_pause_button = driver.find_element(By.ID, "blocking-pause")
|
||||||
non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
|
non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause")
|
||||||
|
racy_increment_button = driver.find_element(By.ID, "racy-increment")
|
||||||
driver.find_element(By.ID, "reset")
|
driver.find_element(By.ID, "reset")
|
||||||
|
|
||||||
# get a reference to the counter
|
# get a reference to the counter
|
||||||
@ -196,6 +216,7 @@ def test_background_task(
|
|||||||
delayed_increment_button.click()
|
delayed_increment_button.click()
|
||||||
delayed_increment_button.click()
|
delayed_increment_button.click()
|
||||||
yield_increment_button.click()
|
yield_increment_button.click()
|
||||||
|
racy_increment_button.click()
|
||||||
non_blocking_pause_button.click()
|
non_blocking_pause_button.click()
|
||||||
yield_increment_button.click()
|
yield_increment_button.click()
|
||||||
blocking_pause_button.click()
|
blocking_pause_button.click()
|
||||||
@ -204,7 +225,7 @@ def test_background_task(
|
|||||||
increment_button.click()
|
increment_button.click()
|
||||||
yield_increment_button.click()
|
yield_increment_button.click()
|
||||||
blocking_pause_button.click()
|
blocking_pause_button.click()
|
||||||
assert background_task._poll_for(lambda: counter.text == "420", timeout=40)
|
assert background_task._poll_for(lambda: counter.text == "620", timeout=40)
|
||||||
# all tasks should have exited and cleaned up
|
# all tasks should have exited and cleaned up
|
||||||
assert background_task._poll_for(
|
assert background_task._poll_for(
|
||||||
lambda: not background_task.app_instance.background_tasks # type: ignore
|
lambda: not background_task.app_instance.background_tasks # type: ignore
|
||||||
|
@ -1988,6 +1988,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
self._self_substate_path = state_instance.get_full_name().split(".")
|
self._self_substate_path = state_instance.get_full_name().split(".")
|
||||||
self._self_actx = None
|
self._self_actx = None
|
||||||
self._self_mutable = False
|
self._self_mutable = False
|
||||||
|
self._self_actx_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def __aenter__(self) -> StateProxy:
|
async def __aenter__(self) -> StateProxy:
|
||||||
"""Enter the async context manager protocol.
|
"""Enter the async context manager protocol.
|
||||||
@ -2001,6 +2002,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
Returns:
|
Returns:
|
||||||
This StateProxy instance in mutable mode.
|
This StateProxy instance in mutable mode.
|
||||||
"""
|
"""
|
||||||
|
await self._self_actx_lock.acquire()
|
||||||
self._self_actx = self._self_app.modify_state(
|
self._self_actx = self._self_app.modify_state(
|
||||||
token=_substate_key(
|
token=_substate_key(
|
||||||
self.__wrapped__.router.session.client_token,
|
self.__wrapped__.router.session.client_token,
|
||||||
@ -2025,7 +2027,10 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
if self._self_actx is None:
|
if self._self_actx is None:
|
||||||
return
|
return
|
||||||
self._self_mutable = False
|
self._self_mutable = False
|
||||||
await self._self_actx.__aexit__(*exc_info)
|
try:
|
||||||
|
await self._self_actx.__aexit__(*exc_info)
|
||||||
|
finally:
|
||||||
|
self._self_actx_lock.release()
|
||||||
self._self_actx = None
|
self._self_actx = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user