fix subprotocol for granian (#4698)
* fix subprotocol for granian * use scope subprotocols * use subprotocols or headers * separate the logic
This commit is contained in:
parent
9e36efbd21
commit
64fb78ac5e
@ -27,6 +27,7 @@ from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
@ -410,20 +411,25 @@ class App(MiddlewareMixin, LifespanMixin):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
async def __call__(
|
||||
self, scope: MutableMapping[str, Any], receive, send
|
||||
):
|
||||
original_send = send
|
||||
|
||||
async def modified_send(message):
|
||||
headers = dict(scope["headers"])
|
||||
protocol_key = b"sec-websocket-protocol"
|
||||
if (
|
||||
message["type"] == "websocket.accept"
|
||||
and protocol_key in headers
|
||||
):
|
||||
message["headers"] = [
|
||||
*message.get("headers", []),
|
||||
(b"sec-websocket-protocol", headers[protocol_key]),
|
||||
]
|
||||
if message["type"] == "websocket.accept":
|
||||
if scope.get("subprotocols"):
|
||||
# The following *does* say "subprotocol" instead of "subprotocols", intentionally.
|
||||
message["subprotocol"] = scope["subprotocols"][0]
|
||||
|
||||
headers = dict(message.get("headers", []))
|
||||
header_key = b"sec-websocket-protocol"
|
||||
if subprotocol := headers.get(header_key):
|
||||
message["headers"] = [
|
||||
*message.get("headers", []),
|
||||
(header_key, subprotocol),
|
||||
]
|
||||
|
||||
return await original_send(message)
|
||||
|
||||
return await self.app(scope, receive, modified_send)
|
||||
|
Loading…
Reference in New Issue
Block a user