diff --git a/starlette/_fileio.py b/starlette/_fileio.py new file mode 100644 index 000000000..d821296b8 --- /dev/null +++ b/starlette/_fileio.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import os +from typing import Protocol, runtime_checkable + + +@runtime_checkable # pragma: no cover +class AsyncFileIO(Protocol): + async def read(self, size: int = -1) -> bytes: ... + async def write(self, data: bytes) -> int: ... + async def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: ... + async def aclose(self) -> None: ... diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 38eabec52..75fd33ef2 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -4,7 +4,6 @@ from shlex import shlex from typing import ( Any, - BinaryIO, NamedTuple, TypeVar, Union, @@ -12,7 +11,7 @@ ) from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit -from starlette.concurrency import run_in_threadpool +from starlette._fileio import AsyncFileIO from starlette.types import Scope @@ -417,7 +416,7 @@ class UploadFile: def __init__( self, - file: BinaryIO, + file: AsyncFileIO, *, size: int | None = None, filename: str | None = None, @@ -428,55 +427,23 @@ def __init__( self.size = size self.headers = headers or Headers() - # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks. - # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__ - self._max_mem_size = getattr(self.file, "_max_size", 0) - @property def content_type(self) -> str | None: return self.headers.get("content-type", None) - @property - def _in_memory(self) -> bool: - # check for SpooledTemporaryFile._rolled - rolled_to_disk = getattr(self.file, "_rolled", True) - return not rolled_to_disk - - def _will_roll(self, size_to_add: int) -> bool: - # If we're not in_memory then we will always roll - if not self._in_memory: - return True - - # Check for SpooledTemporaryFile._max_size - future_size = self.file.tell() + size_to_add - return bool(future_size > self._max_mem_size) if self._max_mem_size else False - async def write(self, data: bytes) -> None: - new_data_len = len(data) if self.size is not None: - self.size += new_data_len - - if self._will_roll(new_data_len): - await run_in_threadpool(self.file.write, data) - else: - self.file.write(data) + self.size += len(data) + await self.file.write(data) async def read(self, size: int = -1) -> bytes: - if self._in_memory: - return self.file.read(size) - return await run_in_threadpool(self.file.read, size) + return await self.file.read(size) async def seek(self, offset: int) -> None: - if self._in_memory: - self.file.seek(offset) - else: - await run_in_threadpool(self.file.seek, offset) + await self.file.seek(offset) async def close(self) -> None: - if self._in_memory: - self.file.close() - else: - await run_in_threadpool(self.file.close) + await self.file.aclose() def __repr__(self) -> str: return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})" diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 8e389dec7..aedefdc80 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -1,12 +1,14 @@ from __future__ import annotations from collections.abc import AsyncGenerator +from contextlib import AsyncExitStack from dataclasses import dataclass, field from enum import Enum -from tempfile import SpooledTemporaryFile from typing import TYPE_CHECKING from urllib.parse import unquote_plus +from anyio import SpooledTemporaryFile + from starlette.datastructures import FormData, Headers, UploadFile if TYPE_CHECKING: @@ -151,7 +153,7 @@ def __init__( self._charset = "" self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] self._file_parts_to_finish: list[MultipartPart] = [] - self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] + self._files_to_close_on_error: AsyncExitStack = AsyncExitStack() self.max_part_size = max_part_size def on_part_begin(self) -> None: @@ -207,9 +209,9 @@ def on_headers_finished(self) -> None: raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") filename = _user_safe_decode(options[b"filename"], self._charset) tempfile = SpooledTemporaryFile(max_size=self.spool_max_size) - self._files_to_close_on_error.append(tempfile) + self._files_to_close_on_error.push_async_callback(tempfile.aclose) self._current_part.file = UploadFile( - file=tempfile, # type: ignore[arg-type] + file=tempfile, size=0, filename=filename, headers=Headers(raw=self._current_part.item_headers), @@ -266,10 +268,9 @@ async def parse(self) -> FormData: await part.file.seek(0) self._file_parts_to_write.clear() self._file_parts_to_finish.clear() - except MultiPartException as exc: + except BaseException as exc: # Close all the files if there was an error. - for file in self._files_to_close_on_error: - file.close() + await self._files_to_close_on_error.aclose() raise exc parser.finalize() diff --git a/tests/test__fileio.py b/tests/test__fileio.py new file mode 100644 index 000000000..672e217e6 --- /dev/null +++ b/tests/test__fileio.py @@ -0,0 +1,68 @@ +import os + +import pytest + +from starlette._fileio import AsyncFileIO + + +class GoodAsyncFile: + async def read(self, size: int = -1) -> bytes: + return b"ok" + + async def write(self, data: bytes) -> int: + return len(data) + + async def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: + return offset + + async def aclose(self) -> None: + return None + + +class BadAsyncFile_MissingSeek: + async def read(self, size: int = -1) -> bytes: + return b"ok" + + async def write(self, data: bytes) -> int: + return len(data) + + async def aclose(self) -> None: + return None + + +def test_async_fileio_runtime_check_positive() -> None: + obj = GoodAsyncFile() + assert isinstance(obj, AsyncFileIO) + assert issubclass(GoodAsyncFile, AsyncFileIO) + + +def test_async_fileio_runtime_check_negative() -> None: + obj = BadAsyncFile_MissingSeek() + assert not isinstance(obj, AsyncFileIO) + assert not issubclass(BadAsyncFile_MissingSeek, AsyncFileIO) + + +def test_async_fileio_runtime_check_unrelated_type() -> None: + assert not isinstance(123, AsyncFileIO) + + +def test_async_fileio_runtime_check_typeerror() -> None: + with pytest.raises(TypeError): + issubclass(123, AsyncFileIO) # type: ignore[arg-type] + + +@pytest.mark.anyio +async def test_goodasyncfile_methods_execute() -> None: + f = GoodAsyncFile() + assert await f.read() == b"ok" + assert await f.write(b"abc") == 3 + assert await f.seek(5) == 5 + await f.aclose() + + +@pytest.mark.anyio +async def test_badasyncfile_methods_execute() -> None: + f = BadAsyncFile_MissingSeek() + assert await f.read() == b"ok" + assert await f.write(b"abc") == 3 + await f.aclose() diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 0e7d35c3c..007119346 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,7 +1,4 @@ -import io -from tempfile import SpooledTemporaryFile -from typing import BinaryIO - +import anyio import pytest from starlette.datastructures import ( @@ -308,29 +305,41 @@ def test_queryparams() -> None: @pytest.mark.anyio async def test_upload_file_file_input() -> None: """Test passing file/stream into the UploadFile constructor""" - stream = io.BytesIO(b"data") - file = UploadFile(filename="file", file=stream, size=4) - assert await file.read() == b"data" - assert file.size == 4 - await file.write(b" and more data!") - assert await file.read() == b"" - assert file.size == 19 - await file.seek(0) - assert await file.read() == b"data and more data!" + async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream: + await stream.write(b"data") + await stream.seek(0) + + file = UploadFile(filename="file", file=stream, size=4) + try: + assert await file.read() == b"data" + assert file.size == 4 + await file.write(b" and more data!") + assert await file.read() == b"" + assert file.size == 19 + await file.seek(0) + assert await file.read() == b"data and more data!" + finally: + await file.close() @pytest.mark.anyio async def test_upload_file_without_size() -> None: """Test passing file/stream into the UploadFile constructor without size""" - stream = io.BytesIO(b"data") - file = UploadFile(filename="file", file=stream) - assert await file.read() == b"data" - assert file.size is None - await file.write(b" and more data!") - assert await file.read() == b"" - assert file.size is None - await file.seek(0) - assert await file.read() == b"data and more data!" + async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream: + await stream.write(b"data") + await stream.seek(0) + + file = UploadFile(filename="file", file=stream) + try: + assert await file.read() == b"data" + assert file.size is None + await file.write(b" and more data!") + assert await file.read() == b"" + assert file.size is None + await file.seek(0) + assert await file.read() == b"data and more data!" + finally: + await file.close() @pytest.mark.anyio @@ -339,61 +348,81 @@ async def test_uploadfile_rolling(max_size: int) -> None: """Test that we can r/w to a SpooledTemporaryFile managed by UploadFile before and after it rolls to disk """ - stream: BinaryIO = SpooledTemporaryFile( # type: ignore[assignment] - max_size=max_size - ) - file = UploadFile(filename="file", file=stream, size=0) - assert await file.read() == b"" - assert file.size == 0 - await file.write(b"data") - assert await file.read() == b"" - assert file.size == 4 - await file.seek(0) - assert await file.read() == b"data" - await file.write(b" more") - assert await file.read() == b"" - assert file.size == 9 - await file.seek(0) - assert await file.read() == b"data more" - assert file.size == 9 - await file.close() - - -def test_formdata() -> None: - stream = io.BytesIO(b"data") - upload = UploadFile(filename="file", file=stream, size=4) - form = FormData([("a", "123"), ("a", "456"), ("b", upload)]) - assert "a" in form - assert "A" not in form - assert "c" not in form - assert form["a"] == "456" - assert form.get("a") == "456" - assert form.get("nope", default=None) is None - assert form.getlist("a") == ["123", "456"] - assert list(form.keys()) == ["a", "b"] - assert list(form.values()) == ["456", upload] - assert list(form.items()) == [("a", "456"), ("b", upload)] - assert len(form) == 2 - assert list(form) == ["a", "b"] - assert dict(form) == {"a": "456", "b": upload} - assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])" - assert FormData(form) == form - assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")]) - assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"} + async with anyio.SpooledTemporaryFile(max_size=max_size) as stream: + file = UploadFile(filename="file", file=stream, size=0) + try: + assert await file.read() == b"" + assert file.size == 0 + await file.write(b"data") + assert await file.read() == b"" + assert file.size == 4 + await file.seek(0) + assert await file.read() == b"data" + await file.write(b" more") + assert await file.read() == b"" + assert file.size == 9 + await file.seek(0) + assert await file.read() == b"data more" + assert file.size == 9 + finally: + await file.close() + + +@pytest.mark.anyio +async def test_formdata() -> None: + async with anyio.SpooledTemporaryFile(max_size=1024) as stream: + await stream.write(b"data") + await stream.seek(0) + + upload = UploadFile(filename="file", file=stream, size=4) + + form = FormData([("a", "123"), ("a", "456"), ("b", upload)]) + + assert "a" in form + assert "A" not in form + assert "c" not in form + assert form["a"] == "456" + assert form.get("a") == "456" + assert form.get("nope", default=None) is None + assert form.getlist("a") == ["123", "456"] + assert list(form.keys()) == ["a", "b"] + assert list(form.values()) == ["456", upload] + assert list(form.items()) == [("a", "456"), ("b", upload)] + assert len(form) == 2 + assert list(form) == ["a", "b"] + assert dict(form) == {"a": "456", "b": upload} + assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])" + assert FormData(form) == form + assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")]) + assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"} @pytest.mark.anyio async def test_upload_file_repr() -> None: - stream = io.BytesIO(b"data") - file = UploadFile(filename="file", file=stream, size=4) - assert repr(file) == "UploadFile(filename='file', size=4, headers=Headers({}))" + """Test the string representation of UploadFile""" + async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream: + await stream.write(b"data") + await stream.seek(0) + + file = UploadFile(filename="file", file=stream, size=4) + try: + assert repr(file) == "UploadFile(filename='file', size=4, headers=Headers({}))" + finally: + await file.close() @pytest.mark.anyio async def test_upload_file_repr_headers() -> None: - stream = io.BytesIO(b"data") - file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"})) - assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))" + """Test the string representation of UploadFile with custom headers""" + async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream: + await stream.write(b"data") + await stream.seek(0) + + file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"})) + try: + assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))" + finally: + await file.close() def test_multidict() -> None: diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 58f6a0c73..35681e78d 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,20 +1,15 @@ from __future__ import annotations import os -import threading -from collections.abc import Generator from contextlib import AbstractContextManager, nullcontext as does_not_raise -from io import BytesIO from pathlib import Path -from tempfile import SpooledTemporaryFile -from typing import Any, ClassVar -from unittest import mock +from typing import Any import pytest from starlette.applications import Starlette from starlette.datastructures import UploadFile -from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode +from starlette.formparsers import MultiPartException, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount @@ -109,22 +104,6 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) -async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None: - """Helper app to monitor what thread the app was called on. - - This can later be used to validate thread/event loop operations. - """ - request = Request(scope, receive) - - # Make sure we parse the form - await request.form() - await request.close() - - # Send back the current thread id - response = JSONResponse({"thread_ident": threading.current_thread().ident}) - await response(scope, receive, send) - - def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) @@ -324,47 +303,6 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor } -class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]): - """Helper class to track which threads performed the rollover operation. - - This is not threadsafe/multi-test safe. - """ - - rollover_threads: ClassVar[set[int | None]] = set() - - def rollover(self) -> None: - ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident) - super().rollover() - - -@pytest.fixture -def mock_spooled_temporary_file() -> Generator[None]: - try: - with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile): - yield - finally: - ThreadTrackingSpooledTemporaryFile.rollover_threads.clear() - - -def test_multipart_request_large_file_rollover_in_background_thread( - mock_spooled_temporary_file: None, test_client_factory: TestClientFactory -) -> None: - """Test that Spooled file rollovers happen in background threads.""" - data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1)) - - client = test_client_factory(app_monitor_thread) - response = client.post("/", files=[("test_large", data)]) - assert response.status_code == 200 - - # Parse the event thread id from the API response and ensure we have one - app_thread_ident = response.json().get("thread_ident") - assert app_thread_ident is not None - - # Ensure the app thread was not the same as the rollover one and that a rollover thread exists - assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads - assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1 - - def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post(