Avoid double JSON encode/decode for socket.io (#4449)

* Avoid double JSON encode/decode for socket.io

socket.io (python and js) already has a built in mechanism for JSON encoding
and decoding messages over the websocket. To use it, we pass a custom `json`
namespace which uses `format.json_dumps` (leveraging reflex serializers) to encode the
messages. This avoids sending a JSON-encoded string of JSON over the wire, and
reduces the number of serialization/deserialization passes over the message
data.

The side benefit is that debugging websocket messages in browser tools displays
the parsed JSON hierarchy and is much easier to work with.

* JSON5.parse in on_upload_progress handler responses
This commit is contained in:
Masen Furer 2024-12-12 05:47:23 -08:00 committed by GitHub
parent 053cbe7558
commit a2f14e7713
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 97 additions and 63 deletions

View File

@ -300,7 +300,7 @@ export const applyEvent = async (event, socket) => {
if (socket) { if (socket) {
socket.emit( socket.emit(
"event", "event",
JSON.stringify(event, (k, v) => (v === undefined ? null : v)) event,
); );
return true; return true;
} }
@ -407,6 +407,8 @@ export const connect = async (
transports: transports, transports: transports,
autoUnref: false, autoUnref: false,
}); });
// Ensure undefined fields in events are sent as null instead of removed
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)
function checkVisibility() { function checkVisibility() {
if (document.visibilityState === "visible") { if (document.visibilityState === "visible") {
@ -443,8 +445,7 @@ export const connect = async (
}); });
// On each received message, queue the updates and events. // On each received message, queue the updates and events.
socket.current.on("event", async (message) => { socket.current.on("event", async (update) => {
const update = JSON5.parse(message);
for (const substate in update.delta) { for (const substate in update.delta) {
dispatch[substate](update.delta[substate]); dispatch[substate](update.delta[substate]);
} }
@ -456,7 +457,7 @@ export const connect = async (
}); });
socket.current.on("reload", async (event) => { socket.current.on("reload", async (event) => {
event_processing = false; event_processing = false;
queueEvents([...initialEvents(), JSON5.parse(event)], socket); queueEvents([...initialEvents(), event], socket);
}); });
document.addEventListener("visibilitychange", checkVisibility); document.addEventListener("visibilitychange", checkVisibility);
@ -497,23 +498,31 @@ export const uploadFiles = async (
// Whenever called, responseText will contain the entire response so far. // Whenever called, responseText will contain the entire response so far.
const chunks = progressEvent.event.target.responseText.trim().split("\n"); const chunks = progressEvent.event.target.responseText.trim().split("\n");
// So only process _new_ chunks beyond resp_idx. // So only process _new_ chunks beyond resp_idx.
chunks.slice(resp_idx).map((chunk) => { chunks.slice(resp_idx).map((chunk_json) => {
event_callbacks.map((f, ix) => { try {
f(chunk) const chunk = JSON5.parse(chunk_json);
.then(() => { event_callbacks.map((f, ix) => {
if (ix === event_callbacks.length - 1) { f(chunk)
// Mark this chunk as processed. .then(() => {
resp_idx += 1; if (ix === event_callbacks.length - 1) {
} // Mark this chunk as processed.
}) resp_idx += 1;
.catch((e) => { }
if (progressEvent.progress === 1) { })
// Chunk may be incomplete, so only report errors when full response is available. .catch((e) => {
console.log("Error parsing chunk", chunk, e); if (progressEvent.progress === 1) {
} // Chunk may be incomplete, so only report errors when full response is available.
return; console.log("Error processing chunk", chunk, e);
}); }
}); return;
});
});
} catch (e) {
if (progressEvent.progress === 1) {
console.log("Error parsing chunk", chunk_json, e);
}
return;
}
}); });
}; };

View File

