diff --git a/docs/exceptions.md b/docs/exceptions.md index ef8e755f6..485f0332e 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -99,6 +99,9 @@ exception_handlers = { } ``` +These handlers are passed to both `ServerErrorMiddleware` and `ExceptionMiddleware`, meaning they +handle unhandled exceptions (e.g., `RuntimeError`) as well as `HTTPException(status_code=500)`. + It's important to notice that in case a [`BackgroundTask`](background.md) raises an exception, it will be handled by the `handle_error` function, but at that point, the response was already sent. In other words, the response created by `handle_error` will be discarded. In case the error happens before the response was sent, then diff --git a/starlette/applications.py b/starlette/applications.py index 721181c0a..253944ea6 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -79,8 +79,7 @@ def build_middleware_stack(self) -> ASGIApp: for key, value in self.exception_handlers.items(): if key in (500, Exception): error_handler = value - else: - exception_handlers[key] = value + exception_handlers[key] = value middleware = ( [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] diff --git a/tests/test_applications.py b/tests/test_applications.py index 20ff06385..dae41b572 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -216,6 +216,29 @@ def test_500(test_client_factory: TestClientFactory) -> None: assert response.json() == {"detail": "Server Error"} +def test_500_status_handler(test_client_factory: TestClientFactory) -> None: + """ + Test that a custom 500 status handler is invoked by ExceptionMiddleware + when an HTTPException with status_code=500 is raised. + """ + + async def custom_500_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse({"detail": "Custom 500"}, status_code=500) + + async def raise_http_500(request: Request) -> None: + raise HTTPException(status_code=500, detail="Server Error") + + app = Starlette( + routes=[Route("/http-500", endpoint=raise_http_500)], + exception_handlers={500: custom_500_handler}, + ) + + client = test_client_factory(app, raise_server_exceptions=False) + response = client.get("/http-500") + assert response.status_code == 500 + assert response.json() == {"detail": "Custom 500"} + + def test_websocket_raise_websocket_exception(client: TestClient) -> None: with client.websocket_connect("/ws-raise-websocket") as session: response = session.receive()