Support Form controls via name attribute (no ID or ref) (#2012)

This commit is contained in:
Masen Furer 2023-11-10 12:58:59 -08:00 committed by GitHub
parent 7a04652a6a
commit 5e6520cb5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 420 additions and 61 deletions

View File

@ -7,6 +7,7 @@ from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys from selenium.webdriver.common.keys import Keys
from reflex.testing import AppHarness from reflex.testing import AppHarness
from reflex.utils import format
def FormSubmit(): def FormSubmit():
@ -38,6 +39,8 @@ def FormSubmit():
rx.number_input(id="number_input"), rx.number_input(id="number_input"),
rx.checkbox(id="bool_input"), rx.checkbox(id="bool_input"),
rx.switch(id="bool_input2"), rx.switch(id="bool_input2"),
rx.checkbox(id="bool_input3"),
rx.switch(id="bool_input4"),
rx.slider(id="slider_input"), rx.slider(id="slider_input"),
rx.range_slider(id="range_input"), rx.range_slider(id="range_input"),
rx.radio_group(["option1", "option2"], id="radio_input"), rx.radio_group(["option1", "option2"], id="radio_input"),
@ -62,11 +65,65 @@ def FormSubmit():
app.compile() app.compile()
@pytest.fixture(scope="session") def FormSubmitName():
def form_submit(tmp_path_factory) -> Generator[AppHarness, None, None]: """App with a form using on_submit."""
import reflex as rx
class FormState(rx.State):
form_data: dict = {}
def form_submit(self, form_data: dict):
self.form_data = form_data
app = rx.App(state=FormState)
@app.add_page
def index():
return rx.vstack(
rx.input(
value=FormState.router.session.client_token,
is_read_only=True,
id="token",
),
rx.form(
rx.vstack(
rx.input(name="name_input"),
rx.hstack(rx.pin_input(length=4, name="pin_input")),
rx.number_input(name="number_input"),
rx.checkbox(name="bool_input"),
rx.switch(name="bool_input2"),
rx.checkbox(name="bool_input3"),
rx.switch(name="bool_input4"),
rx.slider(name="slider_input"),
rx.range_slider(name="range_input"),
rx.radio_group(["option1", "option2"], name="radio_input"),
rx.select(["option1", "option2"], name="select_input"),
rx.text_area(name="text_area_input"),
rx.input(
name="debounce_input",
debounce_timeout=0,
on_change=rx.console_log,
),
rx.button("Submit", type_="submit"),
),
on_submit=FormState.form_submit,
custom_attrs={"action": "/invalid"},
),
rx.spacer(),
height="100vh",
)
app.compile()
@pytest.fixture(
scope="session", params=[FormSubmit, FormSubmitName], ids=["id", "name"]
)
def form_submit(request, tmp_path_factory) -> Generator[AppHarness, None, None]:
"""Start FormSubmit app at tmp_path via AppHarness. """Start FormSubmit app at tmp_path via AppHarness.
Args: Args:
request: pytest request fixture
tmp_path_factory: pytest tmp_path_factory fixture tmp_path_factory: pytest tmp_path_factory fixture
Yields: Yields:
@ -74,7 +131,7 @@ def form_submit(tmp_path_factory) -> Generator[AppHarness, None, None]:
""" """
with AppHarness.create( with AppHarness.create(
root=tmp_path_factory.mktemp("form_submit"), root=tmp_path_factory.mktemp("form_submit"),
app_source=FormSubmit, # type: ignore app_source=request.param, # type: ignore
) as harness: ) as harness:
assert harness.app_instance is not None, "app is not running" assert harness.app_instance is not None, "app is not running"
yield harness yield harness
@ -107,6 +164,7 @@ async def test_submit(driver, form_submit: AppHarness):
form_submit: harness for FormSubmit app form_submit: harness for FormSubmit app
""" """
assert form_submit.app_instance is not None, "app is not running" assert form_submit.app_instance is not None, "app is not running"
by = By.ID if form_submit.app_source is FormSubmit else By.NAME
# get a reference to the connected client # get a reference to the connected client
token_input = driver.find_element(By.ID, "token") token_input = driver.find_element(By.ID, "token")
@ -116,7 +174,7 @@ async def test_submit(driver, form_submit: AppHarness):
token = form_submit.poll_for_value(token_input) token = form_submit.poll_for_value(token_input)
assert token assert token
name_input = driver.find_element(By.ID, "name_input") name_input = driver.find_element(by, "name_input")
name_input.send_keys("foo") name_input.send_keys("foo")
pin_inputs = driver.find_elements(By.CLASS_NAME, "chakra-pin-input") pin_inputs = driver.find_elements(By.CLASS_NAME, "chakra-pin-input")
@ -141,7 +199,7 @@ async def test_submit(driver, form_submit: AppHarness):
textarea_input = driver.find_element(By.CLASS_NAME, "chakra-textarea") textarea_input = driver.find_element(By.CLASS_NAME, "chakra-textarea")
textarea_input.send_keys("Some", Keys.ENTER, "Text") textarea_input.send_keys("Some", Keys.ENTER, "Text")
debounce_input = driver.find_element(By.ID, "debounce_input") debounce_input = driver.find_element(by, "debounce_input")
debounce_input.send_keys("bar baz") debounce_input.send_keys("bar baz")
time.sleep(1) time.sleep(1)
@ -158,11 +216,16 @@ async def test_submit(driver, form_submit: AppHarness):
form_data = await AppHarness._poll_for_async(get_form_data) form_data = await AppHarness._poll_for_async(get_form_data)
assert isinstance(form_data, dict) assert isinstance(form_data, dict)
form_data = format.collect_form_dict_names(form_data)
assert form_data["name_input"] == "foo" assert form_data["name_input"] == "foo"
assert form_data["pin_input"] == pin_values assert form_data["pin_input"] == pin_values
assert form_data["number_input"] == "-3" assert form_data["number_input"] == "-3"
assert form_data["bool_input"] is True assert form_data["bool_input"]
assert form_data["bool_input2"] is True assert form_data["bool_input2"]
assert not form_data.get("bool_input3", False)
assert not form_data.get("bool_input4", False)
assert form_data["slider_input"] == "50" assert form_data["slider_input"] == "50"
assert form_data["range_input"] == ["25", "75"] assert form_data["range_input"] == ["25", "75"]
assert form_data["radio_input"] == "option2" assert form_data["radio_input"] == "option2"

View File

@ -522,6 +522,10 @@ def VarOperations():
rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"), rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"),
rx.text(VarOperationState.list3.join(""), id="list_join"), rx.text(VarOperationState.list3.join(""), id="list_join"),
rx.text(VarOperationState.list3.join(","), id="list_join_comma"), rx.text(VarOperationState.list3.join(","), id="list_join_comma"),
rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"),
rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"),
rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"),
rx.text(rx.Var.range(0, 3).join(","), id="list_join_range4"),
) )
app.compile() app.compile()
@ -688,6 +692,10 @@ def test_var_operations(driver, var_operations: AppHarness):
("list_reverse", "[2,1]"), ("list_reverse", "[2,1]"),
("list_join", "firstsecondthird"), ("list_join", "firstsecondthird"),
("list_join_comma", "first,second,third"), ("list_join_comma", "first,second,third"),
("list_join_range1", "2,3,4"),
("list_join_range2", "2,4,6,8"),
("list_join_range3", "5,4,3,2,1"),
("list_join_range4", "0,1,2"),
# list, int # list, int
("list_mult_int", "[1,2,1,2,1,2,1,2,1,2]"), ("list_mult_int", "[1,2,1,2,1,2,1,2,1,2]"),
("list_or_int", "[1,2]"), ("list_or_int", "[1,2]"),

