diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 8b4d7b216..d2c5a1f6c 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -4,9 +4,10 @@ from __future__ import annotations from typing import Any, Iterator -from reflex.components.component import Component +from reflex.components.component import Component, LiteralComponentVar from reflex.components.tags import Tag from reflex.components.tags.tagless import Tagless +from reflex.utils.imports import ParsedImportDict from reflex.vars import BooleanVar, ObjectVar, Var @@ -31,6 +32,72 @@ class Bare(Component): contents = str(contents) if contents is not None else "" return cls(contents=contents) # type: ignore + def _get_all_hooks_internal(self) -> dict[str, None]: + """Include the hooks for the component. + + Returns: + The hooks for the component. + """ + hooks = super()._get_all_hooks_internal() + if isinstance(self.contents, LiteralComponentVar): + hooks |= self.contents._var_value._get_all_hooks_internal() + return hooks + + def _get_all_hooks(self) -> dict[str, None]: + """Include the hooks for the component. + + Returns: + The hooks for the component. + """ + hooks = super()._get_all_hooks() + if isinstance(self.contents, LiteralComponentVar): + hooks |= self.contents._var_value._get_all_hooks() + return hooks + + def _get_all_imports(self) -> ParsedImportDict: + """Include the imports for the component. + + Returns: + The imports for the component. + """ + imports = super()._get_all_imports() + if isinstance(self.contents, LiteralComponentVar): + imports |= self.contents._var_value._get_all_imports() + return imports + + def _get_all_dynamic_imports(self) -> set[str]: + """Get dynamic imports for the component. + + Returns: + The dynamic imports. + """ + dynamic_imports = super()._get_all_dynamic_imports() + if isinstance(self.contents, LiteralComponentVar): + dynamic_imports |= self.contents._var_value._get_all_dynamic_imports() + return dynamic_imports + + def _get_all_custom_code(self) -> set[str]: + """Get custom code for the component. + + Returns: + The custom code. + """ + custom_code = super()._get_all_custom_code() + if isinstance(self.contents, LiteralComponentVar): + custom_code |= self.contents._var_value._get_all_custom_code() + return custom_code + + def _get_all_refs(self) -> set[str]: + """Get the refs for the children of the component. + + Returns: + The refs for the children. + """ + refs = super()._get_all_refs() + if isinstance(self.contents, LiteralComponentVar): + refs |= self.contents._var_value._get_all_refs() + return refs + def _render(self) -> Tag: if isinstance(self.contents, Var): if isinstance(self.contents, (BooleanVar, ObjectVar)): diff --git a/reflex/components/component.py b/reflex/components/component.py index 607598ddf..f85fb8ee4 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -65,7 +65,8 @@ from reflex.vars.base import ( Var, cached_property_no_lock, ) -from reflex.vars.function import FunctionStringVar +from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar +from reflex.vars.number import ternary_operation from reflex.vars.object import ObjectVar from reflex.vars.sequence import LiteralArrayVar @@ -2350,7 +2351,7 @@ class MemoizationLeaf(Component): load_dynamic_serializer() -class ComponentVar(Var[Component], python_types=Component): +class ComponentVar(Var[Component], python_types=BaseComponent): """A Var that represents a Component.""" @@ -2365,15 +2366,68 @@ def empty_component() -> Component: return Bare.create("") -def render_dict_to_var(tag: dict) -> Var: +def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) -> Var: """Convert a render dict to a Var. Args: tag: The render dict. + imported_names: The names of the imported components. Returns: The Var. """ + if not isinstance(tag, dict): + if isinstance(tag, Component): + return render_dict_to_var(tag.render(), imported_names) + return Var.create(tag) + + if "iterable" in tag: + function_return = Var.create( + [ + render_dict_to_var(child.render(), imported_names) + for child in tag["children"] + ] + ) + + func = ArgsFunctionOperation.create( + (tag["arg_var_name"], tag["index_var_name"]), + function_return, + ) + + return FunctionStringVar.create("Array.prototype.map.call").call( + tag["iterable"] + if not isinstance(tag["iterable"], ObjectVar) + else tag["iterable"].items(), + func, + ) + + if tag["name"] == "match": + element = tag["cond"] + + conditionals = tag["default"] + + for case in tag["match_cases"][::-1]: + condition = case[0].to_string() == element.to_string() + for pattern in case[1:-1]: + condition = condition | (pattern.to_string() == element.to_string()) + + conditionals = ternary_operation( + condition, + case[-1], + conditionals, + ) + + return conditionals + + if "cond" in tag: + return ternary_operation( + tag["cond"], + render_dict_to_var(tag["true_value"], imported_names), + render_dict_to_var(tag["false_value"], imported_names) + if tag["false_value"] is not None + else Var.create(None), + ) + props = {} special_props = [] @@ -2394,7 +2448,16 @@ def render_dict_to_var(tag: dict) -> Var: contents = tag["contents"][1:-1] if tag["contents"] else None - tag_name = Var(tag.get("name") or "Fragment") + raw_tag_name = tag.get("name") + tag_name = Var(raw_tag_name or "Fragment") + + tag_name = ( + Var.create(raw_tag_name) + if raw_tag_name + and raw_tag_name.split(".")[0] not in imported_names + and raw_tag_name.lower() == raw_tag_name + else tag_name + ) return FunctionStringVar.create( "jsx", @@ -2402,7 +2465,7 @@ def render_dict_to_var(tag: dict) -> Var: tag_name, props, *([Var(contents)] if contents is not None else []), - *[render_dict_to_var(child) for child in tag["children"]], + *[render_dict_to_var(child, imported_names) for child in tag["children"]], ) @@ -2413,7 +2476,7 @@ def render_dict_to_var(tag: dict) -> Var: class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): """A Var that represents a Component.""" - _var_value: Component = dataclasses.field(default_factory=empty_component) + _var_value: BaseComponent = dataclasses.field(default_factory=empty_component) @cached_property_no_lock def _cached_var_name(self) -> str: @@ -2422,7 +2485,13 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): Returns: The name of the var. """ - return str(render_dict_to_var(self._var_value.render())) + var_data = self._get_all_var_data() + if var_data is not None: + # flatten imports + imported_names = {j.alias or j.name for i in var_data.imports for j in i[1]} + else: + imported_names = set() + return str(render_dict_to_var(self._var_value.render(), imported_names)) @cached_property_no_lock def _cached_get_all_var_data(self) -> VarData | None: @@ -2440,7 +2509,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): } ), VarData( - imports=self._var_value._get_all_imports(collapse=True), + imports=self._var_value._get_all_imports(), ), *( [ @@ -2463,7 +2532,7 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__,)) + return hash((self.__class__.__name__, self._js_expr)) @classmethod def create(