diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 577918eb9..dc353a26d 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -156,7 +156,16 @@ async def coro() -> None: if app_exc is not None: nonlocal exception_already_raised exception_already_raised = True - raise app_exc + # Prevent `anyio.EndOfStream` from polluting app exception context. + # If both cause and context are None then the context is suppressed + # and `anyio.EndOfStream` is not present in the exception traceback. + # If exception cause is not None then it is propagated with + # reraising here. + # If exception has no cause but has context set then the context is + # propagated as a cause with the reraise. This is necessary in order + # to prevent `anyio.EndOfStream` from polluting the exception + # context. + raise app_exc from app_exc.__cause__ or app_exc.__context__ raise RuntimeError("No response returned.") assert message["type"] == "http.response.start" diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index d4548e66b..1b0b94760 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1243,3 +1243,62 @@ async def send(message: Message) -> None: assert len(events) == 2 assert events[0]["type"] == "http.response.start" assert events[1]["type"] == "http.response.pathsend" + + +def test_error_context_propagation(test_client_factory: TestClientFactory) -> None: + class PassthroughMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + return await call_next(request) + + def exception_without_context(request: Request) -> None: + raise Exception("Exception") + + def exception_with_context(request: Request) -> None: + try: + raise Exception("Inner exception") + except Exception: + raise Exception("Outer exception") + + def exception_with_cause(request: Request) -> None: + try: + raise Exception("Inner exception") + except Exception as e: + raise Exception("Outer exception") from e + + app = Starlette( + routes=[ + Route("/exception-without-context", endpoint=exception_without_context), + Route("/exception-with-context", endpoint=exception_with_context), + Route("/exception-with-cause", endpoint=exception_with_cause), + ], + middleware=[Middleware(PassthroughMiddleware)], + ) + client = test_client_factory(app) + + # For exceptions without context the context is filled with the `anyio.EndOfStream` + # but it is suppressed therefore not propagated to traceback. + with pytest.raises(Exception) as ctx: + client.get("/exception-without-context") + assert str(ctx.value) == "Exception" + assert ctx.value.__cause__ is None + assert ctx.value.__context__ is not None + assert ctx.value.__suppress_context__ is True + + # For exceptions with context the context is propagated as a cause to avoid + # `anyio.EndOfStream` error from overwriting it. + with pytest.raises(Exception) as ctx: + client.get("/exception-with-context") + assert str(ctx.value) == "Outer exception" + assert ctx.value.__cause__ is not None + assert str(ctx.value.__cause__) == "Inner exception" + + # For exceptions with cause check that it gets correctly propagated. + with pytest.raises(Exception) as ctx: + client.get("/exception-with-cause") + assert str(ctx.value) == "Outer exception" + assert ctx.value.__cause__ is not None + assert str(ctx.value.__cause__) == "Inner exception"