Avoid double JSON encode/decode for socket.io

socket.io (python and js) already has a built in mechanism for JSON encoding
and decoding messages over the websocket. To use it, we pass a custom `json`
namespace which uses `format.json_dumps` (leveraging reflex serializers) to encode the
messages. This avoids sending a JSON-encoded string of JSON over the wire, and
reduces the number of serialization/deserialization passes over the message
data.

The side benefit is that debugging websocket messages in browser tools displays
the parsed JSON hierarchy and is much easier to work with.
This commit is contained in:
Masen Furer 2024-11-27 13:58:12 -08:00
parent 24ff29f74d
commit 6a4c2a1b9e
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95
4 changed files with 71 additions and 44 deletions

View File

@ -300,7 +300,7 @@ export const applyEvent = async (event, socket) => {
if (socket) { if (socket) {
socket.emit( socket.emit(
"event", "event",
JSON.stringify(event, (k, v) => (v === undefined ? null : v)) event,
); );
return true; return true;
} }
@ -407,6 +407,8 @@ export const connect = async (
transports: transports, transports: transports,
autoUnref: false, autoUnref: false,
}); });
// Ensure undefined fields in events are sent as null instead of removed
socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v)
function checkVisibility() { function checkVisibility() {
if (document.visibilityState === "visible") { if (document.visibilityState === "visible") {
@ -444,7 +446,7 @@ export const connect = async (
// On each received message, queue the updates and events. // On each received message, queue the updates and events.
socket.current.on("event", async (message) => { socket.current.on("event", async (message) => {
const update = JSON5.parse(message); const update = message;
for (const substate in update.delta) { for (const substate in update.delta) {
dispatch[substate](update.delta[substate]); dispatch[substate](update.delta[substate]);
} }

View File

@ -17,6 +17,7 @@ import sys
import traceback import traceback
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -362,6 +363,10 @@ class App(MiddlewareMixin, LifespanMixin):
max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE, max_http_buffer_size=constants.POLLING_MAX_HTTP_BUFFER_SIZE,
ping_interval=constants.Ping.INTERVAL, ping_interval=constants.Ping.INTERVAL,
ping_timeout=constants.Ping.TIMEOUT, ping_timeout=constants.Ping.TIMEOUT,
json=SimpleNamespace(
dumps=staticmethod(format.json_dumps),
loads=staticmethod(json.loads),
),
) )
elif getattr(self.sio, "async_mode", "") != "asgi": elif getattr(self.sio, "async_mode", "") != "asgi":
raise RuntimeError( raise RuntimeError(
@ -1507,7 +1512,7 @@ class EventNamespace(AsyncNamespace):
""" """
# Creating a task prevents the update from being blocked behind other coroutines. # Creating a task prevents the update from being blocked behind other coroutines.
await asyncio.create_task( await asyncio.create_task(
self.emit(str(constants.SocketEvent.EVENT), update.json(), to=sid) self.emit(str(constants.SocketEvent.EVENT), update, to=sid)
) )
async def on_event(self, sid, data): async def on_event(self, sid, data):
@ -1520,7 +1525,7 @@ class EventNamespace(AsyncNamespace):
sid: The Socket.IO session id. sid: The Socket.IO session id.
data: The event data. data: The event data.
""" """
fields = json.loads(data) fields = data
# Get the event. # Get the event.
event = Event( event = Event(
**{k: v for k, v in fields.items() if k not in ("handler", "event_actions")} **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")}

View File

@ -664,18 +664,22 @@ def format_library_name(library_fullname: str):
return lib return lib
def json_dumps(obj: Any) -> str: def json_dumps(obj: Any, **kwargs) -> str:
"""Takes an object and returns a jsonified string. """Takes an object and returns a jsonified string.
Args: Args:
obj: The object to be serialized. obj: The object to be serialized.
kwargs: Additional keyword arguments to pass to json.dumps.
Returns: Returns:
A string A string
""" """
from reflex.utils import serializers from reflex.utils import serializers
return json.dumps(obj, ensure_ascii=False, default=serializers.serialize) kwargs.setdefault("ensure_ascii", False)
kwargs.setdefault("default", serializers.serialize)
return json.dumps(obj, **kwargs)
def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]: def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]:

View File

@ -1837,6 +1837,24 @@ async def test_state_manager_lock_expire_contend(
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1 assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
class CopyingAsyncMock(AsyncMock):
"""An AsyncMock, but deepcopy the args and kwargs first."""
def __call__(self, *args, **kwargs):
"""Call the mock.
Args:
args: the arguments passed to the mock
kwargs: the keyword arguments passed to the mock
Returns:
The result of the mock call
"""
args = copy.deepcopy(args)
kwargs = copy.deepcopy(kwargs)
return super().__call__(*args, **kwargs)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_app_simple(monkeypatch) -> rx.App: def mock_app_simple(monkeypatch) -> rx.App:
"""Simple Mock app fixture. """Simple Mock app fixture.
@ -1853,7 +1871,7 @@ def mock_app_simple(monkeypatch) -> rx.App:
setattr(app_module, CompileVars.APP, app) setattr(app_module, CompileVars.APP, app)
app.state = TestState app.state = TestState
app.event_namespace.emit = AsyncMock() # type: ignore app.event_namespace.emit = CopyingAsyncMock() # type: ignore
def _mock_get_app(*args, **kwargs): def _mock_get_app(*args, **kwargs):
return app_module return app_module
@ -1957,21 +1975,19 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
mock_app.event_namespace.emit.assert_called_once() mock_app.event_namespace.emit.assert_called_once()
mcall = mock_app.event_namespace.emit.mock_calls[0] mcall = mock_app.event_namespace.emit.mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT) assert mcall.args[0] == str(SocketEvent.EVENT)
assert json.loads(mcall.args[1]) == dataclasses.asdict( assert mcall.args[1] == StateUpdate(
StateUpdate( delta={
delta={ parent_state.get_full_name(): {
parent_state.get_full_name(): { "upper": "",
"upper": "", "sum": 3.14,
"sum": 3.14, },
}, grandchild_state.get_full_name(): {
grandchild_state.get_full_name(): { "value2": "42",
"value2": "42", },
}, GrandchildState3.get_full_name(): {
GrandchildState3.get_full_name(): { "computed": "",
"computed": "", },
}, }
}
)
) )
assert mcall.kwargs["to"] == grandchild_state.router.session.session_id assert mcall.kwargs["to"] == grandchild_state.router.session.session_id
@ -2149,51 +2165,51 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
assert mock_app.event_namespace is not None assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit emit_mock = mock_app.event_namespace.emit
first_ws_message = json.loads(emit_mock.mock_calls[0].args[1]) first_ws_message = emit_mock.mock_calls[0].args[1]
assert ( assert (
first_ws_message["delta"][BackgroundTaskState.get_full_name()].pop("router") first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
is not None is not None
) )
assert first_ws_message == { assert first_ws_message == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"order": ["background_task:start"], "order": ["background_task:start"],
"computed_order": ["background_task:start"], "computed_order": ["background_task:start"],
} }
}, },
"events": [], events=[],
"final": True, final=True,
} )
for call in emit_mock.mock_calls[1:5]: for call in emit_mock.mock_calls[1:5]:
assert json.loads(call.args[1]) == { assert call.args[1] == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"computed_order": ["background_task:start"], "computed_order": ["background_task:start"],
} }
}, },
"events": [], events=[],
"final": True, final=True,
} )
assert json.loads(emit_mock.mock_calls[-2].args[1]) == { assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"order": exp_order, "order": exp_order,
"computed_order": exp_order, "computed_order": exp_order,
"dict_list": {}, "dict_list": {},
} }
}, },
"events": [], events=[],
"final": True, final=True,
} )
assert json.loads(emit_mock.mock_calls[-1].args[1]) == { assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
"delta": { delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"computed_order": exp_order, "computed_order": exp_order,
}, },
}, },
"events": [], events=[],
"final": True, final=True,
} )
@pytest.mark.asyncio @pytest.mark.asyncio