diff --git a/reflex/state.py b/reflex/state.py index e7f4113b9..3718f7254 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -963,14 +963,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow): Returns: The valid StateUpdate containing the events and final flag. """ + # get the delta from the root of the state tree + state = self + while state.parent_state is not None: + state = state.parent_state + token = self.router.session.client_token # Convert valid EventHandler and EventSpec into Event fixed_events = fix_events(self._check_valid(handler, events), token) # Get the delta after processing the event. - delta = self.get_delta() - self._clean() + delta = state.get_delta() + state._clean() return StateUpdate( delta=delta, @@ -1009,30 +1014,30 @@ class State(Base, ABC, extra=pydantic.Extra.allow): # Handle async generators. if inspect.isasyncgen(events): async for event in events: - yield self._as_state_update(handler, event, final=False) - yield self._as_state_update(handler, events=None, final=True) + yield state._as_state_update(handler, event, final=False) + yield state._as_state_update(handler, events=None, final=True) # Handle regular generators. elif inspect.isgenerator(events): try: while True: - yield self._as_state_update(handler, next(events), final=False) + yield state._as_state_update(handler, next(events), final=False) except StopIteration as si: # the "return" value of the generator is not available # in the loop, we must catch StopIteration to access it if si.value is not None: - yield self._as_state_update(handler, si.value, final=False) - yield self._as_state_update(handler, events=None, final=True) + yield state._as_state_update(handler, si.value, final=False) + yield state._as_state_update(handler, events=None, final=True) # Handle regular event chains. else: - yield self._as_state_update(handler, events, final=True) + yield state._as_state_update(handler, events, final=True) # If an error occurs, throw a window alert. except Exception: error = traceback.format_exc() print(error) - yield self._as_state_update( + yield state._as_state_update( handler, window_alert("An error occurred. See logs for details."), final=True, @@ -1360,12 +1365,19 @@ class StateProxy(wrapt.ObjectProxy): Raises: ImmutableStateError: If the state is not in mutable mode. """ - if not name.startswith("_self_") and not self._self_mutable: - raise ImmutableStateError( - "Background task StateProxy is immutable outside of a context " - "manager. Use `async with self` to modify state." - ) - super().__setattr__(name, value) + if ( + name.startswith("_self_") # wrapper attribute + or self._self_mutable # lock held + # non-persisted state attribute + or name in self.__wrapped__.get_skip_vars() + ): + super().__setattr__(name, value) + return + + raise ImmutableStateError( + "Background task StateProxy is immutable outside of a context " + "manager. Use `async with self` to modify state." + ) class StateUpdate(Base): diff --git a/tests/test_state.py b/tests/test_state.py index c3d3ab519..62d051819 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1577,7 +1577,7 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App: setattr(app_module, CompileVars.APP, app) app.state = TestState - app.state_manager = state_manager + app._state_manager = state_manager app.event_namespace.emit = AsyncMock() # type: ignore monkeypatch.setattr(prerequisites, "get_app", lambda: app_module) return app @@ -1663,6 +1663,15 @@ class BackgroundTaskState(State): order: List[str] = [] dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]} + @rx.var + def computed_order(self) -> List[str]: + """Get the order as a computed var. + + Returns: + The value of 'order' var. + """ + return self.order + @rx.background async def background_task(self): """A background task that updates the state.""" @@ -1791,6 +1800,10 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): "background_task:start", "other", ], + "computed_order": [ + "background_task:start", + "other", + ], } } ) @@ -1800,7 +1813,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): await task assert not mock_app.background_tasks - assert (await mock_app.state_manager.get_state(token)).order == [ + exp_order = [ "background_task:start", "other", "background_task:stop", @@ -1808,6 +1821,50 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): "private", ] + assert (await mock_app.state_manager.get_state(token)).order == exp_order + + assert mock_app.event_namespace is not None + emit_mock = mock_app.event_namespace.emit + + assert json.loads(emit_mock.mock_calls[0].args[1]) == { + "delta": { + "background_task_state": { + "order": ["background_task:start"], + "computed_order": ["background_task:start"], + } + }, + "events": [], + "final": True, + } + for call in emit_mock.mock_calls[1:5]: + assert json.loads(call.args[1]) == { + "delta": { + "background_task_state": {"computed_order": ["background_task:start"]} + }, + "events": [], + "final": True, + } + assert json.loads(emit_mock.mock_calls[-2].args[1]) == { + "delta": { + "background_task_state": { + "order": exp_order, + "computed_order": exp_order, + "dict_list": {}, + } + }, + "events": [], + "final": True, + } + assert json.loads(emit_mock.mock_calls[-1].args[1]) == { + "delta": { + "background_task_state": { + "computed_order": exp_order, + }, + }, + "events": [], + "final": True, + } + @pytest.mark.asyncio async def test_background_task_reset(mock_app: rx.App, token: str):