diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 95f7c5a05..aa60c02b7 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -162,7 +162,10 @@ export const applyEvent = async (event, socket) => { const a = document.createElement("a"); a.hidden = true; // Special case when linking to uploaded files - a.href = event.payload.url.replace("${getBackendURL(env.UPLOAD)}", getBackendURL(env.UPLOAD)) + a.href = event.payload.url.replace( + "${getBackendURL(env.UPLOAD)}", + getBackendURL(env.UPLOAD) + ); a.download = event.payload.filename; a.click(); a.remove(); @@ -647,11 +650,11 @@ export const useEventLoop = ( // Route after the initial page hydration. useEffect(() => { const change_start = () => { - const main_state_dispatch = dispatch["state"] + const main_state_dispatch = dispatch["state"]; if (main_state_dispatch !== undefined) { - main_state_dispatch({is_hydrated: false}) + main_state_dispatch({ is_hydrated: false }); } - } + }; const change_complete = () => addEvents(onLoadInternalEvent()); router.events.on("routeChangeStart", change_start); router.events.on("routeChangeComplete", change_complete); diff --git a/reflex/app.py b/reflex/app.py index 50000e507..82c528fb2 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -77,6 +77,7 @@ from reflex.state import ( _substate_key, code_uses_state_contexts, ) +from reflex.staticfiles import DownloadFiles from reflex.utils import console, exceptions, format, prerequisites, types from reflex.utils.exec import is_testing_env, should_skip_compile from reflex.utils.imports import ImportVar @@ -271,17 +272,23 @@ class App(Base): self.api.get(str(constants.Endpoint.PING))(ping) def add_optional_endpoints(self): - """Add optional api endpoints (_upload).""" + """Add optional api endpoints (_upload and _download).""" # To upload files. if Upload.is_used: self.api.post(str(constants.Endpoint.UPLOAD))(upload(self)) - # To access uploaded files. + # To access uploaded files as assets. self.api.mount( str(constants.Endpoint.UPLOAD), StaticFiles(directory=get_upload_dir()), name="uploaded_files", ) + # To download uploaded files. + self.api.mount( + str(constants.Endpoint.DOWNLOAD), + DownloadFiles(directory=get_upload_dir()), + name="download_files", + ) def add_cors(self): """Add CORS middleware to the app.""" diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index dfc62bc0f..adeb64cad 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -117,29 +117,37 @@ def get_upload_dir() -> Path: return uploaded_files_dir -uploaded_files_url_prefix: Var = Var.create_safe( - "${getBackendURL(env.UPLOAD)}" -)._replace( - merge_var_data=VarData( # type: ignore - imports={ - f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")}, - "/env.json": {imports.ImportVar(tag="env", is_default=True)}, - } - ) +upload_var_data: VarData = VarData( # type: ignore + imports={ + f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")}, + "/env.json": {imports.ImportVar(tag="env", is_default=True)}, + } ) +uploaded_files_url_prefix: Var = Var.create_safe( + "${getBackendURL(env.UPLOAD)}" +)._replace(merge_var_data=upload_var_data) -def get_upload_url(file_path: str) -> Var[str]: +download_files_url_prefix: Var = Var.create_safe( + "${getBackendURL(env.DOWNLOAD)}" +)._replace(merge_var_data=upload_var_data) + + +def get_upload_url(file_path: str, download: bool = False) -> Var[str]: """Get the URL of an uploaded file. Args: file_path: The path of the uploaded file. + download: Whether to get the download URL instead of the upload URL. Returns: The URL of the uploaded file to be rendered from the frontend (as a str-encoded Var). """ Upload.is_used = True + if download: + return Var.create_safe(f"{download_files_url_prefix}/{file_path}") + return Var.create_safe(f"{uploaded_files_url_prefix}/{file_path}") diff --git a/reflex/components/core/upload.pyi b/reflex/components/core/upload.pyi index b8387e696..23238b9fa 100644 --- a/reflex/components/core/upload.pyi +++ b/reflex/components/core/upload.pyi @@ -38,9 +38,11 @@ def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ... def cancel_upload(upload_id: str) -> EventSpec: ... def get_upload_dir() -> Path: ... +upload_var_data: VarData uploaded_files_url_prefix: Var +download_files_url_prefix: Var -def get_upload_url(file_path: str) -> Var[str]: ... +def get_upload_url(file_path: str, download: bool = False) -> Var[str]: ... class UploadFilesProvider(Component): @overload diff --git a/reflex/constants/event.py b/reflex/constants/event.py index aa6f8c713..4797a65ca 100644 --- a/reflex/constants/event.py +++ b/reflex/constants/event.py @@ -1,4 +1,5 @@ """Event-related constants.""" + from enum import Enum from types import SimpleNamespace @@ -9,6 +10,7 @@ class Endpoint(Enum): PING = "ping" EVENT = "_event" UPLOAD = "_upload" + DOWNLOAD = "_download" def __str__(self) -> str: """Get the string representation of the endpoint. diff --git a/reflex/event.py b/reflex/event.py index 19fdad5e9..b46badead 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -585,6 +585,10 @@ def download( if filename is None: filename = url.rpartition("/")[-1] + elif isinstance(url, Var): + # May need additional check here to only impact get_upload_url Vars ? + url._var_is_string = True + if filename is None: filename = "" diff --git a/reflex/staticfiles.py b/reflex/staticfiles.py new file mode 100644 index 000000000..83bf4c994 --- /dev/null +++ b/reflex/staticfiles.py @@ -0,0 +1,23 @@ +"""Classes for staticfiles served by Reflex backend.""" + +from fastapi.staticfiles import StaticFiles +from starlette.responses import Response +from starlette.types import Scope + + +class DownloadFiles(StaticFiles): + """Static files with download headers.""" + + async def get_response(self, path: str, scope: Scope) -> Response: + """Get the response for a static file with download headers. + + Args: + path: The path of the static file. + scope: The request scope. + + Returns: + The response for the static file with download headers. + """ + response = await super().get_response(path, scope) + response.headers["Content-Disposition"] = "attachment" + return response