diff --git a/starlette/responses.py b/starlette/responses.py index d0024760e..cac1ab699 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -242,10 +242,11 @@ async def listen_for_disconnect(self, receive: Receive) -> None: if message["type"] == "http.disconnect": break - async def stream_response(self, send: Send) -> None: + async def stream_response(self, send: Send, scope: Scope) -> None: + prefix = "websocket." if scope.get("type") == "websocket" else "" await send( { - "type": "http.response.start", + "type": prefix + "http.response.start", "status": self.status_code, "headers": self.raw_headers, } @@ -253,16 +254,16 @@ async def stream_response(self, send: Send) -> None: async for chunk in self.body_iterator: if not isinstance(chunk, bytes | memoryview): chunk = chunk.encode(self.charset) - await send({"type": "http.response.body", "body": chunk, "more_body": True}) + await send({"type": prefix + "http.response.body", "body": chunk, "more_body": True}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) + await send({"type": prefix + "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) if spec_version >= (2, 4): try: - await self.stream_response(send) + await self.stream_response(send, scope) except OSError: raise ClientDisconnect() else: @@ -273,7 +274,7 @@ async def wrap(func: Callable[[], Awaitable[None]]) -> None: await func() task_group.cancel_scope.cancel() - task_group.start_soon(wrap, partial(self.stream_response, send)) + task_group.start_soon(wrap, partial(self.stream_response, send, scope)) await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: @@ -341,8 +342,15 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: self.headers.setdefault("etag", etag) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - send_header_only: bool = scope["method"].upper() == "HEAD" - send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {}) + # For WebSocket denial responses, method and pathsend don't apply + if scope.get("type") == "websocket": + send_header_only = False + send_pathsend = False + prefix = "websocket." + else: + send_header_only: bool = scope.get("method", "GET").upper() == "HEAD" + send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {}) + prefix = "" if self.stat_result is None: try: @@ -362,7 +370,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: http_if_range = headers.get("if-range") if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)): - await self._handle_simple(send, send_header_only, send_pathsend) + await self._handle_simple(send, send_header_only, send_pathsend, prefix) else: try: ranges = self._parse_range_header(http_range, stat_result.st_size) @@ -374,35 +382,35 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if len(ranges) == 1: start, end = ranges[0] - await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only) + await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only, prefix) else: - await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only) + await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only, prefix) if self.background is not None: await self.background() - async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None: - await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) + async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool, prefix: str = "") -> None: + await send({"type": prefix + "http.response.start", "status": self.status_code, "headers": self.raw_headers}) if send_header_only: - await send({"type": "http.response.body", "body": b"", "more_body": False}) + await send({"type": prefix + "http.response.body", "body": b"", "more_body": False}) elif send_pathsend: - await send({"type": "http.response.pathsend", "path": str(self.path)}) + await send({"type": prefix + "http.response.pathsend", "path": str(self.path)}) else: async with await anyio.open_file(self.path, mode="rb") as file: more_body = True while more_body: chunk = await file.read(self.chunk_size) more_body = len(chunk) == self.chunk_size - await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) + await send({"type": prefix + "http.response.body", "body": chunk, "more_body": more_body}) async def _handle_single_range( - self, send: Send, start: int, end: int, file_size: int, send_header_only: bool + self, send: Send, start: int, end: int, file_size: int, send_header_only: bool, prefix: str = "" ) -> None: self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}" self.headers["content-length"] = str(end - start) - await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) + await send({"type": prefix + "http.response.start", "status": 206, "headers": self.raw_headers}) if send_header_only: - await send({"type": "http.response.body", "body": b"", "more_body": False}) + await send({"type": prefix + "http.response.body", "body": b"", "more_body": False}) else: async with await anyio.open_file(self.path, mode="rb") as file: await file.seek(start) @@ -411,7 +419,7 @@ async def _handle_single_range( chunk = await file.read(min(self.chunk_size, end - start)) start += len(chunk) more_body = len(chunk) == self.chunk_size and start < end - await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) + await send({"type": prefix + "http.response.body", "body": chunk, "more_body": more_body}) async def _handle_multiple_ranges( self, @@ -419,6 +427,7 @@ async def _handle_multiple_ranges( ranges: list[tuple[int, int]], file_size: int, send_header_only: bool, + prefix: str = "", ) -> None: # In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes). boundary = token_hex(13) @@ -427,22 +436,22 @@ async def _handle_multiple_ranges( ) self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}" self.headers["content-length"] = str(content_length) - await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) + await send({"type": prefix + "http.response.start", "status": 206, "headers": self.raw_headers}) if send_header_only: - await send({"type": "http.response.body", "body": b"", "more_body": False}) + await send({"type": prefix + "http.response.body", "body": b"", "more_body": False}) else: async with await anyio.open_file(self.path, mode="rb") as file: for start, end in ranges: - await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True}) + await send({"type": prefix + "http.response.body", "body": header_generator(start, end), "more_body": True}) await file.seek(start) while start < end: chunk = await file.read(min(self.chunk_size, end - start)) start += len(chunk) - await send({"type": "http.response.body", "body": chunk, "more_body": True}) - await send({"type": "http.response.body", "body": b"\n", "more_body": True}) + await send({"type": prefix + "http.response.body", "body": chunk, "more_body": True}) + await send({"type": prefix + "http.response.body", "body": b"\n", "more_body": True}) await send( { - "type": "http.response.body", + "type": prefix + "http.response.body", "body": f"\n--{boundary}--\n".encode("latin-1"), "more_body": False, } diff --git a/tests/test_websockets.py b/tests/test_websockets.py index e76d8f29b..b3f194581 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -322,6 +322,63 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert exc.value.content == b"foo" +def test_send_denial_response_with_streaming_response(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + from starlette.responses import StreamingResponse + + websocket = WebSocket(scope, receive=receive, send=send) + msg = await websocket.receive() + assert msg == {"type": "websocket.connect"} + + async def generate() -> None: + yield b"hello" + yield b"world" + + response = StreamingResponse(generate(), status_code=403) + await websocket.send_denial_response(response) + + client = test_client_factory(app) + with pytest.raises(WebSocketDenialResponse) as exc: + with client.websocket_connect("/"): + pass # pragma: no cover + assert exc.value.status_code == 403 + assert exc.value.content == b"helloworld" + + +def test_send_denial_response_with_file_response(test_client_factory: TestClientFactory) -> None: + import tempfile + from starlette.responses import FileResponse + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + msg = await websocket.receive() + assert msg == {"type": "websocket.connect"} + + response = FileResponse(scope["app"].file_path, status_code=401) + await websocket.send_denial_response(response) + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: + f.write("test content") + temp_file = f.name + + try: + async def test_app(scope: Scope, receive: Receive, send: Send) -> None: + scope["app"] = type("App", (), {"file_path": temp_file})() + await app(scope, receive, send) + + client = test_client_factory(test_app) + with pytest.raises(WebSocketDenialResponse) as exc: + with client.websocket_connect("/"): + pass # pragma: no cover + assert exc.value.status_code == 401 + assert exc.value.content == b"test content" + finally: + import os + + os.unlink(temp_file) + + def test_send_response_multi(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send)