add components to var data

This commit is contained in:
Khaleel Al-Adhami 2024-11-15 15:47:46 -08:00
parent 88cfb3b7e2
commit 702670ff26
3 changed files with 52 additions and 29 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any, Iterator from typing import Any, Iterator
from reflex.components.component import Component, LiteralComponentVar from reflex.components.component import Component
from reflex.components.tags import Tag from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless from reflex.components.tags.tagless import Tagless
from reflex.utils.imports import ParsedImportDict from reflex.utils.imports import ParsedImportDict
@ -39,8 +39,11 @@ class Bare(Component):
The hooks for the component. The hooks for the component.
""" """
hooks = super()._get_all_hooks_internal() hooks = super()._get_all_hooks_internal()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
hooks |= self.contents._var_value._get_all_hooks_internal() var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
hooks |= component._get_all_hooks_internal()
return hooks return hooks
def _get_all_hooks(self) -> dict[str, None]: def _get_all_hooks(self) -> dict[str, None]:
@ -50,8 +53,11 @@ class Bare(Component):
The hooks for the component. The hooks for the component.
""" """
hooks = super()._get_all_hooks() hooks = super()._get_all_hooks()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
hooks |= self.contents._var_value._get_all_hooks() var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
hooks |= component._get_all_hooks()
return hooks return hooks
def _get_all_imports(self) -> ParsedImportDict: def _get_all_imports(self) -> ParsedImportDict:
@ -61,7 +67,7 @@ class Bare(Component):
The imports for the component. The imports for the component.
""" """
imports = super()._get_all_imports() imports = super()._get_all_imports()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data() var_data = self.contents._get_all_var_data()
if var_data: if var_data:
imports |= {k: list(v) for k, v in var_data.imports} imports |= {k: list(v) for k, v in var_data.imports}
@ -74,8 +80,11 @@ class Bare(Component):
The dynamic imports. The dynamic imports.
""" """
dynamic_imports = super()._get_all_dynamic_imports() dynamic_imports = super()._get_all_dynamic_imports()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
dynamic_imports |= self.contents._var_value._get_all_dynamic_imports() var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
dynamic_imports |= component._get_all_dynamic_imports()
return dynamic_imports return dynamic_imports
def _get_all_custom_code(self) -> set[str]: def _get_all_custom_code(self) -> set[str]:
@ -85,8 +94,11 @@ class Bare(Component):
The custom code. The custom code.
""" """
custom_code = super()._get_all_custom_code() custom_code = super()._get_all_custom_code()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
custom_code |= self.contents._var_value._get_all_custom_code() var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
custom_code |= component._get_all_custom_code()
return custom_code return custom_code
def _get_all_refs(self) -> set[str]: def _get_all_refs(self) -> set[str]:
@ -96,8 +108,11 @@ class Bare(Component):
The refs for the children. The refs for the children.
""" """
refs = super()._get_all_refs() refs = super()._get_all_refs()
if isinstance(self.contents, LiteralComponentVar): if isinstance(self.contents, Var):
refs |= self.contents._var_value._get_all_refs() var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
refs |= component._get_all_refs()
return refs return refs
def _render(self) -> Tag: def _render(self) -> Tag:

View File

@ -2521,17 +2521,14 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
"@emotion/react": [ "@emotion/react": [
ImportVar(tag="jsx"), ImportVar(tag="jsx"),
], ],
}
),
VarData(
imports=self._var_value._get_all_imports(),
),
VarData(
imports={
"react": [ "react": [
ImportVar(tag="Fragment"), ImportVar(tag="Fragment"),
], ],
} },
components=(self._var_value,),
),
VarData(
imports=self._var_value._get_all_imports(),
), ),
) )

View File

@ -76,6 +76,7 @@ from reflex.utils.types import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.components.component import BaseComponent
from reflex.state import BaseState from reflex.state import BaseState
from .function import ArgsFunctionOperation from .function import ArgsFunctionOperation
@ -166,12 +167,16 @@ class VarData:
# Hooks that need to be present in the component to render this var # Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
# Components that need to be present in the component to render this var
components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple)
def __init__( def __init__(
self, self,
state: str = "", state: str = "",
field_name: str = "", field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None, imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None, hooks: dict[str, None] | None = None,
components: Iterable[BaseComponent] | None = None,
): ):
"""Initialize the var data. """Initialize the var data.
@ -180,6 +185,7 @@ class VarData:
field_name: The name of the field in the state. field_name: The name of the field in the state.
imports: Imports needed to render this var. imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var. hooks: Hooks that need to be present in the component to render this var.
components: Components that need to be present in the component to render this var.
""" """
immutable_imports: ImmutableParsedImportDict = tuple( immutable_imports: ImmutableParsedImportDict = tuple(
sorted( sorted(
@ -190,6 +196,9 @@ class VarData:
object.__setattr__(self, "field_name", field_name) object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {})) object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(
self, "components", tuple(components) if components is not None else tuple()
)
def old_school_imports(self) -> ImportDict: def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict. """Return the imports as a mutable dict.
@ -235,15 +244,17 @@ class VarData:
*(var_data.imports for var_data in all_var_datas) *(var_data.imports for var_data in all_var_datas)
) )
if state or _imports or hooks or field_name: components = tuple(
return VarData( component for var_data in all_var_datas for component in var_data.components
state=state, )
field_name=field_name,
imports=_imports,
hooks=hooks,
)
return None return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
components=components,
)
def __bool__(self) -> bool: def __bool__(self) -> bool:
"""Check if the var data is non-empty. """Check if the var data is non-empty.
@ -251,7 +262,7 @@ class VarData:
Returns: Returns:
True if any field is set to a non-default value. True if any field is set to a non-default value.
""" """
return bool(self.state or self.imports or self.hooks or self.field_name) return any(getattr(self, field.name) for field in dataclasses.fields(self))
@classmethod @classmethod
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: