Set unique index vars in rx.foreach (#2126)
This commit is contained in:
parent
e6b02555f4
commit
e703d87450
@ -1,6 +1,7 @@
|
||||
"""Create a list of components from an iterable."""
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
from reflex.components.component import Component
|
||||
@ -47,15 +48,20 @@ class Foreach(Component):
|
||||
f"Could not foreach over var of type Any. (If you are trying to foreach over a state var, add a type annotation to the var.)"
|
||||
)
|
||||
arg = BaseVar(_var_name="_", _var_type=type_, _var_is_local=True)
|
||||
comp = IterTag(iterable=iterable, render_fn=render_fn).render_component(arg)
|
||||
return cls(
|
||||
iterable=iterable,
|
||||
render_fn=render_fn,
|
||||
children=[IterTag.render_component(render_fn, arg=arg)],
|
||||
children=[comp],
|
||||
**props,
|
||||
)
|
||||
|
||||
def _render(self) -> IterTag:
|
||||
return IterTag(iterable=self.iterable, render_fn=self.render_fn)
|
||||
return IterTag(
|
||||
iterable=self.iterable,
|
||||
render_fn=self.render_fn,
|
||||
index_var_name=get_unique_variable_name(),
|
||||
)
|
||||
|
||||
def render(self):
|
||||
"""Render the component.
|
||||
@ -66,9 +72,9 @@ class Foreach(Component):
|
||||
tag = self._render()
|
||||
try:
|
||||
type_ = (
|
||||
self.iterable._var_type
|
||||
if self.iterable._var_type.mro()[0] == dict
|
||||
else self.iterable._var_type.__args__[0]
|
||||
tag.iterable._var_type
|
||||
if tag.iterable._var_type.mro()[0] == dict
|
||||
else typing.get_args(tag.iterable._var_type)[0]
|
||||
)
|
||||
except Exception:
|
||||
type_ = Any
|
||||
@ -77,7 +83,7 @@ class Foreach(Component):
|
||||
_var_type=type_,
|
||||
)
|
||||
index_arg = tag.get_index_var_arg()
|
||||
component = tag.render_component(self.render_fn, arg)
|
||||
component = tag.render_component(arg)
|
||||
return dict(
|
||||
tag.add_props(
|
||||
**self.event_triggers,
|
||||
|
@ -11,9 +11,6 @@ if TYPE_CHECKING:
|
||||
from reflex.components.component import Component
|
||||
|
||||
|
||||
INDEX_VAR = "i"
|
||||
|
||||
|
||||
class IterTag(Tag):
|
||||
"""An iterator tag."""
|
||||
|
||||
@ -23,37 +20,40 @@ class IterTag(Tag):
|
||||
# The component render function for each item in the iterable.
|
||||
render_fn: Callable
|
||||
|
||||
@staticmethod
|
||||
def get_index_var() -> Var:
|
||||
"""Get the index var for the tag.
|
||||
# The name of the index var.
|
||||
index_var_name: str = "i"
|
||||
|
||||
def get_index_var(self) -> Var:
|
||||
"""Get the index var for the tag (with curly braces).
|
||||
|
||||
This is used to reference the index var within the tag.
|
||||
|
||||
Returns:
|
||||
The index var.
|
||||
"""
|
||||
return BaseVar(
|
||||
_var_name=INDEX_VAR,
|
||||
_var_name=self.index_var_name,
|
||||
_var_type=int,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_index_var_arg() -> Var:
|
||||
"""Get the index var for the tag.
|
||||
def get_index_var_arg(self) -> Var:
|
||||
"""Get the index var for the tag (without curly braces).
|
||||
|
||||
This is used to render the index var in the .map() function.
|
||||
|
||||
Returns:
|
||||
The index var.
|
||||
"""
|
||||
return BaseVar(
|
||||
_var_name=INDEX_VAR,
|
||||
_var_name=self.index_var_name,
|
||||
_var_type=int,
|
||||
_var_is_local=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def render_component(render_fn: Callable, arg: Var) -> Component:
|
||||
def render_component(self, arg: Var) -> Component:
|
||||
"""Render the component.
|
||||
|
||||
Args:
|
||||
render_fn: The render function.
|
||||
arg: The argument to pass to the render function.
|
||||
|
||||
Returns:
|
||||
@ -65,16 +65,16 @@ class IterTag(Tag):
|
||||
from reflex.components.layout.fragment import Fragment
|
||||
|
||||
# Get the render function arguments.
|
||||
args = inspect.getfullargspec(render_fn).args
|
||||
index = IterTag.get_index_var()
|
||||
args = inspect.getfullargspec(self.render_fn).args
|
||||
index = self.get_index_var()
|
||||
|
||||
if len(args) == 1:
|
||||
# If the render function doesn't take the index as an argument.
|
||||
component = render_fn(arg)
|
||||
component = self.render_fn(arg)
|
||||
else:
|
||||
# If the render function takes the index as an argument.
|
||||
assert len(args) == 2
|
||||
component = render_fn(arg, index)
|
||||
component = self.render_fn(arg, index)
|
||||
|
||||
# Nested foreach components or cond must be wrapped in fragments.
|
||||
if isinstance(component, (Foreach, Cond)):
|
||||
|
@ -78,6 +78,9 @@ def display_nested_list_element(element: str, index: int):
|
||||
return box(text(element[index]))
|
||||
|
||||
|
||||
seen_index_vars = set()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state_var, render_fn, render_dict",
|
||||
[
|
||||
@ -86,7 +89,6 @@ def display_nested_list_element(element: str, index: int):
|
||||
display_color,
|
||||
{
|
||||
"iterable_state": "for_each_state.colors_list",
|
||||
"arg_index": "i",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
@ -95,7 +97,6 @@ def display_nested_list_element(element: str, index: int):
|
||||
display_color_name,
|
||||
{
|
||||
"iterable_state": "for_each_state.colors_dict_list",
|
||||
"arg_index": "i",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
@ -104,7 +105,6 @@ def display_nested_list_element(element: str, index: int):
|
||||
display_shade,
|
||||
{
|
||||
"iterable_state": "for_each_state.colors_nested_dict_list",
|
||||
"arg_index": "i",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
@ -113,7 +113,6 @@ def display_nested_list_element(element: str, index: int):
|
||||
display_primary_colors,
|
||||
{
|
||||
"iterable_state": "for_each_state.primary_color",
|
||||
"arg_index": "i",
|
||||
"iterable_type": "dict",
|
||||
},
|
||||
),
|
||||
@ -122,7 +121,6 @@ def display_nested_list_element(element: str, index: int):
|
||||
display_color_with_shades,
|
||||
{
|
||||
"iterable_state": "for_each_state.color_with_shades",
|
||||
"arg_index": "i",
|
||||
"iterable_type": "dict",
|
||||
},
|
||||
),
|
||||
@ -131,7 +129,6 @@ def display_nested_list_element(element: str, index: int):
|
||||
display_nested_color_with_shades,
|
||||
{
|
||||
"iterable_state": "for_each_state.nested_colors_with_shades",
|
||||
"arg_index": "i",
|
||||
"iterable_type": "dict",
|
||||
},
|
||||
),
|
||||
@ -140,7 +137,6 @@ def display_nested_list_element(element: str, index: int):
|
||||
display_nested_color_with_shades_v2,
|
||||
{
|
||||
"iterable_state": "for_each_state.nested_colors_with_shades",
|
||||
"arg_index": "i",
|
||||
"iterable_type": "dict",
|
||||
},
|
||||
),
|
||||
@ -149,7 +145,6 @@ def display_nested_list_element(element: str, index: int):
|
||||
display_color_tuple,
|
||||
{
|
||||
"iterable_state": "for_each_state.color_tuple",
|
||||
"arg_index": "i",
|
||||
"iterable_type": "tuple",
|
||||
},
|
||||
),
|
||||
@ -158,7 +153,6 @@ def display_nested_list_element(element: str, index: int):
|
||||
display_colors_set,
|
||||
{
|
||||
"iterable_state": "for_each_state.colors_set",
|
||||
"arg_index": "i",
|
||||
"iterable_type": "set",
|
||||
},
|
||||
),
|
||||
@ -167,7 +161,6 @@ def display_nested_list_element(element: str, index: int):
|
||||
lambda el, i: display_nested_list_element(el, i),
|
||||
{
|
||||
"iterable_state": "for_each_state.nested_colors_list",
|
||||
"arg_index": "i",
|
||||
"iterable_type": "list",
|
||||
},
|
||||
),
|
||||
@ -184,8 +177,11 @@ def test_foreach_render(state_var, render_fn, render_dict):
|
||||
component = Foreach.create(state_var, render_fn)
|
||||
|
||||
rend = component.render()
|
||||
arg_index = rend["arg_index"]
|
||||
assert rend["iterable_state"] == render_dict["iterable_state"]
|
||||
assert arg_index._var_name == render_dict["arg_index"]
|
||||
assert arg_index._var_type == int
|
||||
assert rend["iterable_type"] == render_dict["iterable_type"]
|
||||
|
||||
# Make sure the index vars are unique.
|
||||
arg_index = rend["arg_index"]
|
||||
assert arg_index._var_name not in seen_index_vars
|
||||
assert arg_index._var_type == int
|
||||
seen_index_vars.add(arg_index._var_name)
|
||||
|
Loading…
Reference in New Issue
Block a user