Skip to content
Open
31 changes: 7 additions & 24 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from shlex import shlex
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit

from starlette.concurrency import run_in_threadpool
import anyio

from starlette.types import Scope


Expand Down Expand Up @@ -413,7 +414,7 @@ class UploadFile:

def __init__(
self,
file: typing.BinaryIO,
file: anyio.SpooledTemporaryFile[bytes],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a base class analogous to BinaryIO? This was not a spoiled temporary file type before on purpose.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because that shouldn't matter for the UploadFile.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because that shouldn't matter for the UploadFile.

Yeah, that makes sense — the original use of BinaryIO was definitely meant to keep things implementation-agnostic, and I get why switching directly to SpooledTemporaryFile might feel too specific.

If you're open to it, I think one possible middle ground could be introducing a small AsyncFile protocol — something simple that defines the expected async methods (read, write, seek, etc). That way, we can preserve the flexibility while still ensuring everything stays fully async.

Happy to adjust the PR in that direction if it sounds reasonable to you!

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable, but maybe that structure already exists in anyio? 🤔

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable, but maybe that structure already exists in anyio? 🤔

Yes, it seems like anyio provides a type called AsyncFile for async file operations

Copy link
Contributor

@graingert graingert May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsyncFile is a concrete implementation not an abc, probably uh "AsyncFileIO[bytes]" protocol would make sense here.

If it's a protocol we can move it or have it in multiple places and have everything still work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I don't think the file parameter should be exposed here. It seems to give users the power to customize the file, but in fact the other parts of UploadFile rely heavily on the implementation of file. If we need to make some abstractions to allow users to customize the file parsing results of multipart, we can use a design similar to baize. But at least in this PR, I think there is no problem with modifying the type in this way.

*,
size: int | None = None,
filename: str | None = None,
Expand All @@ -428,37 +429,19 @@ def __init__(
def content_type(self) -> str | None:
return self.headers.get("content-type", None)

@property
def _in_memory(self) -> bool:
# check for SpooledTemporaryFile._rolled
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk

async def write(self, data: bytes) -> None:
if self.size is not None:
self.size += len(data)

if self._in_memory:
self.file.write(data)
else:
await run_in_threadpool(self.file.write, data)
await self.file.write(data)

async def read(self, size: int = -1) -> bytes:
if self._in_memory:
return self.file.read(size)
return await run_in_threadpool(self.file.read, size)
return await self.file.read(size)

async def seek(self, offset: int) -> None:
if self._in_memory:
self.file.seek(offset)
else:
await run_in_threadpool(self.file.seek, offset)
await self.file.seek(offset)

async def close(self) -> None:
if self._in_memory:
self.file.close()
else:
await run_in_threadpool(self.file.close)
await self.file.aclose()

def __repr__(self) -> str:
return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"
Expand Down
7 changes: 4 additions & 3 deletions starlette/formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import typing
from dataclasses import dataclass, field
from enum import Enum
from tempfile import SpooledTemporaryFile
from urllib.parse import unquote_plus

from anyio import SpooledTemporaryFile

from starlette.datastructures import FormData, Headers, UploadFile

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -208,7 +209,7 @@ def on_headers_finished(self) -> None:
tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
self._files_to_close_on_error.append(tempfile)
self._current_part.file = UploadFile(
file=tempfile, # type: ignore[arg-type]
file=tempfile,
size=0,
filename=filename,
headers=Headers(raw=self._current_part.item_headers),
Expand Down Expand Up @@ -268,7 +269,7 @@ async def parse(self) -> FormData:
except MultiPartException as exc:
# Close all the files if there was an error.
for file in self._files_to_close_on_error:
file.close()
await file.aclose()
raise exc

parser.finalize()
Expand Down
167 changes: 98 additions & 69 deletions tests/test_datastructures.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import io
from tempfile import SpooledTemporaryFile
from typing import BinaryIO

import anyio
import pytest

from starlette.datastructures import (
Expand Down Expand Up @@ -308,29 +305,41 @@ def test_queryparams() -> None:
@pytest.mark.anyio
async def test_upload_file_file_input() -> None:
"""Test passing file/stream into the UploadFile constructor"""
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream, size=4)
assert await file.read() == b"data"
assert file.size == 4
await file.write(b" and more data!")
assert await file.read() == b""
assert file.size == 19
await file.seek(0)
assert await file.read() == b"data and more data!"
async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream:
await stream.write(b"data")
await stream.seek(0)

file = UploadFile(filename="file", file=stream, size=4)
try:
assert await file.read() == b"data"
assert file.size == 4
await file.write(b" and more data!")
assert await file.read() == b""
assert file.size == 19
await file.seek(0)
assert await file.read() == b"data and more data!"
finally:
await file.close()


@pytest.mark.anyio
async def test_upload_file_without_size() -> None:
"""Test passing file/stream into the UploadFile constructor without size"""
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream)
assert await file.read() == b"data"
assert file.size is None
await file.write(b" and more data!")
assert await file.read() == b""
assert file.size is None
await file.seek(0)
assert await file.read() == b"data and more data!"
async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream:
await stream.write(b"data")
await stream.seek(0)

file = UploadFile(filename="file", file=stream)
try:
assert await file.read() == b"data"
assert file.size is None
await file.write(b" and more data!")
assert await file.read() == b""
assert file.size is None
await file.seek(0)
assert await file.read() == b"data and more data!"
finally:
await file.close()


@pytest.mark.anyio
Expand All @@ -339,61 +348,81 @@ async def test_uploadfile_rolling(max_size: int) -> None:
"""Test that we can r/w to a SpooledTemporaryFile
managed by UploadFile before and after it rolls to disk
"""
stream: BinaryIO = SpooledTemporaryFile( # type: ignore[assignment]
max_size=max_size
)
file = UploadFile(filename="file", file=stream, size=0)
assert await file.read() == b""
assert file.size == 0
await file.write(b"data")
assert await file.read() == b""
assert file.size == 4
await file.seek(0)
assert await file.read() == b"data"
await file.write(b" more")
assert await file.read() == b""
assert file.size == 9
await file.seek(0)
assert await file.read() == b"data more"
assert file.size == 9
await file.close()


def test_formdata() -> None:
stream = io.BytesIO(b"data")
upload = UploadFile(filename="file", file=stream, size=4)
form = FormData([("a", "123"), ("a", "456"), ("b", upload)])
assert "a" in form
assert "A" not in form
assert "c" not in form
assert form["a"] == "456"
assert form.get("a") == "456"
assert form.get("nope", default=None) is None
assert form.getlist("a") == ["123", "456"]
assert list(form.keys()) == ["a", "b"]
assert list(form.values()) == ["456", upload]
assert list(form.items()) == [("a", "456"), ("b", upload)]
assert len(form) == 2
assert list(form) == ["a", "b"]
assert dict(form) == {"a": "456", "b": upload}
assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
assert FormData(form) == form
assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")])
assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"}
async with anyio.SpooledTemporaryFile(max_size=max_size) as stream:
file = UploadFile(filename="file", file=stream, size=0)
try:
assert await file.read() == b""
assert file.size == 0
await file.write(b"data")
assert await file.read() == b""
assert file.size == 4
await file.seek(0)
assert await file.read() == b"data"
await file.write(b" more")
assert await file.read() == b""
assert file.size == 9
await file.seek(0)
assert await file.read() == b"data more"
assert file.size == 9
finally:
await file.close()


@pytest.mark.anyio
async def test_formdata() -> None:
async with anyio.SpooledTemporaryFile(max_size=1024) as stream:
await stream.write(b"data")
await stream.seek(0)

upload = UploadFile(filename="file", file=stream, size=4)

form = FormData([("a", "123"), ("a", "456"), ("b", upload)])

assert "a" in form
assert "A" not in form
assert "c" not in form
assert form["a"] == "456"
assert form.get("a") == "456"
assert form.get("nope", default=None) is None
assert form.getlist("a") == ["123", "456"]
assert list(form.keys()) == ["a", "b"]
assert list(form.values()) == ["456", upload]
assert list(form.items()) == [("a", "456"), ("b", upload)]
assert len(form) == 2
assert list(form) == ["a", "b"]
assert dict(form) == {"a": "456", "b": upload}
assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
assert FormData(form) == form
assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")])
assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"}


@pytest.mark.anyio
async def test_upload_file_repr() -> None:
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream, size=4)
assert repr(file) == "UploadFile(filename='file', size=4, headers=Headers({}))"
"""Test the string representation of UploadFile"""
async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream:
await stream.write(b"data")
await stream.seek(0)

file = UploadFile(filename="file", file=stream, size=4)
try:
assert repr(file) == "UploadFile(filename='file', size=4, headers=Headers({}))"
finally:
await file.close()


@pytest.mark.anyio
async def test_upload_file_repr_headers() -> None:
stream = io.BytesIO(b"data")
file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"}))
assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
"""Test the string representation of UploadFile with custom headers"""
async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream:
await stream.write(b"data")
await stream.seek(0)

file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"}))
try:
assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
finally:
await file.close()


def test_multidict() -> None:
Expand Down