fix download for uploaded files

This commit is contained in:
Lendemor 2024-04-08 18:24:27 +02:00
parent 67702e0927
commit 7a97cf3672
7 changed files with 66 additions and 17 deletions

View File

@ -162,7 +162,10 @@ export const applyEvent = async (event, socket) => {
const a = document.createElement("a"); const a = document.createElement("a");
a.hidden = true; a.hidden = true;
// Special case when linking to uploaded files // Special case when linking to uploaded files
a.href = event.payload.url.replace("${getBackendURL(env.UPLOAD)}", getBackendURL(env.UPLOAD)) a.href = event.payload.url.replace(
"${getBackendURL(env.UPLOAD)}",
getBackendURL(env.UPLOAD)
);
a.download = event.payload.filename; a.download = event.payload.filename;
a.click(); a.click();
a.remove(); a.remove();
@ -647,11 +650,11 @@ export const useEventLoop = (
// Route after the initial page hydration. // Route after the initial page hydration.
useEffect(() => { useEffect(() => {
const change_start = () => { const change_start = () => {
const main_state_dispatch = dispatch["state"] const main_state_dispatch = dispatch["state"];
if (main_state_dispatch !== undefined) { if (main_state_dispatch !== undefined) {
main_state_dispatch({is_hydrated: false}) main_state_dispatch({ is_hydrated: false });
}
} }
};
const change_complete = () => addEvents(onLoadInternalEvent()); const change_complete = () => addEvents(onLoadInternalEvent());
router.events.on("routeChangeStart", change_start); router.events.on("routeChangeStart", change_start);
router.events.on("routeChangeComplete", change_complete); router.events.on("routeChangeComplete", change_complete);

View File

@ -77,6 +77,7 @@ from reflex.state import (
_substate_key, _substate_key,
code_uses_state_contexts, code_uses_state_contexts,
) )
from reflex.staticfiles import DownloadFiles
from reflex.utils import console, exceptions, format, prerequisites, types from reflex.utils import console, exceptions, format, prerequisites, types
from reflex.utils.exec import is_testing_env, should_skip_compile from reflex.utils.exec import is_testing_env, should_skip_compile
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
@ -271,17 +272,23 @@ class App(Base):
self.api.get(str(constants.Endpoint.PING))(ping) self.api.get(str(constants.Endpoint.PING))(ping)
def add_optional_endpoints(self): def add_optional_endpoints(self):
"""Add optional api endpoints (_upload).""" """Add optional api endpoints (_upload and _download)."""
# To upload files. # To upload files.
if Upload.is_used: if Upload.is_used:
self.api.post(str(constants.Endpoint.UPLOAD))(upload(self)) self.api.post(str(constants.Endpoint.UPLOAD))(upload(self))
# To access uploaded files. # To access uploaded files as assets.
self.api.mount( self.api.mount(
str(constants.Endpoint.UPLOAD), str(constants.Endpoint.UPLOAD),
StaticFiles(directory=get_upload_dir()), StaticFiles(directory=get_upload_dir()),
name="uploaded_files", name="uploaded_files",
) )
# To download uploaded files.
self.api.mount(
str(constants.Endpoint.DOWNLOAD),
DownloadFiles(directory=get_upload_dir()),
name="download_files",
)
def add_cors(self): def add_cors(self):
"""Add CORS middleware to the app.""" """Add CORS middleware to the app."""

View File

@ -117,29 +117,37 @@ def get_upload_dir() -> Path:
return uploaded_files_dir return uploaded_files_dir
uploaded_files_url_prefix: Var = Var.create_safe( upload_var_data: VarData = VarData( # type: ignore
"${getBackendURL(env.UPLOAD)}"
)._replace(
merge_var_data=VarData( # type: ignore
imports={ imports={
f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")}, f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")},
"/env.json": {imports.ImportVar(tag="env", is_default=True)}, "/env.json": {imports.ImportVar(tag="env", is_default=True)},
} }
) )
)
uploaded_files_url_prefix: Var = Var.create_safe(
"${getBackendURL(env.UPLOAD)}"
)._replace(merge_var_data=upload_var_data)
download_files_url_prefix: Var = Var.create_safe(
"${getBackendURL(env.DOWNLOAD)}"
)._replace(merge_var_data=upload_var_data)
def get_upload_url(file_path: str) -> Var[str]: def get_upload_url(file_path: str, download: bool = False) -> Var[str]:
"""Get the URL of an uploaded file. """Get the URL of an uploaded file.
Args: Args:
file_path: The path of the uploaded file. file_path: The path of the uploaded file.
download: Whether to get the download URL instead of the upload URL.
Returns: Returns:
The URL of the uploaded file to be rendered from the frontend (as a str-encoded Var). The URL of the uploaded file to be rendered from the frontend (as a str-encoded Var).
""" """
Upload.is_used = True Upload.is_used = True
if download:
return Var.create_safe(f"{download_files_url_prefix}/{file_path}")
return Var.create_safe(f"{uploaded_files_url_prefix}/{file_path}") return Var.create_safe(f"{uploaded_files_url_prefix}/{file_path}")

View File

@ -38,9 +38,11 @@ def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ...
def cancel_upload(upload_id: str) -> EventSpec: ... def cancel_upload(upload_id: str) -> EventSpec: ...
def get_upload_dir() -> Path: ... def get_upload_dir() -> Path: ...
upload_var_data: VarData
uploaded_files_url_prefix: Var uploaded_files_url_prefix: Var
download_files_url_prefix: Var
def get_upload_url(file_path: str) -> Var[str]: ... def get_upload_url(file_path: str, download: bool = False) -> Var[str]: ...
class UploadFilesProvider(Component): class UploadFilesProvider(Component):
@overload @overload

View File

@ -1,4 +1,5 @@
"""Event-related constants.""" """Event-related constants."""
from enum import Enum from enum import Enum
from types import SimpleNamespace from types import SimpleNamespace
@ -9,6 +10,7 @@ class Endpoint(Enum):
PING = "ping" PING = "ping"
EVENT = "_event" EVENT = "_event"
UPLOAD = "_upload" UPLOAD = "_upload"
DOWNLOAD = "_download"
def __str__(self) -> str: def __str__(self) -> str:
"""Get the string representation of the endpoint. """Get the string representation of the endpoint.

View File

@ -585,6 +585,10 @@ def download(
if filename is None: if filename is None:
filename = url.rpartition("/")[-1] filename = url.rpartition("/")[-1]
elif isinstance(url, Var):
# May need additional check here to only impact get_upload_url Vars ?
url._var_is_string = True
if filename is None: if filename is None:
filename = "" filename = ""

23
reflex/staticfiles.py Normal file
View File

@ -0,0 +1,23 @@
"""Classes for staticfiles served by Reflex backend."""
from fastapi.staticfiles import StaticFiles
from starlette.responses import Response
from starlette.types import Scope
class DownloadFiles(StaticFiles):
"""Static files with download headers."""
async def get_response(self, path: str, scope: Scope) -> Response:
"""Get the response for a static file with download headers.
Args:
path: The path of the static file.
scope: The request scope.
Returns:
The response for the static file with download headers.
"""
response = await super().get_response(path, scope)
response.headers["Content-Disposition"] = "attachment"
return response