diff --git a/reflex/app.py b/reflex/app.py index ae6faa2a3..a43fbecad 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -875,6 +875,7 @@ class App(Base): with executor: result_futures = [] + custom_components_future = None def _mark_complete(_=None): progress.advance(task) @@ -892,7 +893,10 @@ class App(Base): _submit_work(ExecutorSafeFunctions.compile_app) # Compile the custom components. - _submit_work(ExecutorSafeFunctions.compile_custom_components) + custom_components_future = executor.submit( + ExecutorSafeFunctions.compile_custom_components, + ) + custom_components_future.add_done_callback(_mark_complete) # Compile the root stylesheet with base styles. _submit_work(compiler.compile_root_stylesheet, self.stylesheets) @@ -913,12 +917,11 @@ class App(Base): for future in concurrent.futures.as_completed(result_futures): compile_results.append(future.result()) - # Get imports from AppWrap components. - all_imports.update(app_root.get_imports()) - - # Iterate through all the custom components and add their imports to the all_imports. - for component in custom_components: - all_imports.update(component.get_imports()) + # Special case for custom_components, since we need the compiled imports + # to install proper frontend packages. + *custom_components_result, custom_components_imports = custom_components_future.result() + compile_results.append(custom_components_result) + all_imports.update(custom_components_imports) progress.advance(task) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index c0b654c93..07d6377a6 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -186,7 +186,9 @@ def _compile_component(component: Component) -> str: return templates.COMPONENT.render(component=component) -def _compile_components(components: set[CustomComponent]) -> str: +def _compile_components( + components: set[CustomComponent], +) -> tuple[str, Dict[str, list[ImportVar]]]: """Compile the components. Args: @@ -208,9 +210,12 @@ def _compile_components(components: set[CustomComponent]) -> str: imports = utils.merge_imports(imports, component_imports) # Compile the components page. - return templates.COMPONENTS.render( - imports=utils.compile_imports(imports), - components=component_renders, + return ( + templates.COMPONENTS.render( + imports=utils.compile_imports(imports), + components=component_renders, + ), + imports, ) @@ -401,7 +406,9 @@ def compile_page( return output_path, code -def compile_components(components: set[CustomComponent]): +def compile_components( + components: set[CustomComponent], +) -> tuple[str, str, Dict[str, list[ImportVar]]]: """Compile the custom components. Args: @@ -414,8 +421,8 @@ def compile_components(components: set[CustomComponent]): output_path = utils.get_components_path() # Compile the components. - code = _compile_components(components) - return output_path, code + code, imports = _compile_components(components) + return output_path, code, imports def compile_stateful_components( diff --git a/reflex/components/component.py b/reflex/components/component.py index 71b870b8c..b085ee352 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1265,6 +1265,9 @@ class CustomComponent(Component): # The props of the component. props: Dict[str, Any] = {} + # Props that reference other components. + component_props: Dict[str, Component] = {} + def __init__(self, *args, **kwargs): """Initialize the custom component. @@ -1296,17 +1299,13 @@ class CustomComponent(Component): self.props[format.to_camel_case(key)] = value continue - # Convert the type to a Var, then get the type of the var. - if not types._issubclass(type_, Var): - type_ = Var[type_] - type_ = types.get_args(type_)[0] - # Handle subclasses of Base. - if types._issubclass(type_, Base): + if isinstance(value, Base): base_value = Var.create(value) # Track hooks and imports associated with Component instances. - if base_value is not None and types._issubclass(type_, Component): + if base_value is not None and isinstance(value, Component): + self.component_props[key] = value value = base_value._replace( merge_var_data=VarData( # type: ignore imports=value.get_imports(), @@ -1373,6 +1372,16 @@ class CustomComponent(Component): custom_components |= self.get_component(self).get_custom_components( seen=seen ) + + # Fetch custom components from props as well. + for child_component in self.component_props.values(): + if child_component.tag is None: + continue + if child_component.tag not in seen: + seen.add(child_component.tag) + if isinstance(child_component, CustomComponent): + custom_components |= {child_component} + custom_components |= child_component.get_custom_components(seen=seen) return custom_components def _render(self) -> Tag: diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 04468b2ba..8dfe6dedc 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -4,6 +4,7 @@ import pytest import reflex as rx from reflex.base import Base +from reflex.compiler.compiler import compile_components from reflex.components.base.bare import Bare from reflex.components.chakra.layout.box import Box from reflex.components.component import ( @@ -1269,3 +1270,41 @@ def test_deprecated_props(capsys): assert "type={`type1`}" in c2_1_render["props"] assert "min={`min1`}" in c2_1_render["props"] assert "max={`max1`}" in c2_1_render["props"] + + +def test_custom_component_get_imports(): + class Inner(Component): + tag = "Inner" + library = "inner" + + class Other(Component): + tag = "Other" + library = "other" + + @rx.memo + def wrapper(): + return Inner.create() + + @rx.memo + def outer(c: Component): + return Other.create(c) + + custom_comp = wrapper() + + # Inner is not imported directly, but it is imported by the custom component. + assert "inner" not in custom_comp.get_imports() + + # The imports are only resolved during compilation. + _, _, imports_inner = compile_components(custom_comp.get_custom_components()) + assert "inner" in imports_inner + + outer_comp = outer(c=wrapper()) + + # Libraries are not imported directly, but are imported by the custom component. + assert "inner" not in outer_comp.get_imports() + assert "other" not in outer_comp.get_imports() + + # The imports are only resolved during compilation. + _, _, imports_outer = compile_components(outer_comp.get_custom_components()) + assert "inner" in imports_outer + assert "other" in imports_outer