diff --git a/docs/requests.md b/docs/requests.md index d63055532..1fde29f84 100644 --- a/docs/requests.md +++ b/docs/requests.md @@ -122,6 +122,20 @@ async with request.form(max_files=1000, max_fields=1000, max_part_size=1024*1024 ... ``` +You can configure maximum [spooled file](https://docs.python.org/3/library/tempfile.html#tempfile.SpooledTemporaryFile) size per file uploaded with the parameter `max_file_size`: + +```python +async with request.form(max_spool_size=100*1024*1024): # 100 MB limit per file + ... +``` + +You can configure maximum size (on disk) of all files uploaded with the parameter `max_files_size`: + +```python +async with request.form(max_file_size=1024*1024*1024): # 1 GB limit per file + ... +``` + !!! info These limits are for security reasons, allowing an unlimited number of fields or files could lead to a denial of service attack by consuming a lot of CPU and memory parsing too many empty fields. diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 4551d6887..99d8ebe3e 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -135,6 +135,8 @@ def __init__( max_files: int | float = 1000, max_fields: int | float = 1000, max_part_size: int = 1024 * 1024, # 1MB + max_spool_size: int = 1024 * 1024, # 1MB + max_file_size: int = 1024 * 1024 * 1024, # 1GB ) -> None: assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." self.headers = headers @@ -152,6 +154,8 @@ def __init__( self._file_parts_to_finish: list[MultipartPart] = [] self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] self.max_part_size = max_part_size + self.spool_max_size = max_spool_size + self.max_file_size = max_file_size def on_part_begin(self) -> None: self._current_part = MultipartPart() @@ -259,6 +263,8 @@ async def parse(self) -> FormData: # the main thread. for part, data in self._file_parts_to_write: assert part.file # for type checkers + if part.file.size is not None and part.file.size + len(data) > self.max_file_size: + raise MultiPartException(f"File exceeded maximum size of {int(self.max_file_size / 1024)}KB.") await part.file.write(data) for part in self._file_parts_to_finish: assert part.file # for type checkers diff --git a/starlette/requests.py b/starlette/requests.py index 7dc04a746..e2fa0a352 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -255,6 +255,8 @@ async def _get_form( max_files: int | float = 1000, max_fields: int | float = 1000, max_part_size: int = 1024 * 1024, + max_spool_size: int = 1024 * 1024, + max_file_size: int = 1024 * 1024 * 1024, # 1 GB ) -> FormData: if self._form is None: # pragma: no branch assert parse_options_header is not None, ( @@ -271,6 +273,8 @@ async def _get_form( max_files=max_files, max_fields=max_fields, max_part_size=max_part_size, + max_spool_size=max_spool_size, + max_file_size=max_file_size, ) self._form = await multipart_parser.parse() except MultiPartException as exc: @@ -290,9 +294,17 @@ def form( max_files: int | float = 1000, max_fields: int | float = 1000, max_part_size: int = 1024 * 1024, + max_spool_size: int = 1024 * 1024, + max_file_size: int = 1024 * 1024 * 1024, # 1 GB ) -> AwaitableOrContextManager[FormData]: return AwaitableOrContextManagerWrapper( - self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) + self._get_form( + max_files=max_files, + max_fields=max_fields, + max_part_size=max_part_size, + max_spool_size=max_spool_size, + max_file_size=max_file_size, + ) ) async def close(self) -> None: diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index b18fd6c40..dfd1bcf24 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -4,6 +4,7 @@ import typing from contextlib import nullcontext as does_not_raise from pathlib import Path +from secrets import token_bytes import pytest @@ -127,6 +128,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: return app +def make_app_max_file_size(max_file_size: int) -> ASGIApp: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + request = Request(scope, receive) + await request.form(max_file_size=max_file_size) + + return app + + def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) @@ -739,3 +748,27 @@ def test_max_part_size_exceeds_custom_limit( response = client.post("/", content=multipart_data, headers=headers) assert response.status_code == 400 assert response.text == "Part exceeded maximum size of 10KB." + + +@pytest.mark.parametrize( + "app,expectation", + [ + (make_app_max_file_size(1024), pytest.raises(MultiPartException)), + (Starlette(routes=[Mount("/", app=make_app_max_file_size(1024))]), does_not_raise()), + ], +) +def test_max_part_file_size_raise( + tmpdir: Path, + app: ASGIApp, + expectation: typing.ContextManager[Exception], + test_client_factory: TestClientFactory, +) -> None: + path = os.path.join(tmpdir, "test.txt") + with open(path, "wb") as file: + file.write(token_bytes(1024 + 1)) + + client = test_client_factory(app) + with open(path, "rb") as f, expectation: + response = client.post("/", files={"test": f}) + assert response.status_code == 400 + assert response.text == "File exceeded maximum size of 1KB."