From bda99bc9b89a9f13e4140a9c975c0c599dd3b306 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Tue, 4 Feb 2025 13:56:57 -0800 Subject: [PATCH] all my friends hate fast api upload file --- reflex/app.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index f1f6cb43a..314b48c16 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -22,6 +22,7 @@ from typing import ( TYPE_CHECKING, Any, AsyncIterator, + BinaryIO, Callable, Coroutine, Dict, @@ -36,12 +37,15 @@ from typing import ( get_type_hints, ) -from fastapi import FastAPI, HTTPException, Request, UploadFile +from fastapi import FastAPI, HTTPException, Request +from fastapi import UploadFile as FastAPIUploadFile from fastapi.middleware import cors from fastapi.responses import JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp, AsyncNamespace, AsyncServer +from starlette.datastructures import Headers +from starlette.datastructures import UploadFile as StarletteUploadFile from starlette_admin.contrib.sqla.admin import Admin from starlette_admin.contrib.sqla.view import ModelView @@ -192,6 +196,53 @@ class OverlayFragment(Fragment): pass +@dataclasses.dataclass(frozen=True) +class UploadFile(StarletteUploadFile): + """A file uploaded to the server. + + Args: + file: The standard Python file object (non-async). + filename: The original file name. + size: The size of the file in bytes. + headers: The headers of the request. + """ + + file: BinaryIO + + path: Optional[Path] = dataclasses.field(default=None) + + deprecated_filename: Optional[str] = dataclasses.field(default=None) + + size: Optional[int] = dataclasses.field(default=None) + + headers: Headers = dataclasses.field(default_factory=Headers) + + @property + def name(self) -> Optional[str]: + """Get the name of the uploaded file. + + Returns: + The name of the uploaded file. + """ + if self.path: + return self.path.name + + @property + def filename(self) -> Optional[str]: + """Get the filename of the uploaded file. + + Returns: + The filename of the uploaded file. + """ + console.deprecate( + feature_name="UploadFile.filename", + reason="Use UploadFile.name instead.", + deprecation_version="0.7.1", + removal_version="0.8.0", + ) + return self.deprecated_filename + + @dataclasses.dataclass( frozen=True, ) @@ -1483,7 +1534,7 @@ def upload(app: App): The upload function. """ - async def upload_file(request: Request, files: List[UploadFile]): + async def upload_file(request: Request, files: List[FastAPIUploadFile]): """Upload a file. Args: @@ -1559,7 +1610,8 @@ def upload(app: App): file_copies.append( UploadFile( file=content_copy, - filename=Path(file.filename).name if file.filename else None, + path=Path(file.filename) if file.filename else None, + deprecated_filename=file.filename, size=file.size, headers=file.headers, )