Support Form controls via name attribute (no ID or ref) (#2012)

This commit is contained in:
Masen Furer 2023-11-10 12:58:59 -08:00 committed by GitHub
parent 7a04652a6a
commit 5e6520cb5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 420 additions and 61 deletions

View File

@ -7,6 +7,7 @@ from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from reflex.testing import AppHarness
from reflex.utils import format
def FormSubmit():
@ -38,6 +39,8 @@ def FormSubmit():
rx.number_input(id="number_input"),
rx.checkbox(id="bool_input"),
rx.switch(id="bool_input2"),
rx.checkbox(id="bool_input3"),
rx.switch(id="bool_input4"),
rx.slider(id="slider_input"),
rx.range_slider(id="range_input"),
rx.radio_group(["option1", "option2"], id="radio_input"),
@ -62,11 +65,65 @@ def FormSubmit():
app.compile()
@pytest.fixture(scope="session")
def form_submit(tmp_path_factory) -> Generator[AppHarness, None, None]:
def FormSubmitName():
"""App with a form using on_submit."""
import reflex as rx
class FormState(rx.State):
form_data: dict = {}
def form_submit(self, form_data: dict):
self.form_data = form_data
app = rx.App(state=FormState)
@app.add_page
def index():
return rx.vstack(
rx.input(
value=FormState.router.session.client_token,
is_read_only=True,
id="token",
),
rx.form(
rx.vstack(
rx.input(name="name_input"),
rx.hstack(rx.pin_input(length=4, name="pin_input")),
rx.number_input(name="number_input"),
rx.checkbox(name="bool_input"),
rx.switch(name="bool_input2"),
rx.checkbox(name="bool_input3"),
rx.switch(name="bool_input4"),
rx.slider(name="slider_input"),
rx.range_slider(name="range_input"),
rx.radio_group(["option1", "option2"], name="radio_input"),
rx.select(["option1", "option2"], name="select_input"),
rx.text_area(name="text_area_input"),
rx.input(
name="debounce_input",
debounce_timeout=0,
on_change=rx.console_log,
),
rx.button("Submit", type_="submit"),
),
on_submit=FormState.form_submit,
custom_attrs={"action": "/invalid"},
),
rx.spacer(),
height="100vh",
)
app.compile()
@pytest.fixture(
scope="session", params=[FormSubmit, FormSubmitName], ids=["id", "name"]
)
def form_submit(request, tmp_path_factory) -> Generator[AppHarness, None, None]:
"""Start FormSubmit app at tmp_path via AppHarness.
Args:
request: pytest request fixture
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
@ -74,7 +131,7 @@ def form_submit(tmp_path_factory) -> Generator[AppHarness, None, None]:
"""
with AppHarness.create(
root=tmp_path_factory.mktemp("form_submit"),
app_source=FormSubmit, # type: ignore
app_source=request.param, # type: ignore
) as harness:
assert harness.app_instance is not None, "app is not running"
yield harness
@ -107,6 +164,7 @@ async def test_submit(driver, form_submit: AppHarness):
form_submit: harness for FormSubmit app
"""
assert form_submit.app_instance is not None, "app is not running"
by = By.ID if form_submit.app_source is FormSubmit else By.NAME
# get a reference to the connected client
token_input = driver.find_element(By.ID, "token")
@ -116,7 +174,7 @@ async def test_submit(driver, form_submit: AppHarness):
token = form_submit.poll_for_value(token_input)
assert token
name_input = driver.find_element(By.ID, "name_input")
name_input = driver.find_element(by, "name_input")
name_input.send_keys("foo")
pin_inputs = driver.find_elements(By.CLASS_NAME, "chakra-pin-input")
@ -141,7 +199,7 @@ async def test_submit(driver, form_submit: AppHarness):
textarea_input = driver.find_element(By.CLASS_NAME, "chakra-textarea")
textarea_input.send_keys("Some", Keys.ENTER, "Text")
debounce_input = driver.find_element(By.ID, "debounce_input")
debounce_input = driver.find_element(by, "debounce_input")
debounce_input.send_keys("bar baz")
time.sleep(1)
@ -158,11 +216,16 @@ async def test_submit(driver, form_submit: AppHarness):
form_data = await AppHarness._poll_for_async(get_form_data)
assert isinstance(form_data, dict)
form_data = format.collect_form_dict_names(form_data)
assert form_data["name_input"] == "foo"
assert form_data["pin_input"] == pin_values
assert form_data["number_input"] == "-3"
assert form_data["bool_input"] is True
assert form_data["bool_input2"] is True
assert form_data["bool_input"]
assert form_data["bool_input2"]
assert not form_data.get("bool_input3", False)
assert not form_data.get("bool_input4", False)
assert form_data["slider_input"] == "50"
assert form_data["range_input"] == ["25", "75"]
assert form_data["radio_input"] == "option2"

View File

@ -522,6 +522,10 @@ def VarOperations():
rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"),
rx.text(VarOperationState.list3.join(""), id="list_join"),
rx.text(VarOperationState.list3.join(","), id="list_join_comma"),
rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"),
rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"),
rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"),
rx.text(rx.Var.range(0, 3).join(","), id="list_join_range4"),
)
app.compile()
@ -688,6 +692,10 @@ def test_var_operations(driver, var_operations: AppHarness):
("list_reverse", "[2,1]"),
("list_join", "firstsecondthird"),
("list_join_comma", "first,second,third"),
("list_join_range1", "2,3,4"),
("list_join_range2", "2,4,6,8"),
("list_join_range3", "5,4,3,2,1"),
("list_join_range4", "0,1,2"),
# list, int
("list_mult_int", "[1,2,1,2,1,2,1,2,1,2]"),
("list_or_int", "[1,2]"),

View File

@ -0,0 +1,43 @@
/**
* Simulate the python range() builtin function.
* inspired by https://dev.to/guyariely/using-python-range-in-javascript-337p
*
* If needed outside of an iterator context, use `Array.from(range(10))` or
* spread syntax `[...range(10)]` to get an array.
*
* @param {number} start: the start or end of the range.
* @param {number} stop: the end of the range.
* @param {number} step: the step of the range.
* @returns {object} an object with a Symbol.iterator method over the range
*/
export default function range(start, stop, step) {
return {
[Symbol.iterator]() {
if (stop === undefined) {
stop = start;
start = 0;
}
if (step === undefined) {
step = 1;
}
let i = start - step;
return {
next() {
i += step;
if ((step > 0 && i < stop) || (step < 0 && i > stop)) {
return {
value: i,
done: false,
};
}
return {
value: undefined,
done: true,
};
},
};
},
};
}

View File

@ -573,7 +573,7 @@ export const getRefValues = (refs) => {
return;
}
// getAttribute is used by RangeSlider because it doesn't assign value
return refs.map((ref) => ref.current.value || ref.current.getAttribute("aria-valuenow"));
return refs.map((ref) => ref.current ? ref.current.value || ref.current.getAttribute("aria-valuenow") : null);
}
/**

View File

@ -41,6 +41,9 @@ DEFAULT_IMPORTS: imports.ImportDict = {
ImportVar(tag="StateContext"),
ImportVar(tag="ColorModeContext"),
},
"/utils/helpers/range.js": {
ImportVar(tag="range", is_default=True),
},
"": {ImportVar(tag="focus-visible/dist/focus-visible", install=False)},
}

View File

@ -56,6 +56,9 @@ class Button(ChakraComponent):
# Components that are not allowed as children.
invalid_children: List[str] = ["Button", "MenuButton"]
# The name of the form field
name: Var[str]
class ButtonGroup(ChakraComponent):
"""A group of buttons."""

View File

@ -50,6 +50,9 @@ class Checkbox(ChakraComponent):
# The name of the input field in a checkbox (Useful for form submission).
name: Var[str]
# The value of the input field when checked (use is_checked prop for a bool)
value: Var[str] = Var.create(True) # type: ignore
# The spacing between the checkbox and its label text (0.5rem)
spacing: Var[str]

View File

@ -1,13 +1,36 @@
"""Form components."""
from __future__ import annotations
from typing import Any, Callable, Dict, List
from typing import Any, Dict
from jinja2 import Environment
from reflex.components.component import Component
from reflex.components.libs.chakra import ChakraComponent
from reflex.components.tags import Tag
from reflex.constants import EventTriggers
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.vars import Var
from reflex.event import EventChain
from reflex.utils import imports
from reflex.utils.format import format_event_chain
from reflex.utils.serializers import serialize
from reflex.vars import BaseVar, Var, get_unique_variable_name
FORM_DATA = Var.create("form_data")
HANDLE_SUBMIT_JS_JINJA2 = Environment().from_string(
"""
const handleSubmit{{ handle_submit_unique_name }} = useCallback((ev) => {
const $form = ev.target
ev.preventDefault()
const {{ form_data }} = {...Object.fromEntries(new FormData($form).entries()), ...{{ field_ref_mapping }}}
{{ on_submit_event_chain }}
if ({{ reset_on_submit }}) {
$form.reset()
}
})
"""
)
class Form(ChakraComponent):
@ -18,35 +41,51 @@ class Form(ChakraComponent):
# What the form renders to.
as_: Var[str] = "form" # type: ignore
def _create_event_chain(
self,
event_trigger: str,
value: Var
| EventHandler
| EventSpec
| List[EventHandler | EventSpec]
| Callable[..., Any],
) -> EventChain | Var:
"""Override the event chain creation to preventDefault for on_submit.
# If true, the form will be cleared after submit.
reset_on_submit: Var[bool] = False # type: ignore
Args:
event_trigger: The event trigger.
value: The value of the event trigger.
# The name used to make this form's submit handler function unique
handle_submit_unique_name: Var[str] = get_unique_variable_name() # type: ignore
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"react": {imports.ImportVar(tag="useCallback")}},
)
def _get_hooks(self) -> str | None:
if EventTriggers.ON_SUBMIT not in self.event_triggers:
return
return HANDLE_SUBMIT_JS_JINJA2.render(
handle_submit_unique_name=self.handle_submit_unique_name,
form_data=FORM_DATA,
field_ref_mapping=serialize(self._get_form_refs()),
on_submit_event_chain=format_event_chain(
self.event_triggers[EventTriggers.ON_SUBMIT]
),
reset_on_submit=self.reset_on_submit,
)
def _render(self) -> Tag:
return (
super()
._render()
.remove_props("reset_on_submit", "handle_submit_unique_name")
)
def render(self) -> dict:
"""Render the component.
Returns:
The event chain.
The rendered component.
"""
chain = super()._create_event_chain(event_trigger, value)
if event_trigger == EventTriggers.ON_SUBMIT and isinstance(chain, EventChain):
return chain.prevent_default
return chain
self.event_triggers[EventTriggers.ON_SUBMIT] = BaseVar(
_var_name=f"handleSubmit{self.handle_submit_unique_name}",
_var_type=EventChain,
)
return super().render()
def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler.
Returns:
A dict mapping the event trigger to the var that is passed to the handler.
"""
def _get_form_refs(self) -> Dict[str, Any]:
# Send all the input refs to the handler.
form_refs = {}
for ref in self.get_refs():
@ -60,10 +99,17 @@ class Form(ChakraComponent):
form_refs[ref[4:]] = Var.create(
f"getRefValue({ref})", _var_is_local=False
)
return form_refs
def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler.
Returns:
A dict mapping the event trigger to the var that is passed to the handler.
"""
return {
**super().get_event_triggers(),
EventTriggers.ON_SUBMIT: lambda e0: [form_refs],
EventTriggers.ON_SUBMIT: lambda e0: [FORM_DATA],
}

