Compare commits
1 Commits
main
...
masenf/upl
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7f0efa13e9 |
@ -9,11 +9,11 @@ import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import multiprocessing
|
||||
import platform
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@ -29,6 +29,7 @@ from typing import (
|
||||
List,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
@ -1485,6 +1486,48 @@ async def health() -> JSONResponse:
|
||||
return JSONResponse(content=health_status, status_code=status_code)
|
||||
|
||||
|
||||
def _handle_temporary_upload_file(
|
||||
upload_file: UploadFile, temp_root: tempfile.TemporaryDirectory
|
||||
) -> tempfile.SpooledTemporaryFile:
|
||||
temp_file = tempfile.SpooledTemporaryFile(max_size=1024 * 1024, dir=temp_root.name)
|
||||
temp_file.write(upload_file.file.read())
|
||||
temp_file.seek(0)
|
||||
return temp_file
|
||||
|
||||
|
||||
async def temporary_upload_tree(
|
||||
token: str, files: List[UploadFile]
|
||||
) -> AsyncIterator[Sequence[tempfile.SpooledTemporaryFile]]:
|
||||
"""Write the uploaded files to a temporary directory structure.
|
||||
|
||||
Args:
|
||||
token: The token to use for the temporary directory.
|
||||
files: The files to write to the temporary directory.
|
||||
|
||||
Yields:
|
||||
A list of the temporary files.
|
||||
"""
|
||||
upload_dir = get_upload_dir()
|
||||
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_root = tempfile.TemporaryDirectory(prefix=token, dir=upload_dir)
|
||||
temp_files = []
|
||||
loop = asyncio.get_running_loop()
|
||||
temp_files = [
|
||||
await loop.run_in_executor(None, _handle_temporary_upload_file, f, temp_root)
|
||||
for f in files
|
||||
]
|
||||
try:
|
||||
yield temp_files
|
||||
finally:
|
||||
|
||||
def _cleanup():
|
||||
for temp_file in temp_files:
|
||||
temp_file.close()
|
||||
temp_root.cleanup()
|
||||
|
||||
await loop.run_in_executor(None, _cleanup)
|
||||
|
||||
|
||||
def upload(app: App):
|
||||
"""Upload a file.
|
||||
|
||||
@ -1563,24 +1606,21 @@ def upload(app: App):
|
||||
# AsyncExitStack was removed from the request scope and is now
|
||||
# part of the routing function which closes this before the
|
||||
# event is handled.
|
||||
file_copies = []
|
||||
for file in files:
|
||||
content_copy = io.BytesIO()
|
||||
content_copy.write(await file.read())
|
||||
content_copy.seek(0)
|
||||
file_copies.append(
|
||||
UploadFile(
|
||||
file=content_copy,
|
||||
filename=file.filename,
|
||||
size=file.size,
|
||||
headers=file.headers,
|
||||
)
|
||||
file_ctx = temporary_upload_tree(token, files)
|
||||
temp_files = [
|
||||
UploadFile(
|
||||
file=tmp, # pyright: ignore[reportArgumentType]
|
||||
filename=file.filename,
|
||||
size=file.size,
|
||||
headers=file.headers,
|
||||
)
|
||||
for file, tmp in zip(files, await anext(file_ctx), strict=True)
|
||||
]
|
||||
|
||||
event = Event(
|
||||
token=token,
|
||||
name=handler,
|
||||
payload={handler_upload_param[0]: file_copies},
|
||||
payload={handler_upload_param[0]: temp_files},
|
||||
)
|
||||
|
||||
async def _ndjson_updates():
|
||||
@ -1595,6 +1635,9 @@ def upload(app: App):
|
||||
# Postprocess the event.
|
||||
update = await app._postprocess(state, event, update)
|
||||
yield update.json() + "\n"
|
||||
# Clean up the temporary files.
|
||||
async for _ in file_ctx:
|
||||
pass
|
||||
|
||||
# Stream updates to client
|
||||
return StreamingResponse(
|
||||
|
Loading…
Reference in New Issue
Block a user