Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = _CachedRequest(scope, receive)
wrapped_receive = request.wrapped_receive
response_sent = anyio.Event()
app_exc: Exception | None = None

async def call_next(request: Request) -> Response:
app_exc: Exception | None = None
Expand Down Expand Up @@ -165,9 +166,6 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
if not message.get("more_body", False):
break

if app_exc is not None:
raise app_exc

response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info)
response.raw_headers = message["headers"]
return response
Expand All @@ -180,6 +178,9 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
response_sent.set()
recv_stream.close()

if app_exc is not None:
raise app_exc

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
raise NotImplementedError() # pragma: no cover

Expand Down
23 changes: 23 additions & 0 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,29 @@ async def send(message: Message) -> None:
assert background_task_run.is_set()


def test_run_background_tasks_raise_exceptions(test_client_factory: TestClientFactory) -> None:
# test for https://github.com/encode/starlette/issues/2625

async def sleep_and_set() -> None:
await anyio.sleep(0.1)
raise ValueError("TEST")

async def endpoint_with_background_task(_: Request) -> PlainTextResponse:
return PlainTextResponse(background=BackgroundTask(sleep_and_set))

async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
return await call_next(request)

app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_background_task)],
)

client = test_client_factory(app)
with pytest.raises(ValueError, match="TEST"):
client.get("/")


@pytest.mark.anyio
async def test_do_not_block_on_background_tasks() -> None:
response_complete = anyio.Event()
Expand Down
Loading