Set unique index vars in rx.foreach (#2126)

This commit is contained in:
Nikhil Rao 2023-11-03 16:20:42 -07:00 committed by GitHub
parent e6b02555f4
commit e703d87450
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 37 deletions

View File

@ -1,6 +1,7 @@
"""Create a list of components from an iterable.""" """Create a list of components from an iterable."""
from __future__ import annotations from __future__ import annotations
import typing
from typing import Any, Callable, Iterable from typing import Any, Callable, Iterable
from reflex.components.component import Component from reflex.components.component import Component
@ -47,15 +48,20 @@ class Foreach(Component):
f"Could not foreach over var of type Any. (If you are trying to foreach over a state var, add a type annotation to the var.)" f"Could not foreach over var of type Any. (If you are trying to foreach over a state var, add a type annotation to the var.)"
) )
arg = BaseVar(_var_name="_", _var_type=type_, _var_is_local=True) arg = BaseVar(_var_name="_", _var_type=type_, _var_is_local=True)
comp = IterTag(iterable=iterable, render_fn=render_fn).render_component(arg)
return cls( return cls(
iterable=iterable, iterable=iterable,
render_fn=render_fn, render_fn=render_fn,
children=[IterTag.render_component(render_fn, arg=arg)], children=[comp],
**props, **props,
) )
def _render(self) -> IterTag: def _render(self) -> IterTag:
return IterTag(iterable=self.iterable, render_fn=self.render_fn) return IterTag(
iterable=self.iterable,
render_fn=self.render_fn,
index_var_name=get_unique_variable_name(),
)
def render(self): def render(self):
"""Render the component. """Render the component.
@ -66,9 +72,9 @@ class Foreach(Component):
tag = self._render() tag = self._render()
try: try:
type_ = ( type_ = (
self.iterable._var_type tag.iterable._var_type
if self.iterable._var_type.mro()[0] == dict if tag.iterable._var_type.mro()[0] == dict
else self.iterable._var_type.__args__[0] else typing.get_args(tag.iterable._var_type)[0]
) )
except Exception: except Exception:
type_ = Any type_ = Any
@ -77,7 +83,7 @@ class Foreach(Component):
_var_type=type_, _var_type=type_,
) )
index_arg = tag.get_index_var_arg() index_arg = tag.get_index_var_arg()
component = tag.render_component(self.render_fn, arg) component = tag.render_component(arg)
return dict( return dict(
tag.add_props( tag.add_props(
**self.event_triggers, **self.event_triggers,

View File

@ -11,9 +11,6 @@ if TYPE_CHECKING:
from reflex.components.component import Component from reflex.components.component import Component
INDEX_VAR = "i"
class IterTag(Tag): class IterTag(Tag):
"""An iterator tag.""" """An iterator tag."""
@ -23,37 +20,40 @@ class IterTag(Tag):
# The component render function for each item in the iterable. # The component render function for each item in the iterable.
render_fn: Callable render_fn: Callable
@staticmethod # The name of the index var.
def get_index_var() -> Var: index_var_name: str = "i"
"""Get the index var for the tag.
def get_index_var(self) -> Var:
"""Get the index var for the tag (with curly braces).
This is used to reference the index var within the tag.
Returns: Returns:
The index var. The index var.
""" """
return BaseVar( return BaseVar(
_var_name=INDEX_VAR, _var_name=self.index_var_name,
_var_type=int, _var_type=int,
) )
@staticmethod def get_index_var_arg(self) -> Var:
def get_index_var_arg() -> Var: """Get the index var for the tag (without curly braces).
"""Get the index var for the tag.
This is used to render the index var in the .map() function.
Returns: Returns:
The index var. The index var.
""" """
return BaseVar( return BaseVar(
_var_name=INDEX_VAR, _var_name=self.index_var_name,
_var_type=int, _var_type=int,
_var_is_local=True, _var_is_local=True,
) )
@staticmethod def render_component(self, arg: Var) -> Component:
def render_component(render_fn: Callable, arg: Var) -> Component:
"""Render the component. """Render the component.
Args: Args:
render_fn: The render function.
arg: The argument to pass to the render function. arg: The argument to pass to the render function.
Returns: Returns:
@ -65,16 +65,16 @@ class IterTag(Tag):
from reflex.components.layout.fragment import Fragment from reflex.components.layout.fragment import Fragment
# Get the render function arguments. # Get the render function arguments.
args = inspect.getfullargspec(render_fn).args args = inspect.getfullargspec(self.render_fn).args
index = IterTag.get_index_var() index = self.get_index_var()
if len(args) == 1: if len(args) == 1:
# If the render function doesn't take the index as an argument. # If the render function doesn't take the index as an argument.
component = render_fn(arg) component = self.render_fn(arg)
else: else:
# If the render function takes the index as an argument. # If the render function takes the index as an argument.
assert len(args) == 2 assert len(args) == 2
component = render_fn(arg, index) component = self.render_fn(arg, index)
# Nested foreach components or cond must be wrapped in fragments. # Nested foreach components or cond must be wrapped in fragments.
if isinstance(component, (Foreach, Cond)): if isinstance(component, (Foreach, Cond)):

View File

@ -78,6 +78,9 @@ def display_nested_list_element(element: str, index: int):
return box(text(element[index])) return box(text(element[index]))
seen_index_vars = set()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"state_var, render_fn, render_dict", "state_var, render_fn, render_dict",
[ [
@ -86,7 +89,6 @@ def display_nested_list_element(element: str, index: int):
display_color, display_color,
{ {
"iterable_state": "for_each_state.colors_list", "iterable_state": "for_each_state.colors_list",
"arg_index": "i",
"iterable_type": "list", "iterable_type": "list",
}, },
), ),
@ -95,7 +97,6 @@ def display_nested_list_element(element: str, index: int):
display_color_name, display_color_name,
{ {
"iterable_state": "for_each_state.colors_dict_list", "iterable_state": "for_each_state.colors_dict_list",
"arg_index": "i",
"iterable_type": "list", "iterable_type": "list",
}, },
), ),
@ -104,7 +105,6 @@ def display_nested_list_element(element: str, index: int):
display_shade, display_shade,
{ {
"iterable_state": "for_each_state.colors_nested_dict_list", "iterable_state": "for_each_state.colors_nested_dict_list",
"arg_index": "i",
"iterable_type": "list", "iterable_type": "list",
}, },
), ),
@ -113,7 +113,6 @@ def display_nested_list_element(element: str, index: int):
display_primary_colors, display_primary_colors,
{ {
"iterable_state": "for_each_state.primary_color", "iterable_state": "for_each_state.primary_color",
"arg_index": "i",
"iterable_type": "dict", "iterable_type": "dict",
}, },
), ),
@ -122,7 +121,6 @@ def display_nested_list_element(element: str, index: int):
display_color_with_shades, display_color_with_shades,
{ {
"iterable_state": "for_each_state.color_with_shades", "iterable_state": "for_each_state.color_with_shades",
"arg_index": "i",
"iterable_type": "dict", "iterable_type": "dict",
}, },
), ),
@ -131,7 +129,6 @@ def display_nested_list_element(element: str, index: int):
display_nested_color_with_shades, display_nested_color_with_shades,
{ {
"iterable_state": "for_each_state.nested_colors_with_shades", "iterable_state": "for_each_state.nested_colors_with_shades",
"arg_index": "i",
"iterable_type": "dict", "iterable_type": "dict",
}, },
), ),
@ -140,7 +137,6 @@ def display_nested_list_element(element: str, index: int):
display_nested_color_with_shades_v2, display_nested_color_with_shades_v2,
{ {
"iterable_state": "for_each_state.nested_colors_with_shades", "iterable_state": "for_each_state.nested_colors_with_shades",
"arg_index": "i",
"iterable_type": "dict", "iterable_type": "dict",
}, },
), ),
@ -149,7 +145,6 @@ def display_nested_list_element(element: str, index: int):
display_color_tuple, display_color_tuple,
{ {
"iterable_state": "for_each_state.color_tuple", "iterable_state": "for_each_state.color_tuple",
"arg_index": "i",
"iterable_type": "tuple", "iterable_type": "tuple",
}, },
), ),
@ -158,7 +153,6 @@ def display_nested_list_element(element: str, index: int):
display_colors_set, display_colors_set,
{ {
"iterable_state": "for_each_state.colors_set", "iterable_state": "for_each_state.colors_set",
"arg_index": "i",
"iterable_type": "set", "iterable_type": "set",
}, },
), ),
@ -167,7 +161,6 @@ def display_nested_list_element(element: str, index: int):
lambda el, i: display_nested_list_element(el, i), lambda el, i: display_nested_list_element(el, i),
{ {
"iterable_state": "for_each_state.nested_colors_list", "iterable_state": "for_each_state.nested_colors_list",
"arg_index": "i",
"iterable_type": "list", "iterable_type": "list",
}, },
), ),
@ -184,8 +177,11 @@ def test_foreach_render(state_var, render_fn, render_dict):
component = Foreach.create(state_var, render_fn) component = Foreach.create(state_var, render_fn)
rend = component.render() rend = component.render()
arg_index = rend["arg_index"]
assert rend["iterable_state"] == render_dict["iterable_state"] assert rend["iterable_state"] == render_dict["iterable_state"]
assert arg_index._var_name == render_dict["arg_index"]
assert arg_index._var_type == int
assert rend["iterable_type"] == render_dict["iterable_type"] assert rend["iterable_type"] == render_dict["iterable_type"]
# Make sure the index vars are unique.
arg_index = rend["arg_index"]
assert arg_index._var_name not in seen_index_vars
assert arg_index._var_type == int
seen_index_vars.add(arg_index._var_name)