components deserve to be first class props (#4827)

* components deserve to be first class props

* default back to {}

* smarter yield

* how much does caching help?

* only hit the slower path on _are_fields_known

* remove the cache thingy

* cache the inner _get_component_prop_names

* oops

* dang it darglint

* refactor things a bit

* fix events
This commit is contained in:
Khaleel Al-Adhami 2025-02-19 11:27:33 -08:00 committed by GitHub
parent 762d975a87
commit abab18e165
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 259 additions and 137 deletions

View File

@ -2,9 +2,9 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Iterator from typing import Any, Iterator, Sequence
from reflex.components.component import Component, LiteralComponentVar from reflex.components.component import BaseComponent, Component, ComponentStyle
from reflex.components.tags import Tag from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless from reflex.components.tags.tagless import Tagless
from reflex.config import PerformanceMode, environment from reflex.config import PerformanceMode, environment
@ -12,7 +12,7 @@ from reflex.utils import console
from reflex.utils.decorator import once from reflex.utils.decorator import once
from reflex.utils.imports import ParsedImportDict from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var from reflex.vars import BooleanVar, ObjectVar, Var
from reflex.vars.base import VarData from reflex.vars.base import GLOBAL_CACHE, VarData
from reflex.vars.sequence import LiteralStringVar from reflex.vars.sequence import LiteralStringVar
@ -47,6 +47,11 @@ def validate_str(value: str):
) )
def _components_from_var(var: Var) -> Sequence[BaseComponent]:
var_data = var._get_all_var_data()
return var_data.components if var_data else ()
class Bare(Component): class Bare(Component):
"""A component with no tag.""" """A component with no tag."""
@ -80,8 +85,9 @@ class Bare(Component):
The hooks for the component. The hooks for the component.
""" """
hooks = super()._get_all_hooks_internal() hooks = super()._get_all_hooks_internal()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
hooks |= self.contents._var_value._get_all_hooks_internal() for component in _components_from_var(self.contents):
hooks |= component._get_all_hooks_internal()
return hooks return hooks
def _get_all_hooks(self) -> dict[str, VarData | None]: def _get_all_hooks(self) -> dict[str, VarData | None]:
@ -91,18 +97,22 @@ class Bare(Component):
The hooks for the component. The hooks for the component.
""" """
hooks = super()._get_all_hooks() hooks = super()._get_all_hooks()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
hooks |= self.contents._var_value._get_all_hooks() for component in _components_from_var(self.contents):
hooks |= component._get_all_hooks()
return hooks return hooks
def _get_all_imports(self) -> ParsedImportDict: def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
"""Include the imports for the component. """Include the imports for the component.
Args:
collapse: Whether to collapse the imports.
Returns: Returns:
The imports for the component. The imports for the component.
""" """
imports = super()._get_all_imports() imports = super()._get_all_imports(collapse=collapse)
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data() var_data = self.contents._get_all_var_data()
if var_data: if var_data:
imports |= {k: list(v) for k, v in var_data.imports} imports |= {k: list(v) for k, v in var_data.imports}
@ -115,8 +125,9 @@ class Bare(Component):
The dynamic imports. The dynamic imports.
""" """
dynamic_imports = super()._get_all_dynamic_imports() dynamic_imports = super()._get_all_dynamic_imports()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
dynamic_imports |= self.contents._var_value._get_all_dynamic_imports() for component in _components_from_var(self.contents):
dynamic_imports |= component._get_all_dynamic_imports()
return dynamic_imports return dynamic_imports
def _get_all_custom_code(self) -> set[str]: def _get_all_custom_code(self) -> set[str]:
@ -126,10 +137,24 @@ class Bare(Component):
The custom code. The custom code.
""" """
custom_code = super()._get_all_custom_code() custom_code = super()._get_all_custom_code()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
custom_code |= self.contents._var_value._get_all_custom_code() for component in _components_from_var(self.contents):
custom_code |= component._get_all_custom_code()
return custom_code return custom_code
def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]:
"""Get the components that should be wrapped in the app.
Returns:
The components that should be wrapped in the app.
"""
app_wrap_components = super()._get_all_app_wrap_components()
if isinstance(self.contents, Var):
for component in _components_from_var(self.contents):
if isinstance(component, Component):
app_wrap_components |= component._get_all_app_wrap_components()
return app_wrap_components
def _get_all_refs(self) -> set[str]: def _get_all_refs(self) -> set[str]:
"""Get the refs for the children of the component. """Get the refs for the children of the component.
@ -137,8 +162,9 @@ class Bare(Component):
The refs for the children. The refs for the children.
""" """
refs = super()._get_all_refs() refs = super()._get_all_refs()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
refs |= self.contents._var_value._get_all_refs() for component in _components_from_var(self.contents):
refs |= component._get_all_refs()
return refs return refs
def _render(self) -> Tag: def _render(self) -> Tag:
@ -148,6 +174,33 @@ class Bare(Component):
return Tagless(contents=f"{{{self.contents!s}}}") return Tagless(contents=f"{{{self.contents!s}}}")
return Tagless(contents=str(self.contents)) return Tagless(contents=str(self.contents))
def _add_style_recursive(
self, style: ComponentStyle, theme: Component | None = None
) -> Component:
"""Add style to the component and its children.
Args:
style: The style to add.
theme: The theme to add.
Returns:
The component with the style added.
"""
new_self = super()._add_style_recursive(style, theme)
are_components_touched = False
if isinstance(self.contents, Var):
for component in _components_from_var(self.contents):
if isinstance(component, Component):
component._add_style_recursive(style, theme)
are_components_touched = True
if are_components_touched:
GLOBAL_CACHE.clear()
return new_self
def _get_vars( def _get_vars(
self, include_children: bool = False, ignore_ids: set[int] | None = None self, include_children: bool = False, ignore_ids: set[int] | None = None
) -> Iterator[Var]: ) -> Iterator[Var]:

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import copy import copy
import dataclasses import dataclasses
import inspect
import typing import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import lru_cache, wraps from functools import lru_cache, wraps
@ -21,6 +22,8 @@ from typing import (
Set, Set,
Type, Type,
Union, Union,
get_args,
get_origin,
) )
from typing_extensions import Self from typing_extensions import Self
@ -43,6 +46,7 @@ from reflex.constants import (
from reflex.constants.compiler import SpecialAttributes from reflex.constants.compiler import SpecialAttributes
from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.event import ( from reflex.event import (
EventActionsMixin,
EventCallback, EventCallback,
EventChain, EventChain,
EventHandler, EventHandler,
@ -191,6 +195,25 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
return types._isinstance(obj, type_hint, nested=1) return types._isinstance(obj, type_hint, nested=1)
def _components_from(
component_or_var: Union[BaseComponent, Var],
) -> tuple[BaseComponent, ...]:
"""Get the components from a component or Var.
Args:
component_or_var: The component or Var to get the components from.
Returns:
The components.
"""
if isinstance(component_or_var, Var):
var_data = component_or_var._get_all_var_data()
return var_data.components if var_data else ()
if isinstance(component_or_var, BaseComponent):
return (component_or_var,)
return ()
class Component(BaseComponent, ABC): class Component(BaseComponent, ABC):
"""A component with style, event trigger and other props.""" """A component with style, event trigger and other props."""
@ -489,7 +512,7 @@ class Component(BaseComponent, ABC):
# Remove any keys that were added as events. # Remove any keys that were added as events.
for key in kwargs["event_triggers"]: for key in kwargs["event_triggers"]:
del kwargs[key] kwargs.pop(key, None)
# Place data_ and aria_ attributes into custom_attrs # Place data_ and aria_ attributes into custom_attrs
special_attributes = tuple( special_attributes = tuple(
@ -666,12 +689,21 @@ class Component(BaseComponent, ABC):
return set() return set()
@classmethod @classmethod
@lru_cache(maxsize=None) def _are_fields_known(cls) -> bool:
def get_component_props(cls) -> set[str]: """Check if all fields are known at compile time. True for most components.
"""Get the props that expected a component as value.
Returns: Returns:
The components props. Whether all fields are known at compile time.
"""
return True
@classmethod
@lru_cache(maxsize=None)
def _get_component_prop_names(cls) -> Set[str]:
"""Get the names of the component props. NOTE: This assumes all fields are known.
Returns:
The names of the component props.
""" """
return { return {
name name
@ -680,6 +712,26 @@ class Component(BaseComponent, ABC):
and types._issubclass(field.outer_type_, Component) and types._issubclass(field.outer_type_, Component)
} }
def _get_components_in_props(self) -> Sequence[BaseComponent]:
"""Get the components in the props.
Returns:
The components in the props
"""
if self._are_fields_known():
return [
component
for name in self._get_component_prop_names()
for component in _components_from(getattr(self, name))
]
return [
component
for prop in self.get_props()
if (value := getattr(self, prop)) is not None
and isinstance(value, (BaseComponent, Var))
for component in _components_from(value)
]
@classmethod @classmethod
def create(cls, *children, **props) -> Self: def create(cls, *children, **props) -> Self:
"""Create the component. """Create the component.
@ -1136,6 +1188,9 @@ class Component(BaseComponent, ABC):
if custom_code is not None: if custom_code is not None:
code.add(custom_code) code.add(custom_code)
for component in self._get_components_in_props():
code |= component._get_all_custom_code()
# Add the custom code from add_custom_code method. # Add the custom code from add_custom_code method.
for clz in self._iter_parent_classes_with_method("add_custom_code"): for clz in self._iter_parent_classes_with_method("add_custom_code"):
for item in clz.add_custom_code(self): for item in clz.add_custom_code(self):
@ -1163,7 +1218,7 @@ class Component(BaseComponent, ABC):
The dynamic imports. The dynamic imports.
""" """
# Store the import in a set to avoid duplicates. # Store the import in a set to avoid duplicates.
dynamic_imports = set() dynamic_imports: set[str] = set()
# Get dynamic import for this component. # Get dynamic import for this component.
dynamic_import = self._get_dynamic_imports() dynamic_import = self._get_dynamic_imports()
@ -1174,25 +1229,12 @@ class Component(BaseComponent, ABC):
for child in self.children: for child in self.children:
dynamic_imports |= child._get_all_dynamic_imports() dynamic_imports |= child._get_all_dynamic_imports()
for prop in self.get_component_props(): for component in self._get_components_in_props():
if getattr(self, prop) is not None: dynamic_imports |= component._get_all_dynamic_imports()
dynamic_imports |= getattr(self, prop)._get_all_dynamic_imports()
# Return the dynamic imports # Return the dynamic imports
return dynamic_imports return dynamic_imports
def _get_props_imports(self) -> List[ParsedImportDict]:
"""Get the imports needed for components props.
Returns:
The imports for the components props of the component.
"""
return [
getattr(self, prop)._get_all_imports()
for prop in self.get_component_props()
if getattr(self, prop) is not None
]
def _should_transpile(self, dep: str | None) -> bool: def _should_transpile(self, dep: str | None) -> bool:
"""Check if a dependency should be transpiled. """Check if a dependency should be transpiled.
@ -1303,7 +1345,6 @@ class Component(BaseComponent, ABC):
) )
return imports.merge_imports( return imports.merge_imports(
*self._get_props_imports(),
self._get_dependencies_imports(), self._get_dependencies_imports(),
self._get_hooks_imports(), self._get_hooks_imports(),
_imports, _imports,
@ -1380,6 +1421,8 @@ class Component(BaseComponent, ABC):
for k in var_data.hooks for k in var_data.hooks
} }
) )
for component in var_data.components:
vars_hooks.update(component._get_all_hooks())
return vars_hooks return vars_hooks
def _get_events_hooks(self) -> dict[str, VarData | None]: def _get_events_hooks(self) -> dict[str, VarData | None]:
@ -1528,6 +1571,9 @@ class Component(BaseComponent, ABC):
refs.add(ref) refs.add(ref)
for child in self.children: for child in self.children:
refs |= child._get_all_refs() refs |= child._get_all_refs()
for component in self._get_components_in_props():
refs |= component._get_all_refs()
return refs return refs
def _get_all_custom_components( def _get_all_custom_components(
@ -1551,6 +1597,9 @@ class Component(BaseComponent, ABC):
if not isinstance(child, Component): if not isinstance(child, Component):
continue continue
custom_components |= child._get_all_custom_components(seen=seen) custom_components |= child._get_all_custom_components(seen=seen)
for component in self._get_components_in_props():
if isinstance(component, Component) and component.tag is not None:
custom_components |= component._get_all_custom_components(seen=seen)
return custom_components return custom_components
@property @property
@ -1614,17 +1663,65 @@ class CustomComponent(Component):
# The props of the component. # The props of the component.
props: Dict[str, Any] = {} props: Dict[str, Any] = {}
# Props that reference other components. def __init__(self, **kwargs):
component_props: Dict[str, Component] = {}
def __init__(self, *args, **kwargs):
"""Initialize the custom component. """Initialize the custom component.
Args: Args:
*args: The args to pass to the component.
**kwargs: The kwargs to pass to the component. **kwargs: The kwargs to pass to the component.
""" """
super().__init__(*args, **kwargs) component_fn = kwargs.get("component_fn")
# Set the props.
props_types = typing.get_type_hints(component_fn) if component_fn else {}
props = {key: value for key, value in kwargs.items() if key in props_types}
kwargs = {key: value for key, value in kwargs.items() if key not in props_types}
event_types = {
key
for key in props
if (
(get_origin((annotation := props_types.get(key))) or annotation)
== EventHandler
)
}
def get_args_spec(key: str) -> types.ArgsSpec | Sequence[types.ArgsSpec]:
type_ = props_types[key]
return (
args[0]
if (args := get_args(type_))
else (
annotation_args[1]
if get_origin(
(
annotation := inspect.getfullargspec(
component_fn
).annotations[key]
)
)
is typing.Annotated
and (annotation_args := get_args(annotation))
else no_args_event_spec
)
)
super().__init__(
event_triggers={
key: EventChain.create(
value=props[key],
args_spec=get_args_spec(key),
key=key,
)
for key in event_types
},
**kwargs,
)
to_camel_cased_props = {
format.to_camel_case(key) for key in props if key not in event_types
}
self.get_props = lambda: to_camel_cased_props # pyright: ignore [reportIncompatibleVariableOverride]
# Unset the style. # Unset the style.
self.style = Style() self.style = Style()
@ -1632,51 +1729,36 @@ class CustomComponent(Component):
# Set the tag to the name of the function. # Set the tag to the name of the function.
self.tag = format.to_title_case(self.component_fn.__name__) self.tag = format.to_title_case(self.component_fn.__name__)
# Get the event triggers defined in the component declaration. for key, value in props.items():
event_triggers_in_component_declaration = self.get_event_triggers()
# Set the props.
props = typing.get_type_hints(self.component_fn)
for key, value in kwargs.items():
# Skip kwargs that are not props. # Skip kwargs that are not props.
if key not in props: if key not in props_types:
continue continue
camel_cased_key = format.to_camel_case(key)
# Get the type based on the annotation. # Get the type based on the annotation.
type_ = props[key] type_ = props_types[key]
# Handle event chains. # Handle event chains.
if types._issubclass(type_, EventChain): if types._issubclass(type_, EventActionsMixin):
value = EventChain.create( inspect.getfullargspec(component_fn).annotations[key]
value=value, self.props[camel_cased_key] = EventChain.create(
args_spec=event_triggers_in_component_declaration.get( value=value, args_spec=get_args_spec(key), key=key
key, no_args_event_spec
),
key=key,
) )
self.props[format.to_camel_case(key)] = value
continue continue
# Handle subclasses of Base. value = LiteralVar.create(value)
if isinstance(value, Base): self.props[camel_cased_key] = value
base_value = LiteralVar.create(value) setattr(self, camel_cased_key, value)
# Track hooks and imports associated with Component instances. @classmethod
if base_value is not None and isinstance(value, Component): def _are_fields_known(cls) -> bool:
self.component_props[key] = value """Check if the fields are known.
value = base_value._replace(
merge_var_data=VarData(
imports=value._get_all_imports(),
hooks=value._get_all_hooks(),
)
)
else:
value = base_value
else:
value = LiteralVar.create(value)
# Set the prop. Returns:
self.props[format.to_camel_case(key)] = value Whether the fields are known.
"""
return False
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Check if the component is equal to another. """Check if the component is equal to another.
@ -1698,7 +1780,7 @@ class CustomComponent(Component):
return hash(self.tag) return hash(self.tag)
@classmethod @classmethod
def get_props(cls) -> Set[str]: # pyright: ignore [reportIncompatibleVariableOverride] def get_props(cls) -> Set[str]:
"""Get the props for the component. """Get the props for the component.
Returns: Returns:
@ -1735,27 +1817,8 @@ class CustomComponent(Component):
seen=seen seen=seen
) )
# Fetch custom components from props as well.
for child_component in self.component_props.values():
if child_component.tag is None:
continue
if child_component.tag not in seen:
seen.add(child_component.tag)
if isinstance(child_component, CustomComponent):
custom_components |= {child_component}
custom_components |= child_component._get_all_custom_components(
seen=seen
)
return custom_components return custom_components
def _render(self) -> Tag:
"""Define how to render the component in React.
Returns:
The tag to render.
"""
return super()._render(props=self.props)
def get_prop_vars(self) -> List[Var]: def get_prop_vars(self) -> List[Var]:
"""Get the prop vars. """Get the prop vars.
@ -1765,29 +1828,19 @@ class CustomComponent(Component):
return [ return [
Var( Var(
_js_expr=name, _js_expr=name,
_var_type=(prop._var_type if isinstance(prop, Var) else type(prop)), _var_type=(
prop._var_type
if isinstance(prop, Var)
else (
type(prop)
if not isinstance(prop, EventActionsMixin)
else EventChain
)
),
).guess_type() ).guess_type()
for name, prop in self.props.items() for name, prop in self.props.items()
] ]
def _get_vars(
self, include_children: bool = False, ignore_ids: set[int] | None = None
) -> Iterator[Var]:
"""Walk all Vars used in this component.
Args:
include_children: Whether to include Vars from children.
ignore_ids: The ids to ignore.
Yields:
Each var referenced by the component (props, styles, event handlers).
"""
ignore_ids = ignore_ids or set()
yield from super()._get_vars(
include_children=include_children, ignore_ids=ignore_ids
)
yield from filter(lambda prop: isinstance(prop, Var), self.props.values())
@lru_cache(maxsize=None) # noqa: B019 @lru_cache(maxsize=None) # noqa: B019
def get_component(self) -> Component: def get_component(self) -> Component:
"""Render the component. """Render the component.
@ -2475,6 +2528,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
The VarData for the var. The VarData for the var.
""" """
return VarData.merge( return VarData.merge(
self._var_data,
VarData( VarData(
imports={ imports={
"@emotion/react": [ "@emotion/react": [
@ -2517,9 +2571,21 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
Returns: Returns:
The var. The var.
""" """
var_datas = [
var_data
for var in value._get_vars(include_children=True)
if (var_data := var._get_all_var_data())
]
return LiteralComponentVar( return LiteralComponentVar(
_js_expr="", _js_expr="",
_var_type=type(value), _var_type=type(value),
_var_data=_var_data, _var_data=VarData.merge(
_var_data,
*var_datas,
VarData(
components=(value,),
),
),
_var_value=value, _var_value=value,
) )

View File

@ -61,14 +61,6 @@ class Cond(MemoizationLeaf):
) )
) )
def _get_props_imports(self):
"""Get the imports needed for component's props.
Returns:
The imports for the component's props of the component.
"""
return []
def _render(self) -> Tag: def _render(self) -> Tag:
return CondTag( return CondTag(
cond=self.cond, cond=self.cond,

View File

@ -76,6 +76,7 @@ from reflex.utils.types import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.components.component import BaseComponent
from reflex.state import BaseState from reflex.state import BaseState
from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar
@ -132,6 +133,9 @@ class VarData:
# Position of the hook in the component # Position of the hook in the component
position: Hooks.HookPosition | None = None position: Hooks.HookPosition | None = None
# Components that are part of this var
components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple)
def __init__( def __init__(
self, self,
state: str = "", state: str = "",
@ -140,6 +144,7 @@ class VarData:
hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None, hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None,
deps: list[Var] | None = None, deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None, position: Hooks.HookPosition | None = None,
components: Iterable[BaseComponent] | None = None,
): ):
"""Initialize the var data. """Initialize the var data.
@ -150,6 +155,7 @@ class VarData:
hooks: Hooks that need to be present in the component to render this var. hooks: Hooks that need to be present in the component to render this var.
deps: Dependencies of the var for useCallback. deps: Dependencies of the var for useCallback.
position: Position of the hook in the component. position: Position of the hook in the component.
components: Components that are part of this var.
""" """
if isinstance(hooks, str): if isinstance(hooks, str):
hooks = [hooks] hooks = [hooks]
@ -164,6 +170,7 @@ class VarData:
object.__setattr__(self, "hooks", tuple(hooks or {})) object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(self, "deps", tuple(deps or [])) object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None) object.__setattr__(self, "position", position or None)
object.__setattr__(self, "components", tuple(components or []))
if hooks and any(hooks.values()): if hooks and any(hooks.values()):
merged_var_data = VarData.merge(self, *hooks.values()) merged_var_data = VarData.merge(self, *hooks.values())
@ -174,6 +181,7 @@ class VarData:
object.__setattr__(self, "hooks", merged_var_data.hooks) object.__setattr__(self, "hooks", merged_var_data.hooks)
object.__setattr__(self, "deps", merged_var_data.deps) object.__setattr__(self, "deps", merged_var_data.deps)
object.__setattr__(self, "position", merged_var_data.position) object.__setattr__(self, "position", merged_var_data.position)
object.__setattr__(self, "components", merged_var_data.components)
def old_school_imports(self) -> ImportDict: def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict. """Return the imports as a mutable dict.
@ -242,17 +250,19 @@ class VarData:
else: else:
position = None position = None
if state or _imports or hooks or field_name or deps or position: components = tuple(
return VarData( component for var_data in all_var_datas for component in var_data.components
state=state, )
field_name=field_name,
imports=_imports,
hooks=hooks,
deps=deps,
position=position,
)
return None return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
deps=deps,
position=position,
components=components,
)
def __bool__(self) -> bool: def __bool__(self) -> bool:
"""Check if the var data is non-empty. """Check if the var data is non-empty.
@ -267,6 +277,7 @@ class VarData:
or self.field_name or self.field_name
or self.deps or self.deps
or self.position or self.position
or self.components
) )
@classmethod @classmethod

View File

@ -871,7 +871,7 @@ def test_create_custom_component(my_component):
""" """
component = CustomComponent(component_fn=my_component, prop1="test", prop2=1) component = CustomComponent(component_fn=my_component, prop1="test", prop2=1)
assert component.tag == "MyComponent" assert component.tag == "MyComponent"
assert component.get_props() == set() assert component.get_props() == {"prop1", "prop2"}
assert component._get_all_custom_components() == {component} assert component._get_all_custom_components() == {component}