Validate component children (#1647)

This commit is contained in:
Elijah Ahianyo 2023-08-23 22:56:27 +00:00 committed by GitHub
parent 457173eed7
commit 217a5806ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 108 additions and 21 deletions

View File

@ -66,6 +66,10 @@ class Component(Base, ABC):
# components that cannot be children # components that cannot be children
invalid_children: List[str] = [] invalid_children: List[str] = []
# components that are only allowed as children
valid_children: List[str] = []
# custom attribute # custom attribute
custom_attrs: Dict[str, str] = {} custom_attrs: Dict[str, str] = {}
@ -103,9 +107,10 @@ class Component(Base, ABC):
TypeError: If an invalid prop is passed. TypeError: If an invalid prop is passed.
""" """
# Set the id and children initially. # Set the id and children initially.
children = kwargs.get("children", [])
initial_kwargs = { initial_kwargs = {
"id": kwargs.get("id"), "id": kwargs.get("id"),
"children": kwargs.get("children", []), "children": children,
**{ **{
prop: Var.create(kwargs[prop]) prop: Var.create(kwargs[prop])
for prop in self.get_initial_props() for prop in self.get_initial_props()
@ -114,6 +119,8 @@ class Component(Base, ABC):
} }
super().__init__(**initial_kwargs) super().__init__(**initial_kwargs)
self._validate_component_children(children)
# Get the component fields, triggers, and props. # Get the component fields, triggers, and props.
fields = self.get_fields() fields = self.get_fields()
triggers = self.get_triggers() triggers = self.get_triggers()
@ -381,6 +388,7 @@ class Component(Base, ABC):
else Bare.create(contents=Var.create(child, is_string=True)) else Bare.create(contents=Var.create(child, is_string=True))
for child in children for child in children
] ]
return cls(children=children, **props) return cls(children=children, **props)
def _add_style(self, style): def _add_style(self, style):
@ -435,30 +443,44 @@ class Component(Base, ABC):
), ),
autofocus=self.autofocus, autofocus=self.autofocus,
) )
self._validate_component_children(
rendered_dict["name"], rendered_dict["children"]
)
return rendered_dict return rendered_dict
def _validate_component_children(self, comp_name: str, children: List[Dict]): def _validate_component_children(self, children: List[Component]):
"""Validate the children components. """Validate the children components.
Args: Args:
comp_name: name of the component. children: The children of the component.
children: list of children components.
Raises:
ValueError: when an unsupported component is matched.
""" """
if not self.invalid_children: if not self.invalid_children and not self.valid_children:
return return
for child in children:
name = child["name"] comp_name = type(self).__name__
if name in self.invalid_children:
def validate_invalid_child(child_name):
if child_name in self.invalid_children:
raise ValueError( raise ValueError(
f"The component `{comp_name.lower()}` cannot have `{name.lower()}` as a child component" f"The component `{comp_name}` cannot have `{child_name}` as a child component"
) )
def validate_valid_child(child_name):
if child_name not in self.valid_children:
valid_child_list = ", ".join(
[f"`{v_child}`" for v_child in self.valid_children]
)
raise ValueError(
f"The component `{comp_name}` only allows the components: {valid_child_list} as children. Got `{child_name}` instead."
)
for child in children:
name = type(child).__name__
if self.invalid_children:
validate_invalid_child(name)
if self.valid_children:
validate_valid_child(name)
def _get_custom_code(self) -> Optional[str]: def _get_custom_code(self) -> Optional[str]:
"""Get custom code for the component. """Get custom code for the component.

View File

@ -1,4 +1,5 @@
"""A button component.""" """A button component."""
from typing import List
from reflex.components.libs.chakra import ChakraComponent from reflex.components.libs.chakra import ChakraComponent
from reflex.vars import Var from reflex.vars import Var
@ -42,6 +43,9 @@ class Button(ChakraComponent):
# The type of button. # The type of button.
type_: Var[str] type_: Var[str]
# Components that are not allowed as children.
invalid_children: List[str] = ["Button", "MenuButton"]
class ButtonGroup(ChakraComponent): class ButtonGroup(ChakraComponent):
"""A group of buttons.""" """A group of buttons."""

View File

@ -1,6 +1,6 @@
"""Menu components.""" """Menu components."""
from typing import Set from typing import List, Set
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.libs.chakra import ChakraComponent from reflex.components.libs.chakra import ChakraComponent
@ -100,6 +100,9 @@ class MenuButton(ChakraComponent):
# The variant of the menu button. # The variant of the menu button.
variant: Var[str] variant: Var[str]
# Components that are not allowed as children.
invalid_children: List[str] = ["Button", "MenuButton"]
# The tag to use for the menu button. # The tag to use for the menu button.
as_: Var[str] as_: Var[str]

View File

@ -121,13 +121,47 @@ def component5() -> Type[Component]:
""" """
class TestComponent5(Component): class TestComponent5(Component):
tag = "Tag" tag = "RandomComponent"
invalid_children: List[str] = ["Text"] invalid_children: List[str] = ["Text"]
valid_children: List[str] = ["Text"]
return TestComponent5 return TestComponent5
@pytest.fixture
def component6() -> Type[Component]:
"""A test component.
Returns:
A test component.
"""
class TestComponent6(Component):
tag = "RandomComponent"
invalid_children: List[str] = ["Text"]
return TestComponent6
@pytest.fixture
def component7() -> Type[Component]:
"""A test component.
Returns:
A test component.
"""
class TestComponent7(Component):
tag = "RandomComponent"
valid_children: List[str] = ["Text"]
return TestComponent7
@pytest.fixture @pytest.fixture
def on_click1() -> EventHandler: def on_click1() -> EventHandler:
"""A sample on click function. """A sample on click function.
@ -461,16 +495,40 @@ def test_get_hooks_nested2(component3, component4):
) )
def test_unsupported_child_components(component5): @pytest.mark.parametrize("fixture", ["component5", "component6"])
"""Test that a value error is raised when an unsupported component is provided as a child. def test_unsupported_child_components(fixture, request):
"""Test that a value error is raised when an unsupported component (a child component found in the
component's invalid children list) is provided as a child.
Args: Args:
component5: the test component fixture: the test component as a fixture.
request: Pytest request.
""" """
component = request.getfixturevalue(fixture)
with pytest.raises(ValueError) as err: with pytest.raises(ValueError) as err:
comp = component5.create(rx.text("testing component")) comp = component.create(rx.text("testing component"))
comp.render() comp.render()
assert ( assert (
err.value.args[0] err.value.args[0]
== f"The component `tag` cannot have `text` as a child component" == f"The component `{component.__name__}` cannot have `Text` as a child component"
)
@pytest.mark.parametrize("fixture", ["component5", "component7"])
def test_component_with_only_valid_children(fixture, request):
"""Test that a value error is raised when an unsupported component (a child component not found in the
component's valid children list) is provided as a child.
Args:
fixture: the test component as a fixture.
request: Pytest request.
"""
component = request.getfixturevalue(fixture)
with pytest.raises(ValueError) as err:
comp = component.create(rx.box("testing component"))
comp.render()
assert (
err.value.args[0]
== f"The component `{component.__name__}` only allows the components: `Text` as children. "
f"Got `Box` instead."
) )