From 036afa951a5700b04722a0bcd7d3143dde07a12a Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 13 Mar 2024 13:41:17 -0700 Subject: [PATCH] Make @rx.memo work with state vars passed as props (#2810) * Make @rx.memo work with state vars passed as props Seems like this was a regression from the StatefulComponent refactor, because trying to pass a state Var to a CustomComponent gave undefined, likely due to `_get_vars` not accounting for `self.props` in CustomComponents. With this change, it works. Integration test added to `test_var_operations.py` * Allow CustomComponent props to be Component Avoid calling `.json()` on all Base types because the Var serializer already does that, but this way, more specific types (like Component) can be serialized differently. When the type is Component, attach a VarData with the imports and hooks to when the Var is rendered, it also carries the correct imports/hooks and does not throw frontend errors. --- integration/test_var_operations.py | 20 ++++++++++++++++++++ reflex/components/component.py | 29 ++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/integration/test_var_operations.py b/integration/test_var_operations.py index 768aa3d34..03bb31d6b 100644 --- a/integration/test_var_operations.py +++ b/integration/test_var_operations.py @@ -34,6 +34,14 @@ def VarOperations(): app = rx.App(state=rx.State) + @rx.memo + def memo_comp(list1: list[int], int_var1: int, id: str): + return rx.text(list1, int_var1, id=id) + + @rx.memo + def memo_comp_nested(int_var2: int, id: str): + return memo_comp(list1=[3, 4], int_var1=int_var2, id=id) + @app.add_page def index(): return rx.vstack( @@ -566,6 +574,15 @@ def VarOperations(): ), id="foreach_list_nested", ), + memo_comp( + list1=VarOperationState.list1, + int_var1=VarOperationState.int_var1, + id="memo_comp", + ), + memo_comp_nested( + int_var2=VarOperationState.int_var2, + id="memo_comp_nested", + ), ) @@ -759,6 +776,9 @@ def test_var_operations(driver, var_operations: AppHarness): ("foreach_list_arg", "1\n2"), ("foreach_list_ix", "1\n2"), ("foreach_list_nested", "1\n1\n2"), + # rx.memo component with state + ("memo_comp", "1210"), + ("memo_comp_nested", "345"), ] for tag, expected in tests: diff --git a/reflex/components/component.py b/reflex/components/component.py index 44e1ffcdc..71b870b8c 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1303,12 +1303,18 @@ class CustomComponent(Component): # Handle subclasses of Base. if types._issubclass(type_, Base): - try: - value = BaseVar( - _var_name=value.json(), _var_type=type_, _var_is_local=True + base_value = Var.create(value) + + # Track hooks and imports associated with Component instances. + if base_value is not None and types._issubclass(type_, Component): + value = base_value._replace( + merge_var_data=VarData( # type: ignore + imports=value.get_imports(), + hooks=value.get_hooks(), + ) ) - except Exception: - value = Var.create(value) + else: + value = base_value else: value = Var.create(value, _var_is_string=type(value) is str) @@ -1393,6 +1399,19 @@ class CustomComponent(Component): for name, prop in self.props.items() ] + def _get_vars(self, include_children: bool = False) -> list[Var]: + """Walk all Vars used in this component. + + Args: + include_children: Whether to include Vars from children. + + Returns: + Each var referenced by the component (props, styles, event handlers). + """ + return super()._get_vars(include_children=include_children) + [ + prop for prop in self.props.values() if isinstance(prop, Var) + ] + @lru_cache(maxsize=None) # noqa def get_component(self) -> Component: """Render the component.