fix subprotocol for granian (#4698)

* fix subprotocol for granian

* use scope subprotocols

* use subprotocols or headers

* separate the logic
This commit is contained in:
Khaleel Al-Adhami 2025-01-28 11:46:00 -08:00 committed by GitHub
parent 9e36efbd21
commit 64fb78ac5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,6 +27,7 @@ from typing import (
Dict, Dict,
Generic, Generic,
List, List,
MutableMapping,
Optional, Optional,
Set, Set,
Type, Type,
@ -410,20 +411,25 @@ class App(MiddlewareMixin, LifespanMixin):
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
async def __call__(self, scope, receive, send): async def __call__(
self, scope: MutableMapping[str, Any], receive, send
):
original_send = send original_send = send
async def modified_send(message): async def modified_send(message):
headers = dict(scope["headers"]) if message["type"] == "websocket.accept":
protocol_key = b"sec-websocket-protocol" if scope.get("subprotocols"):
if ( # The following *does* say "subprotocol" instead of "subprotocols", intentionally.
message["type"] == "websocket.accept" message["subprotocol"] = scope["subprotocols"][0]
and protocol_key in headers
): headers = dict(message.get("headers", []))
message["headers"] = [ header_key = b"sec-websocket-protocol"
*message.get("headers", []), if subprotocol := headers.get(header_key):
(b"sec-websocket-protocol", headers[protocol_key]), message["headers"] = [
] *message.get("headers", []),
(header_key, subprotocol),
]
return await original_send(message) return await original_send(message)
return await self.app(scope, receive, modified_send) return await self.app(scope, receive, modified_send)