From fd038f3070c302bff17ef7d173dbb0b007617733 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 15 Oct 2024 08:40:51 +0200 Subject: [PATCH] Merge commit from fork --- starlette/formparsers.py | 11 +++++++---- tests/test_formparsers.py | 41 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 45 insertions(+), 7 deletions(-) Index: starlette-0.38.5/starlette/formparsers.py =================================================================== --- starlette-0.38.5.orig/starlette/formparsers.py +++ starlette-0.38.5/starlette/formparsers.py @@ -28,12 +28,12 @@ class FormMessage(Enum): class MultipartPart: content_disposition: bytes | None = None field_name: str = "" - data: bytes = b"" + data: bytearray = field(default_factory=bytearray) file: UploadFile | None = None item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) -def _user_safe_decode(src: bytes, codec: str) -> str: +def _user_safe_decode(src: bytes | bytearray, codec: str) -> str: try: return src.decode(codec) except (UnicodeDecodeError, LookupError): @@ -114,7 +114,8 @@ class FormParser: class MultiPartParser: - max_file_size = 1024 * 1024 + max_file_size = 1024 * 1024 # 1MB + max_part_size = 1024 * 1024 # 1MB def __init__( self, @@ -146,7 +147,9 @@ class MultiPartParser: def on_part_data(self, data: bytes, start: int, end: int) -> None: message_bytes = data[start:end] if self._current_part.file is None: - self._current_part.data += message_bytes + if len(self._current_part.data) + len(message_bytes) > self.max_part_size: + raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.") + self._current_part.data.extend(message_bytes) else: self._file_parts_to_write.append((self._current_part, message_bytes)) Index: starlette-0.38.5/tests/test_formparsers.py =================================================================== --- starlette-0.38.5.orig/tests/test_formparsers.py +++ starlette-0.38.5/tests/test_formparsers.py @@ -640,9 +640,7 @@ def test_max_files_is_customizable_low_r assert res.text == "Too many files. Maximum number of files is 1." -def test_max_fields_is_customizable_high( - test_client_factory: TestClientFactory, -) -> None: +def test_max_fields_is_customizable_high(test_client_factory: TestClientFactory) -> None: client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000)) fields = [] for i in range(2000): @@ -664,3 +662,40 @@ def test_max_fields_is_customizable_high "content": "", "content_type": None, } + + +@pytest.mark.parametrize( + "app,expectation", + [ + (app, pytest.raises(MultiPartException)), + (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), + ], +) +def test_max_part_size_exceeds_limit( + app: ASGIApp, + expectation: typing.ContextManager[Exception], + test_client_factory: TestClientFactory, +) -> None: + client = test_client_factory(app) + boundary = "------------------------4K1ON9fZkj9uCUmqLHRbbR" + + multipart_data = ( + f"--{boundary}\r\n" + f'Content-Disposition: form-data; name="small"\r\n\r\n' + "small content\r\n" + f"--{boundary}\r\n" + f'Content-Disposition: form-data; name="large"\r\n\r\n' + + ("x" * 1024 * 1024 + "x") # 1MB + 1 byte of data + + "\r\n" + f"--{boundary}--\r\n" + ).encode("utf-8") + + headers = { + "Content-Type": f"multipart/form-data; boundary={boundary}", + "Transfer-Encoding": "chunked", + } + + with expectation: + response = client.post("/", data=multipart_data, headers=headers) # type: ignore + assert response.status_code == 400 + assert response.text == "Part exceeded maximum size of 1024KB."