220 lines
8.5 KiB
Diff
220 lines
8.5 KiB
Diff
From 9f7ec2eb512fcc3fe90b43cb9dd9e1d08696bec1 Mon Sep 17 00:00:00 2001
|
|
From: Michael Honaker <37811263+HonakerM@users.noreply.github.com>
|
|
Date: Mon, 21 Jul 2025 02:24:02 +0900
|
|
Subject: [PATCH] Make UploadFile check for future rollover (#2962)
|
|
|
|
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
|
|
---
|
|
starlette/datastructures.py | 22 ++++++++++---
|
|
tests/test_formparsers.py | 66 +++++++++++++++++++++++++++++++++++--
|
|
2 files changed, 82 insertions(+), 6 deletions(-)
|
|
|
|
Index: starlette-0.41.3/starlette/datastructures.py
|
|
===================================================================
|
|
--- starlette-0.41.3.orig/starlette/datastructures.py
|
|
+++ starlette-0.41.3/starlette/datastructures.py
|
|
@@ -424,6 +424,10 @@ class UploadFile:
|
|
self.size = size
|
|
self.headers = headers or Headers()
|
|
|
|
+ # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
|
|
+ # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
|
|
+ self._max_mem_size = getattr(self.file, "_max_size", 0)
|
|
+
|
|
@property
|
|
def content_type(self) -> str | None:
|
|
return self.headers.get("content-type", None)
|
|
@@ -434,14 +438,24 @@ class UploadFile:
|
|
rolled_to_disk = getattr(self.file, "_rolled", True)
|
|
return not rolled_to_disk
|
|
|
|
+ def _will_roll(self, size_to_add: int) -> bool:
|
|
+ # If we're not in_memory then we will always roll
|
|
+ if not self._in_memory:
|
|
+ return True
|
|
+
|
|
+ # Check for SpooledTemporaryFile._max_size
|
|
+ future_size = self.file.tell() + size_to_add
|
|
+ return bool(future_size > self._max_mem_size) if self._max_mem_size else False
|
|
+
|
|
async def write(self, data: bytes) -> None:
|
|
+ new_data_len = len(data)
|
|
if self.size is not None:
|
|
- self.size += len(data)
|
|
+ self.size += new_data_len
|
|
|
|
- if self._in_memory:
|
|
- self.file.write(data)
|
|
- else:
|
|
+ if self._will_roll(new_data_len):
|
|
await run_in_threadpool(self.file.write, data)
|
|
+ else:
|
|
+ self.file.write(data)
|
|
|
|
async def read(self, size: int = -1) -> bytes:
|
|
if self._in_memory:
|
|
Index: starlette-0.41.3/tests/test_formparsers.py
|
|
===================================================================
|
|
--- starlette-0.41.3.orig/tests/test_formparsers.py
|
|
+++ starlette-0.41.3/tests/test_formparsers.py
|
|
@@ -1,15 +1,21 @@
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
+import threading
|
|
import typing
|
|
+from collections.abc import Generator
|
|
from contextlib import nullcontext as does_not_raise
|
|
+from io import BytesIO
|
|
from pathlib import Path
|
|
+from tempfile import SpooledTemporaryFile
|
|
+from typing import Any, ClassVar
|
|
+from unittest import mock
|
|
|
|
import pytest
|
|
|
|
from starlette.applications import Starlette
|
|
from starlette.datastructures import UploadFile
|
|
-from starlette.formparsers import MultiPartException, _user_safe_decode
|
|
+from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
from starlette.routing import Mount
|
|
@@ -104,6 +110,22 @@ async def app_read_body(scope: Scope, re
|
|
await response(scope, receive, send)
|
|
|
|
|
|
+async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None:
|
|
+ """Helper app to monitor what thread the app was called on.
|
|
+
|
|
+ This can later be used to validate thread/event loop operations.
|
|
+ """
|
|
+ request = Request(scope, receive)
|
|
+
|
|
+ # Make sure we parse the form
|
|
+ await request.form()
|
|
+ await request.close()
|
|
+
|
|
+ # Send back the current thread id
|
|
+ response = JSONResponse({"thread_ident": threading.current_thread().ident})
|
|
+ await response(scope, receive, send)
|
|
+
|
|
+
|
|
def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000) -> ASGIApp:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
request = Request(scope, receive)
|
|
@@ -303,6 +325,88 @@ def test_multipart_request_mixed_files_a
|
|
}
|
|
|
|
|
|
+class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]):
|
|
+ """Helper class to track which threads performed the rollover operation.
|
|
+
|
|
+ This is not threadsafe/multi-test safe.
|
|
+ """
|
|
+
|
|
+ rollover_threads: ClassVar[set[int | None]] = set()
|
|
+
|
|
+ def rollover(self) -> None:
|
|
+ ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident)
|
|
+ super().rollover()
|
|
+
|
|
+
|
|
+@pytest.fixture
|
|
+def mock_spooled_temporary_file() -> Generator[None]:
|
|
+ try:
|
|
+ with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile):
|
|
+ yield
|
|
+ finally:
|
|
+ ThreadTrackingSpooledTemporaryFile.rollover_threads.clear()
|
|
+
|
|
+
|
|
+def test_multipart_request_large_file_rollover_in_background_thread(
|
|
+ mock_spooled_temporary_file: None, test_client_factory: TestClientFactory
|
|
+) -> None:
|
|
+ """Test that Spooled file rollovers happen in background threads."""
|
|
+ data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1))
|
|
+
|
|
+ client = test_client_factory(app_monitor_thread)
|
|
+ response = client.post("/", files=[("test_large", data)])
|
|
+ assert response.status_code == 200
|
|
+
|
|
+ # Parse the event thread id from the API response and ensure we have one
|
|
+ app_thread_ident = response.json().get("thread_ident")
|
|
+ assert app_thread_ident is not None
|
|
+
|
|
+ # Ensure the app thread was not the same as the rollover one and that a rollover thread exists
|
|
+ assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads
|
|
+ assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1
|
|
+
|
|
+
|
|
+class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]):
|
|
+ """Helper class to track which threads performed the rollover operation.
|
|
+
|
|
+ This is not threadsafe/multi-test safe.
|
|
+ """
|
|
+
|
|
+ rollover_threads: ClassVar[set[int | None]] = set()
|
|
+
|
|
+ def rollover(self) -> None:
|
|
+ ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident)
|
|
+ super().rollover()
|
|
+
|
|
+
|
|
+@pytest.fixture
|
|
+def mock_spooled_temporary_file() -> Generator[None]:
|
|
+ try:
|
|
+ with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile):
|
|
+ yield
|
|
+ finally:
|
|
+ ThreadTrackingSpooledTemporaryFile.rollover_threads.clear()
|
|
+
|
|
+
|
|
+def test_multipart_request_large_file_rollover_in_background_thread(
|
|
+ mock_spooled_temporary_file: None, test_client_factory: TestClientFactory
|
|
+) -> None:
|
|
+ """Test that Spooled file rollovers happen in background threads."""
|
|
+ data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1))
|
|
+
|
|
+ client = test_client_factory(app_monitor_thread)
|
|
+ response = client.post("/", files=[("test_large", data)])
|
|
+ assert response.status_code == 200
|
|
+
|
|
+ # Parse the event thread id from the API response and ensure we have one
|
|
+ app_thread_ident = response.json().get("thread_ident")
|
|
+ assert app_thread_ident is not None
|
|
+
|
|
+ # Ensure the app thread was not the same as the rollover one and that a rollover thread exists
|
|
+ assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads
|
|
+ assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1
|
|
+
|
|
+
|
|
def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
|
|
client = test_client_factory(app)
|
|
response = client.post(
|
|
Index: starlette-0.41.3/starlette/formparsers.py
|
|
===================================================================
|
|
--- starlette-0.41.3.orig/starlette/formparsers.py
|
|
+++ starlette-0.41.3/starlette/formparsers.py
|
|
@@ -122,7 +122,10 @@ class FormParser:
|
|
|
|
|
|
class MultiPartParser:
|
|
- max_file_size = 1024 * 1024 # 1MB
|
|
+ spool_max_size = 1024 * 1024 # 1MB
|
|
+ """The maximum size of the spooled temporary file used to store file data."""
|
|
+ max_part_size = 1024 * 1024 # 1MB
|
|
+ """The maximum size of a part in the multipart request."""
|
|
max_part_size = 1024 * 1024 # 1MB
|
|
|
|
def __init__(
|
|
@@ -201,7 +204,7 @@ class MultiPartParser:
|
|
if self._current_files > self.max_files:
|
|
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
|
|
filename = _user_safe_decode(options[b"filename"], self._charset)
|
|
- tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
|
|
+ tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
|
|
self._files_to_close_on_error.append(tempfile)
|
|
self._current_part.file = UploadFile(
|
|
file=tempfile, # type: ignore[arg-type]
|