[REF-723+] Upload with progress and cancellation (#1899)

This commit is contained in:
Masen Furer 2023-11-16 15:46:13 -08:00 committed by GitHub
parent e399b5a98c
commit 7eccc6d988
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 643 additions and 177 deletions

View File

@ -1,13 +1,14 @@
"""Integration tests for file upload."""
from __future__ import annotations
import asyncio
import time
from typing import Generator
import pytest
from selenium.webdriver.common.by import By
from reflex.testing import AppHarness
from reflex.testing import AppHarness, WebDriver
def UploadFile():
@ -16,12 +17,28 @@ def UploadFile():
class UploadState(rx.State):
_file_data: dict[str, str] = {}
event_order: list[str] = []
progress_dicts: list[dict] = []
async def handle_upload(self, files: list[rx.UploadFile]):
for file in files:
upload_data = await file.read()
self._file_data[file.filename or ""] = upload_data.decode("utf-8")
async def handle_upload_secondary(self, files: list[rx.UploadFile]):
for file in files:
upload_data = await file.read()
self._file_data[file.filename or ""] = upload_data.decode("utf-8")
yield UploadState.chain_event
def upload_progress(self, progress):
assert progress
self.event_order.append("upload_progress")
self.progress_dicts.append(progress)
def chain_event(self):
self.event_order.append("chain_event")
def index():
return rx.vstack(
rx.input(
@ -29,6 +46,7 @@ def UploadFile():
is_read_only=True,
id="token",
),
rx.heading("Default Upload"),
rx.upload(
rx.vstack(
rx.button("Select File"),
@ -52,6 +70,47 @@ def UploadFile():
on_click=rx.clear_selected_files,
id="clear_button",
),
rx.heading("Secondary Upload"),
rx.upload(
rx.vstack(
rx.button("Select File"),
rx.text("Drag and drop files here or click to select files"),
),
id="secondary",
),
rx.button(
"Upload",
on_click=UploadState.handle_upload_secondary( # type: ignore
rx.upload_files(
upload_id="secondary",
on_upload_progress=UploadState.upload_progress,
),
),
id="upload_button_secondary",
),
rx.box(
rx.foreach(
rx.selected_files("secondary"),
lambda f: rx.text(f),
),
id="selected_files_secondary",
),
rx.button(
"Clear",
on_click=rx.clear_selected_files("secondary"),
id="clear_button_secondary",
),
rx.vstack(
rx.foreach(
UploadState.progress_dicts, # type: ignore
lambda d: rx.text(d.to_string()),
)
),
rx.button(
"Cancel",
on_click=rx.cancel_upload("secondary"),
id="cancel_button_secondary",
),
)
app = rx.App(state=UploadState)
@ -94,14 +153,18 @@ def driver(upload_file: AppHarness):
driver.quit()
@pytest.mark.parametrize("secondary", [False, True])
@pytest.mark.asyncio
async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
async def test_upload_file(
tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool
):
"""Submit a file upload and check that it arrived on the backend.
Args:
tmp_path: pytest tmp_path fixture
upload_file: harness for UploadFile app.
driver: WebDriver instance.
secondary: whether to use the secondary upload form
"""
assert upload_file.app_instance is not None
token_input = driver.find_element(By.ID, "token")
@ -110,9 +173,13 @@ async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
token = upload_file.poll_for_value(token_input)
assert token is not None
upload_box = driver.find_element(By.XPATH, "//input[@type='file']")
suffix = "_secondary" if secondary else ""
upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[
1 if secondary else 0
]
assert upload_box
upload_button = driver.find_element(By.ID, "upload_button")
upload_button = driver.find_element(By.ID, f"upload_button{suffix}")
assert upload_button
exp_name = "test.txt"
@ -132,9 +199,15 @@ async def test_upload_file(tmp_path, upload_file: AppHarness, driver):
assert file_data[exp_name] == exp_contents
# check that the selected files are displayed
selected_files = driver.find_element(By.ID, "selected_files")
selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
assert selected_files.text == exp_name
state = await upload_file.get_state(token)
if secondary:
# only the secondary form tracks progress and chain events
assert state.event_order.count("upload_progress") == 1
assert state.event_order.count("chain_event") == 1
@pytest.mark.asyncio
async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
@ -186,13 +259,17 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver):
assert file_data[exp_name] == exp_contents
def test_clear_files(tmp_path, upload_file: AppHarness, driver):
@pytest.mark.parametrize("secondary", [False, True])
def test_clear_files(
tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool
):
"""Select then clear several file uploads and check that they are cleared.
Args:
tmp_path: pytest tmp_path fixture
upload_file: harness for UploadFile app.
driver: WebDriver instance.
secondary: whether to use the secondary upload form.
"""
assert upload_file.app_instance is not None
token_input = driver.find_element(By.ID, "token")
@ -201,9 +278,13 @@ def test_clear_files(tmp_path, upload_file: AppHarness, driver):
token = upload_file.poll_for_value(token_input)
assert token is not None
upload_box = driver.find_element(By.XPATH, "//input[@type='file']")
suffix = "_secondary" if secondary else ""
upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[
1 if secondary else 0
]
assert upload_box
upload_button = driver.find_element(By.ID, "upload_button")
upload_button = driver.find_element(By.ID, f"upload_button{suffix}")
assert upload_button
exp_files = {
@ -219,13 +300,56 @@ def test_clear_files(tmp_path, upload_file: AppHarness, driver):
time.sleep(0.2)
# check that the selected files are displayed
selected_files = driver.find_element(By.ID, "selected_files")
selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
assert selected_files.text == "\n".join(exp_files)
clear_button = driver.find_element(By.ID, "clear_button")
clear_button = driver.find_element(By.ID, f"clear_button{suffix}")
assert clear_button
clear_button.click()
# check that the selected files are cleared
selected_files = driver.find_element(By.ID, "selected_files")
selected_files = driver.find_element(By.ID, f"selected_files{suffix}")
assert selected_files.text == ""
# TODO: drag and drop directory
# https://gist.github.com/florentbr/349b1ab024ca9f3de56e6bf8af2ac69e
@pytest.mark.asyncio
async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDriver):
"""Submit a large file upload and cancel it.
Args:
tmp_path: pytest tmp_path fixture
upload_file: harness for UploadFile app.
driver: WebDriver instance.
"""
assert upload_file.app_instance is not None
token_input = driver.find_element(By.ID, "token")
assert token_input
# wait for the backend connection to send the token
token = upload_file.poll_for_value(token_input)
assert token is not None
upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1]
upload_button = driver.find_element(By.ID, f"upload_button_secondary")
cancel_button = driver.find_element(By.ID, f"cancel_button_secondary")
exp_name = "large.txt"
target_file = tmp_path / exp_name
with target_file.open("wb") as f:
f.seek(1024 * 1024 * 256)
f.write(b"0")
upload_box.send_keys(str(target_file))
upload_button.click()
await asyncio.sleep(0.3)
cancel_button.click()
# look up the backend state and assert on progress
state = await upload_file.get_state(token)
assert state.progress_dicts
assert exp_name not in state._file_data
target_file.unlink()