@ -17,6 +17,7 @@ import sys
import traceback import traceback
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -363,6 +364,10 @@ class App(MiddlewareMixin, LifespanMixin):
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
ping_interval=constants.Ping.INTERVAL, ping_interval=constants.Ping.INTERVAL,
ping_timeout=constants.Ping.TIMEOUT, ping_timeout=constants.Ping.TIMEOUT,
json=SimpleNamespace(
dumps=staticmethod(format.json_dumps),
loads=staticmethod(json.loads),
),
transports=["websocket"], transports=["websocket"],
) )
elif getattr(self.sio, "async_mode", "") != "asgi": elif getattr(self.sio, "async_mode", "") != "asgi":
@ -1543,7 +1548,7 @@ class EventNamespace(AsyncNamespace):
""" """
# Creating a task prevents the update from being blocked behind other coroutines. # Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task( await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
) )
async def on_event(self, sid, data): async def on_event(self, sid, data):
@ -1556,7 +1561,7 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id. sid: The Socket.IO session id.
data: The event data. data: The event data.
""" """
fields = json.loads(data) fields = data
# Get the event. # Get the event.
event = Event( event = Event(
**{k: v for k, v in fields.items() if k not in ("handler", "event_actions")} **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}

View File

@ -664,18 +664,22 @@ def format_library_name(library_fullname: str):
return lib return lib
def json_dumps(obj: Any) -> str: def json_dumps(obj: Any, **kwargs) -> str:
"""Takes an object and returns a jsonified string. """Takes an object and returns a jsonified string.
Args: Args:
obj: The object to be serialized. obj: The object to be serialized.
kwargs: Additional keyword arguments to pass to json.dumps.
Returns: Returns:
A string A string
""" """
from reflex.utils import serializers from reflex.utils import serializers
return json.dumps(obj, ensure_ascii=False, default=serializers.serialize) kwargs.setdefault("ensure_ascii", False)
kwargs.setdefault("default", serializers.serialize)
return json.dumps(obj, **kwargs)
def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]: def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]:

View File

@ -1840,6 +1840,24 @@ async def test_state_manager_lock_expire_contend(
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
class CopyingAsyncMock(AsyncMock):
"""An AsyncMock, but deepcopy the args and kwargs first."""
def __call__(self, *args, **kwargs):
"""Call the mock.
Args:
args: the arguments passed to the mock
kwargs: the keyword arguments passed to the mock
Returns:
The result of the mock call
"""
args = copy.deepcopy(args)
kwargs = copy.deepcopy(kwargs)
return super().__call__(*args, **kwargs)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_app_simple(monkeypatch) -> rx.App: def mock_app_simple(monkeypatch) -> rx.App:
"""Simple Mock app fixture. """Simple Mock app fixture.
@ -1856,7 +1874,7 @@ def mock_app_simple(monkeypatch) -> rx.App:
setattr(app_module, CompileVars.APP, app) setattr(app_module, CompileVars.APP, app)
app.state = TestState app.state = TestState
app.event_namespace.emit = AsyncMock() # type: ignore app.event_namespace.emit = CopyingAsyncMock() # type: ignore
def _mock_get_app(*args, **kwargs): def _mock_get_app(*args, **kwargs):
return app_module return app_module
@ -1960,21 +1978,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
mock_app.event_namespace.emit.assert_called_once() mock_app.event_namespace.emit.assert_called_once()
mcall = mock_app.event_namespace.emit.mock_calls[0] mcall = mock_app.event_namespace.emit.mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT) assert mcall.args[0] == str(SocketEvent.EVENT)
assert json.loads(mcall.args[1]) == dataclasses.asdict( assert mcall.args[1] == StateUpdate(
StateUpdate( delta={
delta={ parent_state.get_full_name(): {
parent_state.get_full_name(): { "upper": "",
"upper": "", "sum": 3.14,
"sum": 3.14, },
}, grandchild_state.get_full_name(): {
grandchild_state.get_full_name(): { "value2": "42",
"value2": "42", },
}, GrandchildState3.get_full_name(): {
GrandchildState3.get_full_name(): { "computed": "",
"computed": "", },
}, }
}
)
) )
assert mcall.kwargs["to"] == grandchild_state.router.session.session_id assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
@ -2156,51 +2172,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
assert mock_app.event_namespace is not None assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit emit_mock = mock_app.event_namespace.emit
first_ws_message = json.loads(emit_mock.mock_calls[0].args[1]) first_ws_message = emit_mock.mock_calls[0].args[1]
assert ( assert (
first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router") first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
is not None is not None
) )
assert first_ws_message == { assert first_ws_message == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"order": ["background_task:start"], "order": ["background_task:start"],
"computed_order": ["background_task:start"], "computed_order": ["background_task:start"],
} }
}, },
"events": [], events=[],
"final": True, final=True,
} )
for call in emit_mock.mock_calls[1:5]: for call in emit_mock.mock_calls[1:5]:
assert json.loads(call.args[1]) == { assert call.args[1] == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"computed_order": ["background_task:start"], "computed_order": ["background_task:start"],
} }
}, },
"events": [], events=[],
"final": True, final=True,
} )
assert json.loads(emit_mock.mock_calls[-2].args[1]) == { assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"order": exp_order, "order": exp_order,
"computed_order": exp_order, "computed_order": exp_order,
"dict_list": {}, "dict_list": {},
} }
}, },
"events": [], events=[],
"final": True, final=True,
} )
assert json.loads(emit_mock.mock_calls[-1].args[1]) == { assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"computed_order": exp_order, "computed_order": exp_order,
}, },
}, },
"events": [], events=[],
"final": True, final=True,
} )
@pytest.mark.asyncio @pytest.mark.asyncio