diff --git a/pynecone/compiler/compiler.py b/pynecone/compiler/compiler.py index 308c5b3fe..d16b0801c 100644 --- a/pynecone/compiler/compiler.py +++ b/pynecone/compiler/compiler.py @@ -70,6 +70,7 @@ def _compile_page(component: Component, state: Type[State]) -> str: state=utils.compile_state(state), events=utils.compile_events(state), effects=utils.compile_effects(state), + hooks=path_ops.join(component.get_hooks()), render=component.render(), ) diff --git a/pynecone/compiler/templates.py b/pynecone/compiler/templates.py index 39fb89c9b..47e97c9a8 100644 --- a/pynecone/compiler/templates.py +++ b/pynecone/compiler/templates.py @@ -76,6 +76,7 @@ PAGE = path_ops.join( "{state}", "{events}", "{effects}", + "{hooks}", "return (", "{render}", ")", diff --git a/pynecone/components/component.py b/pynecone/components/component.py index 09eda8db2..10bf27434 100644 --- a/pynecone/components/component.py +++ b/pynecone/components/component.py @@ -459,6 +459,29 @@ class Component(Base, ABC): self._get_imports(), *[child.get_imports() for child in self.children] ) + def _get_hooks(self) -> Optional[str]: + return None + + def get_hooks(self) -> Set[str]: + """Get javascript code for react hooks. + + Returns: + The code that should appear just before returning the rendered component. + """ + # Store the code in a set to avoid duplicates. + code = set() + + # Add the hook code for this component. + hooks = self._get_hooks() + if hooks is not None: + code.add(hooks) + + # Add the hook code for the children. + for child in self.children: + code.update(child.get_hooks()) + + return code + def get_custom_components( self, seen: Optional[Set[str]] = None ) -> Set[CustomComponent]: diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 100589a37..a921498fa 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -83,6 +83,36 @@ def component2() -> Type[Component]: return TestComponent2 +@pytest.fixture +def component3() -> Type[Component]: + """A test component with hook defined. + + Returns: + A test component. + """ + + class TestComponent3(Component): + def _get_hooks(self) -> str: + return "const a = () => true" + + return TestComponent3 + + +@pytest.fixture +def component4() -> Type[Component]: + """A test component with hook defined. + + Returns: + A test component. + """ + + class TestComponent4(Component): + def _get_hooks(self) -> str: + return "const b = () => false" + + return TestComponent4 + + @pytest.fixture def on_click1() -> EventHandler: """A sample on click function. @@ -363,3 +393,42 @@ def test_invalid_event_handler_args(component2, TestState): component2.create(on_open=TestState.do_something) with pytest.raises(ValueError): component2.create(on_open=[TestState.do_something_arg, TestState.do_something]) + + +def test_get_hooks_nested(component1, component2, component3): + """Test that a component returns hooks from child components. + + Args: + component1: test component. + component2: another component. + component3: component with hooks defined. + """ + c = component1.create( + component2.create(arr=[]), + component3.create(), + component3.create(), + component3.create(), + text="a", + number=1, + ) + assert c.get_hooks() == component3().get_hooks() + + +def test_get_hooks_nested2(component3, component4): + """Test that a component returns both when parent and child have hooks. + + Args: + component3: component with hooks defined. + component4: component with different hooks defined. + """ + exp_hooks = component3().get_hooks().union(component4().get_hooks()) + assert component3.create(component4.create()).get_hooks() == exp_hooks + assert component4.create(component3.create()).get_hooks() == exp_hooks + assert ( + component4.create( + component3.create(), + component4.create(), + component3.create(), + ).get_hooks() + == exp_hooks + )