diff --git a/reflex/app.py b/reflex/app.py index bd0e932b6..9cac205f1 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -27,6 +27,7 @@ from typing import ( Dict, Generic, List, + MutableMapping, Optional, Set, Type, @@ -410,20 +411,16 @@ class App(MiddlewareMixin, LifespanMixin): def __init__(self, app): self.app = app - async def __call__(self, scope, receive, send): + async def __call__( + self, scope: MutableMapping[str, Any], receive, send + ): original_send = send async def modified_send(message): - headers = dict(scope["headers"]) - protocol_key = b"sec-websocket-protocol" - if ( - message["type"] == "websocket.accept" - and protocol_key in headers - and isinstance( - (subprotocol := headers[protocol_key]), bytes - ) + if message["type"] == "websocket.accept" and ( + subprotocols := scope.get("subprotocols") ): - message["subprotocol"] = subprotocol.decode() + message["subprotocol"] = subprotocols[0] return await original_send(message)