From 5b283baef40921a1a32224e486627242c807c9c9 Mon Sep 17 00:00:00 2001 From: Lendemor Date: Mon, 8 Apr 2024 21:00:12 +0200 Subject: [PATCH] simplify changes to reduce API changes --- reflex/app.py | 11 ++--------- reflex/components/core/upload.py | 23 ++++++++--------------- reflex/components/core/upload.pyi | 2 -- reflex/staticfiles.py | 11 +++++++++-- 4 files changed, 19 insertions(+), 28 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 82c528fb2..27f34a8bb 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -29,7 +29,6 @@ from typing import ( from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.middleware import cors from fastapi.responses import StreamingResponse -from fastapi.staticfiles import StaticFiles from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp, AsyncNamespace, AsyncServer from starlette_admin.contrib.sqla.admin import Admin @@ -77,7 +76,7 @@ from reflex.state import ( _substate_key, code_uses_state_contexts, ) -from reflex.staticfiles import DownloadFiles +from reflex.staticfiles import UploadedFiles from reflex.utils import console, exceptions, format, prerequisites, types from reflex.utils.exec import is_testing_env, should_skip_compile from reflex.utils.imports import ImportVar @@ -280,15 +279,9 @@ class App(Base): # To access uploaded files as assets. self.api.mount( str(constants.Endpoint.UPLOAD), - StaticFiles(directory=get_upload_dir()), + UploadedFiles(directory=get_upload_dir()), name="uploaded_files", ) - # To download uploaded files. - self.api.mount( - str(constants.Endpoint.DOWNLOAD), - DownloadFiles(directory=get_upload_dir()), - name="download_files", - ) def add_cors(self): """Add CORS middleware to the app.""" diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index adeb64cad..75154ca37 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -117,20 +117,16 @@ def get_upload_dir() -> Path: return uploaded_files_dir -upload_var_data: VarData = VarData( # type: ignore - imports={ - f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")}, - "/env.json": {imports.ImportVar(tag="env", is_default=True)}, - } -) - uploaded_files_url_prefix: Var = Var.create_safe( "${getBackendURL(env.UPLOAD)}" -)._replace(merge_var_data=upload_var_data) - -download_files_url_prefix: Var = Var.create_safe( - "${getBackendURL(env.DOWNLOAD)}" -)._replace(merge_var_data=upload_var_data) +)._replace( + merge_var_data=VarData( # type: ignore + imports={ + f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")}, + "/env.json": {imports.ImportVar(tag="env", is_default=True)}, + } + ) +) def get_upload_url(file_path: str, download: bool = False) -> Var[str]: @@ -145,9 +141,6 @@ def get_upload_url(file_path: str, download: bool = False) -> Var[str]: """ Upload.is_used = True - if download: - return Var.create_safe(f"{download_files_url_prefix}/{file_path}") - return Var.create_safe(f"{uploaded_files_url_prefix}/{file_path}") diff --git a/reflex/components/core/upload.pyi b/reflex/components/core/upload.pyi index 23238b9fa..1f9d16d73 100644 --- a/reflex/components/core/upload.pyi +++ b/reflex/components/core/upload.pyi @@ -38,9 +38,7 @@ def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ... def cancel_upload(upload_id: str) -> EventSpec: ... def get_upload_dir() -> Path: ... -upload_var_data: VarData uploaded_files_url_prefix: Var -download_files_url_prefix: Var def get_upload_url(file_path: str, download: bool = False) -> Var[str]: ... diff --git a/reflex/staticfiles.py b/reflex/staticfiles.py index 83bf4c994..d929eec79 100644 --- a/reflex/staticfiles.py +++ b/reflex/staticfiles.py @@ -1,11 +1,12 @@ """Classes for staticfiles served by Reflex backend.""" from fastapi.staticfiles import StaticFiles +from starlette.requests import Request from starlette.responses import Response from starlette.types import Scope -class DownloadFiles(StaticFiles): +class UploadedFiles(StaticFiles): """Static files with download headers.""" async def get_response(self, path: str, scope: Scope) -> Response: @@ -18,6 +19,12 @@ class DownloadFiles(StaticFiles): Returns: The response for the static file with download headers. """ + req = Request(scope) + if "filename" in req.query_params: + filename = req.query_params["filename"] + content_disposition = f'attachment; filename="{filename}"' + else: + content_disposition = "attachment" response = await super().get_response(path, scope) - response.headers["Content-Disposition"] = "attachment" + response.headers["Content-Disposition"] = content_disposition return response