[REF-3184] [REF-3339] Background task locking improvements (#3696)
* [REF-3184] Raise exception when encountering nested `async with self` blocks Avoid deadlock when the background task already holds the mutation lock for a given state. * [REF-3339] get_state from background task links to StateProxy When calling `get_state` from a background task, the resulting state instance is wrapped in a StateProxy that is bound to the original StateProxy and shares the same async context, lock, and mutability flag. * If StateProxy has a _self_parent_state_proxy, retrieve the correct substate * test_state fixup
This commit is contained in:
parent
b9927b6f49
commit
0845d2ee76
@ -12,7 +12,10 @@ def BackgroundTask():
|
|||||||
"""Test that background tasks work as expected."""
|
"""Test that background tasks work as expected."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
import reflex as rx
|
import reflex as rx
|
||||||
|
from reflex.state import ImmutableStateError
|
||||||
|
|
||||||
class State(rx.State):
|
class State(rx.State):
|
||||||
counter: int = 0
|
counter: int = 0
|
||||||
@ -71,6 +74,38 @@ def BackgroundTask():
|
|||||||
self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
|
self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@rx.background
|
||||||
|
async def nested_async_with_self(self):
|
||||||
|
async with self:
|
||||||
|
self.counter += 1
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
async with self:
|
||||||
|
self.counter += 1
|
||||||
|
|
||||||
|
async def triple_count(self):
|
||||||
|
third_state = await self.get_state(ThirdState)
|
||||||
|
await third_state._triple_count()
|
||||||
|
|
||||||
|
class OtherState(rx.State):
|
||||||
|
@rx.background
|
||||||
|
async def get_other_state(self):
|
||||||
|
async with self:
|
||||||
|
state = await self.get_state(State)
|
||||||
|
state.counter += 1
|
||||||
|
await state.triple_count()
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
await state.triple_count()
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
state.counter += 1
|
||||||
|
async with state:
|
||||||
|
state.counter += 1
|
||||||
|
await state.triple_count()
|
||||||
|
|
||||||
|
class ThirdState(rx.State):
|
||||||
|
async def _triple_count(self):
|
||||||
|
state = await self.get_state(State)
|
||||||
|
state.counter *= 3
|
||||||
|
|
||||||
def index() -> rx.Component:
|
def index() -> rx.Component:
|
||||||
return rx.vstack(
|
return rx.vstack(
|
||||||
rx.chakra.input(
|
rx.chakra.input(
|
||||||
@ -109,6 +144,16 @@ def BackgroundTask():
|
|||||||
on_click=State.handle_racy_event,
|
on_click=State.handle_racy_event,
|
||||||
id="racy-increment",
|
id="racy-increment",
|
||||||
),
|
),
|
||||||
|
rx.button(
|
||||||
|
"Nested Async with Self",
|
||||||
|
on_click=State.nested_async_with_self,
|
||||||
|
id="nested-async-with-self",
|
||||||
|
),
|
||||||
|
rx.button(
|
||||||
|
"Increment from OtherState",
|
||||||
|
on_click=OtherState.get_other_state,
|
||||||
|
id="increment-from-other-state",
|
||||||
|
),
|
||||||
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
rx.button("Reset", on_click=State.reset_counter, id="reset"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -230,3 +275,61 @@ def test_background_task(
|
|||||||
assert background_task._poll_for(
|
assert background_task._poll_for(
|
||||||
lambda: not background_task.app_instance.background_tasks # type: ignore
|
lambda: not background_task.app_instance.background_tasks # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_async_with_self(
|
||||||
|
background_task: AppHarness,
|
||||||
|
driver: WebDriver,
|
||||||
|
token: str,
|
||||||
|
):
|
||||||
|
"""Test that nested async with self in the same coroutine raises Exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
background_task: harness for BackgroundTask app.
|
||||||
|
driver: WebDriver instance.
|
||||||
|
token: The token for the connected client.
|
||||||
|
"""
|
||||||
|
assert background_task.app_instance is not None
|
||||||
|
|
||||||
|
# get a reference to all buttons
|
||||||
|
nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self")
|
||||||
|
increment_button = driver.find_element(By.ID, "increment")
|
||||||
|
|
||||||
|
# get a reference to the counter
|
||||||
|
counter = driver.find_element(By.ID, "counter")
|
||||||
|
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
|
||||||
|
|
||||||
|
nested_async_with_self_button.click()
|
||||||
|
assert background_task._poll_for(lambda: counter.text == "1", timeout=5)
|
||||||
|
|
||||||
|
increment_button.click()
|
||||||
|
assert background_task._poll_for(lambda: counter.text == "2", timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_state(
|
||||||
|
background_task: AppHarness,
|
||||||
|
driver: WebDriver,
|
||||||
|
token: str,
|
||||||
|
):
|
||||||
|
"""Test that get_state returns a state bound to the correct StateProxy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
background_task: harness for BackgroundTask app.
|
||||||
|
driver: WebDriver instance.
|
||||||
|
token: The token for the connected client.
|
||||||
|
"""
|
||||||
|
assert background_task.app_instance is not None
|
||||||
|
|
||||||
|
# get a reference to all buttons
|
||||||
|
other_state_button = driver.find_element(By.ID, "increment-from-other-state")
|
||||||
|
increment_button = driver.find_element(By.ID, "increment")
|
||||||
|
|
||||||
|
# get a reference to the counter
|
||||||
|
counter = driver.find_element(By.ID, "counter")
|
||||||
|
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
|
||||||
|
|
||||||
|
other_state_button.click()
|
||||||
|
assert background_task._poll_for(lambda: counter.text == "12", timeout=5)
|
||||||
|
|
||||||
|
increment_button.click()
|
||||||
|
assert background_task._poll_for(lambda: counter.text == "13", timeout=5)
|
||||||
|
@ -202,7 +202,7 @@ def _no_chain_background_task(
|
|||||||
|
|
||||||
def _substate_key(
|
def _substate_key(
|
||||||
token: str,
|
token: str,
|
||||||
state_cls_or_name: BaseState | Type[BaseState] | str | list[str],
|
state_cls_or_name: BaseState | Type[BaseState] | str | Sequence[str],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get the substate key.
|
"""Get the substate key.
|
||||||
|
|
||||||
@ -2029,19 +2029,38 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
self.counter += 1
|
self.counter += 1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, state_instance):
|
def __init__(
|
||||||
|
self, state_instance, parent_state_proxy: Optional["StateProxy"] = None
|
||||||
|
):
|
||||||
"""Create a proxy for a state instance.
|
"""Create a proxy for a state instance.
|
||||||
|
|
||||||
|
If `get_state` is used on a StateProxy, the resulting state will be
|
||||||
|
linked to the given state via parent_state_proxy. The first state in the
|
||||||
|
chain is the state that initiated the background task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_instance: The state instance to proxy.
|
state_instance: The state instance to proxy.
|
||||||
|
parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
|
||||||
"""
|
"""
|
||||||
super().__init__(state_instance)
|
super().__init__(state_instance)
|
||||||
# compile is not relevant to backend logic
|
# compile is not relevant to backend logic
|
||||||
self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
|
self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
|
||||||
self._self_substate_path = state_instance.get_full_name().split(".")
|
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
|
||||||
self._self_actx = None
|
self._self_actx = None
|
||||||
self._self_mutable = False
|
self._self_mutable = False
|
||||||
self._self_actx_lock = asyncio.Lock()
|
self._self_actx_lock = asyncio.Lock()
|
||||||
|
self._self_actx_lock_holder = None
|
||||||
|
self._self_parent_state_proxy = parent_state_proxy
|
||||||
|
|
||||||
|
def _is_mutable(self) -> bool:
|
||||||
|
"""Check if the state is mutable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the state is mutable.
|
||||||
|
"""
|
||||||
|
if self._self_parent_state_proxy is not None:
|
||||||
|
return self._self_parent_state_proxy._is_mutable()
|
||||||
|
return self._self_mutable
|
||||||
|
|
||||||
async def __aenter__(self) -> StateProxy:
|
async def __aenter__(self) -> StateProxy:
|
||||||
"""Enter the async context manager protocol.
|
"""Enter the async context manager protocol.
|
||||||
@ -2054,8 +2073,31 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
This StateProxy instance in mutable mode.
|
This StateProxy instance in mutable mode.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImmutableStateError: If the state is already mutable.
|
||||||
"""
|
"""
|
||||||
|
if self._self_parent_state_proxy is not None:
|
||||||
|
parent_state = (
|
||||||
|
await self._self_parent_state_proxy.__aenter__()
|
||||||
|
).__wrapped__
|
||||||
|
super().__setattr__(
|
||||||
|
"__wrapped__",
|
||||||
|
await parent_state.get_state(
|
||||||
|
State.get_class_substate(self._self_substate_path)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
current_task = asyncio.current_task()
|
||||||
|
if (
|
||||||
|
self._self_actx_lock.locked()
|
||||||
|
and current_task == self._self_actx_lock_holder
|
||||||
|
):
|
||||||
|
raise ImmutableStateError(
|
||||||
|
"The state is already mutable. Do not nest `async with self` blocks."
|
||||||
|
)
|
||||||
await self._self_actx_lock.acquire()
|
await self._self_actx_lock.acquire()
|
||||||
|
self._self_actx_lock_holder = current_task
|
||||||
self._self_actx = self._self_app.modify_state(
|
self._self_actx = self._self_app.modify_state(
|
||||||
token=_substate_key(
|
token=_substate_key(
|
||||||
self.__wrapped__.router.session.client_token,
|
self.__wrapped__.router.session.client_token,
|
||||||
@ -2077,12 +2119,16 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
Args:
|
Args:
|
||||||
exc_info: The exception info tuple.
|
exc_info: The exception info tuple.
|
||||||
"""
|
"""
|
||||||
|
if self._self_parent_state_proxy is not None:
|
||||||
|
await self._self_parent_state_proxy.__aexit__(*exc_info)
|
||||||
|
return
|
||||||
if self._self_actx is None:
|
if self._self_actx is None:
|
||||||
return
|
return
|
||||||
self._self_mutable = False
|
self._self_mutable = False
|
||||||
try:
|
try:
|
||||||
await self._self_actx.__aexit__(*exc_info)
|
await self._self_actx.__aexit__(*exc_info)
|
||||||
finally:
|
finally:
|
||||||
|
self._self_actx_lock_holder = None
|
||||||
self._self_actx_lock.release()
|
self._self_actx_lock.release()
|
||||||
self._self_actx = None
|
self._self_actx = None
|
||||||
|
|
||||||
@ -2117,7 +2163,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
Raises:
|
Raises:
|
||||||
ImmutableStateError: If the state is not in mutable mode.
|
ImmutableStateError: If the state is not in mutable mode.
|
||||||
"""
|
"""
|
||||||
if name in ["substates", "parent_state"] and not self._self_mutable:
|
if name in ["substates", "parent_state"] and not self._is_mutable():
|
||||||
raise ImmutableStateError(
|
raise ImmutableStateError(
|
||||||
"Background task StateProxy is immutable outside of a context "
|
"Background task StateProxy is immutable outside of a context "
|
||||||
"manager. Use `async with self` to modify state."
|
"manager. Use `async with self` to modify state."
|
||||||
@ -2157,7 +2203,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
name.startswith("_self_") # wrapper attribute
|
name.startswith("_self_") # wrapper attribute
|
||||||
or self._self_mutable # lock held
|
or self._is_mutable() # lock held
|
||||||
# non-persisted state attribute
|
# non-persisted state attribute
|
||||||
or name in self.__wrapped__.get_skip_vars()
|
or name in self.__wrapped__.get_skip_vars()
|
||||||
):
|
):
|
||||||
@ -2181,7 +2227,7 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
Raises:
|
Raises:
|
||||||
ImmutableStateError: If the state is not in mutable mode.
|
ImmutableStateError: If the state is not in mutable mode.
|
||||||
"""
|
"""
|
||||||
if not self._self_mutable:
|
if not self._is_mutable():
|
||||||
raise ImmutableStateError(
|
raise ImmutableStateError(
|
||||||
"Background task StateProxy is immutable outside of a context "
|
"Background task StateProxy is immutable outside of a context "
|
||||||
"manager. Use `async with self` to modify state."
|
"manager. Use `async with self` to modify state."
|
||||||
@ -2200,12 +2246,14 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
Raises:
|
Raises:
|
||||||
ImmutableStateError: If the state is not in mutable mode.
|
ImmutableStateError: If the state is not in mutable mode.
|
||||||
"""
|
"""
|
||||||
if not self._self_mutable:
|
if not self._is_mutable():
|
||||||
raise ImmutableStateError(
|
raise ImmutableStateError(
|
||||||
"Background task StateProxy is immutable outside of a context "
|
"Background task StateProxy is immutable outside of a context "
|
||||||
"manager. Use `async with self` to modify state."
|
"manager. Use `async with self` to modify state."
|
||||||
)
|
)
|
||||||
return await self.__wrapped__.get_state(state_cls)
|
return type(self)(
|
||||||
|
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
|
||||||
|
)
|
||||||
|
|
||||||
def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
def _as_state_update(self, *args, **kwargs) -> StateUpdate:
|
||||||
"""Temporarily allow mutability to access parent_state.
|
"""Temporarily allow mutability to access parent_state.
|
||||||
|
@ -1825,7 +1825,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
|||||||
|
|
||||||
sp = StateProxy(grandchild_state)
|
sp = StateProxy(grandchild_state)
|
||||||
assert sp.__wrapped__ == grandchild_state
|
assert sp.__wrapped__ == grandchild_state
|
||||||
assert sp._self_substate_path == grandchild_state.get_full_name().split(".")
|
assert sp._self_substate_path == tuple(grandchild_state.get_full_name().split("."))
|
||||||
assert sp._self_app is mock_app
|
assert sp._self_app is mock_app
|
||||||
assert not sp._self_mutable
|
assert not sp._self_mutable
|
||||||
assert sp._self_actx is None
|
assert sp._self_actx is None
|
||||||
|
Loading…
Reference in New Issue
Block a user