components deserve to be first class props (#4827)

* components deserve to be first class props

* default back to {}

* smarter yield

* how much does caching help?

* only hit the slower path on _are_fields_known

* remove the cache thingy

* cache the inner _get_component_prop_names

* oops

* dang it darglint

* refactor things a bit

* fix events
This commit is contained in:
Khaleel Al-Adhami 2025-02-19 11:27:33 -08:00 committed by GitHub
parent 762d975a87
commit abab18e165
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 259 additions and 137 deletions

View File

@ -2,9 +2,9 @@
from __future__ import annotations
from typing import Any, Iterator
from typing import Any, Iterator, Sequence
from reflex.components.component import Component, LiteralComponentVar
from reflex.components.component import BaseComponent, Component, ComponentStyle
from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless
from reflex.config import PerformanceMode, environment
@ -12,7 +12,7 @@ from reflex.utils import console
from reflex.utils.decorator import once
from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var
from reflex.vars.base import VarData
from reflex.vars.base import GLOBAL_CACHE, VarData
from reflex.vars.sequence import LiteralStringVar
@ -47,6 +47,11 @@ def validate_str(value: str):
)
def _components_from_var(var: Var) -> Sequence[BaseComponent]:
var_data = var._get_all_var_data()
return var_data.components if var_data else ()
class Bare(Component):
"""A component with no tag."""
@ -80,8 +85,9 @@ class Bare(Component):
The hooks for the component.
"""
hooks = super()._get_all_hooks_internal()
if isinstance(self.contents, LiteralComponentVar):
hooks |= self.contents._var_value._get_all_hooks_internal()
if isinstance(self.contents, Var):
for component in _components_from_var(self.contents):
hooks |= component._get_all_hooks_internal()
return hooks
def _get_all_hooks(self) -> dict[str, VarData | None]:
@ -91,18 +97,22 @@ class Bare(Component):
The hooks for the component.
"""
hooks = super()._get_all_hooks()
if isinstance(self.contents, LiteralComponentVar):
hooks |= self.contents._var_value._get_all_hooks()
if isinstance(self.contents, Var):
for component in _components_from_var(self.contents):
hooks |= component._get_all_hooks()
return hooks
def _get_all_imports(self) -> ParsedImportDict:
def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict:
"""Include the imports for the component.
Args:
collapse: Whether to collapse the imports.
Returns:
The imports for the component.
"""
imports = super()._get_all_imports()
if isinstance(self.contents, LiteralComponentVar):
imports = super()._get_all_imports(collapse=collapse)
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
imports |= {k: list(v) for k, v in var_data.imports}
@ -115,8 +125,9 @@ class Bare(Component):
The dynamic imports.
"""
dynamic_imports = super()._get_all_dynamic_imports()
if isinstance(self.contents, LiteralComponentVar):
dynamic_imports |= self.contents._var_value._get_all_dynamic_imports()
if isinstance(self.contents, Var):
for component in _components_from_var(self.contents):
dynamic_imports |= component._get_all_dynamic_imports()
return dynamic_imports
def _get_all_custom_code(self) -> set[str]:
@ -126,10 +137,24 @@ class Bare(Component):
The custom code.
"""
custom_code = super()._get_all_custom_code()
if isinstance(self.contents, LiteralComponentVar):
custom_code |= self.contents._var_value._get_all_custom_code()
if isinstance(self.contents, Var):
for component in _components_from_var(self.contents):
custom_code |= component._get_all_custom_code()
return custom_code
def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]:
"""Get the components that should be wrapped in the app.
Returns:
The components that should be wrapped in the app.
"""
app_wrap_components = super()._get_all_app_wrap_components()
if isinstance(self.contents, Var):
for component in _components_from_var(self.contents):
if isinstance(component, Component):
app_wrap_components |= component._get_all_app_wrap_components()
return app_wrap_components
def _get_all_refs(self) -> set[str]:
"""Get the refs for the children of the component.
@ -137,8 +162,9 @@ class Bare(Component):
The refs for the children.
"""
refs = super()._get_all_refs()
if isinstance(self.contents, LiteralComponentVar):
refs |= self.contents._var_value._get_all_refs()
if isinstance(self.contents, Var):
for component in _components_from_var(self.contents):
refs |= component._get_all_refs()
return refs
def _render(self) -> Tag:
@ -148,6 +174,33 @@ class Bare(Component):
return Tagless(contents=f"{{{self.contents!s}}}")
return Tagless(contents=str(self.contents))
def _add_style_recursive(
self, style: ComponentStyle, theme: Component | None = None
) -> Component:
"""Add style to the component and its children.
Args:
style: The style to add.
theme: The theme to add.
Returns:
The component with the style added.
"""
new_self = super()._add_style_recursive(style, theme)
are_components_touched = False
if isinstance(self.contents, Var):
for component in _components_from_var(self.contents):
if isinstance(component, Component):
component._add_style_recursive(style, theme)
are_components_touched = True
if are_components_touched:
GLOBAL_CACHE.clear()
return new_self
def _get_vars(
self, include_children: bool = False, ignore_ids: set[int] | None = None
) -> Iterator[Var]:

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import copy
import dataclasses
import inspect
import typing
from abc import ABC, abstractmethod
from functools import lru_cache, wraps
@ -21,6 +22,8 @@ from typing import (
Set,
Type,
Union,
get_args,
get_origin,
)
from typing_extensions import Self
@ -43,6 +46,7 @@ from reflex.constants import (
from reflex.constants.compiler import SpecialAttributes
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.event import (
EventActionsMixin,
EventCallback,
EventChain,
EventHandler,
@ -191,6 +195,25 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
return types._isinstance(obj, type_hint, nested=1)
def _components_from(
component_or_var: Union[BaseComponent, Var],
) -> tuple[BaseComponent, ...]:
"""Get the components from a component or Var.
Args:
component_or_var: The component or Var to get the components from.
Returns:
The components.
"""
if isinstance(component_or_var, Var):
var_data = component_or_var._get_all_var_data()
return var_data.components if var_data else ()
if isinstance(component_or_var, BaseComponent):
return (component_or_var,)
return ()
class Component(BaseComponent, ABC):
"""A component with style, event trigger and other props."""
@ -489,7 +512,7 @@ class Component(BaseComponent, ABC):
# Remove any keys that were added as events.
for key in kwargs["event_triggers"]:
del kwargs[key]
kwargs.pop(key, None)
# Place data_ and aria_ attributes into custom_attrs
special_attributes = tuple(
@ -666,12 +689,21 @@ class Component(BaseComponent, ABC):
return set()
@classmethod
@lru_cache(maxsize=None)
def get_component_props(cls) -> set[str]:
"""Get the props that expected a component as value.
def _are_fields_known(cls) -> bool:
"""Check if all fields are known at compile time. True for most components.
Returns:
The components props.
Whether all fields are known at compile time.
"""
return True
@classmethod
@lru_cache(maxsize=None)
def _get_component_prop_names(cls) -> Set[str]:
"""Get the names of the component props. NOTE: This assumes all fields are known.
Returns:
The names of the component props.
"""
return {
name
@ -680,6 +712,26 @@ class Component(BaseComponent, ABC):
and types._issubclass(field.outer_type_, Component)
}
def _get_components_in_props(self) -> Sequence[BaseComponent]:
"""Get the components in the props.
Returns:
The components in the props
"""
if self._are_fields_known():
return [
component
for name in self._get_component_prop_names()
for component in _components_from(getattr(self, name))
]
return [
component
for prop in self.get_props()
if (value := getattr(self, prop)) is not None
and isinstance(value, (BaseComponent, Var))
for component in _components_from(value)
]
@classmethod
def create(cls, *children, **props) -> Self:
"""Create the component.
@ -1136,6 +1188,9 @@ class Component(BaseComponent, ABC):
if custom_code is not None:
code.add(custom_code)
for component in self._get_components_in_props():
code |= component._get_all_custom_code()
# Add the custom code from add_custom_code method.
for clz in self._iter_parent_classes_with_method("add_custom_code"):
for item in clz.add_custom_code(self):
@ -1163,7 +1218,7 @@ class Component(BaseComponent, ABC):
The dynamic imports.
"""
# Store the import in a set to avoid duplicates.
dynamic_imports = set()
dynamic_imports: set[str] = set()
# Get dynamic import for this component.
dynamic_import = self._get_dynamic_imports()
@ -1174,25 +1229,12 @@ class Component(BaseComponent, ABC):
for child in self.children:
dynamic_imports |= child._get_all_dynamic_imports()
for prop in self.get_component_props():
if getattr(self, prop) is not None:
dynamic_imports |= getattr(self, prop)._get_all_dynamic_imports()
for component in self._get_components_in_props():
dynamic_imports |= component._get_all_dynamic_imports()
# Return the dynamic imports
return dynamic_imports
def _get_props_imports(self) -> List[ParsedImportDict]:
"""Get the imports needed for components props.
Returns:
The imports for the components props of the component.
"""
return [
getattr(self, prop)._get_all_imports()
for prop in self.get_component_props()
if getattr(self, prop) is not None
]
def _should_transpile(self, dep: str | None) -> bool:
"""Check if a dependency should be transpiled.
@ -1303,7 +1345,6 @@ class Component(BaseComponent, ABC):
)
return imports.merge_imports(
*self._get_props_imports(),
self._get_dependencies_imports(),
self._get_hooks_imports(),
_imports,
@ -1380,6 +1421,8 @@ class Component(BaseComponent, ABC):
for k in var_data.hooks
}
)
for component in var_data.components:
vars_hooks.update(component._get_all_hooks())
return vars_hooks
def _get_events_hooks(self) -> dict[str, VarData | None]:
@ -1528,6 +1571,9 @@ class Component(BaseComponent, ABC):
refs.add(ref)
for child in self.children:
refs |= child._get_all_refs()
for component in self._get_components_in_props():
refs |= component._get_all_refs()
return refs
def _get_all_custom_components(
@ -1551,6 +1597,9 @@ class Component(BaseComponent, ABC):
if not isinstance(child, Component):
continue
custom_components |= child._get_all_custom_components(seen=seen)
for component in self._get_components_in_props():
if isinstance(component, Component) and component.tag is not None:
custom_components |= component._get_all_custom_components(seen=seen)
return custom_components
@property
@ -1614,17 +1663,65 @@ class CustomComponent(Component):
# The props of the component.
props: Dict[str, Any] = {}
# Props that reference other components.
component_props: Dict[str, Component] = {}
def __init__(self, *args, **kwargs):
def __init__(self, **kwargs):
"""Initialize the custom component.
Args:
*args: The args to pass to the component.
**kwargs: The kwargs to pass to the component.
"""
super().__init__(*args, **kwargs)
component_fn = kwargs.get("component_fn")
# Set the props.
props_types = typing.get_type_hints(component_fn) if component_fn else {}
props = {key: value for key, value in kwargs.items() if key in props_types}
kwargs = {key: value for key, value in kwargs.items() if key not in props_types}
event_types = {
key
for key in props
if (
(get_origin((annotation := props_types.get(key))) or annotation)
== EventHandler
)
}
def get_args_spec(key: str) -> types.ArgsSpec | Sequence[types.ArgsSpec]:
type_ = props_types[key]
return (
args[0]
if (args := get_args(type_))
else (
annotation_args[1]
if get_origin(
(
annotation := inspect.getfullargspec(
component_fn
).annotations[key]
)
)
is typing.Annotated
and (annotation_args := get_args(annotation))
else no_args_event_spec
)
)
super().__init__(
event_triggers={
key: EventChain.create(
value=props[key],
args_spec=get_args_spec(key),
key=key,
)
for key in event_types
},
**kwargs,
)
to_camel_cased_props = {
format.to_camel_case(key) for key in props if key not in event_types
}
self.get_props = lambda: to_camel_cased_props # pyright: ignore [reportIncompatibleVariableOverride]
# Unset the style.
self.style = Style()
@ -1632,51 +1729,36 @@ class CustomComponent(Component):
# Set the tag to the name of the function.
self.tag = format.to_title_case(self.component_fn.__name__)
# Get the event triggers defined in the component declaration.
event_triggers_in_component_declaration = self.get_event_triggers()
# Set the props.
props = typing.get_type_hints(self.component_fn)
for key, value in kwargs.items():
for key, value in props.items():
# Skip kwargs that are not props.
if key not in props:
if key not in props_types:
continue
camel_cased_key = format.to_camel_case(key)
# Get the type based on the annotation.
type_ = props[key]
type_ = props_types[key]
# Handle event chains.
if types._issubclass(type_, EventChain):
value = EventChain.create(
value=value,
args_spec=event_triggers_in_component_declaration.get(
key, no_args_event_spec
),
key=key,
if types._issubclass(type_, EventActionsMixin):
inspect.getfullargspec(component_fn).annotations[key]
self.props[camel_cased_key] = EventChain.create(
value=value, args_spec=get_args_spec(key), key=key
)
self.props[format.to_camel_case(key)] = value
continue
# Handle subclasses of Base.
if isinstance(value, Base):
base_value = LiteralVar.create(value)
value = LiteralVar.create(value)
self.props[camel_cased_key] = value
setattr(self, camel_cased_key, value)
# Track hooks and imports associated with Component instances.
if base_value is not None and isinstance(value, Component):
self.component_props[key] = value
value = base_value._replace(
merge_var_data=VarData(
imports=value._get_all_imports(),
hooks=value._get_all_hooks(),
)
)
else:
value = base_value
else:
value = LiteralVar.create(value)
@classmethod
def _are_fields_known(cls) -> bool:
"""Check if the fields are known.
# Set the prop.
self.props[format.to_camel_case(key)] = value
Returns:
Whether the fields are known.
"""
return False
def __eq__(self, other: Any) -> bool:
"""Check if the component is equal to another.
@ -1698,7 +1780,7 @@ class CustomComponent(Component):
return hash(self.tag)
@classmethod
def get_props(cls) -> Set[str]: # pyright: ignore [reportIncompatibleVariableOverride]
def get_props(cls) -> Set[str]:
"""Get the props for the component.
Returns:
@ -1735,27 +1817,8 @@ class CustomComponent(Component):
seen=seen
)
# Fetch custom components from props as well.
for child_component in self.component_props.values():
if child_component.tag is None:
continue
if child_component.tag not in seen:
seen.add(child_component.tag)
if isinstance(child_component, CustomComponent):
custom_components |= {child_component}
custom_components |= child_component._get_all_custom_components(
seen=seen
)
return custom_components
def _render(self) -> Tag:
"""Define how to render the component in React.
Returns:
The tag to render.
"""
return super()._render(props=self.props)
def get_prop_vars(self) -> List[Var]:
"""Get the prop vars.
@ -1765,29 +1828,19 @@ class CustomComponent(Component):
return [
Var(
_js_expr=name,
_var_type=(prop._var_type if isinstance(prop, Var) else type(prop)),
_var_type=(
prop._var_type
if isinstance(prop, Var)
else (
type(prop)
if not isinstance(prop, EventActionsMixin)
else EventChain
)
),
).guess_type()
for name, prop in self.props.items()
]
def _get_vars(
self, include_children: bool = False, ignore_ids: set[int] | None = None
) -> Iterator[Var]:
"""Walk all Vars used in this component.
Args:
include_children: Whether to include Vars from children.
ignore_ids: The ids to ignore.
Yields:
Each var referenced by the component (props, styles, event handlers).
"""
ignore_ids = ignore_ids or set()
yield from super()._get_vars(
include_children=include_children, ignore_ids=ignore_ids
)
yield from filter(lambda prop: isinstance(prop, Var), self.props.values())
@lru_cache(maxsize=None) # noqa: B019
def get_component(self) -> Component:
"""Render the component.
@ -2475,6 +2528,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
The VarData for the var.
"""
return VarData.merge(
self._var_data,
VarData(
imports={
"@emotion/react": [
@ -2517,9 +2571,21 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
Returns:
The var.
"""
var_datas = [
var_data
for var in value._get_vars(include_children=True)
if (var_data := var._get_all_var_data())
]
return LiteralComponentVar(
_js_expr="",
_var_type=type(value),
_var_data=_var_data,
_var_data=VarData.merge(
_var_data,
*var_datas,
VarData(
components=(value,),
),
),
_var_value=value,
)

View File

@ -61,14 +61,6 @@ class Cond(MemoizationLeaf):
)
)
def _get_props_imports(self):
"""Get the imports needed for component's props.
Returns:
The imports for the component's props of the component.
"""
return []
def _render(self) -> Tag:
return CondTag(
cond=self.cond,

View File

@ -76,6 +76,7 @@ from reflex.utils.types import (
)
if TYPE_CHECKING:
from reflex.components.component import BaseComponent
from reflex.state import BaseState
from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar
@ -132,6 +133,9 @@ class VarData:
# Position of the hook in the component
position: Hooks.HookPosition | None = None
# Components that are part of this var
components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple)
def __init__(
self,
state: str = "",
@ -140,6 +144,7 @@ class VarData:
hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
components: Iterable[BaseComponent] | None = None,
):
"""Initialize the var data.
@ -150,6 +155,7 @@ class VarData:
hooks: Hooks that need to be present in the component to render this var.
deps: Dependencies of the var for useCallback.
position: Position of the hook in the component.
components: Components that are part of this var.
"""
if isinstance(hooks, str):
hooks = [hooks]
@ -164,6 +170,7 @@ class VarData:
object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None)
object.__setattr__(self, "components", tuple(components or []))
if hooks and any(hooks.values()):
merged_var_data = VarData.merge(self, *hooks.values())
@ -174,6 +181,7 @@ class VarData:
object.__setattr__(self, "hooks", merged_var_data.hooks)
object.__setattr__(self, "deps", merged_var_data.deps)
object.__setattr__(self, "position", merged_var_data.position)
object.__setattr__(self, "components", merged_var_data.components)
def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict.
@ -242,17 +250,19 @@ class VarData:
else:
position = None
if state or _imports or hooks or field_name or deps or position:
return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
deps=deps,
position=position,
)
components = tuple(
component for var_data in all_var_datas for component in var_data.components
)
return None
return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
deps=deps,
position=position,
components=components,
)
def __bool__(self) -> bool:
"""Check if the var data is non-empty.
@ -267,6 +277,7 @@ class VarData:
or self.field_name
or self.deps
or self.position
or self.components
)
@classmethod

View File

@ -871,7 +871,7 @@ def test_create_custom_component(my_component):
"""
component = CustomComponent(component_fn=my_component, prop1="test", prop2=1)
assert component.tag == "MyComponent"
assert component.get_props() == set()
assert component.get_props() == {"prop1", "prop2"}
assert component._get_all_custom_components() == {component}