what if i deleted rx.foreach

This commit is contained in:
Khaleel Al-Adhami 2025-01-17 18:16:43 -08:00
parent b10f6e836d
commit 5d6b51c561
14 changed files with 188 additions and 424 deletions

View File

@ -4,12 +4,12 @@ from __future__ import annotations
from typing import Any, Iterator from typing import Any, Iterator
from reflex.components.component import Component from reflex.components.component import Component, ComponentStyle
from reflex.components.tags import Tag from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless from reflex.components.tags.tagless import Tagless
from reflex.utils.imports import ParsedImportDict from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var from reflex.vars import BooleanVar, ObjectVar, Var
from reflex.vars.base import VarData from reflex.vars.base import VarData, get_var_caching, set_var_caching
class Bare(Component): class Bare(Component):
@ -141,6 +141,31 @@ class Bare(Component):
return Tagless(contents=f"{{{self.contents!s}}}") return Tagless(contents=f"{{{self.contents!s}}}")
return Tagless(contents=str(self.contents)) return Tagless(contents=str(self.contents))
def _add_style_recursive(
self, style: ComponentStyle, theme: Component | None = None
) -> Component:
"""Add style to the component and its children.
Args:
style: The style to add.
theme: The theme to add.
Returns:
The component with the style added.
"""
new_self = super()._add_style_recursive(style, theme)
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
if isinstance(component, Component):
component._add_style_recursive(style, theme)
if get_var_caching():
set_var_caching(False)
str(new_self)
set_var_caching(True)
return new_self
def _get_vars( def _get_vars(
self, include_children: bool = False, ignore_ids: set[int] | None = None self, include_children: bool = False, ignore_ids: set[int] | None = None
) -> Iterator[Var]: ) -> Iterator[Var]:

View File

