diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 03afdd9c5..0be4973a5 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,6 +1,13 @@ from __future__ import annotations -from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable, Mapping, MutableMapping +from collections.abc import ( + AsyncGenerator, + AsyncIterable, + Awaitable, + Callable, + Mapping, + MutableMapping, +) from typing import Any, TypeVar import anyio @@ -29,7 +36,7 @@ def __init__(self, scope: Scope, receive: Receive): super().__init__(scope, receive) self._wrapped_rcv_disconnected = False self._wrapped_rcv_consumed = False - self._wrapped_rc_stream = self.stream() + self._wrapped_rcv_stream: AsyncGenerator[bytes, None] | None = None async def wrapped_receive(self) -> Message: # wrapped_rcv state 1: disconnected @@ -80,7 +87,11 @@ async def wrapped_receive(self) -> Message: else: # body() was never called and stream() wasn't consumed try: - stream = self.stream() + stream = self._wrapped_rcv_stream + if stream is None: + stream = self.stream() + self._wrapped_rcv_stream = stream + chunk = await stream.__anext__() self._wrapped_rcv_consumed = self._stream_consumed return { @@ -103,7 +114,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return - request = _CachedRequest(scope, receive) + request: _CachedRequest = scope.setdefault("__starlette_CachedRequest__", _CachedRequest(scope, receive)) wrapped_receive = request.wrapped_receive response_sent = anyio.Event() app_exc: Exception | None = None diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 1b0b94760..875b031d1 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -15,7 +15,12 @@ from starlette.middleware import Middleware, _MiddlewareFactory from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import ClientDisconnect, Request -from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse +from starlette.responses import ( + FileResponse, + PlainTextResponse, + Response, + StreamingResponse, +) from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -106,6 +111,72 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None: assert text == "Hello, world!" +@pytest.mark.anyio +async def test_streaming_request( + test_client_factory: TestClientFactory, +) -> None: + async def echo(request: Request) -> Response: + body = await request.body() + return Response( + body, + status_code=200, + media_type=request.headers.get("Content-Type", "text/plain"), + ) + + async def skip_dispatch(request: Request, call_next: RequestResponseEndpoint) -> Response: + return await call_next(request) + + app = Starlette( + routes=[Route("/echo", echo)], + middleware=[Middleware(BaseHTTPMiddleware, dispatch=skip_dispatch)], + ) + + async def receive_generator() -> AsyncGenerator[Message, None]: + yield {"type": "http.request", "body": b"first chunk\n", "more_body": True} + yield {"type": "http.request", "body": b"second chunk", "more_body": False} + yield {"type": "http.disconnect"} + + async def send(message: Message) -> None: + sent.append(message) + + sent: list[Message] = [] + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/echo", + "headers": [[b"Content-Type", "text/plain"]], + } + + receive = receive_generator().__anext__ + + await app(scope, receive, send) + + assert sent == [ + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-length", b"24"), + (b"content-type", b"text/plain; charset=utf-8"), + ], + }, + { + "type": "http.response.body", + "more_body": True, + "body": b"first chunk\nsecond chunk", + }, + # BaseHTTPMiddleware converts responses into StreamingResponses that yield b"" at end + {"type": "http.response.body", "more_body": False, "body": b""}, + ] + + assert await receive() == {"type": "http.disconnect"} + + with pytest.raises(StopAsyncIteration): + await receive() + + def test_state_data_across_multiple_middlewares( test_client_factory: TestClientFactory, ) -> None: @@ -298,7 +369,9 @@ async def send(message: Message) -> None: assert background_task_run.is_set() -def test_run_background_tasks_raise_exceptions(test_client_factory: TestClientFactory) -> None: +def test_run_background_tasks_raise_exceptions( + test_client_factory: TestClientFactory, +) -> None: # test for https://github.com/Kludex/starlette/issues/2625 async def sleep_and_set() -> None: