diff --git a/reflex/components/component.py b/reflex/components/component.py index 889aa8d59..2eef065bd 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -668,20 +668,43 @@ class Component(BaseComponent, ABC): children: The children of the component. """ - skip_parentable = all(child._valid_parents == [] for child in children) - if not self._invalid_children and not self._valid_children and skip_parentable: + no_valid_parents_defined = all(child._valid_parents == [] for child in children) + if ( + not self._invalid_children + and not self._valid_children + and no_valid_parents_defined + ): return comp_name = type(self).__name__ + allowed_components = ["Fragment", "Foreach", "Cond", "Match"] - def validate_invalid_child(child_name): - if child_name in self._invalid_children: + def validate_child(child): + child_name = type(child).__name__ + + # Iterate through the immediate children of fragment + if child_name == "Fragment": + for c in child.children: + validate_child(c) + + if child_name == "Cond": + validate_child(child.comp1) + validate_child(child.comp2) + + if child_name == "Match": + for cases in child.match_cases: + validate_child(cases[-1]) + validate_child(child.default) + + if self._invalid_children and child_name in self._invalid_children: raise ValueError( f"The component `{comp_name}` cannot have `{child_name}` as a child component" ) - def validate_valid_child(child_name): - if child_name not in self._valid_children: + if self._valid_children and child_name not in [ + *self._valid_children, + *allowed_components, + ]: valid_child_list = ", ".join( [f"`{v_child}`" for v_child in self._valid_children] ) @@ -689,26 +712,19 @@ class Component(BaseComponent, ABC): f"The component `{comp_name}` only allows the components: {valid_child_list} as children. Got `{child_name}` instead." ) - def validate_vaild_parent(child_name, valid_parents): - if comp_name not in valid_parents: + if child._valid_parents and comp_name not in [ + *child._valid_parents, + *allowed_components, + ]: valid_parent_list = ", ".join( - [f"`{v_parent}`" for v_parent in valid_parents] + [f"`{v_parent}`" for v_parent in child._valid_parents] ) raise ValueError( f"The component `{child_name}` can only be a child of the components: {valid_parent_list}. Got `{comp_name}` instead." ) for child in children: - name = type(child).__name__ - - if self._invalid_children: - validate_invalid_child(name) - - if self._valid_children: - validate_valid_child(name) - - if child._valid_parents: - validate_vaild_parent(name, child._valid_parents) + validate_child(child) @staticmethod def _get_vars_from_event_triggers( diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 3c0c09d7e..63ec6b31e 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -948,3 +948,235 @@ def test_instantiate_all_components(): component = getattr(rx, component_name) if isinstance(component, type) and issubclass(component, Component): component.create() + + +class InvalidParentComponent(Component): + """Invalid Parent Component.""" + + ... + + +class ValidComponent1(Component): + """Test valid component.""" + + _valid_children = ["ValidComponent2"] + + +class ValidComponent2(Component): + """Test valid component.""" + + ... + + +class ValidComponent3(Component): + """Test valid component.""" + + _valid_parents = ["ValidComponent2"] + + +class ValidComponent4(Component): + """Test valid component.""" + + _invalid_children = ["InvalidComponent"] + + +class InvalidComponent(Component): + """Test invalid component.""" + + ... + + +valid_component1 = ValidComponent1.create +valid_component2 = ValidComponent2.create +invalid_component = InvalidComponent.create +valid_component3 = ValidComponent3.create +invalid_parent = InvalidParentComponent.create +valid_component4 = ValidComponent4.create + + +def test_validate_valid_children(): + valid_component1(valid_component2()) + valid_component1( + rx.fragment(valid_component2()), + ) + valid_component1( + rx.fragment( + rx.fragment( + rx.fragment(valid_component2()), + ), + ), + ) + + valid_component1( + rx.cond( # type: ignore + True, + rx.fragment(valid_component2()), + rx.fragment( + rx.foreach(Var.create([1, 2, 3]), lambda x: valid_component2(x)) # type: ignore + ), + ) + ) + + valid_component1( + rx.cond( + True, + valid_component2(), + rx.fragment( + rx.match( + "condition", + ("first", valid_component2()), + rx.fragment(valid_component2(rx.text("default"))), + ) + ), + ) + ) + + valid_component1( + rx.match( + "condition", + ("first", valid_component2()), + ("second", "third", rx.fragment(valid_component2())), + ( + "fourth", + rx.cond(True, valid_component2(), rx.fragment(valid_component2())), + ), + ( + "fifth", + rx.match( + "nested_condition", + ("nested_first", valid_component2()), + rx.fragment(valid_component2()), + ), + valid_component2(), + ), + ) + ) + + +def test_validate_valid_parents(): + valid_component2(valid_component3()) + valid_component2( + rx.fragment(valid_component3()), + ) + valid_component1( + rx.fragment( + valid_component2( + rx.fragment(valid_component3()), + ), + ), + ) + + valid_component2( + rx.cond( # type: ignore + True, + rx.fragment(valid_component3()), + rx.fragment( + rx.foreach( + Var.create([1, 2, 3]), # type: ignore + lambda x: valid_component2(valid_component3(x)), + ) + ), + ) + ) + + valid_component2( + rx.cond( + True, + valid_component3(), + rx.fragment( + rx.match( + "condition", + ("first", valid_component3()), + rx.fragment(valid_component3(rx.text("default"))), + ) + ), + ) + ) + + valid_component2( + rx.match( + "condition", + ("first", valid_component3()), + ("second", "third", rx.fragment(valid_component3())), + ( + "fourth", + rx.cond(True, valid_component3(), rx.fragment(valid_component3())), + ), + ( + "fifth", + rx.match( + "nested_condition", + ("nested_first", valid_component3()), + rx.fragment(valid_component3()), + ), + valid_component3(), + ), + ) + ) + + +def test_validate_invalid_children(): + with pytest.raises(ValueError): + valid_component4(invalid_component()) + + with pytest.raises(ValueError): + valid_component4( + rx.fragment(invalid_component()), + ) + + with pytest.raises(ValueError): + valid_component2( + rx.fragment( + valid_component4( + rx.fragment(invalid_component()), + ), + ), + ) + + with pytest.raises(ValueError): + valid_component4( + rx.cond( # type: ignore + True, + rx.fragment(invalid_component()), + rx.fragment( + rx.foreach(Var.create([1, 2, 3]), lambda x: invalid_component(x)) # type: ignore + ), + ) + ) + + with pytest.raises(ValueError): + valid_component4( + rx.cond( + True, + invalid_component(), + rx.fragment( + rx.match( + "condition", + ("first", invalid_component()), + rx.fragment(invalid_component(rx.text("default"))), + ) + ), + ) + ) + + with pytest.raises(ValueError): + valid_component4( + rx.match( + "condition", + ("first", invalid_component()), + ("second", "third", rx.fragment(invalid_component())), + ( + "fourth", + rx.cond(True, invalid_component(), rx.fragment(valid_component2())), + ), + ( + "fifth", + rx.match( + "nested_condition", + ("nested_first", invalid_component()), + rx.fragment(invalid_component()), + ), + invalid_component(), + ), + ) + )