Avoid double JSON encode/decode for socket.io

socket.io (python and js) already has a built in mechanism for JSON encoding
and decoding messages over the websocket. To use it, we pass a custom `json`
namespace which uses `format.json_dumps` (leveraging reflex serializers) to encode the
messages. This avoids sending a JSON-encoded string of JSON over the wire, and
reduces the number of serialization/deserialization passes over the message
data.

The side benefit is that debugging websocket messages in browser tools displays
the parsed JSON hierarchy and is much easier to work with.
This commit is contained in:
Masen Furer 2024-11-27 13:58:12 -08:00
parent 24ff29f74d
commit 6a4c2a1b9e
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
4 changed files with 71 additions and 44 deletions

View File

@ -300,7 +300,7 @@ export const applyEvent = async (event, socket) => {
if (socket) {
socket.emit(
"event",
JSON.stringify(event, (k, v) => (v === undefined ? null : v))
event,
);
return true;
}
@ -407,6 +407,8 @@ export const connect = async (
transports: transports,
autoUnref: false,
});
// Ensure undefined fields in events are sent as null instead of removed
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)
function checkVisibility() {
if (document.visibilityState === "visible") {
@ -444,7 +446,7 @@ export const connect = async (
// On each received message, queue the updates and events.
socket.current.on("event", async (message) => {
const update = JSON5.parse(message);
const update = message;
for (const substate in update.delta) {
dispatch[substate](update.delta[substate]);
}

View File

@ -17,6 +17,7 @@ import sys
import traceback
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import (
TYPE_CHECKING,
Any,
@ -362,6 +363,10 @@ class App(MiddlewareMixin, LifespanMixin):
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
ping_interval=constants.Ping.INTERVAL,
ping_timeout=constants.Ping.TIMEOUT,
json=SimpleNamespace(
dumps=staticmethod(format.json_dumps),
loads=staticmethod(json.loads),
),
)
elif getattr(self.sio, "async_mode", "") != "asgi":
raise RuntimeError(
@ -1507,7 +1512,7 @@ class EventNamespace(AsyncNamespace):
"""
# Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid)
self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
)
async def on_event(self, sid, data):
@ -1520,7 +1525,7 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id.
data: The event data.
"""
fields = json.loads(data)
fields = data
# Get the event.
event = Event(
**{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}

View File

@ -664,18 +664,22 @@ def format_library_name(library_fullname: str):
return lib
def json_dumps(obj: Any) -> str:
def json_dumps(obj: Any, **kwargs) -> str:
"""Takes an object and returns a jsonified string.
Args:
obj: The object to be serialized.
kwargs: Additional keyword arguments to pass to json.dumps.
Returns:
A string
"""
from reflex.utils import serializers
return json.dumps(obj, ensure_ascii=False, default=serializers.serialize)
kwargs.setdefault("ensure_ascii", False)
kwargs.setdefault("default", serializers.serialize)
return json.dumps(obj, **kwargs)
def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]:

View File

@ -1837,6 +1837,24 @@ async def test_state_manager_lock_expire_contend(
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
class CopyingAsyncMock(AsyncMock):
"""An AsyncMock, but deepcopy the args and kwargs first."""
def __call__(self, *args, **kwargs):
"""Call the mock.
Args:
args: the arguments passed to the mock
kwargs: the keyword arguments passed to the mock
Returns:
The result of the mock call
"""
args = copy.deepcopy(args)
kwargs = copy.deepcopy(kwargs)
return super().__call__(*args, **kwargs)
@pytest.fixture(scope="function")
def mock_app_simple(monkeypatch) -> rx.App:
"""Simple Mock app fixture.
@ -1853,7 +1871,7 @@ def mock_app_simple(monkeypatch) -> rx.App:
setattr(app_module, CompileVars.APP, app)
app.state = TestState
app.event_namespace.emit = AsyncMock() # type: ignore
app.event_namespace.emit = CopyingAsyncMock() # type: ignore
def _mock_get_app(*args, **kwargs):
return app_module
@ -1957,21 +1975,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
mock_app.event_namespace.emit.assert_called_once()
mcall = mock_app.event_namespace.emit.mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT)
assert json.loads(mcall.args[1]) == dataclasses.asdict(
StateUpdate(
delta={
parent_state.get_full_name(): {
"upper": "",
"sum": 3.14,
},
grandchild_state.get_full_name(): {
"value2": "42",
},
GrandchildState3.get_full_name(): {
"computed": "",
},
}
)
assert mcall.args[1] == StateUpdate(
delta={
parent_state.get_full_name(): {
"upper": "",
"sum": 3.14,
},
grandchild_state.get_full_name(): {
"value2": "42",
},
GrandchildState3.get_full_name(): {
"computed": "",
},
}
)
assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
@ -2149,51 +2165,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit
first_ws_message = json.loads(emit_mock.mock_calls[0].args[1])
first_ws_message = emit_mock.mock_calls[0].args[1]
assert (
first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router")
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
is not None
)
assert first_ws_message == {
"delta": {
assert first_ws_message == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"order": ["background_task:start"],
"computed_order": ["background_task:start"],
}
},
"events": [],
"final": True,
}
events=[],
final=True,
)
for call in emit_mock.mock_calls[1:5]:
assert json.loads(call.args[1]) == {
"delta": {
assert call.args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"computed_order": ["background_task:start"],
}
},
"events": [],
"final": True,
}
assert json.loads(emit_mock.mock_calls[-2].args[1]) == {
"delta": {
events=[],
final=True,
)
assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"order": exp_order,
"computed_order": exp_order,
"dict_list": {},
}
},
"events": [],
"final": True,
}
assert json.loads(emit_mock.mock_calls[-1].args[1]) == {
"delta": {
events=[],
final=True,
)
assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"computed_order": exp_order,
},
},
"events": [],
"final": True,
}
events=[],
final=True,
)
@pytest.mark.asyncio