From 702670ff26dfcf969d81b126cb9a4f25993a1f0c Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 15 Nov 2024 15:47:46 -0800 Subject: [PATCH] add components to var data --- reflex/components/base/bare.py | 39 +++++++++++++++++++++++----------- reflex/components/component.py | 13 +++++------- reflex/vars/base.py | 29 +++++++++++++++++-------- 3 files changed, 52 insertions(+), 29 deletions(-) diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index c70b4c844..4758181e3 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import Any, Iterator -from reflex.components.component import Component, LiteralComponentVar +from reflex.components.component import Component from reflex.components.tags import Tag from reflex.components.tags.tagless import Tagless from reflex.utils.imports import ParsedImportDict @@ -39,8 +39,11 @@ class Bare(Component): The hooks for the component. """ hooks = super()._get_all_hooks_internal() - if isinstance(self.contents, LiteralComponentVar): - hooks |= self.contents._var_value._get_all_hooks_internal() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + hooks |= component._get_all_hooks_internal() return hooks def _get_all_hooks(self) -> dict[str, None]: @@ -50,8 +53,11 @@ class Bare(Component): The hooks for the component. """ hooks = super()._get_all_hooks() - if isinstance(self.contents, LiteralComponentVar): - hooks |= self.contents._var_value._get_all_hooks() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + hooks |= component._get_all_hooks() return hooks def _get_all_imports(self) -> ParsedImportDict: @@ -61,7 +67,7 @@ class Bare(Component): The imports for the component. """ imports = super()._get_all_imports() - if isinstance(self.contents, LiteralComponentVar): + if isinstance(self.contents, Var): var_data = self.contents._get_all_var_data() if var_data: imports |= {k: list(v) for k, v in var_data.imports} @@ -74,8 +80,11 @@ class Bare(Component): The dynamic imports. """ dynamic_imports = super()._get_all_dynamic_imports() - if isinstance(self.contents, LiteralComponentVar): - dynamic_imports |= self.contents._var_value._get_all_dynamic_imports() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + dynamic_imports |= component._get_all_dynamic_imports() return dynamic_imports def _get_all_custom_code(self) -> set[str]: @@ -85,8 +94,11 @@ class Bare(Component): The custom code. """ custom_code = super()._get_all_custom_code() - if isinstance(self.contents, LiteralComponentVar): - custom_code |= self.contents._var_value._get_all_custom_code() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + custom_code |= component._get_all_custom_code() return custom_code def _get_all_refs(self) -> set[str]: @@ -96,8 +108,11 @@ class Bare(Component): The refs for the children. """ refs = super()._get_all_refs() - if isinstance(self.contents, LiteralComponentVar): - refs |= self.contents._var_value._get_all_refs() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + refs |= component._get_all_refs() return refs def _render(self) -> Tag: diff --git a/reflex/components/component.py b/reflex/components/component.py index face5d557..8cf0001d2 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -2521,17 +2521,14 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): "@emotion/react": [ ImportVar(tag="jsx"), ], - } - ), - VarData( - imports=self._var_value._get_all_imports(), - ), - VarData( - imports={ "react": [ ImportVar(tag="Fragment"), ], - } + }, + components=(self._var_value,), + ), + VarData( + imports=self._var_value._get_all_imports(), ), ) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 8945d130b..c8aa99731 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -76,6 +76,7 @@ from reflex.utils.types import ( ) if TYPE_CHECKING: + from reflex.components.component import BaseComponent from reflex.state import BaseState from .function import ArgsFunctionOperation @@ -166,12 +167,16 @@ class VarData: # Hooks that need to be present in the component to render this var hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + # Components that need to be present in the component to render this var + components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple) + def __init__( self, state: str = "", field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, hooks: dict[str, None] | None = None, + components: Iterable[BaseComponent] | None = None, ): """Initialize the var data. @@ -180,6 +185,7 @@ class VarData: field_name: The name of the field in the state. imports: Imports needed to render this var. hooks: Hooks that need to be present in the component to render this var. + components: Components that need to be present in the component to render this var. """ immutable_imports: ImmutableParsedImportDict = tuple( sorted( @@ -190,6 +196,9 @@ class VarData: object.__setattr__(self, "field_name", field_name) object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "hooks", tuple(hooks or {})) + object.__setattr__( + self, "components", tuple(components) if components is not None else tuple() + ) def old_school_imports(self) -> ImportDict: """Return the imports as a mutable dict. @@ -235,15 +244,17 @@ class VarData: *(var_data.imports for var_data in all_var_datas) ) - if state or _imports or hooks or field_name: - return VarData( - state=state, - field_name=field_name, - imports=_imports, - hooks=hooks, - ) + components = tuple( + component for var_data in all_var_datas for component in var_data.components + ) - return None + return VarData( + state=state, + field_name=field_name, + imports=_imports, + hooks=hooks, + components=components, + ) def __bool__(self) -> bool: """Check if the var data is non-empty. @@ -251,7 +262,7 @@ class VarData: Returns: True if any field is set to a non-default value. """ - return bool(self.state or self.imports or self.hooks or self.field_name) + return any(getattr(self, field.name) for field in dataclasses.fields(self)) @classmethod def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: