diff --git a/reflex/components/component.py b/reflex/components/component.py index 0f4756fd0..2b8f89bf3 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -66,6 +66,10 @@ class Component(Base, ABC): # components that cannot be children invalid_children: List[str] = [] + + # components that are only allowed as children + valid_children: List[str] = [] + # custom attribute custom_attrs: Dict[str, str] = {} @@ -103,9 +107,10 @@ class Component(Base, ABC): TypeError: If an invalid prop is passed. """ # Set the id and children initially. + children = kwargs.get("children", []) initial_kwargs = { "id": kwargs.get("id"), - "children": kwargs.get("children", []), + "children": children, **{ prop: Var.create(kwargs[prop]) for prop in self.get_initial_props() @@ -114,6 +119,8 @@ class Component(Base, ABC): } super().__init__(**initial_kwargs) + self._validate_component_children(children) + # Get the component fields, triggers, and props. fields = self.get_fields() triggers = self.get_triggers() @@ -381,6 +388,7 @@ class Component(Base, ABC): else Bare.create(contents=Var.create(child, is_string=True)) for child in children ] + return cls(children=children, **props) def _add_style(self, style): @@ -435,30 +443,44 @@ class Component(Base, ABC): ), autofocus=self.autofocus, ) - self._validate_component_children( - rendered_dict["name"], rendered_dict["children"] - ) return rendered_dict - def _validate_component_children(self, comp_name: str, children: List[Dict]): + def _validate_component_children(self, children: List[Component]): """Validate the children components. Args: - comp_name: name of the component. - children: list of children components. + children: The children of the component. - Raises: - ValueError: when an unsupported component is matched. """ - if not self.invalid_children: + if not self.invalid_children and not self.valid_children: return - for child in children: - name = child["name"] - if name in self.invalid_children: + + comp_name = type(self).__name__ + + def validate_invalid_child(child_name): + if child_name in self.invalid_children: raise ValueError( - f"The component `{comp_name.lower()}` cannot have `{name.lower()}` as a child component" + f"The component `{comp_name}` cannot have `{child_name}` as a child component" ) + def validate_valid_child(child_name): + if child_name not in self.valid_children: + valid_child_list = ", ".join( + [f"`{v_child}`" for v_child in self.valid_children] + ) + raise ValueError( + f"The component `{comp_name}` only allows the components: {valid_child_list} as children. Got `{child_name}` instead." + ) + + for child in children: + name = type(child).__name__ + + if self.invalid_children: + validate_invalid_child(name) + + if self.valid_children: + validate_valid_child(name) + def _get_custom_code(self) -> Optional[str]: """Get custom code for the component. diff --git a/reflex/components/forms/button.py b/reflex/components/forms/button.py index 3ec573bd9..5c336f8e8 100644 --- a/reflex/components/forms/button.py +++ b/reflex/components/forms/button.py @@ -1,4 +1,5 @@ """A button component.""" +from typing import List from reflex.components.libs.chakra import ChakraComponent from reflex.vars import Var @@ -42,6 +43,9 @@ class Button(ChakraComponent): # The type of button. type_: Var[str] + # Components that are not allowed as children. + invalid_children: List[str] = ["Button", "MenuButton"] + class ButtonGroup(ChakraComponent): """A group of buttons.""" diff --git a/reflex/components/overlay/menu.py b/reflex/components/overlay/menu.py index 2c23407b0..f477b36cc 100644 --- a/reflex/components/overlay/menu.py +++ b/reflex/components/overlay/menu.py @@ -1,6 +1,6 @@ """Menu components.""" -from typing import Set +from typing import List, Set from reflex.components.component import Component from reflex.components.libs.chakra import ChakraComponent @@ -100,6 +100,9 @@ class MenuButton(ChakraComponent): # The variant of the menu button. variant: Var[str] + # Components that are not allowed as children. + invalid_children: List[str] = ["Button", "MenuButton"] + # The tag to use for the menu button. as_: Var[str] diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 23274b80d..c676e543a 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -121,13 +121,47 @@ def component5() -> Type[Component]: """ class TestComponent5(Component): - tag = "Tag" + tag = "RandomComponent" invalid_children: List[str] = ["Text"] + valid_children: List[str] = ["Text"] + return TestComponent5 +@pytest.fixture +def component6() -> Type[Component]: + """A test component. + + Returns: + A test component. + """ + + class TestComponent6(Component): + tag = "RandomComponent" + + invalid_children: List[str] = ["Text"] + + return TestComponent6 + + +@pytest.fixture +def component7() -> Type[Component]: + """A test component. + + Returns: + A test component. + """ + + class TestComponent7(Component): + tag = "RandomComponent" + + valid_children: List[str] = ["Text"] + + return TestComponent7 + + @pytest.fixture def on_click1() -> EventHandler: """A sample on click function. @@ -461,16 +495,40 @@ def test_get_hooks_nested2(component3, component4): ) -def test_unsupported_child_components(component5): - """Test that a value error is raised when an unsupported component is provided as a child. +@pytest.mark.parametrize("fixture", ["component5", "component6"]) +def test_unsupported_child_components(fixture, request): + """Test that a value error is raised when an unsupported component (a child component found in the + component's invalid children list) is provided as a child. Args: - component5: the test component + fixture: the test component as a fixture. + request: Pytest request. """ + component = request.getfixturevalue(fixture) with pytest.raises(ValueError) as err: - comp = component5.create(rx.text("testing component")) + comp = component.create(rx.text("testing component")) comp.render() assert ( err.value.args[0] - == f"The component `tag` cannot have `text` as a child component" + == f"The component `{component.__name__}` cannot have `Text` as a child component" + ) + + +@pytest.mark.parametrize("fixture", ["component5", "component7"]) +def test_component_with_only_valid_children(fixture, request): + """Test that a value error is raised when an unsupported component (a child component not found in the + component's valid children list) is provided as a child. + + Args: + fixture: the test component as a fixture. + request: Pytest request. + """ + component = request.getfixturevalue(fixture) + with pytest.raises(ValueError) as err: + comp = component.create(rx.box("testing component")) + comp.render() + assert ( + err.value.args[0] + == f"The component `{component.__name__}` only allows the components: `Text` as children. " + f"Got `Box` instead." )