Generate state delta from processed state instance (#2023)

This commit is contained in:
Masen Furer 2023-10-24 10:44:12 -07:00 committed by GitHub
parent 6ea657a4fd
commit 1734ba0b6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 17 deletions

View File

@ -963,14 +963,19 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
Returns:
The valid StateUpdate containing the events and final flag.
"""
# get the delta from the root of the state tree
state = self
while state.parent_state is not None:
state = state.parent_state
token = self.router.session.client_token
# Convert valid EventHandler and EventSpec into Event
fixed_events = fix_events(self._check_valid(handler, events), token)
# Get the delta after processing the event.
delta = self.get_delta()
self._clean()
delta = state.get_delta()
state._clean()
return StateUpdate(
delta=delta,
@ -1009,30 +1014,30 @@ class State(Base, ABC, extra=pydantic.Extra.allow):
# Handle async generators.
if inspect.isasyncgen(events):
async for event in events:
yield self._as_state_update(handler, event, final=False)
yield self._as_state_update(handler, events=None, final=True)
yield state._as_state_update(handler, event, final=False)
yield state._as_state_update(handler, events=None, final=True)
# Handle regular generators.
elif inspect.isgenerator(events):
try:
while True:
yield self._as_state_update(handler, next(events), final=False)
yield state._as_state_update(handler, next(events), final=False)
except StopIteration as si:
# the "return" value of the generator is not available
# in the loop, we must catch StopIteration to access it
if si.value is not None:
yield self._as_state_update(handler, si.value, final=False)
yield self._as_state_update(handler, events=None, final=True)
yield state._as_state_update(handler, si.value, final=False)
yield state._as_state_update(handler, events=None, final=True)
# Handle regular event chains.
else:
yield self._as_state_update(handler, events, final=True)
yield state._as_state_update(handler, events, final=True)
# If an error occurs, throw a window alert.
except Exception:
error = traceback.format_exc()
print(error)
yield self._as_state_update(
yield state._as_state_update(
handler,
window_alert("An error occurred. See logs for details."),
final=True,
@ -1360,12 +1365,19 @@ class StateProxy(wrapt.ObjectProxy):
Raises:
ImmutableStateError: If the state is not in mutable mode.
"""
if not name.startswith("_self_") and not self._self_mutable:
raise ImmutableStateError(
"Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state."
)
super().__setattr__(name, value)
if (
name.startswith("_self_") # wrapper attribute
or self._self_mutable # lock held
# non-persisted state attribute
or name in self.__wrapped__.get_skip_vars()
):
super().__setattr__(name, value)
return
raise ImmutableStateError(
"Background task StateProxy is immutable outside of a context "
"manager. Use `async with self` to modify state."
)
class StateUpdate(Base):

View File

@ -1577,7 +1577,7 @@ def mock_app(monkeypatch, state_manager: StateManager) -> rx.App:
setattr(app_module, CompileVars.APP, app)
app.state = TestState
app.state_manager = state_manager
app._state_manager = state_manager
app.event_namespace.emit = AsyncMock() # type: ignore
monkeypatch.setattr(prerequisites, "get_app", lambda: app_module)
return app
@ -1663,6 +1663,15 @@ class BackgroundTaskState(State):
order: List[str] = []
dict_list: Dict[str, List[int]] = {"foo": [1, 2, 3]}
@rx.var
def computed_order(self) -> List[str]:
"""Get the order as a computed var.
Returns:
The value of 'order' var.
"""
return self.order
@rx.background
async def background_task(self):
"""A background task that updates the state."""
@ -1791,6 +1800,10 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
"background_task:start",
"other",
],
"computed_order": [
"background_task:start",
"other",
],
}
}
)
@ -1800,7 +1813,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
await task
assert not mock_app.background_tasks
assert (await mock_app.state_manager.get_state(token)).order == [
exp_order = [
"background_task:start",
"other",
"background_task:stop",
@ -1808,6 +1821,50 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
"private",
]
assert (await mock_app.state_manager.get_state(token)).order == exp_order
assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit
assert json.loads(emit_mock.mock_calls[0].args[1]) == {
"delta": {
"background_task_state": {
"order": ["background_task:start"],
"computed_order": ["background_task:start"],
}
},
"events": [],
"final": True,
}
for call in emit_mock.mock_calls[1:5]:
assert json.loads(call.args[1]) == {
"delta": {
"background_task_state": {"computed_order": ["background_task:start"]}
},
"events": [],
"final": True,
}
assert json.loads(emit_mock.mock_calls[-2].args[1]) == {
"delta": {
"background_task_state": {
"order": exp_order,
"computed_order": exp_order,
"dict_list": {},
}
},
"events": [],
"final": True,
}
assert json.loads(emit_mock.mock_calls[-1].args[1]) == {
"delta": {
"background_task_state": {
"computed_order": exp_order,
},
},
"events": [],
"final": True,
}
@pytest.mark.asyncio
async def test_background_task_reset(mock_app: rx.App, token: str):