Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from __future__ import annotations

from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable, Mapping, MutableMapping
from collections.abc import (
AsyncGenerator,
AsyncIterable,
Awaitable,
Callable,
Mapping,
MutableMapping,
)
from typing import Any, TypeVar

import anyio
Expand Down Expand Up @@ -29,7 +36,7 @@ def __init__(self, scope: Scope, receive: Receive):
super().__init__(scope, receive)
self._wrapped_rcv_disconnected = False
self._wrapped_rcv_consumed = False
self._wrapped_rc_stream = self.stream()
self._wrapped_rcv_stream: AsyncGenerator[bytes, None] | None = None

async def wrapped_receive(self) -> Message:
# wrapped_rcv state 1: disconnected
Expand Down Expand Up @@ -80,7 +87,11 @@ async def wrapped_receive(self) -> Message:
else:
# body() was never called and stream() wasn't consumed
try:
stream = self.stream()
stream = self._wrapped_rcv_stream
if stream is None:
stream = self.stream()
self._wrapped_rcv_stream = stream

chunk = await stream.__anext__()
self._wrapped_rcv_consumed = self._stream_consumed
return {
Expand All @@ -103,7 +114,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

request = _CachedRequest(scope, receive)
request: _CachedRequest = scope.setdefault("__starlette_CachedRequest__", _CachedRequest(scope, receive))
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()
app_exc: Exception | None = None
Expand Down
77 changes: 75 additions & 2 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from starlette.middleware import Middleware, _MiddlewareFactory
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import ClientDisconnect, Request
from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse
from starlette.responses import (
FileResponse,
PlainTextResponse,
Response,
StreamingResponse,
)
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Expand Down Expand Up @@ -106,6 +111,72 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
assert text == "Hello, world!"


@pytest.mark.anyio
async def test_streaming_request(
test_client_factory: TestClientFactory,
) -> None:
async def echo(request: Request) -> Response:
body = await request.body()
return Response(
body,
status_code=200,
media_type=request.headers.get("Content-Type", "text/plain"),
)

async def skip_dispatch(request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)

app = Starlette(
routes=[Route("/echo", echo)],
middleware=[Middleware(BaseHTTPMiddleware, dispatch=skip_dispatch)],
)

async def receive_generator() -> AsyncGenerator[Message, None]:
yield {"type": "http.request", "body": b"first chunk\n", "more_body": True}
yield {"type": "http.request", "body": b"second chunk", "more_body": False}
yield {"type": "http.disconnect"}

async def send(message: Message) -> None:
sent.append(message)

sent: list[Message] = []

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/echo",
"headers": [[b"Content-Type", "text/plain"]],
}

receive = receive_generator().__anext__

await app(scope, receive, send)

assert sent == [
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-length", b"24"),
(b"content-type", b"text/plain; charset=utf-8"),
],
},
{
"type": "http.response.body",
"more_body": True,
"body": b"first chunk\nsecond chunk",
},
# BaseHTTPMiddleware converts responses into StreamingResponses that yield b"" at end
{"type": "http.response.body", "more_body": False, "body": b""},
]

assert await receive() == {"type": "http.disconnect"}

with pytest.raises(StopAsyncIteration):
await receive()


def test_state_data_across_multiple_middlewares(
test_client_factory: TestClientFactory,
) -> None:
Expand Down Expand Up @@ -298,7 +369,9 @@ async def send(message: Message) -> None:
assert background_task_run.is_set()


def test_run_background_tasks_raise_exceptions(test_client_factory: TestClientFactory) -> None:
def test_run_background_tasks_raise_exceptions(
test_client_factory: TestClientFactory,
) -> None:
# test for https://github.com/Kludex/starlette/issues/2625

async def sleep_and_set() -> None:
Expand Down