Skip to content
Open
12 changes: 12 additions & 0 deletions starlette/_fileio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

import os
from typing import Protocol, runtime_checkable


@runtime_checkable # pragma: no cover
class AsyncFileIO(Protocol):
async def read(self, size: int = -1) -> bytes: ...
async def write(self, data: bytes) -> int: ...
async def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: ...
async def aclose(self) -> None: ...
47 changes: 7 additions & 40 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from shlex import shlex
from typing import (
Any,
BinaryIO,
NamedTuple,
TypeVar,
Union,
cast,
)
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit

from starlette.concurrency import run_in_threadpool
from starlette._fileio import AsyncFileIO
from starlette.types import Scope


Expand Down Expand Up @@ -417,7 +416,7 @@ class UploadFile:

def __init__(
self,
file: BinaryIO,
file: AsyncFileIO,
*,
size: int | None = None,
filename: str | None = None,
Expand All @@ -428,55 +427,23 @@ def __init__(
self.size = size
self.headers = headers or Headers()

# Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
# Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
self._max_mem_size = getattr(self.file, "_max_size", 0)

@property
def content_type(self) -> str | None:
return self.headers.get("content-type", None)

@property
def _in_memory(self) -> bool:
# check for SpooledTemporaryFile._rolled
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk

def _will_roll(self, size_to_add: int) -> bool:
# If we're not in_memory then we will always roll
if not self._in_memory:
return True

# Check for SpooledTemporaryFile._max_size
future_size = self.file.tell() + size_to_add
return bool(future_size > self._max_mem_size) if self._max_mem_size else False

async def write(self, data: bytes) -> None:
new_data_len = len(data)
if self.size is not None:
self.size += new_data_len

if self._will_roll(new_data_len):
await run_in_threadpool(self.file.write, data)
else:
self.file.write(data)
self.size += len(data)
await self.file.write(data)

async def read(self, size: int = -1) -> bytes:
if self._in_memory:
return self.file.read(size)
return await run_in_threadpool(self.file.read, size)
return await self.file.read(size)

async def seek(self, offset: int) -> None:
if self._in_memory:
self.file.seek(offset)
else:
await run_in_threadpool(self.file.seek, offset)
await self.file.seek(offset)

async def close(self) -> None:
if self._in_memory:
self.file.close()
else:
await run_in_threadpool(self.file.close)
await self.file.aclose()

def __repr__(self) -> str:
return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
Expand Down
15 changes: 8 additions & 7 deletions starlette/formparsers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from collections.abc import AsyncGenerator
from contextlib import AsyncExitStack
from dataclasses import dataclass, field
from enum import Enum
from tempfile import SpooledTemporaryFile
from typing import TYPE_CHECKING
from urllib.parse import unquote_plus

from anyio import SpooledTemporaryFile

from starlette.datastructures import FormData, Headers, UploadFile

if TYPE_CHECKING:
Expand Down Expand Up @@ -151,7 +153,7 @@ def __init__(
self._charset = ""
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: list[MultipartPart] = []
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
self._files_to_close_on_error: AsyncExitStack = AsyncExitStack()
self.max_part_size = max_part_size

def on_part_begin(self) -> None:
Expand Down Expand Up @@ -207,9 +209,9 @@ def on_headers_finished(self) -> None:
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
self._files_to_close_on_error.append(tempfile)
self._files_to_close_on_error.push_async_callback(tempfile.aclose)
self._current_part.file = UploadFile(
file=tempfile, # type: ignore[arg-type]
file=tempfile,
size=0,
filename=filename,
headers=Headers(raw=self._current_part.item_headers),
Expand Down Expand Up @@ -266,10 +268,9 @@ async def parse(self) -> FormData:
await part.file.seek(0)
self._file_parts_to_write.clear()
self._file_parts_to_finish.clear()
except MultiPartException as exc:
except BaseException as exc:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this change?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this change?

I changed it to BaseException assuming it would help ensure cleanup even for cases like CancelledError or KeyboardInterrupt, beyond just MultiPartException.

But perhaps I misunderstood — please let me know if I’m thinking about this the wrong way!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CancelledError being the most likely one

# Close all the files if there was an error.
for file in self._files_to_close_on_error:
file.close()
await self._files_to_close_on_error.aclose()
raise exc

parser.finalize()
Expand Down
68 changes: 68 additions & 0 deletions tests/test__fileio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os

import pytest

from starlette._fileio import AsyncFileIO


class GoodAsyncFile:
async def read(self, size: int = -1) -> bytes:
return b"ok"

async def write(self, data: bytes) -> int:
return len(data)

async def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
return offset

async def aclose(self) -> None:
return None


class BadAsyncFile_MissingSeek:
async def read(self, size: int = -1) -> bytes:
return b"ok"

async def write(self, data: bytes) -> int:
return len(data)

async def aclose(self) -> None:
return None


def test_async_fileio_runtime_check_positive() -> None:
obj = GoodAsyncFile()
assert isinstance(obj, AsyncFileIO)
assert issubclass(GoodAsyncFile, AsyncFileIO)


def test_async_fileio_runtime_check_negative() -> None:
obj = BadAsyncFile_MissingSeek()
assert not isinstance(obj, AsyncFileIO)
assert not issubclass(BadAsyncFile_MissingSeek, AsyncFileIO)


def test_async_fileio_runtime_check_unrelated_type() -> None:
assert not isinstance(123, AsyncFileIO)


def test_async_fileio_runtime_check_typeerror() -> None:
with pytest.raises(TypeError):
issubclass(123, AsyncFileIO) # type: ignore[arg-type]


@pytest.mark.anyio
async def test_goodasyncfile_methods_execute() -> None:
f = GoodAsyncFile()
assert await f.read() == b"ok"
assert await f.write(b"abc") == 3
assert await f.seek(5) == 5
await f.aclose()


@pytest.mark.anyio
async def test_badasyncfile_methods_execute() -> None:
f = BadAsyncFile_MissingSeek()
assert await f.read() == b"ok"
assert await f.write(b"abc") == 3
await f.aclose()
Loading