StateProxy rebinds functools.partial and methods that are bound to the proxied State (#1853)

This commit is contained in:
Masen Furer 2023-09-21 17:59:18 -07:00 committed by GitHub
parent 351611ca25
commit 83d7a044fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 2 deletions

View File

@ -12,7 +12,7 @@ import urllib.parse
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from types import FunctionType
from types import FunctionType, MethodType
from typing import (
Any,
AsyncIterator,
@ -1177,6 +1177,17 @@ class StateProxy(wrapt.ObjectProxy):
state=self, # type: ignore
field_name=value._self_field_name,
)
if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
# Rebind event handler to the proxy instance
value = functools.partial(
value.func,
self,
*value.args[1:],
**value.keywords,
)
if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
# Rebind methods to the proxy instance
value = type(value)(value.__func__, self) # type: ignore
return value
def __setattr__(self, name: str, value: Any) -> None:

View File

@ -1699,6 +1699,14 @@ class BackgroundTaskState(State):
# Even nested access to mutables raises an exception.
self.dict_list["foo"].append(42)
with pytest.raises(ImmutableStateError):
# Direct calling another handler that modifies state raises an exception.
self.other()
with pytest.raises(ImmutableStateError):
# Calling other methods that modify state raises an exception.
self._private_method()
# wait for some other event to happen
while len(self.order) == 1:
await asyncio.sleep(0.01)
@ -1707,6 +1715,22 @@ class BackgroundTaskState(State):
async with self:
self.order.append("background_task:stop")
self.other() # direct calling event handlers works in context
self._private_method()
@rx.background
async def background_task_reset(self):
"""A background task that resets the state."""
with pytest.raises(ImmutableStateError):
# Resetting the state should be explicitly blocked.
self.reset()
async with self:
self.order.append("foo")
self.reset()
assert not self.order
async with self:
self.order.append("reset")
@rx.background
async def background_task_generator(self):
@ -1721,6 +1745,10 @@ class BackgroundTaskState(State):
"""Some other event that updates the state."""
self.order.append("other")
def _private_method(self):
"""Some private method that updates the state."""
self.order.append("private")
async def bad_chain1(self):
"""Test that a background task cannot be chained."""
await self.background_task()
@ -1755,7 +1783,6 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
):
# background task returns empty update immediately
assert update == StateUpdate()
assert len(mock_app.background_tasks) == 1
# wait for the coroutine to start
await asyncio.sleep(0.5 if CI else 0.1)
@ -1795,6 +1822,43 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
"background_task:start",
"other",
"background_task:stop",
"other",
"private",
]
@pytest.mark.asyncio
async def test_background_task_reset(mock_app: rx.App, token: str):
"""Test that a background task calling reset is protected by the state proxy.
Args:
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
router_data = {"query": {}}
mock_app.state_manager.state = mock_app.state = BackgroundTaskState
async for update in rx.app.process( # type: ignore
mock_app,
Event(
token=token,
name=f"{BackgroundTaskState.get_name()}.background_task_reset",
router_data=router_data,
payload={},
),
sid="",
headers={},
client_ip="",
):
# background task returns empty update immediately
assert update == StateUpdate()
# Explicit wait for background tasks
for task in tuple(mock_app.background_tasks):
await task
assert not mock_app.background_tasks
assert (await mock_app.state_manager.get_state(token)).order == [
"reset",
]