StateProxy rebinds functools.partial and methods that are bound to the proxied State (#1853)
This commit is contained in:
parent
351611ca25
commit
83d7a044fe
@ -12,7 +12,7 @@ import urllib.parse
|
|||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from types import FunctionType
|
from types import FunctionType, MethodType
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
@ -1177,6 +1177,17 @@ class StateProxy(wrapt.ObjectProxy):
|
|||||||
state=self, # type: ignore
|
state=self, # type: ignore
|
||||||
field_name=value._self_field_name,
|
field_name=value._self_field_name,
|
||||||
)
|
)
|
||||||
|
if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
|
||||||
|
# Rebind event handler to the proxy instance
|
||||||
|
value = functools.partial(
|
||||||
|
value.func,
|
||||||
|
self,
|
||||||
|
*value.args[1:],
|
||||||
|
**value.keywords,
|
||||||
|
)
|
||||||
|
if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
|
||||||
|
# Rebind methods to the proxy instance
|
||||||
|
value = type(value)(value.__func__, self) # type: ignore
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: Any) -> None:
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
@ -1699,6 +1699,14 @@ class BackgroundTaskState(State):
|
|||||||
# Even nested access to mutables raises an exception.
|
# Even nested access to mutables raises an exception.
|
||||||
self.dict_list["foo"].append(42)
|
self.dict_list["foo"].append(42)
|
||||||
|
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
# Direct calling another handler that modifies state raises an exception.
|
||||||
|
self.other()
|
||||||
|
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
# Calling other methods that modify state raises an exception.
|
||||||
|
self._private_method()
|
||||||
|
|
||||||
# wait for some other event to happen
|
# wait for some other event to happen
|
||||||
while len(self.order) == 1:
|
while len(self.order) == 1:
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
@ -1707,6 +1715,22 @@ class BackgroundTaskState(State):
|
|||||||
|
|
||||||
async with self:
|
async with self:
|
||||||
self.order.append("background_task:stop")
|
self.order.append("background_task:stop")
|
||||||
|
self.other() # direct calling event handlers works in context
|
||||||
|
self._private_method()
|
||||||
|
|
||||||
|
@rx.background
|
||||||
|
async def background_task_reset(self):
|
||||||
|
"""A background task that resets the state."""
|
||||||
|
with pytest.raises(ImmutableStateError):
|
||||||
|
# Resetting the state should be explicitly blocked.
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
async with self:
|
||||||
|
self.order.append("foo")
|
||||||
|
self.reset()
|
||||||
|
assert not self.order
|
||||||
|
async with self:
|
||||||
|
self.order.append("reset")
|
||||||
|
|
||||||
@rx.background
|
@rx.background
|
||||||
async def background_task_generator(self):
|
async def background_task_generator(self):
|
||||||
@ -1721,6 +1745,10 @@ class BackgroundTaskState(State):
|
|||||||
"""Some other event that updates the state."""
|
"""Some other event that updates the state."""
|
||||||
self.order.append("other")
|
self.order.append("other")
|
||||||
|
|
||||||
|
def _private_method(self):
|
||||||
|
"""Some private method that updates the state."""
|
||||||
|
self.order.append("private")
|
||||||
|
|
||||||
async def bad_chain1(self):
|
async def bad_chain1(self):
|
||||||
"""Test that a background task cannot be chained."""
|
"""Test that a background task cannot be chained."""
|
||||||
await self.background_task()
|
await self.background_task()
|
||||||
@ -1755,7 +1783,6 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
|||||||
):
|
):
|
||||||
# background task returns empty update immediately
|
# background task returns empty update immediately
|
||||||
assert update == StateUpdate()
|
assert update == StateUpdate()
|
||||||
assert len(mock_app.background_tasks) == 1
|
|
||||||
|
|
||||||
# wait for the coroutine to start
|
# wait for the coroutine to start
|
||||||
await asyncio.sleep(0.5 if CI else 0.1)
|
await asyncio.sleep(0.5 if CI else 0.1)
|
||||||
@ -1795,6 +1822,43 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
|||||||
"background_task:start",
|
"background_task:start",
|
||||||
"other",
|
"other",
|
||||||
"background_task:stop",
|
"background_task:stop",
|
||||||
|
"other",
|
||||||
|
"private",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_background_task_reset(mock_app: rx.App, token: str):
|
||||||
|
"""Test that a background task calling reset is protected by the state proxy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mock_app: An app that will be returned by `get_app()`
|
||||||
|
token: A token.
|
||||||
|
"""
|
||||||
|
router_data = {"query": {}}
|
||||||
|
mock_app.state_manager.state = mock_app.state = BackgroundTaskState
|
||||||
|
async for update in rx.app.process( # type: ignore
|
||||||
|
mock_app,
|
||||||
|
Event(
|
||||||
|
token=token,
|
||||||
|
name=f"{BackgroundTaskState.get_name()}.background_task_reset",
|
||||||
|
router_data=router_data,
|
||||||
|
payload={},
|
||||||
|
),
|
||||||
|
sid="",
|
||||||
|
headers={},
|
||||||
|
client_ip="",
|
||||||
|
):
|
||||||
|
# background task returns empty update immediately
|
||||||
|
assert update == StateUpdate()
|
||||||
|
|
||||||
|
# Explicit wait for background tasks
|
||||||
|
for task in tuple(mock_app.background_tasks):
|
||||||
|
await task
|
||||||
|
assert not mock_app.background_tasks
|
||||||
|
|
||||||
|
assert (await mock_app.state_manager.get_state(token)).order == [
|
||||||
|
"reset",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user