From 9f7ec2eb512fcc3fe90b43cb9dd9e1d08696bec1 Mon Sep 17 00:00:00 2001 From: Michael Honaker <37811263+HonakerM@users.noreply.github.com> Date: Mon, 21 Jul 2025 02:24:02 +0900 Subject: [PATCH] Make UploadFile check for future rollover (#2962) Co-authored-by: Marcelo Trylesinski --- starlette/datastructures.py | 22 ++++++++++--- tests/test_formparsers.py | 66 +++++++++++++++++++++++++++++++++++-- 2 files changed, 82 insertions(+), 6 deletions(-) Index: starlette-0.41.3/starlette/datastructures.py =================================================================== --- starlette-0.41.3.orig/starlette/datastructures.py +++ starlette-0.41.3/starlette/datastructures.py @@ -424,6 +424,10 @@ class UploadFile: self.size = size self.headers = headers or Headers() + # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks. + # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__ + self._max_mem_size = getattr(self.file, "_max_size", 0) + @property def content_type(self) -> str | None: return self.headers.get("content-type", None) @@ -434,14 +438,24 @@ class UploadFile: rolled_to_disk = getattr(self.file, "_rolled", True) return not rolled_to_disk + def _will_roll(self, size_to_add: int) -> bool: + # If we're not in_memory then we will always roll + if not self._in_memory: + return True + + # Check for SpooledTemporaryFile._max_size + future_size = self.file.tell() + size_to_add + return bool(future_size > self._max_mem_size) if self._max_mem_size else False + async def write(self, data: bytes) -> None: + new_data_len = len(data) if self.size is not None: - self.size += len(data) + self.size += new_data_len - if self._in_memory: - self.file.write(data) - else: + if self._will_roll(new_data_len): await run_in_threadpool(self.file.write, data) + else: + self.file.write(data) async def read(self, size: int = -1) -> bytes: if self._in_memory: Index: starlette-0.41.3/tests/test_formparsers.py =================================================================== --- starlette-0.41.3.orig/tests/test_formparsers.py +++ starlette-0.41.3/tests/test_formparsers.py @@ -1,15 +1,21 @@ from __future__ import annotations import os +import threading import typing +from collections.abc import Generator from contextlib import nullcontext as does_not_raise +from io import BytesIO from pathlib import Path +from tempfile import SpooledTemporaryFile +from typing import Any, ClassVar +from unittest import mock import pytest from starlette.applications import Starlette from starlette.datastructures import UploadFile -from starlette.formparsers import MultiPartException, _user_safe_decode +from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount @@ -104,6 +110,22 @@ async def app_read_body(scope: Scope, re await response(scope, receive, send) +async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None: + """Helper app to monitor what thread the app was called on. + + This can later be used to validate thread/event loop operations. + """ + request = Request(scope, receive) + + # Make sure we parse the form + await request.form() + await request.close() + + # Send back the current thread id + response = JSONResponse({"thread_ident": threading.current_thread().ident}) + await response(scope, receive, send) + + def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000) -> ASGIApp: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) @@ -303,6 +325,88 @@ def test_multipart_request_mixed_files_a } +class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]): + """Helper class to track which threads performed the rollover operation. + + This is not threadsafe/multi-test safe. + """ + + rollover_threads: ClassVar[set[int | None]] = set() + + def rollover(self) -> None: + ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident) + super().rollover() + + +@pytest.fixture +def mock_spooled_temporary_file() -> Generator[None]: + try: + with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile): + yield + finally: + ThreadTrackingSpooledTemporaryFile.rollover_threads.clear() + + +def test_multipart_request_large_file_rollover_in_background_thread( + mock_spooled_temporary_file: None, test_client_factory: TestClientFactory +) -> None: + """Test that Spooled file rollovers happen in background threads.""" + data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1)) + + client = test_client_factory(app_monitor_thread) + response = client.post("/", files=[("test_large", data)]) + assert response.status_code == 200 + + # Parse the event thread id from the API response and ensure we have one + app_thread_ident = response.json().get("thread_ident") + assert app_thread_ident is not None + + # Ensure the app thread was not the same as the rollover one and that a rollover thread exists + assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads + assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1 + + +class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]): + """Helper class to track which threads performed the rollover operation. + + This is not threadsafe/multi-test safe. + """ + + rollover_threads: ClassVar[set[int | None]] = set() + + def rollover(self) -> None: + ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident) + super().rollover() + + +@pytest.fixture +def mock_spooled_temporary_file() -> Generator[None]: + try: + with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile): + yield + finally: + ThreadTrackingSpooledTemporaryFile.rollover_threads.clear() + + +def test_multipart_request_large_file_rollover_in_background_thread( + mock_spooled_temporary_file: None, test_client_factory: TestClientFactory +) -> None: + """Test that Spooled file rollovers happen in background threads.""" + data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1)) + + client = test_client_factory(app_monitor_thread) + response = client.post("/", files=[("test_large", data)]) + assert response.status_code == 200 + + # Parse the event thread id from the API response and ensure we have one + app_thread_ident = response.json().get("thread_ident") + assert app_thread_ident is not None + + # Ensure the app thread was not the same as the rollover one and that a rollover thread exists + assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads + assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1 + + def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.post( Index: starlette-0.41.3/starlette/formparsers.py =================================================================== --- starlette-0.41.3.orig/starlette/formparsers.py +++ starlette-0.41.3/starlette/formparsers.py @@ -122,7 +122,10 @@ class FormParser: class MultiPartParser: - max_file_size = 1024 * 1024 # 1MB + spool_max_size = 1024 * 1024 # 1MB + """The maximum size of the spooled temporary file used to store file data.""" + max_part_size = 1024 * 1024 # 1MB + """The maximum size of a part in the multipart request.""" max_part_size = 1024 * 1024 # 1MB def __init__( @@ -201,7 +204,7 @@ class MultiPartParser: if self._current_files > self.max_files: raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") filename = _user_safe_decode(options[b"filename"], self._charset) - tempfile = SpooledTemporaryFile(max_size=self.max_file_size) + tempfile = SpooledTemporaryFile(max_size=self.spool_max_size) self._files_to_close_on_error.append(tempfile) self._current_part.file = UploadFile( file=tempfile, # type: ignore[arg-type]