improve behavior

This commit is contained in:
Khaleel Al-Adhami 2024-10-23 14:57:33 -07:00
parent 49ab17d6d6
commit 56b703d2de
2 changed files with 146 additions and 10 deletions

View File

@ -4,9 +4,10 @@ from __future__ import annotations
from typing import Any, Iterator
from reflex.components.component import Component
from reflex.components.component import Component, LiteralComponentVar
from reflex.components.tags import Tag
from reflex.components.tags.tagless import Tagless
from reflex.utils.imports import ParsedImportDict
from reflex.vars import BooleanVar, ObjectVar, Var
@ -31,6 +32,72 @@ class Bare(Component):
contents = str(contents) if contents is not None else ""
return cls(contents=contents) # type: ignore
def _get_all_hooks_internal(self) -> dict[str, None]:
"""Include the hooks for the component.
Returns:
The hooks for the component.
"""
hooks = super()._get_all_hooks_internal()
if isinstance(self.contents, LiteralComponentVar):
hooks |= self.contents._var_value._get_all_hooks_internal()
return hooks
def _get_all_hooks(self) -> dict[str, None]:
"""Include the hooks for the component.
Returns:
The hooks for the component.
"""
hooks = super()._get_all_hooks()
if isinstance(self.contents, LiteralComponentVar):
hooks |= self.contents._var_value._get_all_hooks()
return hooks
def _get_all_imports(self) -> ParsedImportDict:
"""Include the imports for the component.
Returns:
The imports for the component.
"""
imports = super()._get_all_imports()
if isinstance(self.contents, LiteralComponentVar):
imports |= self.contents._var_value._get_all_imports()
return imports
def _get_all_dynamic_imports(self) -> set[str]:
"""Get dynamic imports for the component.
Returns:
The dynamic imports.
"""
dynamic_imports = super()._get_all_dynamic_imports()
if isinstance(self.contents, LiteralComponentVar):
dynamic_imports |= self.contents._var_value._get_all_dynamic_imports()
return dynamic_imports
def _get_all_custom_code(self) -> set[str]:
"""Get custom code for the component.
Returns:
The custom code.
"""
custom_code = super()._get_all_custom_code()
if isinstance(self.contents, LiteralComponentVar):
custom_code |= self.contents._var_value._get_all_custom_code()
return custom_code
def _get_all_refs(self) -> set[str]:
"""Get the refs for the children of the component.
Returns:
The refs for the children.
"""
refs = super()._get_all_refs()
if isinstance(self.contents, LiteralComponentVar):
refs |= self.contents._var_value._get_all_refs()
return refs
def _render(self) -> Tag:
if isinstance(self.contents, Var):
if isinstance(self.contents, (BooleanVar, ObjectVar)):

View File

@ -65,7 +65,8 @@ from reflex.vars.base import (
Var,
cached_property_no_lock,
)
from reflex.vars.function import FunctionStringVar
from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar
from reflex.vars.number import ternary_operation
from reflex.vars.object import ObjectVar
from reflex.vars.sequence import LiteralArrayVar
@ -2350,7 +2351,7 @@ class MemoizationLeaf(Component):
load_dynamic_serializer()
class ComponentVar(Var[Component], python_types=Component):
class ComponentVar(Var[Component], python_types=BaseComponent):
"""A Var that represents a Component."""
@ -2365,15 +2366,68 @@ def empty_component() -> Component:
return Bare.create("")
def render_dict_to_var(tag: dict) -> Var:
def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) -> Var:
"""Convert a render dict to a Var.
Args:
tag: The render dict.
imported_names: The names of the imported components.
Returns:
The Var.
"""
if not isinstance(tag, dict):
if isinstance(tag, Component):
return render_dict_to_var(tag.render(), imported_names)
return Var.create(tag)
if "iterable" in tag:
function_return = Var.create(
[
render_dict_to_var(child.render(), imported_names)
for child in tag["children"]
]
)
func = ArgsFunctionOperation.create(
(tag["arg_var_name"], tag["index_var_name"]),
function_return,
)
return FunctionStringVar.create("Array.prototype.map.call").call(
tag["iterable"]
if not isinstance(tag["iterable"], ObjectVar)
else tag["iterable"].items(),
func,
)
if tag["name"] == "match":
element = tag["cond"]
conditionals = tag["default"]
for case in tag["match_cases"][::-1]:
condition = case[0].to_string() == element.to_string()
for pattern in case[1:-1]:
condition = condition | (pattern.to_string() == element.to_string())
conditionals = ternary_operation(
condition,
case[-1],
conditionals,
)
return conditionals
if "cond" in tag:
return ternary_operation(
tag["cond"],
render_dict_to_var(tag["true_value"], imported_names),
render_dict_to_var(tag["false_value"], imported_names)
if tag["false_value"] is not None
else Var.create(None),
)
props = {}
special_props = []
@ -2394,7 +2448,16 @@ def render_dict_to_var(tag: dict) -> Var:
contents = tag["contents"][1:-1] if tag["contents"] else None
tag_name = Var(tag.get("name") or "Fragment")
raw_tag_name = tag.get("name")
tag_name = Var(raw_tag_name or "Fragment")
tag_name = (
Var.create(raw_tag_name)
if raw_tag_name
and raw_tag_name.split(".")[0] not in imported_names
and raw_tag_name.lower() == raw_tag_name
else tag_name
)
return FunctionStringVar.create(
"jsx",
@ -2402,7 +2465,7 @@ def render_dict_to_var(tag: dict) -> Var:
tag_name,
props,
*([Var(contents)] if contents is not None else []),
*[render_dict_to_var(child) for child in tag["children"]],
*[render_dict_to_var(child, imported_names) for child in tag["children"]],
)
@ -2413,7 +2476,7 @@ def render_dict_to_var(tag: dict) -> Var:
class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
"""A Var that represents a Component."""
_var_value: Component = dataclasses.field(default_factory=empty_component)
_var_value: BaseComponent = dataclasses.field(default_factory=empty_component)
@cached_property_no_lock
def _cached_var_name(self) -> str:
@ -2422,7 +2485,13 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
Returns:
The name of the var.
"""
return str(render_dict_to_var(self._var_value.render()))
var_data = self._get_all_var_data()
if var_data is not None:
# flatten imports
imported_names = {j.alias or j.name for i in var_data.imports for j in i[1]}
else:
imported_names = set()
return str(render_dict_to_var(self._var_value.render(), imported_names))
@cached_property_no_lock
def _cached_get_all_var_data(self) -> VarData | None:
@ -2440,7 +2509,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
}
),
VarData(
imports=self._var_value._get_all_imports(collapse=True),
imports=self._var_value._get_all_imports(),
),
*(
[
@ -2463,7 +2532,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__,))
return hash((self.__class__.__name__, self._js_expr))
@classmethod
def create(