From 8b9e306ae1f1e6dc55cf0c73a6ea2eb5ecaf0af4 Mon Sep 17 00:00:00 2001 From: Paulo Costa Date: Wed, 25 Jun 2025 19:45:30 -0300 Subject: [PATCH] Support returning Not Modified responses in FileResponse --- starlette/responses.py | 52 ++++++++++++++++++++++++++++++---- starlette/staticfiles.py | 61 ++-------------------------------------- 2 files changed, 49 insertions(+), 64 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index 031633b15..1bacdd08b 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -10,7 +10,7 @@ import warnings from collections.abc import AsyncIterable, Awaitable, Iterable, Mapping, Sequence from datetime import datetime -from email.utils import format_datetime, formatdate +from email.utils import format_datetime, formatdate, parsedate from functools import partial from mimetypes import guess_type from secrets import token_hex @@ -297,6 +297,15 @@ def __init__(self, max_size: int) -> None: class FileResponse(Response): chunk_size = 64 * 1024 + NOT_MODIFIED_HEADERS = { + b"cache-control", + b"content-location", + b"date", + b"etag", + b"expires", + b"vary", + } + def __init__( self, path: str | os.PathLike[str], @@ -362,12 +371,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: stat_result = self.stat_result headers = Headers(scope=scope) + http_if_none_match = headers.get("if-none-match") + http_if_modified_since = headers.get("if-modified-since") http_range = headers.get("range") http_if_range = headers.get("if-range") - if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)): - await self._handle_simple(send, send_header_only, send_pathsend) - else: + if self.status_code == 200 and self._is_not_modified(http_if_none_match, http_if_modified_since): + await self._handle_not_modified(send) + elif self.status_code == 200 and http_range is not None and self._should_use_range(http_if_range): try: ranges = self._parse_range_header(http_range, stat_result.st_size) except MalformedRangeHeader as exc: @@ -381,6 +392,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only) else: await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only) + else: + await self._handle_simple(send, send_header_only, send_pathsend) if self.background is not None: await self.background() @@ -399,6 +412,11 @@ async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend more_body = len(chunk) == self.chunk_size await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) + async def _handle_not_modified(self, send: Send) -> None: + headers = [(k, v) for k, v in self.raw_headers if k in FileResponse.NOT_MODIFIED_HEADERS] + await send({"type": "http.response.start", "status": 304, "headers": headers}) + await send({"type": "http.response.body", "body": b"", "more_body": False}) + async def _handle_single_range( self, send: Send, start: int, end: int, file_size: int, send_header_only: bool ) -> None: @@ -452,8 +470,30 @@ async def _handle_multiple_ranges( } ) - def _should_use_range(self, http_if_range: str) -> bool: - return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"] + def _is_not_modified(self, http_if_none_match: str | None, http_if_modified_since: str | None) -> bool: + """ + Given the request and response headers, return `True` if an HTTP + "Not Modified" response could be returned instead. + """ + if http_if_none_match is not None: + match = [tag.strip(" W/") for tag in http_if_none_match.split(",")] + etag = self.headers["etag"] + return etag in match # Client already has the version with current tag + + if http_if_modified_since: + since = parsedate(http_if_modified_since) + last_modified = parsedate(self.headers["last-modified"]) + if since is not None and last_modified is not None: + return since >= last_modified + + return False + + def _should_use_range(self, http_if_range: str | None) -> bool: + return http_if_range in ( + None, + self.headers["last-modified"], + self.headers["etag"], + ) @staticmethod def _parse_range_header(http_range: str, file_size: int) -> list[tuple[int, int]]: diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 7fba9aa95..b13d3fbd6 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -4,14 +4,13 @@ import importlib.util import os import stat -from email.utils import parsedate from typing import Union import anyio import anyio.to_thread from starlette._utils import get_route_path -from starlette.datastructures import URL, Headers +from starlette.datastructures import URL from starlette.exceptions import HTTPException from starlette.responses import FileResponse, RedirectResponse, Response from starlette.types import Receive, Scope, Send @@ -19,23 +18,6 @@ PathLike = Union[str, "os.PathLike[str]"] -class NotModifiedResponse(Response): - NOT_MODIFIED_HEADERS = ( - "cache-control", - "content-location", - "date", - "etag", - "expires", - "vary", - ) - - def __init__(self, headers: Headers): - super().__init__( - status_code=304, - headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS}, - ) - - class StaticFiles: def __init__( self, @@ -126,7 +108,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: if stat_result and stat.S_ISREG(stat_result.st_mode): # We have a static file to serve. - return self.file_response(full_path, stat_result, scope) + return FileResponse(full_path, stat_result=stat_result) elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html: # We're in HTML mode, and have got a directory URL. @@ -139,7 +121,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: url = URL(scope=scope) url = url.replace(path=url.path + "/") return RedirectResponse(url=url) - return self.file_response(full_path, stat_result, scope) + return FileResponse(full_path, stat_result=stat_result) if self.html: # Check for '404.html' if we're in HTML mode. @@ -166,20 +148,6 @@ def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]: continue return "", None - def file_response( - self, - full_path: PathLike, - stat_result: os.stat_result, - scope: Scope, - status_code: int = 200, - ) -> Response: - request_headers = Headers(scope=scope) - - response = FileResponse(full_path, status_code=status_code, stat_result=stat_result) - if self.is_not_modified(response.headers, request_headers): - return NotModifiedResponse(response.headers) - return response - async def check_config(self) -> None: """ Perform a one-off configuration check that StaticFiles is actually @@ -195,26 +163,3 @@ async def check_config(self) -> None: raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.") if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)): raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.") - - def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool: - """ - Given the request and response headers, return `True` if an HTTP - "Not Modified" response could be returned instead. - """ - try: - if_none_match = request_headers["if-none-match"] - etag = response_headers["etag"] - if etag in [tag.strip(" W/") for tag in if_none_match.split(",")]: - return True - except KeyError: - pass - - try: - if_modified_since = parsedate(request_headers["if-modified-since"]) - last_modified = parsedate(response_headers["last-modified"]) - if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified: - return True - except KeyError: - pass - - return False