components deserve to be first class props

This commit is contained in:
Khaleel Al-Adhami 2025-02-14 14:55:55 -08:00
parent 10bae9577c
commit f158709314
5 changed files with 177 additions and 132 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any, Iterator from typing import Any, Iterator
from reflex.components.component import Component, LiteralComponentVar from reflex.components.component import Component, ComponentStyle
from reflex.components.tags import Tag from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless from reflex.components.tags.tagless import Tagless
from reflex.config import PerformanceMode, environment from reflex.config import PerformanceMode, environment
@ -12,7 +12,7 @@ from reflex.utils import console
from reflex.utils.decorator import once from reflex.utils.decorator import once
from reflex.utils.imports import ParsedImportDict from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var from reflex.vars import BooleanVar, ObjectVar, Var
from reflex.vars.base import VarData from reflex.vars.base import GLOBAL_CACHE, VarData
from reflex.vars.sequence import LiteralStringVar from reflex.vars.sequence import LiteralStringVar
@ -80,8 +80,11 @@ class Bare(Component):
The hooks for the component. The hooks for the component.
""" """
hooks = super()._get_all_hooks_internal() hooks = super()._get_all_hooks_internal()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
hooks |= self.contents._var_value._get_all_hooks_internal() var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
hooks |= component._get_all_hooks_internal()
return hooks return hooks
def _get_all_hooks(self) -> dict[str, VarData | None]: def _get_all_hooks(self) -> dict[str, VarData | None]:
@ -91,18 +94,24 @@ class Bare(Component):
The hooks for the component. The hooks for the component.
""" """
hooks = super()._get_all_hooks() hooks = super()._get_all_hooks()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
hooks |= self.contents._var_value._get_all_hooks() var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
hooks |= component._get_all_hooks()
return hooks return hooks
def _get_all_imports(self) -> ParsedImportDict: def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
"""Include the imports for the component. """Include the imports for the component.
Args:
collapse: Whether to collapse the imports.
Returns: Returns:
The imports for the component. The imports for the component.
""" """
imports = super()._get_all_imports() imports = super()._get_all_imports(collapse=collapse)
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data() var_data = self.contents._get_all_var_data()
if var_data: if var_data:
imports |= {k: list(v) for k, v in var_data.imports} imports |= {k: list(v) for k, v in var_data.imports}
@ -115,8 +124,11 @@ class Bare(Component):
The dynamic imports. The dynamic imports.
""" """
dynamic_imports = super()._get_all_dynamic_imports() dynamic_imports = super()._get_all_dynamic_imports()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
dynamic_imports |= self.contents._var_value._get_all_dynamic_imports() var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
dynamic_imports |= component._get_all_dynamic_imports()
return dynamic_imports return dynamic_imports
def _get_all_custom_code(self) -> set[str]: def _get_all_custom_code(self) -> set[str]:
@ -126,10 +138,28 @@ class Bare(Component):
The custom code. The custom code.
""" """
custom_code = super()._get_all_custom_code() custom_code = super()._get_all_custom_code()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
custom_code |= self.contents._var_value._get_all_custom_code() var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
custom_code |= component._get_all_custom_code()
return custom_code return custom_code
def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]:
"""Get the components that should be wrapped in the app.
Returns:
The components that should be wrapped in the app.
"""
app_wrap_components = super()._get_all_app_wrap_components()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
if isinstance(component, Component):
app_wrap_components |= component._get_all_app_wrap_components()
return app_wrap_components
def _get_all_refs(self) -> set[str]: def _get_all_refs(self) -> set[str]:
"""Get the refs for the children of the component. """Get the refs for the children of the component.
@ -137,8 +167,11 @@ class Bare(Component):
The refs for the children. The refs for the children.
""" """
refs = super()._get_all_refs() refs = super()._get_all_refs()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
refs |= self.contents._var_value._get_all_refs() var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
refs |= component._get_all_refs()
return refs return refs
def _render(self) -> Tag: def _render(self) -> Tag:
@ -148,6 +181,35 @@ class Bare(Component):
return Tagless(contents=f"{{{self.contents!s}}}") return Tagless(contents=f"{{{self.contents!s}}}")
return Tagless(contents=str(self.contents)) return Tagless(contents=str(self.contents))
def _add_style_recursive(
self, style: ComponentStyle, theme: Component | None = None
) -> Component:
"""Add style to the component and its children.
Args:
style: The style to add.
theme: The theme to add.
Returns:
The component with the style added.
"""
new_self = super()._add_style_recursive(style, theme)
are_components_touched = False
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
if isinstance(component, Component):
component._add_style_recursive(style, theme)
are_components_touched = True
if are_components_touched:
GLOBAL_CACHE.clear()
return new_self
def _get_vars( def _get_vars(
self, include_children: bool = False, ignore_ids: set[int] | None = None self, include_children: bool = False, ignore_ids: set[int] | None = None
) -> Iterator[Var]: ) -> Iterator[Var]:

