diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 2f09ac2de..f571500ff 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -227,8 +227,8 @@ export const applyEvent = async (event, socket) => { a.href = eval?.( event.payload.url.replace( "getBackendURL(env.UPLOAD)", - `"${getBackendURL(env.UPLOAD)}"`, - ), + `"${getBackendURL(env.UPLOAD)}"` + ) ); } a.download = event.payload.filename; @@ -341,7 +341,7 @@ export const applyRestEvent = async (event, socket) => { event.payload.files, event.payload.upload_id, event.payload.on_upload_progress, - socket, + socket ); return false; } @@ -408,7 +408,7 @@ export const connect = async ( dispatch, transports, setConnectErrors, - client_storage = {}, + client_storage = {} ) => { // Get backend URL object from the endpoint. const endpoint = getBackendURL(EVENTURL); @@ -419,6 +419,7 @@ export const connect = async ( transports: transports, protocols: [reflexEnvironment.version], autoUnref: false, + query: { token: getToken() }, }); // Ensure undefined fields in events are sent as null instead of removed socket.current.io.encoder.replacer = (k, v) => (v === undefined ? null : v); @@ -479,6 +480,10 @@ export const connect = async ( event_processing = false; queueEvents([...initialEvents(), event], socket); }); + socket.current.on("new_token", async (new_token) => { + token = new_token; + window.sessionStorage.setItem(TOKEN_KEY, new_token); + }); document.addEventListener("visibilitychange", checkVisibility); }; @@ -499,7 +504,7 @@ export const uploadFiles = async ( files, upload_id, on_upload_progress, - socket, + socket ) => { // return if there's no file to upload if (files === undefined || files.length === 0) { @@ -604,7 +609,7 @@ export const Event = ( name, payload = {}, event_actions = {}, - handler = null, + handler = null ) => { return { name, payload, handler, event_actions }; }; @@ -631,7 +636,7 @@ export const hydrateClientStorage = (client_storage) => { for (const state_key in client_storage.local_storage) { const options = client_storage.local_storage[state_key]; const local_storage_value = localStorage.getItem( - options.name || state_key, + options.name || state_key ); if (local_storage_value !== null) { client_storage_values[state_key] = local_storage_value; @@ -642,7 +647,7 @@ export const hydrateClientStorage = (client_storage) => { for (const state_key in client_storage.session_storage) { const session_options = client_storage.session_storage[state_key]; const session_storage_value = sessionStorage.getItem( - session_options.name || state_key, + session_options.name || state_key ); if (session_storage_value != null) { client_storage_values[state_key] = session_storage_value; @@ -667,7 +672,7 @@ export const hydrateClientStorage = (client_storage) => { const applyClientStorageDelta = (client_storage, delta) => { // find the main state and check for is_hydrated const unqualified_states = Object.keys(delta).filter( - (key) => key.split(".").length === 1, + (key) => key.split(".").length === 1 ); if (unqualified_states.length === 1) { const main_state = delta[unqualified_states[0]]; @@ -701,7 +706,7 @@ const applyClientStorageDelta = (client_storage, delta) => { const session_options = client_storage.session_storage[state_key]; sessionStorage.setItem( session_options.name || state_key, - delta[substate][key], + delta[substate][key] ); } } @@ -721,7 +726,7 @@ const applyClientStorageDelta = (client_storage, delta) => { export const useEventLoop = ( dispatch, initial_events = () => [], - client_storage = {}, + client_storage = {} ) => { const socket = useRef(null); const router = useRouter(); @@ -735,7 +740,7 @@ export const useEventLoop = ( event_actions = events.reduce( (acc, e) => ({ ...acc, ...e.event_actions }), - event_actions ?? {}, + event_actions ?? {} ); const _e = args.filter((o) => o?.preventDefault !== undefined)[0]; @@ -763,7 +768,7 @@ export const useEventLoop = ( debounce( combined_name, () => queueEvents(events, socket), - event_actions.debounce, + event_actions.debounce ); } else { queueEvents(events, socket); @@ -782,7 +787,7 @@ export const useEventLoop = ( query, asPath, }))(router), - })), + })) ); sentHydrate.current = true; } @@ -828,7 +833,7 @@ export const useEventLoop = ( dispatch, ["websocket"], setConnectErrors, - client_storage, + client_storage ); } } @@ -876,7 +881,7 @@ export const useEventLoop = ( vars[storage_to_state_map[e.key]] = e.newValue; const event = Event( `${state_name}.reflex___state____update_vars_internal_state.update_vars_internal`, - { vars: vars }, + { vars: vars } ); addEvents([event], e); } @@ -969,7 +974,7 @@ export const getRefValues = (refs) => { return refs.map((ref) => ref.current ? ref.current.value || ref.current.getAttribute("aria-valuenow") - : null, + : null ); }; diff --git a/reflex/app.py b/reflex/app.py index 6e66257b4..b386001f1 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -13,6 +13,8 @@ import io import json import sys import traceback +import urllib.parse +import uuid from datetime import datetime from pathlib import Path from timeit import default_timer as timer @@ -1825,13 +1827,16 @@ class EventNamespace(AsyncNamespace): self.sid_to_token = {} self.app = app - def on_connect(self, sid: str, environ: dict): + async def on_connect(self, sid: str, environ: dict): """Event for when the websocket is connected. Args: sid: The Socket.IO session id. environ: The request information, including HTTP headers. """ + query_params = urllib.parse.parse_qs(environ.get("QUERY_STRING")) + await self.link_token_to_sid(sid, query_params.get("token", [])[0]) + subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL") if subprotocol and subprotocol != constants.Reflex.VERSION: console.warn( @@ -1900,9 +1905,6 @@ class EventNamespace(AsyncNamespace): f"Failed to deserialize event data: {fields}." ) from ex - self.token_to_sid[event.token] = sid - self.sid_to_token[sid] = event.token - # Get the event environment. if self.app.sio is None: raise RuntimeError("Socket.IO is not initialized.") @@ -1935,3 +1937,17 @@ class EventNamespace(AsyncNamespace): """ # Emit the test event. await self.emit(str(constants.SocketEvent.PING), "pong", to=sid) + + async def link_token_to_sid(self, sid: str, token: str): + """Link a token to a session id. + + Args: + sid: The Socket.IO session id. + token: The client token. + """ + if token in self.sid_to_token.values() and sid != self.token_to_sid.get(token): + token = str(uuid.uuid4()) + await self.emit("new_token", token, to=sid) + + self.token_to_sid[token] = sid + self.sid_to_token[sid] = token