diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 41dbee446..98dfefd69 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -300,10 +300,7 @@ export const applyEvent = async (event, socket) => { // Send the event to the server. if (socket) { - socket.emit( - "event", - event, - ); + socket.emit("event", event); return true; } @@ -408,9 +405,10 @@ export const connect = async ( path: endpoint["pathname"], transports: transports, autoUnref: false, + query: { token: getToken() }, }); // Ensure undefined fields in events are sent as null instead of removed - socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v) + socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v); function checkVisibility() { if (document.visibilityState === "visible") { @@ -461,6 +459,10 @@ export const connect = async ( event_processing = false; queueEvents([...initialEvents(), event], socket); }); + socket.current.on("new_token", async (new_token) => { + token = new_token; + window.sessionStorage.setItem(TOKEN_KEY, new_token); + }); document.addEventListener("visibilitychange", checkVisibility); }; @@ -488,7 +490,7 @@ export const uploadFiles = async ( return false; } - const upload_ref_name = `__upload_controllers_${upload_id}` + const upload_ref_name = `__upload_controllers_${upload_id}`; if (refs[upload_ref_name]) { console.log("Upload already in progress for ", upload_id); diff --git a/reflex/app.py b/reflex/app.py index 259fcca29..d0c9592c4 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -15,6 +15,7 @@ import multiprocessing import platform import sys import traceback +import uuid from datetime import datetime from pathlib import Path from types import SimpleNamespace @@ -1528,14 +1529,18 @@ class EventNamespace(AsyncNamespace): self.sid_to_token = {} self.app = app - def on_connect(self, sid, environ): + async def on_connect(self, sid, environ): """Event for when the websocket is connected. Args: sid: The Socket.IO session id. environ: The request information, including HTTP headers. """ - pass + query_string = environ.get("QUERY_STRING") + query_params = dict( + qc.split("=") for qc in query_string.split("&") if "=" in qc + ) + await self.link_token_to_sid(sid, query_params.get("token")) def on_disconnect(self, sid): """Event for when the websocket disconnects. @@ -1575,9 +1580,6 @@ class EventNamespace(AsyncNamespace): **{k: v for k, v in fields.items() if k not in ("handler", "event_actions")} ) - self.token_to_sid[event.token] = sid - self.sid_to_token[sid] = event.token - # Get the event environment. if self.app.sio is None: raise RuntimeError("Socket.IO is not initialized.") @@ -1610,3 +1612,17 @@ class EventNamespace(AsyncNamespace): """ # Emit the test event. await self.emit(str(constants.SocketEvent.PING), "pong", to=sid) + + async def link_token_to_sid(self, sid, token): + """Link a token to a session id. + + Args: + sid: The Socket.IO session id. + token: The client token. + """ + if token in self.sid_to_token.values() and sid != self.token_to_sid.get(token): + token = uuid.uuid4().hex + await self.emit("new_token", token, to=sid) + + self.token_to_sid[token] = sid + self.sid_to_token[sid] = token