diff --git a/reflex/__init__.py b/reflex/__init__.py index 3cf8bfe8f..4b2dc7e8a 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -214,6 +214,7 @@ COMPONENTS_CORE_MAPPING: dict = { "components.core.match": ["match"], "components.core.clipboard": ["clipboard"], "components.core.colors": ["color"], + "components.core.breakpoints": ["breakpoints"], "components.core.responsive": [ "desktop_only", "mobile_and_tablet", diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index 77ac16456..0b55c6661 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -131,6 +131,7 @@ from .components.core.html import html as html from .components.core.match import match as match from .components.core.clipboard import clipboard as clipboard from .components.core.colors import color as color +from .components.core.breakpoints import breakpoints as breakpoints from .components.core.responsive import desktop_only as desktop_only from .components.core.responsive import mobile_and_tablet as mobile_and_tablet from .components.core.responsive import mobile_only as mobile_only diff --git a/reflex/app.py b/reflex/app.py index 489755d88..bb3e1d403 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -52,6 +52,7 @@ from reflex.components.component import ( evaluate_style_namespaces, ) from reflex.components.core.banner import connection_pulser, connection_toaster +from reflex.components.core.breakpoints import set_breakpoints from reflex.components.core.client_side_routing import ( Default404Page, wait_for_client_redirect, @@ -245,6 +246,9 @@ class App(LifespanMixin, Base): "rx.BaseState cannot be subclassed multiple times. use rx.State instead" ) + if "breakpoints" in self.style: + set_breakpoints(self.style.pop("breakpoints")) + # Add middleware. self.middleware.append(HydrateMiddleware()) diff --git a/reflex/components/core/__init__.py b/reflex/components/core/__init__.py index a2332cabd..fbe0bdc84 100644 --- a/reflex/components/core/__init__.py +++ b/reflex/components/core/__init__.py @@ -32,6 +32,7 @@ _SUBMOD_ATTRS: dict[str, list[str]] = { "match", "Match", ], + "breakpoints": ["breakpoints", "set_breakpoints"], "responsive": [ "desktop_only", "mobile_and_tablet", diff --git a/reflex/components/core/__init__.pyi b/reflex/components/core/__init__.pyi index 0d11eade4..2004cefa7 100644 --- a/reflex/components/core/__init__.pyi +++ b/reflex/components/core/__init__.pyi @@ -27,6 +27,8 @@ from .html import html as html from .html import Html as Html from .match import match as match from .match import Match as Match +from .breakpoints import breakpoints as breakpoints +from .breakpoints import set_breakpoints as set_breakpoints from .responsive import desktop_only as desktop_only from .responsive import mobile_and_tablet as mobile_and_tablet from .responsive import mobile_only as mobile_only diff --git a/reflex/components/core/breakpoints.py b/reflex/components/core/breakpoints.py new file mode 100644 index 000000000..b9817a179 --- /dev/null +++ b/reflex/components/core/breakpoints.py @@ -0,0 +1,66 @@ +"""Breakpoints utility.""" + +from typing import Optional, Tuple + +breakpoints_values = ["30em", "48em", "62em", "80em", "96em"] + + +def set_breakpoints(values: Tuple[str, str, str, str, str]): + """Overwrite default breakpoint values. + + Args: + values: CSS values in order defining the breakpoints of responsive layouts + """ + breakpoints_values.clear() + breakpoints_values.extend(values) + + +class Breakpoints(dict): + """A responsive styling helper.""" + + @classmethod + def create( + cls, + custom: Optional[dict] = None, + initial=None, + xs=None, + sm=None, + md=None, + lg=None, + xl=None, + ): + """Create a new instance of the helper. Only provide a custom component OR use named props. + + Args: + custom: Custom mapping using CSS values or variables. + initial: Styling when in the inital width + xs: Styling when in the extra-small width + sm: Styling when in the small width + md: Styling when in the medium width + lg: Styling when in the large width + xl: Styling when in the extra-large width + + Raises: + ValueError: If both custom and any other named parameters are provided. + + Returns: + The responsive mapping. + """ + thresholds = [initial, xs, sm, md, lg, xl] + + if custom is not None: + if any((threshold is not None for threshold in thresholds)): + raise ValueError("Named props cannot be used with custom thresholds") + + return Breakpoints(custom) + else: + return Breakpoints( + { + k: v + for k, v in zip(["0px", *breakpoints_values], thresholds) + if v is not None + } + ) + + +breakpoints = Breakpoints.create diff --git a/reflex/style.py b/reflex/style.py index ae2aa9059..2bb3dea76 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -5,6 +5,7 @@ from __future__ import annotations from typing import Any, Literal, Tuple, Type from reflex import constants +from reflex.components.core.breakpoints import Breakpoints, breakpoints_values from reflex.event import EventChain from reflex.utils import format from reflex.utils.imports import ImportVar @@ -86,8 +87,6 @@ toggle_color_mode = _color_mode_var( _var_type=EventChain, ) -breakpoints = ["0", "30em", "48em", "62em", "80em", "96em"] - STYLE_PROP_SHORTHAND_MAPPING = { "paddingX": ("paddingInlineStart", "paddingInlineEnd"), "paddingY": ("paddingTop", "paddingBottom"), @@ -100,16 +99,16 @@ STYLE_PROP_SHORTHAND_MAPPING = { } -def media_query(breakpoint_index: int): +def media_query(breakpoint_expr: str): """Create a media query selector. Args: - breakpoint_index: The index of the breakpoint to use. + breakpoint_expr: The CSS expression representing the breakpoint. Returns: The media query selector used as a key in emotion css dict. """ - return f"@media screen and (min-width: {breakpoints[breakpoint_index]})" + return f"@media screen and (min-width: {breakpoint_expr})" def convert_item(style_item: str | Var) -> tuple[str, VarData | None]: @@ -189,6 +188,10 @@ def convert(style_dict): update_out_dict(return_val, keys) # Combine all the collected VarData instances. var_data = VarData.merge(var_data, new_var_data) + + if isinstance(style_dict, Breakpoints): + out = Breakpoints(out) + return out, var_data @@ -295,14 +298,22 @@ def format_as_emotion(style_dict: dict[str, Any]) -> Style | None: for orig_key, value in style_dict.items(): key = _format_emotion_style_pseudo_selector(orig_key) - if isinstance(value, list): - # Apply media queries from responsive value list. - mbps = { - media_query(bp): ( - bp_value if isinstance(bp_value, dict) else {key: bp_value} - ) - for bp, bp_value in enumerate(value) - } + if isinstance(value, (Breakpoints, list)): + if isinstance(value, Breakpoints): + mbps = { + media_query(bp): ( + bp_value if isinstance(bp_value, dict) else {key: bp_value} + ) + for bp, bp_value in value.items() + } + else: + # Apply media queries from responsive value list. + mbps = { + media_query([0, *breakpoints_values][bp]): ( + bp_value if isinstance(bp_value, dict) else {key: bp_value} + ) + for bp, bp_value in enumerate(value) + } if key.startswith("&:"): emotion_style[key] = mbps else: