Compare commits
1 Commits
main
...
masenf/upl
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7f0efa13e9 |
@ -9,11 +9,11 @@ import copy
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -29,6 +29,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
MutableMapping,
|
MutableMapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
@ -1485,6 +1486,48 @@ async def health() -> JSONResponse:
|
|||||||
return JSONResponse(content=health_status, status_code=status_code)
|
return JSONResponse(content=health_status, status_code=status_code)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_temporary_upload_file(
|
||||||
|
upload_file: UploadFile, temp_root: tempfile.TemporaryDirectory
|
||||||
|
) -> tempfile.SpooledTemporaryFile:
|
||||||
|
temp_file = tempfile.SpooledTemporaryFile(max_size=1024 * 1024, dir=temp_root.name)
|
||||||
|
temp_file.write(upload_file.file.read())
|
||||||
|
temp_file.seek(0)
|
||||||
|
return temp_file
|
||||||
|
|
||||||
|
|
||||||
|
async def temporary_upload_tree(
|
||||||
|
token: str, files: List[UploadFile]
|
||||||
|
) -> AsyncIterator[Sequence[tempfile.SpooledTemporaryFile]]:
|
||||||
|
"""Write the uploaded files to a temporary directory structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token to use for the temporary directory.
|
||||||
|
files: The files to write to the temporary directory.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A list of the temporary files.
|
||||||
|
"""
|
||||||
|
upload_dir = get_upload_dir()
|
||||||
|
upload_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
temp_root = tempfile.TemporaryDirectory(prefix=token, dir=upload_dir)
|
||||||
|
temp_files = []
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
temp_files = [
|
||||||
|
await loop.run_in_executor(None, _handle_temporary_upload_file, f, temp_root)
|
||||||
|
for f in files
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
yield temp_files
|
||||||
|
finally:
|
||||||
|
|
||||||
|
def _cleanup():
|
||||||
|
for temp_file in temp_files:
|
||||||
|
temp_file.close()
|
||||||
|
temp_root.cleanup()
|
||||||
|
|
||||||
|
await loop.run_in_executor(None, _cleanup)
|
||||||
|
|
||||||
|
|
||||||
def upload(app: App):
|
def upload(app: App):
|
||||||
"""Upload a file.
|
"""Upload a file.
|
||||||
|
|
||||||
@ -1563,24 +1606,21 @@ def upload(app: App):
|
|||||||
# AsyncExitStack was removed from the request scope and is now
|
# AsyncExitStack was removed from the request scope and is now
|
||||||
# part of the routing function which closes this before the
|
# part of the routing function which closes this before the
|
||||||
# event is handled.
|
# event is handled.
|
||||||
file_copies = []
|
file_ctx = temporary_upload_tree(token, files)
|
||||||
for file in files:
|
temp_files = [
|
||||||
content_copy = io.BytesIO()
|
UploadFile(
|
||||||
content_copy.write(await file.read())
|
file=tmp, # pyright: ignore[reportArgumentType]
|
||||||
content_copy.seek(0)
|
filename=file.filename,
|
||||||
file_copies.append(
|
size=file.size,
|
||||||
UploadFile(
|
headers=file.headers,
|
||||||
file=content_copy,
|
|
||||||
filename=file.filename,
|
|
||||||
size=file.size,
|
|
||||||
headers=file.headers,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for file, tmp in zip(files, await anext(file_ctx), strict=True)
|
||||||
|
]
|
||||||
|
|
||||||
event = Event(
|
event = Event(
|
||||||
token=token,
|
token=token,
|
||||||
name=handler,
|
name=handler,
|
||||||
payload={handler_upload_param[0]: file_copies},
|
payload={handler_upload_param[0]: temp_files},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _ndjson_updates():
|
async def _ndjson_updates():
|
||||||
@ -1595,6 +1635,9 @@ def upload(app: App):
|
|||||||
# Postprocess the event.
|
# Postprocess the event.
|
||||||
update = await app._postprocess(state, event, update)
|
update = await app._postprocess(state, event, update)
|
||||||
yield update.json() + "\n"
|
yield update.json() + "\n"
|
||||||
|
# Clean up the temporary files.
|
||||||
|
async for _ in file_ctx:
|
||||||
|
pass
|
||||||
|
|
||||||
# Stream updates to client
|
# Stream updates to client
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
Loading…
Reference in New Issue
Block a user