View File

@ -0,0 +1,43 @@
/**
* Simulate the python range() builtin function.
* inspired by https://dev.to/guyariely/using-python-range-in-javascript-337p
*
* If needed outside of an iterator context, use `Array.from(range(10))` or
* spread syntax `[...range(10)]` to get an array.
*
* @param {number} start: the start or end of the range.
* @param {number} stop: the end of the range.
* @param {number} step: the step of the range.
* @returns {object} an object with a Symbol.iterator method over the range
*/
export default function range(start, stop, step) {
return {
[Symbol.iterator]() {
if (stop === undefined) {
stop = start;
start = 0;
}
if (step === undefined) {
step = 1;
}
let i = start - step;
return {
next() {
i += step;
if ((step > 0 && i < stop) || (step < 0 && i > stop)) {
return {
value: i,
done: false,
};
}
return {
value: undefined,
done: true,
};
},
};
},
};
}

View File

@ -573,7 +573,7 @@ export const getRefValues = (refs) => {
return; return;
} }
// getAttribute is used by RangeSlider because it doesn't assign value // getAttribute is used by RangeSlider because it doesn't assign value
return refs.map((ref) => ref.current.value || ref.current.getAttribute("aria-valuenow")); return refs.map((ref) => ref.current ? ref.current.value || ref.current.getAttribute("aria-valuenow") : null);
} }
/** /**

View File

@ -41,6 +41,9 @@ DEFAULT_IMPORTS: imports.ImportDict = {
ImportVar(tag="StateContext"), ImportVar(tag="StateContext"),
ImportVar(tag="ColorModeContext"), ImportVar(tag="ColorModeContext"),
}, },
"/utils/helpers/range.js": {
ImportVar(tag="range", is_default=True),
},
"": {ImportVar(tag="focus-visible/dist/focus-visible", install=False)}, "": {ImportVar(tag="focus-visible/dist/focus-visible", install=False)},
} }

View File

@ -56,6 +56,9 @@ class Button(ChakraComponent):
# Components that are not allowed as children. # Components that are not allowed as children.
invalid_children: List[str] = ["Button", "MenuButton"] invalid_children: List[str] = ["Button", "MenuButton"]
# The name of the form field
name: Var[str]
class ButtonGroup(ChakraComponent): class ButtonGroup(ChakraComponent):
"""A group of buttons.""" """A group of buttons."""

View File

@ -50,6 +50,9 @@ class Checkbox(ChakraComponent):
# The name of the input field in a checkbox (Useful for form submission). # The name of the input field in a checkbox (Useful for form submission).
name: Var[str] name: Var[str]
# The value of the input field when checked (use is_checked prop for a bool)
value: Var[str] = Var.create(True) # type: ignore
# The spacing between the checkbox and its label text (0.5rem) # The spacing between the checkbox and its label text (0.5rem)
spacing: Var[str] spacing: Var[str]

View File

@ -1,13 +1,36 @@
"""Form components.""" """Form components."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Dict, List from typing import Any, Dict
from jinja2 import Environment
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.libs.chakra import ChakraComponent from reflex.components.libs.chakra import ChakraComponent
from reflex.components.tags import Tag
from reflex.constants import EventTriggers from reflex.constants import EventTriggers
from reflex.event import EventChain, EventHandler, EventSpec from reflex.event import EventChain
from reflex.vars import Var from reflex.utils import imports
from reflex.utils.format import format_event_chain
from reflex.utils.serializers import serialize
from reflex.vars import BaseVar, Var, get_unique_variable_name
FORM_DATA = Var.create("form_data")
HANDLE_SUBMIT_JS_JINJA2 = Environment().from_string(
"""
const handleSubmit{{ handle_submit_unique_name }} = useCallback((ev) => {
const $form = ev.target
ev.preventDefault()
const {{ form_data }} = {...Object.fromEntries(new FormData($form).entries()), ...{{ field_ref_mapping }}}
{{ on_submit_event_chain }}
if ({{ reset_on_submit }}) {
$form.reset()
}
})
"""
)
class Form(ChakraComponent): class Form(ChakraComponent):
@ -18,35 +41,51 @@ class Form(ChakraComponent):
# What the form renders to. # What the form renders to.
as_: Var[str] = "form" # type: ignore as_: Var[str] = "form" # type: ignore
def _create_event_chain( # If true, the form will be cleared after submit.
self, reset_on_submit: Var[bool] = False # type: ignore
event_trigger: str,
value: Var
| EventHandler
| EventSpec
| List[EventHandler | EventSpec]
| Callable[..., Any],
) -> EventChain | Var:
"""Override the event chain creation to preventDefault for on_submit.
Args: # The name used to make this form's submit handler function unique
event_trigger: The event trigger. handle_submit_unique_name: Var[str] = get_unique_variable_name() # type: ignore
value: The value of the event trigger.
def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports(
super()._get_imports(),
{"react": {imports.ImportVar(tag="useCallback")}},
)
def _get_hooks(self) -> str | None:
if EventTriggers.ON_SUBMIT not in self.event_triggers:
return
return HANDLE_SUBMIT_JS_JINJA2.render(
handle_submit_unique_name=self.handle_submit_unique_name,
form_data=FORM_DATA,
field_ref_mapping=serialize(self._get_form_refs()),
on_submit_event_chain=format_event_chain(
self.event_triggers[EventTriggers.ON_SUBMIT]
),
reset_on_submit=self.reset_on_submit,
)
def _render(self) -> Tag:
return (
super()
._render()
.remove_props("reset_on_submit", "handle_submit_unique_name")
)
def render(self) -> dict:
"""Render the component.
Returns: Returns:
The event chain. The rendered component.
""" """
chain = super()._create_event_chain(event_trigger, value) self.event_triggers[EventTriggers.ON_SUBMIT] = BaseVar(
if event_trigger == EventTriggers.ON_SUBMIT and isinstance(chain, EventChain): _var_name=f"handleSubmit{self.handle_submit_unique_name}",
return chain.prevent_default _var_type=EventChain,
return chain )
return super().render()
def get_event_triggers(self) -> Dict[str, Any]: def _get_form_refs(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler.
Returns:
A dict mapping the event trigger to the var that is passed to the handler.
"""
# Send all the input refs to the handler. # Send all the input refs to the handler.
form_refs = {} form_refs = {}
for ref in self.get_refs(): for ref in self.get_refs():
@ -60,10 +99,17 @@ class Form(ChakraComponent):
form_refs[ref[4:]] = Var.create( form_refs[ref[4:]] = Var.create(
f"getRefValue({ref})", _var_is_local=False f"getRefValue({ref})", _var_is_local=False
) )
return form_refs
def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler.
Returns:
A dict mapping the event trigger to the var that is passed to the handler.
"""
return { return {
**super().get_event_triggers(), **super().get_event_triggers(),
EventTriggers.ON_SUBMIT: lambda e0: [form_refs], EventTriggers.ON_SUBMIT: lambda e0: [FORM_DATA],
} }

View File

@ -55,6 +55,9 @@ class Input(ChakraComponent):
# "lg" | "md" | "sm" | "xs" # "lg" | "md" | "sm" | "xs"
size: Var[LiteralButtonSize] size: Var[LiteralButtonSize]
# The name of the form field
name: Var[str]
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
return imports.merge_imports( return imports.merge_imports(
super()._get_imports(), super()._get_imports(),

View File

@ -68,6 +68,9 @@ class NumberInput(ChakraComponent):
# "outline" | "filled" | "flushed" | "unstyled" # "outline" | "filled" | "flushed" | "unstyled"
variant: Var[LiteralInputVariant] variant: Var[LiteralInputVariant]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> Dict[str, Any]: def get_event_triggers(self) -> Dict[str, Any]:
"""Get the event triggers that pass the component's value to the handler. """Get the event triggers that pass the component's value to the handler.

View File

@ -4,10 +4,11 @@ from __future__ import annotations
from typing import Any, Optional, Union from typing import Any, Optional, Union
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.layout import Foreach
from reflex.components.libs.chakra import ChakraComponent, LiteralInputVariant from reflex.components.libs.chakra import ChakraComponent, LiteralInputVariant
from reflex.components.tags.tag import Tag
from reflex.constants import EventTriggers from reflex.constants import EventTriggers
from reflex.utils import format from reflex.utils import format
from reflex.utils.imports import ImportDict, merge_imports
from reflex.vars import Var from reflex.vars import Var
@ -58,6 +59,20 @@ class PinInput(ChakraComponent):
# "outline" | "flushed" | "filled" | "unstyled" # "outline" | "flushed" | "filled" | "unstyled"
variant: Var[LiteralInputVariant] variant: Var[LiteralInputVariant]
# The name of the form field
name: Var[str]
def _get_imports(self) -> ImportDict:
"""Include PinInputField explicitly because it may not be a child component at compile time.
Returns:
The merged import dict.
"""
return merge_imports(
super()._get_imports(),
PinInputField().get_imports(), # type: ignore
)
def get_event_triggers(self) -> dict[str, Union[Var, Any]]: def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler. """Get the event triggers that pass the component's value to the handler.
@ -70,13 +85,24 @@ class PinInput(ChakraComponent):
EventTriggers.ON_COMPLETE: lambda e0: [e0], EventTriggers.ON_COMPLETE: lambda e0: [e0],
} }
def get_ref(self): def get_ref(self) -> str | None:
"""Return a reference because we actually attached the ref to the PinInputFields. """Override ref handling to handle array refs.
PinInputFields may be created dynamically, so it's not possible
to compute their ref at compile time, so we return a cheating
guess if the id is specified.
The `ref` for this outer component will always be stripped off, so what
is returned here only matters for form ref collection purposes.
Returns: Returns:
None. None.
""" """
return None if any(isinstance(c, PinInputField) for c in self.children):
return None
if self.id:
return format.format_array_ref(self.id, idx=self.length)
return super().get_ref()
def _get_ref_hook(self) -> Optional[str]: def _get_ref_hook(self) -> Optional[str]:
"""Override the base _get_ref_hook to handle array refs. """Override the base _get_ref_hook to handle array refs.
@ -86,10 +112,22 @@ class PinInput(ChakraComponent):
""" """
if self.id: if self.id:
ref = format.format_array_ref(self.id, None) ref = format.format_array_ref(self.id, None)
refs_declaration = Var.range(self.length).foreach(
lambda: Var.create_safe("useRef(null)", _var_is_string=False),
)
refs_declaration._var_is_local = True
if ref: if ref:
return f"const {ref} = Array.from({{length:{self.length}}}, () => useRef(null));" return f"const {ref} = {refs_declaration}"
return super()._get_ref_hook() return super()._get_ref_hook()
def _render(self) -> Tag:
"""Override the base _render to remove the fake get_ref.
Returns:
The rendered component.
"""
return super()._render().remove_props("ref")
@classmethod @classmethod
def create(cls, *children, **props) -> Component: def create(cls, *children, **props) -> Component:
"""Create a pin input component. """Create a pin input component.
@ -104,22 +142,17 @@ class PinInput(ChakraComponent):
Returns: Returns:
The pin input component. The pin input component.
""" """
if not children and "length" in props: if children:
_id = props.get("id", None) props.pop("length", None)
length = props["length"] elif "length" in props:
if _id: field_props = {}
children = [ if "id" in props:
Foreach.create( field_props["id"] = props["id"]
list(range(length)), # type: ignore if "name" in props:
lambda ref, i: PinInputField.create( field_props["name"] = props["name"]
key=i, children = [
id=_id, PinInputField.for_length(props["length"], **field_props),
index=i, ]
),
)
]
else:
children = [PinInputField()] * length
return super().create(*children, **props) return super().create(*children, **props)
@ -132,6 +165,29 @@ class PinInputField(ChakraComponent):
# Default to None because it is assigned by PinInput when created. # Default to None because it is assigned by PinInput when created.
index: Optional[Var[int]] = None index: Optional[Var[int]] = None
# The name of the form field
name: Var[str]
@classmethod
def for_length(cls, length: Var | int, **props) -> Var:
"""Create a PinInputField for a PinInput with a given length.
Args:
length: The length of the PinInput.
props: The props of each PinInputField (name will become indexed).
Returns:
The PinInputField.
"""
name = props.get("name")
def _create(i):
if name is not None:
props["name"] = f"{name}-{i}"
return PinInputField.create(**props, index=i, key=i)
return Var.range(length).foreach(_create) # type: ignore
def _get_ref_hook(self) -> Optional[str]: def _get_ref_hook(self) -> Optional[str]:
return None return None

View File

@ -23,6 +23,9 @@ class RadioGroup(ChakraComponent):
# The default value. # The default value.
default_value: Var[Any] default_value: Var[Any]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> Dict[str, Union[Var, Any]]: def get_event_triggers(self) -> Dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler. """Get the event triggers that pass the component's value to the handler.

View File

@ -45,6 +45,9 @@ class RangeSlider(ChakraComponent):
# The minimum distance between slider thumbs. Useful for preventing the thumbs from being too close together. # The minimum distance between slider thumbs. Useful for preventing the thumbs from being too close together.
min_steps_between_thumbs: Var[int] min_steps_between_thumbs: Var[int]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> dict[str, Union[Var, Any]]: def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler. """Get the event triggers that pass the component's value to the handler.

View File

@ -46,6 +46,9 @@ class Select(ChakraComponent):
# The size of the select. # The size of the select.
size: Var[str] size: Var[str]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> Dict[str, Union[Var, Any]]: def get_event_triggers(self) -> Dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler. """Get the event triggers that pass the component's value to the handler.

View File

@ -66,6 +66,9 @@ class Slider(ChakraComponent):
# Maximum width of the slider. # Maximum width of the slider.
max_w: Var[str] max_w: Var[str]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> dict[str, Union[Var, Any]]: def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler. """Get the event triggers that pass the component's value to the handler.

View File

@ -34,6 +34,9 @@ class Switch(ChakraComponent):
# The name of the input field in a switch (Useful for form submission). # The name of the input field in a switch (Useful for form submission).
name: Var[str] name: Var[str]
# The value of the input field when checked (use is_checked prop for a bool)
value: Var[str] = Var.create(True) # type: ignore
# The spacing between the switch and its label text (0.5rem) # The spacing between the switch and its label text (0.5rem)
spacing: Var[str] spacing: Var[str]

View File

@ -45,6 +45,9 @@ class TextArea(ChakraComponent):
# "outline" | "filled" | "flushed" | "unstyled" # "outline" | "filled" | "flushed" | "unstyled"
variant: Var[LiteralInputVariant] variant: Var[LiteralInputVariant]
# The name of the form field
name: Var[str]
def get_event_triggers(self) -> dict[str, Union[Var, Any]]: def get_event_triggers(self) -> dict[str, Union[Var, Any]]:
"""Get the event triggers that pass the component's value to the handler. """Get the event triggers that pass the component's value to the handler.

View File

@ -83,8 +83,9 @@ class Tag(Base):
The tag with the props removed. The tag with the props removed.
""" """
for name in args: for name in args:
if name in self.props: prop_name = format.to_camel_case(name)
del self.props[name] if prop_name in self.props:
del self.props[prop_name]
return self return self
@staticmethod @staticmethod

View File

@ -43,6 +43,7 @@ from reflex.event import (
) )
from reflex.utils import console, format, prerequisites, types from reflex.utils import console, format, prerequisites, types
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.vars import BaseVar, ComputedVar, Var from reflex.vars import BaseVar, ComputedVar, Var
Delta = Dict[str, Any] Delta = Dict[str, Any]
@ -2016,6 +2017,25 @@ class MutableProxy(wrapt.ObjectProxy):
return copy.deepcopy(self.__wrapped__, memo=memo) return copy.deepcopy(self.__wrapped__, memo=memo)
@serializer
def serialize_mutable_proxy(mp: MutableProxy) -> SerializedType:
"""Serialize the wrapped value of a MutableProxy.
Args:
mp: The MutableProxy to serialize.
Returns:
The serialized wrapped object.
Raises:
ValueError: when the wrapped object is not serializable.
"""
value = serialize(mp.__wrapped__)
if value is None:
raise ValueError(f"Cannot serialize {type(mp.__wrapped__)}")
return value
class ImmutableMutableProxy(MutableProxy): class ImmutableMutableProxy(MutableProxy):
"""A proxy for a mutable object that tracks changes. """A proxy for a mutable object that tracks changes.

View File

@ -627,6 +627,28 @@ def unwrap_vars(value: str) -> str:
) )
def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]:
"""Collapse keys with consecutive suffixes into a single list value.
Separators dash and underscore are removed, unless this would overwrite an existing key.
Args:
form_dict: The dict to collapse.
Returns:
The collapsed dict.
"""
ending_digit_regex = re.compile(r"^(.*?)[_-]?(\d+)$")
collapsed = {}
for k in sorted(form_dict):
m = ending_digit_regex.match(k)
if m:
collapsed.setdefault(m.group(1), []).append(form_dict[k])
# collapsing never overwrites valid data from the form_dict
collapsed.update(form_dict)
return collapsed
def format_data_editor_column(col: str | dict): def format_data_editor_column(col: str | dict):
"""Format a given column into the proper format. """Format a given column into the proper format.

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import contextlib import contextlib
import dataclasses import dataclasses
import dis import dis
import inspect
import json import json
import random import random
import string import string
@ -1138,17 +1139,76 @@ class Var:
Returns: Returns:
A var representing foreach operation. A var representing foreach operation.
Raises:
TypeError: If the var is not a list.
""" """
inner_types = get_args(self._var_type)
if not inner_types:
raise TypeError(
f"Cannot foreach over non-sequence var {self._var_full_name} of type {self._var_type}."
)
arg = BaseVar( arg = BaseVar(
_var_name=get_unique_variable_name(), _var_name=get_unique_variable_name(),
_var_type=self._var_type, _var_type=inner_types[0],
) )
index = BaseVar(
_var_name=get_unique_variable_name(),
_var_type=int,
)
fn_signature = inspect.signature(fn)
fn_args = (arg, index)
fn_ret = fn(*fn_args[: len(fn_signature.parameters)])
return BaseVar( return BaseVar(
_var_name=f"{self._var_full_name}.map(({arg._var_name}, i) => {fn(arg, key='i')})", _var_name=f"{self._var_full_name}.map(({arg._var_name}, {index._var_name}) => {fn_ret})",
_var_type=self._var_type, _var_type=self._var_type,
_var_is_local=self._var_is_local, _var_is_local=self._var_is_local,
) )
@classmethod
def range(
cls,
v1: Var | int = 0,
v2: Var | int | None = None,
step: Var | int | None = None,
) -> Var:
"""Return an iterator over indices from v1 to v2 (or 0 to v1).
Args:
v1: The start of the range or end of range if v2 is not given.
v2: The end of the range.
step: The number of numbers between each item.
Returns:
A var representing range operation.
Raises:
TypeError: If the var is not an int.
"""
if not isinstance(v1, Var):
v1 = Var.create_safe(v1)
if v1._var_type != int:
raise TypeError(f"Cannot get range on non-int var {v1._var_full_name}.")
if not isinstance(v2, Var):
v2 = Var.create(v2)
if v2 is None:
v2 = Var.create_safe("undefined")
elif v2._var_type != int:
raise TypeError(f"Cannot get range on non-int var {v2._var_full_name}.")
if not isinstance(step, Var):
step = Var.create(step)
if step is None:
step = Var.create_safe(1)
elif step._var_type != int:
raise TypeError(f"Cannot get range on non-int var {step._var_full_name}.")
return BaseVar(
_var_name=f"Array.from(range({v1._var_full_name}, {v2._var_full_name}, {step._var_name}))",
_var_type=list[int],
_var_is_local=False,
)
def to(self, type_: Type) -> Var: def to(self, type_: Type) -> Var:
"""Convert the type of the var. """Convert the type of the var.

View File

@ -85,6 +85,13 @@ class Var:
def contains(self, other: Any) -> Var: ... def contains(self, other: Any) -> Var: ...
def reverse(self) -> Var: ... def reverse(self) -> Var: ...
def foreach(self, fn: Callable) -> Var: ... def foreach(self, fn: Callable) -> Var: ...
@classmethod
def range(
cls,
v1: Var | int = 0,
v2: Var | int | None = None,
step: Var | int | None = None,
) -> Var: ...
def to(self, type_: Type) -> Var: ... def to(self, type_: Type) -> Var: ...
@property @property
def _var_full_name(self) -> str: ... def _var_full_name(self) -> str: ...