Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import partial
from mimetypes import guess_type
from secrets import token_hex
from typing import Any, Callable, Literal, Union
from typing import IO, Any, Callable, Literal, Union
from urllib.parse import quote

import anyio
Expand Down Expand Up @@ -308,6 +308,7 @@ def __init__(
stat_result: os.stat_result | None = None,
method: str | None = None,
content_disposition_type: str = "attachment",
file: IO[bytes] | None = None,
) -> None:
self.path = path
self.status_code = status_code
Expand All @@ -334,6 +335,8 @@ def __init__(
if stat_result is not None:
self.set_stat_headers(stat_result)

self.file = file

def set_stat_headers(self, stat_result: os.stat_result) -> None:
content_length = str(stat_result.st_size)
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
Expand All @@ -346,7 +349,7 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None:

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
send_header_only: bool = scope["method"].upper() == "HEAD"
send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {})
send_pathsend: bool = self.file is None and "http.response.pathsend" in scope.get("extensions", {})

if self.stat_result is None:
try:
Expand Down Expand Up @@ -385,14 +388,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.background is not None:
await self.background()

async def _open_file(self) -> anyio.AsyncFile[bytes]:
if self.file is not None:
return anyio.wrap_file(self.file)
else:
return await anyio.open_file(self.path, mode="rb")

async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None:
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
elif send_pathsend:
await send({"type": "http.response.pathsend", "path": str(self.path)})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
async with await self._open_file() as file:
more_body = True
while more_body:
chunk = await file.read(self.chunk_size)
Expand All @@ -408,7 +417,7 @@ async def _handle_single_range(
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
async with await self._open_file() as file:
await file.seek(start)
more_body = True
while more_body:
Expand All @@ -435,7 +444,7 @@ async def _handle_multiple_ranges(
if send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
async with await self._open_file() as file:
for start, end in ranges:
await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
await file.seek(start)
Expand Down
Loading