View File

@ -55,6 +55,9 @@ class Input(ChakraComponent):
# "lg" | "md" | "sm" | "xs"
size: Var[LiteralButtonSize]
# The name of the form field
name: Var[str]
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),

View File

@ -68,6 +68,9 @@ class NumberInput(ChakraComponent):
# "outline" | "filled" | "flushed" | "unstyled"
variant: Var[LiteralInputVariant]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler.

View File

@ -4,10 +4,11 @@ from __future__ import annotations
from typing import Any, Optional, Union
from reflex.components.component import Component
from reflex.components.layout import Foreach
from reflex.components.libs.chakra import ChakraComponent, LiteralInputVariant
from reflex.components.tags.tag import Tag
from reflex.constants import EventTriggers
from reflex.utils import format
from reflex.utils.imports import ImportDict, merge_imports
from reflex.vars import Var
@ -58,6 +59,20 @@ class PinInput(ChakraComponent):
# "outline" | "flushed" | "filled" | "unstyled"
variant: Var[LiteralInputVariant]
# The name of the form field
name: Var[str]
def _get_imports(self) -> ImportDict:
"""Include PinInputField explicitly because it may not be a child component at compile time.
Returns:
The merged import dict.
"""
return merge_imports(
super()._get_imports(),
PinInputField().get_imports(), # type: ignore
)
def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler.
@ -70,13 +85,24 @@ class PinInput(ChakraComponent):
EventTriggers.ON_COMPLETE: lambda e0: [e0],
}
def get_ref(self):
"""Return a reference because we actually attached the ref to the PinInputFields.
def get_ref(self) -> str | None:
"""Override ref handling to handle array refs.
PinInputFields may be created dynamically, so it's not possible
to compute their ref at compile time, so we return a cheating
guess if the id is specified.
The `ref` for this outer component will always be stripped off, so what
is returned here only matters for form ref collection purposes.
Returns:
None.
"""
return None
if any(isinstance(c, PinInputField) for c in self.children):
return None
if self.id:
return format.format_array_ref(self.id, idx=self.length)
return super().get_ref()
def _get_ref_hook(self) -> Optional[str]:
"""Override the base _get_ref_hook to handle array refs.
@ -86,10 +112,22 @@ class PinInput(ChakraComponent):
"""
if self.id:
ref = format.format_array_ref(self.id, None)
refs_declaration = Var.range(self.length).foreach(
lambda: Var.create_safe("useRef(null)", _var_is_string=False),
)
refs_declaration._var_is_local = True
if ref:
return f"const {ref} = Array.from({{length:{self.length}}}, () => useRef(null));"
return f"const {ref} = {refs_declaration}"
return super()._get_ref_hook()
def _render(self) -> Tag:
"""Override the base _render to remove the fake get_ref.
Returns:
The rendered component.
"""
return super()._render().remove_props("ref")
@classmethod
def create(cls, *children, **props) -> Component:
"""Create a pin input component.
@ -104,22 +142,17 @@ class PinInput(ChakraComponent):
Returns:
The pin input component.
"""
if not children and "length" in props:
_id = props.get("id", None)
length = props["length"]
if _id:
children = [
Foreach.create(
list(range(length)), # type: ignore
lambda ref, i: PinInputField.create(
key=i,
id=_id,
index=i,
),
)
]
else:
children = [PinInputField()] * length
if children:
props.pop("length", None)
elif "length" in props:
field_props = {}
if "id" in props:
field_props["id"] = props["id"]
if "name" in props:
field_props["name"] = props["name"]
children = [
PinInputField.for_length(props["length"], **field_props),
]
return super().create(*children, **props)
@ -132,6 +165,29 @@ class PinInputField(ChakraComponent):
# Default to None because it is assigned by PinInput when created.
index: Optional[Var[int]] = None
# The name of the form field
name: Var[str]
@classmethod
def for_length(cls, length: Var | int, **props) -> Var:
"""Create a PinInputField for a PinInput with a given length.
Args:
length: The length of the PinInput.
props: The props of each PinInputField (name will become indexed).
Returns:
The PinInputField.
"""
name = props.get("name")
def _create(i):
if name is not None:
props["name"] = f"{name}-{i}"
return PinInputField.create(**props, index=i, key=i)
return Var.range(length).foreach(_create) # type: ignore
def _get_ref_hook(self) -> Optional[str]:
return None

View File

@ -23,6 +23,9 @@ class RadioGroup(ChakraComponent):
# The default value.
default_value: Var[Any]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> Dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler.

View File

@ -45,6 +45,9 @@ class RangeSlider(ChakraComponent):
# The minimum distance between slider thumbs. Useful for preventing the thumbs from being too close together.
min_steps_between_thumbs: Var[int]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler.

View File

@ -46,6 +46,9 @@ class Select(ChakraComponent):
# The size of the select.
size: Var[str]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> Dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler.

View File

@ -66,6 +66,9 @@ class Slider(ChakraComponent):
# Maximum width of the slider.
max_w: Var[str]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler.

View File

@ -34,6 +34,9 @@ class Switch(ChakraComponent):
# The name of the input field in a switch (Useful for form submission).
name: Var[str]
# The value of the input field when checked (use is_checked prop for a bool)
value: Var[str] = Var.create(True) # type: ignore
# The spacing between the switch and its label text (0.5rem)
spacing: Var[str]

View File

@ -45,6 +45,9 @@ class TextArea(ChakraComponent):
# "outline" | "filled" | "flushed" | "unstyled"
variant: Var[LiteralInputVariant]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler.

View File

@ -83,8 +83,9 @@ class Tag(Base):
The tag with the props removed.
"""
for name in args:
if name in self.props:
del self.props[name]
prop_name = format.to_camel_case(name)
if prop_name in self.props:
del self.props[prop_name]
return self
@staticmethod

