Skip to content

Make UploadFile check for future rollover #2962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update test to actually test the bug
Signed-off-by: Michael Honaker <mchonaker@gmail.com>
  • Loading branch information
HonakerM committed Jul 11, 2025
commit 753e790fde6bac8f5cbce690dfd6f88846c24328
52 changes: 46 additions & 6 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from contextlib import AbstractContextManager, nullcontext as does_not_raise
from io import BytesIO
from pathlib import Path
from tempfile import SpooledTemporaryFile
import threading
from typing import Any
from unittest import mock

import pytest

Expand Down Expand Up @@ -105,6 +108,20 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None:
await response(scope, receive, send)


async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None:
"""Helper app to monitor what thread the app was called on. This can later
be used to validate thread/event loop operations"""
request = Request(scope, receive)

# Make sure we parse the form
await request.form()
await request.close()

# Send back the current thread id
response = JSONResponse({"thread_ident": threading.current_thread().ident})
await response(scope, receive, send)


def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
Expand Down Expand Up @@ -304,14 +321,37 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor
}


class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile):
"""Helper class to track which threads performed the rollover operation. This is
not threadsafe/multi-test safe"""

rollover_threads: set[int] = set()

def rollover(self):
ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident)
return super().rollover()


def test_multipart_request_large_file(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
data = BytesIO(b" " * MultiPartParser.spool_max_size * 2)
client = test_client_factory(app)
response = client.post(
"/",
files=[("test_large", data)],
)
assert response.status_code == 200

# Mock the formparser to use our monitoring class
with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile):
client = test_client_factory(app_monitor_thread)
response = client.post(
"/",
files=[("test_large", data)],
)
assert response.status_code == 200

# Parse the event thread id from the API response and ensure we have one
app_thread_ident = response.json().get("thread_ident")
assert app_thread_ident

# Ensure the app thread was not the same as the rollover one and that a rollover thread
# exists
assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads
assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) > 0


def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
Expand Down
Loading