@ -931,7 +931,6 @@ class Component(BaseComponent, ABC):
""" """
from reflex.components.base.bare import Bare from reflex.components.base.bare import Bare
from reflex.components.base.fragment import Fragment from reflex.components.base.fragment import Fragment
from reflex.components.core.foreach import Foreach
no_valid_parents_defined = all(child._valid_parents == [] for child in children) no_valid_parents_defined = all(child._valid_parents == [] for child in children)
if ( if (
@ -942,7 +941,7 @@ class Component(BaseComponent, ABC):
return return
comp_name = type(self).__name__ comp_name = type(self).__name__
allowed_components = [comp.__name__ for comp in (Fragment, Foreach)] allowed_components = [comp.__name__ for comp in (Fragment,)]
def validate_child(child): def validate_child(child):
child_name = type(child).__name__ child_name = type(child).__name__
@ -1974,8 +1973,6 @@ class StatefulComponent(BaseComponent):
Returns: Returns:
The stateful component or None if the component should not be memoized. The stateful component or None if the component should not be memoized.
""" """
from reflex.components.core.foreach import Foreach
if component._memoization_mode.disposition == MemoizationDisposition.NEVER: if component._memoization_mode.disposition == MemoizationDisposition.NEVER:
# Never memoize this component. # Never memoize this component.
return None return None
@ -2004,10 +2001,6 @@ class StatefulComponent(BaseComponent):
# Skip BaseComponent and StatefulComponent children. # Skip BaseComponent and StatefulComponent children.
if not isinstance(child, Component): if not isinstance(child, Component):
continue continue
# Always consider Foreach something that must be memoized by the parent.
if isinstance(child, Foreach):
should_memoize = True
break
child = cls._child_var(child) child = cls._child_var(child)
if isinstance(child, Var) and child._get_all_var_data(): if isinstance(child, Var) and child._get_all_var_data():
should_memoize = True should_memoize = True
@ -2057,12 +2050,9 @@ class StatefulComponent(BaseComponent):
The Var from the child component or the child itself (for regular cases). The Var from the child component or the child itself (for regular cases).
""" """
from reflex.components.base.bare import Bare from reflex.components.base.bare import Bare
from reflex.components.core.foreach import Foreach
if isinstance(child, Bare): if isinstance(child, Bare):
return child.contents return child.contents
if isinstance(child, Foreach):
return child.iterable
return child return child
@classmethod @classmethod

View File

@ -25,7 +25,6 @@ _SUBMOD_ATTRS: dict[str, list[str]] = {
"debounce": ["DebounceInput", "debounce_input"], "debounce": ["DebounceInput", "debounce_input"],
"foreach": [ "foreach": [
"foreach", "foreach",
"Foreach",
], ],
"html": ["html", "Html"], "html": ["html", "Html"],
"match": [ "match": [

View File

@ -21,7 +21,6 @@ from .cond import color_mode_cond as color_mode_cond
from .cond import cond as cond from .cond import cond as cond
from .debounce import DebounceInput as DebounceInput from .debounce import DebounceInput as DebounceInput
from .debounce import debounce_input as debounce_input from .debounce import debounce_input as debounce_input
from .foreach import Foreach as Foreach
from .foreach import foreach as foreach from .foreach import foreach as foreach
from .html import Html as Html from .html import Html as Html
from .html import html as html from .html import html as html

View File

@ -2,15 +2,11 @@
from __future__ import annotations from __future__ import annotations
import inspect from typing import Callable, Iterable
from typing import Any, Callable, Iterable
from reflex.components.base.fragment import Fragment
from reflex.components.component import Component
from reflex.components.tags import IterTag
from reflex.constants import MemoizationMode
from reflex.state import ComponentState
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
from reflex.vars.object import ObjectVar
from reflex.vars.sequence import ArrayVar
class ForeachVarError(TypeError): class ForeachVarError(TypeError):
@ -21,116 +17,32 @@ class ForeachRenderError(TypeError):
"""Raised when there is an error with the foreach render function.""" """Raised when there is an error with the foreach render function."""
class Foreach(Component): def foreach(
"""A component that takes in an iterable and a render function and renders a list of components.""" iterable: Var[Iterable] | Iterable,
render_fn: Callable,
) -> Var:
"""Create a foreach component.
_memoization_mode = MemoizationMode(recursive=False) Args:
iterable: The iterable to create components from.
render_fn: A function from the render args to the component.
# The iterable to create components from. Returns:
iterable: Var[Iterable] The foreach component.
# A function from the render args to the component. Raises:
render_fn: Callable = Fragment.create ForeachVarError: If the iterable is of type Any.
TypeError: If the render function is a ComponentState.
"""
iterable = LiteralVar.create(iterable)
if isinstance(iterable, ObjectVar):
iterable = iterable.items()
@classmethod if not isinstance(iterable, ArrayVar):
def create( raise ForeachVarError(
cls, f"Could not foreach over var `{iterable!s}` of type {iterable._var_type!s}. "
iterable: Var[Iterable] | Iterable, "(If you are trying to foreach over a state var, add a type annotation to the var). "
render_fn: Callable, "See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
) -> Foreach:
"""Create a foreach component.
Args:
iterable: The iterable to create components from.
render_fn: A function from the render args to the component.
Returns:
The foreach component.
Raises:
ForeachVarError: If the iterable is of type Any.
TypeError: If the render function is a ComponentState.
"""
iterable = LiteralVar.create(iterable)
if iterable._var_type == Any:
raise ForeachVarError(
f"Could not foreach over var `{iterable!s}` of type Any. "
"(If you are trying to foreach over a state var, add a type annotation to the var). "
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
)
if (
hasattr(render_fn, "__qualname__")
and render_fn.__qualname__ == ComponentState.create.__qualname__
):
raise TypeError(
"Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet."
)
component = cls(
iterable=iterable,
render_fn=render_fn,
)
# Keep a ref to a rendered component to determine correct imports/hooks/styles.
component.children = [component._render().render_component()]
return component
def _render(self) -> IterTag:
props = {}
render_sig = inspect.signature(self.render_fn)
params = list(render_sig.parameters.values())
# Validate the render function signature.
if len(params) == 0 or len(params) > 2:
raise ForeachRenderError(
"Expected 1 or 2 parameters in foreach render function, got "
f"{[p.name for p in params]}. See "
"https://reflex.dev/docs/library/dynamic-rendering/foreach/"
)
if len(params) >= 1:
# Determine the arg var name based on the params accepted by render_fn.
props["arg_var_name"] = params[0].name
if len(params) == 2:
# Determine the index var name based on the params accepted by render_fn.
props["index_var_name"] = params[1].name
else:
# Otherwise, use a deterministic index, based on the render function bytecode.
code_hash = (
hash(self.render_fn.__code__)
.to_bytes(
length=8,
byteorder="big",
signed=True,
)
.hex()
)
props["index_var_name"] = f"index_{code_hash}"
return IterTag(
iterable=self.iterable,
render_fn=self.render_fn,
children=self.children,
**props,
) )
def render(self): return iterable.foreach(render_fn)
"""Render the component.
Returns:
The dictionary for template of component.
"""
tag = self._render()
return dict(
tag,
iterable_state=str(tag.iterable),
arg_name=tag.arg_var_name,
arg_index=tag.get_index_var_arg(),
iterable_type=tag.iterable._var_type.mro()[0].__name__,
)
foreach = Foreach.create

View File

@ -188,7 +188,7 @@ class Slider(ComponentNamespace):
else: else:
children = [ children = [
track, track,
# Foreach.create(props.get("value"), lambda e: SliderThumb.create()), # foreach doesn't render Thumbs properly # noqa: ERA001 # foreach(props.get("value"), lambda e: SliderThumb.create()), # foreach doesn't render Thumbs properly # noqa: ERA001
] ]
return SliderRoot.create(*children, **props) return SliderRoot.create(*children, **props)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from typing import Any, Iterable, Literal, Union from typing import Any, Iterable, Literal, Union
from reflex.components.component import Component, ComponentNamespace from reflex.components.component import Component, ComponentNamespace
from reflex.components.core.foreach import Foreach from reflex.components.core.foreach import foreach
from reflex.components.el.elements.typography import Li, Ol, Ul from reflex.components.el.elements.typography import Li, Ol, Ul
from reflex.components.lucide.icon import Icon from reflex.components.lucide.icon import Icon
from reflex.components.markdown.markdown import MarkdownComponentMap from reflex.components.markdown.markdown import MarkdownComponentMap
@ -70,7 +70,7 @@ class BaseList(Component, MarkdownComponentMap):
if not children and items is not None: if not children and items is not None:
if isinstance(items, Var): if isinstance(items, Var):
children = [Foreach.create(items, ListItem.create)] children = [foreach(items, ListItem.create)]
else: else:
children = [ListItem.create(item) for item in items] # type: ignore children = [ListItem.create(item) for item in items] # type: ignore
props["direction"] = "column" props["direction"] = "column"

View File

@ -1,4 +1,3 @@
"""Representations for React tags.""" """Representations for React tags."""
from .iter_tag import IterTag
from .tag import Tag from .tag import Tag

View File

@ -1,141 +0,0 @@
"""Tag to loop through a list of components."""
from __future__ import annotations
import dataclasses
import inspect
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args
from reflex.components.tags.tag import Tag
from reflex.utils import types
from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
if TYPE_CHECKING:
from reflex.components.component import Component
@dataclasses.dataclass()
class IterTag(Tag):
"""An iterator tag."""
# The var to iterate over.
iterable: Var[Iterable] = dataclasses.field(
default_factory=lambda: LiteralArrayVar.create([])
)
# The component render function for each item in the iterable.
render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x)
# The name of the arg var.
arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
# The name of the index var.
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
def get_iterable_var_type(self) -> types.GenericType:
"""Get the type of the iterable var.
Returns:
The type of the iterable var.
"""
iterable = self.iterable
try:
if iterable._var_type.mro()[0] is dict:
# Arg is a tuple of (key, value).
return Tuple[get_args(iterable._var_type)]
elif iterable._var_type.mro()[0] is tuple:
# Arg is a union of any possible values in the tuple.
return Union[get_args(iterable._var_type)]
else:
return get_args(iterable._var_type)[0]
except Exception:
return Any
def get_index_var(self) -> Var:
"""Get the index var for the tag (with curly braces).
This is used to reference the index var within the tag.
Returns:
The index var.
"""
return Var(
_js_expr=self.index_var_name,
_var_type=int,
).guess_type()
def get_arg_var(self) -> Var:
"""Get the arg var for the tag (with curly braces).
This is used to reference the arg var within the tag.
Returns:
The arg var.
"""
return Var(
_js_expr=self.arg_var_name,
_var_type=self.get_iterable_var_type(),
).guess_type()
def get_index_var_arg(self) -> Var:
"""Get the index var for the tag (without curly braces).
This is used to render the index var in the .map() function.
Returns:
The index var.
"""
return Var(
_js_expr=self.index_var_name,
_var_type=int,
).guess_type()
def get_arg_var_arg(self) -> Var:
"""Get the arg var for the tag (without curly braces).
This is used to render the arg var in the .map() function.
Returns:
The arg var.
"""
return Var(
_js_expr=self.arg_var_name,
_var_type=self.get_iterable_var_type(),
).guess_type()
def render_component(self) -> Component:
"""Render the component.
Raises:
ValueError: If the render function takes more than 2 arguments.
Returns:
The rendered component.
"""
# Import here to avoid circular imports.
from reflex.components.base.fragment import Fragment
from reflex.components.core.foreach import Foreach
# Get the render function arguments.
args = inspect.getfullargspec(self.render_fn).args
arg = self.get_arg_var()
index = self.get_index_var()
if len(args) == 1:
# If the render function doesn't take the index as an argument.
component = self.render_fn(arg)
else:
# If the render function takes the index as an argument.
if len(args) != 2:
raise ValueError("The render function must take 2 arguments.")
component = self.render_fn(arg, index)
# Nested foreach components or cond must be wrapped in fragments.
if isinstance(component, (Foreach, Var)):
component = Fragment.create(component)
# Set the component key.
if component.key is None:
component.key = index
return component

View File

@ -890,12 +890,23 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo
Returns: Returns:
Whether the type hint is a subclass of the other type hint. Whether the type hint is a subclass of the other type hint.
""" """
if isinstance(possible_subclass, Sequence) and isinstance(
possible_superclass, Sequence
):
return all(
typehint_issubclass(subclass, superclass)
for subclass, superclass in zip(possible_subclass, possible_superclass)
)
if possible_subclass is possible_superclass: if possible_subclass is possible_superclass:
return True return True
if possible_superclass is Any: if possible_superclass is Any:
return True return True
if possible_subclass is Any: if possible_subclass is Any:
return False return False
if isinstance(
possible_subclass, (TypeVar, typing_extensions.TypeVar)
) or isinstance(possible_superclass, (TypeVar, typing_extensions.TypeVar)):
return True
provided_type_origin = get_origin(possible_subclass) provided_type_origin = get_origin(possible_subclass)
accepted_type_origin = get_origin(possible_superclass) accepted_type_origin = get_origin(possible_superclass)

View File

@ -151,6 +151,28 @@ def unwrap_reflex_callalbe(
return args return args
_VAR_CACHING = True
def get_var_caching() -> bool:
"""Get the var caching status.
Returns:
The var caching status.
"""
return _VAR_CACHING
def set_var_caching(value: bool):
"""Set the var caching status.
Args:
value: The value to set the var caching status to.
"""
global _VAR_CACHING
_VAR_CACHING = value
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
@ -1186,6 +1208,25 @@ class Var(Generic[VAR_TYPE]):
""" """
return self return self
def __getattribute__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute.
"""
if not _VAR_CACHING:
try:
self_dict = object.__getattribute__(self, "__dict__")
for key in self_dict:
if key.startswith("_cached_"):
del self_dict[key]
except Exception:
pass
return super().__getattribute__(name)
def __getattr__(self, name: str): def __getattr__(self, name: str):
"""Get an attribute of the var. """Get an attribute of the var.

View File

@ -741,7 +741,8 @@ if TYPE_CHECKING:
def map_array_operation( def map_array_operation(
array: Var[Sequence[INNER_ARRAY_VAR]], array: Var[Sequence[INNER_ARRAY_VAR]],
function: Var[ function: Var[
ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR] ReflexCallable[[INNER_ARRAY_VAR, int], ANOTHER_ARRAY_VAR]
| ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]
| ReflexCallable[[], ANOTHER_ARRAY_VAR] | ReflexCallable[[], ANOTHER_ARRAY_VAR]
], ],
) -> CustomVarOperationReturn[Sequence[ANOTHER_ARRAY_VAR]]: ) -> CustomVarOperationReturn[Sequence[ANOTHER_ARRAY_VAR]]:
@ -973,7 +974,8 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)):
def foreach( def foreach(
self: ArrayVar[Sequence[INNER_ARRAY_VAR]], self: ArrayVar[Sequence[INNER_ARRAY_VAR]],
fn: Callable[[Var[INNER_ARRAY_VAR]], ANOTHER_ARRAY_VAR] fn: Callable[[Var[INNER_ARRAY_VAR], NumberVar[int]], ANOTHER_ARRAY_VAR]
| Callable[[Var[INNER_ARRAY_VAR]], ANOTHER_ARRAY_VAR]
| Callable[[], ANOTHER_ARRAY_VAR], | Callable[[], ANOTHER_ARRAY_VAR],
) -> ArrayVar[Sequence[ANOTHER_ARRAY_VAR]]: ) -> ArrayVar[Sequence[ANOTHER_ARRAY_VAR]]:
"""Apply a function to each element of the array. """Apply a function to each element of the array.
@ -987,21 +989,36 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)):
Raises: Raises:
VarTypeError: If the function takes more than one argument. VarTypeError: If the function takes more than one argument.
""" """
from reflex.state import ComponentState
from .function import ArgsFunctionOperation from .function import ArgsFunctionOperation
if not callable(fn): if not callable(fn):
raise_unsupported_operand_types("foreach", (type(self), type(fn))) raise_unsupported_operand_types("foreach", (type(self), type(fn)))
# get the number of arguments of the function # get the number of arguments of the function
num_args = len(inspect.signature(fn).parameters) num_args = len(inspect.signature(fn).parameters)
if num_args > 1: if num_args > 2:
raise VarTypeError( raise VarTypeError(
"The function passed to foreach should take at most one argument." "The function passed to foreach should take at most two arguments."
)
if (
hasattr(fn, "__qualname__")
and fn.__qualname__ == ComponentState.create.__qualname__
):
raise TypeError(
"Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet."
) )
if num_args == 0: if num_args == 0:
return_value = fn() # type: ignore fn_result = fn() # pyright: ignore [reportCallIssue]
return_value = Var.create(fn_result)
simple_function_var: FunctionVar[ReflexCallable[[], ANOTHER_ARRAY_VAR]] = ( simple_function_var: FunctionVar[ReflexCallable[[], ANOTHER_ARRAY_VAR]] = (
ArgsFunctionOperation.create((), return_value) ArgsFunctionOperation.create(
(),
return_value,
_var_type=ReflexCallable[[], return_value._var_type],
)
) )
return map_array_operation(self, simple_function_var).guess_type() return map_array_operation(self, simple_function_var).guess_type()
@ -1021,11 +1038,40 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)):
).guess_type(), ).guess_type(),
) )
if num_args == 1:
fn_result = fn(first_arg) # pyright: ignore [reportCallIssue]
return_value = Var.create(fn_result)
function_var = cast(
Var[ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]],
ArgsFunctionOperation.create(
(arg_name,),
return_value,
_var_type=ReflexCallable[[first_arg_type], return_value._var_type],
),
)
return map_array_operation.call(self, function_var).guess_type()
second_arg = cast(
NumberVar[int],
Var(
_js_expr=get_unique_variable_name(),
_var_type=int,
).guess_type(),
)
fn_result = fn(first_arg, second_arg) # pyright: ignore [reportCallIssue]
return_value = Var.create(fn_result)
function_var = cast( function_var = cast(
Var[ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]], Var[ReflexCallable[[INNER_ARRAY_VAR, int], ANOTHER_ARRAY_VAR]],
ArgsFunctionOperation.create( ArgsFunctionOperation.create(
(arg_name,), (arg_name, second_arg._js_expr),
Var.create(fn(first_arg)), # type: ignore return_value,
_var_type=ReflexCallable[[first_arg_type, int], return_value._var_type],
), ),
) )

View File

@ -6,16 +6,11 @@ import pytest
from reflex import el from reflex import el
from reflex.base import Base from reflex.base import Base
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.core.foreach import ( from reflex.components.core.foreach import ForeachVarError, foreach
Foreach,
ForeachRenderError,
ForeachVarError,
foreach,
)
from reflex.components.radix.themes.layout.box import box from reflex.components.radix.themes.layout.box import box
from reflex.components.radix.themes.typography.text import text from reflex.components.radix.themes.typography.text import text
from reflex.state import BaseState, ComponentState from reflex.state import BaseState, ComponentState
from reflex.vars.base import Var from reflex.utils.exceptions import VarTypeError
from reflex.vars.number import NumberVar from reflex.vars.number import NumberVar
from reflex.vars.sequence import ArrayVar from reflex.vars.sequence import ArrayVar
@ -141,143 +136,35 @@ def display_color_index_tuple(color):
seen_index_vars = set() seen_index_vars = set()
@pytest.mark.parametrize(
"state_var, render_fn, render_dict",
[
(
ForEachState.colors_list,
display_color,
{
"iterable_state": f"{ForEachState.get_full_name()}.colors_list",
"iterable_type": "list",
},
),
(
ForEachState.colors_dict_list,
display_color_name,
{
"iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list",
"iterable_type": "list",
},
),
(
ForEachState.colors_nested_dict_list,
display_shade,
{
"iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list",
"iterable_type": "list",
},
),
(
ForEachState.primary_color,
display_primary_colors,
{
"iterable_state": f"{ForEachState.get_full_name()}.primary_color",
"iterable_type": "dict",
},
),
(
ForEachState.color_with_shades,
display_color_with_shades,
{
"iterable_state": f"{ForEachState.get_full_name()}.color_with_shades",
"iterable_type": "dict",
},
),
(
ForEachState.nested_colors_with_shades,
display_nested_color_with_shades,
{
"iterable_state": f"{ForEachState.get_full_name()}.nested_colors_with_shades",
"iterable_type": "dict",
},
),
(
ForEachState.nested_colors_with_shades,
display_nested_color_with_shades_v2,
{
"iterable_state": f"{ForEachState.get_full_name()}.nested_colors_with_shades",
"iterable_type": "dict",
},
),
(
ForEachState.color_tuple,
display_color_tuple,
{
"iterable_state": f"{ForEachState.get_full_name()}.color_tuple",
"iterable_type": "tuple",
},
),
(
ForEachState.colors_set,
display_colors_set,
{
"iterable_state": f"{ForEachState.get_full_name()}.colors_set",
"iterable_type": "set",
},
),
(
ForEachState.nested_colors_list,
lambda el, i: display_nested_list_element(el, i),
{
"iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list",
"iterable_type": "list",
},
),
(
ForEachState.color_index_tuple,
display_color_index_tuple,
{
"iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple",
"iterable_type": "tuple",
},
),
],
)
def test_foreach_render(state_var, render_fn, render_dict):
"""Test that the foreach component renders without error.
Args:
state_var: the state var.
render_fn: The render callable
render_dict: return dict on calling `component.render`
"""
component = Foreach.create(state_var, render_fn)
rend = component.render()
assert rend["iterable_state"] == render_dict["iterable_state"]
assert rend["iterable_type"] == render_dict["iterable_type"]
# Make sure the index vars are unique.
arg_index = rend["arg_index"]
assert isinstance(arg_index, Var)
assert arg_index._js_expr not in seen_index_vars
assert arg_index._var_type is int
seen_index_vars.add(arg_index._js_expr)
def test_foreach_bad_annotations(): def test_foreach_bad_annotations():
"""Test that the foreach component raises a ForeachVarError if the iterable is of type Any.""" """Test that the foreach component raises a ForeachVarError if the iterable is of type Any."""
with pytest.raises(ForeachVarError): with pytest.raises(ForeachVarError):
Foreach.create( foreach(
ForEachState.bad_annotation_list, ForEachState.bad_annotation_list,
lambda sublist: Foreach.create(sublist, lambda color: text(color)), lambda sublist: foreach(sublist, lambda color: text(color)),
) )
def test_foreach_no_param_in_signature(): def test_foreach_no_param_in_signature():
"""Test that the foreach component raises a ForeachRenderError if no parameters are passed.""" """Test that the foreach component DOES NOT raise an error if no parameters are passed."""
with pytest.raises(ForeachRenderError): foreach(
Foreach.create( ForEachState.colors_list,
ForEachState.colors_list, lambda: text("color"),
lambda: text("color"), )
)
def test_foreach_with_index():
"""Test that the foreach component works with an index."""
foreach(
ForEachState.colors_list,
lambda color, index: text(color, index),
)
def test_foreach_too_many_params_in_signature(): def test_foreach_too_many_params_in_signature():
"""Test that the foreach component raises a ForeachRenderError if too many parameters are passed.""" """Test that the foreach component raises a ForeachRenderError if too many parameters are passed."""
with pytest.raises(ForeachRenderError): with pytest.raises(VarTypeError):
Foreach.create( foreach(
ForEachState.colors_list, ForEachState.colors_list,
lambda color, index, extra: text(color), lambda color, index, extra: text(color),
) )
@ -292,13 +179,13 @@ def test_foreach_component_styles():
) )
) )
component._add_style_recursive({box: {"color": "red"}}) component._add_style_recursive({box: {"color": "red"}})
assert 'css={({ ["color"] : "red" })}' in str(component) assert '{ ["css"] : ({ ["color"] : "red" }) }' in str(component)
def test_foreach_component_state(): def test_foreach_component_state():
"""Test that using a component state to render in the foreach raises an error.""" """Test that using a component state to render in the foreach raises an error."""
with pytest.raises(TypeError): with pytest.raises(TypeError):
Foreach.create( foreach(
ForEachState.colors_list, ForEachState.colors_list,
ComponentStateTest.create, ComponentStateTest.create,
) )
@ -306,7 +193,7 @@ def test_foreach_component_state():
def test_foreach_default_factory(): def test_foreach_default_factory():
"""Test that the default factory is called.""" """Test that the default factory is called."""
_ = Foreach.create( _ = foreach(
ForEachState.default_factory_list, ForEachState.default_factory_list,
lambda tag: text(tag.name), lambda tag: text(tag.name),
) )

View File

@ -1446,7 +1446,6 @@ def test_instantiate_all_components():
untested_components = { untested_components = {
"Card", "Card",
"DebounceInput", "DebounceInput",
"Foreach",
"FormControl", "FormControl",
"Html", "Html",
"Icon", "Icon",
@ -2147,14 +2146,11 @@ def test_add_style_foreach():
page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i))) page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i)))
page._add_style_recursive(Style()) page._add_style_recursive(Style())
# Expect only a single child of the foreach on the python side
assert len(page.children[0].children) == 1
# Expect the style to be added to the child of the foreach # Expect the style to be added to the child of the foreach
assert 'css={({ ["color"] : "red" })}' in str(page.children[0].children[0]) assert '({ ["css"] : ({ ["color"] : "red" }) }),' in str(page.children[0])
# Expect only one instance of this CSS dict in the rendered page # Expect only one instance of this CSS dict in the rendered page
assert str(page).count('css={({ ["color"] : "red" })}') == 1 assert str(page).count('({ ["css"] : ({ ["color"] : "red" }) }),') == 1
class TriggerState(rx.State): class TriggerState(rx.State):