From 19d8f6c7522a48e54f6247ad36426fdfe96e0f32 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 May 2024 14:46:27 -0700 Subject: [PATCH] [REF-2523] Implement new public Component API (#3203) --- reflex/__init__.py | 1 + reflex/__init__.pyi | 1 + reflex/components/component.py | 141 ++++++++++++++++++++ tests/components/test_component.py | 205 +++++++++++++++++++++++++++++ 4 files changed, 348 insertions(+) diff --git a/reflex/__init__.py b/reflex/__init__.py index ffb029550..f395ba52b 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -167,6 +167,7 @@ _MAPPING = { "reflex.style": ["style", "toggle_color_mode"], "reflex.testing": ["testing"], "reflex.utils": ["utils"], + "reflex.utils.imports": ["ImportVar"], "reflex.vars": ["vars", "cached_var", "Var"], } diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index ec89a4fde..76c47f5b1 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -150,6 +150,7 @@ from reflex import style as style from reflex.style import toggle_color_mode as toggle_color_mode from reflex import testing as testing from reflex import utils as utils +from reflex.utils.imports import ImportVar as ImportVar from reflex import vars as vars from reflex.vars import cached_var as cached_var from reflex.vars import Var as Var diff --git a/reflex/components/component.py b/reflex/components/component.py index 013181a58..0f387b1d5 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -213,6 +213,91 @@ class Component(BaseComponent, ABC): # State class associated with this component instance State: Optional[Type[reflex.state.State]] = None + def add_imports(self) -> dict[str, str | ImportVar | list[str | ImportVar]]: + """Add imports for the component. + + This method should be implemented by subclasses to add new imports for the component. + + Implementations do NOT need to call super(). The result of calling + add_imports in each parent class will be merged internally. + + Returns: + The additional imports for this component subclass. + + The format of the return value is a dictionary where the keys are the + library names (with optional npm-style version specifications) mapping + to a single name to be imported, or a list names to be imported. + + For advanced use cases, the values can be ImportVar instances (for + example, to provide an alias or mark that an import is the default + export from the given library). + + ```python + return { + "react": "useEffect", + "react-draggable": ["DraggableCore", rx.ImportVar(tag="Draggable", is_default=True)], + } + ``` + """ + return {} + + def add_hooks(self) -> list[str]: + """Add hooks inside the component function. + + Hooks are pieces of literal Javascript code that is inserted inside the + React component function. + + Each logical hook should be a separate string in the list. + + Common strings will be deduplicated and inserted into the component + function only once, so define const variables and other identical code + in their own strings to avoid defining the same const or hook multiple + times. + + If a hook depends on specific data from the component instance, be sure + to use unique values inside the string to _avoid_ deduplication. + + Implementations do NOT need to call super(). The result of calling + add_hooks in each parent class will be merged and deduplicated internally. + + Returns: + The additional hooks for this component subclass. + + ```python + return [ + "const [count, setCount] = useState(0);", + "useEffect(() => { setCount((prev) => prev + 1); console.log(`mounted ${count} times`); }, []);", + ] + ``` + """ + return [] + + def add_custom_code(self) -> list[str]: + """Add custom Javascript code into the page that contains this component. + + Custom code is inserted at module level, after any imports. + + Each string of custom code is deduplicated per-page, so take care to + avoid defining the same const or function differently from different + component instances. + + Custom code is useful for defining global functions or constants which + can then be referenced inside hooks or used by component vars. + + Implementations do NOT need to call super(). The result of calling + add_custom_code in each parent class will be merged and deduplicated internally. + + Returns: + The additional custom code for this component subclass. + + ```python + return [ + "const translatePoints = (event) => { return { x: event.clientX, y: event.clientY }; };", + ] + ``` + """ + return [] + @classmethod def __init_subclass__(cls, **kwargs): """Set default properties. @@ -949,6 +1034,30 @@ class Component(BaseComponent, ABC): return True return False + @classmethod + def _iter_parent_classes_with_method(cls, method: str) -> Iterator[Type[Component]]: + """Iterate through parent classes that define a given method. + + Used for handling the `add_*` API functions that internally simulate a super() call chain. + + Args: + method: The method to look for. + + Yields: + The parent classes that define the method (differently than the base). + """ + seen_methods = set([getattr(Component, method)]) + for clz in cls.mro(): + if clz is Component: + break + if not issubclass(clz, Component): + continue + method_func = getattr(clz, method, None) + if not callable(method_func) or method_func in seen_methods: + continue + seen_methods.add(method_func) + yield clz + def _get_custom_code(self) -> str | None: """Get custom code for the component. @@ -971,6 +1080,11 @@ class Component(BaseComponent, ABC): if custom_code is not None: code.add(custom_code) + # Add the custom code from add_custom_code method. + for clz in self._iter_parent_classes_with_method("add_custom_code"): + for item in clz.add_custom_code(self): + code.add(item) + # Add the custom code for the children. for child in self.children: code |= child._get_all_custom_code() @@ -1106,6 +1220,26 @@ class Component(BaseComponent, ABC): var._var_data.imports for var in self._get_vars() if var._var_data ] + # If any subclass implements add_imports, merge the imports. + def _make_list( + value: str | ImportVar | list[str | ImportVar], + ) -> list[str | ImportVar]: + if isinstance(value, (str, ImportVar)): + return [value] + return value + + _added_import_dicts = [] + for clz in self._iter_parent_classes_with_method("add_imports"): + _added_import_dicts.append( + { + package: [ + ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag + for tag in _make_list(maybe_tags) + ] + for package, maybe_tags in clz.add_imports(self).items() + } + ) + return imports.merge_imports( *self._get_props_imports(), self._get_dependencies_imports(), @@ -1113,6 +1247,7 @@ class Component(BaseComponent, ABC): _imports, event_imports, *var_imports, + *_added_import_dicts, ) def _get_all_imports(self, collapse: bool = False) -> imports.ImportDict: @@ -1248,6 +1383,12 @@ class Component(BaseComponent, ABC): if hooks is not None: code[hooks] = None + # Add the hook code from add_hooks for each parent class (this is reversed to preserve + # the order of the hooks in the final output) + for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))): + for hook in clz.add_hooks(self): + code[hook] = None + # Add the hook code for the children. for child in self.children: code = {**code, **child._get_all_hooks()} diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 15ceee7e4..96c1b6962 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -1746,3 +1746,208 @@ def test_invalid_event_trigger(): with pytest.raises(ValueError): trigger_comp(on_b=rx.console_log("log")) + + +@pytest.mark.parametrize( + "tags", + ( + ["Component"], + ["Component", "useState"], + [ImportVar(tag="Component")], + [ImportVar(tag="Component"), ImportVar(tag="useState")], + ["Component", ImportVar(tag="useState")], + ), +) +def test_component_add_imports(tags): + def _list_to_import_vars(tags: List[str]) -> List[ImportVar]: + return [ + ImportVar(tag=tag) if not isinstance(tag, ImportVar) else tag + for tag in tags + ] + + class BaseComponent(Component): + def _get_imports(self) -> imports.ImportDict: + return {} + + class Reference(Component): + def _get_imports(self) -> imports.ImportDict: + return imports.merge_imports( + super()._get_imports(), + {"react": _list_to_import_vars(tags)}, + {"foo": [ImportVar(tag="bar")]}, + ) + + class TestBase(Component): + def add_imports( + self, + ) -> Dict[str, Union[str, ImportVar, List[str], List[ImportVar]]]: + return {"foo": "bar"} + + class Test(TestBase): + def add_imports( + self, + ) -> Dict[str, Union[str, ImportVar, List[str], List[ImportVar]]]: + return {"react": (tags[0] if len(tags) == 1 else tags)} + + baseline = Reference.create() + test = Test.create() + + assert baseline._get_all_imports() == { + "react": _list_to_import_vars(tags), + "foo": [ImportVar(tag="bar")], + } + assert test._get_all_imports() == baseline._get_all_imports() + + +def test_component_add_hooks(): + class BaseComponent(Component): + def _get_hooks(self): + return "const hook1 = 42" + + class ChildComponent1(BaseComponent): + pass + + class GrandchildComponent1(ChildComponent1): + def add_hooks(self): + return [ + "const hook2 = 43", + "const hook3 = 44", + ] + + class GreatGrandchildComponent1(GrandchildComponent1): + def add_hooks(self): + return [ + "const hook4 = 45", + ] + + class GrandchildComponent2(ChildComponent1): + def _get_hooks(self): + return "const hook5 = 46" + + class GreatGrandchildComponent2(GrandchildComponent2): + def add_hooks(self): + return [ + "const hook2 = 43", + "const hook6 = 47", + ] + + assert list(BaseComponent()._get_all_hooks()) == ["const hook1 = 42"] + assert list(ChildComponent1()._get_all_hooks()) == ["const hook1 = 42"] + assert list(GrandchildComponent1()._get_all_hooks()) == [ + "const hook1 = 42", + "const hook2 = 43", + "const hook3 = 44", + ] + assert list(GreatGrandchildComponent1()._get_all_hooks()) == [ + "const hook1 = 42", + "const hook2 = 43", + "const hook3 = 44", + "const hook4 = 45", + ] + assert list(GrandchildComponent2()._get_all_hooks()) == ["const hook5 = 46"] + assert list(GreatGrandchildComponent2()._get_all_hooks()) == [ + "const hook5 = 46", + "const hook2 = 43", + "const hook6 = 47", + ] + assert list( + BaseComponent.create( + GrandchildComponent1.create(GreatGrandchildComponent2()), + GreatGrandchildComponent1(), + )._get_all_hooks(), + ) == [ + "const hook1 = 42", + "const hook2 = 43", + "const hook3 = 44", + "const hook5 = 46", + "const hook6 = 47", + "const hook4 = 45", + ] + assert list( + Fragment.create( + GreatGrandchildComponent2(), + GreatGrandchildComponent1(), + )._get_all_hooks() + ) == [ + "const hook5 = 46", + "const hook2 = 43", + "const hook6 = 47", + "const hook1 = 42", + "const hook3 = 44", + "const hook4 = 45", + ] + + +def test_component_add_custom_code(): + class BaseComponent(Component): + def _get_custom_code(self): + return "const custom_code1 = 42" + + class ChildComponent1(BaseComponent): + pass + + class GrandchildComponent1(ChildComponent1): + def add_custom_code(self): + return [ + "const custom_code2 = 43", + "const custom_code3 = 44", + ] + + class GreatGrandchildComponent1(GrandchildComponent1): + def add_custom_code(self): + return [ + "const custom_code4 = 45", + ] + + class GrandchildComponent2(ChildComponent1): + def _get_custom_code(self): + return "const custom_code5 = 46" + + class GreatGrandchildComponent2(GrandchildComponent2): + def add_custom_code(self): + return [ + "const custom_code2 = 43", + "const custom_code6 = 47", + ] + + assert BaseComponent()._get_all_custom_code() == {"const custom_code1 = 42"} + assert ChildComponent1()._get_all_custom_code() == {"const custom_code1 = 42"} + assert GrandchildComponent1()._get_all_custom_code() == { + "const custom_code1 = 42", + "const custom_code2 = 43", + "const custom_code3 = 44", + } + assert GreatGrandchildComponent1()._get_all_custom_code() == { + "const custom_code1 = 42", + "const custom_code2 = 43", + "const custom_code3 = 44", + "const custom_code4 = 45", + } + assert GrandchildComponent2()._get_all_custom_code() == {"const custom_code5 = 46"} + assert GreatGrandchildComponent2()._get_all_custom_code() == { + "const custom_code2 = 43", + "const custom_code5 = 46", + "const custom_code6 = 47", + } + assert BaseComponent.create( + GrandchildComponent1.create(GreatGrandchildComponent2()), + GreatGrandchildComponent1(), + )._get_all_custom_code() == { + "const custom_code1 = 42", + "const custom_code2 = 43", + "const custom_code3 = 44", + "const custom_code4 = 45", + "const custom_code5 = 46", + "const custom_code6 = 47", + } + assert Fragment.create( + GreatGrandchildComponent2(), + GreatGrandchildComponent1(), + )._get_all_custom_code() == { + "const custom_code1 = 42", + "const custom_code2 = 43", + "const custom_code3 = 44", + "const custom_code4 = 45", + "const custom_code5 = 46", + "const custom_code6 = 47", + }