diff --git a/docs/middleware.md b/docs/middleware.md index aa9d119fe..aecffc3bf 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -239,9 +239,27 @@ The following arguments are supported: * `minimum_size` - Do not GZip responses that are smaller than this minimum size in bytes. Defaults to `500`. * `compresslevel` - Used during GZip compression. It is an integer ranging from 1 to 9. Defaults to `9`. Lower value results in faster compression but larger file sizes, while higher value results in slower compression but smaller file sizes. +* `excluded_content_types` - A tuple of content type prefixes that should not be compressed. Defaults to `("text/event-stream", "application/zip", "application/gzip", "application/x-gzip", "image/", "video/", "audio/")`. You can customize this to add or remove content types as needed. -The middleware won't GZip responses that already have either a `Content-Encoding` set, to prevent them from -being encoded twice, or a `Content-Type` set to `text/event-stream`, to avoid compressing server-sent events. +The middleware won't GZip responses that: + +* Already have a `Content-Encoding` set, to prevent them from being encoded twice +* Have a `Content-Type` that starts with any of the prefixes in `excluded_content_types` + +By default, the following content types are excluded: + +* `text/event-stream` - Server-sent events should not be compressed +* Already compressed formats: `application/zip`, `application/gzip`, `application/x-gzip` +* Media files: `image/*`, `video/*`, `audio/*` (any image, video, or audio format) + +### Customizing excluded content types + +You can provide your own list of excluded content types: +```python +middleware = [ + Middleware(GZipMiddleware, minimum_size=1000, compresslevel=9, excluded_content_types=("text/event-stream", "image/")) +] +``` ## BaseHTTPMiddleware diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index abd898b2d..c3507554c 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -5,14 +5,27 @@ from starlette.datastructures import Headers, MutableHeaders from starlette.types import ASGIApp, Message, Receive, Scope, Send -DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",) - class GZipMiddleware: - def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None: + def __init__( + self, + app: ASGIApp, + minimum_size: int = 500, + compresslevel: int = 9, + excluded_content_types: tuple[str, ...] = ( + "text/event-stream", + "application/zip", + "application/gzip", + "application/x-gzip", + "image/", + "video/", + "audio/", + ), + ) -> None: self.app = app self.minimum_size = minimum_size self.compresslevel = compresslevel + self.excluded_content_types = excluded_content_types async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": # pragma: no cover @@ -22,9 +35,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: headers = Headers(scope=scope) responder: ASGIApp if "gzip" in headers.get("Accept-Encoding", ""): - responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel) + responder = GZipResponder( + self.app, self.minimum_size, self.excluded_content_types, compresslevel=self.compresslevel + ) else: - responder = IdentityResponder(self.app, self.minimum_size) + responder = IdentityResponder(self.app, self.minimum_size, self.excluded_content_types) await responder(scope, receive, send) @@ -32,7 +47,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class IdentityResponder: content_encoding: str - def __init__(self, app: ASGIApp, minimum_size: int) -> None: + def __init__( + self, + app: ASGIApp, + minimum_size: int, + excluded_content_types: tuple[str, ...], + ) -> None: self.app = app self.minimum_size = minimum_size self.send: Send = unattached_send @@ -40,6 +60,7 @@ def __init__(self, app: ASGIApp, minimum_size: int) -> None: self.started = False self.content_encoding_set = False self.content_type_is_excluded = False + self.excluded_content_types = excluded_content_types async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self.send = send @@ -53,7 +74,9 @@ async def send_with_compression(self, message: Message) -> None: self.initial_message = message headers = Headers(raw=self.initial_message["headers"]) self.content_encoding_set = "content-encoding" in headers - self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES) + self.content_type_is_excluded = any( + headers.get("content-type", "").startswith(ct) for ct in self.excluded_content_types + ) elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded): if not self.started: self.started = True @@ -119,8 +142,10 @@ def apply_compression(self, body: bytes, *, more_body: bool) -> bytes: class GZipResponder(IdentityResponder): content_encoding = "gzip" - def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None: - super().__init__(app, minimum_size) + def __init__( + self, app: ASGIApp, minimum_size: int, excluded_content_types: tuple[str, ...], compresslevel: int = 9 + ) -> None: + super().__init__(app, minimum_size, excluded_content_types) self.gzip_buffer = io.BytesIO() self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel) diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index 3a2f2e0f3..344492a59 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -8,7 +8,7 @@ from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request -from starlette.responses import ContentStream, FileResponse, PlainTextResponse, StreamingResponse +from starlette.responses import ContentStream, FileResponse, PlainTextResponse, Response, StreamingResponse from starlette.routing import Route from starlette.types import Message from tests.types import TestClientFactory @@ -163,6 +163,37 @@ async def generator(bytes: bytes, count: int) -> ContentStream: assert "Content-Length" not in response.headers +@pytest.mark.parametrize( + "content_type,content", + [ + ("image/png", b"\x89PNG\r\n\x1a\n" + b"x" * 1000), + ("image/jpeg", b"\xff\xd8\xff" + b"x" * 1000), + ("video/mp4", b"x" * 1000), + ("audio/mpeg", b"x" * 1000), + ("application/zip", b"PK\x03\x04" + b"x" * 1000), + ("application/gzip", b"\x1f\x8b" + b"x" * 1000), + ("application/x-gzip", b"\x1f\x8b" + b"x" * 1000), + ], +) +def test_gzip_ignored_on_compressed_content_types( + test_client_factory: TestClientFactory, content_type: str, content: bytes +) -> None: + def endpoint(request: Request) -> Response: + return Response(content, status_code=200, media_type=content_type) + + app = Starlette( + routes=[Route("/", endpoint=endpoint)], + middleware=[Middleware(GZipMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/", headers={"accept-encoding": "gzip"}) + assert response.status_code == 200 + assert response.content == content + assert "Content-Encoding" not in response.headers + assert int(response.headers["Content-Length"]) == len(content) + + @pytest.mark.anyio async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None: path = tmpdir / "example.txt"