diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index e135c7c0b..304bc92de 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -300,7 +300,7 @@ export const applyEvent = async (event, socket) => { if (socket) { socket.emit( "event", - JSON.stringify(event, (k, v) => (v === undefined ? null : v)) + event, ); return true; } @@ -407,6 +407,8 @@ export const connect = async ( transports: transports, autoUnref: false, }); + // Ensure undefined fields in events are sent as null instead of removed + socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v) function checkVisibility() { if (document.visibilityState === "visible") { @@ -443,8 +445,7 @@ export const connect = async ( }); // On each received message, queue the updates and events. - socket.current.on("event", async (message) => { - const update = JSON5.parse(message); + socket.current.on("event", async (update) => { for (const substate in update.delta) { dispatch[substate](update.delta[substate]); } @@ -456,7 +457,7 @@ export const connect = async ( }); socket.current.on("reload", async (event) => { event_processing = false; - queueEvents([...initialEvents(), JSON5.parse(event)], socket); + queueEvents([...initialEvents(), event], socket); }); document.addEventListener("visibilitychange", checkVisibility); @@ -497,23 +498,31 @@ export const uploadFiles = async ( // Whenever called, responseText will contain the entire response so far. const chunks = progressEvent.event.target.responseText.trim().split("\n"); // So only process _new_ chunks beyond resp_idx. - chunks.slice(resp_idx).map((chunk) => { - event_callbacks.map((f, ix) => { - f(chunk) - .then(() => { - if (ix === event_callbacks.length - 1) { - // Mark this chunk as processed. - resp_idx += 1; - } - }) - .catch((e) => { - if (progressEvent.progress === 1) { - // Chunk may be incomplete, so only report errors when full response is available. - console.log("Error parsing chunk", chunk, e); - } - return; - }); - }); + chunks.slice(resp_idx).map((chunk_json) => { + try { + const chunk = JSON5.parse(chunk_json); + event_callbacks.map((f, ix) => { + f(chunk) + .then(() => { + if (ix === event_callbacks.length - 1) { + // Mark this chunk as processed. + resp_idx += 1; + } + }) + .catch((e) => { + if (progressEvent.progress === 1) { + // Chunk may be incomplete, so only report errors when full response is available. + console.log("Error processing chunk", chunk, e); + } + return; + }); + }); + } catch (e) { + if (progressEvent.progress === 1) { + console.log("Error parsing chunk", chunk_json, e); + } + return; + } }); }; diff --git a/reflex/app.py b/reflex/app.py index 42808823a..67bb203fa 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -17,6 +17,7 @@ import sys import traceback from datetime import datetime from pathlib import Path +from types import SimpleNamespace from typing import ( TYPE_CHECKING, Any, @@ -363,6 +364,10 @@ class App(MiddlewareMixin, LifespanMixin): max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, ping_interval=constants.Ping.INTERVAL, ping_timeout=constants.Ping.TIMEOUT, + json=SimpleNamespace( + dumps=staticmethod(format.json_dumps), + loads=staticmethod(json.loads), + ), transports=["websocket"], ) elif getattr(self.sio, "async_mode", "") != "asgi": @@ -1543,7 +1548,7 @@ class EventNamespace(AsyncNamespace): """ # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( - self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) + self.emit(str(constants.SocketEvent.EVENT), update, to=sid) ) async def on_event(self, sid, data): @@ -1556,7 +1561,7 @@ class EventNamespace(AsyncNamespace): sid: The Socket.IO session id. data: The event data. """ - fields = json.loads(data) + fields = data # Get the event. event = Event( **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")} diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 0159a17c3..1d6671a0b 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -664,18 +664,22 @@ def format_library_name(library_fullname: str): return lib -def json_dumps(obj: Any) -> str: +def json_dumps(obj: Any, **kwargs) -> str: """Takes an object and returns a jsonified string. Args: obj: The object to be serialized. + kwargs: Additional keyword arguments to pass to json.dumps. Returns: A string """ from reflex.utils import serializers - return json.dumps(obj, ensure_ascii=False, default=serializers.serialize) + kwargs.setdefault("ensure_ascii", False) + kwargs.setdefault("default", serializers.serialize) + + return json.dumps(obj, **kwargs) def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]: diff --git a/tests/units/test_state.py b/tests/units/test_state.py index a580f9d74..9e952e10f 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1840,6 +1840,24 @@ async def test_state_manager_lock_expire_contend( assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 +class CopyingAsyncMock(AsyncMock): + """An AsyncMock, but deepcopy the args and kwargs first.""" + + def __call__(self, *args, **kwargs): + """Call the mock. + + Args: + args: the arguments passed to the mock + kwargs: the keyword arguments passed to the mock + + Returns: + The result of the mock call + """ + args = copy.deepcopy(args) + kwargs = copy.deepcopy(kwargs) + return super().__call__(*args, **kwargs) + + @pytest.fixture(scope="function") def mock_app_simple(monkeypatch) -> rx.App: """Simple Mock app fixture. @@ -1856,7 +1874,7 @@ def mock_app_simple(monkeypatch) -> rx.App: setattr(app_module, CompileVars.APP, app) app.state = TestState - app.event_namespace.emit = AsyncMock() # type: ignore + app.event_namespace.emit = CopyingAsyncMock() # type: ignore def _mock_get_app(*args, **kwargs): return app_module @@ -1960,21 +1978,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): mock_app.event_namespace.emit.assert_called_once() mcall = mock_app.event_namespace.emit.mock_calls[0] assert mcall.args[0] == str(SocketEvent.EVENT) - assert json.loads(mcall.args[1]) == dataclasses.asdict( - StateUpdate( - delta={ - parent_state.get_full_name(): { - "upper": "", - "sum": 3.14, - }, - grandchild_state.get_full_name(): { - "value2": "42", - }, - GrandchildState3.get_full_name(): { - "computed": "", - }, - } - ) + assert mcall.args[1] == StateUpdate( + delta={ + parent_state.get_full_name(): { + "upper": "", + "sum": 3.14, + }, + grandchild_state.get_full_name(): { + "value2": "42", + }, + GrandchildState3.get_full_name(): { + "computed": "", + }, + } ) assert mcall.kwargs["to"] == grandchild_state.router.session.session_id @@ -2156,51 +2172,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): assert mock_app.event_namespace is not None emit_mock = mock_app.event_namespace.emit - first_ws_message = json.loads(emit_mock.mock_calls[0].args[1]) + first_ws_message = emit_mock.mock_calls[0].args[1] assert ( - first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router") + first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router") is not None ) - assert first_ws_message == { - "delta": { + assert first_ws_message == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "order": ["background_task:start"], "computed_order": ["background_task:start"], } }, - "events": [], - "final": True, - } + events=[], + final=True, + ) for call in emit_mock.mock_calls[1:5]: - assert json.loads(call.args[1]) == { - "delta": { + assert call.args[1] == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "computed_order": ["background_task:start"], } }, - "events": [], - "final": True, - } - assert json.loads(emit_mock.mock_calls[-2].args[1]) == { - "delta": { + events=[], + final=True, + ) + assert emit_mock.mock_calls[-2].args[1] == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "order": exp_order, "computed_order": exp_order, "dict_list": {}, } }, - "events": [], - "final": True, - } - assert json.loads(emit_mock.mock_calls[-1].args[1]) == { - "delta": { + events=[], + final=True, + ) + assert emit_mock.mock_calls[-1].args[1] == StateUpdate( + delta={ BackgroundTaskState.get_full_name(): { "computed_order": exp_order, }, }, - "events": [], - "final": True, - } + events=[], + final=True, + ) @pytest.mark.asyncio