From 1eb9655b6cf39bc90f16bfc711c986d5ac1f069e Mon Sep 17 00:00:00 2001 From: 11kkw <11kkw@naver.com> Date: Wed, 16 Apr 2025 12:35:04 +0900 Subject: [PATCH 1/7] fix: Use anyio.SpooledTemporaryFile in UploadFile for proper async handling --- starlette/datastructures.py | 31 ++----- starlette/formparsers.py | 7 +- tests/test_datastructures.py | 167 ++++++++++++++++++++--------------- 3 files changed, 109 insertions(+), 96 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index f5d74d25f..7fa0faec9 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -4,7 +4,8 @@ from shlex import shlex from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit -from starlette.concurrency import run_in_threadpool +import anyio + from starlette.types import Scope @@ -413,7 +414,7 @@ class UploadFile: def __init__( self, - file: typing.BinaryIO, + file: anyio.SpooledTemporaryFile[bytes], *, size: int | None = None, filename: str | None = None, @@ -428,37 +429,19 @@ def __init__( def content_type(self) -> str | None: return self.headers.get("content-type", None) - @property - def _in_memory(self) -> bool: - # check for SpooledTemporaryFile._rolled - rolled_to_disk = getattr(self.file, "_rolled", True) - return not rolled_to_disk - async def write(self, data: bytes) -> None: if self.size is not None: self.size += len(data) - - if self._in_memory: - self.file.write(data) - else: - await run_in_threadpool(self.file.write, data) + await self.file.write(data) async def read(self, size: int = -1) -> bytes: - if self._in_memory: - return self.file.read(size) - return await run_in_threadpool(self.file.read, size) + return await self.file.read(size) async def seek(self, offset: int) -> None: - if self._in_memory: - self.file.seek(offset) - else: - await run_in_threadpool(self.file.seek, offset) + await self.file.seek(offset) async def close(self) -> None: - if self._in_memory: - self.file.close() - else: - await run_in_threadpool(self.file.close) + await self.file.aclose() def __repr__(self) -> str: return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})" diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 4551d6887..3d24dfad3 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -3,9 +3,10 @@ import typing from dataclasses import dataclass, field from enum import Enum -from tempfile import SpooledTemporaryFile from urllib.parse import unquote_plus +from anyio import SpooledTemporaryFile + from starlette.datastructures import FormData, Headers, UploadFile if typing.TYPE_CHECKING: @@ -208,7 +209,7 @@ def on_headers_finished(self) -> None: tempfile = SpooledTemporaryFile(max_size=self.spool_max_size) self._files_to_close_on_error.append(tempfile) self._current_part.file = UploadFile( - file=tempfile, # type: ignore[arg-type] + file=tempfile, size=0, filename=filename, headers=Headers(raw=self._current_part.item_headers), @@ -268,7 +269,7 @@ async def parse(self) -> FormData: except MultiPartException as exc: # Close all the files if there was an error. for file in self._files_to_close_on_error: - file.close() + await file.aclose() raise exc parser.finalize() diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 0e7d35c3c..007119346 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,7 +1,4 @@ -import io -from tempfile import SpooledTemporaryFile -from typing import BinaryIO - +import anyio import pytest from starlette.datastructures import ( @@ -308,29 +305,41 @@ def test_queryparams() -> None: @pytest.mark.anyio async def test_upload_file_file_input() -> None: """Test passing file/stream into the UploadFile constructor""" - stream = io.BytesIO(b"data") - file = UploadFile(filename="file", file=stream, size=4) - assert await file.read() == b"data" - assert file.size == 4 - await file.write(b" and more data!") - assert await file.read() == b"" - assert file.size == 19 - await file.seek(0) - assert await file.read() == b"data and more data!" + async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream: + await stream.write(b"data") + await stream.seek(0) + + file = UploadFile(filename="file", file=stream, size=4) + try: + assert await file.read() == b"data" + assert file.size == 4 + await file.write(b" and more data!") + assert await file.read() == b"" + assert file.size == 19 + await file.seek(0) + assert await file.read() == b"data and more data!" + finally: + await file.close() @pytest.mark.anyio async def test_upload_file_without_size() -> None: """Test passing file/stream into the UploadFile constructor without size""" - stream = io.BytesIO(b"data") - file = UploadFile(filename="file", file=stream) - assert await file.read() == b"data" - assert file.size is None - await file.write(b" and more data!") - assert await file.read() == b"" - assert file.size is None - await file.seek(0) - assert await file.read() == b"data and more data!" + async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream: + await stream.write(b"data") + await stream.seek(0) + + file = UploadFile(filename="file", file=stream) + try: + assert await file.read() == b"data" + assert file.size is None + await file.write(b" and more data!") + assert await file.read() == b"" + assert file.size is None + await file.seek(0) + assert await file.read() == b"data and more data!" + finally: + await file.close() @pytest.mark.anyio @@ -339,61 +348,81 @@ async def test_uploadfile_rolling(max_size: int) -> None: """Test that we can r/w to a SpooledTemporaryFile managed by UploadFile before and after it rolls to disk """ - stream: BinaryIO = SpooledTemporaryFile( # type: ignore[assignment] - max_size=max_size - ) - file = UploadFile(filename="file", file=stream, size=0) - assert await file.read() == b"" - assert file.size == 0 - await file.write(b"data") - assert await file.read() == b"" - assert file.size == 4 - await file.seek(0) - assert await file.read() == b"data" - await file.write(b" more") - assert await file.read() == b"" - assert file.size == 9 - await file.seek(0) - assert await file.read() == b"data more" - assert file.size == 9 - await file.close() - - -def test_formdata() -> None: - stream = io.BytesIO(b"data") - upload = UploadFile(filename="file", file=stream, size=4) - form = FormData([("a", "123"), ("a", "456"), ("b", upload)]) - assert "a" in form - assert "A" not in form - assert "c" not in form - assert form["a"] == "456" - assert form.get("a") == "456" - assert form.get("nope", default=None) is None - assert form.getlist("a") == ["123", "456"] - assert list(form.keys()) == ["a", "b"] - assert list(form.values()) == ["456", upload] - assert list(form.items()) == [("a", "456"), ("b", upload)] - assert len(form) == 2 - assert list(form) == ["a", "b"] - assert dict(form) == {"a": "456", "b": upload} - assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])" - assert FormData(form) == form - assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")]) - assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"} + async with anyio.SpooledTemporaryFile(max_size=max_size) as stream: + file = UploadFile(filename="file", file=stream, size=0) + try: + assert await file.read() == b"" + assert file.size == 0 + await file.write(b"data") + assert await file.read() == b"" + assert file.size == 4 + await file.seek(0) + assert await file.read() == b"data" + await file.write(b" more") + assert await file.read() == b"" + assert file.size == 9 + await file.seek(0) + assert await file.read() == b"data more" + assert file.size == 9 + finally: + await file.close() + + +@pytest.mark.anyio +async def test_formdata() -> None: + async with anyio.SpooledTemporaryFile(max_size=1024) as stream: + await stream.write(b"data") + await stream.seek(0) + + upload = UploadFile(filename="file", file=stream, size=4) + + form = FormData([("a", "123"), ("a", "456"), ("b", upload)]) + + assert "a" in form + assert "A" not in form + assert "c" not in form + assert form["a"] == "456" + assert form.get("a") == "456" + assert form.get("nope", default=None) is None + assert form.getlist("a") == ["123", "456"] + assert list(form.keys()) == ["a", "b"] + assert list(form.values()) == ["456", upload] + assert list(form.items()) == [("a", "456"), ("b", upload)] + assert len(form) == 2 + assert list(form) == ["a", "b"] + assert dict(form) == {"a": "456", "b": upload} + assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])" + assert FormData(form) == form + assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")]) + assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"} @pytest.mark.anyio async def test_upload_file_repr() -> None: - stream = io.BytesIO(b"data") - file = UploadFile(filename="file", file=stream, size=4) - assert repr(file) == "UploadFile(filename='file', size=4, headers=Headers({}))" + """Test the string representation of UploadFile""" + async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream: + await stream.write(b"data") + await stream.seek(0) + + file = UploadFile(filename="file", file=stream, size=4) + try: + assert repr(file) == "UploadFile(filename='file', size=4, headers=Headers({}))" + finally: + await file.close() @pytest.mark.anyio async def test_upload_file_repr_headers() -> None: - stream = io.BytesIO(b"data") - file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"})) - assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))" + """Test the string representation of UploadFile with custom headers""" + async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream: + await stream.write(b"data") + await stream.seek(0) + + file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"})) + try: + assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))" + finally: + await file.close() def test_multidict() -> None: From 873e2af417364f97e93a3165e1a6eeacccb546b8 Mon Sep 17 00:00:00 2001 From: 11kkw <11kkw@naver.com> Date: Thu, 8 May 2025 01:52:30 +0900 Subject: [PATCH 2/7] Refactor multipart parsing to use AsyncExitStack for safe file cleanup on error --- starlette/formparsers.py | 45 +++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 3d24dfad3..fa5685549 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +from contextlib import AsyncExitStack from dataclasses import dataclass, field from enum import Enum from urllib.parse import unquote_plus @@ -249,28 +250,30 @@ async def parse(self) -> FormData: # Create the parser. parser = multipart.MultipartParser(boundary, callbacks) - try: - # Feed the parser with data from the request. - async for chunk in self.stream: - parser.write(chunk) - # Write file data, it needs to use await with the UploadFile methods - # that call the corresponding file methods *in a threadpool*, - # otherwise, if they were called directly in the callback methods above - # (regular, non-async functions), that would block the event loop in - # the main thread. - for part, data in self._file_parts_to_write: - assert part.file # for type checkers - await part.file.write(data) - for part in self._file_parts_to_finish: - assert part.file # for type checkers - await part.file.seek(0) - self._file_parts_to_write.clear() - self._file_parts_to_finish.clear() - except MultiPartException as exc: - # Close all the files if there was an error. + + async with AsyncExitStack() as stack: for file in self._files_to_close_on_error: - await file.aclose() - raise exc + stack.push_async_callback(file.aclose) + + try: + # Feed the parser with data from the request. + async for chunk in self.stream: + parser.write(chunk) + # Write file data, it needs to use await with the UploadFile methods + # that call the corresponding file methods *in a threadpool*, + # otherwise, if they were called directly in the callback methods above + # (regular, non-async functions), that would block the event loop in + # the main thread. + for part, data in self._file_parts_to_write: + assert part.file # for type checkers + await part.file.write(data) + for part in self._file_parts_to_finish: + assert part.file # for type checkers + await part.file.seek(0) + self._file_parts_to_write.clear() + self._file_parts_to_finish.clear() + except MultiPartException as exc: + raise exc parser.finalize() return FormData(self.items) From c87a49a642810af11c430d275f10bd79fe0cf216 Mon Sep 17 00:00:00 2001 From: 11kkw <11kkw@naver.com> Date: Thu, 8 May 2025 03:15:45 +0900 Subject: [PATCH 3/7] Use AsyncExitStack for safe file cleanup on exception --- starlette/formparsers.py | 47 ++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/starlette/formparsers.py b/starlette/formparsers.py index fa5685549..616e312bc 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -250,30 +250,29 @@ async def parse(self) -> FormData: # Create the parser. parser = multipart.MultipartParser(boundary, callbacks) - - async with AsyncExitStack() as stack: - for file in self._files_to_close_on_error: - stack.push_async_callback(file.aclose) - - try: - # Feed the parser with data from the request. - async for chunk in self.stream: - parser.write(chunk) - # Write file data, it needs to use await with the UploadFile methods - # that call the corresponding file methods *in a threadpool*, - # otherwise, if they were called directly in the callback methods above - # (regular, non-async functions), that would block the event loop in - # the main thread. - for part, data in self._file_parts_to_write: - assert part.file # for type checkers - await part.file.write(data) - for part in self._file_parts_to_finish: - assert part.file # for type checkers - await part.file.seek(0) - self._file_parts_to_write.clear() - self._file_parts_to_finish.clear() - except MultiPartException as exc: - raise exc + try: + # Feed the parser with data from the request. + async for chunk in self.stream: + parser.write(chunk) + # Write file data, it needs to use await with the UploadFile methods + # that call the corresponding file methods *in a threadpool*, + # otherwise, if they were called directly in the callback methods above + # (regular, non-async functions), that would block the event loop in + # the main thread. + for part, data in self._file_parts_to_write: + assert part.file # for type checkers + await part.file.write(data) + for part in self._file_parts_to_finish: + assert part.file # for type checkers + await part.file.seek(0) + self._file_parts_to_write.clear() + self._file_parts_to_finish.clear() + except MultiPartException as exc: + # Close all the files if there was an error. + async with AsyncExitStack() as stack: + for f in self._files_to_close_on_error: + stack.push_async_callback(f.aclose) + raise exc parser.finalize() return FormData(self.items) From 73b68cd0c9d55c69c56758d554edc1144461974d Mon Sep 17 00:00:00 2001 From: 11kkw <11kkw@naver.com> Date: Mon, 12 May 2025 02:44:57 +0900 Subject: [PATCH 4/7] Use AsyncExitStack to cleanup temp files on any error --- starlette/formparsers.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 616e312bc..5f4db40df 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -152,7 +152,7 @@ def __init__( self._charset = "" self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] self._file_parts_to_finish: list[MultipartPart] = [] - self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] + self._files_to_close_on_error: AsyncExitStack = AsyncExitStack() self.max_part_size = max_part_size def on_part_begin(self) -> None: @@ -208,7 +208,7 @@ def on_headers_finished(self) -> None: raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") filename = _user_safe_decode(options[b"filename"], self._charset) tempfile = SpooledTemporaryFile(max_size=self.spool_max_size) - self._files_to_close_on_error.append(tempfile) + self._files_to_close_on_error.push_async_callback(tempfile.aclose) self._current_part.file = UploadFile( file=tempfile, size=0, @@ -267,11 +267,9 @@ async def parse(self) -> FormData: await part.file.seek(0) self._file_parts_to_write.clear() self._file_parts_to_finish.clear() - except MultiPartException as exc: + except BaseException as exc: # Close all the files if there was an error. - async with AsyncExitStack() as stack: - for f in self._files_to_close_on_error: - stack.push_async_callback(f.aclose) + await self._files_to_close_on_error.aclose() raise exc parser.finalize() From 72e214545408d345286b50b90752df8d5ac510c4 Mon Sep 17 00:00:00 2001 From: 11kkw <11kkw17@gmail.com> Date: Tue, 9 Sep 2025 17:04:30 +0900 Subject: [PATCH 5/7] tests(formparsers): remove rollover thread test and clean up imports The old test depended on SpooledTemporaryFile internals (background thread rollover), which are now handled by anyio. Removed the test, dropped unused imports, and fixed minor ruff formatting issues. --- starlette/datastructures.py | 4 +-- starlette/formparsers.py | 1 + tests/test_formparsers.py | 66 ++----------------------------------- 3 files changed, 4 insertions(+), 67 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 5cee0e24a..559b9d954 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -4,7 +4,6 @@ from shlex import shlex from typing import ( Any, - BinaryIO, NamedTuple, TypeVar, Union, @@ -436,9 +435,8 @@ def __init__( @property def content_type(self) -> str | None: return self.headers.get("content-type", None) - + async def write(self, data: bytes) -> None: - new_data_len = len(data) if self.size is not None: self.size += len(data) await self.file.write(data) diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 5e524a9e0..aedefdc80 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -1,4 +1,5 @@ from __future__ import annotations + from collections.abc import AsyncGenerator from contextlib import AsyncExitStack from dataclasses import dataclass, field diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 58f6a0c73..35681e78d 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,20 +1,15 @@ from __future__ import annotations import os -import threading -from collections.abc import Generator from contextlib import AbstractContextManager, nullcontext as does_not_raise -from io import BytesIO from pathlib import Path -from tempfile import SpooledTemporaryFile -from typing import Any, ClassVar -from unittest import mock +from typing import Any import pytest from starlette.applications import Starlette from starlette.datastructures import UploadFile -from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode +from starlette.formparsers import MultiPartException, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount @@ -109,22 +104,6 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) -async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None: - """Helper app to monitor what thread the app was called on. - - This can later be used to validate thread/event loop operations. - """ - request = Request(scope, receive) - - # Make sure we parse the form - await request.form() - await request.close() - - # Send back the current thread id - response = JSONResponse({"thread_ident": threading.current_thread().ident}) - await response(scope, receive, send) - - def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) @@ -324,47 +303,6 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor } -class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]): - """Helper class to track which threads performed the rollover operation. - - This is not threadsafe/multi-test safe. - """ - - rollover_threads: ClassVar[set[int | None]] = set() - - def rollover(self) -> None: - ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident) - super().rollover() - - -@pytest.fixture -def mock_spooled_temporary_file() -> Generator[None]: - try: - with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile): - yield - finally: - ThreadTrackingSpooledTemporaryFile.rollover_threads.clear() - - -def test_multipart_request_large_file_rollover_in_background_thread( - mock_spooled_temporary_file: None, test_client_factory: TestClientFactory -) -> None: - """Test that Spooled file rollovers happen in background threads.""" - data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1)) - - client = test_client_factory(app_monitor_thread) - response = client.post("/", files=[("test_large", data)]) - assert response.status_code == 200 - - # Parse the event thread id from the API response and ensure we have one - app_thread_ident = response.json().get("thread_ident") - assert app_thread_ident is not None - - # Ensure the app thread was not the same as the rollover one and that a rollover thread exists - assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads - assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1 - - def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post( From d3b043971aa08ac3d72db68c6e7e187094a5c258 Mon Sep 17 00:00:00 2001 From: 11kkw <11kkw@naver.com> Date: Wed, 10 Sep 2025 02:36:52 +0900 Subject: [PATCH 6/7] types: add AsyncFileIO protocol with minimal async file methods Defines , , , and for use in UploadFile --- starlette/_fileio.py | 12 ++++++++++++ starlette/datastructures.py | 9 ++------- 2 files changed, 14 insertions(+), 7 deletions(-) create mode 100644 starlette/_fileio.py diff --git a/starlette/_fileio.py b/starlette/_fileio.py new file mode 100644 index 000000000..c5a3edd65 --- /dev/null +++ b/starlette/_fileio.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +import os +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class AsyncFileIO(Protocol): + async def read(self, size: int = -1) -> bytes: ... + async def write(self, data: bytes) -> int: ... + async def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: ... + async def aclose(self) -> None: ... diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 559b9d954..75fd33ef2 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -11,8 +11,7 @@ ) from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit -import anyio - +from starlette._fileio import AsyncFileIO from starlette.types import Scope @@ -417,7 +416,7 @@ class UploadFile: def __init__( self, - file: anyio.SpooledTemporaryFile[bytes], + file: AsyncFileIO, *, size: int | None = None, filename: str | None = None, @@ -428,10 +427,6 @@ def __init__( self.size = size self.headers = headers or Headers() - # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks. - # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__ - self._max_mem_size = getattr(self.file, "_max_size", 0) - @property def content_type(self) -> str | None: return self.headers.get("content-type", None) From 979840bb0e0c03dd7e90c7ce06dba744805c9050 Mon Sep 17 00:00:00 2001 From: 11kkw <11kkw@naver.com> Date: Wed, 10 Sep 2025 03:14:41 +0900 Subject: [PATCH 7/7] tests: add coverage for AsyncFileIO protocol Add dedicated tests for the AsyncFileIO runtime-checkable protocol, covering positive/negative cases, unrelated types, and method execution to ensure full coverage. Mark the protocol definition with to avoid false negatives from coverage.py. --- starlette/_fileio.py | 2 +- tests/test__fileio.py | 68 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/test__fileio.py diff --git a/starlette/_fileio.py b/starlette/_fileio.py index c5a3edd65..d821296b8 100644 --- a/starlette/_fileio.py +++ b/starlette/_fileio.py @@ -4,7 +4,7 @@ from typing import Protocol, runtime_checkable -@runtime_checkable +@runtime_checkable # pragma: no cover class AsyncFileIO(Protocol): async def read(self, size: int = -1) -> bytes: ... async def write(self, data: bytes) -> int: ... diff --git a/tests/test__fileio.py b/tests/test__fileio.py new file mode 100644 index 000000000..672e217e6 --- /dev/null +++ b/tests/test__fileio.py @@ -0,0 +1,68 @@ +import os + +import pytest + +from starlette._fileio import AsyncFileIO + + +class GoodAsyncFile: + async def read(self, size: int = -1) -> bytes: + return b"ok" + + async def write(self, data: bytes) -> int: + return len(data) + + async def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: + return offset + + async def aclose(self) -> None: + return None + + +class BadAsyncFile_MissingSeek: + async def read(self, size: int = -1) -> bytes: + return b"ok" + + async def write(self, data: bytes) -> int: + return len(data) + + async def aclose(self) -> None: + return None + + +def test_async_fileio_runtime_check_positive() -> None: + obj = GoodAsyncFile() + assert isinstance(obj, AsyncFileIO) + assert issubclass(GoodAsyncFile, AsyncFileIO) + + +def test_async_fileio_runtime_check_negative() -> None: + obj = BadAsyncFile_MissingSeek() + assert not isinstance(obj, AsyncFileIO) + assert not issubclass(BadAsyncFile_MissingSeek, AsyncFileIO) + + +def test_async_fileio_runtime_check_unrelated_type() -> None: + assert not isinstance(123, AsyncFileIO) + + +def test_async_fileio_runtime_check_typeerror() -> None: + with pytest.raises(TypeError): + issubclass(123, AsyncFileIO) # type: ignore[arg-type] + + +@pytest.mark.anyio +async def test_goodasyncfile_methods_execute() -> None: + f = GoodAsyncFile() + assert await f.read() == b"ok" + assert await f.write(b"abc") == 3 + assert await f.seek(5) == 5 + await f.aclose() + + +@pytest.mark.anyio +async def test_badasyncfile_methods_execute() -> None: + f = BadAsyncFile_MissingSeek() + assert await f.read() == b"ok" + assert await f.write(b"abc") == 3 + await f.aclose()