View File

@ -43,6 +43,7 @@ from reflex.event import (
)
from reflex.utils import console, format, prerequisites, types
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.vars import BaseVar, ComputedVar, Var
Delta = Dict[str, Any]
@ -2016,6 +2017,25 @@ class MutableProxy(wrapt.ObjectProxy):
return copy.deepcopy(self.__wrapped__, memo=memo)
@serializer
def serialize_mutable_proxy(mp: MutableProxy) -> SerializedType:
"""Serialize the wrapped value of a MutableProxy.
Args:
mp: The MutableProxy to serialize.
Returns:
The serialized wrapped object.
Raises:
ValueError: when the wrapped object is not serializable.
"""
value = serialize(mp.__wrapped__)
if value is None:
raise ValueError(f"Cannot serialize {type(mp.__wrapped__)}")
return value
class ImmutableMutableProxy(MutableProxy):
"""A proxy for a mutable object that tracks changes.

View File

@ -627,6 +627,28 @@ def unwrap_vars(value: str) -> str:
)
def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]:
"""Collapse keys with consecutive suffixes into a single list value.
Separators dash and underscore are removed, unless this would overwrite an existing key.
Args:
form_dict: The dict to collapse.
Returns:
The collapsed dict.
"""
ending_digit_regex = re.compile(r"^(.*?)[_-]?(\d+)$")
collapsed = {}
for k in sorted(form_dict):
m = ending_digit_regex.match(k)
if m:
collapsed.setdefault(m.group(1), []).append(form_dict[k])
# collapsing never overwrites valid data from the form_dict
collapsed.update(form_dict)
return collapsed
def format_data_editor_column(col: str | dict):
"""Format a given column into the proper format.

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import contextlib
import dataclasses
import dis
import inspect
import json
import random
import string
@ -1138,17 +1139,76 @@ class Var:
Returns:
A var representing foreach operation.
Raises:
TypeError: If the var is not a list.
"""
inner_types = get_args(self._var_type)
if not inner_types:
raise TypeError(
f"Cannot foreach over non-sequence var {self._var_full_name} of type {self._var_type}."
)
arg = BaseVar(
_var_name=get_unique_variable_name(),
_var_type=self._var_type,
_var_type=inner_types[0],
)
index = BaseVar(
_var_name=get_unique_variable_name(),
_var_type=int,
)
fn_signature = inspect.signature(fn)
fn_args = (arg, index)
fn_ret = fn(*fn_args[: len(fn_signature.parameters)])
return BaseVar(
_var_name=f"{self._var_full_name}.map(({arg._var_name}, i) => {fn(arg, key='i')})",
_var_name=f"{self._var_full_name}.map(({arg._var_name}, {index._var_name}) => {fn_ret})",
_var_type=self._var_type,
_var_is_local=self._var_is_local,
)
@classmethod
def range(
cls,
v1: Var | int = 0,
v2: Var | int | None = None,
step: Var | int | None = None,
) -> Var:
"""Return an iterator over indices from v1 to v2 (or 0 to v1).
Args:
v1: The start of the range or end of range if v2 is not given.
v2: The end of the range.
step: The number of numbers between each item.
Returns:
A var representing range operation.
Raises:
TypeError: If the var is not an int.
"""
if not isinstance(v1, Var):
v1 = Var.create_safe(v1)
if v1._var_type != int:
raise TypeError(f"Cannot get range on non-int var {v1._var_full_name}.")
if not isinstance(v2, Var):
v2 = Var.create(v2)
if v2 is None:
v2 = Var.create_safe("undefined")
elif v2._var_type != int:
raise TypeError(f"Cannot get range on non-int var {v2._var_full_name}.")
if not isinstance(step, Var):
step = Var.create(step)
if step is None:
step = Var.create_safe(1)
elif step._var_type != int:
raise TypeError(f"Cannot get range on non-int var {step._var_full_name}.")
return BaseVar(
_var_name=f"Array.from(range({v1._var_full_name}, {v2._var_full_name}, {step._var_name}))",
_var_type=list[int],
_var_is_local=False,
)
def to(self, type_: Type) -> Var:
"""Convert the type of the var.

View File

@ -85,6 +85,13 @@ class Var:
def contains(self, other: Any) -> Var: ...
def reverse(self) -> Var: ...
def foreach(self, fn: Callable) -> Var: ...
@classmethod
def range(
cls,
v1: Var | int = 0,
v2: Var | int | None = None,
step: Var | int | None = None,
) -> Var: ...
def to(self, type_: Type) -> Var: ...
@property
def _var_full_name(self) -> str: ...