diff --git a/integration/test_call_script.py b/integration/test_call_script.py index 97f995487..3aea81e10 100644 --- a/integration/test_call_script.py +++ b/integration/test_call_script.py @@ -312,7 +312,7 @@ def test_call_script( update_counter_button.click() assert call_script.poll_for_value(counter, exp_not_equal="0") == "4" reset_button.click() - assert call_script.poll_for_value(counter, exp_not_equal="3") == "0" + assert call_script.poll_for_value(counter, exp_not_equal="4") == "0" return_button.click() update_counter_button.click() assert call_script.poll_for_value(counter, exp_not_equal="0") == "1" @@ -330,7 +330,7 @@ def test_call_script( script, ) reset_button.click() - assert call_script.poll_for_value(counter, exp_not_equal="3") == "0" + assert call_script.poll_for_value(counter, exp_not_equal="4") == "0" return_callback_button.click() update_counter_button.click() diff --git a/reflex/app.py b/reflex/app.py index 7d5e449bf..72632bc15 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -27,6 +27,7 @@ from typing import ( from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.middleware import cors from fastapi.responses import StreamingResponse +from fastapi.staticfiles import StaticFiles from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp, AsyncNamespace, AsyncServer from starlette_admin.contrib.sqla.admin import Admin @@ -46,7 +47,7 @@ from reflex.components.core.client_side_routing import ( Default404Page, wait_for_client_redirect, ) -from reflex.components.core.upload import Upload +from reflex.components.core.upload import Upload, get_uploaded_files_dir from reflex.components.radix import themes from reflex.config import get_config from reflex.event import Event, EventHandler, EventSpec @@ -252,6 +253,13 @@ class App(Base): if Upload.is_used: self.api.post(str(constants.Endpoint.UPLOAD))(upload(self)) + # To access uploaded files. + self.api.mount( + str(constants.Endpoint.UPLOAD), + StaticFiles(directory=get_uploaded_files_dir()), + name="uploaded_files", + ) + def add_cors(self): """Add CORS middleware to the app.""" self.api.add_middleware( diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index 489a5648a..aadb6fc77 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -1,6 +1,8 @@ """A file upload component.""" from __future__ import annotations +import os +from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, Union from reflex import constants @@ -92,6 +94,43 @@ def cancel_upload(upload_id: str) -> EventSpec: return call_script(f"upload_controllers[{upload_id!r}]?.abort()") +def get_uploaded_files_dir() -> Path: + """Get the directory where uploaded files are stored. + + Returns: + The directory where uploaded files are stored. + """ + uploaded_files_dir = Path( + os.environ.get("REFLEX_UPLOADED_FILES_DIR", "./uploaded_files") + ) + uploaded_files_dir.mkdir(parents=True, exist_ok=True) + return uploaded_files_dir + + +uploaded_files_url_prefix: Var = Var.create_safe( + "${getBackendURL(env.UPLOAD)}" +)._replace( + merge_var_data=VarData( # type: ignore + imports={ + f"/{Dirs.STATE_PATH}": {imports.ImportVar(tag="getBackendURL")}, + "/env.json": {imports.ImportVar(tag="env", is_default=True)}, + } + ) +) + + +def get_uploaded_file_url(file_path: str) -> str: + """Get the URL of an uploaded file. + + Args: + file_path: The path of the uploaded file. + + Returns: + The URL of the uploaded file to be rendered from the frontend (as a str-encoded Var). + """ + return f"{uploaded_files_url_prefix}/{file_path}" + + class UploadFilesProvider(Component): """AppWrap component that provides a dict of selected files by ID via useContext.""" diff --git a/reflex/components/core/upload.pyi b/reflex/components/core/upload.pyi index a7d82ada1..f11d1ea10 100644 --- a/reflex/components/core/upload.pyi +++ b/reflex/components/core/upload.pyi @@ -7,6 +7,8 @@ from typing import Any, Dict, Literal, Optional, Union, overload from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style +import os +from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, Union from reflex import constants from reflex.components.chakra.forms.input import Input @@ -27,6 +29,11 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ... @CallableEventSpec def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ... def cancel_upload(upload_id: str) -> EventSpec: ... +def get_uploaded_files_dir() -> Path: ... + +uploaded_files_url_prefix: Var + +def get_uploaded_file_url(file_path: str) -> str: ... class UploadFilesProvider(Component): @overload