Fix custom components special props (#1956)

This commit is contained in:
Nikhil Rao 2023-10-12 15:27:41 -07:00 committed by GitHub
parent b13e9c92e3
commit 7019708638
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 11 deletions

View File

@ -540,7 +540,7 @@ class Component(Base, ABC):
if self.valid_children:
validate_valid_child(name)
def _get_custom_code(self) -> Optional[str]:
def _get_custom_code(self) -> str | None:
"""Get custom code for the component.
Returns:
@ -569,7 +569,7 @@ class Component(Base, ABC):
# Return the code.
return code
def _get_dynamic_imports(self) -> Optional[str]:
def _get_dynamic_imports(self) -> str | None:
"""Get dynamic import for the component.
Returns:
@ -667,7 +667,7 @@ class Component(Base, ABC):
if hook
)
def _get_hooks(self) -> Optional[str]:
def _get_hooks(self) -> str | None:
"""Get the React hooks for this component.
Downstream components should override this method to add their own hooks.
@ -697,7 +697,7 @@ class Component(Base, ABC):
return code
def get_ref(self) -> Optional[str]:
def get_ref(self) -> str | None:
"""Get the name of the ref for the component.
Returns:
@ -723,7 +723,7 @@ class Component(Base, ABC):
return refs
def get_custom_components(
self, seen: Optional[Set[str]] = None
self, seen: set[str] | None = None
) -> Set[CustomComponent]:
"""Get all the custom components used by the component.
@ -846,7 +846,7 @@ class CustomComponent(Component):
return set()
def get_custom_components(
self, seen: Optional[Set[str]] = None
self, seen: set[str] | None = None
) -> Set[CustomComponent]:
"""Get all the custom components used by the component.
@ -875,7 +875,10 @@ class CustomComponent(Component):
Returns:
The tag to render.
"""
return Tag(name=self.tag).add_props(**self.props)
return Tag(
name=self.tag if not self.alias else self.alias,
special_props=self.special_props,
).add_props(**self.props)
def get_prop_vars(self) -> List[BaseVar]:
"""Get the prop vars.
@ -914,6 +917,8 @@ def custom_component(
@wraps(component_fn)
def wrapper(*children, **props) -> CustomComponent:
# Remove the children from the props.
props.pop("children", None)
return CustomComponent(component_fn=component_fn, children=children, **props)
return wrapper

View File

@ -6,7 +6,7 @@ import textwrap
from typing import Any, Callable, Dict, Union
from reflex.compiler import utils
from reflex.components.component import Component
from reflex.components.component import Component, CustomComponent
from reflex.components.datadisplay.list import ListItem, OrderedList, UnorderedList
from reflex.components.navigation import Link
from reflex.components.tags.tag import Tag
@ -19,6 +19,7 @@ from reflex.vars import ImportVar, Var
# Special vars used in the component map.
_CHILDREN = Var.create_safe("children", is_local=False)
_PROPS = Var.create_safe("...props", is_local=False)
_MOCK_ARG = Var.create_safe("")
# Special remark plugins.
_REMARK_MATH = Var.create_safe("remarkMath", is_local=False)
@ -122,6 +123,25 @@ class Markdown(Component):
# Create the component.
return super().create(src, component_map=component_map, **props)
def get_custom_components(
self, seen: set[str] | None = None
) -> set[CustomComponent]:
"""Get all the custom components used by the component.
Args:
seen: The tags of the components that have already been seen.
Returns:
The set of custom components.
"""
custom_components = super().get_custom_components(seen=seen)
# Get the custom components for each tag.
for component in self.component_map.values():
custom_components |= component(_MOCK_ARG).get_custom_components(seen=seen)
return custom_components
def _get_imports(self) -> imports.ImportDict:
# Import here to avoid circular imports.
from reflex.components.datadisplay.code import Code, CodeBlock
@ -145,9 +165,7 @@ class Markdown(Component):
# Get the imports for each component.
for component in self.component_map.values():
imports = utils.merge_imports(
imports, component(Var.create("")).get_imports()
)
imports = utils.merge_imports(imports, component(_MOCK_ARG).get_imports())
# Get the imports for the code components.
imports = utils.merge_imports(