Set unique index vars in rx.foreach (#2126)

This commit is contained in:
Nikhil Rao 2023-11-03 16:20:42 -07:00 committed by GitHub
parent e6b02555f4
commit e703d87450
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 37 deletions

View File

@ -1,6 +1,7 @@
"""Create a list of components from an iterable."""
from __future__ import annotations
import typing
from typing import Any, Callable, Iterable
from reflex.components.component import Component
@ -47,15 +48,20 @@ class Foreach(Component):
f"Could not foreach over var of type Any. (If you are trying to foreach over a state var, add a type annotation to the var.)"
)
arg = BaseVar(_var_name="_", _var_type=type_, _var_is_local=True)
comp = IterTag(iterable=iterable, render_fn=render_fn).render_component(arg)
return cls(
iterable=iterable,
render_fn=render_fn,
children=[IterTag.render_component(render_fn, arg=arg)],
children=[comp],
**props,
)
def _render(self) -> IterTag:
return IterTag(iterable=self.iterable, render_fn=self.render_fn)
return IterTag(
iterable=self.iterable,
render_fn=self.render_fn,
index_var_name=get_unique_variable_name(),
)
def render(self):
"""Render the component.
@ -66,9 +72,9 @@ class Foreach(Component):
tag = self._render()
try:
type_ = (
self.iterable._var_type
if self.iterable._var_type.mro()[0] == dict
else self.iterable._var_type.__args__[0]
tag.iterable._var_type
if tag.iterable._var_type.mro()[0] == dict
else typing.get_args(tag.iterable._var_type)[0]
)
except Exception:
type_ = Any
@ -77,7 +83,7 @@ class Foreach(Component):
_var_type=type_,
)
index_arg = tag.get_index_var_arg()
component = tag.render_component(self.render_fn, arg)
component = tag.render_component(arg)
return dict(
tag.add_props(
**self.event_triggers,

View File

@ -11,9 +11,6 @@ if TYPE_CHECKING:
from reflex.components.component import Component
INDEX_VAR = "i"
class IterTag(Tag):
"""An iterator tag."""
@ -23,37 +20,40 @@ class IterTag(Tag):
# The component render function for each item in the iterable.
render_fn: Callable
@staticmethod
def get_index_var() -> Var:
"""Get the index var for the tag.
# The name of the index var.
index_var_name: str = "i"
def get_index_var(self) -> Var:
"""Get the index var for the tag (with curly braces).
This is used to reference the index var within the tag.
Returns:
The index var.
"""
return BaseVar(
_var_name=INDEX_VAR,
_var_name=self.index_var_name,
_var_type=int,
)
@staticmethod
def get_index_var_arg() -> Var:
"""Get the index var for the tag.
def get_index_var_arg(self) -> Var:
"""Get the index var for the tag (without curly braces).
This is used to render the index var in the .map() function.
Returns:
The index var.
"""
return BaseVar(
_var_name=INDEX_VAR,
_var_name=self.index_var_name,
_var_type=int,
_var_is_local=True,
)
@staticmethod
def render_component(render_fn: Callable, arg: Var) -> Component:
def render_component(self, arg: Var) -> Component:
"""Render the component.
Args:
render_fn: The render function.
arg: The argument to pass to the render function.
Returns:
@ -65,16 +65,16 @@ class IterTag(Tag):
from reflex.components.layout.fragment import Fragment
# Get the render function arguments.
args = inspect.getfullargspec(render_fn).args
index = IterTag.get_index_var()
args = inspect.getfullargspec(self.render_fn).args
index = self.get_index_var()
if len(args) == 1:
# If the render function doesn't take the index as an argument.
component = render_fn(arg)
component = self.render_fn(arg)
else:
# If the render function takes the index as an argument.
assert len(args) == 2
component = render_fn(arg, index)
component = self.render_fn(arg, index)
# Nested foreach components or cond must be wrapped in fragments.
if isinstance(component, (Foreach, Cond)):

View File

@ -78,6 +78,9 @@ def display_nested_list_element(element: str, index: int):
return box(text(element[index]))
seen_index_vars = set()
@pytest.mark.parametrize(
"state_var, render_fn, render_dict",
[
@ -86,7 +89,6 @@ def display_nested_list_element(element: str, index: int):
display_color,
{
"iterable_state": "for_each_state.colors_list",
"arg_index": "i",
"iterable_type": "list",
},
),
@ -95,7 +97,6 @@ def display_nested_list_element(element: str, index: int):
display_color_name,
{
"iterable_state": "for_each_state.colors_dict_list",
"arg_index": "i",
"iterable_type": "list",
},
),
@ -104,7 +105,6 @@ def display_nested_list_element(element: str, index: int):
display_shade,
{
"iterable_state": "for_each_state.colors_nested_dict_list",
"arg_index": "i",
"iterable_type": "list",
},
),
@ -113,7 +113,6 @@ def display_nested_list_element(element: str, index: int):
display_primary_colors,
{
"iterable_state": "for_each_state.primary_color",
"arg_index": "i",
"iterable_type": "dict",
},
),
@ -122,7 +121,6 @@ def display_nested_list_element(element: str, index: int):
display_color_with_shades,
{
"iterable_state": "for_each_state.color_with_shades",
"arg_index": "i",
"iterable_type": "dict",
},
),
@ -131,7 +129,6 @@ def display_nested_list_element(element: str, index: int):
display_nested_color_with_shades,
{
"iterable_state": "for_each_state.nested_colors_with_shades",
"arg_index": "i",
"iterable_type": "dict",
},
),
@ -140,7 +137,6 @@ def display_nested_list_element(element: str, index: int):
display_nested_color_with_shades_v2,
{
"iterable_state": "for_each_state.nested_colors_with_shades",
"arg_index": "i",
"iterable_type": "dict",
},
),
@ -149,7 +145,6 @@ def display_nested_list_element(element: str, index: int):
display_color_tuple,
{
"iterable_state": "for_each_state.color_tuple",
"arg_index": "i",
"iterable_type": "tuple",
},
),
@ -158,7 +153,6 @@ def display_nested_list_element(element: str, index: int):
display_colors_set,
{
"iterable_state": "for_each_state.colors_set",
"arg_index": "i",
"iterable_type": "set",
},
),
@ -167,7 +161,6 @@ def display_nested_list_element(element: str, index: int):
lambda el, i: display_nested_list_element(el, i),
{
"iterable_state": "for_each_state.nested_colors_list",
"arg_index": "i",
"iterable_type": "list",
},
),
@ -184,8 +177,11 @@ def test_foreach_render(state_var, render_fn, render_dict):
component = Foreach.create(state_var, render_fn)
rend = component.render()
arg_index = rend["arg_index"]
assert rend["iterable_state"] == render_dict["iterable_state"]
assert arg_index._var_name == render_dict["arg_index"]
assert arg_index._var_type == int
assert rend["iterable_type"] == render_dict["iterable_type"]
# Make sure the index vars are unique.
arg_index = rend["arg_index"]
assert arg_index._var_name not in seen_index_vars
assert arg_index._var_type == int
seen_index_vars.add(arg_index._var_name)