View File

@ -32,6 +32,11 @@ let event_processing = false
// Array holding pending events to be processed.
const event_queue = [];
// Pending upload promises, by id
const upload_controllers = {};
// Upload files state by id
export const upload_files = {};
/**
* Generate a UUID (Used for session tokens).
* Taken from: https://stackoverflow.com/questions/105034/how-do-i-create-a-guid-uuid
@ -235,14 +240,22 @@ export const applyEvent = async (event, socket) => {
/**
* Send an event to the server via REST.
* @param event The current event.
* @param state The state with the event queue.
* @param socket The socket object to send the response event(s) on.
*
* @returns Whether the event was sent.
*/
export const applyRestEvent = async (event) => {
export const applyRestEvent = async (event, socket) => {
let eventSent = false;
if (event.handler == "uploadFiles") {
eventSent = await uploadFiles(event.name, event.payload.files);
// Start upload, but do not wait for it, which would block other events.
uploadFiles(
event.name,
event.payload.files,
event.payload.upload_id,
event.payload.on_upload_progress,
socket
);
return false;
}
return eventSent;
};
@ -283,7 +296,7 @@ export const processEvent = async (
let eventSent = false
// Process events with handlers via REST and all others via websockets.
if (event.handler) {
eventSent = await applyRestEvent(event);
eventSent = await applyRestEvent(event, socket);
} else {
eventSent = await applyEvent(event, socket);
}
@ -347,50 +360,86 @@ export const connect = async (
*
* @param state The state to apply the delta to.
* @param handler The handler to use.
* @param upload_id The upload id to use.
* @param on_upload_progress The function to call on upload progress.
* @param socket the websocket connection
*
* @returns Whether the files were uploaded.
* @returns The response from posting to the UPLOADURL endpoint.
*/
export const uploadFiles = async (handler, files) => {
export const uploadFiles = async (handler, files, upload_id, on_upload_progress, socket) => {
// return if there's no file to upload
if (files.length == 0) {
return false;
}
const headers = {
"Content-Type": files[0].type,
};
if (upload_controllers[upload_id]) {
console.log("Upload already in progress for ", upload_id)
return false;
}
let resp_idx = 0;
const eventHandler = (progressEvent) => {
// handle any delta / event streamed from the upload event handler
const chunks = progressEvent.event.target.responseText.trim().split("\n")
chunks.slice(resp_idx).map((chunk) => {
try {
socket._callbacks.$event.map((f) => {
f(chunk)
})
resp_idx += 1
} catch (e) {
console.log("Error parsing chunk", chunk, e)
return
}
})
}
const controller = new AbortController()
const config = {
headers: {
"Reflex-Client-Token": getToken(),
"Reflex-Event-Handler": handler,
},
signal: controller.signal,
onDownloadProgress: eventHandler,
}
if (on_upload_progress) {
config["onUploadProgress"] = on_upload_progress
}
const formdata = new FormData();
// Add the token and handler to the file name.
for (let i = 0; i < files.length; i++) {
files.forEach((file) => {
formdata.append(
"files",
files[i],
getToken() + ":" + handler + ":" + files[i].name
file,
file.path || file.name
);
}
})
// Send the file to the server.
await axios.post(UPLOADURL, formdata, headers)
.then(() => { return true; })
.catch(
error => {
if (error.response) {
// The request was made and the server responded with a status code
// that falls out of the range of 2xx
console.log(error.response.data);
} else if (error.request) {
// The request was made but no response was received
// `error.request` is an instance of XMLHttpRequest in the browser and an instance of
// http.ClientRequest in node.js
console.log(error.request);
} else {
// Something happened in setting up the request that triggered an Error
console.log(error.message);
}
return false;
}
)
upload_controllers[upload_id] = controller
try {
return await axios.post(UPLOADURL, formdata, config)
} catch (error) {
if (error.response) {
// The request was made and the server responded with a status code
// that falls out of the range of 2xx
console.log(error.response.data);
} else if (error.request) {
// The request was made but no response was received
// `error.request` is an instance of XMLHttpRequest in the browser and an instance of
// http.ClientRequest in node.js
console.log(error.request);
} else {
// Something happened in setting up the request that triggered an Error
console.log(error.message);
}
return false;
} finally {
delete upload_controllers[upload_id]
}
};
/**

View File

@ -229,6 +229,7 @@ _ALL_COMPONENTS = [
_ALL_COMPONENTS += [to_snake_case(component) for component in _ALL_COMPONENTS]
_ALL_COMPONENTS += [
"cancel_upload",
"components",
"color_mode_cond",
"desktop_only",

View File

@ -58,6 +58,9 @@ from reflex.components import ConnectionModal as ConnectionModal
from reflex.components import Container as Container
from reflex.components import DataTable as DataTable
from reflex.components import DataEditor as DataEditor
from reflex.components import DataEditorTheme as DataEditorTheme
from reflex.components import DatePicker as DatePicker
from reflex.components import DateTimePicker as DateTimePicker
from reflex.components import DebounceInput as DebounceInput
from reflex.components import Divider as Divider
from reflex.components import Drawer as Drawer
@ -265,6 +268,9 @@ from reflex.components import connection_modal as connection_modal
from reflex.components import container as container
from reflex.components import data_table as data_table
from reflex.components import data_editor as data_editor
from reflex.components import data_editor_theme as data_editor_theme
from reflex.components import date_picker as date_picker
from reflex.components import date_time_picker as date_time_picker
from reflex.components import debounce_input as debounce_input
from reflex.components import divider as divider
from reflex.components import drawer as drawer
@ -421,7 +427,9 @@ from reflex.components import visually_hidden as visually_hidden
from reflex.components import vstack as vstack
from reflex.components import wrap as wrap
from reflex.components import wrap_item as wrap_item
from reflex.components import cancel_upload as cancel_upload
from reflex import components as components
from reflex.components import color_mode_cond as color_mode_cond
from reflex.components import desktop_only as desktop_only
from reflex.components import mobile_only as mobile_only
from reflex.components import tablet_only as tablet_only
@ -429,7 +437,9 @@ from reflex.components import mobile_and_tablet as mobile_and_tablet
from reflex.components import tablet_and_desktop as tablet_and_desktop
from reflex.components import selected_files as selected_files
from reflex.components import clear_selected_files as clear_selected_files
from reflex.components import EditorButtonList as EditorButtonList
from reflex.components import EditorOptions as EditorOptions
from reflex.components import NoSSRComponent as NoSSRComponent
from reflex.components.component import memo as memo
from reflex.components.graphing import recharts as recharts
from reflex import config as config

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
import contextlib
import inspect
import functools
import os
from multiprocessing.pool import ThreadPool
from typing import (
@ -17,10 +17,13 @@ from typing import (
Set,
Type,
Union,
get_args,
get_type_hints,
)
from fastapi import FastAPI, UploadFile
from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.middleware import cors
from fastapi.responses import StreamingResponse
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from socketio import ASGIApp, AsyncNamespace, AsyncServer
from starlette_admin.contrib.sqla.admin import Admin
@ -880,62 +883,90 @@ def upload(app: App):
The upload function.
"""
async def upload_file(files: List[UploadFile]):
async def upload_file(request: Request, files: List[UploadFile]):
"""Upload a file.
Args:
request: The FastAPI request object.
files: The file(s) to upload.
Returns:
StreamingResponse yielding newline-delimited JSON of StateUpdate
emitted by the upload handler.
Raises:
ValueError: if there are no args with supported annotation.
TypeError: if a background task is used as the handler.
HTTPException: when the request does not include token / handler headers.
"""
assert files[0].filename is not None
token, handler = files[0].filename.split(":")[:2]
for file in files:
assert file.filename is not None
file.filename = file.filename.split(":")[-1]
token = request.headers.get("reflex-client-token")
handler = request.headers.get("reflex-event-handler")
if not token or not handler:
raise HTTPException(
status_code=400,
detail="Missing reflex-client-token or reflex-event-handler header.",
)
# Get the state for the session.
async with app.state_manager.modify_state(token) as state:
# get the current session ID
sid = state.router.session.session_id
# get the current state(parent state/substate)
path = handler.split(".")[:-1]
current_state = state.get_substate(path)
handler_upload_param = ()
state = await app.state_manager.get_state(token)
# get handler function
func = getattr(current_state, handler.split(".")[-1])
# get the current session ID
# get the current state(parent state/substate)
path = handler.split(".")[:-1]
current_state = state.get_substate(path)
handler_upload_param = ()
# check if there exists any handler args with annotation, List[UploadFile]
for k, v in inspect.getfullargspec(
func.fn if isinstance(func, EventHandler) else func
).annotations.items():
if types.is_generic_alias(v) and types._issubclass(
v.__args__[0], UploadFile
):
handler_upload_param = (k, v)
break
# get handler function
func = getattr(type(current_state), handler.split(".")[-1])
if not handler_upload_param:
raise ValueError(
f"`{handler}` handler should have a parameter annotated as List["
f"rx.UploadFile]"
# check if there exists any handler args with annotation, List[UploadFile]
if isinstance(func, EventHandler):
if func.is_background:
raise TypeError(
f"@rx.background is not supported for upload handler `{handler}`.",
)
func = func.fn
if isinstance(func, functools.partial):
func = func.func
for k, v in get_type_hints(func).items():
if types.is_generic_alias(v) and types._issubclass(
get_args(v)[0],
UploadFile,
):
handler_upload_param = (k, v)
break
event = Event(
token=token,
name=handler,
payload={handler_upload_param[0]: files},
if not handler_upload_param:
raise ValueError(
f"`{handler}` handler should have a parameter annotated as "
"List[rx.UploadFile]"
)
async for update in state._process(event):
# Postprocess the event.
update = await app.postprocess(state, event, update)
# Send update to client
await app.event_namespace.emit_update( # type: ignore
update=update,
sid=sid,
)
event = Event(
token=token,
name=handler,
payload={handler_upload_param[0]: files},
)
async def _ndjson_updates():
"""Process the upload event, generating ndjson updates.
Yields:
Each state update as JSON followed by a new line.
"""
# Process the event.
async with app.state_manager.modify_state(token) as state:
async for update in state._process(event):
# Postprocess the event.
update = await app.postprocess(state, event, update)
yield update.json() + "\n"
# Stream updates to client
return StreamingResponse(
_ndjson_updates(),
media_type="application/x-ndjson",
)
return upload_file

View File

@ -46,12 +46,18 @@ from .select import Option, Select
from .slider import Slider, SliderFilledTrack, SliderMark, SliderThumb, SliderTrack
from .switch import Switch
from .textarea import TextArea
from .upload import Upload, clear_selected_files, selected_files
from .upload import (
Upload,
cancel_upload,
clear_selected_files,
selected_files,
)
helpers = [
"color_mode_cond",
"selected_files",
"cancel_upload",
"clear_selected_files",
"selected_files",
]
__all__ = [f for f in dir() if f[0].isupper()] + helpers # type: ignore

View File

@ -3,26 +3,75 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from reflex import constants
from reflex.components.component import Component
from reflex.components.forms.input import Input
from reflex.components.layout.box import Box
from reflex.constants import EventTriggers
from reflex.event import EventChain
from reflex.vars import BaseVar, Var
from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
from reflex.utils import imports
from reflex.vars import BaseVar, CallableVar, ImportVar, Var
files_state: str = "const [files, setFiles] = useState([]);"
upload_file: BaseVar = BaseVar(
_var_name="e => setFiles((files) => e)", _var_type=EventChain
)
DEFAULT_UPLOAD_ID: str = "default"
# Use this var along with the Upload component to render the list of selected files.
selected_files: BaseVar = BaseVar(
_var_name="files.map((f) => f.name)", _var_type=List[str]
)
clear_selected_files: BaseVar = BaseVar(
_var_name="_e => setFiles((files) => [])", _var_type=EventChain
)
@CallableVar
def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
"""Get the file upload drop trigger.
This var is passed to the dropzone component to update the file list when a
drop occurs.
Args:
id_: The id of the upload to get the drop trigger for.
Returns:
A var referencing the file upload drop trigger.
"""
return BaseVar(
_var_name=f"e => upload_files.{id_}[1]((files) => e)",
_var_type=EventChain,
)
@CallableVar
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar:
"""Get the list of selected files.
Args:
id_: The id of the upload to get the selected files for.
Returns:
A var referencing the list of selected file paths.
"""
return BaseVar(
_var_name=f"(upload_files.{id_} ? upload_files.{id_}[0]?.map((f) => (f.path || f.name)) : [])",
_var_type=List[str],
)
@CallableEventSpec
def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec:
"""Clear the list of selected files.
Args:
id_: The id of the upload to clear.
Returns:
An event spec that clears the list of selected files when triggered.
"""
return call_script(f"upload_files.{id_}[1]((files) => [])")
def cancel_upload(upload_id: str) -> EventSpec:
"""Cancel an upload.
Args:
upload_id: The id of the upload to cancel.
Returns:
An event spec that cancels the upload when triggered.
"""
return call_script(f"upload_controllers[{upload_id!r}]?.abort()")
class Upload(Component):
@ -94,7 +143,10 @@ class Upload(Component):
zone.special_props = {BaseVar(_var_name="{...getRootProps()}", _var_type=None)}
# Create the component.
return super().create(zone, on_drop=upload_file, **upload_props)
upload_props["id"] = props.get("id", DEFAULT_UPLOAD_ID)
return super().create(
zone, on_drop=upload_file(upload_props["id"]), **upload_props
)
def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler.
@ -104,7 +156,7 @@ class Upload(Component):
"""
return {
**super().get_event_triggers(),
EventTriggers.ON_DROP: lambda e0: [e0],
constants.EventTriggers.ON_DROP: lambda e0: [e0],
}
def _render(self):
@ -113,4 +165,15 @@ class Upload(Component):
return out
def _get_hooks(self) -> str | None:
return (super()._get_hooks() or "") + files_state
return (
(super()._get_hooks() or "")
+ f"""
upload_files.{self.id or DEFAULT_UPLOAD_ID} = useState([]);
"""
)
def _get_imports(self) -> imports.ImportDict:
return {
**super()._get_imports(),
f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="upload_files")},
}

View File

@ -8,17 +8,23 @@ from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style
from typing import Any, Dict, List, Optional, Union
from reflex import constants
from reflex.components.component import Component
from reflex.components.forms.input import Input
from reflex.components.layout.box import Box
from reflex.constants import EventTriggers
from reflex.event import EventChain
from reflex.vars import BaseVar, Var
from reflex.event import CallableEventSpec, EventChain, EventSpec, call_script
from reflex.utils import imports
from reflex.vars import BaseVar, CallableVar, ImportVar, Var
files_state: str
upload_file: BaseVar
selected_files: BaseVar
clear_selected_files: BaseVar
DEFAULT_UPLOAD_ID: str
@CallableVar
def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ...
@CallableVar
def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ...
@CallableEventSpec
def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ...
def cancel_upload(upload_id: str) -> EventSpec: ...
class Upload(Component):
@overload

View File

@ -174,13 +174,7 @@ class EventHandler(EventActionsMixin):
for arg in args:
# Special case for file uploads.
if isinstance(arg, FileUpload):
return EventSpec(
handler=self,
client_handler_name="uploadFiles",
# `files` is defined in the Upload component's _use_hooks
args=((Var.create_safe("files"), Var.create_safe("files")),),
event_actions=self.event_actions.copy(),
)
return arg.as_event_spec(handler=self)
# Otherwise, convert to JSON.
try:
@ -236,6 +230,50 @@ class EventSpec(EventActionsMixin):
)
class CallableEventSpec(EventSpec):
"""Decorate an EventSpec-returning function to act as both a EventSpec and a function.
This is used as a compatibility shim for replacing EventSpec objects in the
API with functions that return a family of EventSpec.
"""
fn: Optional[Callable[..., EventSpec]] = None
def __init__(self, fn: Callable[..., EventSpec] | None = None, **kwargs):
"""Initialize a CallableEventSpec.
Args:
fn: The function to decorate.
**kwargs: The kwargs to pass to pydantic initializer
"""
if fn is not None:
default_event_spec = fn()
super().__init__(
fn=fn, # type: ignore
**default_event_spec.dict(),
**kwargs,
)
else:
super().__init__(**kwargs)
def __call__(self, *args, **kwargs) -> EventSpec:
"""Call the decorated function.
Args:
*args: The args to pass to the function.
**kwargs: The kwargs to pass to the function.
Returns:
The EventSpec returned from calling the function.
Raises:
TypeError: If the CallableEventSpec has no associated function.
"""
if self.fn is None:
raise TypeError("CallableEventSpec has no associated function.")
return self.fn(*args, **kwargs)
class EventChain(EventActionsMixin):
"""Container for a chain of events that will be executed in order."""
@ -267,7 +305,76 @@ class FrontendEvent(Base):
class FileUpload(Base):
"""Class to represent a file upload."""
pass
upload_id: Optional[str] = None
on_upload_progress: Optional[Union[EventHandler, Callable]] = None
@staticmethod
def on_upload_progress_args_spec(_prog: dict[str, int | float | bool]):
"""Args spec for on_upload_progress event handler.
Returns:
The arg mapping passed to backend event handler
"""
return [_prog]
def as_event_spec(self, handler: EventHandler) -> EventSpec:
"""Get the EventSpec for the file upload.
Args:
handler: The event handler.
Returns:
The event spec for the handler.
Raises:
ValueError: If the on_upload_progress is not a valid event handler.
"""
from reflex.components.forms.upload import DEFAULT_UPLOAD_ID
upload_id = self.upload_id or DEFAULT_UPLOAD_ID
spec_args = [
# `upload_files` is defined in state.js and assigned in the Upload component's _use_hooks
(Var.create_safe("files"), Var.create_safe(f"upload_files.{upload_id}[0]")),
(
Var.create_safe("upload_id"),
Var.create_safe(upload_id, _var_is_string=True),
),
]
if self.on_upload_progress is not None:
on_upload_progress = self.on_upload_progress
if isinstance(on_upload_progress, EventHandler):
events = [
call_event_handler(
on_upload_progress,
self.on_upload_progress_args_spec,
),
]
elif isinstance(on_upload_progress, Callable):
# Call the lambda to get the event chain.
events = call_event_fn(on_upload_progress, self.on_upload_progress_args_spec) # type: ignore
else:
raise ValueError(f"{on_upload_progress} is not a valid event handler.")
on_upload_progress_chain = EventChain(
events=events,
args_spec=self.on_upload_progress_args_spec,
)
formatted_chain = str(format.format_prop(on_upload_progress_chain))
spec_args.append(
(
Var.create_safe("on_upload_progress"),
BaseVar(
_var_name=formatted_chain.strip("{}"),
_var_type=EventChain,
),
),
)
return EventSpec(
handler=handler,
client_handler_name="uploadFiles",
args=tuple(spec_args),
event_actions=handler.event_actions.copy(),
)
# Alias for rx.upload_files

View File

@ -1590,3 +1590,33 @@ class NoRenderImportVar(ImportVar):
"""A import that doesn't need to be rendered."""
render: Optional[bool] = False
class CallableVar(BaseVar):
"""Decorate a Var-returning function to act as both a Var and a function.
This is used as a compatibility shim for replacing Var objects in the
API with functions that return a family of Var.
"""
def __init__(self, fn: Callable[..., BaseVar]):
"""Initialize a CallableVar.
Args:
fn: The function to decorate (must return Var)
"""
self.fn = fn
default_var = fn()
super().__init__(**dataclasses.asdict(default_var))
def __call__(self, *args, **kwargs) -> BaseVar:
"""Call the decorated function.
Args:
*args: The args to pass to the function.
**kwargs: The kwargs to pass to the function.
Returns:
The Var returned from calling the function.
"""
return self.fn(*args, **kwargs)

View File

@ -137,3 +137,7 @@ class NoRenderImportVar(ImportVar):
"""A import that doesn't need to be rendered."""
def get_local_storage(key: Optional[Union[Var, str]] = ...) -> BaseVar: ...
class CallableVar(BaseVar):
def __init__(self, fn: Callable[..., BaseVar]): ...
def __call__(self, *args, **kwargs) -> BaseVar: ...

View File

@ -52,8 +52,10 @@ def test_upload_component_render(upload_component):
# upload
assert upload["name"] == "ReactDropzone"
assert upload["props"] == [
"id={`default`}",
"multiple={true}",
"onDrop={e => setFiles((files) => e)}",
"onDrop={e => upload_files.default[1]((files) => e)}",
"ref={ref_default}",
]
assert upload["args"] == ("getRootProps", "getInputProps")
@ -89,8 +91,10 @@ def test_upload_component_with_props_render(upload_component_with_props):
upload = upload_component_with_props.render()
assert upload["props"] == [
"id={`default`}",
"maxFiles={2}",
"multiple={true}",
"noDrag={true}",
"onDrop={e => setFiles((files) => e)}",
"onDrop={e => upload_files.default[1]((files) => e)}",
"ref={ref_default}",
]

View File

@ -49,16 +49,7 @@ class FileUploadState(rx.State):
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
pass
async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
@ -78,6 +69,15 @@ class FileUploadState(rx.State):
assert file.filename is not None
self.img_list.append(file.filename)
@rx.background
async def bg_upload(self, files: List[rx.UploadFile]):
"""Background task cannot be upload handler.
Args:
files: The uploaded files.
"""
pass
class FileStateBase1(rx.State):
"""The base state for a child FileUploadState."""
@ -97,16 +97,7 @@ class ChildFileUploadState(FileStateBase1):
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
pass
async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
@ -126,6 +117,15 @@ class ChildFileUploadState(FileStateBase1):
assert file.filename is not None
self.img_list.append(file.filename)
@rx.background
async def bg_upload(self, files: List[rx.UploadFile]):
"""Background task cannot be upload handler.
Args:
files: The uploaded files.
"""
pass
class FileStateBase2(FileStateBase1):
"""The parent state for a grandchild FileUploadState."""
@ -145,16 +145,7 @@ class GrandChildFileUploadState(FileStateBase2):
Args:
files: The uploaded files.
"""
for file in files:
upload_data = await file.read()
outfile = f"{self._tmp_path}/{file.filename}"
# Save the file.
with open(outfile, "wb") as file_object:
file_object.write(upload_data)
# Update the img var.
self.img_list.append(file.filename)
pass
async def multi_handle_upload(self, files: List[rx.UploadFile]):
"""Handle the upload of a file.
@ -173,3 +164,12 @@ class GrandChildFileUploadState(FileStateBase2):
# Update the img var.
assert file.filename is not None
self.img_list.append(file.filename)
@rx.background
async def bg_upload(self, files: List[rx.UploadFile]):
"""Background task cannot be upload handler.
Args:
files: The uploaded files.
"""
pass

View File

@ -746,23 +746,28 @@ async def test_upload_file(tmp_path, state, delta, token: str):
bio.write(data)
state_name = state.get_full_name().partition(".")[2] or state.get_name()
handler_prefix = f"{token}:{state_name}"
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state_name}.multi_handle_upload",
}
file1 = UploadFile(
filename=f"{handler_prefix}.multi_handle_upload:True:image1.jpg",
filename=f"image1.jpg",
file=bio,
)
file2 = UploadFile(
filename=f"{handler_prefix}.multi_handle_upload:True:image2.jpg",
filename=f"image2.jpg",
file=bio,
)
upload_fn = upload(app)
await upload_fn([file1, file2])
state_update = StateUpdate(delta=delta, events=[], final=True)
streaming_response = await upload_fn(request_mock, [file1, file2])
async for state_update in streaming_response.body_iterator:
assert (
state_update
== StateUpdate(delta=delta, events=[], final=True).json() + "\n"
)
app.event_namespace.emit.assert_called_with( # type: ignore
"event", state_update.json(), to=current_state.router.session.session_id
)
current_state = await app.state_manager.get_state(token)
state_dict = current_state.dict()
for substate in state.get_full_name().split(".")[1:]:
@ -789,30 +794,20 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
tmp_path: Temporary path.
token: a Token.
"""
data = b"This is binary data"
# Create a binary IO object and write data to it
bio = io.BytesIO()
bio.write(data)
state._tmp_path = tmp_path
# The App state must be the "root" of the state tree
app = App(state=state if state is FileUploadState else FileStateBase1)
state_name = state.get_full_name().partition(".")[2] or state.get_name()
handler_prefix = f"{token}:{state_name}"
file1 = UploadFile(
filename=f"{handler_prefix}.handle_upload2:True:image1.jpg",
file=bio,
)
file2 = UploadFile(
filename=f"{handler_prefix}.handle_upload2:True:image2.jpg",
file=bio,
)
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state_name}.handle_upload2",
}
file_mock = unittest.mock.Mock(filename="image1.jpg")
fn = upload(app)
with pytest.raises(ValueError) as err:
await fn([file1, file2])
await fn(request_mock, [file_mock])
assert (
err.value.args[0]
== f"`{state_name}.handle_upload2` handler should have a parameter annotated as List[rx.UploadFile]"
@ -822,6 +817,42 @@ async def test_upload_file_without_annotation(state, tmp_path, token):
await app.state_manager.redis.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"state",
[FileUploadState, ChildFileUploadState, GrandChildFileUploadState],
)
async def test_upload_file_background(state, tmp_path, token):
"""Test that an error is thrown handler is a background task.
Args:
state: The state class.
tmp_path: Temporary path.
token: a Token.
"""
state._tmp_path = tmp_path
# The App state must be the "root" of the state tree
app = App(state=state if state is FileUploadState else FileStateBase1)
state_name = state.get_full_name().partition(".")[2] or state.get_name()
request_mock = unittest.mock.Mock()
request_mock.headers = {
"reflex-client-token": token,
"reflex-event-handler": f"{state_name}.bg_upload",
}
file_mock = unittest.mock.Mock(filename="image1.jpg")
fn = upload(app)
with pytest.raises(TypeError) as err:
await fn(request_mock, [file_mock])
assert (
err.value.args[0]
== f"@rx.background is not supported for upload handler `{state_name}.bg_upload`."
)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.redis.close()
class DynamicState(State):
"""State class for testing dynamic route var.