diff --git a/reflex/components/plotly/plotly.py b/reflex/components/plotly/plotly.py index ed4b1a3fc..4554dfd28 100644 --- a/reflex/components/plotly/plotly.py +++ b/reflex/components/plotly/plotly.py @@ -4,14 +4,20 @@ from __future__ import annotations from typing import Any, Dict, List from reflex.base import Base -from reflex.components.component import NoSSRComponent +from reflex.components.component import Component, NoSSRComponent +from reflex.components.core.cond import color_mode_cond from reflex.event import EventHandler +from reflex.utils import console from reflex.vars import Var try: - from plotly.graph_objects import Figure + from plotly.graph_objects import Figure, layout + + Template = layout.Template except ImportError: + console.warn("Plotly is not installed. Please run `pip install plotly`.") Figure = Any # type: ignore + Template = Any # type: ignore def _event_data_signature(e0: Var) -> List[Any]: @@ -84,17 +90,13 @@ def _null_signature() -> List[Any]: return [] -class PlotlyLib(NoSSRComponent): - """A component that wraps a plotly lib.""" +class Plotly(NoSSRComponent): + """Display a plotly graph.""" library = "react-plotly.js@2.6.0" lib_dependencies: List[str] = ["plotly.js@2.22.0"] - -class Plotly(PlotlyLib): - """Display a plotly graph.""" - tag = "Plot" is_default = True @@ -105,6 +107,9 @@ class Plotly(PlotlyLib): # The layout of the graph. layout: Var[Dict] + # The template for visual appearance of the graph. + template: Var[Template] + # The config of the graph. config: Var[Dict] @@ -171,6 +176,17 @@ class Plotly(PlotlyLib): # Fired when a hovered element is no longer hovered. on_unhover: EventHandler[_event_points_data_signature] + def add_imports(self) -> dict[str, str]: + """Add imports for the plotly component. + + Returns: + The imports for the plotly component. + """ + return { + # For merging plotly data/layout/templates. + "mergician@v2.0.2": "mergician" + } + def add_custom_code(self) -> list[str]: """Add custom codes for processing the plotly points data. @@ -210,14 +226,62 @@ const extractPoints = (points) => { """, ] + @classmethod + def create(cls, *children, **props) -> Component: + """Create the Plotly component. + + Args: + *children: The children of the component. + **props: The properties of the component. + + Returns: + The Plotly component. + """ + from plotly.io import templates + + responsive_template = color_mode_cond( + light=Var.create_safe(templates["plotly"]).to(dict), + dark=Var.create_safe(templates["plotly_dark"]).to(dict), + ) + if isinstance(responsive_template, Var): + # Mark the conditional Var as a Template to avoid type mismatch + responsive_template = responsive_template.to(Template) + props.setdefault("template", responsive_template) + return super().create(*children, **props) + + def _exclude_props(self) -> set[str]: + # These props are handled specially in the _render function + return {"data", "layout", "template"} + def _render(self): tag = super()._render() figure = self.data.to(dict) - if self.layout is None: - tag.remove_props("data", "layout") + merge_dicts = [] # Data will be merged and spread from these dict Vars + if self.layout is not None: + # Why is this not a literal dict? Great question... it didn't work + # reliably because of how _var_name_unwrapped strips the outer curly + # brackets if any of the contained Vars depend on state. + layout_dict = Var.create_safe( + f"{{'layout': {self.layout.to(dict)._var_name_unwrapped}}}" + ).to(dict) + merge_dicts.append(layout_dict) + if self.template is not None: + template_dict = Var.create_safe( + {"layout": {"template": self.template.to(dict)}} + ) + template_dict._var_data = None # To avoid stripping outer curly brackets + merge_dicts.append(template_dict) + if merge_dicts: + tag.special_props.add( + # Merge all dictionaries and spread the result over props. + Var.create_safe( + f"{{...mergician({figure._var_name_unwrapped}," + f"{','.join(md._var_name_unwrapped for md in merge_dicts)})}}", + ), + ) + else: + # Spread the figure dict over props, nothing to merge. tag.special_props.add( Var.create_safe(f"{{...{figure._var_name_unwrapped}}}") ) - else: - tag.add_props(data=figure["data"]) return tag diff --git a/reflex/components/plotly/plotly.pyi b/reflex/components/plotly/plotly.pyi index b8a57f781..46dbab62d 100644 --- a/reflex/components/plotly/plotly.pyi +++ b/reflex/components/plotly/plotly.pyi @@ -9,97 +9,28 @@ from reflex.event import EventChain, EventHandler, EventSpec from reflex.style import Style from typing import Any, Dict, List from reflex.base import Base -from reflex.components.component import NoSSRComponent +from reflex.components.component import Component, NoSSRComponent +from reflex.components.core.cond import color_mode_cond from reflex.event import EventHandler +from reflex.utils import console from reflex.vars import Var try: - from plotly.graph_objects import Figure # type: ignore + from plotly.graph_objects import Figure, layout # type: ignore + + Template = layout.Template except ImportError: + console.warn("Plotly is not installed. Please run `pip install plotly`.") Figure = Any # type: ignore + Template = Any class _ButtonClickData(Base): menu: Any button: Any active: Any -class PlotlyLib(NoSSRComponent): - @overload - @classmethod - def create( # type: ignore - cls, - *children, - style: Optional[Style] = None, - key: Optional[Any] = None, - id: Optional[Any] = None, - class_name: Optional[Any] = None, - autofocus: Optional[bool] = None, - custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, - on_blur: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_click: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_context_menu: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_double_click: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_focus: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_mount: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_mouse_down: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_mouse_enter: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_mouse_leave: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_mouse_move: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_mouse_out: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_mouse_over: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_mouse_up: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_scroll: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - on_unmount: Optional[ - Union[EventHandler, EventSpec, list, function, BaseVar] - ] = None, - **props - ) -> "PlotlyLib": - """Create the component. - - Args: - *children: The children of the component. - style: The style of the component. - key: A unique key for the component. - id: The id for the component. - class_name: The class name for the component. - autofocus: Whether the component should take the focus once the page is loaded - custom_attrs: custom attribute - **props: The props of the component. - - Returns: - The component. - """ - ... - -class Plotly(PlotlyLib): +class Plotly(NoSSRComponent): + def add_imports(self) -> dict[str, str]: ... def add_custom_code(self) -> list[str]: ... @overload @classmethod @@ -108,6 +39,7 @@ class Plotly(PlotlyLib): *children, data: Optional[Union[Var[Figure], Figure]] = None, # type: ignore layout: Optional[Union[Var[Dict], Dict]] = None, + template: Optional[Union[Var[Template], Template]] = None, # type: ignore config: Optional[Union[Var[Dict], Dict]] = None, use_resize_handler: Optional[Union[Var[bool], bool]] = None, style: Optional[Style] = None, @@ -217,12 +149,13 @@ class Plotly(PlotlyLib): ] = None, **props ) -> "Plotly": - """Create the component. + """Create the Plotly component. Args: *children: The children of the component. data: The figure to display. This can be a plotly figure or a plotly data json. layout: The layout of the graph. + template: The template for visual appearance of the graph. config: The config of the graph. use_resize_handler: If true, the graph will resize when the window is resized. style: The style of the component. @@ -231,9 +164,9 @@ class Plotly(PlotlyLib): class_name: The class name for the component. autofocus: Whether the component should take the focus once the page is loaded custom_attrs: custom attribute - **props: The props of the component. + **props: The properties of the component. Returns: - The component. + The Plotly component. """ ... diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 3b5352a04..cc27aac70 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -32,7 +32,7 @@ logger = logging.getLogger("pyi_generator") PWD = Path(".").resolve() EXCLUDED_FILES = [ - # "app.py", + "app.py", "component.py", "bare.py", "foreach.py", @@ -856,7 +856,11 @@ class PyiGenerator: mode=black.mode.Mode(is_pyi=True), ).splitlines(): # Bit of a hack here, since the AST cannot represent comments. - if "def create(" in formatted_line or "Figure" in formatted_line: + if ( + "def create(" in formatted_line + or "Figure" in formatted_line + or "Var[Template]" in formatted_line + ): pyi_content.append(formatted_line + " # type: ignore") else: pyi_content.append(formatted_line) @@ -956,6 +960,7 @@ class PyiGenerator: target_path.is_file() and target_path.suffix == ".py" and target_path.name not in EXCLUDED_FILES + and "reflex/components" in str(target_path) ): file_targets.append(target_path) continue diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index aa9b6e484..c35be95d3 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -314,7 +314,7 @@ except ImportError: pass try: - from plotly.graph_objects import Figure + from plotly.graph_objects import Figure, layout from plotly.io import to_json @serializer @@ -329,6 +329,21 @@ try: """ return json.loads(str(to_json(figure))) + @serializer + def serialize_template(template: layout.Template) -> dict: + """Serialize a plotly template. + + Args: + template: The template to serialize. + + Returns: + The serialized template. + """ + return { + "data": json.loads(str(to_json(template.data))), + "layout": json.loads(str(to_json(template.layout))), + } + except ImportError: pass