diff --git a/integration/test_background_task.py b/integration/test_background_task.py index 5edf810f4..96a47e951 100644 --- a/integration/test_background_task.py +++ b/integration/test_background_task.py @@ -57,6 +57,20 @@ def BackgroundTask(): async def non_blocking_pause(self): await asyncio.sleep(0.02) + async def racy_task(self): + async with self: + self._task_id += 1 + for _ix in range(int(self.iterations)): + async with self: + self.counter += 1 + await asyncio.sleep(0.005) + + @rx.background + async def handle_racy_event(self): + await asyncio.gather( + self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task() + ) + def index() -> rx.Component: return rx.vstack( rx.chakra.input( @@ -90,6 +104,11 @@ def BackgroundTask(): on_click=State.non_blocking_pause, id="non-blocking-pause", ), + rx.button( + "Racy Increment (x4)", + on_click=State.handle_racy_event, + id="racy-increment", + ), rx.button("Reset", on_click=State.reset_counter, id="reset"), ) @@ -176,6 +195,7 @@ def test_background_task( increment_button = driver.find_element(By.ID, "increment") blocking_pause_button = driver.find_element(By.ID, "blocking-pause") non_blocking_pause_button = driver.find_element(By.ID, "non-blocking-pause") + racy_increment_button = driver.find_element(By.ID, "racy-increment") driver.find_element(By.ID, "reset") # get a reference to the counter @@ -196,6 +216,7 @@ def test_background_task( delayed_increment_button.click() delayed_increment_button.click() yield_increment_button.click() + racy_increment_button.click() non_blocking_pause_button.click() yield_increment_button.click() blocking_pause_button.click() @@ -204,7 +225,7 @@ def test_background_task( increment_button.click() yield_increment_button.click() blocking_pause_button.click() - assert background_task._poll_for(lambda: counter.text == "420", timeout=40) + assert background_task._poll_for(lambda: counter.text == "620", timeout=40) # all tasks should have exited and cleaned up assert background_task._poll_for( lambda: not background_task.app_instance.background_tasks # type: ignore diff --git a/reflex/state.py b/reflex/state.py index 156751ac2..2a4494e65 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1988,6 +1988,7 @@ class StateProxy(wrapt.ObjectProxy): self._self_substate_path = state_instance.get_full_name().split(".") self._self_actx = None self._self_mutable = False + self._self_actx_lock = asyncio.Lock() async def __aenter__(self) -> StateProxy: """Enter the async context manager protocol. @@ -2001,6 +2002,7 @@ class StateProxy(wrapt.ObjectProxy): Returns: This StateProxy instance in mutable mode. """ + await self._self_actx_lock.acquire() self._self_actx = self._self_app.modify_state( token=_substate_key( self.__wrapped__.router.session.client_token, @@ -2025,7 +2027,10 @@ class StateProxy(wrapt.ObjectProxy): if self._self_actx is None: return self._self_mutable = False - await self._self_actx.__aexit__(*exc_info) + try: + await self._self_actx.__aexit__(*exc_info) + finally: + self._self_actx_lock.release() self._self_actx = None def __enter__(self):