[REF-723+] Upload with progress and cancellation (#1899)

This commit is contained in:
Masen Furer 2023-11-16 15:46:13 -08:00 committed by GitHub
parent e399b5a98c
commit 7eccc6d988
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 643 additions and 177 deletions

View File

@ -1,13 +1,14 @@
"""Integration tests for file upload.""" """Integration tests for file upload."""
from __future__ import annotations from __future__ import annotations
import asyncio
import time import time
from typing import Generator from typing import Generator
import pytest import pytest
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from reflex.testing import AppHarness from reflex.testing import AppHarness, WebDriver
def UploadFile(): def UploadFile():
@ -16,12 +17,28 @@ def UploadFile():
class UploadState(rx.State): class UploadState(rx.State):
_file_data: dict[str, str] = {} _file_data: dict[str, str] = {}
event_order: list[str] = []
progress_dicts: list[dict] = []
async def handle_upload(self, files: list[rx.UploadFile]): async def handle_upload(self, files: list[rx.UploadFile]):
for file in files: for file in files:
upload_data = await file.read() upload_data = await file.read()
self._file_data[file.filename or ""] = upload_data.decode("utf-8") self._file_data[file.filename or ""] = upload_data.decode("utf-8")
async def handle_upload_secondary(self, files: list[rx.UploadFile]):
for file in files:
upload_data = await file.read()
self._file_data[file.filename or ""] = upload_data.decode("utf-8")
yield UploadState.chain_event
def upload_progress(self, progress):
assert progress
self.event_order.append("upload_progress")
self.progress_dicts.append(progress)
def chain_event(self):
self.event_order.append("chain_event")
def index(): def index():
return rx.vstack( return rx.vstack(
rx.input( rx.input(
@ -29,6 +46,7 @@ def UploadFile():
is_read_only=True, is_read_only=True,
id="token", id="token",
), ),
rx.heading("Default Upload"),
rx.upload( rx.upload(
rx.vstack( rx.vstack(
rx.button("Select File"), rx.button("Select File"),
@ -52,6 +70,47 @@ def UploadFile():
on_click=rx.clear_selected_files, on_click=rx.clear_selected_files,
id="clear_button", id="clear_button",
), ),
rx.heading("Secondary Upload"),
rx.upload(
rx.vstack(
rx.button("Select File"),
rx.text("Drag and drop files here or click to select files"),
),
id="secondary",
),
rx.button(
"Upload",
on_click=UploadState.handle_upload_secondary( # type: ignore
rx.upload_files(
upload_id="secondary",
on_upload_progress=UploadState.upload_progress,
),
),
id="upload_button_secondary",
),
rx.box(
rx.foreach(
rx.selected_files("secondary"),
lambda f: rx.text(f),
),
id="selected_files_secondary",
),
rx.button(
"Clear",
on_click=rx.clear_selected_files("secondary"),
id="clear_button_secondary",
),
rx.vstack(
rx.foreach(
UploadState.progress_dicts, # type: ignore
lambda d: rx.text(d.to_string()),
)
),
rx.button(
"Cancel",
on_click=rx.cancel_upload("secondary"),
id="cancel_button_secondary",
),
) )
app = rx.App(state=UploadState) app = rx.App(state=UploadState)
@ -94,14 +153,18 @@ def driver(upload_file: AppHarness):
driver.quit() driver.quit()
@pytest.mark.parametrize("secondary", [False, True])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_upload_file(tmp_path, upload_file: AppHarness, driver): async def test_upload_file(
tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool
):
"""Submit a file upload and check that it arrived on the backend. """Submit a file upload and check that it arrived on the backend.
Args: Args:
tmp_path: pytest tmp_path fixture tmp_path: pytest tmp_path fixture
upload_file: harness for UploadFile app. upload_file: harness for UploadFile app.
driver: WebDriver instance. driver: WebDriver instance.
secondary: whether to use the secondary upload form
""" """
assert upload_file.app_instance is not None assert upload_file.app_instance is not None
token_input = driver.find_element(By.ID, "token") token_input = driver.find_element(By.ID, "token")
@ -110,9 +173,13 @@ async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
token = upload_file.poll_for_value(token_input) token = upload_file.poll_for_value(token_input)
assert token is not None assert token is not None
upload_box = driver.find_element(By.XPATH, "//input[@type='file']") suffix = "_secondary" if secondary else ""
upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[
1 if secondary else 0
]
assert upload_box assert upload_box
upload_button = driver.find_element(By.ID, "upload_button") upload_button = driver.find_element(By.ID, f"upload_button{suffix}")
assert upload_button assert upload_button
exp_name = "test.txt" exp_name = "test.txt"
@ -132,9 +199,15 @@ async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
assert file_data[exp_name] == exp_contents assert file_data[exp_name] == exp_contents
# check that the selected files are displayed # check that the selected files are displayed
selected_files = driver.find_element(By.ID, "selected_files") selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
assert selected_files.text == exp_name assert selected_files.text == exp_name
state = await upload_file.get_state(token)
if secondary:
# only the secondary form tracks progress and chain events
assert state.event_order.count("upload_progress") == 1
assert state.event_order.count("chain_event") == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
@ -186,13 +259,17 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
assert file_data[exp_name] == exp_contents assert file_data[exp_name] == exp_contents
def test_clear_files(tmp_path, upload_file: AppHarness, driver): @pytest.mark.parametrize("secondary", [False, True])
def test_clear_files(
tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool
):
"""Select then clear several file uploads and check that they are cleared. """Select then clear several file uploads and check that they are cleared.
Args: Args:
tmp_path: pytest tmp_path fixture tmp_path: pytest tmp_path fixture
upload_file: harness for UploadFile app. upload_file: harness for UploadFile app.
driver: WebDriver instance. driver: WebDriver instance.
secondary: whether to use the secondary upload form.
""" """
assert upload_file.app_instance is not None assert upload_file.app_instance is not None
token_input = driver.find_element(By.ID, "token") token_input = driver.find_element(By.ID, "token")
@ -201,9 +278,13 @@ def test_clear_files(tmp_path, upload_file: AppHarness, driver):
token = upload_file.poll_for_value(token_input) token = upload_file.poll_for_value(token_input)
assert token is not None assert token is not None
upload_box = driver.find_element(By.XPATH, "//input[@type='file']") suffix = "_secondary" if secondary else ""
upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[
1 if secondary else 0
]
assert upload_box assert upload_box
upload_button = driver.find_element(By.ID, "upload_button") upload_button = driver.find_element(By.ID, f"upload_button{suffix}")
assert upload_button assert upload_button
exp_files = { exp_files = {
@ -219,13 +300,56 @@ def test_clear_files(tmp_path, upload_file: AppHarness, driver):
time.sleep(0.2) time.sleep(0.2)
# check that the selected files are displayed # check that the selected files are displayed
selected_files = driver.find_element(By.ID, "selected_files") selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
assert selected_files.text == "\n".join(exp_files) assert selected_files.text == "\n".join(exp_files)
clear_button = driver.find_element(By.ID, "clear_button") clear_button = driver.find_element(By.ID, f"clear_button{suffix}")
assert clear_button assert clear_button
clear_button.click() clear_button.click()
# check that the selected files are cleared # check that the selected files are cleared
selected_files = driver.find_element(By.ID, "selected_files") selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
assert selected_files.text == "" assert selected_files.text == ""
# TODO: drag and drop directory
# https://gist.github.com/florentbr/349b1ab024ca9f3de56e6bf8af2ac69e
@pytest.mark.asyncio
async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDriver):
"""Submit a large file upload and cancel it.
Args:
tmp_path: pytest tmp_path fixture
upload_file: harness for UploadFile app.
driver: WebDriver instance.
"""
assert upload_file.app_instance is not None
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = upload_file.poll_for_value(token_input)
assert token is not None
upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1]
upload_button = driver.find_element(By.ID, f"upload_button_secondary")
cancel_button = driver.find_element(By.ID, f"cancel_button_secondary")
exp_name = "large.txt"
target_file = tmp_path / exp_name
with target_file.open("wb") as f:
f.seek(1024 * 1024 * 256)
f.write(b"0")
upload_box.send_keys(str(target_file))
upload_button.click()
await asyncio.sleep(0.3)
cancel_button.click()
# look up the backend state and assert on progress
state = await upload_file.get_state(token)
assert state.progress_dicts
assert exp_name not in state._file_data
target_file.unlink()

View File

@ -32,6 +32,11 @@ let event_processing = false
// Array holding pending events to be processed. // Array holding pending events to be processed.
const event_queue = []; const event_queue = [];
// Pending upload promises, by id
const upload_controllers = {};
// Upload files state by id
export const upload_files = {};
/** /**
* Generate a UUID (Used for session tokens). * Generate a UUID (Used for session tokens).
* Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid * Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid
@ -235,14 +240,22 @@ export const applyEvent = async (event, socket) => {
/** /**
* Send an event to the server via REST. * Send an event to the server via REST.
* @param event The current event. * @param event The current event.
* @param state The state with the event queue. * @param socket The socket object to send the response event(s) on.
* *
* @returns Whether the event was sent. * @returns Whether the event was sent.
*/ */
export const applyRestEvent = async (event) => { export const applyRestEvent = async (event, socket) => {
let eventSent = false; let eventSent = false;
if (event.handler == "uploadFiles") { if (event.handler == "uploadFiles") {
eventSent = await uploadFiles(event.name, event.payload.files); // Start upload, but do not wait for it, which would block other events.
uploadFiles(
event.name,
event.payload.files,
event.payload.upload_id,
event.payload.on_upload_progress,
socket
);
return false;
} }
return eventSent; return eventSent;
}; };
@ -283,7 +296,7 @@ export const processEvent = async (
let eventSent = false let eventSent = false
// Process events with handlers via REST and all others via websockets. // Process events with handlers via REST and all others via websockets.
if (event.handler) { if (event.handler) {
eventSent = await applyRestEvent(event); eventSent = await applyRestEvent(event, socket);
} else { } else {
eventSent = await applyEvent(event, socket); eventSent = await applyEvent(event, socket);
} }
@ -347,50 +360,86 @@ export const connect = async (
* *
* @param state The state to apply the delta to. * @param state The state to apply the delta to.
* @param handler The handler to use. * @param handler The handler to use.
* @param upload_id The upload id to use.
* @param on_upload_progress The function to call on upload progress.
* @param socket the websocket connection
* *
* @returns Whether the files were uploaded. * @returns The response from posting to the UPLOADURL endpoint.
*/ */
export const uploadFiles = async (handler, files) => { export const uploadFiles = async (handler, files, upload_id, on_upload_progress, socket) => {
// return if there's no file to upload // return if there's no file to upload
if (files.length == 0) { if (files.length == 0) {
return false; return false;
} }
const headers = { if (upload_controllers[upload_id]) {
"Content-Type": files[0].type, console.log("Upload already in progress for ", upload_id)
}; return false;
}
let resp_idx = 0;
const eventHandler = (progressEvent) => {
// handle any delta / event streamed from the upload event handler
const chunks = progressEvent.event.target.responseText.trim().split("\n")
chunks.slice(resp_idx).map((chunk) => {
try {
socket._callbacks.$event.map((f) => {
f(chunk)
})
resp_idx += 1
} catch (e) {
console.log("Error parsing chunk", chunk, e)
return
}
})
}
const controller = new AbortController()
const config = {
headers: {
"Reflex-Client-Token": getToken(),
"Reflex-Event-Handler": handler,
},
signal: controller.signal,
onDownloadProgress: eventHandler,
}
if (on_upload_progress) {
config["onUploadProgress"] = on_upload_progress
}
const formdata = new FormData(); const formdata = new FormData();
// Add the token and handler to the file name. // Add the token and handler to the file name.
for (let i = 0; i < files.length; i++) { files.forEach((file) => {
formdata.append( formdata.append(
"files", "files",
files[i], file,
getToken() + ":" + handler + ":" + files[i].name file.path || file.name
); );
} })
// Send the file to the server. // Send the file to the server.
await axios.post(UPLOADURL, formdata, headers) upload_controllers[upload_id] = controller
.then(() => { return true; })
.catch( try {
error => { return await axios.post(UPLOADURL, formdata, config)
if (error.response) { } catch (error) {
// The request was made and the server responded with a status code if (error.response) {
// that falls out of the range of 2xx // The request was made and the server responded with a status code
console.log(error.response.data); // that falls out of the range of 2xx
} else if (error.request) { console.log(error.response.data);
// The request was made but no response was received } else if (error.request) {
// `error.request` is an instance of XMLHttpRequest in the browser and an instance of // The request was made but no response was received
// http.ClientRequest in node.js // `error.request` is an instance of XMLHttpRequest in the browser and an instance of
console.log(error.request); // http.ClientRequest in node.js
} else { console.log(error.request);
// Something happened in setting up the request that triggered an Error } else {
console.log(error.message); // Something happened in setting up the request that triggered an Error
} console.log(error.message);
return false; }
} return false;
) } finally {
delete upload_controllers[upload_id]
}
}; };
/** /**

View File

@ -229,6 +229,7 @@ _ALL_COMPONENTS = [
_ALL_COMPONENTS += [to_snake_case(component) for component in _ALL_COMPONENTS] _ALL_COMPONENTS += [to_snake_case(component) for component in _ALL_COMPONENTS]
_ALL_COMPONENTS += [ _ALL_COMPONENTS += [
"cancel_upload",
"components", "components",
"color_mode_cond", "color_mode_cond",
"desktop_only", "desktop_only",

View File

@ -58,6 +58,9 @@ from reflex.components import ConnectionModal as ConnectionModal
from reflex.components import Container as Container from reflex.components import Container as Container
from reflex.components import DataTable as DataTable from reflex.components import DataTable as DataTable
from reflex.components import DataEditor as DataEditor from reflex.components import DataEditor as DataEditor
from reflex.components import DataEditorTheme as DataEditorTheme
from reflex.components import DatePicker as DatePicker
from reflex.components import DateTimePicker as DateTimePicker
from reflex.components import DebounceInput as DebounceInput from reflex.components import DebounceInput as DebounceInput
from reflex.components import Divider as Divider from reflex.components import Divider as Divider
from reflex.components import Drawer as Drawer from reflex.components import Drawer as Drawer
@ -265,6 +268,9 @@ from reflex.components import connection_modal as connection_modal
from reflex.components import container as container from reflex.components import container as container
from reflex.components import data_table as data_table from reflex.components import data_table as data_table
from reflex.components import data_editor as data_editor from reflex.components import data_editor as data_editor
from reflex.components import data_editor_theme as data_editor_theme
from reflex.components import date_picker as date_picker
from reflex.components import date_time_picker as date_time_picker
from reflex.components import debounce_input as debounce_input from reflex.components import debounce_input as debounce_input
from reflex.components import divider as divider from reflex.components import divider as divider
from reflex.components import drawer as drawer from reflex.components import drawer as drawer
@ -421,7 +427,9 @@ from reflex.components import visually_hidden as visually_hidden
from reflex.components import vstack as vstack from reflex.components import vstack as vstack
from reflex.components import wrap as wrap from reflex.components import wrap as wrap
from reflex.components import wrap_item as wrap_item from reflex.components import wrap_item as wrap_item
from reflex.components import cancel_upload as cancel_upload
from reflex import components as components from reflex import components as components
from reflex.components import color_mode_cond as color_mode_cond
from reflex.components import desktop_only as desktop_only from reflex.components import desktop_only as desktop_only
from reflex.components import mobile_only as mobile_only from reflex.components import mobile_only as mobile_only
from reflex.components import tablet_only as tablet_only from reflex.components import tablet_only as tablet_only
@ -429,7 +437,9 @@ from reflex.components import mobile_and_tablet as mobile_and_tablet
from reflex.components import tablet_and_desktop as tablet_and_desktop from reflex.components import tablet_and_desktop as tablet_and_desktop
from reflex.components import selected_files as selected_files from reflex.components import selected_files as selected_files
from reflex.components import clear_selected_files as clear_selected_files from reflex.components import clear_selected_files as clear_selected_files
from reflex.components import EditorButtonList as EditorButtonList
from reflex.components import EditorOptions as EditorOptions from reflex.components import EditorOptions as EditorOptions
from reflex.components import NoSSRComponent as NoSSRComponent
from reflex.components.component import memo as memo from reflex.components.component import memo as memo
from reflex.components.graphing import recharts as recharts from reflex.components.graphing import recharts as recharts
from reflex import config as config from reflex import config as config

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import inspect import functools
import os import os
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from typing import ( from typing import (
@ -17,10 +17,13 @@ from typing import (
Set, Set,
Type, Type,
Union, Union,
get_args,
get_type_hints,
) )
from fastapi import FastAPI, UploadFile from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.middleware import cors from fastapi.middleware import cors
from fastapi.responses import StreamingResponse
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from socketio import ASGIApp, AsyncNamespace, AsyncServer from socketio import ASGIApp, AsyncNamespace, AsyncServer
from starlette_admin.contrib.sqla.admin import Admin from starlette_admin.contrib.sqla.admin import Admin
@ -880,62 +883,90 @@ def upload(app: App):
The upload function. The upload function.
""" """
async def upload_file(files: List[UploadFile]): async def upload_file(request: Request, files: List[UploadFile]):
"""Upload a file. """Upload a file.
Args: Args:
request: The FastAPI request object.
files: The file(s) to upload. files: The file(s) to upload.
Returns:
StreamingResponse yielding newline-delimited JSON of StateUpdate
emitted by the upload handler.
Raises: Raises:
ValueError: if there are no args with supported annotation. ValueError: if there are no args with supported annotation.
TypeError: if a background task is used as the handler.
HTTPException: when the request does not include token / handler headers.
""" """
assert files[0].filename is not None token = request.headers.get("reflex-client-token")
token, handler = files[0].filename.split(":")[:2] handler = request.headers.get("reflex-event-handler")
for file in files:
assert file.filename is not None if not token or not handler:
file.filename = file.filename.split(":")[-1] raise HTTPException(
status_code=400,
detail="Missing reflex-client-token or reflex-event-handler header.",
)
# Get the state for the session. # Get the state for the session.
async with app.state_manager.modify_state(token) as state: state = await app.state_manager.get_state(token)
# get the current session ID
sid = state.router.session.session_id
# get the current state(parent state/substate)
path = handler.split(".")[:-1]
current_state = state.get_substate(path)
handler_upload_param = ()
# get handler function # get the current session ID
func = getattr(current_state, handler.split(".")[-1]) # get the current state(parent state/substate)
path = handler.split(".")[:-1]
current_state = state.get_substate(path)
handler_upload_param = ()
# check if there exists any handler args with annotation, List[UploadFile] # get handler function
for k, v in inspect.getfullargspec( func = getattr(type(current_state), handler.split(".")[-1])
func.fn if isinstance(func, EventHandler) else func
).annotations.items():
if types.is_generic_alias(v) and types._issubclass(
v.__args__[0], UploadFile
):
handler_upload_param = (k, v)
break
if not handler_upload_param: # check if there exists any handler args with annotation, List[UploadFile]
raise ValueError( if isinstance(func, EventHandler):
f"`{handler}` handler should have a parameter annotated as List[" if func.is_background:
f"rx.UploadFile]" raise TypeError(
f"@rx.background is not supported for upload handler `{handler}`.",
) )
func = func.fn
if isinstance(func, functools.partial):
func = func.func
for k, v in get_type_hints(func).items():
if types.is_generic_alias(v) and types._issubclass(
get_args(v)[0],
UploadFile,
):
handler_upload_param = (k, v)
break
event = Event( if not handler_upload_param:
token=token, raise ValueError(
name=handler, f"`{handler}` handler should have a parameter annotated as "
payload={handler_upload_param[0]: files}, "List[rx.UploadFile]"
) )
async for update in state._process(event):
# Postprocess the event. event = Event(
update = await app.postprocess(state, event, update) token=token,
# Send update to client name=handler,
await app.event_namespace.emit_update( # type: ignore payload={handler_upload_param[0]: files},
update=update, )
sid=sid,
) async def _ndjson_updates():
"""Process the upload event, generating ndjson updates.
Yields:
Each state update as JSON followed by a new line.
"""
# Process the event.
async with app.state_manager.modify_state(token) as state:
async for update in state._process(event):
# Postprocess the event.
update = await app.postprocess(state, event, update)
yield update.json() + "\n"
# Stream updates to client
return StreamingResponse(
_ndjson_updates(),
media_type="application/x-ndjson",
)
return upload_file return upload_file

View File

@ -46,12 +46,18 @@ from .select import Option, Select
from .slider import Slider, SliderFilledTrack, SliderMark, SliderThumb, SliderTrack from .slider import Slider, SliderFilledTrack, SliderMark, SliderThumb, SliderTrack
from .switch import Switch from .switch import Switch
from .textarea import TextArea from .textarea import TextArea
from .upload import Upload, clear_selected_files, selected_files from .upload import (
Upload,
cancel_upload,
clear_selected_files,
selected_files,
)
helpers = [ helpers = [
"color_mode_cond", "color_mode_cond",
"selected_files", "cancel_upload",
"clear_selected_files", "clear_selected_files",
"selected_files",
] ]
__all__ = [f for f in dir() if f[0].isupper()] + helpers # type: ignore __all__ = [f for f in dir() if f[0].isupper()] + helpers # type: ignore

View File

@ -3,26 +3,75 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from reflex import constants
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.forms.input import Input from reflex.components.forms.input import Input
from reflex.components.layout.box import Box from reflex.components.layout.box import Box
from reflex.constants import EventTriggers from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
from reflex.event import EventChain from reflex.utils import imports
from reflex.vars import BaseVar, Var from reflex.vars import BaseVar, CallableVar, ImportVar, Var
files_state: str = "const [files, setFiles] = useState([]);" DEFAULT_UPLOAD_ID: str = "default"
upload_file: BaseVar = BaseVar(
_var_name="e => setFiles((files) => e)", _var_type=EventChain
)
# Use this var along with the Upload component to render the list of selected files.
selected_files: BaseVar = BaseVar(
_var_name="files.map((f) => f.name)", _var_type=List[str]
)
clear_selected_files: BaseVar = BaseVar( @CallableVar
_var_name="_e => setFiles((files) => [])", _var_type=EventChain def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
) """Get the file upload drop trigger.
This var is passed to the dropzone component to update the file list when a
drop occurs.
Args:
id_: The id of the upload to get the drop trigger for.
Returns:
A var referencing the file upload drop trigger.
"""
return BaseVar(
_var_name=f"e => upload_files.{id_}[1]((files) => e)",
_var_type=EventChain,
)
@CallableVar
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
"""Get the list of selected files.
Args:
id_: The id of the upload to get the selected files for.
Returns:
A var referencing the list of selected file paths.
"""
return BaseVar(
_var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])",
_var_type=List[str],
)
@CallableEventSpec
def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec:
"""Clear the list of selected files.
Args:
id_: The id of the upload to clear.
Returns:
An event spec that clears the list of selected files when triggered.
"""
return call_script(f"upload_files.{id_}[1]((files) => [])")
def cancel_upload(upload_id: str) -> EventSpec:
"""Cancel an upload.
Args:
upload_id: The id of the upload to cancel.
Returns:
An event spec that cancels the upload when triggered.
"""
return call_script(f"upload_controllers[{upload_id!r}]?.abort()")
class Upload(Component): class Upload(Component):
@ -94,7 +143,10 @@ class Upload(Component):
zone.special_props = {BaseVar(_var_name="{...getRootProps()}", _var_type=None)} zone.special_props = {BaseVar(_var_name="{...getRootProps()}", _var_type=None)}
# Create the component. # Create the component.
return super().create(zone, on_drop=upload_file, **upload_props) upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID)
return super().create(
zone, on_drop=upload_file(upload_props["id"]), **upload_props
)
def get_event_triggers(self) -> dict[str, Union[Var, Any]]: def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler. """Get the event triggers that pass the component's value to the handler.
@ -104,7 +156,7 @@ class Upload(Component):
""" """
return { return {
**super().get_event_triggers(), **super().get_event_triggers(),
EventTriggers.ON_DROP: lambda e0: [e0], constants.EventTriggers.ON_DROP: lambda e0: [e0],
} }
def _render(self): def _render(self):
@ -113,4 +165,15 @@ class Upload(Component):
return out return out
def _get_hooks(self) -> str | None: def _get_hooks(self) -> str | None:
return (super()._get_hooks() or "") + files_state return (
(super()._get_hooks() or "")
+ f"""
upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);
"""
)
def _get_imports(self) -> imports.ImportDict:
return {
**super()._get_imports(),
f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="upload_files")},
}

View File

@ -8,17 +8,23 @@ from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style from reflex.style import Style
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from reflex import constants
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.forms.input import Input from reflex.components.forms.input import Input
from reflex.components.layout.box import Box from reflex.components.layout.box import Box
from reflex.constants import EventTriggers from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
from reflex.event import EventChain from reflex.utils import imports
from reflex.vars import BaseVar, Var from reflex.vars import BaseVar, CallableVar, ImportVar, Var
files_state: str DEFAULT_UPLOAD_ID: str
upload_file: BaseVar
selected_files: BaseVar @CallableVar
clear_selected_files: BaseVar def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ...
@CallableVar
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ...
@CallableEventSpec
def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ...
def cancel_upload(upload_id: str) -> EventSpec: ...
class Upload(Component): class Upload(Component):
@overload @overload

View File

@ -174,13 +174,7 @@ class EventHandler(EventActionsMixin):
for arg in args: for arg in args:
# Special case for file uploads. # Special case for file uploads.
if isinstance(arg, FileUpload): if isinstance(arg, FileUpload):
return EventSpec( return arg.as_event_spec(handler=self)
handler=self,
client_handler_name="uploadFiles",
# `files` is defined in the Upload component's _use_hooks
args=((Var.create_safe("files"), Var.create_safe("files")),),
event_actions=self.event_actions.copy(),
)
# Otherwise, convert to JSON. # Otherwise, convert to JSON.
try: try:
@ -236,6 +230,50 @@ class EventSpec(EventActionsMixin):
) )
class CallableEventSpec(EventSpec):
"""Decorate an EventSpec-returning function to act as both a EventSpec and a function.
This is used as a compatibility shim for replacing EventSpec objects in the
API with functions that return a family of EventSpec.
"""
fn: Optional[Callable[..., EventSpec]] = None
def __init__(self, fn: Callable[..., EventSpec] | None = None, **kwargs):
"""Initialize a CallableEventSpec.
Args:
fn: The function to decorate.
**kwargs: The kwargs to pass to pydantic initializer
"""
if fn is not None:
default_event_spec = fn()
super().__init__(
fn=fn, # type: ignore
**default_event_spec.dict(),
**kwargs,
)
else:
super().__init__(**kwargs)
def __call__(self, *args, **kwargs) -> EventSpec:
"""Call the decorated function.
Args:
*args: The args to pass to the function.
**kwargs: The kwargs to pass to the function.
Returns:
The EventSpec returned from calling the function.
Raises:
TypeError: If the CallableEventSpec has no associated function.
"""
if self.fn is None:
raise TypeError("CallableEventSpec has no associated function.")
return self.fn(*args, **kwargs)
class EventChain(EventActionsMixin): class EventChain(EventActionsMixin):
"""Container for a chain of events that will be executed in order.""" """Container for a chain of events that will be executed in order."""
@ -267,7 +305,76 @@ class FrontendEvent(Base):
class FileUpload(Base): class FileUpload(Base):
"""Class to represent a file upload.""" """Class to represent a file upload."""
pass upload_id: Optional[str] = None
on_upload_progress: Optional[Union[EventHandler, Callable]] = None
@staticmethod
def on_upload_progress_args_spec(_prog: dict[str, int | float | bool]):
"""Args spec for on_upload_progress event handler.
Returns:
The arg mapping passed to backend event handler
"""
return [_prog]
def as_event_spec(self, handler: EventHandler) -> EventSpec:
"""Get the EventSpec for the file upload.
Args:
handler: The event handler.
Returns:
The event spec for the handler.
Raises:
ValueError: If the on_upload_progress is not a valid event handler.
"""
from reflex.components.forms.upload import DEFAULT_UPLOAD_ID
upload_id = self.upload_id or DEFAULT_UPLOAD_ID
spec_args = [
# `upload_files` is defined in state.js and assigned in the Upload component's _use_hooks
(Var.create_safe("files"), Var.create_safe(f"upload_files.{upload_id}[0]")),
(
Var.create_safe("upload_id"),
Var.create_safe(upload_id, _var_is_string=True),
),
]
if self.on_upload_progress is not None:
on_upload_progress = self.on_upload_progress
if isinstance(on_upload_progress, EventHandler):
events = [
call_event_handler(
on_upload_progress,
self.on_upload_progress_args_spec,
),
]
elif isinstance(on_upload_progress, Callable):
# Call the lambda to get the event chain.
events = call_event_fn(on_upload_progress, self.on_upload_progress_args_spec) # type: ignore
else:
raise ValueError(f"{on_upload_progress} is not a valid event handler.")
on_upload_progress_chain = EventChain(
events=events,
args_spec=self.on_upload_progress_args_spec,
)
formatted_chain = str(format.format_prop(on_upload_progress_chain))
spec_args.append(
(
Var.create_safe("on_upload_progress"),
BaseVar(
_var_name=formatted_chain.strip("{}"),
_var_type=EventChain,
),
),
)
return EventSpec(
handler=handler,
client_handler_name="uploadFiles",
args=tuple(spec_args),
event_actions=handler.event_actions.copy(),
)
# Alias for rx.upload_files # Alias for rx.upload_files

View File

@ -1590,3 +1590,33 @@ class NoRenderImportVar(ImportVar):
"""A import that doesn't need to be rendered.""" """A import that doesn't need to be rendered."""
render: Optional[bool] = False render: Optional[bool] = False
class CallableVar(BaseVar):
"""Decorate a Var-returning function to act as both a Var and a function.
This is used as a compatibility shim for replacing Var objects in the
API with functions that return a family of Var.
"""
def __init__(self, fn: Callable[..., BaseVar]):
"""Initialize a CallableVar.
Args:
fn: The function to decorate (must return Var)
"""
self.fn = fn
default_var = fn()
super().__init__(**dataclasses.asdict(default_var))
def __call__(self, *args, **kwargs) -> BaseVar:
"""Call the decorated function.
Args:
*args: The args to pass to the function.
**kwargs: The kwargs to pass to the function.
Returns:
The Var returned from calling the function.
"""
return self.fn(*args, **kwargs)

View File

@ -137,3 +137,7 @@ class NoRenderImportVar(ImportVar):
"""A import that doesn't need to be rendered.""" """A import that doesn't need to be rendered."""
def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ... def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ...
class CallableVar(BaseVar):
def __init__(self, fn: Callable[..., BaseVar]): ...
def __call__(self, *args, **kwargs) -> BaseVar: ...

View File

@ -52,8 +52,10 @@ def test_upload_component_render(upload_component):
# upload # upload
assert upload["name"] == "ReactDropzone" assert upload["name"] == "ReactDropzone"
assert upload["props"] == [ assert upload["props"] == [
"id={`default`}",
"multiple={true}", "multiple={true}",
"onDrop={e => setFiles((files) => e)}", "onDrop={e => upload_files.default[1]((files) => e)}",
"ref={ref_default}",
] ]
assert upload["args"] == ("getRootProps", "getInputProps") assert upload["args"] == ("getRootProps", "getInputProps")
@ -89,8 +91,10 @@ def test_upload_component_with_props_render(upload_component_with_props):
upload = upload_component_with_props.render() upload = upload_component_with_props.render()
assert upload["props"] == [ assert upload["props"] == [
"id={`default`}",
"maxFiles={2}", "maxFiles={2}",
"multiple={true}", "multiple={true}",
"noDrag={true}", "noDrag={true}",
"onDrop={e => setFiles((files) => e)}", "onDrop={e => upload_files.default[1]((files) => e)}",
"ref={ref_default}",
] ]

View File

@ -49,16 +49,7 @@ class FileUploadState(rx.State):
Args: Args:
files: The uploaded files. files: The uploaded files.
""" """
for file in files: pass
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
async def multi_handle_upload(self, files: List[rx.UploadFile]): async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file. """Handle the upload of a file.
@ -78,6 +69,15 @@ class FileUploadState(rx.State):
assert file.filename is not None assert file.filename is not None
self.img_list.append(file.filename) self.img_list.append(file.filename)
@rx.background
async def bg_upload(self, files: List[rx.UploadFile]):
"""Background task cannot be upload handler.
Args:
files: The uploaded files.
"""
pass
class FileStateBase1(rx.State): class FileStateBase1(rx.State):
"""The base state for a child FileUploadState.""" """The base state for a child FileUploadState."""
@ -97,16 +97,7 @@ class ChildFileUploadState(FileStateBase1):
Args: Args:
files: The uploaded files. files: The uploaded files.
""" """
for file in files: pass
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
async def multi_handle_upload(self, files: List[rx.UploadFile]): async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file. """Handle the upload of a file.
@ -126,6 +117,15 @@ class ChildFileUploadState(FileStateBase1):
assert file.filename is not None assert file.filename is not None
self.img_list.append(file.filename) self.img_list.append(file.filename)
@rx.background
async def bg_upload(self, files: List[rx.UploadFile]):
"""Background task cannot be upload handler.
Args:
files: The uploaded files.
"""
pass
class FileStateBase2(FileStateBase1): class FileStateBase2(FileStateBase1):
"""The parent state for a grandchild FileUploadState.""" """The parent state for a grandchild FileUploadState."""
@ -145,16 +145,7 @@ class GrandChildFileUploadState(FileStateBase2):
Args: Args:
files: The uploaded files. files: The uploaded files.
""" """
for file in files: pass
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
async def multi_handle_upload(self, files: List[rx.UploadFile]): async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file. """Handle the upload of a file.
@ -173,3 +164,12 @@ class GrandChildFileUploadState(FileStateBase2):
# Update the img var. # Update the img var.
assert file.filename is not None assert file.filename is not None
self.img_list.append(file.filename) self.img_list.append(file.filename)
@rx.background
async def bg_upload(self, files: List[rx.UploadFile]):
"""Background task cannot be upload handler.
Args:
files: The uploaded files.
"""
pass

View File

@ -746,23 +746,28 @@ async def test_upload_file(tmp_path, state, delta, token: str):
bio.write(data) bio.write(data)
state_name = state.get_full_name().partition(".")[2] or state.get_name() state_name = state.get_full_name().partition(".")[2] or state.get_name()
handler_prefix = f"{token}:{state_name}" request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state_name}.multi_handle_upload",
}
file1 = UploadFile( file1 = UploadFile(
filename=f"{handler_prefix}.multi_handle_upload:True:image1.jpg", filename=f"image1.jpg",
file=bio, file=bio,
) )
file2 = UploadFile( file2 = UploadFile(
filename=f"{handler_prefix}.multi_handle_upload:True:image2.jpg", filename=f"image2.jpg",
file=bio, file=bio,
) )
upload_fn = upload(app) upload_fn = upload(app)
await upload_fn([file1, file2]) streaming_response = await upload_fn(request_mock, [file1, file2])
state_update = StateUpdate(delta=delta, events=[], final=True) async for state_update in streaming_response.body_iterator:
assert (
state_update
== StateUpdate(delta=delta, events=[], final=True).json() + "\n"
)
app.event_namespace.emit.assert_called_with( # type: ignore
"event", state_update.json(), to=current_state.router.session.session_id
)
current_state = await app.state_manager.get_state(token) current_state = await app.state_manager.get_state(token)
state_dict = current_state.dict() state_dict = current_state.dict()
for substate in state.get_full_name().split(".")[1:]: for substate in state.get_full_name().split(".")[1:]:
@ -789,30 +794,20 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
tmp_path: Temporary path. tmp_path: Temporary path.
token: a Token. token: a Token.
""" """
data = b"This is binary data"
# Create a binary IO object and write data to it
bio = io.BytesIO()
bio.write(data)
state._tmp_path = tmp_path state._tmp_path = tmp_path
# The App state must be the "root" of the state tree # The App state must be the "root" of the state tree
app = App(state=state if state is FileUploadState else FileStateBase1) app = App(state=state if state is FileUploadState else FileStateBase1)
state_name = state.get_full_name().partition(".")[2] or state.get_name() state_name = state.get_full_name().partition(".")[2] or state.get_name()
handler_prefix = f"{token}:{state_name}" request_mock = unittest.mock.Mock()
request_mock.headers = {
file1 = UploadFile( "reflex-client-token": token,
filename=f"{handler_prefix}.handle_upload2:True:image1.jpg", "reflex-event-handler": f"{state_name}.handle_upload2",
file=bio, }
) file_mock = unittest.mock.Mock(filename="image1.jpg")
file2 = UploadFile(
filename=f"{handler_prefix}.handle_upload2:True:image2.jpg",
file=bio,
)
fn = upload(app) fn = upload(app)
with pytest.raises(ValueError) as err: with pytest.raises(ValueError) as err:
await fn([file1, file2]) await fn(request_mock, [file_mock])
assert ( assert (
err.value.args[0] err.value.args[0]
== f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]" == f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
@ -822,6 +817,42 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
await app.state_manager.redis.close() await app.state_manager.redis.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"state",
[FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
)
async def test_upload_file_background(state, tmp_path, token):
"""Test that an error is thrown handler is a background task.
Args:
state: The state class.
tmp_path: Temporary path.
token: a Token.
"""
state._tmp_path = tmp_path
# The App state must be the "root" of the state tree
app = App(state=state if state is FileUploadState else FileStateBase1)
state_name = state.get_full_name().partition(".")[2] or state.get_name()
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state_name}.bg_upload",
}
file_mock = unittest.mock.Mock(filename="image1.jpg")
fn = upload(app)
with pytest.raises(TypeError) as err:
await fn(request_mock, [file_mock])
assert (
err.value.args[0]
== f"@rx.background is not supported for upload handler `{state_name}.bg_upload`."
)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.redis.close()
class DynamicState(State): class DynamicState(State):
"""State class for testing dynamic route var. """State class for testing dynamic route var.