diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index c39e427cf..052746ad8 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -2,7 +2,7 @@ from __future__ import annotations import os -from typing import Any, Type +from typing import Any, Callable, Type from urllib.parse import urlparse from pydantic.fields import ModelField @@ -290,7 +290,7 @@ def create_theme(style: ComponentStyle) -> dict: The base style for the app. """ # Get the global style from the style dict. - style_rules = Style({k: v for k, v in style.items() if not isinstance(k, type)}) + style_rules = Style({k: v for k, v in style.items() if not isinstance(k, Callable)}) root_style = { # Root styles. diff --git a/reflex/components/component.py b/reflex/components/component.py index 428b4b64f..971188f35 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -115,7 +115,7 @@ class BaseComponent(Base, ABC): # Map from component to styling. -ComponentStyle = Dict[Union[str, Type[BaseComponent]], Any] +ComponentStyle = Dict[Union[str, Type[BaseComponent], Callable], Any] ComponentChild = Union[types.PrimitiveType, Var, BaseComponent] @@ -600,10 +600,13 @@ class Component(BaseComponent, ABC): Returns: The component with the additional style. """ + component_style = None if type(self) in style: # Extract the style for this component. component_style = Style(style[type(self)]) - + if self.create in style: + component_style = Style(style[self.create]) + if component_style is not None: # Only add style props that are not overridden. component_style = { k: v for k, v in component_style.items() if k not in self.style diff --git a/scripts/pyi_generator.py b/scripts/pyi_generator.py index 4cd2972c4..8caf91060 100644 --- a/scripts/pyi_generator.py +++ b/scripts/pyi_generator.py @@ -704,7 +704,6 @@ class PyiGenerator: def _write_pyi_file(self, module_path: Path, source: str): relpath = str(_relative_to_pwd(module_path)).replace("\\", "/") - print(f"Writing {relpath}") pyi_content = [ f'"""Stub file for {relpath}"""', "# ------------------- DO NOT EDIT ----------------------", diff --git a/tests/components/test_component.py b/tests/components/test_component.py index e4bad40db..fb9479979 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -268,6 +268,23 @@ def test_add_style(component1, component2): assert c2.style["color"] == "black" +def test_add_style_create(component1, component2): + """Test that adding style works with the create method. + + Args: + component1: A test component. + component2: A test component. + """ + style = { + component1.create: Style({"color": "white"}), + component2.create: Style({"color": "black"}), + } + c1 = component1().add_style(style) # type: ignore + c2 = component2().add_style(style) # type: ignore + assert c1.style["color"] == "white" + assert c2.style["color"] == "black" + + def test_get_imports(component1, component2): """Test getting the imports of a component.