View File

@ -191,6 +191,25 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
return types._isinstance(obj, type_hint, nested=1) return types._isinstance(obj, type_hint, nested=1)
def _components_from(
component_or_var: Union[BaseComponent, Var],
) -> tuple[BaseComponent, ...]:
"""Get the components from a component or Var.
Args:
component_or_var: The component or Var to get the components from.
Returns:
The components.
"""
if isinstance(component_or_var, Var):
var_data = component_or_var._get_all_var_data()
return var_data.components if var_data else ()
if isinstance(component_or_var, BaseComponent):
return (component_or_var,)
return ()
class Component(BaseComponent, ABC): class Component(BaseComponent, ABC):
"""A component with style, event trigger and other props.""" """A component with style, event trigger and other props."""
@ -665,20 +684,15 @@ class Component(BaseComponent, ABC):
""" """
return set() return set()
@classmethod def _get_components_in_props(self) -> Iterator[BaseComponent]:
@lru_cache(maxsize=None) """Get the components in the props.
def get_component_props(cls) -> set[str]:
"""Get the props that expected a component as value.
Returns: Yields:
The components props. The components in the props.
""" """
return { for prop in self.get_props():
name value = getattr(self, prop)
for name, field in cls.get_fields().items() yield from _components_from(value)
if name in cls.get_props()
and types._issubclass(field.outer_type_, Component)
}
@classmethod @classmethod
def create(cls, *children, **props) -> Self: def create(cls, *children, **props) -> Self:
@ -1136,6 +1150,9 @@ class Component(BaseComponent, ABC):
if custom_code is not None: if custom_code is not None:
code.add(custom_code) code.add(custom_code)
for component in self._get_components_in_props():
code |= component._get_all_custom_code()
# Add the custom code from add_custom_code method. # Add the custom code from add_custom_code method.
for clz in self._iter_parent_classes_with_method("add_custom_code"): for clz in self._iter_parent_classes_with_method("add_custom_code"):
for item in clz.add_custom_code(self): for item in clz.add_custom_code(self):
@ -1163,7 +1180,7 @@ class Component(BaseComponent, ABC):
The dynamic imports. The dynamic imports.
""" """
# Store the import in a set to avoid duplicates. # Store the import in a set to avoid duplicates.
dynamic_imports = set() dynamic_imports: set[str] = set()
# Get dynamic import for this component. # Get dynamic import for this component.
dynamic_import = self._get_dynamic_imports() dynamic_import = self._get_dynamic_imports()
@ -1174,25 +1191,12 @@ class Component(BaseComponent, ABC):
for child in self.children: for child in self.children:
dynamic_imports |= child._get_all_dynamic_imports() dynamic_imports |= child._get_all_dynamic_imports()
for prop in self.get_component_props(): for component in self._get_components_in_props():
if getattr(self, prop) is not None: dynamic_imports |= component._get_all_dynamic_imports()
dynamic_imports |= getattr(self, prop)._get_all_dynamic_imports()
# Return the dynamic imports # Return the dynamic imports
return dynamic_imports return dynamic_imports
def _get_props_imports(self) -> List[ParsedImportDict]:
"""Get the imports needed for components props.
Returns:
The imports for the components props of the component.
"""
return [
getattr(self, prop)._get_all_imports()
for prop in self.get_component_props()
if getattr(self, prop) is not None
]
def _should_transpile(self, dep: str | None) -> bool: def _should_transpile(self, dep: str | None) -> bool:
"""Check if a dependency should be transpiled. """Check if a dependency should be transpiled.
@ -1303,7 +1307,6 @@ class Component(BaseComponent, ABC):
) )
return imports.merge_imports( return imports.merge_imports(
*self._get_props_imports(),
self._get_dependencies_imports(), self._get_dependencies_imports(),
self._get_hooks_imports(), self._get_hooks_imports(),
_imports, _imports,
@ -1380,6 +1383,8 @@ class Component(BaseComponent, ABC):
for k in var_data.hooks for k in var_data.hooks
} }
) )
for component in var_data.components:
vars_hooks.update(component._get_all_hooks())
return vars_hooks return vars_hooks
def _get_events_hooks(self) -> dict[str, VarData | None]: def _get_events_hooks(self) -> dict[str, VarData | None]:
@ -1528,6 +1533,9 @@ class Component(BaseComponent, ABC):
refs.add(ref) refs.add(ref)
for child in self.children: for child in self.children:
refs |= child._get_all_refs() refs |= child._get_all_refs()
for component in self._get_components_in_props():
refs |= component._get_all_refs()
return refs return refs
def _get_all_custom_components( def _get_all_custom_components(
@ -1551,6 +1559,9 @@ class Component(BaseComponent, ABC):
if not isinstance(child, Component): if not isinstance(child, Component):
continue continue
custom_components |= child._get_all_custom_components(seen=seen) custom_components |= child._get_all_custom_components(seen=seen)
for component in self._get_components_in_props():
if isinstance(component, Component) and component.tag is not None:
custom_components |= component._get_all_custom_components(seen=seen)
return custom_components return custom_components
@property @property
@ -1614,17 +1625,25 @@ class CustomComponent(Component):
# The props of the component. # The props of the component.
props: Dict[str, Any] = {} props: Dict[str, Any] = {}
# Props that reference other components. def __init__(self, **kwargs):
component_props: Dict[str, Component] = {}
def __init__(self, *args, **kwargs):
"""Initialize the custom component. """Initialize the custom component.
Args: Args:
*args: The args to pass to the component.
**kwargs: The kwargs to pass to the component. **kwargs: The kwargs to pass to the component.
""" """
super().__init__(*args, **kwargs) component_fn = kwargs.get("component_fn")
# Set the props.
props_types = typing.get_type_hints(component_fn)
props = {key: value for key, value in kwargs.items() if key in props_types}
kwargs = {key: value for key, value in kwargs.items() if key not in props_types}
super().__init__(
**kwargs,
)
to_camel_cased_props = {format.to_camel_case(key) for key in props}
self.get_props = lambda: to_camel_cased_props # pyright: ignore [reportIncompatibleVariableOverride]
# Unset the style. # Unset the style.
self.style = Style() self.style = Style()
@ -1635,15 +1654,15 @@ class CustomComponent(Component):
# Get the event triggers defined in the component declaration. # Get the event triggers defined in the component declaration.
event_triggers_in_component_declaration = self.get_event_triggers() event_triggers_in_component_declaration = self.get_event_triggers()
# Set the props. for key, value in props.items():
props = typing.get_type_hints(self.component_fn)
for key, value in kwargs.items():
# Skip kwargs that are not props. # Skip kwargs that are not props.
if key not in props: if key not in props_types:
continue continue
camel_cased_key = format.to_camel_case(key)
# Get the type based on the annotation. # Get the type based on the annotation.
type_ = props[key] type_ = props_types[key]
# Handle event chains. # Handle event chains.
if types._issubclass(type_, EventChain): if types._issubclass(type_, EventChain):
@ -1654,29 +1673,14 @@ class CustomComponent(Component):
), ),
key=key, key=key,
) )
self.props[format.to_camel_case(key)] = value self.props[camel_cased_key] = value
continue continue
# Handle subclasses of Base. value = LiteralVar.create(value)
if isinstance(value, Base):
base_value = LiteralVar.create(value)
# Track hooks and imports associated with Component instances.
if base_value is not None and isinstance(value, Component):
self.component_props[key] = value
value = base_value._replace(
merge_var_data=VarData(
imports=value._get_all_imports(),
hooks=value._get_all_hooks(),
)
)
else:
value = base_value
else:
value = LiteralVar.create(value)
# Set the prop. # Set the prop.
self.props[format.to_camel_case(key)] = value self.props[camel_cased_key] = value
setattr(self, camel_cased_key, value)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Check if the component is equal to another. """Check if the component is equal to another.
@ -1698,7 +1702,7 @@ class CustomComponent(Component):
return hash(self.tag) return hash(self.tag)
@classmethod @classmethod
def get_props(cls) -> Set[str]: # pyright: ignore [reportIncompatibleVariableOverride] def get_props(cls) -> Set[str]:
"""Get the props for the component. """Get the props for the component.
Returns: Returns:
@ -1735,27 +1739,8 @@ class CustomComponent(Component):
seen=seen seen=seen
) )
# Fetch custom components from props as well.
for child_component in self.component_props.values():
if child_component.tag is None:
continue
if child_component.tag not in seen:
seen.add(child_component.tag)
if isinstance(child_component, CustomComponent):
custom_components |= {child_component}
custom_components |= child_component._get_all_custom_components(
seen=seen
)
return custom_components return custom_components
def _render(self) -> Tag:
"""Define how to render the component in React.
Returns:
The tag to render.
"""
return super()._render(props=self.props)
def get_prop_vars(self) -> List[Var]: def get_prop_vars(self) -> List[Var]:
"""Get the prop vars. """Get the prop vars.
@ -1770,24 +1755,6 @@ class CustomComponent(Component):
for name, prop in self.props.items() for name, prop in self.props.items()
] ]
def _get_vars(
self, include_children: bool = False, ignore_ids: set[int] | None = None
) -> Iterator[Var]:
"""Walk all Vars used in this component.
Args:
include_children: Whether to include Vars from children.
ignore_ids: The ids to ignore.
Yields:
Each var referenced by the component (props, styles, event handlers).
"""
ignore_ids = ignore_ids or set()
yield from super()._get_vars(
include_children=include_children, ignore_ids=ignore_ids
)
yield from filter(lambda prop: isinstance(prop, Var), self.props.values())
@lru_cache(maxsize=None) # noqa: B019 @lru_cache(maxsize=None) # noqa: B019
def get_component(self) -> Component: def get_component(self) -> Component:
"""Render the component. """Render the component.
@ -2475,6 +2442,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
The VarData for the var. The VarData for the var.
""" """
return VarData.merge( return VarData.merge(
self._var_data,
VarData( VarData(
imports={ imports={
"@emotion/react": [ "@emotion/react": [
@ -2517,9 +2485,21 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
Returns: Returns:
The var. The var.
""" """
var_datas = [
var_data
for var in value._get_vars(include_children=True)
if (var_data := var._get_all_var_data())
]
return LiteralComponentVar( return LiteralComponentVar(
_js_expr="", _js_expr="",
_var_type=type(value), _var_type=type(value),
_var_data=_var_data, _var_data=VarData.merge(
_var_data,
*var_datas,
VarData(
components=(value,),
),
),
_var_value=value, _var_value=value,
) )

View File

@ -61,14 +61,6 @@ class Cond(MemoizationLeaf):
) )
) )
def _get_props_imports(self):
"""Get the imports needed for component's props.
Returns:
The imports for the component's props of the component.
"""
return []
def _render(self) -> Tag: def _render(self) -> Tag:
return CondTag( return CondTag(
cond=self.cond, cond=self.cond,

View File

@ -73,6 +73,7 @@ from reflex.utils.types import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.components.component import BaseComponent
from reflex.state import BaseState from reflex.state import BaseState
from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar
@ -129,6 +130,9 @@ class VarData:
# Position of the hook in the component # Position of the hook in the component
position: Hooks.HookPosition | None = None position: Hooks.HookPosition | None = None
# Components that are part of this var
components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple)
def __init__( def __init__(
self, self,
state: str = "", state: str = "",
@ -137,6 +141,7 @@ class VarData:
hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None, hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None,
deps: list[Var] | None = None, deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None, position: Hooks.HookPosition | None = None,
components: Iterable[BaseComponent] | None = None,
): ):
"""Initialize the var data. """Initialize the var data.
@ -147,6 +152,7 @@ class VarData:
hooks: Hooks that need to be present in the component to render this var. hooks: Hooks that need to be present in the component to render this var.
deps: Dependencies of the var for useCallback. deps: Dependencies of the var for useCallback.
position: Position of the hook in the component. position: Position of the hook in the component.
components: Components that are part of this var.
""" """
if isinstance(hooks, str): if isinstance(hooks, str):
hooks = [hooks] hooks = [hooks]
@ -161,6 +167,7 @@ class VarData:
object.__setattr__(self, "hooks", tuple(hooks or {})) object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(self, "deps", tuple(deps or [])) object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None) object.__setattr__(self, "position", position or None)
object.__setattr__(self, "components", tuple(components or []))
if hooks and any(hooks.values()): if hooks and any(hooks.values()):
merged_var_data = VarData.merge(self, *hooks.values()) merged_var_data = VarData.merge(self, *hooks.values())
@ -171,6 +178,7 @@ class VarData:
object.__setattr__(self, "hooks", merged_var_data.hooks) object.__setattr__(self, "hooks", merged_var_data.hooks)
object.__setattr__(self, "deps", merged_var_data.deps) object.__setattr__(self, "deps", merged_var_data.deps)
object.__setattr__(self, "position", merged_var_data.position) object.__setattr__(self, "position", merged_var_data.position)
object.__setattr__(self, "components", merged_var_data.components)
def old_school_imports(self) -> ImportDict: def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict. """Return the imports as a mutable dict.
@ -239,17 +247,19 @@ class VarData:
else: else:
position = None position = None
if state or _imports or hooks or field_name or deps or position: components = tuple(
return VarData( component for var_data in all_var_datas for component in var_data.components
state=state, )
field_name=field_name,
imports=_imports,
hooks=hooks,
deps=deps,
position=position,
)
return None return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
deps=deps,
position=position,
components=components,
)
def __bool__(self) -> bool: def __bool__(self) -> bool:
"""Check if the var data is non-empty. """Check if the var data is non-empty.
@ -264,6 +274,7 @@ class VarData:
or self.field_name or self.field_name
or self.deps or self.deps
or self.position or self.position
or self.components
) )
@classmethod @classmethod

View File

@ -871,7 +871,7 @@ def test_create_custom_component(my_component):
""" """
component = CustomComponent(component_fn=my_component, prop1="test", prop2=1) component = CustomComponent(component_fn=my_component, prop1="test", prop2=1)
assert component.tag == "MyComponent" assert component.tag == "MyComponent"
assert component.get_props() == set() assert component.get_props() == {"prop1", "prop2"}
assert component._get_all_custom_components() == {component} assert component._get_all_custom_components() == {component}