From 7f0efa13e9230d84286af2290a80e66399d42154 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 3 Feb 2025 12:11:46 -0800 Subject: [PATCH] Write uploaded file bytes to a SpooledTemporaryFile --- reflex/app.py | 71 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 14 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index d9104ece6..aa392552e 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -9,11 +9,11 @@ import copy import dataclasses import functools import inspect -import io import json import multiprocessing import platform import sys +import tempfile import traceback from datetime import datetime from pathlib import Path @@ -29,6 +29,7 @@ from typing import ( List, MutableMapping, Optional, + Sequence, Set, Type, Union, @@ -1485,6 +1486,48 @@ async def health() -> JSONResponse: return JSONResponse(content=health_status, status_code=status_code) +def _handle_temporary_upload_file( + upload_file: UploadFile, temp_root: tempfile.TemporaryDirectory +) -> tempfile.SpooledTemporaryFile: + temp_file = tempfile.SpooledTemporaryFile(max_size=1024 * 1024, dir=temp_root.name) + temp_file.write(upload_file.file.read()) + temp_file.seek(0) + return temp_file + + +async def temporary_upload_tree( + token: str, files: List[UploadFile] +) -> AsyncIterator[Sequence[tempfile.SpooledTemporaryFile]]: + """Write the uploaded files to a temporary directory structure. + + Args: + token: The token to use for the temporary directory. + files: The files to write to the temporary directory. + + Yields: + A list of the temporary files. + """ + upload_dir = get_upload_dir() + upload_dir.mkdir(parents=True, exist_ok=True) + temp_root = tempfile.TemporaryDirectory(prefix=token, dir=upload_dir) + temp_files = [] + loop = asyncio.get_running_loop() + temp_files = [ + await loop.run_in_executor(None, _handle_temporary_upload_file, f, temp_root) + for f in files + ] + try: + yield temp_files + finally: + + def _cleanup(): + for temp_file in temp_files: + temp_file.close() + temp_root.cleanup() + + await loop.run_in_executor(None, _cleanup) + + def upload(app: App): """Upload a file. @@ -1563,24 +1606,21 @@ def upload(app: App): # AsyncExitStack was removed from the request scope and is now # part of the routing function which closes this before the # event is handled. - file_copies = [] - for file in files: - content_copy = io.BytesIO() - content_copy.write(await file.read()) - content_copy.seek(0) - file_copies.append( - UploadFile( - file=content_copy, - filename=file.filename, - size=file.size, - headers=file.headers, - ) + file_ctx = temporary_upload_tree(token, files) + temp_files = [ + UploadFile( + file=tmp, # pyright: ignore[reportArgumentType] + filename=file.filename, + size=file.size, + headers=file.headers, ) + for file, tmp in zip(files, await anext(file_ctx), strict=True) + ] event = Event( token=token, name=handler, - payload={handler_upload_param[0]: file_copies}, + payload={handler_upload_param[0]: temp_files}, ) async def _ndjson_updates(): @@ -1595,6 +1635,9 @@ def upload(app: App): # Postprocess the event. update = await app._postprocess(state, event, update) yield update.json() + "\n" + # Clean up the temporary files. + async for _ in file_ctx: + pass # Stream updates to client return StreamingResponse(