diff --git a/integration/test_background_task.py b/integration/test_background_task.py index 45866ba00..3764f67b4 100644 --- a/integration/test_background_task.py +++ b/integration/test_background_task.py @@ -87,6 +87,13 @@ def BackgroundTask(): third_state = await self.get_state(ThirdState) await third_state._triple_count() + @rx.background + async def yield_in_async_with_self(self): + async with self: + self.counter += 1 + yield + self.counter += 1 + class OtherState(rx.State): @rx.background async def get_other_state(self): @@ -155,6 +162,11 @@ def BackgroundTask(): on_click=OtherState.get_other_state, id="increment-from-other-state", ), + rx.button( + "Yield in Async with Self", + on_click=State.yield_in_async_with_self, + id="yield-in-async-with-self", + ), rx.button("Reset", on_click=State.reset_counter, id="reset"), ) @@ -334,3 +346,30 @@ def test_get_state( increment_button.click() assert background_task._poll_for(lambda: counter.text == "13", timeout=5) + + +def test_yield_in_async_with_self( + background_task: AppHarness, + driver: WebDriver, + token: str, +): + """Test that yielding inside async with self does not disable mutability. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert background_task.app_instance is not None + + # get a reference to all buttons + yield_in_async_with_self_button = driver.find_element( + By.ID, "yield-in-async-with-self" + ) + + # get a reference to the counter + counter = driver.find_element(By.ID, "counter") + assert background_task._poll_for(lambda: counter.text == "0", timeout=5) + + yield_in_async_with_self_button.click() + assert background_task._poll_for(lambda: counter.text == "2", timeout=5) diff --git a/reflex/state.py b/reflex/state.py index 534672bd9..c94adb39d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2295,11 +2295,12 @@ class StateProxy(wrapt.ObjectProxy): Returns: The state update. """ + original_mutable = self._self_mutable self._self_mutable = True try: return self.__wrapped__._as_state_update(*args, **kwargs) finally: - self._self_mutable = False + self._self_mutable = original_mutable class StateUpdate(Base):