Merge branch 'masenf/custom_component_respect_value_not_annotation' into masenf/multiprocess-compile-fx-conflict

This commit is contained in:
Masen Furer 2024-03-14 20:25:20 -07:00
commit 04211a3234
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4
4 changed files with 79 additions and 21 deletions

View File

@ -875,6 +875,7 @@ class App(Base):
with executor:
result_futures = []
custom_components_future = None
def _mark_complete(_=None):
progress.advance(task)
@ -892,7 +893,10 @@ class App(Base):
_submit_work(ExecutorSafeFunctions.compile_app)
# Compile the custom components.
_submit_work(ExecutorSafeFunctions.compile_custom_components)
custom_components_future = executor.submit(
ExecutorSafeFunctions.compile_custom_components,
)
custom_components_future.add_done_callback(_mark_complete)
# Compile the root stylesheet with base styles.
_submit_work(compiler.compile_root_stylesheet, self.stylesheets)
@ -913,12 +917,11 @@ class App(Base):
for future in concurrent.futures.as_completed(result_futures):
compile_results.append(future.result())
# Get imports from AppWrap components.
all_imports.update(app_root.get_imports())
# Iterate through all the custom components and add their imports to the all_imports.
for component in custom_components:
all_imports.update(component.get_imports())
# Special case for custom_components, since we need the compiled imports
# to install proper frontend packages.
*custom_components_result, custom_components_imports = custom_components_future.result()
compile_results.append(custom_components_result)
all_imports.update(custom_components_imports)
progress.advance(task)

View File

@ -186,7 +186,9 @@ def _compile_component(component: Component) -> str:
return templates.COMPONENT.render(component=component)
def _compile_components(components: set[CustomComponent]) -> str:
def _compile_components(
components: set[CustomComponent],
) -> tuple[str, Dict[str, list[ImportVar]]]:
"""Compile the components.
Args:
@ -208,9 +210,12 @@ def _compile_components(components: set[CustomComponent]) -> str:
imports = utils.merge_imports(imports, component_imports)
# Compile the components page.
return templates.COMPONENTS.render(
imports=utils.compile_imports(imports),
components=component_renders,
return (
templates.COMPONENTS.render(
imports=utils.compile_imports(imports),
components=component_renders,
),
imports,
)
@ -401,7 +406,9 @@ def compile_page(
return output_path, code
def compile_components(components: set[CustomComponent]):
def compile_components(
components: set[CustomComponent],
) -> tuple[str, str, Dict[str, list[ImportVar]]]:
"""Compile the custom components.
Args:
@ -414,8 +421,8 @@ def compile_components(components: set[CustomComponent]):
output_path = utils.get_components_path()
# Compile the components.
code = _compile_components(components)
return output_path, code
code, imports = _compile_components(components)
return output_path, code, imports
def compile_stateful_components(

View File

@ -1265,6 +1265,9 @@ class CustomComponent(Component):
# The props of the component.
props: Dict[str, Any] = {}
# Props that reference other components.
component_props: Dict[str, Component] = {}
def __init__(self, *args, **kwargs):
"""Initialize the custom component.
@ -1296,17 +1299,13 @@ class CustomComponent(Component):
self.props[format.to_camel_case(key)] = value
continue
# Convert the type to a Var, then get the type of the var.
if not types._issubclass(type_, Var):
type_ = Var[type_]
type_ = types.get_args(type_)[0]
# Handle subclasses of Base.
if types._issubclass(type_, Base):
if isinstance(value, Base):
base_value = Var.create(value)
# Track hooks and imports associated with Component instances.
if base_value is not None and types._issubclass(type_, Component):
if base_value is not None and isinstance(value, Component):
self.component_props[key] = value
value = base_value._replace(
merge_var_data=VarData( # type: ignore
imports=value.get_imports(),
@ -1373,6 +1372,16 @@ class CustomComponent(Component):
custom_components |= self.get_component(self).get_custom_components(
seen=seen
)
# Fetch custom components from props as well.
for child_component in self.component_props.values():
if child_component.tag is None:
continue
if child_component.tag not in seen:
seen.add(child_component.tag)
if isinstance(child_component, CustomComponent):
custom_components |= {child_component}
custom_components |= child_component.get_custom_components(seen=seen)
return custom_components
def _render(self) -> Tag:

View File

@ -4,6 +4,7 @@ import pytest
import reflex as rx
from reflex.base import Base
from reflex.compiler.compiler import compile_components
from reflex.components.base.bare import Bare
from reflex.components.chakra.layout.box import Box
from reflex.components.component import (
@ -1269,3 +1270,41 @@ def test_deprecated_props(capsys):
assert "type={`type1`}" in c2_1_render["props"]
assert "min={`min1`}" in c2_1_render["props"]
assert "max={`max1`}" in c2_1_render["props"]
def test_custom_component_get_imports():
class Inner(Component):
tag = "Inner"
library = "inner"
class Other(Component):
tag = "Other"
library = "other"
@rx.memo
def wrapper():
return Inner.create()
@rx.memo
def outer(c: Component):
return Other.create(c)
custom_comp = wrapper()
# Inner is not imported directly, but it is imported by the custom component.
assert "inner" not in custom_comp.get_imports()
# The imports are only resolved during compilation.
_, _, imports_inner = compile_components(custom_comp.get_custom_components())
assert "inner" in imports_inner
outer_comp = outer(c=wrapper())
# Libraries are not imported directly, but are imported by the custom component.
assert "inner" not in outer_comp.get_imports()
assert "other" not in outer_comp.get_imports()
# The imports are only resolved during compilation.
_, _, imports_outer = compile_components(outer_comp.get_custom_components())
assert "inner" in imports_outer
assert "other" in imports_outer