diff --git a/integration/test_upload.py b/integration/test_upload.py index 9639fbe8e..648c68be5 100644 --- a/integration/test_upload.py +++ b/integration/test_upload.py @@ -1,13 +1,14 @@ """Integration tests for file upload.""" from __future__ import annotations +import asyncio import time from typing import Generator import pytest from selenium.webdriver.common.by import By -from reflex.testing import AppHarness +from reflex.testing import AppHarness, WebDriver def UploadFile(): @@ -16,12 +17,28 @@ def UploadFile(): class UploadState(rx.State): _file_data: dict[str, str] = {} + event_order: list[str] = [] + progress_dicts: list[dict] = [] async def handle_upload(self, files: list[rx.UploadFile]): for file in files: upload_data = await file.read() self._file_data[file.filename or ""] = upload_data.decode("utf-8") + async def handle_upload_secondary(self, files: list[rx.UploadFile]): + for file in files: + upload_data = await file.read() + self._file_data[file.filename or ""] = upload_data.decode("utf-8") + yield UploadState.chain_event + + def upload_progress(self, progress): + assert progress + self.event_order.append("upload_progress") + self.progress_dicts.append(progress) + + def chain_event(self): + self.event_order.append("chain_event") + def index(): return rx.vstack( rx.input( @@ -29,6 +46,7 @@ def UploadFile(): is_read_only=True, id="token", ), + rx.heading("Default Upload"), rx.upload( rx.vstack( rx.button("Select File"), @@ -52,6 +70,47 @@ def UploadFile(): on_click=rx.clear_selected_files, id="clear_button", ), + rx.heading("Secondary Upload"), + rx.upload( + rx.vstack( + rx.button("Select File"), + rx.text("Drag and drop files here or click to select files"), + ), + id="secondary", + ), + rx.button( + "Upload", + on_click=UploadState.handle_upload_secondary( # type: ignore + rx.upload_files( + upload_id="secondary", + on_upload_progress=UploadState.upload_progress, + ), + ), + id="upload_button_secondary", + ), + rx.box( + rx.foreach( + rx.selected_files("secondary"), + lambda f: rx.text(f), + ), + id="selected_files_secondary", + ), + rx.button( + "Clear", + on_click=rx.clear_selected_files("secondary"), + id="clear_button_secondary", + ), + rx.vstack( + rx.foreach( + UploadState.progress_dicts, # type: ignore + lambda d: rx.text(d.to_string()), + ) + ), + rx.button( + "Cancel", + on_click=rx.cancel_upload("secondary"), + id="cancel_button_secondary", + ), ) app = rx.App(state=UploadState) @@ -94,14 +153,18 @@ def driver(upload_file: AppHarness): driver.quit() +@pytest.mark.parametrize("secondary", [False, True]) @pytest.mark.asyncio -async def test_upload_file(tmp_path, upload_file: AppHarness, driver): +async def test_upload_file( + tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool +): """Submit a file upload and check that it arrived on the backend. Args: tmp_path: pytest tmp_path fixture upload_file: harness for UploadFile app. driver: WebDriver instance. + secondary: whether to use the secondary upload form """ assert upload_file.app_instance is not None token_input = driver.find_element(By.ID, "token") @@ -110,9 +173,13 @@ async def test_upload_file(tmp_path, upload_file: AppHarness, driver): token = upload_file.poll_for_value(token_input) assert token is not None - upload_box = driver.find_element(By.XPATH, "//input[@type='file']") + suffix = "_secondary" if secondary else "" + + upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[ + 1 if secondary else 0 + ] assert upload_box - upload_button = driver.find_element(By.ID, "upload_button") + upload_button = driver.find_element(By.ID, f"upload_button{suffix}") assert upload_button exp_name = "test.txt" @@ -132,9 +199,15 @@ async def test_upload_file(tmp_path, upload_file: AppHarness, driver): assert file_data[exp_name] == exp_contents # check that the selected files are displayed - selected_files = driver.find_element(By.ID, "selected_files") + selected_files = driver.find_element(By.ID, f"selected_files{suffix}") assert selected_files.text == exp_name + state = await upload_file.get_state(token) + if secondary: + # only the secondary form tracks progress and chain events + assert state.event_order.count("upload_progress") == 1 + assert state.event_order.count("chain_event") == 1 + @pytest.mark.asyncio async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): @@ -186,13 +259,17 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): assert file_data[exp_name] == exp_contents -def test_clear_files(tmp_path, upload_file: AppHarness, driver): +@pytest.mark.parametrize("secondary", [False, True]) +def test_clear_files( + tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool +): """Select then clear several file uploads and check that they are cleared. Args: tmp_path: pytest tmp_path fixture upload_file: harness for UploadFile app. driver: WebDriver instance. + secondary: whether to use the secondary upload form. """ assert upload_file.app_instance is not None token_input = driver.find_element(By.ID, "token") @@ -201,9 +278,13 @@ def test_clear_files(tmp_path, upload_file: AppHarness, driver): token = upload_file.poll_for_value(token_input) assert token is not None - upload_box = driver.find_element(By.XPATH, "//input[@type='file']") + suffix = "_secondary" if secondary else "" + + upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[ + 1 if secondary else 0 + ] assert upload_box - upload_button = driver.find_element(By.ID, "upload_button") + upload_button = driver.find_element(By.ID, f"upload_button{suffix}") assert upload_button exp_files = { @@ -219,13 +300,56 @@ def test_clear_files(tmp_path, upload_file: AppHarness, driver): time.sleep(0.2) # check that the selected files are displayed - selected_files = driver.find_element(By.ID, "selected_files") + selected_files = driver.find_element(By.ID, f"selected_files{suffix}") assert selected_files.text == "\n".join(exp_files) - clear_button = driver.find_element(By.ID, "clear_button") + clear_button = driver.find_element(By.ID, f"clear_button{suffix}") assert clear_button clear_button.click() # check that the selected files are cleared - selected_files = driver.find_element(By.ID, "selected_files") + selected_files = driver.find_element(By.ID, f"selected_files{suffix}") assert selected_files.text == "" + + +# TODO: drag and drop directory +# https://gist.github.com/florentbr/349b1ab024ca9f3de56e6bf8af2ac69e + + +@pytest.mark.asyncio +async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDriver): + """Submit a large file upload and cancel it. + + Args: + tmp_path: pytest tmp_path fixture + upload_file: harness for UploadFile app. + driver: WebDriver instance. + """ + assert upload_file.app_instance is not None + token_input = driver.find_element(By.ID, "token") + assert token_input + # wait for the backend connection to send the token + token = upload_file.poll_for_value(token_input) + assert token is not None + + upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1] + upload_button = driver.find_element(By.ID, f"upload_button_secondary") + cancel_button = driver.find_element(By.ID, f"cancel_button_secondary") + + exp_name = "large.txt" + target_file = tmp_path / exp_name + with target_file.open("wb") as f: + f.seek(1024 * 1024 * 256) + f.write(b"0") + + upload_box.send_keys(str(target_file)) + upload_button.click() + await asyncio.sleep(0.3) + cancel_button.click() + + # look up the backend state and assert on progress + state = await upload_file.get_state(token) + assert state.progress_dicts + assert exp_name not in state._file_data + + target_file.unlink() diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 0243ce31a..22c853dea 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -32,6 +32,11 @@ let event_processing = false // Array holding pending events to be processed. const event_queue = []; +// Pending upload promises, by id +const upload_controllers = {}; +// Upload files state by id +export const upload_files = {}; + /** * Generate a UUID (Used for session tokens). * Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid @@ -235,14 +240,22 @@ export const applyEvent = async (event, socket) => { /** * Send an event to the server via REST. * @param event The current event. - * @param state The state with the event queue. + * @param socket The socket object to send the response event(s) on. * * @returns Whether the event was sent. */ -export const applyRestEvent = async (event) => { +export const applyRestEvent = async (event, socket) => { let eventSent = false; if (event.handler == "uploadFiles") { - eventSent = await uploadFiles(event.name, event.payload.files); + // Start upload, but do not wait for it, which would block other events. + uploadFiles( + event.name, + event.payload.files, + event.payload.upload_id, + event.payload.on_upload_progress, + socket + ); + return false; } return eventSent; }; @@ -283,7 +296,7 @@ export const processEvent = async ( let eventSent = false // Process events with handlers via REST and all others via websockets. if (event.handler) { - eventSent = await applyRestEvent(event); + eventSent = await applyRestEvent(event, socket); } else { eventSent = await applyEvent(event, socket); } @@ -347,50 +360,86 @@ export const connect = async ( * * @param state The state to apply the delta to. * @param handler The handler to use. + * @param upload_id The upload id to use. + * @param on_upload_progress The function to call on upload progress. + * @param socket the websocket connection * - * @returns Whether the files were uploaded. + * @returns The response from posting to the UPLOADURL endpoint. */ -export const uploadFiles = async (handler, files) => { +export const uploadFiles = async (handler, files, upload_id, on_upload_progress, socket) => { // return if there's no file to upload if (files.length == 0) { return false; } - const headers = { - "Content-Type": files[0].type, - }; + if (upload_controllers[upload_id]) { + console.log("Upload already in progress for ", upload_id) + return false; + } + + let resp_idx = 0; + const eventHandler = (progressEvent) => { + // handle any delta / event streamed from the upload event handler + const chunks = progressEvent.event.target.responseText.trim().split("\n") + chunks.slice(resp_idx).map((chunk) => { + try { + socket._callbacks.$event.map((f) => { + f(chunk) + }) + resp_idx += 1 + } catch (e) { + console.log("Error parsing chunk", chunk, e) + return + } + }) + } + + const controller = new AbortController() + const config = { + headers: { + "Reflex-Client-Token": getToken(), + "Reflex-Event-Handler": handler, + }, + signal: controller.signal, + onDownloadProgress: eventHandler, + } + if (on_upload_progress) { + config["onUploadProgress"] = on_upload_progress + } const formdata = new FormData(); // Add the token and handler to the file name. - for (let i = 0; i < files.length; i++) { + files.forEach((file) => { formdata.append( "files", - files[i], - getToken() + ":" + handler + ":" + files[i].name + file, + file.path || file.name ); - } + }) // Send the file to the server. - await axios.post(UPLOADURL, formdata, headers) - .then(() => { return true; }) - .catch( - error => { - if (error.response) { - // The request was made and the server responded with a status code - // that falls out of the range of 2xx - console.log(error.response.data); - } else if (error.request) { - // The request was made but no response was received - // `error.request` is an instance of XMLHttpRequest in the browser and an instance of - // http.ClientRequest in node.js - console.log(error.request); - } else { - // Something happened in setting up the request that triggered an Error - console.log(error.message); - } - return false; - } - ) + upload_controllers[upload_id] = controller + + try { + return await axios.post(UPLOADURL, formdata, config) + } catch (error) { + if (error.response) { + // The request was made and the server responded with a status code + // that falls out of the range of 2xx + console.log(error.response.data); + } else if (error.request) { + // The request was made but no response was received + // `error.request` is an instance of XMLHttpRequest in the browser and an instance of + // http.ClientRequest in node.js + console.log(error.request); + } else { + // Something happened in setting up the request that triggered an Error + console.log(error.message); + } + return false; + } finally { + delete upload_controllers[upload_id] + } }; /** diff --git a/reflex/__init__.py b/reflex/__init__.py index 03b13c527..dd11b73f4 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -229,6 +229,7 @@ _ALL_COMPONENTS = [ _ALL_COMPONENTS += [to_snake_case(component) for component in _ALL_COMPONENTS] _ALL_COMPONENTS += [ + "cancel_upload", "components", "color_mode_cond", "desktop_only", diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index 800645220..1b5d64d9b 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -58,6 +58,9 @@ from reflex.components import ConnectionModal as ConnectionModal from reflex.components import Container as Container from reflex.components import DataTable as DataTable from reflex.components import DataEditor as DataEditor +from reflex.components import DataEditorTheme as DataEditorTheme +from reflex.components import DatePicker as DatePicker +from reflex.components import DateTimePicker as DateTimePicker from reflex.components import DebounceInput as DebounceInput from reflex.components import Divider as Divider from reflex.components import Drawer as Drawer @@ -265,6 +268,9 @@ from reflex.components import connection_modal as connection_modal from reflex.components import container as container from reflex.components import data_table as data_table from reflex.components import data_editor as data_editor +from reflex.components import data_editor_theme as data_editor_theme +from reflex.components import date_picker as date_picker +from reflex.components import date_time_picker as date_time_picker from reflex.components import debounce_input as debounce_input from reflex.components import divider as divider from reflex.components import drawer as drawer @@ -421,7 +427,9 @@ from reflex.components import visually_hidden as visually_hidden from reflex.components import vstack as vstack from reflex.components import wrap as wrap from reflex.components import wrap_item as wrap_item +from reflex.components import cancel_upload as cancel_upload from reflex import components as components +from reflex.components import color_mode_cond as color_mode_cond from reflex.components import desktop_only as desktop_only from reflex.components import mobile_only as mobile_only from reflex.components import tablet_only as tablet_only @@ -429,7 +437,9 @@ from reflex.components import mobile_and_tablet as mobile_and_tablet from reflex.components import tablet_and_desktop as tablet_and_desktop from reflex.components import selected_files as selected_files from reflex.components import clear_selected_files as clear_selected_files +from reflex.components import EditorButtonList as EditorButtonList from reflex.components import EditorOptions as EditorOptions +from reflex.components import NoSSRComponent as NoSSRComponent from reflex.components.component import memo as memo from reflex.components.graphing import recharts as recharts from reflex import config as config diff --git a/reflex/app.py b/reflex/app.py index ab53672a5..477d06511 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio import contextlib -import inspect +import functools import os from multiprocessing.pool import ThreadPool from typing import ( @@ -17,10 +17,13 @@ from typing import ( Set, Type, Union, + get_args, + get_type_hints, ) -from fastapi import FastAPI, UploadFile +from fastapi import FastAPI, HTTPException, Request, UploadFile from fastapi.middleware import cors +from fastapi.responses import StreamingResponse from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp, AsyncNamespace, AsyncServer from starlette_admin.contrib.sqla.admin import Admin @@ -880,62 +883,90 @@ def upload(app: App): The upload function. """ - async def upload_file(files: List[UploadFile]): + async def upload_file(request: Request, files: List[UploadFile]): """Upload a file. Args: + request: The FastAPI request object. files: The file(s) to upload. + Returns: + StreamingResponse yielding newline-delimited JSON of StateUpdate + emitted by the upload handler. + Raises: ValueError: if there are no args with supported annotation. + TypeError: if a background task is used as the handler. + HTTPException: when the request does not include token / handler headers. """ - assert files[0].filename is not None - token, handler = files[0].filename.split(":")[:2] - for file in files: - assert file.filename is not None - file.filename = file.filename.split(":")[-1] + token = request.headers.get("reflex-client-token") + handler = request.headers.get("reflex-event-handler") + + if not token or not handler: + raise HTTPException( + status_code=400, + detail="Missing reflex-client-token or reflex-event-handler header.", + ) # Get the state for the session. - async with app.state_manager.modify_state(token) as state: - # get the current session ID - sid = state.router.session.session_id - # get the current state(parent state/substate) - path = handler.split(".")[:-1] - current_state = state.get_substate(path) - handler_upload_param = () + state = await app.state_manager.get_state(token) - # get handler function - func = getattr(current_state, handler.split(".")[-1]) + # get the current session ID + # get the current state(parent state/substate) + path = handler.split(".")[:-1] + current_state = state.get_substate(path) + handler_upload_param = () - # check if there exists any handler args with annotation, List[UploadFile] - for k, v in inspect.getfullargspec( - func.fn if isinstance(func, EventHandler) else func - ).annotations.items(): - if types.is_generic_alias(v) and types._issubclass( - v.__args__[0], UploadFile - ): - handler_upload_param = (k, v) - break + # get handler function + func = getattr(type(current_state), handler.split(".")[-1]) - if not handler_upload_param: - raise ValueError( - f"`{handler}` handler should have a parameter annotated as List[" - f"rx.UploadFile]" + # check if there exists any handler args with annotation, List[UploadFile] + if isinstance(func, EventHandler): + if func.is_background: + raise TypeError( + f"@rx.background is not supported for upload handler `{handler}`.", ) + func = func.fn + if isinstance(func, functools.partial): + func = func.func + for k, v in get_type_hints(func).items(): + if types.is_generic_alias(v) and types._issubclass( + get_args(v)[0], + UploadFile, + ): + handler_upload_param = (k, v) + break - event = Event( - token=token, - name=handler, - payload={handler_upload_param[0]: files}, + if not handler_upload_param: + raise ValueError( + f"`{handler}` handler should have a parameter annotated as " + "List[rx.UploadFile]" ) - async for update in state._process(event): - # Postprocess the event. - update = await app.postprocess(state, event, update) - # Send update to client - await app.event_namespace.emit_update( # type: ignore - update=update, - sid=sid, - ) + + event = Event( + token=token, + name=handler, + payload={handler_upload_param[0]: files}, + ) + + async def _ndjson_updates(): + """Process the upload event, generating ndjson updates. + + Yields: + Each state update as JSON followed by a new line. + """ + # Process the event. + async with app.state_manager.modify_state(token) as state: + async for update in state._process(event): + # Postprocess the event. + update = await app.postprocess(state, event, update) + yield update.json() + "\n" + + # Stream updates to client + return StreamingResponse( + _ndjson_updates(), + media_type="application/x-ndjson", + ) return upload_file diff --git a/reflex/components/forms/__init__.py b/reflex/components/forms/__init__.py index d5ff4a164..05b7a2c12 100644 --- a/reflex/components/forms/__init__.py +++ b/reflex/components/forms/__init__.py @@ -46,12 +46,18 @@ from .select import Option, Select from .slider import Slider, SliderFilledTrack, SliderMark, SliderThumb, SliderTrack from .switch import Switch from .textarea import TextArea -from .upload import Upload, clear_selected_files, selected_files +from .upload import ( + Upload, + cancel_upload, + clear_selected_files, + selected_files, +) helpers = [ "color_mode_cond", - "selected_files", + "cancel_upload", "clear_selected_files", + "selected_files", ] __all__ = [f for f in dir() if f[0].isupper()] + helpers # type: ignore diff --git a/reflex/components/forms/upload.py b/reflex/components/forms/upload.py index d1a91adc7..21457e64e 100644 --- a/reflex/components/forms/upload.py +++ b/reflex/components/forms/upload.py @@ -3,26 +3,75 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Union +from reflex import constants from reflex.components.component import Component from reflex.components.forms.input import Input from reflex.components.layout.box import Box -from reflex.constants import EventTriggers -from reflex.event import EventChain -from reflex.vars import BaseVar, Var +from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script +from reflex.utils import imports +from reflex.vars import BaseVar, CallableVar, ImportVar, Var -files_state: str = "const [files, setFiles] = useState([]);" -upload_file: BaseVar = BaseVar( - _var_name="e => setFiles((files) => e)", _var_type=EventChain -) +DEFAULT_UPLOAD_ID: str = "default" -# Use this var along with the Upload component to render the list of selected files. -selected_files: BaseVar = BaseVar( - _var_name="files.map((f) => f.name)", _var_type=List[str] -) -clear_selected_files: BaseVar = BaseVar( - _var_name="_e => setFiles((files) => [])", _var_type=EventChain -) +@CallableVar +def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: + """Get the file upload drop trigger. + + This var is passed to the dropzone component to update the file list when a + drop occurs. + + Args: + id_: The id of the upload to get the drop trigger for. + + Returns: + A var referencing the file upload drop trigger. + """ + return BaseVar( + _var_name=f"e => upload_files.{id_}[1]((files) => e)", + _var_type=EventChain, + ) + + +@CallableVar +def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: + """Get the list of selected files. + + Args: + id_: The id of the upload to get the selected files for. + + Returns: + A var referencing the list of selected file paths. + """ + return BaseVar( + _var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])", + _var_type=List[str], + ) + + +@CallableEventSpec +def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: + """Clear the list of selected files. + + Args: + id_: The id of the upload to clear. + + Returns: + An event spec that clears the list of selected files when triggered. + """ + return call_script(f"upload_files.{id_}[1]((files) => [])") + + +def cancel_upload(upload_id: str) -> EventSpec: + """Cancel an upload. + + Args: + upload_id: The id of the upload to cancel. + + Returns: + An event spec that cancels the upload when triggered. + """ + return call_script(f"upload_controllers[{upload_id!r}]?.abort()") class Upload(Component): @@ -94,7 +143,10 @@ class Upload(Component): zone.special_props = {BaseVar(_var_name="{...getRootProps()}", _var_type=None)} # Create the component. - return super().create(zone, on_drop=upload_file, **upload_props) + upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID) + return super().create( + zone, on_drop=upload_file(upload_props["id"]), **upload_props + ) def get_event_triggers(self) -> dict[str, Union[Var, Any]]: """Get the event triggers that pass the component's value to the handler. @@ -104,7 +156,7 @@ class Upload(Component): """ return { **super().get_event_triggers(), - EventTriggers.ON_DROP: lambda e0: [e0], + constants.EventTriggers.ON_DROP: lambda e0: [e0], } def _render(self): @@ -113,4 +165,15 @@ class Upload(Component): return out def _get_hooks(self) -> str | None: - return (super()._get_hooks() or "") + files_state + return ( + (super()._get_hooks() or "") + + f""" + upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]); + """ + ) + + def _get_imports(self) -> imports.ImportDict: + return { + **super()._get_imports(), + f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="upload_files")}, + } diff --git a/reflex/components/forms/upload.pyi b/reflex/components/forms/upload.pyi index 7323807fb..87e5d58c7 100644 --- a/reflex/components/forms/upload.pyi +++ b/reflex/components/forms/upload.pyi @@ -8,17 +8,23 @@ from reflex.vars import Var, BaseVar, ComputedVar from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style from typing import Any, Dict, List, Optional, Union +from reflex import constants from reflex.components.component import Component from reflex.components.forms.input import Input from reflex.components.layout.box import Box -from reflex.constants import EventTriggers -from reflex.event import EventChain -from reflex.vars import BaseVar, Var +from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script +from reflex.utils import imports +from reflex.vars import BaseVar, CallableVar, ImportVar, Var -files_state: str -upload_file: BaseVar -selected_files: BaseVar -clear_selected_files: BaseVar +DEFAULT_UPLOAD_ID: str + +@CallableVar +def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ... +@CallableVar +def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ... +@CallableEventSpec +def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ... +def cancel_upload(upload_id: str) -> EventSpec: ... class Upload(Component): @overload diff --git a/reflex/event.py b/reflex/event.py index 3fa586a4a..9c0c8e13c 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -174,13 +174,7 @@ class EventHandler(EventActionsMixin): for arg in args: # Special case for file uploads. if isinstance(arg, FileUpload): - return EventSpec( - handler=self, - client_handler_name="uploadFiles", - # `files` is defined in the Upload component's _use_hooks - args=((Var.create_safe("files"), Var.create_safe("files")),), - event_actions=self.event_actions.copy(), - ) + return arg.as_event_spec(handler=self) # Otherwise, convert to JSON. try: @@ -236,6 +230,50 @@ class EventSpec(EventActionsMixin): ) +class CallableEventSpec(EventSpec): + """Decorate an EventSpec-returning function to act as both a EventSpec and a function. + + This is used as a compatibility shim for replacing EventSpec objects in the + API with functions that return a family of EventSpec. + """ + + fn: Optional[Callable[..., EventSpec]] = None + + def __init__(self, fn: Callable[..., EventSpec] | None = None, **kwargs): + """Initialize a CallableEventSpec. + + Args: + fn: The function to decorate. + **kwargs: The kwargs to pass to pydantic initializer + """ + if fn is not None: + default_event_spec = fn() + super().__init__( + fn=fn, # type: ignore + **default_event_spec.dict(), + **kwargs, + ) + else: + super().__init__(**kwargs) + + def __call__(self, *args, **kwargs) -> EventSpec: + """Call the decorated function. + + Args: + *args: The args to pass to the function. + **kwargs: The kwargs to pass to the function. + + Returns: + The EventSpec returned from calling the function. + + Raises: + TypeError: If the CallableEventSpec has no associated function. + """ + if self.fn is None: + raise TypeError("CallableEventSpec has no associated function.") + return self.fn(*args, **kwargs) + + class EventChain(EventActionsMixin): """Container for a chain of events that will be executed in order.""" @@ -267,7 +305,76 @@ class FrontendEvent(Base): class FileUpload(Base): """Class to represent a file upload.""" - pass + upload_id: Optional[str] = None + on_upload_progress: Optional[Union[EventHandler, Callable]] = None + + @staticmethod + def on_upload_progress_args_spec(_prog: dict[str, int | float | bool]): + """Args spec for on_upload_progress event handler. + + Returns: + The arg mapping passed to backend event handler + """ + return [_prog] + + def as_event_spec(self, handler: EventHandler) -> EventSpec: + """Get the EventSpec for the file upload. + + Args: + handler: The event handler. + + Returns: + The event spec for the handler. + + Raises: + ValueError: If the on_upload_progress is not a valid event handler. + """ + from reflex.components.forms.upload import DEFAULT_UPLOAD_ID + + upload_id = self.upload_id or DEFAULT_UPLOAD_ID + + spec_args = [ + # `upload_files` is defined in state.js and assigned in the Upload component's _use_hooks + (Var.create_safe("files"), Var.create_safe(f"upload_files.{upload_id}[0]")), + ( + Var.create_safe("upload_id"), + Var.create_safe(upload_id, _var_is_string=True), + ), + ] + if self.on_upload_progress is not None: + on_upload_progress = self.on_upload_progress + if isinstance(on_upload_progress, EventHandler): + events = [ + call_event_handler( + on_upload_progress, + self.on_upload_progress_args_spec, + ), + ] + elif isinstance(on_upload_progress, Callable): + # Call the lambda to get the event chain. + events = call_event_fn(on_upload_progress, self.on_upload_progress_args_spec) # type: ignore + else: + raise ValueError(f"{on_upload_progress} is not a valid event handler.") + on_upload_progress_chain = EventChain( + events=events, + args_spec=self.on_upload_progress_args_spec, + ) + formatted_chain = str(format.format_prop(on_upload_progress_chain)) + spec_args.append( + ( + Var.create_safe("on_upload_progress"), + BaseVar( + _var_name=formatted_chain.strip("{}"), + _var_type=EventChain, + ), + ), + ) + return EventSpec( + handler=handler, + client_handler_name="uploadFiles", + args=tuple(spec_args), + event_actions=handler.event_actions.copy(), + ) # Alias for rx.upload_files diff --git a/reflex/vars.py b/reflex/vars.py index 33d37ebb0..468dc2a24 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1590,3 +1590,33 @@ class NoRenderImportVar(ImportVar): """A import that doesn't need to be rendered.""" render: Optional[bool] = False + + +class CallableVar(BaseVar): + """Decorate a Var-returning function to act as both a Var and a function. + + This is used as a compatibility shim for replacing Var objects in the + API with functions that return a family of Var. + """ + + def __init__(self, fn: Callable[..., BaseVar]): + """Initialize a CallableVar. + + Args: + fn: The function to decorate (must return Var) + """ + self.fn = fn + default_var = fn() + super().__init__(**dataclasses.asdict(default_var)) + + def __call__(self, *args, **kwargs) -> BaseVar: + """Call the decorated function. + + Args: + *args: The args to pass to the function. + **kwargs: The kwargs to pass to the function. + + Returns: + The Var returned from calling the function. + """ + return self.fn(*args, **kwargs) diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 9df2f7ece..001ba9aa3 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -137,3 +137,7 @@ class NoRenderImportVar(ImportVar): """A import that doesn't need to be rendered.""" def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ... + +class CallableVar(BaseVar): + def __init__(self, fn: Callable[..., BaseVar]): ... + def __call__(self, *args, **kwargs) -> BaseVar: ... diff --git a/tests/components/forms/test_uploads.py b/tests/components/forms/test_uploads.py index d59d760bd..fee58ec3b 100644 --- a/tests/components/forms/test_uploads.py +++ b/tests/components/forms/test_uploads.py @@ -52,8 +52,10 @@ def test_upload_component_render(upload_component): # upload assert upload["name"] == "ReactDropzone" assert upload["props"] == [ + "id={`default`}", "multiple={true}", - "onDrop={e => setFiles((files) => e)}", + "onDrop={e => upload_files.default[1]((files) => e)}", + "ref={ref_default}", ] assert upload["args"] == ("getRootProps", "getInputProps") @@ -89,8 +91,10 @@ def test_upload_component_with_props_render(upload_component_with_props): upload = upload_component_with_props.render() assert upload["props"] == [ + "id={`default`}", "maxFiles={2}", "multiple={true}", "noDrag={true}", - "onDrop={e => setFiles((files) => e)}", + "onDrop={e => upload_files.default[1]((files) => e)}", + "ref={ref_default}", ] diff --git a/tests/states/upload.py b/tests/states/upload.py index 893947930..ec2585dd1 100644 --- a/tests/states/upload.py +++ b/tests/states/upload.py @@ -49,16 +49,7 @@ class FileUploadState(rx.State): Args: files: The uploaded files. """ - for file in files: - upload_data = await file.read() - outfile = f"{self._tmp_path}/{file.filename}" - - # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) - - # Update the img var. - self.img_list.append(file.filename) + pass async def multi_handle_upload(self, files: List[rx.UploadFile]): """Handle the upload of a file. @@ -78,6 +69,15 @@ class FileUploadState(rx.State): assert file.filename is not None self.img_list.append(file.filename) + @rx.background + async def bg_upload(self, files: List[rx.UploadFile]): + """Background task cannot be upload handler. + + Args: + files: The uploaded files. + """ + pass + class FileStateBase1(rx.State): """The base state for a child FileUploadState.""" @@ -97,16 +97,7 @@ class ChildFileUploadState(FileStateBase1): Args: files: The uploaded files. """ - for file in files: - upload_data = await file.read() - outfile = f"{self._tmp_path}/{file.filename}" - - # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) - - # Update the img var. - self.img_list.append(file.filename) + pass async def multi_handle_upload(self, files: List[rx.UploadFile]): """Handle the upload of a file. @@ -126,6 +117,15 @@ class ChildFileUploadState(FileStateBase1): assert file.filename is not None self.img_list.append(file.filename) + @rx.background + async def bg_upload(self, files: List[rx.UploadFile]): + """Background task cannot be upload handler. + + Args: + files: The uploaded files. + """ + pass + class FileStateBase2(FileStateBase1): """The parent state for a grandchild FileUploadState.""" @@ -145,16 +145,7 @@ class GrandChildFileUploadState(FileStateBase2): Args: files: The uploaded files. """ - for file in files: - upload_data = await file.read() - outfile = f"{self._tmp_path}/{file.filename}" - - # Save the file. - with open(outfile, "wb") as file_object: - file_object.write(upload_data) - - # Update the img var. - self.img_list.append(file.filename) + pass async def multi_handle_upload(self, files: List[rx.UploadFile]): """Handle the upload of a file. @@ -173,3 +164,12 @@ class GrandChildFileUploadState(FileStateBase2): # Update the img var. assert file.filename is not None self.img_list.append(file.filename) + + @rx.background + async def bg_upload(self, files: List[rx.UploadFile]): + """Background task cannot be upload handler. + + Args: + files: The uploaded files. + """ + pass diff --git a/tests/test_app.py b/tests/test_app.py index 619b2775f..434eeb60f 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -746,23 +746,28 @@ async def test_upload_file(tmp_path, state, delta, token: str): bio.write(data) state_name = state.get_full_name().partition(".")[2] or state.get_name() - handler_prefix = f"{token}:{state_name}" + request_mock = unittest.mock.Mock() + request_mock.headers = { + "reflex-client-token": token, + "reflex-event-handler": f"{state_name}.multi_handle_upload", + } file1 = UploadFile( - filename=f"{handler_prefix}.multi_handle_upload:True:image1.jpg", + filename=f"image1.jpg", file=bio, ) file2 = UploadFile( - filename=f"{handler_prefix}.multi_handle_upload:True:image2.jpg", + filename=f"image2.jpg", file=bio, ) upload_fn = upload(app) - await upload_fn([file1, file2]) - state_update = StateUpdate(delta=delta, events=[], final=True) + streaming_response = await upload_fn(request_mock, [file1, file2]) + async for state_update in streaming_response.body_iterator: + assert ( + state_update + == StateUpdate(delta=delta, events=[], final=True).json() + "\n" + ) - app.event_namespace.emit.assert_called_with( # type: ignore - "event", state_update.json(), to=current_state.router.session.session_id - ) current_state = await app.state_manager.get_state(token) state_dict = current_state.dict() for substate in state.get_full_name().split(".")[1:]: @@ -789,30 +794,20 @@ async def test_upload_file_without_annotation(state, tmp_path, token): tmp_path: Temporary path. token: a Token. """ - data = b"This is binary data" - - # Create a binary IO object and write data to it - bio = io.BytesIO() - bio.write(data) - state._tmp_path = tmp_path # The App state must be the "root" of the state tree app = App(state=state if state is FileUploadState else FileStateBase1) state_name = state.get_full_name().partition(".")[2] or state.get_name() - handler_prefix = f"{token}:{state_name}" - - file1 = UploadFile( - filename=f"{handler_prefix}.handle_upload2:True:image1.jpg", - file=bio, - ) - file2 = UploadFile( - filename=f"{handler_prefix}.handle_upload2:True:image2.jpg", - file=bio, - ) + request_mock = unittest.mock.Mock() + request_mock.headers = { + "reflex-client-token": token, + "reflex-event-handler": f"{state_name}.handle_upload2", + } + file_mock = unittest.mock.Mock(filename="image1.jpg") fn = upload(app) with pytest.raises(ValueError) as err: - await fn([file1, file2]) + await fn(request_mock, [file_mock]) assert ( err.value.args[0] == f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]" @@ -822,6 +817,42 @@ async def test_upload_file_without_annotation(state, tmp_path, token): await app.state_manager.redis.close() +@pytest.mark.asyncio +@pytest.mark.parametrize( + "state", + [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], +) +async def test_upload_file_background(state, tmp_path, token): + """Test that an error is thrown handler is a background task. + + Args: + state: The state class. + tmp_path: Temporary path. + token: a Token. + """ + state._tmp_path = tmp_path + # The App state must be the "root" of the state tree + app = App(state=state if state is FileUploadState else FileStateBase1) + + state_name = state.get_full_name().partition(".")[2] or state.get_name() + request_mock = unittest.mock.Mock() + request_mock.headers = { + "reflex-client-token": token, + "reflex-event-handler": f"{state_name}.bg_upload", + } + file_mock = unittest.mock.Mock(filename="image1.jpg") + fn = upload(app) + with pytest.raises(TypeError) as err: + await fn(request_mock, [file_mock]) + assert ( + err.value.args[0] + == f"@rx.background is not supported for upload handler `{state_name}.bg_upload`." + ) + + if isinstance(app.state_manager, StateManagerRedis): + await app.state_manager.redis.close() + + class DynamicState(State): """State class for testing dynamic route var.