diff --git a/reflex/components/core/foreach.py b/reflex/components/core/foreach.py index 8bca0a3db..88f2886a8 100644 --- a/reflex/components/core/foreach.py +++ b/reflex/components/core/foreach.py @@ -2,16 +2,24 @@ from __future__ import annotations import inspect -from hashlib import md5 from typing import Any, Callable, Iterable from reflex.components.base.fragment import Fragment from reflex.components.component import Component from reflex.components.tags import IterTag from reflex.constants import MemoizationMode +from reflex.utils import console from reflex.vars import Var +class ForeachVarError(TypeError): + """Raised when the iterable type is Any.""" + + +class ForeachRenderError(TypeError): + """Raised when there is an error with the foreach render function.""" + + class Foreach(Component): """A component that takes in an iterable and a render function and renders a list of components.""" @@ -24,56 +32,84 @@ class Foreach(Component): render_fn: Callable = Fragment.create @classmethod - def create(cls, iterable: Var[Iterable], render_fn: Callable, **props) -> Foreach: + def create( + cls, + iterable: Var[Iterable] | Iterable, + render_fn: Callable, + **props, + ) -> Foreach: """Create a foreach component. Args: iterable: The iterable to create components from. render_fn: A function from the render args to the component. - **props: The attributes to pass to each child component. + **props: The attributes to pass to each child component (deprecated). Returns: The foreach component. Raises: - TypeError: If the iterable is of type Any. + ForeachVarError: If the iterable is of type Any. """ - iterable = Var.create(iterable) # type: ignore + if props: + console.deprecate( + feature_name="Passing props to rx.foreach", + reason="it does not have the intended effect and may be confusing", + deprecation_version="0.5.0", + removal_version="0.6.0", + ) + iterable = Var.create_safe(iterable) if iterable._var_type == Any: - raise TypeError( - f"Could not foreach over var of type Any. (If you are trying to foreach over a state var, add a type annotation to the var.)" + raise ForeachVarError( + f"Could not foreach over var `{iterable._var_full_name}` of type Any. " + "(If you are trying to foreach over a state var, add a type annotation to the var). " + "See https://reflex.dev/docs/library/layout/foreach/" ) component = cls( iterable=iterable, render_fn=render_fn, - **props, ) - # Keep a ref to a rendered component to determine correct imports. - component.children = [ - component._render(props=dict(index_var_name="i")).render_component() - ] + # Keep a ref to a rendered component to determine correct imports/hooks/styles. + component.children = [component._render().render_component()] return component - def _render(self, props: dict[str, Any] | None = None) -> IterTag: - props = {} if props is None else props.copy() + def _render(self) -> IterTag: + props = {} - # Determine the arg var name based on the params accepted by render_fn. render_sig = inspect.signature(self.render_fn) params = list(render_sig.parameters.values()) - if len(params) >= 1: - props.setdefault("arg_var_name", params[0].name) - if len(params) >= 2: + # Validate the render function signature. + if len(params) == 0 or len(params) > 2: + raise ForeachRenderError( + "Expected 1 or 2 parameters in foreach render function, got " + f"{[p.name for p in params]}. See https://reflex.dev/docs/library/layout/foreach/" + ) + + if len(params) >= 1: + # Determine the arg var name based on the params accepted by render_fn. + props["arg_var_name"] = params[0].name + + if len(params) == 2: # Determine the index var name based on the params accepted by render_fn. - props.setdefault("index_var_name", params[1].name) - elif "index_var_name" not in props: - # Otherwise, use a deterministic index, based on the rendered code. - code_hash = md5(str(self.children[0].render()).encode("utf-8")).hexdigest() - props.setdefault("index_var_name", f"index_{code_hash}") + props["index_var_name"] = params[1].name + else: + # Otherwise, use a deterministic index, based on the render function bytecode. + code_hash = ( + hash(self.render_fn.__code__) + .to_bytes( + length=8, + byteorder="big", + signed=True, + ) + .hex() + ) + props["index_var_name"] = f"index_{code_hash}" return IterTag( iterable=self.iterable, render_fn=self.render_fn, + children=self.children, **props, ) @@ -84,19 +120,9 @@ class Foreach(Component): The dictionary for template of component. """ tag = self._render() - component = tag.render_component() return dict( - tag.add_props( - **self.event_triggers, - key=self.key, - sx=self.style, - id=self.id, - class_name=self.class_name, - ).set( - children=[component.render()], - props=tag.format_props(), - ), + tag, iterable_state=tag.iterable._var_full_name, arg_name=tag.arg_var_name, arg_index=tag.get_index_var_arg(), diff --git a/tests/components/core/test_foreach.py b/tests/components/core/test_foreach.py index 34e43e94d..9691ed50e 100644 --- a/tests/components/core/test_foreach.py +++ b/tests/components/core/test_foreach.py @@ -2,17 +2,11 @@ from typing import Dict, List, Set, Tuple, Union import pytest -from reflex.components import box, foreach, text -from reflex.components.core import Foreach +from reflex.components import box, el, foreach, text +from reflex.components.core.foreach import Foreach, ForeachRenderError, ForeachVarError from reflex.state import BaseState from reflex.vars import Var -try: - # When pydantic v2 is installed - from pydantic.v1 import ValidationError # type: ignore -except ImportError: - from pydantic import ValidationError - class ForEachState(BaseState): """A state for testing the ForEach component.""" @@ -84,12 +78,12 @@ def display_nested_color_with_shades_v2(color): def display_color_tuple(color): assert color._var_type == str - return box(text(color, "tuple")) + return box(text(color)) def display_colors_set(color): assert color._var_type == str - return box(text(color, "set")) + return box(text(color)) def display_nested_list_element(element: Var[str], index: Var[int]): @@ -100,7 +94,7 @@ def display_nested_list_element(element: Var[str], index: Var[int]): def display_color_index_tuple(color): assert color._var_type == Union[int, str] - return box(text(color, "index_tuple")) + return box(text(color)) seen_index_vars = set() @@ -215,24 +209,46 @@ def test_foreach_render(state_var, render_fn, render_dict): # Make sure the index vars are unique. arg_index = rend["arg_index"] + assert isinstance(arg_index, Var) assert arg_index._var_name not in seen_index_vars assert arg_index._var_type == int seen_index_vars.add(arg_index._var_name) def test_foreach_bad_annotations(): - """Test that the foreach component raises a TypeError if the iterable is of type Any.""" - with pytest.raises(TypeError): + """Test that the foreach component raises a ForeachVarError if the iterable is of type Any.""" + with pytest.raises(ForeachVarError): Foreach.create( - ForEachState.bad_annotation_list, # type: ignore + ForEachState.bad_annotation_list, lambda sublist: Foreach.create(sublist, lambda color: text(color)), ) def test_foreach_no_param_in_signature(): - """Test that the foreach component raises a TypeError if no parameters are passed.""" - with pytest.raises(ValidationError): + """Test that the foreach component raises a ForeachRenderError if no parameters are passed.""" + with pytest.raises(ForeachRenderError): Foreach.create( - ForEachState.colors_list, # type: ignore + ForEachState.colors_list, lambda: text("color"), ) + + +def test_foreach_too_many_params_in_signature(): + """Test that the foreach component raises a ForeachRenderError if too many parameters are passed.""" + with pytest.raises(ForeachRenderError): + Foreach.create( + ForEachState.colors_list, + lambda color, index, extra: text(color), + ) + + +def test_foreach_component_styles(): + """Test that the foreach component works with global component styles.""" + component = el.div( + foreach( + ForEachState.colors_list, + display_color, + ) + ) + component._add_style_recursive({box: {"color": "red"}}) + assert 'css={{"color": "red"}}' in str(component) diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 96c1b6962..c2a28c84d 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -1951,3 +1951,24 @@ def test_component_add_custom_code(): "const custom_code5 = 46", "const custom_code6 = 47", } + + +def test_add_style_foreach(): + class StyledComponent(Component): + tag = "StyledComponent" + ix: Var[int] + + def add_style(self): + return Style({"color": "red"}) + + page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i))) + page._add_style_recursive(Style()) + + # Expect only a single child of the foreach on the python side + assert len(page.children[0].children) == 1 + + # Expect the style to be added to the child of the foreach + assert 'css={{"color": "red"}}' in str(page.children[0].children[0]) + + # Expect only one instance of this CSS dict in the rendered page + assert str(page).count('css={{"color": "red"}}') == 1