diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 5b8046347..93c664ef1 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -408,7 +408,7 @@ export const connect = async ( socket.current = io(endpoint.href, { path: endpoint["pathname"], transports: transports, - protocols: env.TEST_MODE ? undefined : [reflexEnvironment.version], + protocols: [reflexEnvironment.version], autoUnref: false, }); // Ensure undefined fields in events are sent as null instead of removed diff --git a/reflex/app.py b/reflex/app.py index 5ee424719..d432925ab 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -405,7 +405,31 @@ class App(MiddlewareMixin, LifespanMixin): self.sio.register_namespace(self.event_namespace) # Mount the socket app with the API. if self.api: - self.api.mount(str(constants.Endpoint.EVENT), socket_app) + + class HeaderMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + original_send = send + + async def modified_send(message): + headers = dict(scope["headers"]) + protocol_key = b"sec-websocket-protocol" + if ( + message["type"] == "websocket.accept" + and protocol_key in headers + ): + message["headers"] = [ + *message.get("headers", []), + (b"sec-websocket-protocol", headers[protocol_key]), + ] + return await original_send(message) + + return await self.app(scope, receive, modified_send) + + socket_app_with_headers = HeaderMiddleware(socket_app) + self.api.mount(str(constants.Endpoint.EVENT), socket_app_with_headers) # Check the exception handlers self._validate_exception_handlers()