From 64fb78ac5e3310fa4783739f9054468b5913cf06 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Tue, 28 Jan 2025 11:46:00 -0800 Subject: [PATCH] fix subprotocol for granian (#4698) * fix subprotocol for granian * use scope subprotocols * use subprotocols or headers * separate the logic --- reflex/app.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 2e9765d21..20c01cba6 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -27,6 +27,7 @@ from typing import ( Dict, Generic, List, + MutableMapping, Optional, Set, Type, @@ -410,20 +411,25 @@ class App(MiddlewareMixin, LifespanMixin): def __init__(self, app): self.app = app - async def __call__(self, scope, receive, send): + async def __call__( + self, scope: MutableMapping[str, Any], receive, send + ): original_send = send async def modified_send(message): - headers = dict(scope["headers"]) - protocol_key = b"sec-websocket-protocol" - if ( - message["type"] == "websocket.accept" - and protocol_key in headers - ): - message["headers"] = [ - *message.get("headers", []), - (b"sec-websocket-protocol", headers[protocol_key]), - ] + if message["type"] == "websocket.accept": + if scope.get("subprotocols"): + # The following *does* say "subprotocol" instead of "subprotocols", intentionally. + message["subprotocol"] = scope["subprotocols"][0] + + headers = dict(message.get("headers", [])) + header_key = b"sec-websocket-protocol" + if subprotocol := headers.get(header_key): + message["headers"] = [ + *message.get("headers", []), + (header_key, subprotocol), + ] + return await original_send(message) return await self.app(scope, receive, modified_send)