Fix custom components special props (#1956)

This commit is contained in:
Nikhil Rao 2023-10-12 15:27:41 -07:00 committed by GitHub
parent b13e9c92e3
commit 7019708638
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 11 deletions

View File

@ -540,7 +540,7 @@ class Component(Base, ABC):
if self.valid_children: if self.valid_children:
validate_valid_child(name) validate_valid_child(name)
def _get_custom_code(self) -> Optional[str]: def _get_custom_code(self) -> str | None:
"""Get custom code for the component. """Get custom code for the component.
Returns: Returns:
@ -569,7 +569,7 @@ class Component(Base, ABC):
# Return the code. # Return the code.
return code return code
def _get_dynamic_imports(self) -> Optional[str]: def _get_dynamic_imports(self) -> str | None:
"""Get dynamic import for the component. """Get dynamic import for the component.
Returns: Returns:
@ -667,7 +667,7 @@ class Component(Base, ABC):
if hook if hook
) )
def _get_hooks(self) -> Optional[str]: def _get_hooks(self) -> str | None:
"""Get the React hooks for this component. """Get the React hooks for this component.
Downstream components should override this method to add their own hooks. Downstream components should override this method to add their own hooks.
@ -697,7 +697,7 @@ class Component(Base, ABC):
return code return code
def get_ref(self) -> Optional[str]: def get_ref(self) -> str | None:
"""Get the name of the ref for the component. """Get the name of the ref for the component.
Returns: Returns:
@ -723,7 +723,7 @@ class Component(Base, ABC):
return refs return refs
def get_custom_components( def get_custom_components(
self, seen: Optional[Set[str]] = None self, seen: set[str] | None = None
) -> Set[CustomComponent]: ) -> Set[CustomComponent]:
"""Get all the custom components used by the component. """Get all the custom components used by the component.
@ -846,7 +846,7 @@ class CustomComponent(Component):
return set() return set()
def get_custom_components( def get_custom_components(
self, seen: Optional[Set[str]] = None self, seen: set[str] | None = None
) -> Set[CustomComponent]: ) -> Set[CustomComponent]:
"""Get all the custom components used by the component. """Get all the custom components used by the component.
@ -875,7 +875,10 @@ class CustomComponent(Component):
Returns: Returns:
The tag to render. The tag to render.
""" """
return Tag(name=self.tag).add_props(**self.props) return Tag(
name=self.tag if not self.alias else self.alias,
special_props=self.special_props,
).add_props(**self.props)
def get_prop_vars(self) -> List[BaseVar]: def get_prop_vars(self) -> List[BaseVar]:
"""Get the prop vars. """Get the prop vars.
@ -914,6 +917,8 @@ def custom_component(
@wraps(component_fn) @wraps(component_fn)
def wrapper(*children, **props) -> CustomComponent: def wrapper(*children, **props) -> CustomComponent:
# Remove the children from the props.
props.pop("children", None)
return CustomComponent(component_fn=component_fn, children=children, **props) return CustomComponent(component_fn=component_fn, children=children, **props)
return wrapper return wrapper

View File

@ -6,7 +6,7 @@ import textwrap
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union
from reflex.compiler import utils from reflex.compiler import utils
from reflex.components.component import Component from reflex.components.component import Component, CustomComponent
from reflex.components.datadisplay.list import ListItem, OrderedList, UnorderedList from reflex.components.datadisplay.list import ListItem, OrderedList, UnorderedList
from reflex.components.navigation import Link from reflex.components.navigation import Link
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
@ -19,6 +19,7 @@ from reflex.vars import ImportVar, Var
# Special vars used in the component map. # Special vars used in the component map.
_CHILDREN = Var.create_safe("children", is_local=False) _CHILDREN = Var.create_safe("children", is_local=False)
_PROPS = Var.create_safe("...props", is_local=False) _PROPS = Var.create_safe("...props", is_local=False)
_MOCK_ARG = Var.create_safe("")
# Special remark plugins. # Special remark plugins.
_REMARK_MATH = Var.create_safe("remarkMath", is_local=False) _REMARK_MATH = Var.create_safe("remarkMath", is_local=False)
@ -122,6 +123,25 @@ class Markdown(Component):
# Create the component. # Create the component.
return super().create(src, component_map=component_map, **props) return super().create(src, component_map=component_map, **props)
def get_custom_components(
self, seen: set[str] | None = None
) -> set[CustomComponent]:
"""Get all the custom components used by the component.
Args:
seen: The tags of the components that have already been seen.
Returns:
The set of custom components.
"""
custom_components = super().get_custom_components(seen=seen)
# Get the custom components for each tag.
for component in self.component_map.values():
custom_components |= component(_MOCK_ARG).get_custom_components(seen=seen)
return custom_components
def _get_imports(self) -> imports.ImportDict: def _get_imports(self) -> imports.ImportDict:
# Import here to avoid circular imports. # Import here to avoid circular imports.
from reflex.components.datadisplay.code import Code, CodeBlock from reflex.components.datadisplay.code import Code, CodeBlock
@ -145,9 +165,7 @@ class Markdown(Component):
# Get the imports for each component. # Get the imports for each component.
for component in self.component_map.values(): for component in self.component_map.values():
imports = utils.merge_imports( imports = utils.merge_imports(imports, component(_MOCK_ARG).get_imports())
imports, component(Var.create("")).get_imports()
)
# Get the imports for the code components. # Get the imports for the code components.
imports = utils.merge_imports( imports = utils.merge_imports(