From 3715462eb47c0adb9d33db476bd5c2cb71dafaa7 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 10 May 2024 20:25:04 -0700 Subject: [PATCH] Allow `Component.add_style` to return a regular dict (#3264) * Allow `Component.add_style` to return a regular dict It's more convenient to allow returning a regular dict without having to import and wrap the value in `rx.style.Style`. If the dict contains any Var or encoded VarData f-strings, these will be picked up when the plain dicts are passed to Style.update(). Because Style.update already merges VarData, there is no reason to explicitly merge it again in this function; this change keeps the merging logic inside the Style class. * Test for Style.update with existing Style with _var_data and kwargs Should retain the _var_data from the original Style instance * style: Avoid losing VarData in Style.update If a Style class with _var_data is passed to `Style.update` along with kwargs, then the _var_data was lost in the double-splat dictionary expansion. Instead, only apply the kwargs to an existing or new Style instance to retain _var_data and properly convert values. * add_style return annotation is Dict[str, Any] * nit: use lowercase dict in annotation --- reflex/components/component.py | 6 +--- reflex/style.py | 7 +++-- tests/components/test_component.py | 49 ++++++++++++++++++++++++++++++ tests/test_style.py | 24 ++++++++++++++- 4 files changed, 78 insertions(+), 8 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 7cc26348c..39a97792e 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -782,7 +782,7 @@ class Component(BaseComponent, ABC): return cls(children=children, **props) - def add_style(self) -> Style | None: + def add_style(self) -> dict[str, Any] | None: """Add style to the component. Downstream components can override this method to return a style dict @@ -802,20 +802,16 @@ class Component(BaseComponent, ABC): The style to add. """ styles = [] - vars = [] # Walk the MRO to call all `add_style` methods. for base in self._iter_parent_classes_with_method("add_style"): s = base.add_style(self) # type: ignore if s is not None: styles.append(s) - vars.append(s._var_data) _style = Style() for s in reversed(styles): _style.update(s) - - _style._var_data = VarData.merge(*vars) return _style def _get_component_style(self, styles: ComponentStyle) -> Style | None: diff --git a/reflex/style.py b/reflex/style.py index d77c2bb7c..e48aa3dd8 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -180,12 +180,15 @@ class Style(dict): style_dict: The style dictionary. kwargs: Other key value pairs to apply to the dict update. """ - if kwargs: - style_dict = {**(style_dict or {}), **kwargs} if not isinstance(style_dict, Style): converted_dict = type(self)(style_dict) else: converted_dict = style_dict + if kwargs: + if converted_dict is None: + converted_dict = type(self)(kwargs) + else: + converted_dict.update(kwargs) # Combine our VarData with that of any Vars in the style_dict that was passed. self._var_data = VarData.merge(self._var_data, converted_dict._var_data) super().update(converted_dict) diff --git a/tests/components/test_component.py b/tests/components/test_component.py index c2a28c84d..6245746c9 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -1953,6 +1953,55 @@ def test_component_add_custom_code(): } +def test_add_style_embedded_vars(test_state: BaseState): + """Test that add_style works with embedded vars when returning a plain dict. + + Args: + test_state: A test state. + """ + v0 = Var.create_safe("parent")._replace( + merge_var_data=VarData(hooks={"useParent": None}), # type: ignore + ) + v1 = rx.color("plum", 10) + v2 = Var.create_safe("text")._replace( + merge_var_data=VarData(hooks={"useText": None}), # type: ignore + ) + + class ParentComponent(Component): + def add_style(self): + return Style( + { + "fake_parent": v0, + } + ) + + class StyledComponent(ParentComponent): + tag = "StyledComponent" + + def add_style(self): + return { + "color": v1, + "fake": v2, + "margin": f"{test_state.num}%", + } + + page = rx.vstack(StyledComponent.create()) + page._add_style_recursive(Style()) + + assert ( + "const test_state = useContext(StateContexts.test_state)" + in page._get_all_hooks_internal() + ) + assert "useText" in page._get_all_hooks_internal() + assert "useParent" in page._get_all_hooks_internal() + assert ( + str(page).count( + 'css={{"fakeParent": "parent", "color": "var(--plum-10)", "fake": "text", "margin": `${test_state.num}%`}}' + ) + == 1 + ) + + def test_add_style_foreach(): class StyledComponent(Component): tag = "StyledComponent" diff --git a/tests/test_style.py b/tests/test_style.py index ccc7b6569..825d72a9d 100644 --- a/tests/test_style.py +++ b/tests/test_style.py @@ -8,7 +8,7 @@ import reflex as rx from reflex import style from reflex.components.component import evaluate_style_namespaces from reflex.style import Style -from reflex.vars import Var +from reflex.vars import Var, VarData test_style = [ ({"a": 1}, {"a": 1}), @@ -503,3 +503,25 @@ def test_evaluate_style_namespaces(): assert rx.text.__call__ not in style_dict style_dict = evaluate_style_namespaces(style_dict) # type: ignore assert rx.text.__call__ in style_dict + + +def test_style_update_with_var_data(): + """Test that .update with a Style containing VarData works.""" + red_var = Var.create_safe("red")._replace( + merge_var_data=VarData(hooks={"const red = true": None}), # type: ignore + ) + blue_var = Var.create_safe("blue", _var_is_local=False)._replace( + merge_var_data=VarData(hooks={"const blue = true": None}), # type: ignore + ) + + s1 = Style( + { + "color": red_var, + } + ) + s2 = Style() + s2.update(s1, background_color=f"{blue_var}ish") + assert s2 == {"color": "red", "backgroundColor": "`${blue}ish`"} + assert s2._var_data is not None + assert "const red = true" in s2._var_data.hooks + assert "const blue = true" in s2._var_data.hooks