diff --git a/starlette/responses.py b/starlette/responses.py index 031633b15..f847dc1a6 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -14,7 +14,7 @@ from functools import partial from mimetypes import guess_type from secrets import token_hex -from typing import Any, Callable, Literal, Union +from typing import IO, Any, Callable, Literal, Union from urllib.parse import quote import anyio @@ -308,6 +308,7 @@ def __init__( stat_result: os.stat_result | None = None, method: str | None = None, content_disposition_type: str = "attachment", + file: IO[bytes] | None = None, ) -> None: self.path = path self.status_code = status_code @@ -334,6 +335,8 @@ def __init__( if stat_result is not None: self.set_stat_headers(stat_result) + self.file = file + def set_stat_headers(self, stat_result: os.stat_result) -> None: content_length = str(stat_result.st_size) last_modified = formatdate(stat_result.st_mtime, usegmt=True) @@ -346,7 +349,7 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: send_header_only: bool = scope["method"].upper() == "HEAD" - send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {}) + send_pathsend: bool = self.file is None and "http.response.pathsend" in scope.get("extensions", {}) if self.stat_result is None: try: @@ -385,6 +388,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.background is not None: await self.background() + async def _open_file(self) -> anyio.AsyncFile[bytes]: + if self.file is not None: + return anyio.wrap_file(self.file) + else: + return await anyio.open_file(self.path, mode="rb") + async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None: await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) if send_header_only: @@ -392,7 +401,7 @@ async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend elif send_pathsend: await send({"type": "http.response.pathsend", "path": str(self.path)}) else: - async with await anyio.open_file(self.path, mode="rb") as file: + async with await self._open_file() as file: more_body = True while more_body: chunk = await file.read(self.chunk_size) @@ -408,7 +417,7 @@ async def _handle_single_range( if send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) else: - async with await anyio.open_file(self.path, mode="rb") as file: + async with await self._open_file() as file: await file.seek(start) more_body = True while more_body: @@ -435,7 +444,7 @@ async def _handle_multiple_ranges( if send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) else: - async with await anyio.open_file(self.path, mode="rb") as file: + async with await self._open_file() as file: for start, end in ranges: await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True}) await file.seek(start)