[REF-3184] [REF-3339] Background task locking improvements (#3696)

* [REF-3184] Raise exception when encountering nested `async with self` blocks

Avoid deadlock when the background task already holds the mutation lock for a
given state.

* [REF-3339] get_state from background task links to StateProxy

When calling `get_state` from a background task, the resulting state instance
is wrapped in a StateProxy that is bound to the original StateProxy and shares
the same async context, lock, and mutability flag.

* If StateProxy has a _self_parent_state_proxy, retrieve the correct substate

* test_state fixup
This commit is contained in:
Masen Furer 2024-07-23 15:28:38 -07:00 committed by GitHub
parent b9927b6f49
commit 0845d2ee76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 160 additions and 9 deletions

View File

@ -12,7 +12,10 @@ def BackgroundTask():
"""Test that background tasks work as expected.""" """Test that background tasks work as expected."""
import asyncio import asyncio
import pytest
import reflex as rx import reflex as rx
from reflex.state import ImmutableStateError
class State(rx.State): class State(rx.State):
counter: int = 0 counter: int = 0
@ -71,6 +74,38 @@ def BackgroundTask():
self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task() self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
) )
@rx.background
async def nested_async_with_self(self):
async with self:
self.counter += 1
with pytest.raises(ImmutableStateError):
async with self:
self.counter += 1
async def triple_count(self):
third_state = await self.get_state(ThirdState)
await third_state._triple_count()
class OtherState(rx.State):
@rx.background
async def get_other_state(self):
async with self:
state = await self.get_state(State)
state.counter += 1
await state.triple_count()
with pytest.raises(ImmutableStateError):
await state.triple_count()
with pytest.raises(ImmutableStateError):
state.counter += 1
async with state:
state.counter += 1
await state.triple_count()
class ThirdState(rx.State):
async def _triple_count(self):
state = await self.get_state(State)
state.counter *= 3
def index() -> rx.Component: def index() -> rx.Component:
return rx.vstack( return rx.vstack(
rx.chakra.input( rx.chakra.input(
@ -109,6 +144,16 @@ def BackgroundTask():
on_click=State.handle_racy_event, on_click=State.handle_racy_event,
id="racy-increment", id="racy-increment",
), ),
rx.button(
"Nested Async with Self",
on_click=State.nested_async_with_self,
id="nested-async-with-self",
),
rx.button(
"Increment from OtherState",
on_click=OtherState.get_other_state,
id="increment-from-other-state",
),
rx.button("Reset", on_click=State.reset_counter, id="reset"), rx.button("Reset", on_click=State.reset_counter, id="reset"),
) )
@ -230,3 +275,61 @@ def test_background_task(
assert background_task._poll_for( assert background_task._poll_for(
lambda: not background_task.app_instance.background_tasks # type: ignore lambda: not background_task.app_instance.background_tasks # type: ignore
) )
def test_nested_async_with_self(
background_task: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that nested async with self in the same coroutine raises Exception.
Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert background_task.app_instance is not None
# get a reference to all buttons
nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
increment_button = driver.find_element(By.ID, "increment")
# get a reference to the counter
counter = driver.find_element(By.ID, "counter")
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
nested_async_with_self_button.click()
assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
increment_button.click()
assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
def test_get_state(
background_task: AppHarness,
driver: WebDriver,
token: str,
):
"""Test that get_state returns a state bound to the correct StateProxy.
Args:
background_task: harness for BackgroundTask app.
driver: WebDriver instance.
token: The token for the connected client.
"""
assert background_task.app_instance is not None
# get a reference to all buttons
other_state_button = driver.find_element(By.ID, "increment-from-other-state")
increment_button = driver.find_element(By.ID, "increment")
# get a reference to the counter
counter = driver.find_element(By.ID, "counter")
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
other_state_button.click()
assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
increment_button.click()
assert background_task._poll_for(lambda: counter.text == "13", timeout=5)

View File

@ -202,7 +202,7 @@ def _no_chain_background_task(
def _substate_key( def _substate_key(
token: str, token: str,
state_cls_or_name: BaseState | Type[BaseState] | str | list[str], state_cls_or_name: BaseState | Type[BaseState] | str | Sequence[str],
) -> str: ) -> str:
"""Get the substate key. """Get the substate key.
@ -2029,19 +2029,38 @@ class StateProxy(wrapt.ObjectProxy):
self.counter += 1 self.counter += 1
""" """
def __init__(self, state_instance): def __init__(
self, state_instance, parent_state_proxy: Optional["StateProxy"] = None
):
"""Create a proxy for a state instance. """Create a proxy for a state instance.
If `get_state` is used on a StateProxy, the resulting state will be
linked to the given state via parent_state_proxy. The first state in the
chain is the state that initiated the background task.
Args: Args:
state_instance: The state instance to proxy. state_instance: The state instance to proxy.
parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
""" """
super().__init__(state_instance) super().__init__(state_instance)
# compile is not relevant to backend logic # compile is not relevant to backend logic
self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP) self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
self._self_substate_path = state_instance.get_full_name().split(".") self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_actx = None self._self_actx = None
self._self_mutable = False self._self_mutable = False
self._self_actx_lock = asyncio.Lock() self._self_actx_lock = asyncio.Lock()
self._self_actx_lock_holder = None
self._self_parent_state_proxy = parent_state_proxy
def _is_mutable(self) -> bool:
"""Check if the state is mutable.
Returns:
Whether the state is mutable.
"""
if self._self_parent_state_proxy is not None:
return self._self_parent_state_proxy._is_mutable()
return self._self_mutable
async def __aenter__(self) -> StateProxy: async def __aenter__(self) -> StateProxy:
"""Enter the async context manager protocol. """Enter the async context manager protocol.
@ -2054,8 +2073,31 @@ class StateProxy(wrapt.ObjectProxy):
Returns: Returns:
This StateProxy instance in mutable mode. This StateProxy instance in mutable mode.
Raises:
ImmutableStateError: If the state is already mutable.
""" """
if self._self_parent_state_proxy is not None:
parent_state = (
await self._self_parent_state_proxy.__aenter__()
).__wrapped__
super().__setattr__(
"__wrapped__",
await parent_state.get_state(
State.get_class_substate(self._self_substate_path)
),
)
return self
current_task = asyncio.current_task()
if (
self._self_actx_lock.locked()
and current_task == self._self_actx_lock_holder
):
raise ImmutableStateError(
"The state is already mutable. Do not nest `async with self` blocks."
)
await self._self_actx_lock.acquire() await self._self_actx_lock.acquire()
self._self_actx_lock_holder = current_task
self._self_actx = self._self_app.modify_state( self._self_actx = self._self_app.modify_state(
token=_substate_key( token=_substate_key(
self.__wrapped__.router.session.client_token, self.__wrapped__.router.session.client_token,
@ -2077,12 +2119,16 @@ class StateProxy(wrapt.ObjectProxy):
Args: Args:
exc_info: The exception info tuple. exc_info: The exception info tuple.
""" """
if self._self_parent_state_proxy is not None:
await self._self_parent_state_proxy.__aexit__(*exc_info)
return
if self._self_actx is None: if self._self_actx is None:
return return
self._self_mutable = False self._self_mutable = False
try: try:
await self._self_actx.__aexit__(*exc_info) await self._self_actx.__aexit__(*exc_info)
finally: finally:
self._self_actx_lock_holder = None
self._self_actx_lock.release() self._self_actx_lock.release()
self._self_actx = None self._self_actx = None
@ -2117,7 +2163,7 @@ class StateProxy(wrapt.ObjectProxy):
Raises: Raises:
ImmutableStateError: If the state is not in mutable mode. ImmutableStateError: If the state is not in mutable mode.
""" """
if name in ["substates", "parent_state"] and not self._self_mutable: if name in ["substates", "parent_state"] and not self._is_mutable():
raise ImmutableStateError( raise ImmutableStateError(
"Background task StateProxy is immutable outside of a context " "Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state." "manager. Use `async with self` to modify state."
@ -2157,7 +2203,7 @@ class StateProxy(wrapt.ObjectProxy):
""" """
if ( if (
name.startswith("_self_") # wrapper attribute name.startswith("_self_") # wrapper attribute
or self._self_mutable # lock held or self._is_mutable() # lock held
# non-persisted state attribute # non-persisted state attribute
or name in self.__wrapped__.get_skip_vars() or name in self.__wrapped__.get_skip_vars()
): ):
@ -2181,7 +2227,7 @@ class StateProxy(wrapt.ObjectProxy):
Raises: Raises:
ImmutableStateError: If the state is not in mutable mode. ImmutableStateError: If the state is not in mutable mode.
""" """
if not self._self_mutable: if not self._is_mutable():
raise ImmutableStateError( raise ImmutableStateError(
"Background task StateProxy is immutable outside of a context " "Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state." "manager. Use `async with self` to modify state."
@ -2200,12 +2246,14 @@ class StateProxy(wrapt.ObjectProxy):
Raises: Raises:
ImmutableStateError: If the state is not in mutable mode. ImmutableStateError: If the state is not in mutable mode.
""" """
if not self._self_mutable: if not self._is_mutable():
raise ImmutableStateError( raise ImmutableStateError(
"Background task StateProxy is immutable outside of a context " "Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state." "manager. Use `async with self` to modify state."
) )
return await self.__wrapped__.get_state(state_cls) return type(self)(
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
)
def _as_state_update(self, *args, **kwargs) -> StateUpdate: def _as_state_update(self, *args, **kwargs) -> StateUpdate:
"""Temporarily allow mutability to access parent_state. """Temporarily allow mutability to access parent_state.

View File

@ -1825,7 +1825,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
sp = StateProxy(grandchild_state) sp = StateProxy(grandchild_state)
assert sp.__wrapped__ == grandchild_state assert sp.__wrapped__ == grandchild_state
assert sp._self_substate_path == grandchild_state.get_full_name().split(".") assert sp._self_substate_path == tuple(grandchild_state.get_full_name().split("."))
assert sp._self_app is mock_app assert sp._self_app is mock_app
assert not sp._self_mutable assert not sp._self_mutable
assert sp._self_actx is None assert sp._self_actx is None