add components to var data

This commit is contained in:
Khaleel Al-Adhami 2024-11-15 15:47:46 -08:00
parent 88cfb3b7e2
commit 702670ff26
3 changed files with 52 additions and 29 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from typing import Any, Iterator
from reflex.components.component import Component, LiteralComponentVar
from reflex.components.component import Component
from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless
from reflex.utils.imports import ParsedImportDict
@ -39,8 +39,11 @@ class Bare(Component):
The hooks for the component.
"""
hooks = super()._get_all_hooks_internal()
if isinstance(self.contents, LiteralComponentVar):
hooks |= self.contents._var_value._get_all_hooks_internal()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
hooks |= component._get_all_hooks_internal()
return hooks
def _get_all_hooks(self) -> dict[str, None]:
@ -50,8 +53,11 @@ class Bare(Component):
The hooks for the component.
"""
hooks = super()._get_all_hooks()
if isinstance(self.contents, LiteralComponentVar):
hooks |= self.contents._var_value._get_all_hooks()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
hooks |= component._get_all_hooks()
return hooks
def _get_all_imports(self) -> ParsedImportDict:
@ -61,7 +67,7 @@ class Bare(Component):
The imports for the component.
"""
imports = super()._get_all_imports()
if isinstance(self.contents, LiteralComponentVar):
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
imports |= {k: list(v) for k, v in var_data.imports}
@ -74,8 +80,11 @@ class Bare(Component):
The dynamic imports.
"""
dynamic_imports = super()._get_all_dynamic_imports()
if isinstance(self.contents, LiteralComponentVar):
dynamic_imports |= self.contents._var_value._get_all_dynamic_imports()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
dynamic_imports |= component._get_all_dynamic_imports()
return dynamic_imports
def _get_all_custom_code(self) -> set[str]:
@ -85,8 +94,11 @@ class Bare(Component):
The custom code.
"""
custom_code = super()._get_all_custom_code()
if isinstance(self.contents, LiteralComponentVar):
custom_code |= self.contents._var_value._get_all_custom_code()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
custom_code |= component._get_all_custom_code()
return custom_code
def _get_all_refs(self) -> set[str]:
@ -96,8 +108,11 @@ class Bare(Component):
The refs for the children.
"""
refs = super()._get_all_refs()
if isinstance(self.contents, LiteralComponentVar):
refs |= self.contents._var_value._get_all_refs()
if isinstance(self.contents, Var):
var_data = self.contents._get_all_var_data()
if var_data:
for component in var_data.components:
refs |= component._get_all_refs()
return refs
def _render(self) -> Tag:

View File

@ -2521,17 +2521,14 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
"@emotion/react": [
ImportVar(tag="jsx"),
],
}
),
VarData(
imports=self._var_value._get_all_imports(),
),
VarData(
imports={
"react": [
ImportVar(tag="Fragment"),
],
}
},
components=(self._var_value,),
),
VarData(
imports=self._var_value._get_all_imports(),
),
)

View File

@ -76,6 +76,7 @@ from reflex.utils.types import (
)
if TYPE_CHECKING:
from reflex.components.component import BaseComponent
from reflex.state import BaseState
from .function import ArgsFunctionOperation
@ -166,12 +167,16 @@ class VarData:
# Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
# Components that need to be present in the component to render this var
components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple)
def __init__(
self,
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
components: Iterable[BaseComponent] | None = None,
):
"""Initialize the var data.
@ -180,6 +185,7 @@ class VarData:
field_name: The name of the field in the state.
imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var.
components: Components that need to be present in the component to render this var.
"""
immutable_imports: ImmutableParsedImportDict = tuple(
sorted(
@ -190,6 +196,9 @@ class VarData:
object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(
self, "components", tuple(components) if components is not None else tuple()
)
def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict.
@ -235,15 +244,17 @@ class VarData:
*(var_data.imports for var_data in all_var_datas)
)
if state or _imports or hooks or field_name:
return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
)
components = tuple(
component for var_data in all_var_datas for component in var_data.components
)
return None
return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
components=components,
)
def __bool__(self) -> bool:
"""Check if the var data is non-empty.
@ -251,7 +262,7 @@ class VarData:
Returns:
True if any field is set to a non-default value.
"""
return bool(self.state or self.imports or self.hooks or self.field_name)
return any(getattr(self, field.name) for field in dataclasses.fields(self))
@classmethod
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: