From 0845d2ee7689e91a4532a8b800314f71230febaa Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 23 Jul 2024 15:28:38 -0700 Subject: [PATCH] [REF-3184] [REF-3339] Background task locking improvements (#3696) * [REF-3184] Raise exception when encountering nested `async with self` blocks Avoid deadlock when the background task already holds the mutation lock for a given state. * [REF-3339] get_state from background task links to StateProxy When calling `get_state` from a background task, the resulting state instance is wrapped in a StateProxy that is bound to the original StateProxy and shares the same async context, lock, and mutability flag. * If StateProxy has a _self_parent_state_proxy, retrieve the correct substate * test_state fixup --- integration/test_background_task.py | 103 ++++++++++++++++++++++++++++ reflex/state.py | 64 ++++++++++++++--- tests/test_state.py | 2 +- 3 files changed, 160 insertions(+), 9 deletions(-) diff --git a/integration/test_background_task.py b/integration/test_background_task.py index 96a47e951..98b6e48ff 100644 --- a/integration/test_background_task.py +++ b/integration/test_background_task.py @@ -12,7 +12,10 @@ def BackgroundTask(): """Test that background tasks work as expected.""" import asyncio + import pytest + import reflex as rx + from reflex.state import ImmutableStateError class State(rx.State): counter: int = 0 @@ -71,6 +74,38 @@ def BackgroundTask(): self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task() ) + @rx.background + async def nested_async_with_self(self): + async with self: + self.counter += 1 + with pytest.raises(ImmutableStateError): + async with self: + self.counter += 1 + + async def triple_count(self): + third_state = await self.get_state(ThirdState) + await third_state._triple_count() + + class OtherState(rx.State): + @rx.background + async def get_other_state(self): + async with self: + state = await self.get_state(State) + state.counter += 1 + await state.triple_count() + with pytest.raises(ImmutableStateError): + await state.triple_count() + with pytest.raises(ImmutableStateError): + state.counter += 1 + async with state: + state.counter += 1 + await state.triple_count() + + class ThirdState(rx.State): + async def _triple_count(self): + state = await self.get_state(State) + state.counter *= 3 + def index() -> rx.Component: return rx.vstack( rx.chakra.input( @@ -109,6 +144,16 @@ def BackgroundTask(): on_click=State.handle_racy_event, id="racy-increment", ), + rx.button( + "Nested Async with Self", + on_click=State.nested_async_with_self, + id="nested-async-with-self", + ), + rx.button( + "Increment from OtherState", + on_click=OtherState.get_other_state, + id="increment-from-other-state", + ), rx.button("Reset", on_click=State.reset_counter, id="reset"), ) @@ -230,3 +275,61 @@ def test_background_task( assert background_task._poll_for( lambda: not background_task.app_instance.background_tasks # type: ignore ) + + +def test_nested_async_with_self( + background_task: AppHarness, + driver: WebDriver, + token: str, +): + """Test that nested async with self in the same coroutine raises Exception. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert background_task.app_instance is not None + + # get a reference to all buttons + nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self") + increment_button = driver.find_element(By.ID, "increment") + + # get a reference to the counter + counter = driver.find_element(By.ID, "counter") + assert background_task._poll_for(lambda: counter.text == "0", timeout=5) + + nested_async_with_self_button.click() + assert background_task._poll_for(lambda: counter.text == "1", timeout=5) + + increment_button.click() + assert background_task._poll_for(lambda: counter.text == "2", timeout=5) + + +def test_get_state( + background_task: AppHarness, + driver: WebDriver, + token: str, +): + """Test that get_state returns a state bound to the correct StateProxy. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert background_task.app_instance is not None + + # get a reference to all buttons + other_state_button = driver.find_element(By.ID, "increment-from-other-state") + increment_button = driver.find_element(By.ID, "increment") + + # get a reference to the counter + counter = driver.find_element(By.ID, "counter") + assert background_task._poll_for(lambda: counter.text == "0", timeout=5) + + other_state_button.click() + assert background_task._poll_for(lambda: counter.text == "12", timeout=5) + + increment_button.click() + assert background_task._poll_for(lambda: counter.text == "13", timeout=5) diff --git a/reflex/state.py b/reflex/state.py index 9313939dc..e29336042 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -202,7 +202,7 @@ def _no_chain_background_task( def _substate_key( token: str, - state_cls_or_name: BaseState | Type[BaseState] | str | list[str], + state_cls_or_name: BaseState | Type[BaseState] | str | Sequence[str], ) -> str: """Get the substate key. @@ -2029,19 +2029,38 @@ class StateProxy(wrapt.ObjectProxy): self.counter += 1 """ - def __init__(self, state_instance): + def __init__( + self, state_instance, parent_state_proxy: Optional["StateProxy"] = None + ): """Create a proxy for a state instance. + If `get_state` is used on a StateProxy, the resulting state will be + linked to the given state via parent_state_proxy. The first state in the + chain is the state that initiated the background task. + Args: state_instance: The state instance to proxy. + parent_state_proxy: The parent state proxy, for linked mutability and context tracking. """ super().__init__(state_instance) # compile is not relevant to backend logic self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - self._self_substate_path = state_instance.get_full_name().split(".") + self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_actx = None self._self_mutable = False self._self_actx_lock = asyncio.Lock() + self._self_actx_lock_holder = None + self._self_parent_state_proxy = parent_state_proxy + + def _is_mutable(self) -> bool: + """Check if the state is mutable. + + Returns: + Whether the state is mutable. + """ + if self._self_parent_state_proxy is not None: + return self._self_parent_state_proxy._is_mutable() + return self._self_mutable async def __aenter__(self) -> StateProxy: """Enter the async context manager protocol. @@ -2054,8 +2073,31 @@ class StateProxy(wrapt.ObjectProxy): Returns: This StateProxy instance in mutable mode. + + Raises: + ImmutableStateError: If the state is already mutable. """ + if self._self_parent_state_proxy is not None: + parent_state = ( + await self._self_parent_state_proxy.__aenter__() + ).__wrapped__ + super().__setattr__( + "__wrapped__", + await parent_state.get_state( + State.get_class_substate(self._self_substate_path) + ), + ) + return self + current_task = asyncio.current_task() + if ( + self._self_actx_lock.locked() + and current_task == self._self_actx_lock_holder + ): + raise ImmutableStateError( + "The state is already mutable. Do not nest `async with self` blocks." + ) await self._self_actx_lock.acquire() + self._self_actx_lock_holder = current_task self._self_actx = self._self_app.modify_state( token=_substate_key( self.__wrapped__.router.session.client_token, @@ -2077,12 +2119,16 @@ class StateProxy(wrapt.ObjectProxy): Args: exc_info: The exception info tuple. """ + if self._self_parent_state_proxy is not None: + await self._self_parent_state_proxy.__aexit__(*exc_info) + return if self._self_actx is None: return self._self_mutable = False try: await self._self_actx.__aexit__(*exc_info) finally: + self._self_actx_lock_holder = None self._self_actx_lock.release() self._self_actx = None @@ -2117,7 +2163,7 @@ class StateProxy(wrapt.ObjectProxy): Raises: ImmutableStateError: If the state is not in mutable mode. """ - if name in ["substates", "parent_state"] and not self._self_mutable: + if name in ["substates", "parent_state"] and not self._is_mutable(): raise ImmutableStateError( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." @@ -2157,7 +2203,7 @@ class StateProxy(wrapt.ObjectProxy): """ if ( name.startswith("_self_") # wrapper attribute - or self._self_mutable # lock held + or self._is_mutable() # lock held # non-persisted state attribute or name in self.__wrapped__.get_skip_vars() ): @@ -2181,7 +2227,7 @@ class StateProxy(wrapt.ObjectProxy): Raises: ImmutableStateError: If the state is not in mutable mode. """ - if not self._self_mutable: + if not self._is_mutable(): raise ImmutableStateError( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." @@ -2200,12 +2246,14 @@ class StateProxy(wrapt.ObjectProxy): Raises: ImmutableStateError: If the state is not in mutable mode. """ - if not self._self_mutable: + if not self._is_mutable(): raise ImmutableStateError( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." ) - return await self.__wrapped__.get_state(state_cls) + return type(self)( + await self.__wrapped__.get_state(state_cls), parent_state_proxy=self + ) def _as_state_update(self, *args, **kwargs) -> StateUpdate: """Temporarily allow mutability to access parent_state. diff --git a/tests/test_state.py b/tests/test_state.py index c998944ef..18d740015 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1825,7 +1825,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): sp = StateProxy(grandchild_state) assert sp.__wrapped__ == grandchild_state - assert sp._self_substate_path == grandchild_state.get_full_name().split(".") + assert sp._self_substate_path == tuple(grandchild_state.get_full_name().split(".")) assert sp._self_app is mock_app assert not sp._self_mutable assert sp._self_actx is None