diff --git a/reflex/app.py b/reflex/app.py index 2e9765d21..20c01cba6 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -27,6 +27,7 @@ from typing import ( Dict, Generic, List, + MutableMapping, Optional, Set, Type, @@ -410,20 +411,25 @@ class App(MiddlewareMixin, LifespanMixin): def __init__(self, app): self.app = app - async def __call__(self, scope, receive, send): + async def __call__( + self, scope: MutableMapping[str, Any], receive, send + ): original_send = send async def modified_send(message): - headers = dict(scope["headers"]) - protocol_key = b"sec-websocket-protocol" - if ( - message["type"] == "websocket.accept" - and protocol_key in headers - ): - message["headers"] = [ - *message.get("headers", []), - (b"sec-websocket-protocol", headers[protocol_key]), - ] + if message["type"] == "websocket.accept": + if scope.get("subprotocols"): + # The following *does* say "subprotocol" instead of "subprotocols", intentionally. + message["subprotocol"] = scope["subprotocols"][0] + + headers = dict(message.get("headers", [])) + header_key = b"sec-websocket-protocol" + if subprotocol := headers.get(header_key): + message["headers"] = [ + *message.get("headers", []), + (header_key, subprotocol), + ] + return await original_send(message) return await self.app(scope, receive, modified_send)