From e703d8745002006ef69160040a2b291565ef2944 Mon Sep 17 00:00:00 2001 From: Nikhil Rao Date: Fri, 3 Nov 2023 16:20:42 -0700 Subject: [PATCH] Set unique index vars in rx.foreach (#2126) --- reflex/components/layout/foreach.py | 18 ++++++++----- reflex/components/tags/iter_tag.py | 36 ++++++++++++------------- tests/components/layout/test_foreach.py | 22 +++++++-------- 3 files changed, 39 insertions(+), 37 deletions(-) diff --git a/reflex/components/layout/foreach.py b/reflex/components/layout/foreach.py index 325615577..dc00e767c 100644 --- a/reflex/components/layout/foreach.py +++ b/reflex/components/layout/foreach.py @@ -1,6 +1,7 @@ """Create a list of components from an iterable.""" from __future__ import annotations +import typing from typing import Any, Callable, Iterable from reflex.components.component import Component @@ -47,15 +48,20 @@ class Foreach(Component): f"Could not foreach over var of type Any. (If you are trying to foreach over a state var, add a type annotation to the var.)" ) arg = BaseVar(_var_name="_", _var_type=type_, _var_is_local=True) + comp = IterTag(iterable=iterable, render_fn=render_fn).render_component(arg) return cls( iterable=iterable, render_fn=render_fn, - children=[IterTag.render_component(render_fn, arg=arg)], + children=[comp], **props, ) def _render(self) -> IterTag: - return IterTag(iterable=self.iterable, render_fn=self.render_fn) + return IterTag( + iterable=self.iterable, + render_fn=self.render_fn, + index_var_name=get_unique_variable_name(), + ) def render(self): """Render the component. @@ -66,9 +72,9 @@ class Foreach(Component): tag = self._render() try: type_ = ( - self.iterable._var_type - if self.iterable._var_type.mro()[0] == dict - else self.iterable._var_type.__args__[0] + tag.iterable._var_type + if tag.iterable._var_type.mro()[0] == dict + else typing.get_args(tag.iterable._var_type)[0] ) except Exception: type_ = Any @@ -77,7 +83,7 @@ class Foreach(Component): _var_type=type_, ) index_arg = tag.get_index_var_arg() - component = tag.render_component(self.render_fn, arg) + component = tag.render_component(arg) return dict( tag.add_props( **self.event_triggers, diff --git a/reflex/components/tags/iter_tag.py b/reflex/components/tags/iter_tag.py index e70e2416b..1900a1f70 100644 --- a/reflex/components/tags/iter_tag.py +++ b/reflex/components/tags/iter_tag.py @@ -11,9 +11,6 @@ if TYPE_CHECKING: from reflex.components.component import Component -INDEX_VAR = "i" - - class IterTag(Tag): """An iterator tag.""" @@ -23,37 +20,40 @@ class IterTag(Tag): # The component render function for each item in the iterable. render_fn: Callable - @staticmethod - def get_index_var() -> Var: - """Get the index var for the tag. + # The name of the index var. + index_var_name: str = "i" + + def get_index_var(self) -> Var: + """Get the index var for the tag (with curly braces). + + This is used to reference the index var within the tag. Returns: The index var. """ return BaseVar( - _var_name=INDEX_VAR, + _var_name=self.index_var_name, _var_type=int, ) - @staticmethod - def get_index_var_arg() -> Var: - """Get the index var for the tag. + def get_index_var_arg(self) -> Var: + """Get the index var for the tag (without curly braces). + + This is used to render the index var in the .map() function. Returns: The index var. """ return BaseVar( - _var_name=INDEX_VAR, + _var_name=self.index_var_name, _var_type=int, _var_is_local=True, ) - @staticmethod - def render_component(render_fn: Callable, arg: Var) -> Component: + def render_component(self, arg: Var) -> Component: """Render the component. Args: - render_fn: The render function. arg: The argument to pass to the render function. Returns: @@ -65,16 +65,16 @@ class IterTag(Tag): from reflex.components.layout.fragment import Fragment # Get the render function arguments. - args = inspect.getfullargspec(render_fn).args - index = IterTag.get_index_var() + args = inspect.getfullargspec(self.render_fn).args + index = self.get_index_var() if len(args) == 1: # If the render function doesn't take the index as an argument. - component = render_fn(arg) + component = self.render_fn(arg) else: # If the render function takes the index as an argument. assert len(args) == 2 - component = render_fn(arg, index) + component = self.render_fn(arg, index) # Nested foreach components or cond must be wrapped in fragments. if isinstance(component, (Foreach, Cond)): diff --git a/tests/components/layout/test_foreach.py b/tests/components/layout/test_foreach.py index 5b407d77f..11c8be563 100644 --- a/tests/components/layout/test_foreach.py +++ b/tests/components/layout/test_foreach.py @@ -78,6 +78,9 @@ def display_nested_list_element(element: str, index: int): return box(text(element[index])) +seen_index_vars = set() + + @pytest.mark.parametrize( "state_var, render_fn, render_dict", [ @@ -86,7 +89,6 @@ def display_nested_list_element(element: str, index: int): display_color, { "iterable_state": "for_each_state.colors_list", - "arg_index": "i", "iterable_type": "list", }, ), @@ -95,7 +97,6 @@ def display_nested_list_element(element: str, index: int): display_color_name, { "iterable_state": "for_each_state.colors_dict_list", - "arg_index": "i", "iterable_type": "list", }, ), @@ -104,7 +105,6 @@ def display_nested_list_element(element: str, index: int): display_shade, { "iterable_state": "for_each_state.colors_nested_dict_list", - "arg_index": "i", "iterable_type": "list", }, ), @@ -113,7 +113,6 @@ def display_nested_list_element(element: str, index: int): display_primary_colors, { "iterable_state": "for_each_state.primary_color", - "arg_index": "i", "iterable_type": "dict", }, ), @@ -122,7 +121,6 @@ def display_nested_list_element(element: str, index: int): display_color_with_shades, { "iterable_state": "for_each_state.color_with_shades", - "arg_index": "i", "iterable_type": "dict", }, ), @@ -131,7 +129,6 @@ def display_nested_list_element(element: str, index: int): display_nested_color_with_shades, { "iterable_state": "for_each_state.nested_colors_with_shades", - "arg_index": "i", "iterable_type": "dict", }, ), @@ -140,7 +137,6 @@ def display_nested_list_element(element: str, index: int): display_nested_color_with_shades_v2, { "iterable_state": "for_each_state.nested_colors_with_shades", - "arg_index": "i", "iterable_type": "dict", }, ), @@ -149,7 +145,6 @@ def display_nested_list_element(element: str, index: int): display_color_tuple, { "iterable_state": "for_each_state.color_tuple", - "arg_index": "i", "iterable_type": "tuple", }, ), @@ -158,7 +153,6 @@ def display_nested_list_element(element: str, index: int): display_colors_set, { "iterable_state": "for_each_state.colors_set", - "arg_index": "i", "iterable_type": "set", }, ), @@ -167,7 +161,6 @@ def display_nested_list_element(element: str, index: int): lambda el, i: display_nested_list_element(el, i), { "iterable_state": "for_each_state.nested_colors_list", - "arg_index": "i", "iterable_type": "list", }, ), @@ -184,8 +177,11 @@ def test_foreach_render(state_var, render_fn, render_dict): component = Foreach.create(state_var, render_fn) rend = component.render() - arg_index = rend["arg_index"] assert rend["iterable_state"] == render_dict["iterable_state"] - assert arg_index._var_name == render_dict["arg_index"] - assert arg_index._var_type == int assert rend["iterable_type"] == render_dict["iterable_type"] + + # Make sure the index vars are unique. + arg_index = rend["arg_index"] + assert arg_index._var_name not in seen_index_vars + assert arg_index._var_type == int + seen_index_vars.add(arg_index._var_name)