diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 73b0680d3..f3efc00ba 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import Any, Iterator -from reflex.components.component import Component, LiteralComponentVar +from reflex.components.component import Component, ComponentStyle from reflex.components.tags import Tag from reflex.components.tags.tagless import Tagless from reflex.config import PerformanceMode, environment @@ -12,7 +12,7 @@ from reflex.utils import console from reflex.utils.decorator import once from reflex.utils.imports import ParsedImportDict from reflex.vars import BooleanVar, ObjectVar, Var -from reflex.vars.base import VarData +from reflex.vars.base import GLOBAL_CACHE, VarData from reflex.vars.sequence import LiteralStringVar @@ -80,8 +80,11 @@ class Bare(Component): The hooks for the component. """ hooks = super()._get_all_hooks_internal() - if isinstance(self.contents, LiteralComponentVar): - hooks |= self.contents._var_value._get_all_hooks_internal() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + hooks |= component._get_all_hooks_internal() return hooks def _get_all_hooks(self) -> dict[str, VarData | None]: @@ -91,18 +94,24 @@ class Bare(Component): The hooks for the component. """ hooks = super()._get_all_hooks() - if isinstance(self.contents, LiteralComponentVar): - hooks |= self.contents._var_value._get_all_hooks() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + hooks |= component._get_all_hooks() return hooks - def _get_all_imports(self) -> ParsedImportDict: + def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict: """Include the imports for the component. + Args: + collapse: Whether to collapse the imports. + Returns: The imports for the component. """ - imports = super()._get_all_imports() - if isinstance(self.contents, LiteralComponentVar): + imports = super()._get_all_imports(collapse=collapse) + if isinstance(self.contents, Var): var_data = self.contents._get_all_var_data() if var_data: imports |= {k: list(v) for k, v in var_data.imports} @@ -115,8 +124,11 @@ class Bare(Component): The dynamic imports. """ dynamic_imports = super()._get_all_dynamic_imports() - if isinstance(self.contents, LiteralComponentVar): - dynamic_imports |= self.contents._var_value._get_all_dynamic_imports() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + dynamic_imports |= component._get_all_dynamic_imports() return dynamic_imports def _get_all_custom_code(self) -> set[str]: @@ -126,10 +138,28 @@ class Bare(Component): The custom code. """ custom_code = super()._get_all_custom_code() - if isinstance(self.contents, LiteralComponentVar): - custom_code |= self.contents._var_value._get_all_custom_code() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + custom_code |= component._get_all_custom_code() return custom_code + def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]: + """Get the components that should be wrapped in the app. + + Returns: + The components that should be wrapped in the app. + """ + app_wrap_components = super()._get_all_app_wrap_components() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + if isinstance(component, Component): + app_wrap_components |= component._get_all_app_wrap_components() + return app_wrap_components + def _get_all_refs(self) -> set[str]: """Get the refs for the children of the component. @@ -137,8 +167,11 @@ class Bare(Component): The refs for the children. """ refs = super()._get_all_refs() - if isinstance(self.contents, LiteralComponentVar): - refs |= self.contents._var_value._get_all_refs() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + refs |= component._get_all_refs() return refs def _render(self) -> Tag: @@ -148,6 +181,35 @@ class Bare(Component): return Tagless(contents=f"{{{self.contents!s}}}") return Tagless(contents=str(self.contents)) + def _add_style_recursive( + self, style: ComponentStyle, theme: Component | None = None + ) -> Component: + """Add style to the component and its children. + + Args: + style: The style to add. + theme: The theme to add. + + Returns: + The component with the style added. + """ + new_self = super()._add_style_recursive(style, theme) + + are_components_touched = False + + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + if isinstance(component, Component): + component._add_style_recursive(style, theme) + are_components_touched = True + + if are_components_touched: + GLOBAL_CACHE.clear() + + return new_self + def _get_vars( self, include_children: bool = False, ignore_ids: set[int] | None = None ) -> Iterator[Var]: diff --git a/reflex/components/component.py b/reflex/components/component.py index 005f7791d..8453333d4 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -191,6 +191,25 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool: return types._isinstance(obj, type_hint, nested=1) +def _components_from( + component_or_var: Union[BaseComponent, Var], +) -> tuple[BaseComponent, ...]: + """Get the components from a component or Var. + + Args: + component_or_var: The component or Var to get the components from. + + Returns: + The components. + """ + if isinstance(component_or_var, Var): + var_data = component_or_var._get_all_var_data() + return var_data.components if var_data else () + if isinstance(component_or_var, BaseComponent): + return (component_or_var,) + return () + + class Component(BaseComponent, ABC): """A component with style, event trigger and other props.""" @@ -665,20 +684,15 @@ class Component(BaseComponent, ABC): """ return set() - @classmethod - @lru_cache(maxsize=None) - def get_component_props(cls) -> set[str]: - """Get the props that expected a component as value. + def _get_components_in_props(self) -> Iterator[BaseComponent]: + """Get the components in the props. - Returns: - The components props. + Yields: + The components in the props. """ - return { - name - for name, field in cls.get_fields().items() - if name in cls.get_props() - and types._issubclass(field.outer_type_, Component) - } + for prop in self.get_props(): + value = getattr(self, prop) + yield from _components_from(value) @classmethod def create(cls, *children, **props) -> Self: @@ -1136,6 +1150,9 @@ class Component(BaseComponent, ABC): if custom_code is not None: code.add(custom_code) + for component in self._get_components_in_props(): + code |= component._get_all_custom_code() + # Add the custom code from add_custom_code method. for clz in self._iter_parent_classes_with_method("add_custom_code"): for item in clz.add_custom_code(self): @@ -1163,7 +1180,7 @@ class Component(BaseComponent, ABC): The dynamic imports. """ # Store the import in a set to avoid duplicates. - dynamic_imports = set() + dynamic_imports: set[str] = set() # Get dynamic import for this component. dynamic_import = self._get_dynamic_imports() @@ -1174,25 +1191,12 @@ class Component(BaseComponent, ABC): for child in self.children: dynamic_imports |= child._get_all_dynamic_imports() - for prop in self.get_component_props(): - if getattr(self, prop) is not None: - dynamic_imports |= getattr(self, prop)._get_all_dynamic_imports() + for component in self._get_components_in_props(): + dynamic_imports |= component._get_all_dynamic_imports() # Return the dynamic imports return dynamic_imports - def _get_props_imports(self) -> List[ParsedImportDict]: - """Get the imports needed for components props. - - Returns: - The imports for the components props of the component. - """ - return [ - getattr(self, prop)._get_all_imports() - for prop in self.get_component_props() - if getattr(self, prop) is not None - ] - def _should_transpile(self, dep: str | None) -> bool: """Check if a dependency should be transpiled. @@ -1303,7 +1307,6 @@ class Component(BaseComponent, ABC): ) return imports.merge_imports( - *self._get_props_imports(), self._get_dependencies_imports(), self._get_hooks_imports(), _imports, @@ -1380,6 +1383,8 @@ class Component(BaseComponent, ABC): for k in var_data.hooks } ) + for component in var_data.components: + vars_hooks.update(component._get_all_hooks()) return vars_hooks def _get_events_hooks(self) -> dict[str, VarData | None]: @@ -1528,6 +1533,9 @@ class Component(BaseComponent, ABC): refs.add(ref) for child in self.children: refs |= child._get_all_refs() + for component in self._get_components_in_props(): + refs |= component._get_all_refs() + return refs def _get_all_custom_components( @@ -1551,6 +1559,9 @@ class Component(BaseComponent, ABC): if not isinstance(child, Component): continue custom_components |= child._get_all_custom_components(seen=seen) + for component in self._get_components_in_props(): + if isinstance(component, Component) and component.tag is not None: + custom_components |= component._get_all_custom_components(seen=seen) return custom_components @property @@ -1614,17 +1625,25 @@ class CustomComponent(Component): # The props of the component. props: Dict[str, Any] = {} - # Props that reference other components. - component_props: Dict[str, Component] = {} - - def __init__(self, *args, **kwargs): + def __init__(self, **kwargs): """Initialize the custom component. Args: - *args: The args to pass to the component. **kwargs: The kwargs to pass to the component. """ - super().__init__(*args, **kwargs) + component_fn = kwargs.get("component_fn") + + # Set the props. + props_types = typing.get_type_hints(component_fn) + props = {key: value for key, value in kwargs.items() if key in props_types} + kwargs = {key: value for key, value in kwargs.items() if key not in props_types} + + super().__init__( + **kwargs, + ) + + to_camel_cased_props = {format.to_camel_case(key) for key in props} + self.get_props = lambda: to_camel_cased_props # pyright: ignore [reportIncompatibleVariableOverride] # Unset the style. self.style = Style() @@ -1635,15 +1654,15 @@ class CustomComponent(Component): # Get the event triggers defined in the component declaration. event_triggers_in_component_declaration = self.get_event_triggers() - # Set the props. - props = typing.get_type_hints(self.component_fn) - for key, value in kwargs.items(): + for key, value in props.items(): # Skip kwargs that are not props. - if key not in props: + if key not in props_types: continue + camel_cased_key = format.to_camel_case(key) + # Get the type based on the annotation. - type_ = props[key] + type_ = props_types[key] # Handle event chains. if types._issubclass(type_, EventChain): @@ -1654,29 +1673,14 @@ class CustomComponent(Component): ), key=key, ) - self.props[format.to_camel_case(key)] = value + self.props[camel_cased_key] = value continue - # Handle subclasses of Base. - if isinstance(value, Base): - base_value = LiteralVar.create(value) - - # Track hooks and imports associated with Component instances. - if base_value is not None and isinstance(value, Component): - self.component_props[key] = value - value = base_value._replace( - merge_var_data=VarData( - imports=value._get_all_imports(), - hooks=value._get_all_hooks(), - ) - ) - else: - value = base_value - else: - value = LiteralVar.create(value) + value = LiteralVar.create(value) # Set the prop. - self.props[format.to_camel_case(key)] = value + self.props[camel_cased_key] = value + setattr(self, camel_cased_key, value) def __eq__(self, other: Any) -> bool: """Check if the component is equal to another. @@ -1698,7 +1702,7 @@ class CustomComponent(Component): return hash(self.tag) @classmethod - def get_props(cls) -> Set[str]: # pyright: ignore [reportIncompatibleVariableOverride] + def get_props(cls) -> Set[str]: """Get the props for the component. Returns: @@ -1735,27 +1739,8 @@ class CustomComponent(Component): seen=seen ) - # Fetch custom components from props as well. - for child_component in self.component_props.values(): - if child_component.tag is None: - continue - if child_component.tag not in seen: - seen.add(child_component.tag) - if isinstance(child_component, CustomComponent): - custom_components |= {child_component} - custom_components |= child_component._get_all_custom_components( - seen=seen - ) return custom_components - def _render(self) -> Tag: - """Define how to render the component in React. - - Returns: - The tag to render. - """ - return super()._render(props=self.props) - def get_prop_vars(self) -> List[Var]: """Get the prop vars. @@ -1770,24 +1755,6 @@ class CustomComponent(Component): for name, prop in self.props.items() ] - def _get_vars( - self, include_children: bool = False, ignore_ids: set[int] | None = None - ) -> Iterator[Var]: - """Walk all Vars used in this component. - - Args: - include_children: Whether to include Vars from children. - ignore_ids: The ids to ignore. - - Yields: - Each var referenced by the component (props, styles, event handlers). - """ - ignore_ids = ignore_ids or set() - yield from super()._get_vars( - include_children=include_children, ignore_ids=ignore_ids - ) - yield from filter(lambda prop: isinstance(prop, Var), self.props.values()) - @lru_cache(maxsize=None) # noqa: B019 def get_component(self) -> Component: """Render the component. @@ -2475,6 +2442,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): The VarData for the var. """ return VarData.merge( + self._var_data, VarData( imports={ "@emotion/react": [ @@ -2517,9 +2485,21 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): Returns: The var. """ + var_datas = [ + var_data + for var in value._get_vars(include_children=True) + if (var_data := var._get_all_var_data()) + ] + return LiteralComponentVar( _js_expr="", _var_type=type(value), - _var_data=_var_data, + _var_data=VarData.merge( + _var_data, + *var_datas, + VarData( + components=(value,), + ), + ), _var_value=value, ) diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index 6f9110a16..a76a8b800 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -61,14 +61,6 @@ class Cond(MemoizationLeaf): ) ) - def _get_props_imports(self): - """Get the imports needed for component's props. - - Returns: - The imports for the component's props of the component. - """ - return [] - def _render(self) -> Tag: return CondTag( cond=self.cond, diff --git a/reflex/vars/base.py b/reflex/vars/base.py index a6786b18a..c01bf20f7 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -73,6 +73,7 @@ from reflex.utils.types import ( ) if TYPE_CHECKING: + from reflex.components.component import BaseComponent from reflex.state import BaseState from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar @@ -129,6 +130,9 @@ class VarData: # Position of the hook in the component position: Hooks.HookPosition | None = None + # Components that are part of this var + components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple) + def __init__( self, state: str = "", @@ -137,6 +141,7 @@ class VarData: hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None, deps: list[Var] | None = None, position: Hooks.HookPosition | None = None, + components: Iterable[BaseComponent] | None = None, ): """Initialize the var data. @@ -147,6 +152,7 @@ class VarData: hooks: Hooks that need to be present in the component to render this var. deps: Dependencies of the var for useCallback. position: Position of the hook in the component. + components: Components that are part of this var. """ if isinstance(hooks, str): hooks = [hooks] @@ -161,6 +167,7 @@ class VarData: object.__setattr__(self, "hooks", tuple(hooks or {})) object.__setattr__(self, "deps", tuple(deps or [])) object.__setattr__(self, "position", position or None) + object.__setattr__(self, "components", tuple(components or [])) if hooks and any(hooks.values()): merged_var_data = VarData.merge(self, *hooks.values()) @@ -171,6 +178,7 @@ class VarData: object.__setattr__(self, "hooks", merged_var_data.hooks) object.__setattr__(self, "deps", merged_var_data.deps) object.__setattr__(self, "position", merged_var_data.position) + object.__setattr__(self, "components", merged_var_data.components) def old_school_imports(self) -> ImportDict: """Return the imports as a mutable dict. @@ -239,17 +247,19 @@ class VarData: else: position = None - if state or _imports or hooks or field_name or deps or position: - return VarData( - state=state, - field_name=field_name, - imports=_imports, - hooks=hooks, - deps=deps, - position=position, - ) + components = tuple( + component for var_data in all_var_datas for component in var_data.components + ) - return None + return VarData( + state=state, + field_name=field_name, + imports=_imports, + hooks=hooks, + deps=deps, + position=position, + components=components, + ) def __bool__(self) -> bool: """Check if the var data is non-empty. @@ -264,6 +274,7 @@ class VarData: or self.field_name or self.deps or self.position + or self.components ) @classmethod diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 8cffa6e0e..377f98f97 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -871,7 +871,7 @@ def test_create_custom_component(my_component): """ component = CustomComponent(component_fn=my_component, prop1="test", prop2=1) assert component.tag == "MyComponent" - assert component.get_props() == set() + assert component.get_props() == {"prop1", "prop2"} assert component._get_all_custom_components() == {component}