Write uploaded file bytes to a SpooledTemporaryFile

This commit is contained in:
Masen Furer 2025-02-03 12:11:46 -08:00
parent ef93161840
commit 7f0efa13e9
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95

View File

@ -9,11 +9,11 @@ import copy
import dataclasses import dataclasses
import functools import functools
import inspect import inspect
import io
import json import json
import multiprocessing import multiprocessing
import platform import platform
import sys import sys
import tempfile
import traceback import traceback
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@ -29,6 +29,7 @@ from typing import (
List, List,
MutableMapping, MutableMapping,
Optional, Optional,
Sequence,
Set, Set,
Type, Type,
Union, Union,
@ -1485,6 +1486,48 @@ async def health() -> JSONResponse:
return JSONResponse(content=health_status, status_code=status_code) return JSONResponse(content=health_status, status_code=status_code)
def _handle_temporary_upload_file(
upload_file: UploadFile, temp_root: tempfile.TemporaryDirectory
) -> tempfile.SpooledTemporaryFile:
temp_file = tempfile.SpooledTemporaryFile(max_size=1024 * 1024, dir=temp_root.name)
temp_file.write(upload_file.file.read())
temp_file.seek(0)
return temp_file
async def temporary_upload_tree(
token: str, files: List[UploadFile]
) -> AsyncIterator[Sequence[tempfile.SpooledTemporaryFile]]:
"""Write the uploaded files to a temporary directory structure.
Args:
token: The token to use for the temporary directory.
files: The files to write to the temporary directory.
Yields:
A list of the temporary files.
"""
upload_dir = get_upload_dir()
upload_dir.mkdir(parents=True, exist_ok=True)
temp_root = tempfile.TemporaryDirectory(prefix=token, dir=upload_dir)
temp_files = []
loop = asyncio.get_running_loop()
temp_files = [
await loop.run_in_executor(None, _handle_temporary_upload_file, f, temp_root)
for f in files
]
try:
yield temp_files
finally:
def _cleanup():
for temp_file in temp_files:
temp_file.close()
temp_root.cleanup()
await loop.run_in_executor(None, _cleanup)
def upload(app: App): def upload(app: App):
"""Upload a file. """Upload a file.
@ -1563,24 +1606,21 @@ def upload(app: App):
# AsyncExitStack was removed from the request scope and is now # AsyncExitStack was removed from the request scope and is now
# part of the routing function which closes this before the # part of the routing function which closes this before the
# event is handled. # event is handled.
file_copies = [] file_ctx = temporary_upload_tree(token, files)
for file in files: temp_files = [
content_copy = io.BytesIO() UploadFile(
content_copy.write(await file.read()) file=tmp, # pyright: ignore[reportArgumentType]
content_copy.seek(0) filename=file.filename,
file_copies.append( size=file.size,
UploadFile( headers=file.headers,
file=content_copy,
filename=file.filename,
size=file.size,
headers=file.headers,
)
) )
for file, tmp in zip(files, await anext(file_ctx), strict=True)
]
event = Event( event = Event(
token=token, token=token,
name=handler, name=handler,
payload={handler_upload_param[0]: file_copies}, payload={handler_upload_param[0]: temp_files},
) )
async def _ndjson_updates(): async def _ndjson_updates():
@ -1595,6 +1635,9 @@ def upload(app: App):
# Postprocess the event. # Postprocess the event.
update = await app._postprocess(state, event, update) update = await app._postprocess(state, event, update)
yield update.json() + "\n" yield update.json() + "\n"
# Clean up the temporary files.
async for _ in file_ctx:
pass
# Stream updates to client # Stream updates to client
return StreamingResponse( return StreamingResponse(