From 0025ee13820a7d61effd2e5bd4f9792c113d38c0 Mon Sep 17 00:00:00 2001 From: ShreySinha02 Date: Mon, 12 Feb 2024 12:58:48 +0530 Subject: [PATCH] Add handling of Authorization header in CORS requests --- starlette/middleware/cors.py | 7 +++++-- tests/middleware/test_cors.py | 9 +++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 5c9bfa684..c7570c8a4 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -69,6 +69,7 @@ def __init__( self.allow_origin_regex = compiled_allow_origin_regex self.simple_headers = simple_headers self.preflight_headers = preflight_headers + self.allow_credentials = allow_credentials async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": # pragma: no cover @@ -157,11 +158,13 @@ async def send( headers.update(self.simple_headers) origin = request_headers["Origin"] has_cookie = "cookie" in request_headers + has_authorization = "Authorization" in request_headers # If request includes any cookie headers, then we must respond # with the specific origin instead of '*'. - if self.allow_all_origins and has_cookie: - self.allow_explicit_origin(headers, origin) + if self.allow_all_origins: + if self.allow_credentials and has_authorization or has_cookie: + self.allow_explicit_origin(headers, origin) # If we only allow specific origins, then we have to mirror back # the Origin header in the response. diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 09ec9513f..9a39030d2 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -66,6 +66,15 @@ def homepage(request: Request) -> PlainTextResponse: assert response.headers["access-control-expose-headers"] == "X-Status" assert response.headers["access-control-allow-credentials"] == "true" + # Test Authorization order + headers = {"Origin": "https://example.org", "Authorization": "Some_token"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-expose-headers"] == "X-Status" + assert response.headers["access-control-allow-credentials"] == "true" + # Test non-CORS response response = client.get("/") assert response.status_code == 200