diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 4b601544f..b07bdfcfa 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -4,12 +4,12 @@ from __future__ import annotations from typing import Any, Iterator -from reflex.components.component import Component +from reflex.components.component import Component, ComponentStyle from reflex.components.tags import Tag from reflex.components.tags.tagless import Tagless from reflex.utils.imports import ParsedImportDict from reflex.vars import BooleanVar, ObjectVar, Var -from reflex.vars.base import VarData +from reflex.vars.base import VarData, get_var_caching, set_var_caching class Bare(Component): @@ -141,6 +141,31 @@ class Bare(Component): return Tagless(contents=f"{{{self.contents!s}}}") return Tagless(contents=str(self.contents)) + def _add_style_recursive( + self, style: ComponentStyle, theme: Component | None = None + ) -> Component: + """Add style to the component and its children. + + Args: + style: The style to add. + theme: The theme to add. + + Returns: + The component with the style added. + """ + new_self = super()._add_style_recursive(style, theme) + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + if isinstance(component, Component): + component._add_style_recursive(style, theme) + if get_var_caching(): + set_var_caching(False) + str(new_self) + set_var_caching(True) + return new_self + def _get_vars( self, include_children: bool = False, ignore_ids: set[int] | None = None ) -> Iterator[Var]: diff --git a/reflex/components/component.py b/reflex/components/component.py index d4d2ca612..fe6b9faa3 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -931,7 +931,6 @@ class Component(BaseComponent, ABC): """ from reflex.components.base.bare import Bare from reflex.components.base.fragment import Fragment - from reflex.components.core.foreach import Foreach no_valid_parents_defined = all(child._valid_parents == [] for child in children) if ( @@ -942,7 +941,7 @@ class Component(BaseComponent, ABC): return comp_name = type(self).__name__ - allowed_components = [comp.__name__ for comp in (Fragment, Foreach)] + allowed_components = [comp.__name__ for comp in (Fragment,)] def validate_child(child): child_name = type(child).__name__ @@ -1974,8 +1973,6 @@ class StatefulComponent(BaseComponent): Returns: The stateful component or None if the component should not be memoized. """ - from reflex.components.core.foreach import Foreach - if component._memoization_mode.disposition == MemoizationDisposition.NEVER: # Never memoize this component. return None @@ -2004,10 +2001,6 @@ class StatefulComponent(BaseComponent): # Skip BaseComponent and StatefulComponent children. if not isinstance(child, Component): continue - # Always consider Foreach something that must be memoized by the parent. - if isinstance(child, Foreach): - should_memoize = True - break child = cls._child_var(child) if isinstance(child, Var) and child._get_all_var_data(): should_memoize = True @@ -2057,12 +2050,9 @@ class StatefulComponent(BaseComponent): The Var from the child component or the child itself (for regular cases). """ from reflex.components.base.bare import Bare - from reflex.components.core.foreach import Foreach if isinstance(child, Bare): return child.contents - if isinstance(child, Foreach): - return child.iterable return child @classmethod diff --git a/reflex/components/core/__init__.py b/reflex/components/core/__init__.py index 237dac11a..534035f12 100644 --- a/reflex/components/core/__init__.py +++ b/reflex/components/core/__init__.py @@ -25,7 +25,6 @@ _SUBMOD_ATTRS: dict[str, list[str]] = { "debounce": ["DebounceInput", "debounce_input"], "foreach": [ "foreach", - "Foreach", ], "html": ["html", "Html"], "match": [ diff --git a/reflex/components/core/__init__.pyi b/reflex/components/core/__init__.pyi index 902433d66..2f1fb2084 100644 --- a/reflex/components/core/__init__.pyi +++ b/reflex/components/core/__init__.pyi @@ -21,7 +21,6 @@ from .cond import color_mode_cond as color_mode_cond from .cond import cond as cond from .debounce import DebounceInput as DebounceInput from .debounce import debounce_input as debounce_input -from .foreach import Foreach as Foreach from .foreach import foreach as foreach from .html import Html as Html from .html import html as html diff --git a/reflex/components/core/foreach.py b/reflex/components/core/foreach.py index c9fbe5bc5..4b6192e0b 100644 --- a/reflex/components/core/foreach.py +++ b/reflex/components/core/foreach.py @@ -2,15 +2,11 @@ from __future__ import annotations -import inspect -from typing import Any, Callable, Iterable +from typing import Callable, Iterable -from reflex.components.base.fragment import Fragment -from reflex.components.component import Component -from reflex.components.tags import IterTag -from reflex.constants import MemoizationMode -from reflex.state import ComponentState from reflex.vars.base import LiteralVar, Var +from reflex.vars.object import ObjectVar +from reflex.vars.sequence import ArrayVar class ForeachVarError(TypeError): @@ -21,116 +17,32 @@ class ForeachRenderError(TypeError): """Raised when there is an error with the foreach render function.""" -class Foreach(Component): - """A component that takes in an iterable and a render function and renders a list of components.""" +def foreach( + iterable: Var[Iterable] | Iterable, + render_fn: Callable, +) -> Var: + """Create a foreach component. - _memoization_mode = MemoizationMode(recursive=False) + Args: + iterable: The iterable to create components from. + render_fn: A function from the render args to the component. - # The iterable to create components from. - iterable: Var[Iterable] + Returns: + The foreach component. - # A function from the render args to the component. - render_fn: Callable = Fragment.create + Raises: + ForeachVarError: If the iterable is of type Any. + TypeError: If the render function is a ComponentState. + """ + iterable = LiteralVar.create(iterable) + if isinstance(iterable, ObjectVar): + iterable = iterable.items() - @classmethod - def create( - cls, - iterable: Var[Iterable] | Iterable, - render_fn: Callable, - ) -> Foreach: - """Create a foreach component. - - Args: - iterable: The iterable to create components from. - render_fn: A function from the render args to the component. - - Returns: - The foreach component. - - Raises: - ForeachVarError: If the iterable is of type Any. - TypeError: If the render function is a ComponentState. - """ - iterable = LiteralVar.create(iterable) - if iterable._var_type == Any: - raise ForeachVarError( - f"Could not foreach over var `{iterable!s}` of type Any. " - "(If you are trying to foreach over a state var, add a type annotation to the var). " - "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" - ) - - if ( - hasattr(render_fn, "__qualname__") - and render_fn.__qualname__ == ComponentState.create.__qualname__ - ): - raise TypeError( - "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet." - ) - - component = cls( - iterable=iterable, - render_fn=render_fn, - ) - # Keep a ref to a rendered component to determine correct imports/hooks/styles. - component.children = [component._render().render_component()] - return component - - def _render(self) -> IterTag: - props = {} - - render_sig = inspect.signature(self.render_fn) - params = list(render_sig.parameters.values()) - - # Validate the render function signature. - if len(params) == 0 or len(params) > 2: - raise ForeachRenderError( - "Expected 1 or 2 parameters in foreach render function, got " - f"{[p.name for p in params]}. See " - "https://reflex.dev/docs/library/dynamic-rendering/foreach/" - ) - - if len(params) >= 1: - # Determine the arg var name based on the params accepted by render_fn. - props["arg_var_name"] = params[0].name - - if len(params) == 2: - # Determine the index var name based on the params accepted by render_fn. - props["index_var_name"] = params[1].name - else: - # Otherwise, use a deterministic index, based on the render function bytecode. - code_hash = ( - hash(self.render_fn.__code__) - .to_bytes( - length=8, - byteorder="big", - signed=True, - ) - .hex() - ) - props["index_var_name"] = f"index_{code_hash}" - - return IterTag( - iterable=self.iterable, - render_fn=self.render_fn, - children=self.children, - **props, + if not isinstance(iterable, ArrayVar): + raise ForeachVarError( + f"Could not foreach over var `{iterable!s}` of type {iterable._var_type!s}. " + "(If you are trying to foreach over a state var, add a type annotation to the var). " + "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" ) - def render(self): - """Render the component. - - Returns: - The dictionary for template of component. - """ - tag = self._render() - - return dict( - tag, - iterable_state=str(tag.iterable), - arg_name=tag.arg_var_name, - arg_index=tag.get_index_var_arg(), - iterable_type=tag.iterable._var_type.mro()[0].__name__, - ) - - -foreach = Foreach.create + return iterable.foreach(render_fn) diff --git a/reflex/components/radix/primitives/slider.py b/reflex/components/radix/primitives/slider.py index 68f39e32c..eafdd65b4 100644 --- a/reflex/components/radix/primitives/slider.py +++ b/reflex/components/radix/primitives/slider.py @@ -188,7 +188,7 @@ class Slider(ComponentNamespace): else: children = [ track, - # Foreach.create(props.get("value"), lambda e: SliderThumb.create()), # foreach doesn't render Thumbs properly # noqa: ERA001 + # foreach(props.get("value"), lambda e: SliderThumb.create()), # foreach doesn't render Thumbs properly # noqa: ERA001 ] return SliderRoot.create(*children, **props) diff --git a/reflex/components/radix/themes/layout/list.py b/reflex/components/radix/themes/layout/list.py index a306e19a4..937342c7b 100644 --- a/reflex/components/radix/themes/layout/list.py +++ b/reflex/components/radix/themes/layout/list.py @@ -5,7 +5,7 @@ from __future__ import annotations from typing import Any, Iterable, Literal, Union from reflex.components.component import Component, ComponentNamespace -from reflex.components.core.foreach import Foreach +from reflex.components.core.foreach import foreach from reflex.components.el.elements.typography import Li, Ol, Ul from reflex.components.lucide.icon import Icon from reflex.components.markdown.markdown import MarkdownComponentMap @@ -70,7 +70,7 @@ class BaseList(Component, MarkdownComponentMap): if not children and items is not None: if isinstance(items, Var): - children = [Foreach.create(items, ListItem.create)] + children = [foreach(items, ListItem.create)] else: children = [ListItem.create(item) for item in items] # type: ignore props["direction"] = "column" diff --git a/reflex/components/tags/__init__.py b/reflex/components/tags/__init__.py index 8c8b73ab4..330bcc279 100644 --- a/reflex/components/tags/__init__.py +++ b/reflex/components/tags/__init__.py @@ -1,4 +1,3 @@ """Representations for React tags.""" -from .iter_tag import IterTag from .tag import Tag diff --git a/reflex/components/tags/iter_tag.py b/reflex/components/tags/iter_tag.py deleted file mode 100644 index 076a993e8..000000000 --- a/reflex/components/tags/iter_tag.py +++ /dev/null @@ -1,141 +0,0 @@ -"""Tag to loop through a list of components.""" - -from __future__ import annotations - -import dataclasses -import inspect -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args - -from reflex.components.tags.tag import Tag -from reflex.utils import types -from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name - -if TYPE_CHECKING: - from reflex.components.component import Component - - -@dataclasses.dataclass() -class IterTag(Tag): - """An iterator tag.""" - - # The var to iterate over. - iterable: Var[Iterable] = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]) - ) - - # The component render function for each item in the iterable. - render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x) - - # The name of the arg var. - arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name) - - # The name of the index var. - index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name) - - def get_iterable_var_type(self) -> types.GenericType: - """Get the type of the iterable var. - - Returns: - The type of the iterable var. - """ - iterable = self.iterable - try: - if iterable._var_type.mro()[0] is dict: - # Arg is a tuple of (key, value). - return Tuple[get_args(iterable._var_type)] - elif iterable._var_type.mro()[0] is tuple: - # Arg is a union of any possible values in the tuple. - return Union[get_args(iterable._var_type)] - else: - return get_args(iterable._var_type)[0] - except Exception: - return Any - - def get_index_var(self) -> Var: - """Get the index var for the tag (with curly braces). - - This is used to reference the index var within the tag. - - Returns: - The index var. - """ - return Var( - _js_expr=self.index_var_name, - _var_type=int, - ).guess_type() - - def get_arg_var(self) -> Var: - """Get the arg var for the tag (with curly braces). - - This is used to reference the arg var within the tag. - - Returns: - The arg var. - """ - return Var( - _js_expr=self.arg_var_name, - _var_type=self.get_iterable_var_type(), - ).guess_type() - - def get_index_var_arg(self) -> Var: - """Get the index var for the tag (without curly braces). - - This is used to render the index var in the .map() function. - - Returns: - The index var. - """ - return Var( - _js_expr=self.index_var_name, - _var_type=int, - ).guess_type() - - def get_arg_var_arg(self) -> Var: - """Get the arg var for the tag (without curly braces). - - This is used to render the arg var in the .map() function. - - Returns: - The arg var. - """ - return Var( - _js_expr=self.arg_var_name, - _var_type=self.get_iterable_var_type(), - ).guess_type() - - def render_component(self) -> Component: - """Render the component. - - Raises: - ValueError: If the render function takes more than 2 arguments. - - Returns: - The rendered component. - """ - # Import here to avoid circular imports. - from reflex.components.base.fragment import Fragment - from reflex.components.core.foreach import Foreach - - # Get the render function arguments. - args = inspect.getfullargspec(self.render_fn).args - arg = self.get_arg_var() - index = self.get_index_var() - - if len(args) == 1: - # If the render function doesn't take the index as an argument. - component = self.render_fn(arg) - else: - # If the render function takes the index as an argument. - if len(args) != 2: - raise ValueError("The render function must take 2 arguments.") - component = self.render_fn(arg, index) - - # Nested foreach components or cond must be wrapped in fragments. - if isinstance(component, (Foreach, Var)): - component = Fragment.create(component) - - # Set the component key. - if component.key is None: - component.key = index - - return component diff --git a/reflex/utils/types.py b/reflex/utils/types.py index a27be80c5..2c67b193a 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -890,12 +890,23 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo Returns: Whether the type hint is a subclass of the other type hint. """ + if isinstance(possible_subclass, Sequence) and isinstance( + possible_superclass, Sequence + ): + return all( + typehint_issubclass(subclass, superclass) + for subclass, superclass in zip(possible_subclass, possible_superclass) + ) if possible_subclass is possible_superclass: return True if possible_superclass is Any: return True if possible_subclass is Any: return False + if isinstance( + possible_subclass, (TypeVar, typing_extensions.TypeVar) + ) or isinstance(possible_superclass, (TypeVar, typing_extensions.TypeVar)): + return True provided_type_origin = get_origin(possible_subclass) accepted_type_origin = get_origin(possible_superclass) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index b0f27bd64..52e293f82 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -151,6 +151,28 @@ def unwrap_reflex_callalbe( return args +_VAR_CACHING = True + + +def get_var_caching() -> bool: + """Get the var caching status. + + Returns: + The var caching status. + """ + return _VAR_CACHING + + +def set_var_caching(value: bool): + """Set the var caching status. + + Args: + value: The value to set the var caching status to. + """ + global _VAR_CACHING + _VAR_CACHING = value + + @dataclasses.dataclass( eq=False, frozen=True, @@ -1186,6 +1208,25 @@ class Var(Generic[VAR_TYPE]): """ return self + def __getattribute__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute. + """ + if not _VAR_CACHING: + try: + self_dict = object.__getattribute__(self, "__dict__") + for key in self_dict: + if key.startswith("_cached_"): + del self_dict[key] + except Exception: + pass + return super().__getattribute__(name) + def __getattr__(self, name: str): """Get an attribute of the var. diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 6b94e07bd..5b7c1bfb6 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -741,7 +741,8 @@ if TYPE_CHECKING: def map_array_operation( array: Var[Sequence[INNER_ARRAY_VAR]], function: Var[ - ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR] + ReflexCallable[[INNER_ARRAY_VAR, int], ANOTHER_ARRAY_VAR] + | ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR] | ReflexCallable[[], ANOTHER_ARRAY_VAR] ], ) -> CustomVarOperationReturn[Sequence[ANOTHER_ARRAY_VAR]]: @@ -973,7 +974,8 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)): def foreach( self: ArrayVar[Sequence[INNER_ARRAY_VAR]], - fn: Callable[[Var[INNER_ARRAY_VAR]], ANOTHER_ARRAY_VAR] + fn: Callable[[Var[INNER_ARRAY_VAR], NumberVar[int]], ANOTHER_ARRAY_VAR] + | Callable[[Var[INNER_ARRAY_VAR]], ANOTHER_ARRAY_VAR] | Callable[[], ANOTHER_ARRAY_VAR], ) -> ArrayVar[Sequence[ANOTHER_ARRAY_VAR]]: """Apply a function to each element of the array. @@ -987,21 +989,36 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)): Raises: VarTypeError: If the function takes more than one argument. """ + from reflex.state import ComponentState + from .function import ArgsFunctionOperation if not callable(fn): raise_unsupported_operand_types("foreach", (type(self), type(fn))) # get the number of arguments of the function num_args = len(inspect.signature(fn).parameters) - if num_args > 1: + if num_args > 2: raise VarTypeError( - "The function passed to foreach should take at most one argument." + "The function passed to foreach should take at most two arguments." + ) + + if ( + hasattr(fn, "__qualname__") + and fn.__qualname__ == ComponentState.create.__qualname__ + ): + raise TypeError( + "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet." ) if num_args == 0: - return_value = fn() # type: ignore + fn_result = fn() # pyright: ignore [reportCallIssue] + return_value = Var.create(fn_result) simple_function_var: FunctionVar[ReflexCallable[[], ANOTHER_ARRAY_VAR]] = ( - ArgsFunctionOperation.create((), return_value) + ArgsFunctionOperation.create( + (), + return_value, + _var_type=ReflexCallable[[], return_value._var_type], + ) ) return map_array_operation(self, simple_function_var).guess_type() @@ -1021,11 +1038,40 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)): ).guess_type(), ) + if num_args == 1: + fn_result = fn(first_arg) # pyright: ignore [reportCallIssue] + + return_value = Var.create(fn_result) + + function_var = cast( + Var[ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]], + ArgsFunctionOperation.create( + (arg_name,), + return_value, + _var_type=ReflexCallable[[first_arg_type], return_value._var_type], + ), + ) + + return map_array_operation.call(self, function_var).guess_type() + + second_arg = cast( + NumberVar[int], + Var( + _js_expr=get_unique_variable_name(), + _var_type=int, + ).guess_type(), + ) + + fn_result = fn(first_arg, second_arg) # pyright: ignore [reportCallIssue] + + return_value = Var.create(fn_result) + function_var = cast( - Var[ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]], + Var[ReflexCallable[[INNER_ARRAY_VAR, int], ANOTHER_ARRAY_VAR]], ArgsFunctionOperation.create( - (arg_name,), - Var.create(fn(first_arg)), # type: ignore + (arg_name, second_arg._js_expr), + return_value, + _var_type=ReflexCallable[[first_arg_type, int], return_value._var_type], ), ) diff --git a/tests/units/components/core/test_foreach.py b/tests/units/components/core/test_foreach.py index ddc385f65..c1659540b 100644 --- a/tests/units/components/core/test_foreach.py +++ b/tests/units/components/core/test_foreach.py @@ -6,16 +6,11 @@ import pytest from reflex import el from reflex.base import Base from reflex.components.component import Component -from reflex.components.core.foreach import ( - Foreach, - ForeachRenderError, - ForeachVarError, - foreach, -) +from reflex.components.core.foreach import ForeachVarError, foreach from reflex.components.radix.themes.layout.box import box from reflex.components.radix.themes.typography.text import text from reflex.state import BaseState, ComponentState -from reflex.vars.base import Var +from reflex.utils.exceptions import VarTypeError from reflex.vars.number import NumberVar from reflex.vars.sequence import ArrayVar @@ -141,143 +136,35 @@ def display_color_index_tuple(color): seen_index_vars = set() -@pytest.mark.parametrize( - "state_var, render_fn, render_dict", - [ - ( - ForEachState.colors_list, - display_color, - { - "iterable_state": f"{ForEachState.get_full_name()}.colors_list", - "iterable_type": "list", - }, - ), - ( - ForEachState.colors_dict_list, - display_color_name, - { - "iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list", - "iterable_type": "list", - }, - ), - ( - ForEachState.colors_nested_dict_list, - display_shade, - { - "iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list", - "iterable_type": "list", - }, - ), - ( - ForEachState.primary_color, - display_primary_colors, - { - "iterable_state": f"{ForEachState.get_full_name()}.primary_color", - "iterable_type": "dict", - }, - ), - ( - ForEachState.color_with_shades, - display_color_with_shades, - { - "iterable_state": f"{ForEachState.get_full_name()}.color_with_shades", - "iterable_type": "dict", - }, - ), - ( - ForEachState.nested_colors_with_shades, - display_nested_color_with_shades, - { - "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_with_shades", - "iterable_type": "dict", - }, - ), - ( - ForEachState.nested_colors_with_shades, - display_nested_color_with_shades_v2, - { - "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_with_shades", - "iterable_type": "dict", - }, - ), - ( - ForEachState.color_tuple, - display_color_tuple, - { - "iterable_state": f"{ForEachState.get_full_name()}.color_tuple", - "iterable_type": "tuple", - }, - ), - ( - ForEachState.colors_set, - display_colors_set, - { - "iterable_state": f"{ForEachState.get_full_name()}.colors_set", - "iterable_type": "set", - }, - ), - ( - ForEachState.nested_colors_list, - lambda el, i: display_nested_list_element(el, i), - { - "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list", - "iterable_type": "list", - }, - ), - ( - ForEachState.color_index_tuple, - display_color_index_tuple, - { - "iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple", - "iterable_type": "tuple", - }, - ), - ], -) -def test_foreach_render(state_var, render_fn, render_dict): - """Test that the foreach component renders without error. - - Args: - state_var: the state var. - render_fn: The render callable - render_dict: return dict on calling `component.render` - """ - component = Foreach.create(state_var, render_fn) - - rend = component.render() - assert rend["iterable_state"] == render_dict["iterable_state"] - assert rend["iterable_type"] == render_dict["iterable_type"] - - # Make sure the index vars are unique. - arg_index = rend["arg_index"] - assert isinstance(arg_index, Var) - assert arg_index._js_expr not in seen_index_vars - assert arg_index._var_type is int - seen_index_vars.add(arg_index._js_expr) - - def test_foreach_bad_annotations(): """Test that the foreach component raises a ForeachVarError if the iterable is of type Any.""" with pytest.raises(ForeachVarError): - Foreach.create( + foreach( ForEachState.bad_annotation_list, - lambda sublist: Foreach.create(sublist, lambda color: text(color)), + lambda sublist: foreach(sublist, lambda color: text(color)), ) def test_foreach_no_param_in_signature(): - """Test that the foreach component raises a ForeachRenderError if no parameters are passed.""" - with pytest.raises(ForeachRenderError): - Foreach.create( - ForEachState.colors_list, - lambda: text("color"), - ) + """Test that the foreach component DOES NOT raise an error if no parameters are passed.""" + foreach( + ForEachState.colors_list, + lambda: text("color"), + ) + + +def test_foreach_with_index(): + """Test that the foreach component works with an index.""" + foreach( + ForEachState.colors_list, + lambda color, index: text(color, index), + ) def test_foreach_too_many_params_in_signature(): """Test that the foreach component raises a ForeachRenderError if too many parameters are passed.""" - with pytest.raises(ForeachRenderError): - Foreach.create( + with pytest.raises(VarTypeError): + foreach( ForEachState.colors_list, lambda color, index, extra: text(color), ) @@ -292,13 +179,13 @@ def test_foreach_component_styles(): ) ) component._add_style_recursive({box: {"color": "red"}}) - assert 'css={({ ["color"] : "red" })}' in str(component) + assert '{ ["css"] : ({ ["color"] : "red" }) }' in str(component) def test_foreach_component_state(): """Test that using a component state to render in the foreach raises an error.""" with pytest.raises(TypeError): - Foreach.create( + foreach( ForEachState.colors_list, ComponentStateTest.create, ) @@ -306,7 +193,7 @@ def test_foreach_component_state(): def test_foreach_default_factory(): """Test that the default factory is called.""" - _ = Foreach.create( + _ = foreach( ForEachState.default_factory_list, lambda tag: text(tag.name), ) diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 64b798cf3..2544674f7 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -1446,7 +1446,6 @@ def test_instantiate_all_components(): untested_components = { "Card", "DebounceInput", - "Foreach", "FormControl", "Html", "Icon", @@ -2147,14 +2146,11 @@ def test_add_style_foreach(): page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i))) page._add_style_recursive(Style()) - # Expect only a single child of the foreach on the python side - assert len(page.children[0].children) == 1 - # Expect the style to be added to the child of the foreach - assert 'css={({ ["color"] : "red" })}' in str(page.children[0].children[0]) + assert '({ ["css"] : ({ ["color"] : "red" }) }),' in str(page.children[0]) # Expect only one instance of this CSS dict in the rendered page - assert str(page).count('css={({ ["color"] : "red" })}') == 1 + assert str(page).count('({ ["css"] : ({ ["color"] : "red" }) }),') == 1 class TriggerState(rx.State):