From 83d7a044fedeb7c1b5c05a3db219ed81b9a2dfeb Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 21 Sep 2023 17:59:18 -0700 Subject: [PATCH] StateProxy rebinds functools.partial and methods that are bound to the proxied State (#1853) --- reflex/state.py | 13 ++++++++- tests/test_state.py | 66 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index f7ef2577f..f6a11f1cf 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -12,7 +12,7 @@ import urllib.parse import uuid from abc import ABC, abstractmethod from collections import defaultdict -from types import FunctionType +from types import FunctionType, MethodType from typing import ( Any, AsyncIterator, @@ -1177,6 +1177,17 @@ class StateProxy(wrapt.ObjectProxy): state=self, # type: ignore field_name=value._self_field_name, ) + if isinstance(value, functools.partial) and value.args[0] is self.__wrapped__: + # Rebind event handler to the proxy instance + value = functools.partial( + value.func, + self, + *value.args[1:], + **value.keywords, + ) + if isinstance(value, MethodType) and value.__self__ is self.__wrapped__: + # Rebind methods to the proxy instance + value = type(value)(value.__func__, self) # type: ignore return value def __setattr__(self, name: str, value: Any) -> None: diff --git a/tests/test_state.py b/tests/test_state.py index e24985bea..a6ea8e4e0 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1699,6 +1699,14 @@ class BackgroundTaskState(State): # Even nested access to mutables raises an exception. self.dict_list["foo"].append(42) + with pytest.raises(ImmutableStateError): + # Direct calling another handler that modifies state raises an exception. + self.other() + + with pytest.raises(ImmutableStateError): + # Calling other methods that modify state raises an exception. + self._private_method() + # wait for some other event to happen while len(self.order) == 1: await asyncio.sleep(0.01) @@ -1707,6 +1715,22 @@ class BackgroundTaskState(State): async with self: self.order.append("background_task:stop") + self.other() # direct calling event handlers works in context + self._private_method() + + @rx.background + async def background_task_reset(self): + """A background task that resets the state.""" + with pytest.raises(ImmutableStateError): + # Resetting the state should be explicitly blocked. + self.reset() + + async with self: + self.order.append("foo") + self.reset() + assert not self.order + async with self: + self.order.append("reset") @rx.background async def background_task_generator(self): @@ -1721,6 +1745,10 @@ class BackgroundTaskState(State): """Some other event that updates the state.""" self.order.append("other") + def _private_method(self): + """Some private method that updates the state.""" + self.order.append("private") + async def bad_chain1(self): """Test that a background task cannot be chained.""" await self.background_task() @@ -1755,7 +1783,6 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): ): # background task returns empty update immediately assert update == StateUpdate() - assert len(mock_app.background_tasks) == 1 # wait for the coroutine to start await asyncio.sleep(0.5 if CI else 0.1) @@ -1795,6 +1822,43 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): "background_task:start", "other", "background_task:stop", + "other", + "private", + ] + + +@pytest.mark.asyncio +async def test_background_task_reset(mock_app: rx.App, token: str): + """Test that a background task calling reset is protected by the state proxy. + + Args: + mock_app: An app that will be returned by `get_app()` + token: A token. + """ + router_data = {"query": {}} + mock_app.state_manager.state = mock_app.state = BackgroundTaskState + async for update in rx.app.process( # type: ignore + mock_app, + Event( + token=token, + name=f"{BackgroundTaskState.get_name()}.background_task_reset", + router_data=router_data, + payload={}, + ), + sid="", + headers={}, + client_ip="", + ): + # background task returns empty update immediately + assert update == StateUpdate() + + # Explicit wait for background tasks + for task in tuple(mock_app.background_tasks): + await task + assert not mock_app.background_tasks + + assert (await mock_app.state_manager.get_state(token)).order == [ + "reset", ]