Allow Component.add_style to return a regular dict (#3264)

* Allow `Component.add_style` to return a regular dict

It's more convenient to allow returning a regular dict without having to import
and wrap the value in `rx.style.Style`.

If the dict contains any Var or encoded VarData f-strings, these will be picked
up when the plain dicts are passed to Style.update().

Because Style.update already merges VarData, there is no reason to explicitly
merge it again in this function; this change keeps the merging logic inside the
Style class.

* Test for Style.update with existing Style with _var_data and kwargs

Should retain the _var_data from the original Style instance

* style: Avoid losing VarData in Style.update

If a Style class with _var_data is passed to `Style.update` along with kwargs,
then the _var_data was lost in the double-splat dictionary expansion.

Instead, only apply the kwargs to an existing or new Style instance to retain
_var_data and properly convert values.

* add_style return annotation is Dict[str, Any]

* nit: use lowercase dict in annotation
This commit is contained in:
Masen Furer 2024-05-10 20:25:04 -07:00 committed by GitHub
parent aa2cf80f70
commit 3715462eb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 78 additions and 8 deletions

View File

@ -782,7 +782,7 @@ class Component(BaseComponent, ABC):
return cls(children=children, **props)
def add_style(self) -> Style | None:
def add_style(self) -> dict[str, Any] | None:
"""Add style to the component.
Downstream components can override this method to return a style dict
@ -802,20 +802,16 @@ class Component(BaseComponent, ABC):
The style to add.
"""
styles = []
vars = []
# Walk the MRO to call all `add_style` methods.
for base in self._iter_parent_classes_with_method("add_style"):
s = base.add_style(self) # type: ignore
if s is not None:
styles.append(s)
vars.append(s._var_data)
_style = Style()
for s in reversed(styles):
_style.update(s)
_style._var_data = VarData.merge(*vars)
return _style
def _get_component_style(self, styles: ComponentStyle) -> Style | None:

View File

@ -180,12 +180,15 @@ class Style(dict):
style_dict: The style dictionary.
kwargs: Other key value pairs to apply to the dict update.
"""
if kwargs:
style_dict = {**(style_dict or {}), **kwargs}
if not isinstance(style_dict, Style):
converted_dict = type(self)(style_dict)
else:
converted_dict = style_dict
if kwargs:
if converted_dict is None:
converted_dict = type(self)(kwargs)
else:
converted_dict.update(kwargs)
# Combine our VarData with that of any Vars in the style_dict that was passed.
self._var_data = VarData.merge(self._var_data, converted_dict._var_data)
super().update(converted_dict)

View File

@ -1953,6 +1953,55 @@ def test_component_add_custom_code():
}
def test_add_style_embedded_vars(test_state: BaseState):
"""Test that add_style works with embedded vars when returning a plain dict.
Args:
test_state: A test state.
"""
v0 = Var.create_safe("parent")._replace(
merge_var_data=VarData(hooks={"useParent": None}), # type: ignore
)
v1 = rx.color("plum", 10)
v2 = Var.create_safe("text")._replace(
merge_var_data=VarData(hooks={"useText": None}), # type: ignore
)
class ParentComponent(Component):
def add_style(self):
return Style(
{
"fake_parent": v0,
}
)
class StyledComponent(ParentComponent):
tag = "StyledComponent"
def add_style(self):
return {
"color": v1,
"fake": v2,
"margin": f"{test_state.num}%",
}
page = rx.vstack(StyledComponent.create())
page._add_style_recursive(Style())
assert (
"const test_state = useContext(StateContexts.test_state)"
in page._get_all_hooks_internal()
)
assert "useText" in page._get_all_hooks_internal()
assert "useParent" in page._get_all_hooks_internal()
assert (
str(page).count(
'css={{"fakeParent": "parent", "color": "var(--plum-10)", "fake": "text", "margin": `${test_state.num}%`}}'
)
== 1
)
def test_add_style_foreach():
class StyledComponent(Component):
tag = "StyledComponent"

View File

@ -8,7 +8,7 @@ import reflex as rx
from reflex import style
from reflex.components.component import evaluate_style_namespaces
from reflex.style import Style
from reflex.vars import Var
from reflex.vars import Var, VarData
test_style = [
({"a": 1}, {"a": 1}),
@ -503,3 +503,25 @@ def test_evaluate_style_namespaces():
assert rx.text.__call__ not in style_dict
style_dict = evaluate_style_namespaces(style_dict) # type: ignore
assert rx.text.__call__ in style_dict
def test_style_update_with_var_data():
"""Test that .update with a Style containing VarData works."""
red_var = Var.create_safe("red")._replace(
merge_var_data=VarData(hooks={"const red = true": None}), # type: ignore
)
blue_var = Var.create_safe("blue", _var_is_local=False)._replace(
merge_var_data=VarData(hooks={"const blue = true": None}), # type: ignore
)
s1 = Style(
{
"color": red_var,
}
)
s2 = Style()
s2.update(s1, background_color=f"{blue_var}ish")
assert s2 == {"color": "red", "backgroundColor": "`${blue}ish`"}
assert s2._var_data is not None
assert "const red = true" in s2._var_data.hooks
assert "const blue = true" in s2._var_data.hooks