Skip to content

Commit 7d586f7

Browse files
authored
Support for send client disconnect to HTTP (#2732)
1 parent 530cbf6 commit 7d586f7

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

starlette/responses.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from starlette.background import BackgroundTask
2222
from starlette.concurrency import iterate_in_threadpool
2323
from starlette.datastructures import URL, Headers, MutableHeaders
24+
from starlette.requests import ClientDisconnect
2425
from starlette.types import Receive, Scope, Send
2526

2627

@@ -249,14 +250,22 @@ async def stream_response(self, send: Send) -> None:
249250
await send({"type": "http.response.body", "body": b"", "more_body": False})
250251

251252
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
252-
async with anyio.create_task_group() as task_group:
253+
spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split(".")))
253254

254-
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
255-
await func()
256-
task_group.cancel_scope.cancel()
255+
if spec_version >= (2, 4):
256+
try:
257+
await self.stream_response(send)
258+
except OSError:
259+
raise ClientDisconnect()
260+
else:
261+
async with anyio.create_task_group() as task_group:
262+
263+
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
264+
await func()
265+
task_group.cancel_scope.cancel()
257266

258-
task_group.start_soon(wrap, partial(self.stream_response, send))
259-
await wrap(partial(self.listen_for_disconnect, receive))
267+
task_group.start_soon(wrap, partial(self.stream_response, send))
268+
await wrap(partial(self.listen_for_disconnect, receive))
260269

261270
if self.background is not None:
262271
await self.background()

tests/test_responses.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import time
55
from http.cookies import SimpleCookie
66
from pathlib import Path
7-
from typing import Any, AsyncIterator, Iterator
7+
from typing import Any, AsyncGenerator, AsyncIterator, Iterator
88

99
import anyio
1010
import pytest
1111

1212
from starlette import status
1313
from starlette.background import BackgroundTask
1414
from starlette.datastructures import Headers
15-
from starlette.requests import Request
15+
from starlette.requests import ClientDisconnect, Request
1616
from starlette.responses import FileResponse, JSONResponse, RedirectResponse, Response, StreamingResponse
1717
from starlette.testclient import TestClient
1818
from starlette.types import Message, Receive, Scope, Send
@@ -542,6 +542,39 @@ async def stream_indefinitely() -> AsyncIterator[bytes]:
542542
assert not cancel_scope.cancel_called, "Content streaming should stop itself."
543543

544544

545+
@pytest.mark.anyio
546+
async def test_streaming_response_on_client_disconnects() -> None:
547+
chunks = bytearray()
548+
streamed = False
549+
550+
async def receive_disconnect() -> Message:
551+
raise NotImplementedError
552+
553+
async def send(message: Message) -> None:
554+
nonlocal streamed
555+
if message["type"] == "http.response.body":
556+
if not streamed:
557+
chunks.extend(message.get("body", b""))
558+
streamed = True
559+
else:
560+
raise OSError
561+
562+
async def stream_indefinitely() -> AsyncGenerator[bytes, None]:
563+
while True:
564+
await anyio.sleep(0)
565+
yield b"chunk"
566+
567+
stream = stream_indefinitely()
568+
response = StreamingResponse(content=stream)
569+
570+
with anyio.move_on_after(1) as cancel_scope:
571+
with pytest.raises(ClientDisconnect):
572+
await response({"asgi": {"spec_version": "2.4"}}, receive_disconnect, send)
573+
assert not cancel_scope.cancel_called, "Content streaming should stop itself."
574+
assert chunks == b"chunk"
575+
await stream.aclose()
576+
577+
545578
README = """\
546579
# BáiZé
547580

0 commit comments

Comments
 (0)