diff --git a/integration/test_form_submit.py b/integration/test_form_submit.py index 25296b3d4..4a9f6c2d1 100644 --- a/integration/test_form_submit.py +++ b/integration/test_form_submit.py @@ -7,6 +7,7 @@ from selenium.webdriver.common.by import By from selenium.webdriver.common.keys import Keys from reflex.testing import AppHarness +from reflex.utils import format def FormSubmit(): @@ -38,6 +39,8 @@ def FormSubmit(): rx.number_input(id="number_input"), rx.checkbox(id="bool_input"), rx.switch(id="bool_input2"), + rx.checkbox(id="bool_input3"), + rx.switch(id="bool_input4"), rx.slider(id="slider_input"), rx.range_slider(id="range_input"), rx.radio_group(["option1", "option2"], id="radio_input"), @@ -62,11 +65,65 @@ def FormSubmit(): app.compile() -@pytest.fixture(scope="session") -def form_submit(tmp_path_factory) -> Generator[AppHarness, None, None]: +def FormSubmitName(): + """App with a form using on_submit.""" + import reflex as rx + + class FormState(rx.State): + form_data: dict = {} + + def form_submit(self, form_data: dict): + self.form_data = form_data + + app = rx.App(state=FormState) + + @app.add_page + def index(): + return rx.vstack( + rx.input( + value=FormState.router.session.client_token, + is_read_only=True, + id="token", + ), + rx.form( + rx.vstack( + rx.input(name="name_input"), + rx.hstack(rx.pin_input(length=4, name="pin_input")), + rx.number_input(name="number_input"), + rx.checkbox(name="bool_input"), + rx.switch(name="bool_input2"), + rx.checkbox(name="bool_input3"), + rx.switch(name="bool_input4"), + rx.slider(name="slider_input"), + rx.range_slider(name="range_input"), + rx.radio_group(["option1", "option2"], name="radio_input"), + rx.select(["option1", "option2"], name="select_input"), + rx.text_area(name="text_area_input"), + rx.input( + name="debounce_input", + debounce_timeout=0, + on_change=rx.console_log, + ), + rx.button("Submit", type_="submit"), + ), + on_submit=FormState.form_submit, + custom_attrs={"action": "/invalid"}, + ), + rx.spacer(), + height="100vh", + ) + + app.compile() + + +@pytest.fixture( + scope="session", params=[FormSubmit, FormSubmitName], ids=["id", "name"] +) +def form_submit(request, tmp_path_factory) -> Generator[AppHarness, None, None]: """Start FormSubmit app at tmp_path via AppHarness. Args: + request: pytest request fixture tmp_path_factory: pytest tmp_path_factory fixture Yields: @@ -74,7 +131,7 @@ def form_submit(tmp_path_factory) -> Generator[AppHarness, None, None]: """ with AppHarness.create( root=tmp_path_factory.mktemp("form_submit"), - app_source=FormSubmit, # type: ignore + app_source=request.param, # type: ignore ) as harness: assert harness.app_instance is not None, "app is not running" yield harness @@ -107,6 +164,7 @@ async def test_submit(driver, form_submit: AppHarness): form_submit: harness for FormSubmit app """ assert form_submit.app_instance is not None, "app is not running" + by = By.ID if form_submit.app_source is FormSubmit else By.NAME # get a reference to the connected client token_input = driver.find_element(By.ID, "token") @@ -116,7 +174,7 @@ async def test_submit(driver, form_submit: AppHarness): token = form_submit.poll_for_value(token_input) assert token - name_input = driver.find_element(By.ID, "name_input") + name_input = driver.find_element(by, "name_input") name_input.send_keys("foo") pin_inputs = driver.find_elements(By.CLASS_NAME, "chakra-pin-input") @@ -141,7 +199,7 @@ async def test_submit(driver, form_submit: AppHarness): textarea_input = driver.find_element(By.CLASS_NAME, "chakra-textarea") textarea_input.send_keys("Some", Keys.ENTER, "Text") - debounce_input = driver.find_element(By.ID, "debounce_input") + debounce_input = driver.find_element(by, "debounce_input") debounce_input.send_keys("bar baz") time.sleep(1) @@ -158,11 +216,16 @@ async def test_submit(driver, form_submit: AppHarness): form_data = await AppHarness._poll_for_async(get_form_data) assert isinstance(form_data, dict) + form_data = format.collect_form_dict_names(form_data) + assert form_data["name_input"] == "foo" assert form_data["pin_input"] == pin_values assert form_data["number_input"] == "-3" - assert form_data["bool_input"] is True - assert form_data["bool_input2"] is True + assert form_data["bool_input"] + assert form_data["bool_input2"] + assert not form_data.get("bool_input3", False) + assert not form_data.get("bool_input4", False) + assert form_data["slider_input"] == "50" assert form_data["range_input"] == ["25", "75"] assert form_data["radio_input"] == "option2" diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py index 451452731..6eac82c4c 100644 --- a/integration/test_var_operations.py +++ b/integration/test_var_operations.py @@ -522,6 +522,10 @@ def VarOperations(): rx.text(VarOperationState.str_var4.split(" ").to_string(), id="str_split"), rx.text(VarOperationState.list3.join(""), id="list_join"), rx.text(VarOperationState.list3.join(","), id="list_join_comma"), + rx.text(rx.Var.range(2, 5).join(","), id="list_join_range1"), + rx.text(rx.Var.range(2, 10, 2).join(","), id="list_join_range2"), + rx.text(rx.Var.range(5, 0, -1).join(","), id="list_join_range3"), + rx.text(rx.Var.range(0, 3).join(","), id="list_join_range4"), ) app.compile() @@ -688,6 +692,10 @@ def test_var_operations(driver, var_operations: AppHarness): ("list_reverse", "[2,1]"), ("list_join", "firstsecondthird"), ("list_join_comma", "first,second,third"), + ("list_join_range1", "2,3,4"), + ("list_join_range2", "2,4,6,8"), + ("list_join_range3", "5,4,3,2,1"), + ("list_join_range4", "0,1,2"), # list, int ("list_mult_int", "[1,2,1,2,1,2,1,2,1,2]"), ("list_or_int", "[1,2]"), diff --git a/reflex/.templates/web/utils/helpers/range.js b/reflex/.templates/web/utils/helpers/range.js new file mode 100644 index 000000000..7d1aedaaf --- /dev/null +++ b/reflex/.templates/web/utils/helpers/range.js @@ -0,0 +1,43 @@ +/** + * Simulate the python range() builtin function. + * inspired by https://dev.to/guyariely/using-python-range-in-javascript-337p + * + * If needed outside of an iterator context, use `Array.from(range(10))` or + * spread syntax `[...range(10)]` to get an array. + * + * @param {number} start: the start or end of the range. + * @param {number} stop: the end of the range. + * @param {number} step: the step of the range. + * @returns {object} an object with a Symbol.iterator method over the range + */ +export default function range(start, stop, step) { + return { + [Symbol.iterator]() { + if (stop === undefined) { + stop = start; + start = 0; + } + if (step === undefined) { + step = 1; + } + + let i = start - step; + + return { + next() { + i += step; + if ((step > 0 && i < stop) || (step < 0 && i > stop)) { + return { + value: i, + done: false, + }; + } + return { + value: undefined, + done: true, + }; + }, + }; + }, + }; + } \ No newline at end of file diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 6ffe72ded..0243ce31a 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -573,7 +573,7 @@ export const getRefValues = (refs) => { return; } // getAttribute is used by RangeSlider because it doesn't assign value - return refs.map((ref) => ref.current.value || ref.current.getAttribute("aria-valuenow")); + return refs.map((ref) => ref.current ? ref.current.value || ref.current.getAttribute("aria-valuenow") : null); } /** diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index 7ee5c6ded..aa2e5f2b9 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -41,6 +41,9 @@ DEFAULT_IMPORTS: imports.ImportDict = { ImportVar(tag="StateContext"), ImportVar(tag="ColorModeContext"), }, + "/utils/helpers/range.js": { + ImportVar(tag="range", is_default=True), + }, "": {ImportVar(tag="focus-visible/dist/focus-visible", install=False)}, } diff --git a/reflex/components/forms/button.py b/reflex/components/forms/button.py index 517d76c14..f8990f0a9 100644 --- a/reflex/components/forms/button.py +++ b/reflex/components/forms/button.py @@ -56,6 +56,9 @@ class Button(ChakraComponent): # Components that are not allowed as children. invalid_children: List[str] = ["Button", "MenuButton"] + # The name of the form field + name: Var[str] + class ButtonGroup(ChakraComponent): """A group of buttons.""" diff --git a/reflex/components/forms/checkbox.py b/reflex/components/forms/checkbox.py index 9444ed5ff..51bf37a49 100644 --- a/reflex/components/forms/checkbox.py +++ b/reflex/components/forms/checkbox.py @@ -50,6 +50,9 @@ class Checkbox(ChakraComponent): # The name of the input field in a checkbox (Useful for form submission). name: Var[str] + # The value of the input field when checked (use is_checked prop for a bool) + value: Var[str] = Var.create(True) # type: ignore + # The spacing between the checkbox and its label text (0.5rem) spacing: Var[str] diff --git a/reflex/components/forms/form.py b/reflex/components/forms/form.py index f17642ece..d0de50858 100644 --- a/reflex/components/forms/form.py +++ b/reflex/components/forms/form.py @@ -1,13 +1,36 @@ """Form components.""" from __future__ import annotations -from typing import Any, Callable, Dict, List +from typing import Any, Dict + +from jinja2 import Environment from reflex.components.component import Component from reflex.components.libs.chakra import ChakraComponent +from reflex.components.tags import Tag from reflex.constants import EventTriggers -from reflex.event import EventChain, EventHandler, EventSpec -from reflex.vars import Var +from reflex.event import EventChain +from reflex.utils import imports +from reflex.utils.format import format_event_chain +from reflex.utils.serializers import serialize +from reflex.vars import BaseVar, Var, get_unique_variable_name + +FORM_DATA = Var.create("form_data") +HANDLE_SUBMIT_JS_JINJA2 = Environment().from_string( + """ + const handleSubmit{{ handle_submit_unique_name }} = useCallback((ev) => { + const $form = ev.target + ev.preventDefault() + const {{ form_data }} = {...Object.fromEntries(new FormData($form).entries()), ...{{ field_ref_mapping }}} + + {{ on_submit_event_chain }} + + if ({{ reset_on_submit }}) { + $form.reset() + } + }) + """ +) class Form(ChakraComponent): @@ -18,35 +41,51 @@ class Form(ChakraComponent): # What the form renders to. as_: Var[str] = "form" # type: ignore - def _create_event_chain( - self, - event_trigger: str, - value: Var - | EventHandler - | EventSpec - | List[EventHandler | EventSpec] - | Callable[..., Any], - ) -> EventChain | Var: - """Override the event chain creation to preventDefault for on_submit. + # If true, the form will be cleared after submit. + reset_on_submit: Var[bool] = False # type: ignore - Args: - event_trigger: The event trigger. - value: The value of the event trigger. + # The name used to make this form's submit handler function unique + handle_submit_unique_name: Var[str] = get_unique_variable_name() # type: ignore + + def _get_imports(self) -> imports.ImportDict: + return imports.merge_imports( + super()._get_imports(), + {"react": {imports.ImportVar(tag="useCallback")}}, + ) + + def _get_hooks(self) -> str | None: + if EventTriggers.ON_SUBMIT not in self.event_triggers: + return + return HANDLE_SUBMIT_JS_JINJA2.render( + handle_submit_unique_name=self.handle_submit_unique_name, + form_data=FORM_DATA, + field_ref_mapping=serialize(self._get_form_refs()), + on_submit_event_chain=format_event_chain( + self.event_triggers[EventTriggers.ON_SUBMIT] + ), + reset_on_submit=self.reset_on_submit, + ) + + def _render(self) -> Tag: + return ( + super() + ._render() + .remove_props("reset_on_submit", "handle_submit_unique_name") + ) + + def render(self) -> dict: + """Render the component. Returns: - The event chain. + The rendered component. """ - chain = super()._create_event_chain(event_trigger, value) - if event_trigger == EventTriggers.ON_SUBMIT and isinstance(chain, EventChain): - return chain.prevent_default - return chain + self.event_triggers[EventTriggers.ON_SUBMIT] = BaseVar( + _var_name=f"handleSubmit{self.handle_submit_unique_name}", + _var_type=EventChain, + ) + return super().render() - def get_event_triggers(self) -> Dict[str, Any]: - """Get the event triggers that pass the component's value to the handler. - - Returns: - A dict mapping the event trigger to the var that is passed to the handler. - """ + def _get_form_refs(self) -> Dict[str, Any]: # Send all the input refs to the handler. form_refs = {} for ref in self.get_refs(): @@ -60,10 +99,17 @@ class Form(ChakraComponent): form_refs[ref[4:]] = Var.create( f"getRefValue({ref})", _var_is_local=False ) + return form_refs + def get_event_triggers(self) -> Dict[str, Any]: + """Get the event triggers that pass the component's value to the handler. + + Returns: + A dict mapping the event trigger to the var that is passed to the handler. + """ return { **super().get_event_triggers(), - EventTriggers.ON_SUBMIT: lambda e0: [form_refs], + EventTriggers.ON_SUBMIT: lambda e0: [FORM_DATA], } diff --git a/reflex/components/forms/input.py b/reflex/components/forms/input.py index 75bbac4cd..b4af66b02 100644 --- a/reflex/components/forms/input.py +++ b/reflex/components/forms/input.py @@ -55,6 +55,9 @@ class Input(ChakraComponent): # "lg" | "md" | "sm" | "xs" size: Var[LiteralButtonSize] + # The name of the form field + name: Var[str] + def _get_imports(self) -> imports.ImportDict: return imports.merge_imports( super()._get_imports(), diff --git a/reflex/components/forms/numberinput.py b/reflex/components/forms/numberinput.py index 1f93c6e8a..2fe90ad96 100644 --- a/reflex/components/forms/numberinput.py +++ b/reflex/components/forms/numberinput.py @@ -68,6 +68,9 @@ class NumberInput(ChakraComponent): # "outline" | "filled" | "flushed" | "unstyled" variant: Var[LiteralInputVariant] + # The name of the form field + name: Var[str] + def get_event_triggers(self) -> Dict[str, Any]: """Get the event triggers that pass the component's value to the handler. diff --git a/reflex/components/forms/pininput.py b/reflex/components/forms/pininput.py index 4d81df7f1..c090e1571 100644 --- a/reflex/components/forms/pininput.py +++ b/reflex/components/forms/pininput.py @@ -4,10 +4,11 @@ from __future__ import annotations from typing import Any, Optional, Union from reflex.components.component import Component -from reflex.components.layout import Foreach from reflex.components.libs.chakra import ChakraComponent, LiteralInputVariant +from reflex.components.tags.tag import Tag from reflex.constants import EventTriggers from reflex.utils import format +from reflex.utils.imports import ImportDict, merge_imports from reflex.vars import Var @@ -58,6 +59,20 @@ class PinInput(ChakraComponent): # "outline" | "flushed" | "filled" | "unstyled" variant: Var[LiteralInputVariant] + # The name of the form field + name: Var[str] + + def _get_imports(self) -> ImportDict: + """Include PinInputField explicitly because it may not be a child component at compile time. + + Returns: + The merged import dict. + """ + return merge_imports( + super()._get_imports(), + PinInputField().get_imports(), # type: ignore + ) + def get_event_triggers(self) -> dict[str, Union[Var, Any]]: """Get the event triggers that pass the component's value to the handler. @@ -70,13 +85,24 @@ class PinInput(ChakraComponent): EventTriggers.ON_COMPLETE: lambda e0: [e0], } - def get_ref(self): - """Return a reference because we actually attached the ref to the PinInputFields. + def get_ref(self) -> str | None: + """Override ref handling to handle array refs. + + PinInputFields may be created dynamically, so it's not possible + to compute their ref at compile time, so we return a cheating + guess if the id is specified. + + The `ref` for this outer component will always be stripped off, so what + is returned here only matters for form ref collection purposes. Returns: None. """ - return None + if any(isinstance(c, PinInputField) for c in self.children): + return None + if self.id: + return format.format_array_ref(self.id, idx=self.length) + return super().get_ref() def _get_ref_hook(self) -> Optional[str]: """Override the base _get_ref_hook to handle array refs. @@ -86,10 +112,22 @@ class PinInput(ChakraComponent): """ if self.id: ref = format.format_array_ref(self.id, None) + refs_declaration = Var.range(self.length).foreach( + lambda: Var.create_safe("useRef(null)", _var_is_string=False), + ) + refs_declaration._var_is_local = True if ref: - return f"const {ref} = Array.from({{length:{self.length}}}, () => useRef(null));" + return f"const {ref} = {refs_declaration}" return super()._get_ref_hook() + def _render(self) -> Tag: + """Override the base _render to remove the fake get_ref. + + Returns: + The rendered component. + """ + return super()._render().remove_props("ref") + @classmethod def create(cls, *children, **props) -> Component: """Create a pin input component. @@ -104,22 +142,17 @@ class PinInput(ChakraComponent): Returns: The pin input component. """ - if not children and "length" in props: - _id = props.get("id", None) - length = props["length"] - if _id: - children = [ - Foreach.create( - list(range(length)), # type: ignore - lambda ref, i: PinInputField.create( - key=i, - id=_id, - index=i, - ), - ) - ] - else: - children = [PinInputField()] * length + if children: + props.pop("length", None) + elif "length" in props: + field_props = {} + if "id" in props: + field_props["id"] = props["id"] + if "name" in props: + field_props["name"] = props["name"] + children = [ + PinInputField.for_length(props["length"], **field_props), + ] return super().create(*children, **props) @@ -132,6 +165,29 @@ class PinInputField(ChakraComponent): # Default to None because it is assigned by PinInput when created. index: Optional[Var[int]] = None + # The name of the form field + name: Var[str] + + @classmethod + def for_length(cls, length: Var | int, **props) -> Var: + """Create a PinInputField for a PinInput with a given length. + + Args: + length: The length of the PinInput. + props: The props of each PinInputField (name will become indexed). + + Returns: + The PinInputField. + """ + name = props.get("name") + + def _create(i): + if name is not None: + props["name"] = f"{name}-{i}" + return PinInputField.create(**props, index=i, key=i) + + return Var.range(length).foreach(_create) # type: ignore + def _get_ref_hook(self) -> Optional[str]: return None diff --git a/reflex/components/forms/radio.py b/reflex/components/forms/radio.py index 81fdcb262..ecf5afa52 100644 --- a/reflex/components/forms/radio.py +++ b/reflex/components/forms/radio.py @@ -23,6 +23,9 @@ class RadioGroup(ChakraComponent): # The default value. default_value: Var[Any] + # The name of the form field + name: Var[str] + def get_event_triggers(self) -> Dict[str, Union[Var, Any]]: """Get the event triggers that pass the component's value to the handler. diff --git a/reflex/components/forms/rangeslider.py b/reflex/components/forms/rangeslider.py index 9879d0c69..6eb0b688c 100644 --- a/reflex/components/forms/rangeslider.py +++ b/reflex/components/forms/rangeslider.py @@ -45,6 +45,9 @@ class RangeSlider(ChakraComponent): # The minimum distance between slider thumbs. Useful for preventing the thumbs from being too close together. min_steps_between_thumbs: Var[int] + # The name of the form field + name: Var[str] + def get_event_triggers(self) -> dict[str, Union[Var, Any]]: """Get the event triggers that pass the component's value to the handler. diff --git a/reflex/components/forms/select.py b/reflex/components/forms/select.py index ffca7758a..61c3dd28a 100644 --- a/reflex/components/forms/select.py +++ b/reflex/components/forms/select.py @@ -46,6 +46,9 @@ class Select(ChakraComponent): # The size of the select. size: Var[str] + # The name of the form field + name: Var[str] + def get_event_triggers(self) -> Dict[str, Union[Var, Any]]: """Get the event triggers that pass the component's value to the handler. diff --git a/reflex/components/forms/slider.py b/reflex/components/forms/slider.py index 090e4c732..ebf21ce14 100644 --- a/reflex/components/forms/slider.py +++ b/reflex/components/forms/slider.py @@ -66,6 +66,9 @@ class Slider(ChakraComponent): # Maximum width of the slider. max_w: Var[str] + # The name of the form field + name: Var[str] + def get_event_triggers(self) -> dict[str, Union[Var, Any]]: """Get the event triggers that pass the component's value to the handler. diff --git a/reflex/components/forms/switch.py b/reflex/components/forms/switch.py index e6eb82728..301a0bcec 100644 --- a/reflex/components/forms/switch.py +++ b/reflex/components/forms/switch.py @@ -34,6 +34,9 @@ class Switch(ChakraComponent): # The name of the input field in a switch (Useful for form submission). name: Var[str] + # The value of the input field when checked (use is_checked prop for a bool) + value: Var[str] = Var.create(True) # type: ignore + # The spacing between the switch and its label text (0.5rem) spacing: Var[str] diff --git a/reflex/components/forms/textarea.py b/reflex/components/forms/textarea.py index 794b26ed5..a7f3a9b3d 100644 --- a/reflex/components/forms/textarea.py +++ b/reflex/components/forms/textarea.py @@ -45,6 +45,9 @@ class TextArea(ChakraComponent): # "outline" | "filled" | "flushed" | "unstyled" variant: Var[LiteralInputVariant] + # The name of the form field + name: Var[str] + def get_event_triggers(self) -> dict[str, Union[Var, Any]]: """Get the event triggers that pass the component's value to the handler. diff --git a/reflex/components/tags/tag.py b/reflex/components/tags/tag.py index 7c6344369..0122e6b96 100644 --- a/reflex/components/tags/tag.py +++ b/reflex/components/tags/tag.py @@ -83,8 +83,9 @@ class Tag(Base): The tag with the props removed. """ for name in args: - if name in self.props: - del self.props[name] + prop_name = format.to_camel_case(name) + if prop_name in self.props: + del self.props[prop_name] return self @staticmethod diff --git a/reflex/state.py b/reflex/state.py index 9462c901b..c140d6c94 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -43,6 +43,7 @@ from reflex.event import ( ) from reflex.utils import console, format, prerequisites, types from reflex.utils.exceptions import ImmutableStateError, LockExpiredError +from reflex.utils.serializers import SerializedType, serialize, serializer from reflex.vars import BaseVar, ComputedVar, Var Delta = Dict[str, Any] @@ -2016,6 +2017,25 @@ class MutableProxy(wrapt.ObjectProxy): return copy.deepcopy(self.__wrapped__, memo=memo) +@serializer +def serialize_mutable_proxy(mp: MutableProxy) -> SerializedType: + """Serialize the wrapped value of a MutableProxy. + + Args: + mp: The MutableProxy to serialize. + + Returns: + The serialized wrapped object. + + Raises: + ValueError: when the wrapped object is not serializable. + """ + value = serialize(mp.__wrapped__) + if value is None: + raise ValueError(f"Cannot serialize {type(mp.__wrapped__)}") + return value + + class ImmutableMutableProxy(MutableProxy): """A proxy for a mutable object that tracks changes. diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 7d18460b5..efc88f16a 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -627,6 +627,28 @@ def unwrap_vars(value: str) -> str: ) +def collect_form_dict_names(form_dict: dict[str, Any]) -> dict[str, Any]: + """Collapse keys with consecutive suffixes into a single list value. + + Separators dash and underscore are removed, unless this would overwrite an existing key. + + Args: + form_dict: The dict to collapse. + + Returns: + The collapsed dict. + """ + ending_digit_regex = re.compile(r"^(.*?)[_-]?(\d+)$") + collapsed = {} + for k in sorted(form_dict): + m = ending_digit_regex.match(k) + if m: + collapsed.setdefault(m.group(1), []).append(form_dict[k]) + # collapsing never overwrites valid data from the form_dict + collapsed.update(form_dict) + return collapsed + + def format_data_editor_column(col: str | dict): """Format a given column into the proper format. diff --git a/reflex/vars.py b/reflex/vars.py index e5a7357e7..33d37ebb0 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -4,6 +4,7 @@ from __future__ import annotations import contextlib import dataclasses import dis +import inspect import json import random import string @@ -1138,17 +1139,76 @@ class Var: Returns: A var representing foreach operation. + + Raises: + TypeError: If the var is not a list. """ + inner_types = get_args(self._var_type) + if not inner_types: + raise TypeError( + f"Cannot foreach over non-sequence var {self._var_full_name} of type {self._var_type}." + ) arg = BaseVar( _var_name=get_unique_variable_name(), - _var_type=self._var_type, + _var_type=inner_types[0], ) + index = BaseVar( + _var_name=get_unique_variable_name(), + _var_type=int, + ) + fn_signature = inspect.signature(fn) + fn_args = (arg, index) + fn_ret = fn(*fn_args[: len(fn_signature.parameters)]) return BaseVar( - _var_name=f"{self._var_full_name}.map(({arg._var_name}, i) => {fn(arg, key='i')})", + _var_name=f"{self._var_full_name}.map(({arg._var_name}, {index._var_name}) => {fn_ret})", _var_type=self._var_type, _var_is_local=self._var_is_local, ) + @classmethod + def range( + cls, + v1: Var | int = 0, + v2: Var | int | None = None, + step: Var | int | None = None, + ) -> Var: + """Return an iterator over indices from v1 to v2 (or 0 to v1). + + Args: + v1: The start of the range or end of range if v2 is not given. + v2: The end of the range. + step: The number of numbers between each item. + + Returns: + A var representing range operation. + + Raises: + TypeError: If the var is not an int. + """ + if not isinstance(v1, Var): + v1 = Var.create_safe(v1) + if v1._var_type != int: + raise TypeError(f"Cannot get range on non-int var {v1._var_full_name}.") + if not isinstance(v2, Var): + v2 = Var.create(v2) + if v2 is None: + v2 = Var.create_safe("undefined") + elif v2._var_type != int: + raise TypeError(f"Cannot get range on non-int var {v2._var_full_name}.") + + if not isinstance(step, Var): + step = Var.create(step) + if step is None: + step = Var.create_safe(1) + elif step._var_type != int: + raise TypeError(f"Cannot get range on non-int var {step._var_full_name}.") + + return BaseVar( + _var_name=f"Array.from(range({v1._var_full_name}, {v2._var_full_name}, {step._var_name}))", + _var_type=list[int], + _var_is_local=False, + ) + def to(self, type_: Type) -> Var: """Convert the type of the var. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 1040e202b..9df2f7ece 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -85,6 +85,13 @@ class Var: def contains(self, other: Any) -> Var: ... def reverse(self) -> Var: ... def foreach(self, fn: Callable) -> Var: ... + @classmethod + def range( + cls, + v1: Var | int = 0, + v2: Var | int | None = None, + step: Var | int | None = None, + ) -> Var: ... def to(self, type_: Type) -> Var: ... @property def _var_full_name(self) -> str: ...