StateProxy rebinds functools.partial and methods that are bound to the proxied State (#1853)

This commit is contained in:
Masen Furer 2023-09-21 17:59:18 -07:00 committed by GitHub
parent 351611ca25
commit 83d7a044fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 2 deletions

View File

@ -12,7 +12,7 @@ import urllib.parse
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from types import FunctionType from types import FunctionType, MethodType
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
@ -1177,6 +1177,17 @@ class StateProxy(wrapt.ObjectProxy):
state=self, # type: ignore state=self, # type: ignore
field_name=value._self_field_name, field_name=value._self_field_name,
) )
if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__:
# Rebind event handler to the proxy instance
value = functools.partial(
value.func,
self,
*value.args[1:],
**value.keywords,
)
if isinstance(value, MethodType) and value.__self__ is self.__wrapped__:
# Rebind methods to the proxy instance
value = type(value)(value.__func__, self) # type: ignore
return value return value
def __setattr__(self, name: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None:

View File

@ -1699,6 +1699,14 @@ class BackgroundTaskState(State):
# Even nested access to mutables raises an exception. # Even nested access to mutables raises an exception.
self.dict_list["foo"].append(42) self.dict_list["foo"].append(42)
with pytest.raises(ImmutableStateError):
# Direct calling another handler that modifies state raises an exception.
self.other()
with pytest.raises(ImmutableStateError):
# Calling other methods that modify state raises an exception.
self._private_method()
# wait for some other event to happen # wait for some other event to happen
while len(self.order) == 1: while len(self.order) == 1:
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
@ -1707,6 +1715,22 @@ class BackgroundTaskState(State):
async with self: async with self:
self.order.append("background_task:stop") self.order.append("background_task:stop")
self.other() # direct calling event handlers works in context
self._private_method()
@rx.background
async def background_task_reset(self):
"""A background task that resets the state."""
with pytest.raises(ImmutableStateError):
# Resetting the state should be explicitly blocked.
self.reset()
async with self:
self.order.append("foo")
self.reset()
assert not self.order
async with self:
self.order.append("reset")
@rx.background @rx.background
async def background_task_generator(self): async def background_task_generator(self):
@ -1721,6 +1745,10 @@ class BackgroundTaskState(State):
"""Some other event that updates the state.""" """Some other event that updates the state."""
self.order.append("other") self.order.append("other")
def _private_method(self):
"""Some private method that updates the state."""
self.order.append("private")
async def bad_chain1(self): async def bad_chain1(self):
"""Test that a background task cannot be chained.""" """Test that a background task cannot be chained."""
await self.background_task() await self.background_task()
@ -1755,7 +1783,6 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
): ):
# background task returns empty update immediately # background task returns empty update immediately
assert update == StateUpdate() assert update == StateUpdate()
assert len(mock_app.background_tasks) == 1
# wait for the coroutine to start # wait for the coroutine to start
await asyncio.sleep(0.5 if CI else 0.1) await asyncio.sleep(0.5 if CI else 0.1)
@ -1795,6 +1822,43 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
"background_task:start", "background_task:start",
"other", "other",
"background_task:stop", "background_task:stop",
"other",
"private",
]
@pytest.mark.asyncio
async def test_background_task_reset(mock_app: rx.App, token: str):
"""Test that a background task calling reset is protected by the state proxy.
Args:
mock_app: An app that will be returned by `get_app()`
token: A token.
"""
router_data = {"query": {}}
mock_app.state_manager.state = mock_app.state = BackgroundTaskState
async for update in rx.app.process( # type: ignore
mock_app,
Event(
token=token,
name=f"{BackgroundTaskState.get_name()}.background_task_reset",
router_data=router_data,
payload={},
),
sid="",
headers={},
client_ip="",
):
# background task returns empty update immediately
assert update == StateUpdate()
# Explicit wait for background tasks
for task in tuple(mock_app.background_tasks):
await task
assert not mock_app.background_tasks
assert (await mock_app.state_manager.get_state(token)).order == [
"reset",
] ]