Skip to content

Make UploadFile check for future rollover #2962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ def __init__(
self.size = size
self.headers = headers or Headers()

# Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
# Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
self._max_mem_size = getattr(self.file, "_max_size", 0)

@property
def content_type(self) -> str | None:
return self.headers.get("content-type", None)
Expand All @@ -438,14 +442,24 @@ def _in_memory(self) -> bool:
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk

def _will_roll(self, size_to_add: int) -> bool:
# If we're not in_memory then we will always roll
if not self._in_memory:
return True

# Check for SpooledTemporaryFile._max_size
future_size = self.file.tell() + size_to_add
return bool(future_size > self._max_mem_size) if self._max_mem_size else False

async def write(self, data: bytes) -> None:
new_data_len = len(data)
if self.size is not None:
self.size += len(data)
self.size += new_data_len

if self._in_memory:
self.file.write(data)
else:
if self._will_roll(new_data_len):
await run_in_threadpool(self.file.write, data)
else:
self.file.write(data)

async def read(self, size: int = -1) -> bytes:
if self._in_memory:
Expand Down
66 changes: 64 additions & 2 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations

import os
import threading
from collections.abc import Generator
from contextlib import AbstractContextManager, nullcontext as does_not_raise
from io import BytesIO
from pathlib import Path
from typing import Any
from tempfile import SpooledTemporaryFile
from typing import Any, ClassVar
from unittest import mock

import pytest

from starlette.applications import Starlette
from starlette.datastructures import UploadFile
from starlette.formparsers import MultiPartException, _user_safe_decode
from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount
Expand Down Expand Up @@ -104,6 +109,22 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None:
await response(scope, receive, send)


async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None:
"""Helper app to monitor what thread the app was called on.

This can later be used to validate thread/event loop operations.
"""
request = Request(scope, receive)

# Make sure we parse the form
await request.form()
await request.close()

# Send back the current thread id
response = JSONResponse({"thread_ident": threading.current_thread().ident})
await response(scope, receive, send)


def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
Expand Down Expand Up @@ -303,6 +324,47 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor
}


class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]):
"""Helper class to track which threads performed the rollover operation.

This is not threadsafe/multi-test safe.
"""

rollover_threads: ClassVar[set[int | None]] = set()

def rollover(self) -> None:
ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident)
super().rollover()


@pytest.fixture
def mock_spooled_temporary_file() -> Generator[None]:
try:
with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile):
yield
finally:
ThreadTrackingSpooledTemporaryFile.rollover_threads.clear()


def test_multipart_request_large_file_rollover_in_background_thread(
mock_spooled_temporary_file: None, test_client_factory: TestClientFactory
) -> None:
"""Test that Spooled file rollovers happen in background threads."""
data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1))

client = test_client_factory(app_monitor_thread)
response = client.post("/", files=[("test_large", data)])
assert response.status_code == 200

# Parse the event thread id from the API response and ensure we have one
app_thread_ident = response.json().get("thread_ident")
assert app_thread_ident is not None

# Ensure the app thread was not the same as the rollover one and that a rollover thread exists
assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads
assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1


def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post(
Expand Down