Compare commits

...

1 Commits

Author SHA1 Message Date
Masen Furer
7f0efa13e9
Write uploaded file bytes to a SpooledTemporaryFile 2025-02-03 12:12:59 -08:00

View File

@ -9,11 +9,11 @@ import copy
import dataclasses import dataclasses
import functools import functools
import inspect import inspect
import io
import json import json
import multiprocessing import multiprocessing
import platform import platform
import sys import sys
import tempfile
import traceback import traceback
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@ -29,6 +29,7 @@ from typing import (
List, List,
MutableMapping, MutableMapping,
Optional, Optional,
Sequence,
Set, Set,
Type, Type,
Union, Union,
@ -1485,6 +1486,48 @@ async def health() -> JSONResponse:
return JSONResponse(content=health_status, status_code=status_code) return JSONResponse(content=health_status, status_code=status_code)
def _handle_temporary_upload_file(
upload_file: UploadFile, temp_root: tempfile.TemporaryDirectory
) -> tempfile.SpooledTemporaryFile:
temp_file = tempfile.SpooledTemporaryFile(max_size=1024 * 1024, dir=temp_root.name)
temp_file.write(upload_file.file.read())
temp_file.seek(0)
return temp_file
async def temporary_upload_tree(
token: str, files: List[UploadFile]
) -> AsyncIterator[Sequence[tempfile.SpooledTemporaryFile]]:
"""Write the uploaded files to a temporary directory structure.
Args:
token: The token to use for the temporary directory.
files: The files to write to the temporary directory.
Yields:
A list of the temporary files.
"""
upload_dir = get_upload_dir()
upload_dir.mkdir(parents=True, exist_ok=True)
temp_root = tempfile.TemporaryDirectory(prefix=token, dir=upload_dir)
temp_files = []
loop = asyncio.get_running_loop()
temp_files = [
await loop.run_in_executor(None, _handle_temporary_upload_file, f, temp_root)
for f in files
]
try:
yield temp_files
finally:
def _cleanup():
for temp_file in temp_files:
temp_file.close()
temp_root.cleanup()
await loop.run_in_executor(None, _cleanup)
def upload(app: App): def upload(app: App):
"""Upload a file. """Upload a file.
@ -1563,24 +1606,21 @@ def upload(app: App):
# AsyncExitStack was removed from the request scope and is now # AsyncExitStack was removed from the request scope and is now
# part of the routing function which closes this before the # part of the routing function which closes this before the
# event is handled. # event is handled.
file_copies = [] file_ctx = temporary_upload_tree(token, files)
for file in files: temp_files = [
content_copy = io.BytesIO() UploadFile(
content_copy.write(await file.read()) file=tmp, # pyright: ignore[reportArgumentType]
content_copy.seek(0) filename=file.filename,
file_copies.append( size=file.size,
UploadFile( headers=file.headers,
file=content_copy,
filename=file.filename,
size=file.size,
headers=file.headers,
)
) )
for file, tmp in zip(files, await anext(file_ctx), strict=True)
]
event = Event( event = Event(
token=token, token=token,
name=handler, name=handler,
payload={handler_upload_param[0]: file_copies}, payload={handler_upload_param[0]: temp_files},
) )
async def _ndjson_updates(): async def _ndjson_updates():
@ -1595,6 +1635,9 @@ def upload(app: App):
# Postprocess the event. # Postprocess the event.
update = await app._postprocess(state, event, update) update = await app._postprocess(state, event, update)
yield update.json() + "\n" yield update.json() + "\n"
# Clean up the temporary files.
async for _ in file_ctx:
pass
# Stream updates to client # Stream updates to client
return StreamingResponse( return StreamingResponse(