Compare commits

...

1 Commits

Author SHA1 Message Date
Masen Furer
7f0efa13e9
Write uploaded file bytes to a SpooledTemporaryFile 2025-02-03 12:12:59 -08:00

View File

@ -9,11 +9,11 @@ import copy
import dataclasses
import functools
import inspect
import io
import json
import multiprocessing
import platform
import sys
import tempfile
import traceback
from datetime import datetime
from pathlib import Path
@ -29,6 +29,7 @@ from typing import (
List,
MutableMapping,
Optional,
Sequence,
Set,
Type,
Union,
@ -1485,6 +1486,48 @@ async def health() -> JSONResponse:
return JSONResponse(content=health_status, status_code=status_code)
def _handle_temporary_upload_file(
upload_file: UploadFile, temp_root: tempfile.TemporaryDirectory
) -> tempfile.SpooledTemporaryFile:
temp_file = tempfile.SpooledTemporaryFile(max_size=1024 * 1024, dir=temp_root.name)
temp_file.write(upload_file.file.read())
temp_file.seek(0)
return temp_file
async def temporary_upload_tree(
token: str, files: List[UploadFile]
) -> AsyncIterator[Sequence[tempfile.SpooledTemporaryFile]]:
"""Write the uploaded files to a temporary directory structure.
Args:
token: The token to use for the temporary directory.
files: The files to write to the temporary directory.
Yields:
A list of the temporary files.
"""
upload_dir = get_upload_dir()
upload_dir.mkdir(parents=True, exist_ok=True)
temp_root = tempfile.TemporaryDirectory(prefix=token, dir=upload_dir)
temp_files = []
loop = asyncio.get_running_loop()
temp_files = [
await loop.run_in_executor(None, _handle_temporary_upload_file, f, temp_root)
for f in files
]
try:
yield temp_files
finally:
def _cleanup():
for temp_file in temp_files:
temp_file.close()
temp_root.cleanup()
await loop.run_in_executor(None, _cleanup)
def upload(app: App):
"""Upload a file.
@ -1563,24 +1606,21 @@ def upload(app: App):
# AsyncExitStack was removed from the request scope and is now
# part of the routing function which closes this before the
# event is handled.
file_copies = []
for file in files:
content_copy = io.BytesIO()
content_copy.write(await file.read())
content_copy.seek(0)
file_copies.append(
UploadFile(
file=content_copy,
filename=file.filename,
size=file.size,
headers=file.headers,
)
file_ctx = temporary_upload_tree(token, files)
temp_files = [
UploadFile(
file=tmp, # pyright: ignore[reportArgumentType]
filename=file.filename,
size=file.size,
headers=file.headers,
)
for file, tmp in zip(files, await anext(file_ctx), strict=True)
]
event = Event(
token=token,
name=handler,
payload={handler_upload_param[0]: file_copies},
payload={handler_upload_param[0]: temp_files},
)
async def _ndjson_updates():
@ -1595,6 +1635,9 @@ def upload(app: App):
# Postprocess the event.
update = await app._postprocess(state, event, update)
yield update.json() + "\n"
# Clean up the temporary files.
async for _ in file_ctx:
pass
# Stream updates to client
return StreamingResponse(