diff --git a/starlette/websockets.py b/starlette/websockets.py index fb76361c8..51a02a9a2 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -182,10 +182,23 @@ async def close(self, code: int = 1000, reason: str | None = None) -> None: async def send_denial_response(self, response: Response) -> None: if "websocket.http.response" in self.scope.get("extensions", {}): - await response(self.scope, self.receive, self.send) + wrapped_send = self._send_wrap(self.send, is_websocket_denial=True) + await response(self.scope, self.receive, wrapped_send) else: raise RuntimeError("The server doesn't support the Websocket Denial Response extension.") + @staticmethod + def _send_wrap(send: Send, is_websocket_denial: bool) -> Send: + async def wrapped(message: Message) -> None: + message_type = message["type"] + if is_websocket_denial: + if not message_type.startswith("websocket."): + message["type"] = "websocket." + message_type + + await send(message) + + return wrapped + class WebSocketClose: def __init__(self, code: int = 1000, reason: str | None = None) -> None: