Skip to content

Commit

Permalink
Fix CORS headers not set on exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
nielsbox committed Nov 28, 2023
1 parent bbd085b commit a0f2647
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
13 changes: 13 additions & 0 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from connexion.middleware.abstract import SpecMiddleware
from connexion.middleware.context import ContextMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.server_error import ServerErrorMiddleware
from connexion.middleware.lifespan import Lifespan, LifespanMiddleware
from connexion.middleware.request_validation import RequestValidationMiddleware
from connexion.middleware.response_validation import ResponseValidationMiddleware
Expand Down Expand Up @@ -92,6 +93,17 @@ def replace(self, **changes) -> "_Options":
class MiddlewarePosition(enum.Enum):
"""Positions to insert a middleware"""

BEFORE_EXCEPTION = ExceptionMiddleware
"""Add before the :class:`ExceptionMiddleware`. This is useful if you want your changes to
affect the way exceptions are handled, such as a custom error handler.
Be mindful that security has not yet been applied at this stage.
Additionally, the inserted middleware is positioned before the RoutingMiddleware, so you cannot
leverage any routing information yet and should implement your middleware to work globally
instead of on an operation level.
Usefull for CORS middleware which should be applied before the exception middleware.
"""
BEFORE_SWAGGER = SwaggerUIMiddleware
"""Add before the :class:`SwaggerUIMiddleware`. This is useful if you want your changes to
affect the Swagger UI, such as a path altering middleware that should also alter the paths
Expand Down Expand Up @@ -164,6 +176,7 @@ class ConnexionMiddleware:
provided application."""

default_middlewares = [
ServerErrorMiddleware,
ExceptionMiddleware,
SwaggerUIMiddleware,
RoutingMiddleware,
Expand Down
68 changes: 68 additions & 0 deletions connexion/middleware/server_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import asyncio
import functools
import logging
import typing as t

from starlette.concurrency import run_in_threadpool
from starlette.middleware.errors import (
ServerErrorMiddleware as StarletteServerErrorMiddleware,
)
from starlette.requests import Request as StarletteRequest
from starlette.responses import Response as StarletteResponse
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.exceptions import InternalServerError
from connexion.lifecycle import ConnexionRequest, ConnexionResponse
from connexion.types import MaybeAwaitable

logger = logging.getLogger(__name__)


def connexion_wrapper(
handler: t.Callable[
[ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
]
) -> t.Callable[[StarletteRequest, Exception], t.Awaitable[StarletteResponse]]:
"""Wrapper that translates Starlette requests to Connexion requests before passing
them to the error handler, and translates the returned Connexion responses to
Starlette responses."""

@functools.wraps(handler)
async def wrapper(request: StarletteRequest, exc: Exception) -> StarletteResponse:
request = ConnexionRequest.from_starlette_request(request)

if asyncio.iscoroutinefunction(handler):
response = await handler(request, exc) # type: ignore
else:
response = await run_in_threadpool(handler, request, exc)

while asyncio.iscoroutine(response):
response = await response

return StarletteResponse(
content=response.body,
status_code=response.status_code,
media_type=response.mimetype,
headers=response.headers,
)

return wrapper


class ServerErrorMiddleware(StarletteServerErrorMiddleware):
"""Subclass of starlette ServerErrorMiddleware to change handling of Unhandled Server
exceptions to existing connexion behavior."""

def __init__(self, next_app: ASGIApp):
super().__init__(next_app)

@staticmethod
def error_response(
_request: StarletteRequest, exc: Exception
) -> ConnexionResponse:
"""Default handler for any unhandled Exception"""
logger.error("%r", exc, exc_info=exc)
return InternalServerError().to_problem()

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await super().__call__(scope, receive, send)
6 changes: 3 additions & 3 deletions docs/cookbook.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
Expand Down Expand Up @@ -62,7 +62,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
Expand Down Expand Up @@ -96,7 +96,7 @@ Starlette. You can add it to your application, ideally in front of the ``Routing
app.add_middleware(
CORSMiddleware,
position=MiddlewarePosition.BEFORE_ROUTING,
position=MiddlewarePosition.BEFORE_EXCEPTION,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
Expand Down

0 comments on commit a0f2647

Please sign in to comment.