From 1bc1978c3167475748e35d71fa68af6cbb8b1185 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Thu, 18 Jul 2024 13:26:46 -0700 Subject: [PATCH 01/34] [REF-3135] Radix Primitive components should not ignore provided `class_name` prop (#3676) --- reflex/components/radix/primitives/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/components/radix/primitives/base.py b/reflex/components/radix/primitives/base.py index 3ce6549bf..857e80b5a 100644 --- a/reflex/components/radix/primitives/base.py +++ b/reflex/components/radix/primitives/base.py @@ -26,7 +26,7 @@ class RadixPrimitiveComponentWithClassName(RadixPrimitiveComponent): ._render() .add_props( **{ - "class_name": format.to_title_case(self.tag or ""), + "class_name": f"{format.to_title_case(self.tag or '')} {self.class_name or ''}", } ) ) From 1da606dd8e88f2c1b966f1d308c5a7a276b33bb2 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 18 Jul 2024 18:51:33 -0700 Subject: [PATCH 02/34] [REF-3197] Browser init workflow (#3673) --- reflex/components/el/elements/metadata.py | 6 ++- reflex/constants/base.py | 21 ++++++++++ reflex/reflex.py | 21 ++++++++-- reflex/utils/prerequisites.py | 36 +++++++++++++++++ reflex/utils/redir.py | 48 +++++++++++++++++++++++ 5 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 reflex/utils/redir.py diff --git a/reflex/components/el/elements/metadata.py b/reflex/components/el/elements/metadata.py index 77677df07..c19612abe 100644 --- a/reflex/components/el/elements/metadata.py +++ b/reflex/components/el/elements/metadata.py @@ -1,6 +1,6 @@ """Element classes. This is an auto-generated file. Do not edit. See ../generate.py.""" -from typing import Union +from typing import Set, Union from reflex.components.el.element import Element from reflex.vars import Var as Var @@ -64,6 +64,10 @@ class StyleEl(Element): # noqa: E742 media: Var[Union[str, int, bool]] + special_props: Set[Var] = { + Var.create_safe("suppressHydrationWarning", _var_is_string=False) + } + base = Base.create head = Head.create diff --git a/reflex/constants/base.py b/reflex/constants/base.py index c818fbf06..65d957d27 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -94,6 +94,27 @@ class Templates(SimpleNamespace): # The default template DEFAULT = "blank" + # The reflex.build frontend host + REFLEX_BUILD_FRONTEND = os.environ.get( + "REFLEX_BUILD_FRONTEND", "https://flexgen.reflex.run" + ) + + # The reflex.build backend host + REFLEX_BUILD_BACKEND = os.environ.get( + "REFLEX_BUILD_BACKEND", "https://rxh-prod-flexgen.fly.dev" + ) + + # The URL to redirect to reflex.build + REFLEX_BUILD_URL = ( + REFLEX_BUILD_FRONTEND + "/gen?reflex_init_token={reflex_init_token}" + ) + + # The URL to poll waiting for the user to select a generation. + REFLEX_BUILD_POLL_URL = REFLEX_BUILD_BACKEND + "/api/init/{reflex_init_token}" + + # The URL to fetch the generation's reflex code + REFLEX_BUILD_CODE_URL = REFLEX_BUILD_BACKEND + "/api/gen/{generation_hash}" + class Dirs(SimpleNamespace): """Folders used by the template system of Reflex.""" diff --git a/reflex/reflex.py b/reflex/reflex.py index a9e164477..b8a4ed03a 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -16,7 +16,7 @@ from reflex_cli.utils import dependency from reflex import constants from reflex.config import get_config from reflex.custom_components.custom_components import custom_components_cli -from reflex.utils import console, telemetry +from reflex.utils import console, redir, telemetry # Disable typer+rich integration for help panels typer.core.rich = False # type: ignore @@ -65,6 +65,7 @@ def _init( name: str, template: str | None = None, loglevel: constants.LogLevel = config.loglevel, + ai: bool = False, ): """Initialize a new Reflex app in the given directory.""" from reflex.utils import exec, prerequisites @@ -91,8 +92,16 @@ def _init( # Set up the web project. prerequisites.initialize_frontend_dependencies() - # Initialize the app. - prerequisites.initialize_app(app_name, template) + # Check if AI is requested and redirect the user to reflex.build. + if ai: + prerequisites.initialize_app(app_name, template=constants.Templates.DEFAULT) + generation_hash = redir.reflex_build_redirect() + prerequisites.initialize_main_module_index_from_generation( + app_name, generation_hash=generation_hash + ) + else: + # Initialize the app. + prerequisites.initialize_app(app_name, template) # Migrate Pynecone projects to Reflex. prerequisites.migrate_to_reflex() @@ -119,9 +128,13 @@ def init( loglevel: constants.LogLevel = typer.Option( config.loglevel, help="The log level to use." ), + ai: bool = typer.Option( + False, + help="Use AI to create the initial template. Cannot be used with existing app or `--template` option.", + ), ): """Initialize a new Reflex app in the current directory.""" - _init(name, template, loglevel) + _init(name, template, loglevel, ai) def _run( diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index b09c6fa23..c8455f259 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -16,6 +16,7 @@ import shutil import stat import sys import tempfile +import textwrap import zipfile from datetime import datetime from fileinput import FileInput @@ -1475,6 +1476,41 @@ def initialize_app(app_name: str, template: str | None = None): telemetry.send("init", template=template) +def initialize_main_module_index_from_generation(app_name: str, generation_hash: str): + """Overwrite the `index` function in the main module with reflex.build generated code. + + Args: + app_name: The name of the app. + generation_hash: The generation hash from reflex.build. + """ + # Download the reflex code for the generation. + resp = httpx.get( + constants.Templates.REFLEX_BUILD_CODE_URL.format( + generation_hash=generation_hash + ) + ).raise_for_status() + + def replace_content(_match): + return "\n".join( + [ + "def index() -> rx.Component:", + textwrap.indent("return " + resp.text, " "), + "", + "", + ], + ) + + main_module_path = Path(app_name, app_name + constants.Ext.PY) + main_module_code = main_module_path.read_text() + main_module_path.write_text( + re.sub( + r"def index\(\).*:\n([^\n]\s+.*\n+)+", + replace_content, + main_module_code, + ) + ) + + def format_address_width(address_width) -> int | None: """Cast address width to an int. diff --git a/reflex/utils/redir.py b/reflex/utils/redir.py new file mode 100644 index 000000000..461f055bb --- /dev/null +++ b/reflex/utils/redir.py @@ -0,0 +1,48 @@ +"""Utilities to handle redirection to browser UI.""" + +import time +import uuid +import webbrowser + +import httpx + +from .. import constants +from . import console + + +def open_browser_and_wait( + target_url: str, poll_url: str, interval: int = 1 +) -> httpx.Response: + """Open a browser window to target_url and request poll_url until it returns successfully. + + Args: + target_url: The URL to open in the browser. + poll_url: The URL to poll for success. + interval: The interval in seconds to wait between polling. + + Returns: + The response from the poll_url. + """ + if not webbrowser.open(target_url): + console.warn( + f"Unable to automatically open the browser. Please navigate to {target_url} in your browser." + ) + console.info("Complete the workflow in the browser to continue.") + while response := httpx.get(poll_url, follow_redirects=True): + if response.is_success: + break + time.sleep(interval) + return response + + +def reflex_build_redirect() -> str: + """Open the browser window to reflex.build and wait for the user to select a generation. + + Returns: + The selected generation hash. + """ + token = str(uuid.uuid4()) + target_url = constants.Templates.REFLEX_BUILD_URL.format(reflex_init_token=token) + poll_url = constants.Templates.REFLEX_BUILD_POLL_URL.format(reflex_init_token=token) + response = open_browser_and_wait(target_url, poll_url) + return response.json()["generation_hash"] From 921a5cd6326066961b665bedde622ae0e6780a50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Fri, 19 Jul 2024 18:02:50 +0200 Subject: [PATCH 03/34] fix typo (#3689) --- reflex/components/recharts/cartesian.py | 2 +- reflex/components/recharts/cartesian.pyi | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/reflex/components/recharts/cartesian.py b/reflex/components/recharts/cartesian.py index c9d211149..e3f086a3d 100644 --- a/reflex/components/recharts/cartesian.py +++ b/reflex/components/recharts/cartesian.py @@ -406,7 +406,7 @@ class Line(Cartesian): stroke: Var[Union[str, Color]] = Var.create_safe(Color("accent", 9)) # The width of the line stroke. - stoke_width: Var[int] + stroke_width: Var[int] # The dot is shown when mouse enter a line chart and this chart has tooltip. If false set, no active dot will not be drawn. If true set, active dot will be drawn which have the props calculated internally. dot: Var[Union[bool, Dict[str, Any]]] = Var.create_safe( diff --git a/reflex/components/recharts/cartesian.pyi b/reflex/components/recharts/cartesian.pyi index 989207ab2..21be32b46 100644 --- a/reflex/components/recharts/cartesian.pyi +++ b/reflex/components/recharts/cartesian.pyi @@ -1232,7 +1232,7 @@ class Line(Cartesian): ] ] = None, stroke: Optional[Union[Var[Union[Color, str]], str, Color]] = None, - stoke_width: Optional[Union[Var[int], int]] = None, + stroke_width: Optional[Union[Var[int], int]] = None, dot: Optional[ Union[Var[Union[Dict[str, Any], bool]], bool, Dict[str, Any]] ] = None, @@ -1344,7 +1344,7 @@ class Line(Cartesian): *children: The children of the component. type_: The interpolation type of line. And customized interpolation function can be set to type. It's the same as type in Area. stroke: The color of the line stroke. - stoke_width: The width of the line stroke. + stroke_width: The width of the line stroke. dot: The dot is shown when mouse enter a line chart and this chart has tooltip. If false set, no active dot will not be drawn. If true set, active dot will be drawn which have the props calculated internally. active_dot: The dot is shown when user enter an area chart and this chart has tooltip. If false set, no active dot will not be drawn. If true set, active dot will be drawn which have the props calculated internally. label: If false set, labels will not be drawn. If true set, labels will be drawn which have the props calculated internally. From 8b45d289bd700a4db570ecc2831d5401b3267068 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 19 Jul 2024 14:05:53 -0700 Subject: [PATCH 04/34] reflex init --ai fixups (#3691) * reflex.build polling: catch exceptions and retry Extend interval to 2 seconds to reduce server load. * Add function to check if a string looks like a generation hash * reflex init: allow --template to be a generation hash if --ai is specified --- reflex/reflex.py | 26 ++++++++++++++++++++------ reflex/utils/prerequisites.py | 12 ++++++++++++ reflex/utils/redir.py | 14 +++++++++----- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/reflex/reflex.py b/reflex/reflex.py index b8a4ed03a..224bace47 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -92,16 +92,30 @@ def _init( # Set up the web project. prerequisites.initialize_frontend_dependencies() - # Check if AI is requested and redirect the user to reflex.build. + # Integrate with reflex.build. + generation_hash = None if ai: - prerequisites.initialize_app(app_name, template=constants.Templates.DEFAULT) - generation_hash = redir.reflex_build_redirect() + if template is None: + # If AI is requested and no template specified, redirect the user to reflex.build. + generation_hash = redir.reflex_build_redirect() + elif prerequisites.is_generation_hash(template): + # Otherwise treat the template as a generation hash. + generation_hash = template + else: + console.error( + "Cannot use `--template` option with `--ai` option. Please remove `--template` option." + ) + raise typer.Exit(2) + template = constants.Templates.DEFAULT + + # Initialize the app. + prerequisites.initialize_app(app_name, template) + + # If a reflex.build generation hash is available, download the code and apply it to the main module. + if generation_hash: prerequisites.initialize_main_module_index_from_generation( app_name, generation_hash=generation_hash ) - else: - # Initialize the app. - prerequisites.initialize_app(app_name, template) # Migrate Pynecone projects to Reflex. prerequisites.migrate_to_reflex() diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index c8455f259..40a2338d3 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -1598,3 +1598,15 @@ def is_windows_bun_supported() -> bool: and cpu_info.model_name is not None and "ARM" not in cpu_info.model_name ) + + +def is_generation_hash(template: str) -> bool: + """Check if the template looks like a generation hash. + + Args: + template: The template name. + + Returns: + True if the template is composed of 32 or more hex characters. + """ + return re.match(r"^[0-9a-f]{32,}$", template) is not None diff --git a/reflex/utils/redir.py b/reflex/utils/redir.py index 461f055bb..1dbd989e9 100644 --- a/reflex/utils/redir.py +++ b/reflex/utils/redir.py @@ -11,7 +11,7 @@ from . import console def open_browser_and_wait( - target_url: str, poll_url: str, interval: int = 1 + target_url: str, poll_url: str, interval: int = 2 ) -> httpx.Response: """Open a browser window to target_url and request poll_url until it returns successfully. @@ -27,10 +27,14 @@ def open_browser_and_wait( console.warn( f"Unable to automatically open the browser. Please navigate to {target_url} in your browser." ) - console.info("Complete the workflow in the browser to continue.") - while response := httpx.get(poll_url, follow_redirects=True): - if response.is_success: - break + console.info("[b]Complete the workflow in the browser to continue.[/b]") + while True: + try: + response = httpx.get(poll_url, follow_redirects=True) + if response.is_success: + break + except httpx.RequestError as err: + console.info(f"Will retry after error occurred while polling: {err}.") time.sleep(interval) return response From decdc857be1b8928e27ec29f9307bee6ab683a07 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Mon, 22 Jul 2024 17:56:34 +0000 Subject: [PATCH 05/34] `rx.list` Dont set/hardcode `list_style_position` css prop (#3695) --- reflex/components/radix/themes/layout/list.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/reflex/components/radix/themes/layout/list.py b/reflex/components/radix/themes/layout/list.py index 31c3ed7b3..fb77b223d 100644 --- a/reflex/components/radix/themes/layout/list.py +++ b/reflex/components/radix/themes/layout/list.py @@ -70,7 +70,6 @@ class BaseList(Component): children = [Foreach.create(items, ListItem.create)] else: children = [ListItem.create(item) for item in items] # type: ignore - props["list_style_position"] = "outside" props["direction"] = "column" style = props.setdefault("style", {}) style["list_style_type"] = list_style_type @@ -86,7 +85,6 @@ class BaseList(Component): """ return { "direction": "column", - "list_style_position": "inside", } From 9666244a879a5e8d1e3b86df26973c8118116891 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Mon, 22 Jul 2024 19:05:50 +0000 Subject: [PATCH 06/34] [REF-3273] Add SVG circle, polygon and rect components (#3684) --- reflex/components/el/__init__.py | 1 + reflex/components/el/__init__.pyi | 13 +- reflex/components/el/elements/__init__.py | 4 +- reflex/components/el/elements/__init__.pyi | 14 +- reflex/components/el/elements/media.py | 116 ++++- reflex/components/el/elements/media.pyi | 541 ++++++++++++++++++++- reflex/utils/pyi_generator.py | 2 + 7 files changed, 643 insertions(+), 48 deletions(-) diff --git a/reflex/components/el/__init__.py b/reflex/components/el/__init__.py index 750e65dba..9fe1d89cd 100644 --- a/reflex/components/el/__init__.py +++ b/reflex/components/el/__init__.py @@ -10,6 +10,7 @@ _SUBMODULES: set[str] = {"elements"} _SUBMOD_ATTRS: dict[str, list[str]] = { f"elements.{k}": v for k, v in elements._MAPPING.items() } +_PYRIGHT_IGNORE_IMPORTS = elements._PYRIGHT_IGNORE_IMPORTS __getattr__, __dir__, __all__ = lazy_loader.attach( __name__, diff --git a/reflex/components/el/__init__.pyi b/reflex/components/el/__init__.pyi index 080c9c37c..d312db84e 100644 --- a/reflex/components/el/__init__.pyi +++ b/reflex/components/el/__init__.pyi @@ -3,6 +3,7 @@ # This file was generated by `reflex/utils/pyi_generator.py`! # ------------------------------------------------------ +from . import elements from .elements.forms import Button as Button from .elements.forms import Fieldset as Fieldset from .elements.forms import Form as Form @@ -91,7 +92,7 @@ from .elements.media import Defs as Defs from .elements.media import Embed as Embed from .elements.media import Iframe as Iframe from .elements.media import Img as Img -from .elements.media import Lineargradient as Lineargradient +from .elements.media import LinearGradient as LinearGradient from .elements.media import Map as Map from .elements.media import Object as Object from .elements.media import Path as Path @@ -104,19 +105,19 @@ from .elements.media import Track as Track from .elements.media import Video as Video from .elements.media import area as area from .elements.media import audio as audio -from .elements.media import defs as defs +from .elements.media import defs as defs # type: ignore from .elements.media import embed as embed from .elements.media import iframe as iframe from .elements.media import image as image from .elements.media import img as img -from .elements.media import lineargradient as lineargradient +from .elements.media import lineargradient as lineargradient # type: ignore from .elements.media import map as map from .elements.media import object as object -from .elements.media import path as path +from .elements.media import path as path # type: ignore from .elements.media import picture as picture from .elements.media import portal as portal from .elements.media import source as source -from .elements.media import stop as stop +from .elements.media import stop as stop # type: ignore from .elements.media import svg as svg from .elements.media import track as track from .elements.media import video as video @@ -230,3 +231,5 @@ from .elements.typography import ol as ol from .elements.typography import p as p from .elements.typography import pre as pre from .elements.typography import ul as ul + +_PYRIGHT_IGNORE_IMPORTS = elements._PYRIGHT_IGNORE_IMPORTS diff --git a/reflex/components/el/elements/__init__.py b/reflex/components/el/elements/__init__.py index 1c2684f1d..024ae8c3d 100644 --- a/reflex/components/el/elements/__init__.py +++ b/reflex/components/el/elements/__init__.py @@ -67,6 +67,7 @@ _MAPPING = { "svg", "defs", "lineargradient", + "LinearGradient", "stop", "path", ], @@ -129,12 +130,13 @@ _MAPPING = { } -EXCLUDE = ["del_", "Del", "image"] +EXCLUDE = ["del_", "Del", "image", "lineargradient", "LinearGradient"] for _, v in _MAPPING.items(): v.extend([mod.capitalize() for mod in v if mod not in EXCLUDE]) _SUBMOD_ATTRS: dict[str, list[str]] = _MAPPING +_PYRIGHT_IGNORE_IMPORTS = ["stop", "lineargradient", "path", "defs"] __getattr__, __dir__, __all__ = lazy_loader.attach( __name__, submod_attrs=_SUBMOD_ATTRS, diff --git a/reflex/components/el/elements/__init__.pyi b/reflex/components/el/elements/__init__.pyi index 8c35756e5..4f218f361 100644 --- a/reflex/components/el/elements/__init__.pyi +++ b/reflex/components/el/elements/__init__.pyi @@ -91,7 +91,7 @@ from .media import Defs as Defs from .media import Embed as Embed from .media import Iframe as Iframe from .media import Img as Img -from .media import Lineargradient as Lineargradient +from .media import LinearGradient as LinearGradient from .media import Map as Map from .media import Object as Object from .media import Path as Path @@ -104,19 +104,19 @@ from .media import Track as Track from .media import Video as Video from .media import area as area from .media import audio as audio -from .media import defs as defs +from .media import defs as defs # type: ignore from .media import embed as embed from .media import iframe as iframe from .media import image as image from .media import img as img -from .media import lineargradient as lineargradient +from .media import lineargradient as lineargradient # type: ignore from .media import map as map from .media import object as object -from .media import path as path +from .media import path as path # type: ignore from .media import picture as picture from .media import portal as portal from .media import source as source -from .media import stop as stop +from .media import stop as stop # type: ignore from .media import svg as svg from .media import track as track from .media import video as video @@ -294,6 +294,7 @@ _MAPPING = { "svg", "defs", "lineargradient", + "LinearGradient", "stop", "path", ], @@ -347,6 +348,7 @@ _MAPPING = { "Del", ], } -EXCLUDE = ["del_", "Del", "image"] +EXCLUDE = ["del_", "Del", "image", "lineargradient", "LinearGradient"] for _, v in _MAPPING.items(): v.extend([mod.capitalize() for mod in v if mod not in EXCLUDE]) +_PYRIGHT_IGNORE_IMPORTS = ["stop", "lineargradient", "path", "defs"] diff --git a/reflex/components/el/elements/media.py b/reflex/components/el/elements/media.py index 8d56c78b4..b2bdc9e6f 100644 --- a/reflex/components/el/elements/media.py +++ b/reflex/components/el/elements/media.py @@ -2,8 +2,9 @@ from typing import Any, Union -from reflex import Component +from reflex import Component, ComponentNamespace from reflex.constants.colors import Color +from reflex.utils import console from reflex.vars import Var as Var from .base import BaseHTML @@ -309,6 +310,56 @@ class Svg(BaseHTML): """Display the svg element.""" tag = "svg" + # The width of the svg. + width: Var[Union[str, int]] + # The height of the svg. + height: Var[Union[str, int]] + # The XML namespace declaration. + xmlns: Var[str] + + +class Circle(BaseHTML): + """The SVG circle component.""" + + tag = "circle" + # The x-axis coordinate of the center of the circle. + cx: Var[Union[str, int]] + # The y-axis coordinate of the center of the circle. + cy: Var[Union[str, int]] + # The radius of the circle. + r: Var[Union[str, int]] + # The total length for the circle's circumference, in user units. + path_length: Var[int] + + +class Rect(BaseHTML): + """The SVG rect component.""" + + tag = "rect" + # The x coordinate of the rect. + x: Var[Union[str, int]] + # The y coordinate of the rect. + y: Var[Union[str, int]] + # The width of the rect + width: Var[Union[str, int]] + # The height of the rect. + height: Var[Union[str, int]] + # The horizontal corner radius of the rect. Defaults to ry if it is specified. + rx: Var[Union[str, int]] + # The vertical corner radius of the rect. Defaults to rx if it is specified. + ry: Var[Union[str, int]] + # The total length of the rectangle's perimeter, in user units. + path_length: Var[int] + + +class Polygon(BaseHTML): + """The SVG polygon component.""" + + tag = "polygon" + # defines the list of points (pairs of x,y absolute coordinates) required to draw the polygon. + points: Var[str] + # This prop lets specify the total length for the path, in user units. + path_length: Var[int] class Defs(BaseHTML): @@ -317,30 +368,30 @@ class Defs(BaseHTML): tag = "defs" -class Lineargradient(BaseHTML): +class LinearGradient(BaseHTML): """Display the linearGradient element.""" tag = "linearGradient" - # Units for the gradient + # Units for the gradient. gradient_units: Var[Union[str, bool]] - # Transform applied to the gradient + # Transform applied to the gradient. gradient_transform: Var[Union[str, bool]] - # Method used to spread the gradient + # Method used to spread the gradient. spread_method: Var[Union[str, bool]] - # X coordinate of the starting point of the gradient + # X coordinate of the starting point of the gradient. x1: Var[Union[str, int, bool]] - # X coordinate of the ending point of the gradient + # X coordinate of the ending point of the gradient. x2: Var[Union[str, int, bool]] - # Y coordinate of the starting point of the gradient + # Y coordinate of the starting point of the gradient. y1: Var[Union[str, int, bool]] - # Y coordinate of the ending point of the gradient + # Y coordinate of the ending point of the gradient. y2: Var[Union[str, int, bool]] @@ -349,13 +400,13 @@ class Stop(BaseHTML): tag = "stop" - # Offset of the gradient stop + # Offset of the gradient stop. offset: Var[Union[str, float, int]] - # Color of the gradient stop + # Color of the gradient stop. stop_color: Var[Union[str, Color, bool]] - # Opacity of the gradient stop + # Opacity of the gradient stop. stop_opacity: Var[Union[str, float, int, bool]] @@ -364,10 +415,23 @@ class Path(BaseHTML): tag = "path" - # Defines the shape of the path + # Defines the shape of the path. d: Var[Union[str, int, bool]] +class SVG(ComponentNamespace): + """SVG component namespace.""" + + circle = staticmethod(Circle.create) + rect = staticmethod(Rect.create) + polygon = staticmethod(Polygon.create) + path = staticmethod(Path.create) + stop = staticmethod(Stop.create) + linear_gradient = staticmethod(LinearGradient.create) + defs = staticmethod(Defs.create) + __call__ = staticmethod(Svg.create) + + area = Area.create audio = Audio.create image = img = Img.create @@ -380,8 +444,24 @@ object = Object.create picture = Picture.create portal = Portal.create source = Source.create -svg = Svg.create -defs = Defs.create -lineargradient = Lineargradient.create -stop = Stop.create -path = Path.create +svg = SVG() + + +def __getattr__(name: str): + if name in ("defs", "lineargradient", "stop", "path"): + console.deprecate( + f"`rx.el.{name}`", + reason=f"use `rx.el.svg.{'linear_gradient' if name =='lineargradient' else name}`", + deprecation_version="0.5.8", + removal_version="0.6.0", + ) + return ( + LinearGradient.create + if name == "lineargradient" + else globals()[name.capitalize()].create + ) + + try: + return globals()[name] + except KeyError: + raise AttributeError(f"module '{__name__} has no attribute '{name}'") from None diff --git a/reflex/components/el/elements/media.pyi b/reflex/components/el/elements/media.pyi index 7a9a064fb..ba5f14137 100644 --- a/reflex/components/el/elements/media.pyi +++ b/reflex/components/el/elements/media.pyi @@ -5,6 +5,7 @@ # ------------------------------------------------------ from typing import Any, Callable, Dict, Optional, Union, overload +from reflex import ComponentNamespace from reflex.constants.colors import Color from reflex.event import EventHandler, EventSpec from reflex.style import Style @@ -1563,6 +1564,9 @@ class Svg(BaseHTML): def create( # type: ignore cls, *children, + width: Optional[Union[Var[Union[int, str]], str, int]] = None, + height: Optional[Union[Var[Union[int, str]], str, int]] = None, + xmlns: Optional[Union[Var[str], str]] = None, access_key: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, auto_capitalize: Optional[ Union[Var[Union[bool, int, str]], str, int, bool] @@ -1644,6 +1648,383 @@ class Svg(BaseHTML): Args: *children: The children of the component. + width: The width of the svg. + height: The height of the svg. + xmlns: The XML namespace declaration. + access_key: Provides a hint for generating a keyboard shortcut for the current element. + auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. + content_editable: Indicates whether the element's content is editable. + context_menu: Defines the ID of a element which will serve as the element's context menu. + dir: Defines the text direction. Allowed values are ltr (Left-To-Right) or rtl (Right-To-Left) + draggable: Defines whether the element can be dragged. + enter_key_hint: Hints what media types the media element is able to play. + hidden: Defines whether the element is hidden. + input_mode: Defines the type of the element. + item_prop: Defines the name of the element for metadata purposes. + lang: Defines the language used in the element. + role: Defines the role of the element. + slot: Assigns a slot in a shadow DOM shadow tree to an element. + spell_check: Defines whether the element may be checked for spelling errors. + tab_index: Defines the position of the current element in the tabbing order. + title: Defines a tooltip for the element. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The props of the component. + + Returns: + The component. + """ + ... + +class Circle(BaseHTML): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + cx: Optional[Union[Var[Union[int, str]], str, int]] = None, + cy: Optional[Union[Var[Union[int, str]], str, int]] = None, + r: Optional[Union[Var[Union[int, str]], str, int]] = None, + path_length: Optional[Union[Var[int], int]] = None, + access_key: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + auto_capitalize: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + content_editable: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + context_menu: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + dir: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + draggable: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + enter_key_hint: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + hidden: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + input_mode: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + item_prop: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + lang: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + role: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + slot: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + spell_check: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + tab_index: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + title: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, + on_blur: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_context_menu: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_double_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_focus: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mount: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_down: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_enter: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_leave: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_move: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_out: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_over: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_up: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_scroll: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_unmount: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + **props, + ) -> "Circle": + """Create the component. + + Args: + *children: The children of the component. + cx: The x-axis coordinate of the center of the circle. + cy: The y-axis coordinate of the center of the circle. + r: The radius of the circle. + path_length: The total length for the circle's circumference, in user units. + access_key: Provides a hint for generating a keyboard shortcut for the current element. + auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. + content_editable: Indicates whether the element's content is editable. + context_menu: Defines the ID of a element which will serve as the element's context menu. + dir: Defines the text direction. Allowed values are ltr (Left-To-Right) or rtl (Right-To-Left) + draggable: Defines whether the element can be dragged. + enter_key_hint: Hints what media types the media element is able to play. + hidden: Defines whether the element is hidden. + input_mode: Defines the type of the element. + item_prop: Defines the name of the element for metadata purposes. + lang: Defines the language used in the element. + role: Defines the role of the element. + slot: Assigns a slot in a shadow DOM shadow tree to an element. + spell_check: Defines whether the element may be checked for spelling errors. + tab_index: Defines the position of the current element in the tabbing order. + title: Defines a tooltip for the element. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The props of the component. + + Returns: + The component. + """ + ... + +class Rect(BaseHTML): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + x: Optional[Union[Var[Union[int, str]], str, int]] = None, + y: Optional[Union[Var[Union[int, str]], str, int]] = None, + width: Optional[Union[Var[Union[int, str]], str, int]] = None, + height: Optional[Union[Var[Union[int, str]], str, int]] = None, + rx: Optional[Union[Var[Union[int, str]], str, int]] = None, + ry: Optional[Union[Var[Union[int, str]], str, int]] = None, + path_length: Optional[Union[Var[int], int]] = None, + access_key: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + auto_capitalize: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + content_editable: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + context_menu: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + dir: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + draggable: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + enter_key_hint: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + hidden: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + input_mode: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + item_prop: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + lang: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + role: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + slot: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + spell_check: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + tab_index: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + title: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, + on_blur: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_context_menu: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_double_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_focus: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mount: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_down: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_enter: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_leave: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_move: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_out: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_over: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_up: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_scroll: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_unmount: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + **props, + ) -> "Rect": + """Create the component. + + Args: + *children: The children of the component. + x: The x coordinate of the rect. + y: The y coordinate of the rect. + width: The width of the rect + height: The height of the rect. + rx: The horizontal corner radius of the rect. Defaults to ry if it is specified. + ry: The vertical corner radius of the rect. Defaults to rx if it is specified. + path_length: The total length of the rectangle's perimeter, in user units. + access_key: Provides a hint for generating a keyboard shortcut for the current element. + auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. + content_editable: Indicates whether the element's content is editable. + context_menu: Defines the ID of a element which will serve as the element's context menu. + dir: Defines the text direction. Allowed values are ltr (Left-To-Right) or rtl (Right-To-Left) + draggable: Defines whether the element can be dragged. + enter_key_hint: Hints what media types the media element is able to play. + hidden: Defines whether the element is hidden. + input_mode: Defines the type of the element. + item_prop: Defines the name of the element for metadata purposes. + lang: Defines the language used in the element. + role: Defines the role of the element. + slot: Assigns a slot in a shadow DOM shadow tree to an element. + spell_check: Defines whether the element may be checked for spelling errors. + tab_index: Defines the position of the current element in the tabbing order. + title: Defines a tooltip for the element. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The props of the component. + + Returns: + The component. + """ + ... + +class Polygon(BaseHTML): + @overload + @classmethod + def create( # type: ignore + cls, + *children, + points: Optional[Union[Var[str], str]] = None, + path_length: Optional[Union[Var[int], int]] = None, + access_key: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + auto_capitalize: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + content_editable: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + context_menu: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + dir: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + draggable: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + enter_key_hint: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + hidden: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + input_mode: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + item_prop: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + lang: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + role: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + slot: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + spell_check: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + tab_index: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + title: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, + on_blur: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_context_menu: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_double_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_focus: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mount: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_down: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_enter: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_leave: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_move: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_out: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_over: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_up: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_scroll: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_unmount: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + **props, + ) -> "Polygon": + """Create the component. + + Args: + *children: The children of the component. + points: defines the list of points (pairs of x,y absolute coordinates) required to draw the polygon. + path_length: This prop lets specify the total length for the path, in user units. access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. @@ -1789,7 +2170,7 @@ class Defs(BaseHTML): """ ... -class Lineargradient(BaseHTML): +class LinearGradient(BaseHTML): @overload @classmethod def create( # type: ignore @@ -1878,18 +2259,18 @@ class Lineargradient(BaseHTML): Union[EventHandler, EventSpec, list, Callable, BaseVar] ] = None, **props, - ) -> "Lineargradient": + ) -> "LinearGradient": """Create the component. Args: *children: The children of the component. - gradient_units: Units for the gradient - gradient_transform: Transform applied to the gradient - spread_method: Method used to spread the gradient - x1: X coordinate of the starting point of the gradient - x2: X coordinate of the ending point of the gradient - y1: Y coordinate of the starting point of the gradient - y2: Y coordinate of the ending point of the gradient + gradient_units: Units for the gradient. + gradient_transform: Transform applied to the gradient. + spread_method: Method used to spread the gradient. + x1: X coordinate of the starting point of the gradient. + x2: X coordinate of the ending point of the gradient. + y1: Y coordinate of the starting point of the gradient. + y2: Y coordinate of the ending point of the gradient. access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. @@ -2013,9 +2394,9 @@ class Stop(BaseHTML): Args: *children: The children of the component. - offset: Offset of the gradient stop - stop_color: Color of the gradient stop - stop_opacity: Opacity of the gradient stop + offset: Offset of the gradient stop. + stop_color: Color of the gradient stop. + stop_opacity: Opacity of the gradient stop. access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. @@ -2133,7 +2514,135 @@ class Path(BaseHTML): Args: *children: The children of the component. - d: Defines the shape of the path + d: Defines the shape of the path. + access_key: Provides a hint for generating a keyboard shortcut for the current element. + auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. + content_editable: Indicates whether the element's content is editable. + context_menu: Defines the ID of a element which will serve as the element's context menu. + dir: Defines the text direction. Allowed values are ltr (Left-To-Right) or rtl (Right-To-Left) + draggable: Defines whether the element can be dragged. + enter_key_hint: Hints what media types the media element is able to play. + hidden: Defines whether the element is hidden. + input_mode: Defines the type of the element. + item_prop: Defines the name of the element for metadata purposes. + lang: Defines the language used in the element. + role: Defines the role of the element. + slot: Assigns a slot in a shadow DOM shadow tree to an element. + spell_check: Defines whether the element may be checked for spelling errors. + tab_index: Defines the position of the current element in the tabbing order. + title: Defines a tooltip for the element. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The props of the component. + + Returns: + The component. + """ + ... + +class SVG(ComponentNamespace): + circle = staticmethod(Circle.create) + rect = staticmethod(Rect.create) + polygon = staticmethod(Polygon.create) + path = staticmethod(Path.create) + stop = staticmethod(Stop.create) + linear_gradient = staticmethod(LinearGradient.create) + defs = staticmethod(Defs.create) + + @staticmethod + def __call__( + *children, + width: Optional[Union[Var[Union[int, str]], str, int]] = None, + height: Optional[Union[Var[Union[int, str]], str, int]] = None, + xmlns: Optional[Union[Var[str], str]] = None, + access_key: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + auto_capitalize: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + content_editable: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + context_menu: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + dir: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + draggable: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + enter_key_hint: Optional[ + Union[Var[Union[bool, int, str]], str, int, bool] + ] = None, + hidden: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + input_mode: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + item_prop: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + lang: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + role: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + slot: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + spell_check: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + tab_index: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + title: Optional[Union[Var[Union[bool, int, str]], str, int, bool]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, + on_blur: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_context_menu: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_double_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_focus: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mount: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_down: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_enter: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_leave: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_move: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_out: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_over: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_up: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_scroll: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_unmount: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + **props, + ) -> "Svg": + """Create the component. + + Args: + *children: The children of the component. + width: The width of the svg. + height: The height of the svg. + xmlns: The XML namespace declaration. access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. @@ -2175,8 +2684,4 @@ object = Object.create picture = Picture.create portal = Portal.create source = Source.create -svg = Svg.create -defs = Defs.create -lineargradient = Lineargradient.create -stop = Stop.create -path = Path.create +svg = SVG() diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index f385796bd..d279e3910 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -882,6 +882,7 @@ class PyiGenerator: # retrieve the _SUBMODULES and _SUBMOD_ATTRS from an init file if present. sub_mods = getattr(mod, "_SUBMODULES", None) sub_mod_attrs = getattr(mod, "_SUBMOD_ATTRS", None) + pyright_ignore_imports = getattr(mod, "_PYRIGHT_IGNORE_IMPORTS", []) if not sub_mods and not sub_mod_attrs: return @@ -901,6 +902,7 @@ class PyiGenerator: # construct the import statement and handle special cases for aliases sub_mod_attrs_imports = [ f"from .{path} import {mod if not isinstance(mod, tuple) else mod[0]} as {mod if not isinstance(mod, tuple) else mod[1]}" + + (" # type: ignore" if mod in pyright_ignore_imports else "") for mod, path in sub_mod_attrs.items() ] sub_mod_attrs_imports.append("") From ea016314b0fa7442137701bf6be9be01d0d88c35 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Mon, 22 Jul 2024 12:45:23 -0700 Subject: [PATCH 07/34] [REF-3227] implement more literal vars (#3687) * implement more literal vars * fix super issue * pyright has a bug i think * oh we changed that * fix docs * literalize vars recursively * do what masen told me :D * use dynamic keys * forgot .create * adjust _var_value * dang it darglint * add test for serializing literal vars into js exprs * fix silly mistake * add handling for var and none * use create safe * is none bruh * implement function vars and do various modification * fix None issue * clear a lot of creates that did nothing * add tests to function vars * added simple fix smh * use fconcat to make an even more complicated test --- reflex/experimental/vars/__init__.py | 6 + reflex/experimental/vars/base.py | 563 +++++++++++++++++++++++++-- tests/test_var.py | 57 ++- 3 files changed, 594 insertions(+), 32 deletions(-) diff --git a/reflex/experimental/vars/__init__.py b/reflex/experimental/vars/__init__.py index 98fa802d3..c4b3e6913 100644 --- a/reflex/experimental/vars/__init__.py +++ b/reflex/experimental/vars/__init__.py @@ -3,10 +3,16 @@ from .base import ArrayVar as ArrayVar from .base import BooleanVar as BooleanVar from .base import ConcatVarOperation as ConcatVarOperation +from .base import FunctionStringVar as FunctionStringVar from .base import FunctionVar as FunctionVar from .base import ImmutableVar as ImmutableVar +from .base import LiteralArrayVar as LiteralArrayVar +from .base import LiteralBooleanVar as LiteralBooleanVar +from .base import LiteralNumberVar as LiteralNumberVar +from .base import LiteralObjectVar as LiteralObjectVar from .base import LiteralStringVar as LiteralStringVar from .base import LiteralVar as LiteralVar from .base import NumberVar as NumberVar from .base import ObjectVar as ObjectVar from .base import StringVar as StringVar +from .base import VarOperationCall as VarOperationCall diff --git a/reflex/experimental/vars/base.py b/reflex/experimental/vars/base.py index 258f8d6c3..af0d350f1 100644 --- a/reflex/experimental/vars/base.py +++ b/reflex/experimental/vars/base.py @@ -7,9 +7,10 @@ import json import re import sys from functools import cached_property -from typing import Any, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from reflex import constants +from reflex.base import Base from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.utils import serializers, types from reflex.utils.exceptions import VarTypeError @@ -95,6 +96,11 @@ class ImmutableVar(Var): return hash((self._var_name, self._var_type, self._var_data)) def _get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ return self._var_data def _replace(self, merge_var_data=None, **kwargs: Any): @@ -275,10 +281,250 @@ class ArrayVar(ImmutableVar): class FunctionVar(ImmutableVar): """Base class for immutable function vars.""" + def __call__(self, *args: Var | Any) -> ArgsFunctionOperation: + """Call the function with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The function call operation. + """ + return ArgsFunctionOperation( + ("...args",), + VarOperationCall(self, *args, ImmutableVar.create_safe("...args")), + ) + + def call(self, *args: Var | Any) -> VarOperationCall: + """Call the function with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The function call operation. + """ + return VarOperationCall(self, *args) + + +class FunctionStringVar(FunctionVar): + """Base class for immutable function vars from a string.""" + + def __init__(self, func: str, _var_data: VarData | None = None) -> None: + """Initialize the function var. + + Args: + func: The function to call. + _var_data: Additional hooks and imports associated with the Var. + """ + super(FunctionVar, self).__init__( + _var_name=func, + _var_type=Callable, + _var_data=ImmutableVarData.merge(_var_data), + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class VarOperationCall(ImmutableVar): + """Base class for immutable vars that are the result of a function call.""" + + _func: Optional[FunctionVar] = dataclasses.field(default=None) + _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) + + def __init__( + self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None + ): + """Initialize the function call var. + + Args: + func: The function to call. + *args: The arguments to call the function with. + _var_data: Additional hooks and imports associated with the Var. + """ + super(VarOperationCall, self).__init__( + _var_name="", + _var_type=Callable, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_func", func) + object.__setattr__(self, "_args", args) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))" + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._func._get_all_var_data() if self._func is not None else None, + *[var._get_all_var_data() for var in self._args], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" + pass + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArgsFunctionOperation(FunctionVar): + """Base class for immutable function defined via arguments and return expression.""" + + _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + _return_expr: Union[Var, Any] = dataclasses.field(default=None) + + def __init__( + self, + args_names: Tuple[str, ...], + return_expr: Var | Any, + _var_data: VarData | None = None, + ) -> None: + """Initialize the function with arguments var. + + Args: + args_names: The names of the arguments. + return_expr: The return expression of the function. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArgsFunctionOperation, self).__init__( + _var_name=f"", + _var_type=Callable, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_args_names", args_names) + object.__setattr__(self, "_return_expr", return_expr) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))" + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._return_expr._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" + class LiteralVar(ImmutableVar): """Base class for immutable literal vars.""" + @classmethod + def create( + cls, + value: Any, + _var_data: VarData | None = None, + ) -> Var: + """Create a var from a value. + + Args: + value: The value to create the var from. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + + Raises: + TypeError: If the value is not a supported type for LiteralVar. + """ + if isinstance(value, Var): + if _var_data is None: + return value + return value._replace(merge_var_data=_var_data) + + if value is None: + return ImmutableVar.create_safe("null", _var_data=_var_data) + + if isinstance(value, Base): + return LiteralObjectVar( + value.dict(), _var_type=type(value), _var_data=_var_data + ) + + if isinstance(value, str): + return LiteralStringVar.create(value, _var_data=_var_data) + + constructor = type_mapping.get(type(value)) + + if constructor is None: + raise TypeError(f"Unsupported type {type(value)} for LiteralVar.") + + return constructor(value, _var_data=_var_data) + def __post_init__(self): """Post-initialize the var.""" @@ -298,7 +544,25 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL) class LiteralStringVar(LiteralVar): """Base class for immutable literal string vars.""" - _var_value: Optional[str] = dataclasses.field(default=None) + _var_value: str = dataclasses.field(default="") + + def __init__( + self, + _var_value: str, + _var_data: VarData | None = None, + ): + """Initialize the string var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralStringVar, self).__init__( + _var_name=f'"{_var_value}"', + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) @classmethod def create( @@ -316,7 +580,7 @@ class LiteralStringVar(LiteralVar): The var. """ if REFLEX_VAR_OPENING_TAG in value: - strings_and_vals: list[Var] = [] + strings_and_vals: list[Var | str] = [] offset = 0 # Initialize some methods for reading json. @@ -334,7 +598,7 @@ class LiteralStringVar(LiteralVar): while m := _decode_var_pattern.search(value): start, end = m.span() if start > 0: - strings_and_vals.append(LiteralStringVar.create(value[:start])) + strings_and_vals.append(value[:start]) serialized_data = m.group(1) @@ -364,17 +628,13 @@ class LiteralStringVar(LiteralVar): offset += end - start if value: - strings_and_vals.append(LiteralStringVar.create(value)) + strings_and_vals.append(value) - return ConcatVarOperation.create( - tuple(strings_and_vals), _var_data=_var_data - ) + return ConcatVarOperation(*strings_and_vals, _var_data=_var_data) - return cls( - _var_value=value, - _var_name=f'"{value}"', - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), + return LiteralStringVar( + value, + _var_data=_var_data, ) @@ -386,20 +646,33 @@ class LiteralStringVar(LiteralVar): class ConcatVarOperation(StringVar): """Representing a concatenation of literal string vars.""" - _var_value: tuple[Var, ...] = dataclasses.field(default_factory=tuple) + _var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple) - def __init__(self, _var_value: tuple[Var, ...], _var_data: VarData | None = None): + def __init__(self, *value: Var | str, _var_data: VarData | None = None): """Initialize the operation of concatenating literal string vars. Args: - _var_value: The list of vars to concatenate. + value: The values to concatenate. _var_data: Additional hooks and imports associated with the Var. """ super(ConcatVarOperation, self).__init__( _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str ) - object.__setattr__(self, "_var_value", _var_value) - object.__setattr__(self, "_var_name", self._cached_var_name) + object.__setattr__(self, "_var_value", value) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) @cached_property def _cached_var_name(self) -> str: @@ -408,7 +681,16 @@ class ConcatVarOperation(StringVar): Returns: The name of the var. """ - return "+".join([str(element) for element in self._var_value]) + return ( + "(" + + "+".join( + [ + str(element) if isinstance(element, Var) else f'"{element}"' + for element in self._var_value + ] + ) + + ")" + ) @cached_property def _cached_get_all_var_data(self) -> ImmutableVarData | None: @@ -418,7 +700,12 @@ class ConcatVarOperation(StringVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - *[var._get_all_var_data() for var in self._var_value], self._var_data + *[ + var._get_all_var_data() + for var in self._var_value + if isinstance(var, Var) + ], + self._var_data, ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -433,22 +720,236 @@ class ConcatVarOperation(StringVar): """Post-initialize the var.""" pass - @classmethod - def create( - cls, - value: tuple[Var, ...], + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralBooleanVar(LiteralVar): + """Base class for immutable literal boolean vars.""" + + _var_value: bool = dataclasses.field(default=False) + + def __init__( + self, + _var_value: bool, _var_data: VarData | None = None, - ) -> ConcatVarOperation: - """Create a var from a tuple of values. + ): + """Initialize the boolean var. Args: - value: The value to create the var from. + _var_value: The value of the var. _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralBooleanVar, self).__init__( + _var_name="true" if _var_value else "false", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralNumberVar(LiteralVar): + """Base class for immutable literal number vars.""" + + _var_value: float | int = dataclasses.field(default=0) + + def __init__( + self, + _var_value: float | int, + _var_data: VarData | None = None, + ): + """Initialize the number var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralNumberVar, self).__init__( + _var_name=str(_var_value), + _var_type=type(_var_value), + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralObjectVar(LiteralVar): + """Base class for immutable literal object vars.""" + + _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( + default_factory=dict + ) + + def __init__( + self, + _var_value: dict[Var | Any, Var | Any], + _var_type: Type = dict, + _var_data: VarData | None = None, + ): + """Initialize the object var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralObjectVar, self).__init__( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "_var_value", + _var_value, + ) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. Returns: - The var. + The attribute of the var. """ - return ConcatVarOperation( - _var_value=value, - _var_data=_var_data, + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "{ " + + ", ".join( + [ + f"[{str(LiteralVar.create(key))}] : {str(LiteralVar.create(value))}" + for key, value in self._var_value.items() + ] + ) + + " }" ) + + @cached_property + def _get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[ + value._get_all_var_data() + for key, value in self._var_value + if isinstance(value, Var) + ], + *[ + key._get_all_var_data() + for key, value in self._var_value + if isinstance(key, Var) + ], + self._var_data, + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralArrayVar(LiteralVar): + """Base class for immutable literal array vars.""" + + _var_value: Union[ + List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...] + ] = dataclasses.field(default_factory=list) + + def __init__( + self, + _var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any], + _var_data: VarData | None = None, + ): + """Initialize the array var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralArrayVar, self).__init__( + _var_name="", + _var_data=ImmutableVarData.merge(_var_data), + _var_type=list, + ) + object.__setattr__(self, "_var_value", _var_value) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "[" + + ", ".join( + [str(LiteralVar.create(element)) for element in self._var_value] + ) + + "]" + ) + + @cached_property + def _get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[ + var._get_all_var_data() + for var in self._var_value + if isinstance(var, Var) + ], + self._var_data, + ) + + +type_mapping = { + int: LiteralNumberVar, + float: LiteralNumberVar, + bool: LiteralBooleanVar, + dict: LiteralObjectVar, + list: LiteralArrayVar, + tuple: LiteralArrayVar, + set: LiteralArrayVar, +} diff --git a/tests/test_var.py b/tests/test_var.py index 78b3a2160..47d4f223b 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -8,9 +8,12 @@ from pandas import DataFrame from reflex.base import Base from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.experimental.vars.base import ( + ArgsFunctionOperation, ConcatVarOperation, + FunctionStringVar, ImmutableVar, LiteralStringVar, + LiteralVar, ) from reflex.state import BaseState from reflex.utils.imports import ImportVar @@ -858,6 +861,58 @@ def test_state_with_initial_computed_var( assert runtime_dict[var_name] == expected_runtime +def test_literal_var(): + complicated_var = LiteralVar.create( + [ + {"a": 1, "b": 2, "c": {"d": 3, "e": 4}}, + [1, 2, 3, 4], + 9, + "string", + True, + False, + None, + set([1, 2, 3]), + ] + ) + assert ( + str(complicated_var) + == '[{ ["a"] : 1, ["b"] : 2, ["c"] : { ["d"] : 3, ["e"] : 4 } }, [1, 2, 3, 4], 9, "string", true, false, null, [1, 2, 3]]' + ) + + +def test_function_var(): + addition_func = FunctionStringVar("((a, b) => a + b)") + assert str(addition_func.call(1, 2)) == "(((a, b) => a + b)(1, 2))" + + manual_addition_func = ArgsFunctionOperation( + ("a", "b"), + { + "args": [ImmutableVar.create_safe("a"), ImmutableVar.create_safe("b")], + "result": ImmutableVar.create_safe("a + b"), + }, + ) + assert ( + str(manual_addition_func.call(1, 2)) + == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))' + ) + + increment_func = addition_func(1) + assert ( + str(increment_func.call(2)) + == "(((...args) => ((((a, b) => a + b)(1, ...args))))(2))" + ) + + create_hello_statement = ArgsFunctionOperation( + ("name",), f"Hello, {ImmutableVar.create_safe('name')}!" + ) + first_name = LiteralStringVar("Steven") + last_name = LiteralStringVar("Universe") + assert ( + str(create_hello_statement.call(f"{first_name} {last_name}")) + == '(((name) => (("Hello, "+name+"!")))(("Steven"+" "+"Universe")))' + ) + + def test_retrival(): var_without_data = ImmutableVar.create("test") assert var_without_data is not None @@ -931,7 +986,7 @@ def test_fstring_concat(): ), ) - assert str(string_concat) == '"foo"+imagination+"bar"+consequences+"baz"' + assert str(string_concat) == '("foo"+imagination+"bar"+consequences+"baz")' assert isinstance(string_concat, ConcatVarOperation) assert string_concat._get_all_var_data() == ImmutableVarData( state="fear", From b9927b6f498a1f67a7bea5fa787ecf04d0bdac98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Tue, 23 Jul 2024 22:58:15 +0200 Subject: [PATCH 08/34] notifying frontend about backend error looks better (#3491) --- reflex/app.py | 22 ++++++++++++++++++++-- reflex/components/core/banner.py | 14 ++++++++++++++ reflex/components/core/banner.pyi | 6 +++--- reflex/components/sonner/toast.py | 30 ++++++++++++++++++++++++++++-- reflex/components/sonner/toast.pyi | 21 ++++++++++++++------- reflex/state.py | 3 ++- tests/test_state.py | 29 ++++++++++++++++++++++++----- 7 files changed, 105 insertions(+), 20 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 658ba1a1f..7e40a95bf 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -112,11 +112,29 @@ def default_backend_exception_handler(exception: Exception) -> EventSpec: EventSpec: The window alert event. """ + from reflex.components.sonner.toast import Toaster, toast + error = traceback.format_exc() console.error(f"[Reflex Backend Exception]\n {error}\n") - return window_alert("An error occurred. See logs for details.") + error_message = ( + ["Contact the website administrator."] + if is_prod_mode() + else [f"{type(exception).__name__}: {exception}.", "See logs for details."] + ) + if Toaster.is_used: + return toast( + level="error", + title="An error occurred.", + description="
".join(error_message), + position="top-center", + id="backend_error", + style={"width": "500px"}, + ) # type: ignore + else: + error_message.insert(0, "An error occurred.") + return window_alert("\n".join(error_message)) def default_overlay_component() -> Component: @@ -183,7 +201,7 @@ class App(MiddlewareMixin, LifespanMixin, Base): # A component that is present on every page (defaults to the Connection Error banner). overlay_component: Optional[Union[Component, ComponentCallable]] = ( - default_overlay_component + default_overlay_component() ) # Error boundary component to wrap the app with. diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py index b634ab75a..c6b46696c 100644 --- a/reflex/components/core/banner.py +++ b/reflex/components/core/banner.py @@ -153,6 +153,20 @@ useEffect(() => {{ hook, ] + @classmethod + def create(cls, *children, **props) -> Component: + """Create a connection toaster component. + + Args: + *children: The children of the component. + **props: The properties of the component. + + Returns: + The connection toaster component. + """ + Toaster.is_used = True + return super().create(*children, **props) + class ConnectionBanner(Component): """A connection banner component.""" diff --git a/reflex/components/core/banner.pyi b/reflex/components/core/banner.pyi index ddaafb153..b9b6d506f 100644 --- a/reflex/components/core/banner.pyi +++ b/reflex/components/core/banner.pyi @@ -187,7 +187,7 @@ class ConnectionToaster(Toaster): ] = None, **props, ) -> "ConnectionToaster": - """Create the component. + """Create a connection toaster component. Args: *children: The children of the component. @@ -211,10 +211,10 @@ class ConnectionToaster(Toaster): class_name: The class name for the component. autofocus: Whether the component should take the focus once the page is loaded custom_attrs: custom attribute - **props: The props of the component. + **props: The properties of the component. Returns: - The component. + The connection toaster component. """ ... diff --git a/reflex/components/sonner/toast.py b/reflex/components/sonner/toast.py index f8d1cc340..d4df31e82 100644 --- a/reflex/components/sonner/toast.py +++ b/reflex/components/sonner/toast.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Literal, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union from reflex.base import Base from reflex.components.component import Component, ComponentNamespace @@ -211,6 +211,9 @@ class Toaster(Component): # Pauses toast timers when the page is hidden, e.g., when the tab is backgrounded, the browser is minimized, or the OS is locked. pause_when_page_is_hidden: Var[bool] + # Marked True when any Toast component is created. + is_used: ClassVar[bool] = False + def add_hooks(self) -> list[Var | str]: """Add hooks for the toaster component. @@ -231,7 +234,7 @@ class Toaster(Component): return [hook] @staticmethod - def send_toast(message: str, level: str | None = None, **props) -> EventSpec: + def send_toast(message: str = "", level: str | None = None, **props) -> EventSpec: """Send a toast message. Args: @@ -239,10 +242,19 @@ class Toaster(Component): level: The level of the toast. **props: The options for the toast. + Raises: + ValueError: If the Toaster component is not created. + Returns: The toast event. """ + if not Toaster.is_used: + raise ValueError( + "Toaster component must be created before sending a toast. (use `rx.toast.provider()`)" + ) toast_command = f"{toast_ref}.{level}" if level is not None else toast_ref + if message == "" and ("title" not in props or "description" not in props): + raise ValueError("Toast message or title or description must be provided.") if props: args = serialize(ToastProps(**props)) # type: ignore toast = f"{toast_command}(`{message}`, {args})" @@ -331,6 +343,20 @@ class Toaster(Component): ) return call_script(dismiss_action) + @classmethod + def create(cls, *children, **props) -> Component: + """Create a toaster component. + + Args: + *children: The children of the toaster. + **props: The properties of the toaster. + + Returns: + The toaster component. + """ + cls.is_used = True + return super().create(*children, **props) + # TODO: figure out why loading toast stay open forever # def toast_loading(message: str, **kwargs): diff --git a/reflex/components/sonner/toast.pyi b/reflex/components/sonner/toast.pyi index 82999a06a..7e5758b16 100644 --- a/reflex/components/sonner/toast.pyi +++ b/reflex/components/sonner/toast.pyi @@ -3,7 +3,7 @@ # ------------------- DO NOT EDIT ---------------------- # This file was generated by `reflex/utils/pyi_generator.py`! # ------------------------------------------------------ -from typing import Any, Callable, Dict, Literal, Optional, Union, overload +from typing import Any, Callable, ClassVar, Dict, Literal, Optional, Union, overload from reflex.base import Base from reflex.components.component import Component, ComponentNamespace @@ -52,9 +52,13 @@ class ToastProps(PropsBase): def dict(self, *args, **kwargs) -> dict[str, Any]: ... class Toaster(Component): + is_used: ClassVar[bool] = False + def add_hooks(self) -> list[Var | str]: ... @staticmethod - def send_toast(message: str, level: str | None = None, **props) -> EventSpec: ... + def send_toast( + message: str = "", level: str | None = None, **props + ) -> EventSpec: ... @staticmethod def toast_info(message: str, **kwargs): ... @staticmethod @@ -158,10 +162,10 @@ class Toaster(Component): ] = None, **props, ) -> "Toaster": - """Create the component. + """Create a toaster component. Args: - *children: The children of the component. + *children: The children of the toaster. theme: the theme of the toast rich_colors: whether to show rich colors expand: whether to expand the toast @@ -182,10 +186,10 @@ class Toaster(Component): class_name: The class name for the component. autofocus: Whether the component should take the focus once the page is loaded custom_attrs: custom attribute - **props: The props of the component. + **props: The properties of the toaster. Returns: - The component. + The toaster component. """ ... @@ -200,7 +204,7 @@ class ToastNamespace(ComponentNamespace): @staticmethod def __call__( - message: str, level: Optional[str] = None, **props + message: str = "", level: Optional[str] = None, **props ) -> "Optional[EventSpec]": """Send a toast message. @@ -209,6 +213,9 @@ class ToastNamespace(ComponentNamespace): level: The level of the toast. **props: The options for the toast. + Raises: + ValueError: If the Toaster component is not created. + Returns: The toast event. """ diff --git a/reflex/state.py b/reflex/state.py index 49b5bd4a4..9313939dc 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -31,6 +31,8 @@ from typing import ( import dill from sqlalchemy.orm import DeclarativeBase +from reflex.config import get_config + try: import pydantic.v1 as pydantic except ModuleNotFoundError: @@ -42,7 +44,6 @@ from redis.exceptions import ResponseError from reflex import constants from reflex.base import Base -from reflex.config import get_config from reflex.event import ( BACKGROUND_TASK_MARKER, Event, diff --git a/tests/test_state.py b/tests/test_state.py index 2fc149389..c998944ef 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -19,6 +19,7 @@ import reflex.config from reflex import constants from reflex.app import App from reflex.base import Base +from reflex.components.sonner.toast import Toaster from reflex.constants import CompileVars, RouteVar, SocketEvent from reflex.event import Event, EventHandler from reflex.state import ( @@ -1527,7 +1528,6 @@ async def test_state_with_invalid_yield(capsys, mock_app): Args: capsys: Pytest fixture for capture standard streams. mock_app: Mock app fixture. - """ class StateWithInvalidYield(BaseState): @@ -1546,10 +1546,29 @@ async def test_state_with_invalid_yield(capsys, mock_app): rx.event.Event(token="fake_token", name="invalid_handler") ): assert not update.delta - assert update.events == rx.event.fix_events( - [rx.window_alert("An error occurred. See logs for details.")], - token="", - ) + if Toaster.is_used: + assert update.events == rx.event.fix_events( + [ + rx.toast( + title="An error occurred.", + description="TypeError: Your handler test_state_with_invalid_yield..StateWithInvalidYield.invalid_handler must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`).
See logs for details.", + level="error", + id="backend_error", + position="top-center", + style={"width": "500px"}, + ) # type: ignore + ], + token="", + ) + else: + assert update.events == rx.event.fix_events( + [ + rx.window_alert( + "An error occurred.\nContact the website administrator." + ) + ], + token="", + ) captured = capsys.readouterr() assert "must only return/yield: None, Events or other EventHandlers" in captured.out From 0845d2ee7689e91a4532a8b800314f71230febaa Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 23 Jul 2024 15:28:38 -0700 Subject: [PATCH 09/34] [REF-3184] [REF-3339] Background task locking improvements (#3696) * [REF-3184] Raise exception when encountering nested `async with self` blocks Avoid deadlock when the background task already holds the mutation lock for a given state. * [REF-3339] get_state from background task links to StateProxy When calling `get_state` from a background task, the resulting state instance is wrapped in a StateProxy that is bound to the original StateProxy and shares the same async context, lock, and mutability flag. * If StateProxy has a _self_parent_state_proxy, retrieve the correct substate * test_state fixup --- integration/test_background_task.py | 103 ++++++++++++++++++++++++++++ reflex/state.py | 64 ++++++++++++++--- tests/test_state.py | 2 +- 3 files changed, 160 insertions(+), 9 deletions(-) diff --git a/integration/test_background_task.py b/integration/test_background_task.py index 96a47e951..98b6e48ff 100644 --- a/integration/test_background_task.py +++ b/integration/test_background_task.py @@ -12,7 +12,10 @@ def BackgroundTask(): """Test that background tasks work as expected.""" import asyncio + import pytest + import reflex as rx + from reflex.state import ImmutableStateError class State(rx.State): counter: int = 0 @@ -71,6 +74,38 @@ def BackgroundTask(): self.racy_task(), self.racy_task(), self.racy_task(), self.racy_task() ) + @rx.background + async def nested_async_with_self(self): + async with self: + self.counter += 1 + with pytest.raises(ImmutableStateError): + async with self: + self.counter += 1 + + async def triple_count(self): + third_state = await self.get_state(ThirdState) + await third_state._triple_count() + + class OtherState(rx.State): + @rx.background + async def get_other_state(self): + async with self: + state = await self.get_state(State) + state.counter += 1 + await state.triple_count() + with pytest.raises(ImmutableStateError): + await state.triple_count() + with pytest.raises(ImmutableStateError): + state.counter += 1 + async with state: + state.counter += 1 + await state.triple_count() + + class ThirdState(rx.State): + async def _triple_count(self): + state = await self.get_state(State) + state.counter *= 3 + def index() -> rx.Component: return rx.vstack( rx.chakra.input( @@ -109,6 +144,16 @@ def BackgroundTask(): on_click=State.handle_racy_event, id="racy-increment", ), + rx.button( + "Nested Async with Self", + on_click=State.nested_async_with_self, + id="nested-async-with-self", + ), + rx.button( + "Increment from OtherState", + on_click=OtherState.get_other_state, + id="increment-from-other-state", + ), rx.button("Reset", on_click=State.reset_counter, id="reset"), ) @@ -230,3 +275,61 @@ def test_background_task( assert background_task._poll_for( lambda: not background_task.app_instance.background_tasks # type: ignore ) + + +def test_nested_async_with_self( + background_task: AppHarness, + driver: WebDriver, + token: str, +): + """Test that nested async with self in the same coroutine raises Exception. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert background_task.app_instance is not None + + # get a reference to all buttons + nested_async_with_self_button = driver.find_element(By.ID, "nested-async-with-self") + increment_button = driver.find_element(By.ID, "increment") + + # get a reference to the counter + counter = driver.find_element(By.ID, "counter") + assert background_task._poll_for(lambda: counter.text == "0", timeout=5) + + nested_async_with_self_button.click() + assert background_task._poll_for(lambda: counter.text == "1", timeout=5) + + increment_button.click() + assert background_task._poll_for(lambda: counter.text == "2", timeout=5) + + +def test_get_state( + background_task: AppHarness, + driver: WebDriver, + token: str, +): + """Test that get_state returns a state bound to the correct StateProxy. + + Args: + background_task: harness for BackgroundTask app. + driver: WebDriver instance. + token: The token for the connected client. + """ + assert background_task.app_instance is not None + + # get a reference to all buttons + other_state_button = driver.find_element(By.ID, "increment-from-other-state") + increment_button = driver.find_element(By.ID, "increment") + + # get a reference to the counter + counter = driver.find_element(By.ID, "counter") + assert background_task._poll_for(lambda: counter.text == "0", timeout=5) + + other_state_button.click() + assert background_task._poll_for(lambda: counter.text == "12", timeout=5) + + increment_button.click() + assert background_task._poll_for(lambda: counter.text == "13", timeout=5) diff --git a/reflex/state.py b/reflex/state.py index 9313939dc..e29336042 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -202,7 +202,7 @@ def _no_chain_background_task( def _substate_key( token: str, - state_cls_or_name: BaseState | Type[BaseState] | str | list[str], + state_cls_or_name: BaseState | Type[BaseState] | str | Sequence[str], ) -> str: """Get the substate key. @@ -2029,19 +2029,38 @@ class StateProxy(wrapt.ObjectProxy): self.counter += 1 """ - def __init__(self, state_instance): + def __init__( + self, state_instance, parent_state_proxy: Optional["StateProxy"] = None + ): """Create a proxy for a state instance. + If `get_state` is used on a StateProxy, the resulting state will be + linked to the given state via parent_state_proxy. The first state in the + chain is the state that initiated the background task. + Args: state_instance: The state instance to proxy. + parent_state_proxy: The parent state proxy, for linked mutability and context tracking. """ super().__init__(state_instance) # compile is not relevant to backend logic self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP) - self._self_substate_path = state_instance.get_full_name().split(".") + self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_actx = None self._self_mutable = False self._self_actx_lock = asyncio.Lock() + self._self_actx_lock_holder = None + self._self_parent_state_proxy = parent_state_proxy + + def _is_mutable(self) -> bool: + """Check if the state is mutable. + + Returns: + Whether the state is mutable. + """ + if self._self_parent_state_proxy is not None: + return self._self_parent_state_proxy._is_mutable() + return self._self_mutable async def __aenter__(self) -> StateProxy: """Enter the async context manager protocol. @@ -2054,8 +2073,31 @@ class StateProxy(wrapt.ObjectProxy): Returns: This StateProxy instance in mutable mode. + + Raises: + ImmutableStateError: If the state is already mutable. """ + if self._self_parent_state_proxy is not None: + parent_state = ( + await self._self_parent_state_proxy.__aenter__() + ).__wrapped__ + super().__setattr__( + "__wrapped__", + await parent_state.get_state( + State.get_class_substate(self._self_substate_path) + ), + ) + return self + current_task = asyncio.current_task() + if ( + self._self_actx_lock.locked() + and current_task == self._self_actx_lock_holder + ): + raise ImmutableStateError( + "The state is already mutable. Do not nest `async with self` blocks." + ) await self._self_actx_lock.acquire() + self._self_actx_lock_holder = current_task self._self_actx = self._self_app.modify_state( token=_substate_key( self.__wrapped__.router.session.client_token, @@ -2077,12 +2119,16 @@ class StateProxy(wrapt.ObjectProxy): Args: exc_info: The exception info tuple. """ + if self._self_parent_state_proxy is not None: + await self._self_parent_state_proxy.__aexit__(*exc_info) + return if self._self_actx is None: return self._self_mutable = False try: await self._self_actx.__aexit__(*exc_info) finally: + self._self_actx_lock_holder = None self._self_actx_lock.release() self._self_actx = None @@ -2117,7 +2163,7 @@ class StateProxy(wrapt.ObjectProxy): Raises: ImmutableStateError: If the state is not in mutable mode. """ - if name in ["substates", "parent_state"] and not self._self_mutable: + if name in ["substates", "parent_state"] and not self._is_mutable(): raise ImmutableStateError( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." @@ -2157,7 +2203,7 @@ class StateProxy(wrapt.ObjectProxy): """ if ( name.startswith("_self_") # wrapper attribute - or self._self_mutable # lock held + or self._is_mutable() # lock held # non-persisted state attribute or name in self.__wrapped__.get_skip_vars() ): @@ -2181,7 +2227,7 @@ class StateProxy(wrapt.ObjectProxy): Raises: ImmutableStateError: If the state is not in mutable mode. """ - if not self._self_mutable: + if not self._is_mutable(): raise ImmutableStateError( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." @@ -2200,12 +2246,14 @@ class StateProxy(wrapt.ObjectProxy): Raises: ImmutableStateError: If the state is not in mutable mode. """ - if not self._self_mutable: + if not self._is_mutable(): raise ImmutableStateError( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." ) - return await self.__wrapped__.get_state(state_cls) + return type(self)( + await self.__wrapped__.get_state(state_cls), parent_state_proxy=self + ) def _as_state_update(self, *args, **kwargs) -> StateUpdate: """Temporarily allow mutability to access parent_state. diff --git a/tests/test_state.py b/tests/test_state.py index c998944ef..18d740015 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1825,7 +1825,7 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): sp = StateProxy(grandchild_state) assert sp.__wrapped__ == grandchild_state - assert sp._self_substate_path == grandchild_state.get_full_name().split(".") + assert sp._self_substate_path == tuple(grandchild_state.get_full_name().split(".")) assert sp._self_app is mock_app assert not sp._self_mutable assert sp._self_actx is None From ede5cd1f2c1ac56da9523d9a3ff6536aa0dae53d Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 25 Jul 2024 09:34:14 -0700 Subject: [PATCH 10/34] [REF-3321] implement var operation decorator (#3698) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * implement var operation decorator * use older syntax * use cast and older syntax * use something even simpler * add some tests * use old union tactics * that's not how you do things * implement arithmetic operations while we're at it * add test * even more operations * can't use __bool__ * thanos snap * forgot ruff * use default factory * dang it darglint * i know i should have done that but * convert values into literalvars * make test pass * use older union tactics * add test to string var * pright why do you hate me 🥺 --- reflex/experimental/vars/__init__.py | 26 +- reflex/experimental/vars/base.py | 644 ++----------- reflex/experimental/vars/function.py | 214 +++++ reflex/experimental/vars/number.py | 1295 ++++++++++++++++++++++++++ reflex/experimental/vars/sequence.py | 1039 +++++++++++++++++++++ reflex/vars.py | 8 +- tests/test_var.py | 67 +- 7 files changed, 2714 insertions(+), 579 deletions(-) create mode 100644 reflex/experimental/vars/function.py create mode 100644 reflex/experimental/vars/number.py create mode 100644 reflex/experimental/vars/sequence.py diff --git a/reflex/experimental/vars/__init__.py b/reflex/experimental/vars/__init__.py index c4b3e6913..945cf25fc 100644 --- a/reflex/experimental/vars/__init__.py +++ b/reflex/experimental/vars/__init__.py @@ -1,18 +1,20 @@ """Experimental Immutable-Based Var System.""" -from .base import ArrayVar as ArrayVar -from .base import BooleanVar as BooleanVar -from .base import ConcatVarOperation as ConcatVarOperation -from .base import FunctionStringVar as FunctionStringVar -from .base import FunctionVar as FunctionVar from .base import ImmutableVar as ImmutableVar -from .base import LiteralArrayVar as LiteralArrayVar -from .base import LiteralBooleanVar as LiteralBooleanVar -from .base import LiteralNumberVar as LiteralNumberVar from .base import LiteralObjectVar as LiteralObjectVar -from .base import LiteralStringVar as LiteralStringVar from .base import LiteralVar as LiteralVar -from .base import NumberVar as NumberVar from .base import ObjectVar as ObjectVar -from .base import StringVar as StringVar -from .base import VarOperationCall as VarOperationCall +from .base import var_operation as var_operation +from .function import FunctionStringVar as FunctionStringVar +from .function import FunctionVar as FunctionVar +from .function import VarOperationCall as VarOperationCall +from .number import BooleanVar as BooleanVar +from .number import LiteralBooleanVar as LiteralBooleanVar +from .number import LiteralNumberVar as LiteralNumberVar +from .number import NumberVar as NumberVar +from .sequence import ArrayJoinOperation as ArrayJoinOperation +from .sequence import ArrayVar as ArrayVar +from .sequence import ConcatVarOperation as ConcatVarOperation +from .sequence import LiteralArrayVar as LiteralArrayVar +from .sequence import LiteralStringVar as LiteralStringVar +from .sequence import StringVar as StringVar diff --git a/reflex/experimental/vars/base.py b/reflex/experimental/vars/base.py index af0d350f1..55b5673bd 100644 --- a/reflex/experimental/vars/base.py +++ b/reflex/experimental/vars/base.py @@ -3,15 +3,22 @@ from __future__ import annotations import dataclasses -import json -import re +import functools import sys -from functools import cached_property -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + Optional, + Type, + TypeVar, + Union, +) + +from typing_extensions import ParamSpec from reflex import constants from reflex.base import Base -from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.utils import serializers, types from reflex.utils.exceptions import VarTypeError from reflex.vars import ( @@ -80,11 +87,12 @@ class ImmutableVar(Var): """Post-initialize the var.""" # Decode any inline Var markup and apply it to the instance _var_data, _var_name = _decode_var_immutable(self._var_name) - if _var_data: + + if _var_data or _var_name != self._var_name: self.__init__( - _var_name, - self._var_type, - ImmutableVarData.merge(self._var_data, _var_data), + _var_name=_var_name, + _var_type=self._var_type, + _var_data=ImmutableVarData.merge(self._var_data, _var_data), ) def __hash__(self) -> int: @@ -255,232 +263,13 @@ class ImmutableVar(Var): _global_vars[hashed_var] = self # Encode the _var_data into the formatted output for tracking purposes. - return f"{REFLEX_VAR_OPENING_TAG}{hashed_var}{REFLEX_VAR_CLOSING_TAG}{self._var_name}" - - -class StringVar(ImmutableVar): - """Base class for immutable string vars.""" - - -class NumberVar(ImmutableVar): - """Base class for immutable number vars.""" - - -class BooleanVar(ImmutableVar): - """Base class for immutable boolean vars.""" + return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}" class ObjectVar(ImmutableVar): """Base class for immutable object vars.""" -class ArrayVar(ImmutableVar): - """Base class for immutable array vars.""" - - -class FunctionVar(ImmutableVar): - """Base class for immutable function vars.""" - - def __call__(self, *args: Var | Any) -> ArgsFunctionOperation: - """Call the function with the given arguments. - - Args: - *args: The arguments to call the function with. - - Returns: - The function call operation. - """ - return ArgsFunctionOperation( - ("...args",), - VarOperationCall(self, *args, ImmutableVar.create_safe("...args")), - ) - - def call(self, *args: Var | Any) -> VarOperationCall: - """Call the function with the given arguments. - - Args: - *args: The arguments to call the function with. - - Returns: - The function call operation. - """ - return VarOperationCall(self, *args) - - -class FunctionStringVar(FunctionVar): - """Base class for immutable function vars from a string.""" - - def __init__(self, func: str, _var_data: VarData | None = None) -> None: - """Initialize the function var. - - Args: - func: The function to call. - _var_data: Additional hooks and imports associated with the Var. - """ - super(FunctionVar, self).__init__( - _var_name=func, - _var_type=Callable, - _var_data=ImmutableVarData.merge(_var_data), - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class VarOperationCall(ImmutableVar): - """Base class for immutable vars that are the result of a function call.""" - - _func: Optional[FunctionVar] = dataclasses.field(default=None) - _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) - - def __init__( - self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None - ): - """Initialize the function call var. - - Args: - func: The function to call. - *args: The arguments to call the function with. - _var_data: Additional hooks and imports associated with the Var. - """ - super(VarOperationCall, self).__init__( - _var_name="", - _var_type=Callable, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_func", func) - object.__setattr__(self, "_args", args) - object.__delattr__(self, "_var_name") - - def __getattr__(self, name): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - if name == "_var_name": - return self._cached_var_name - return super(type(self), self).__getattr__(name) - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))" - - @cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - self._func._get_all_var_data() if self._func is not None else None, - *[var._get_all_var_data() for var in self._args], - self._var_data, - ) - - def _get_all_var_data(self) -> ImmutableVarData | None: - """Wrapper method for cached property. - - Returns: - The VarData of the components and all of its children. - """ - return self._cached_get_all_var_data - - def __post_init__(self): - """Post-initialize the var.""" - pass - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ArgsFunctionOperation(FunctionVar): - """Base class for immutable function defined via arguments and return expression.""" - - _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple) - _return_expr: Union[Var, Any] = dataclasses.field(default=None) - - def __init__( - self, - args_names: Tuple[str, ...], - return_expr: Var | Any, - _var_data: VarData | None = None, - ) -> None: - """Initialize the function with arguments var. - - Args: - args_names: The names of the arguments. - return_expr: The return expression of the function. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ArgsFunctionOperation, self).__init__( - _var_name=f"", - _var_type=Callable, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_args_names", args_names) - object.__setattr__(self, "_return_expr", return_expr) - object.__delattr__(self, "_var_name") - - def __getattr__(self, name): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - if name == "_var_name": - return self._cached_var_name - return super(type(self), self).__getattr__(name) - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))" - - @cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - self._return_expr._get_all_var_data(), - self._var_data, - ) - - def _get_all_var_data(self) -> ImmutableVarData | None: - """Wrapper method for cached property. - - Returns: - The VarData of the components and all of its children. - """ - return self._cached_get_all_var_data - - def __post_init__(self): - """Post-initialize the var.""" - - class LiteralVar(ImmutableVar): """Base class for immutable literal vars.""" @@ -515,9 +304,22 @@ class LiteralVar(ImmutableVar): value.dict(), _var_type=type(value), _var_data=_var_data ) + from .number import LiteralBooleanVar, LiteralNumberVar + from .sequence import LiteralArrayVar, LiteralStringVar + if isinstance(value, str): return LiteralStringVar.create(value, _var_data=_var_data) + type_mapping = { + int: LiteralNumberVar, + float: LiteralNumberVar, + bool: LiteralBooleanVar, + dict: LiteralObjectVar, + list: LiteralArrayVar, + tuple: LiteralArrayVar, + set: LiteralArrayVar, + } + constructor = type_mapping.get(type(value)) if constructor is None: @@ -529,256 +331,6 @@ class LiteralVar(ImmutableVar): """Post-initialize the var.""" -# Compile regex for finding reflex var tags. -_decode_var_pattern_re = ( - rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}" -) -_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LiteralStringVar(LiteralVar): - """Base class for immutable literal string vars.""" - - _var_value: str = dataclasses.field(default="") - - def __init__( - self, - _var_value: str, - _var_data: VarData | None = None, - ): - """Initialize the string var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralStringVar, self).__init__( - _var_name=f'"{_var_value}"', - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_var_value", _var_value) - - @classmethod - def create( - cls, - value: str, - _var_data: VarData | None = None, - ) -> LiteralStringVar | ConcatVarOperation: - """Create a var from a string value. - - Args: - value: The value to create the var from. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - if REFLEX_VAR_OPENING_TAG in value: - strings_and_vals: list[Var | str] = [] - offset = 0 - - # Initialize some methods for reading json. - var_data_config = VarData().__config__ - - def json_loads(s): - try: - return var_data_config.json_loads(s) - except json.decoder.JSONDecodeError: - return var_data_config.json_loads( - var_data_config.json_loads(f'"{s}"') - ) - - # Find all tags. - while m := _decode_var_pattern.search(value): - start, end = m.span() - if start > 0: - strings_and_vals.append(value[:start]) - - serialized_data = m.group(1) - - if serialized_data[1:].isnumeric(): - # This is a global immutable var. - var = _global_vars[int(serialized_data)] - strings_and_vals.append(var) - value = value[(end + len(var._var_name)) :] - else: - data = json_loads(serialized_data) - string_length = data.pop("string_length", None) - var_data = VarData.parse_obj(data) - - # Use string length to compute positions of interpolations. - if string_length is not None: - realstart = start + offset - var_data.interpolations = [ - (realstart, realstart + string_length) - ] - strings_and_vals.append( - ImmutableVar.create_safe( - value[end : (end + string_length)], _var_data=var_data - ) - ) - value = value[(end + string_length) :] - - offset += end - start - - if value: - strings_and_vals.append(value) - - return ConcatVarOperation(*strings_and_vals, _var_data=_var_data) - - return LiteralStringVar( - value, - _var_data=_var_data, - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ConcatVarOperation(StringVar): - """Representing a concatenation of literal string vars.""" - - _var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple) - - def __init__(self, *value: Var | str, _var_data: VarData | None = None): - """Initialize the operation of concatenating literal string vars. - - Args: - value: The values to concatenate. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ConcatVarOperation, self).__init__( - _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str - ) - object.__setattr__(self, "_var_value", value) - object.__delattr__(self, "_var_name") - - def __getattr__(self, name): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - if name == "_var_name": - return self._cached_var_name - return super(type(self), self).__getattr__(name) - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return ( - "(" - + "+".join( - [ - str(element) if isinstance(element, Var) else f'"{element}"' - for element in self._var_value - ] - ) - + ")" - ) - - @cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - *[ - var._get_all_var_data() - for var in self._var_value - if isinstance(var, Var) - ], - self._var_data, - ) - - def _get_all_var_data(self) -> ImmutableVarData | None: - """Wrapper method for cached property. - - Returns: - The VarData of the components and all of its children. - """ - return self._cached_get_all_var_data - - def __post_init__(self): - """Post-initialize the var.""" - pass - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LiteralBooleanVar(LiteralVar): - """Base class for immutable literal boolean vars.""" - - _var_value: bool = dataclasses.field(default=False) - - def __init__( - self, - _var_value: bool, - _var_data: VarData | None = None, - ): - """Initialize the boolean var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralBooleanVar, self).__init__( - _var_name="true" if _var_value else "false", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_var_value", _var_value) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LiteralNumberVar(LiteralVar): - """Base class for immutable literal number vars.""" - - _var_value: float | int = dataclasses.field(default=0) - - def __init__( - self, - _var_value: float | int, - _var_data: VarData | None = None, - ): - """Initialize the number var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralNumberVar, self).__init__( - _var_name=str(_var_value), - _var_type=type(_var_value), - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_var_value", _var_value) - - @dataclasses.dataclass( eq=False, frozen=True, @@ -828,7 +380,7 @@ class LiteralObjectVar(LiteralVar): return self._cached_var_name return super(type(self), self).__getattr__(name) - @cached_property + @functools.cached_property def _cached_var_name(self) -> str: """The name of the var. @@ -846,8 +398,8 @@ class LiteralObjectVar(LiteralVar): + " }" ) - @cached_property - def _get_all_var_data(self) -> ImmutableVarData | None: + @functools.cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: """Get all VarData associated with the Var. Returns: @@ -867,89 +419,59 @@ class LiteralObjectVar(LiteralVar): self._var_data, ) - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LiteralArrayVar(LiteralVar): - """Base class for immutable literal array vars.""" - - _var_value: Union[ - List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...] - ] = dataclasses.field(default_factory=list) - - def __init__( - self, - _var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any], - _var_data: VarData | None = None, - ): - """Initialize the array var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralArrayVar, self).__init__( - _var_name="", - _var_data=ImmutableVarData.merge(_var_data), - _var_type=list, - ) - object.__setattr__(self, "_var_value", _var_value) - object.__delattr__(self, "_var_name") - - def __getattr__(self, name): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - if name == "_var_name": - return self._cached_var_name - return super(type(self), self).__getattr__(name) - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return ( - "[" - + ", ".join( - [str(LiteralVar.create(element)) for element in self._var_value] - ) - + "]" - ) - - @cached_property def _get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. + """Wrapper method for cached property. Returns: The VarData of the components and all of its children. """ - return ImmutableVarData.merge( - *[ - var._get_all_var_data() - for var in self._var_value - if isinstance(var, Var) - ], - self._var_data, - ) + return self._cached_get_all_var_data -type_mapping = { - int: LiteralNumberVar, - float: LiteralNumberVar, - bool: LiteralBooleanVar, - dict: LiteralObjectVar, - list: LiteralArrayVar, - tuple: LiteralArrayVar, - set: LiteralArrayVar, -} +P = ParamSpec("P") +T = TypeVar("T", bound=ImmutableVar) + + +def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P, T]]: + """Decorator for creating a var operation. + + Example: + ```python + @var_operation(output=NumberVar) + def add(a: NumberVar, b: NumberVar): + return f"({a} + {b})" + ``` + + Args: + output: The output type of the operation. + + Returns: + The decorator. + """ + + def decorator(func: Callable[P, str], output=output): + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + args_vars = [ + LiteralVar.create(arg) if not isinstance(arg, Var) else arg + for arg in args + ] + kwargs_vars = { + key: LiteralVar.create(value) if not isinstance(value, Var) else value + for key, value in kwargs.items() + } + return output( + _var_name=func(*args_vars, **kwargs_vars), # type: ignore + _var_data=VarData.merge( + *[arg._get_all_var_data() for arg in args if isinstance(arg, Var)], + *[ + arg._get_all_var_data() + for arg in kwargs.values() + if isinstance(arg, Var) + ], + ), + ) + + return wrapper + + return decorator diff --git a/reflex/experimental/vars/function.py b/reflex/experimental/vars/function.py new file mode 100644 index 000000000..f1cf83886 --- /dev/null +++ b/reflex/experimental/vars/function.py @@ -0,0 +1,214 @@ +"""Immutable function vars.""" + +from __future__ import annotations + +import dataclasses +import sys +from functools import cached_property +from typing import Any, Callable, Optional, Tuple, Union + +from reflex.experimental.vars.base import ImmutableVar, LiteralVar +from reflex.vars import ImmutableVarData, Var, VarData + + +class FunctionVar(ImmutableVar): + """Base class for immutable function vars.""" + + def __call__(self, *args: Var | Any) -> ArgsFunctionOperation: + """Call the function with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The function call operation. + """ + return ArgsFunctionOperation( + ("...args",), + VarOperationCall(self, *args, ImmutableVar.create_safe("...args")), + ) + + def call(self, *args: Var | Any) -> VarOperationCall: + """Call the function with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The function call operation. + """ + return VarOperationCall(self, *args) + + +class FunctionStringVar(FunctionVar): + """Base class for immutable function vars from a string.""" + + def __init__(self, func: str, _var_data: VarData | None = None) -> None: + """Initialize the function var. + + Args: + func: The function to call. + _var_data: Additional hooks and imports associated with the Var. + """ + super(FunctionVar, self).__init__( + _var_name=func, + _var_type=Callable, + _var_data=ImmutableVarData.merge(_var_data), + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class VarOperationCall(ImmutableVar): + """Base class for immutable vars that are the result of a function call.""" + + _func: Optional[FunctionVar] = dataclasses.field(default=None) + _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) + + def __init__( + self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None + ): + """Initialize the function call var. + + Args: + func: The function to call. + *args: The arguments to call the function with. + _var_data: Additional hooks and imports associated with the Var. + """ + super(VarOperationCall, self).__init__( + _var_name="", + _var_type=Any, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_func", func) + object.__setattr__(self, "_args", args) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))" + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._func._get_all_var_data() if self._func is not None else None, + *[var._get_all_var_data() for var in self._args], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" + pass + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArgsFunctionOperation(FunctionVar): + """Base class for immutable function defined via arguments and return expression.""" + + _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + _return_expr: Union[Var, Any] = dataclasses.field(default=None) + + def __init__( + self, + args_names: Tuple[str, ...], + return_expr: Var | Any, + _var_data: VarData | None = None, + ) -> None: + """Initialize the function with arguments var. + + Args: + args_names: The names of the arguments. + return_expr: The return expression of the function. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArgsFunctionOperation, self).__init__( + _var_name=f"", + _var_type=Callable, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_args_names", args_names) + object.__setattr__(self, "_return_expr", return_expr) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))" + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._return_expr._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" diff --git a/reflex/experimental/vars/number.py b/reflex/experimental/vars/number.py new file mode 100644 index 000000000..6b74bc336 --- /dev/null +++ b/reflex/experimental/vars/number.py @@ -0,0 +1,1295 @@ +"""Immutable number vars.""" + +from __future__ import annotations + +import dataclasses +import sys +from functools import cached_property +from typing import Any, Union + +from reflex.experimental.vars.base import ( + ImmutableVar, + LiteralVar, +) +from reflex.vars import ImmutableVarData, Var, VarData + + +class NumberVar(ImmutableVar): + """Base class for immutable number vars.""" + + def __add__(self, other: number_types | boolean_types) -> NumberAddOperation: + """Add two numbers. + + Args: + other: The other number. + + Returns: + The number addition operation. + """ + return NumberAddOperation(self, +other) + + def __radd__(self, other: number_types | boolean_types) -> NumberAddOperation: + """Add two numbers. + + Args: + other: The other number. + + Returns: + The number addition operation. + """ + return NumberAddOperation(+other, self) + + def __sub__(self, other: number_types | boolean_types) -> NumberSubtractOperation: + """Subtract two numbers. + + Args: + other: The other number. + + Returns: + The number subtraction operation. + """ + return NumberSubtractOperation(self, +other) + + def __rsub__(self, other: number_types | boolean_types) -> NumberSubtractOperation: + """Subtract two numbers. + + Args: + other: The other number. + + Returns: + The number subtraction operation. + """ + return NumberSubtractOperation(+other, self) + + def __abs__(self) -> NumberAbsoluteOperation: + """Get the absolute value of the number. + + Returns: + The number absolute operation. + """ + return NumberAbsoluteOperation(self) + + def __mul__(self, other: number_types | boolean_types) -> NumberMultiplyOperation: + """Multiply two numbers. + + Args: + other: The other number. + + Returns: + The number multiplication operation. + """ + return NumberMultiplyOperation(self, +other) + + def __rmul__(self, other: number_types | boolean_types) -> NumberMultiplyOperation: + """Multiply two numbers. + + Args: + other: The other number. + + Returns: + The number multiplication operation. + """ + return NumberMultiplyOperation(+other, self) + + def __truediv__(self, other: number_types | boolean_types) -> NumberTrueDivision: + """Divide two numbers. + + Args: + other: The other number. + + Returns: + The number true division operation. + """ + return NumberTrueDivision(self, +other) + + def __rtruediv__(self, other: number_types | boolean_types) -> NumberTrueDivision: + """Divide two numbers. + + Args: + other: The other number. + + Returns: + The number true division operation. + """ + return NumberTrueDivision(+other, self) + + def __floordiv__(self, other: number_types | boolean_types) -> NumberFloorDivision: + """Floor divide two numbers. + + Args: + other: The other number. + + Returns: + The number floor division operation. + """ + return NumberFloorDivision(self, +other) + + def __rfloordiv__(self, other: number_types | boolean_types) -> NumberFloorDivision: + """Floor divide two numbers. + + Args: + other: The other number. + + Returns: + The number floor division operation. + """ + return NumberFloorDivision(+other, self) + + def __mod__(self, other: number_types | boolean_types) -> NumberModuloOperation: + """Modulo two numbers. + + Args: + other: The other number. + + Returns: + The number modulo operation. + """ + return NumberModuloOperation(self, +other) + + def __rmod__(self, other: number_types | boolean_types) -> NumberModuloOperation: + """Modulo two numbers. + + Args: + other: The other number. + + Returns: + The number modulo operation. + """ + return NumberModuloOperation(+other, self) + + def __pow__(self, other: number_types | boolean_types) -> NumberExponentOperation: + """Exponentiate two numbers. + + Args: + other: The other number. + + Returns: + The number exponent operation. + """ + return NumberExponentOperation(self, +other) + + def __rpow__(self, other: number_types | boolean_types) -> NumberExponentOperation: + """Exponentiate two numbers. + + Args: + other: The other number. + + Returns: + The number exponent operation. + """ + return NumberExponentOperation(+other, self) + + def __neg__(self) -> NumberNegateOperation: + """Negate the number. + + Returns: + The number negation operation. + """ + return NumberNegateOperation(self) + + def __and__(self, other: number_types | boolean_types) -> BooleanAndOperation: + """Boolean AND two numbers. + + Args: + other: The other number. + + Returns: + The boolean AND operation. + """ + boolified_other = other.bool() if isinstance(other, Var) else bool(other) + return BooleanAndOperation(self.bool(), boolified_other) + + def __rand__(self, other: number_types | boolean_types) -> BooleanAndOperation: + """Boolean AND two numbers. + + Args: + other: The other number. + + Returns: + The boolean AND operation. + """ + boolified_other = other.bool() if isinstance(other, Var) else bool(other) + return BooleanAndOperation(boolified_other, self.bool()) + + def __or__(self, other: number_types | boolean_types) -> BooleanOrOperation: + """Boolean OR two numbers. + + Args: + other: The other number. + + Returns: + The boolean OR operation. + """ + boolified_other = other.bool() if isinstance(other, Var) else bool(other) + return BooleanOrOperation(self.bool(), boolified_other) + + def __ror__(self, other: number_types | boolean_types) -> BooleanOrOperation: + """Boolean OR two numbers. + + Args: + other: The other number. + + Returns: + The boolean OR operation. + """ + boolified_other = other.bool() if isinstance(other, Var) else bool(other) + return BooleanOrOperation(boolified_other, self.bool()) + + def __invert__(self) -> BooleanNotOperation: + """Boolean NOT the number. + + Returns: + The boolean NOT operation. + """ + return BooleanNotOperation(self.bool()) + + def __pos__(self) -> NumberVar: + """Positive the number. + + Returns: + The number. + """ + return self + + def __round__(self) -> NumberRoundOperation: + """Round the number. + + Returns: + The number round operation. + """ + return NumberRoundOperation(self) + + def __ceil__(self) -> NumberCeilOperation: + """Ceil the number. + + Returns: + The number ceil operation. + """ + return NumberCeilOperation(self) + + def __floor__(self) -> NumberFloorOperation: + """Floor the number. + + Returns: + The number floor operation. + """ + return NumberFloorOperation(self) + + def __trunc__(self) -> NumberTruncOperation: + """Trunc the number. + + Returns: + The number trunc operation. + """ + return NumberTruncOperation(self) + + def __lt__(self, other: number_types | boolean_types) -> LessThanOperation: + """Less than comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return LessThanOperation(self, +other) + + def __le__(self, other: number_types | boolean_types) -> LessThanOrEqualOperation: + """Less than or equal comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return LessThanOrEqualOperation(self, +other) + + def __eq__(self, other: number_types | boolean_types) -> EqualOperation: + """Equal comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return EqualOperation(self, +other) + + def __ne__(self, other: number_types | boolean_types) -> NotEqualOperation: + """Not equal comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return NotEqualOperation(self, +other) + + def __gt__(self, other: number_types | boolean_types) -> GreaterThanOperation: + """Greater than comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return GreaterThanOperation(self, +other) + + def __ge__( + self, other: number_types | boolean_types + ) -> GreaterThanOrEqualOperation: + """Greater than or equal comparison. + + Args: + other: The other number. + + Returns: + The result of the comparison. + """ + return GreaterThanOrEqualOperation(self, +other) + + def bool(self) -> NotEqualOperation: + """Boolean conversion. + + Returns: + The boolean value of the number. + """ + return NotEqualOperation(self, 0) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class BinaryNumberOperation(NumberVar): + """Base class for immutable number vars that are the result of a binary operation.""" + + a: number_types = dataclasses.field(default=0) + b: number_types = dataclasses.field(default=0) + + def __init__( + self, + a: number_types, + b: number_types, + _var_data: VarData | None = None, + ): + """Initialize the binary number operation var. + + Args: + a: The first number. + b: The second number. + _var_data: Additional hooks and imports associated with the Var. + """ + super(BinaryNumberOperation, self).__init__( + _var_name="", + _var_type=float, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__setattr__(self, "b", b) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError( + "BinaryNumberOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(BinaryNumberOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return ImmutableVarData.merge( + first_value._get_all_var_data(), + second_value._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class UnaryNumberOperation(NumberVar): + """Base class for immutable number vars that are the result of a unary operation.""" + + a: number_types = dataclasses.field(default=0) + + def __init__( + self, + a: number_types, + _var_data: VarData | None = None, + ): + """Initialize the unary number operation var. + + Args: + a: The number. + _var_data: Additional hooks and imports associated with the Var. + """ + super(UnaryNumberOperation, self).__init__( + _var_name="", + _var_type=float, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError( + "UnaryNumberOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(UnaryNumberOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return ImmutableVarData.merge(value._get_all_var_data(), self._var_data) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class NumberAddOperation(BinaryNumberOperation): + """Base class for immutable number vars that are the result of an addition operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} + {str(second_value)})" + + +class NumberSubtractOperation(BinaryNumberOperation): + """Base class for immutable number vars that are the result of a subtraction operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} - {str(second_value)})" + + +class NumberAbsoluteOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of an absolute operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"Math.abs({str(value)})" + + +class NumberMultiplyOperation(BinaryNumberOperation): + """Base class for immutable number vars that are the result of a multiplication operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} * {str(second_value)})" + + +class NumberNegateOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of a negation operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"-({str(value)})" + + +class NumberTrueDivision(BinaryNumberOperation): + """Base class for immutable number vars that are the result of a true division operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} / {str(second_value)})" + + +class NumberFloorDivision(BinaryNumberOperation): + """Base class for immutable number vars that are the result of a floor division operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"Math.floor({str(first_value)} / {str(second_value)})" + + +class NumberModuloOperation(BinaryNumberOperation): + """Base class for immutable number vars that are the result of a modulo operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} % {str(second_value)})" + + +class NumberExponentOperation(BinaryNumberOperation): + """Base class for immutable number vars that are the result of an exponent operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + return f"({str(first_value)} ** {str(second_value)})" + + +class NumberRoundOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of a round operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"Math.round({str(value)})" + + +class NumberCeilOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of a ceil operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"Math.ceil({str(value)})" + + +class NumberFloorOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of a floor operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"Math.floor({str(value)})" + + +class NumberTruncOperation(UnaryNumberOperation): + """Base class for immutable number vars that are the result of a trunc operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + return f"Math.trunc({str(value)})" + + +class BooleanVar(ImmutableVar): + """Base class for immutable boolean vars.""" + + def __and__(self, other: bool) -> BooleanAndOperation: + """AND two booleans. + + Args: + other: The other boolean. + + Returns: + The boolean AND operation. + """ + return BooleanAndOperation(self, other) + + def __rand__(self, other: bool) -> BooleanAndOperation: + """AND two booleans. + + Args: + other: The other boolean. + + Returns: + The boolean AND operation. + """ + return BooleanAndOperation(other, self) + + def __or__(self, other: bool) -> BooleanOrOperation: + """OR two booleans. + + Args: + other: The other boolean. + + Returns: + The boolean OR operation. + """ + return BooleanOrOperation(self, other) + + def __ror__(self, other: bool) -> BooleanOrOperation: + """OR two booleans. + + Args: + other: The other boolean. + + Returns: + The boolean OR operation. + """ + return BooleanOrOperation(other, self) + + def __invert__(self) -> BooleanNotOperation: + """NOT the boolean. + + Returns: + The boolean NOT operation. + """ + return BooleanNotOperation(self) + + def __int__(self) -> BooleanToIntOperation: + """Convert the boolean to an int. + + Returns: + The boolean to int operation. + """ + return BooleanToIntOperation(self) + + def __pos__(self) -> BooleanToIntOperation: + """Convert the boolean to an int. + + Returns: + The boolean to int operation. + """ + return BooleanToIntOperation(self) + + def bool(self) -> BooleanVar: + """Boolean conversion. + + Returns: + The boolean value of the boolean. + """ + return self + + def __lt__(self, other: boolean_types | number_types) -> LessThanOperation: + """Less than comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return LessThanOperation(+self, +other) + + def __le__(self, other: boolean_types | number_types) -> LessThanOrEqualOperation: + """Less than or equal comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return LessThanOrEqualOperation(+self, +other) + + def __eq__(self, other: boolean_types | number_types) -> EqualOperation: + """Equal comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return EqualOperation(+self, +other) + + def __ne__(self, other: boolean_types | number_types) -> NotEqualOperation: + """Not equal comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return NotEqualOperation(+self, +other) + + def __gt__(self, other: boolean_types | number_types) -> GreaterThanOperation: + """Greater than comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return GreaterThanOperation(+self, +other) + + def __ge__( + self, other: boolean_types | number_types + ) -> GreaterThanOrEqualOperation: + """Greater than or equal comparison. + + Args: + other: The other boolean. + + Returns: + The result of the comparison. + """ + return GreaterThanOrEqualOperation(+self, +other) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class BooleanToIntOperation(NumberVar): + """Base class for immutable number vars that are the result of a boolean to int operation.""" + + a: boolean_types = dataclasses.field(default=False) + + def __init__( + self, + a: boolean_types, + _var_data: VarData | None = None, + ): + """Initialize the boolean to int operation var. + + Args: + a: The boolean. + _var_data: Additional hooks and imports associated with the Var. + """ + super(BooleanToIntOperation, self).__init__( + _var_name="", + _var_type=int, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"({str(self.a)} ? 1 : 0)" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(BooleanToIntOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data() if isinstance(self.a, Var) else None, + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class NumberComparisonOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a comparison operation.""" + + a: number_types = dataclasses.field(default=0) + b: number_types = dataclasses.field(default=0) + + def __init__( + self, + a: number_types, + b: number_types, + _var_data: VarData | None = None, + ): + """Initialize the comparison operation var. + + Args: + a: The first value. + b: The second value. + _var_data: Additional hooks and imports associated with the Var. + """ + super(NumberComparisonOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__setattr__(self, "b", b) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError("ComparisonOperation must implement _cached_var_name") + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(NumberComparisonOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return ImmutableVarData.merge( + first_value._get_all_var_data(), second_value._get_all_var_data() + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class GreaterThanOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of a greater than operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} > {str(second_value)})" + + +class GreaterThanOrEqualOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of a greater than or equal operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} >= {str(second_value)})" + + +class LessThanOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of a less than operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} < {str(second_value)})" + + +class LessThanOrEqualOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of a less than or equal operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} <= {str(second_value)})" + + +class EqualOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of an equal operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} == {str(second_value)})" + + +class NotEqualOperation(NumberComparisonOperation): + """Base class for immutable boolean vars that are the result of a not equal operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} != {str(second_value)})" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LogicalOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a logical operation.""" + + a: boolean_types = dataclasses.field(default=False) + b: boolean_types = dataclasses.field(default=False) + + def __init__( + self, a: boolean_types, b: boolean_types, _var_data: VarData | None = None + ): + """Initialize the logical operation var. + + Args: + a: The first value. + b: The second value. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LogicalOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__setattr__(self, "b", b) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError("LogicalOperation must implement _cached_var_name") + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(LogicalOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return ImmutableVarData.merge( + first_value._get_all_var_data(), second_value._get_all_var_data() + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class BooleanAndOperation(LogicalOperation): + """Base class for immutable boolean vars that are the result of a logical AND operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} && {str(second_value)})" + + +class BooleanOrOperation(LogicalOperation): + """Base class for immutable boolean vars that are the result of a logical OR operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) + return f"({str(first_value)} || {str(second_value)})" + + +class BooleanNotOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a logical NOT operation.""" + + a: boolean_types = dataclasses.field() + + def __init__(self, a: boolean_types, _var_data: VarData | None = None): + """Initialize the logical NOT operation var. + + Args: + a: The value. + _var_data: Additional hooks and imports associated with the Var. + """ + super(BooleanNotOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + return f"!({str(value)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(BooleanNotOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) + return ImmutableVarData.merge(value._get_all_var_data()) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralBooleanVar(LiteralVar, BooleanVar): + """Base class for immutable literal boolean vars.""" + + _var_value: bool = dataclasses.field(default=False) + + def __init__( + self, + _var_value: bool, + _var_data: VarData | None = None, + ): + """Initialize the boolean var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralBooleanVar, self).__init__( + _var_name="true" if _var_value else "false", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralNumberVar(LiteralVar, NumberVar): + """Base class for immutable literal number vars.""" + + _var_value: float | int = dataclasses.field(default=0) + + def __init__( + self, + _var_value: float | int, + _var_data: VarData | None = None, + ): + """Initialize the number var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralNumberVar, self).__init__( + _var_name=str(_var_value), + _var_type=type(_var_value), + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) + + def __hash__(self) -> int: + """Hash the var. + + Returns: + The hash of the var. + """ + return hash(self._var_value) + + +number_types = Union[NumberVar, LiteralNumberVar, int, float] +boolean_types = Union[BooleanVar, LiteralBooleanVar, bool] diff --git a/reflex/experimental/vars/sequence.py b/reflex/experimental/vars/sequence.py new file mode 100644 index 000000000..c0e8bb9d7 --- /dev/null +++ b/reflex/experimental/vars/sequence.py @@ -0,0 +1,1039 @@ +"""Collection of string classes and utilities.""" + +from __future__ import annotations + +import dataclasses +import functools +import json +import re +import sys +from functools import cached_property +from typing import Any, List, Set, Tuple, Union + +from reflex import constants +from reflex.constants.base import REFLEX_VAR_OPENING_TAG +from reflex.experimental.vars.base import ( + ImmutableVar, + LiteralVar, +) +from reflex.experimental.vars.number import BooleanVar, NotEqualOperation, NumberVar +from reflex.vars import ImmutableVarData, Var, VarData, _global_vars + + +class StringVar(ImmutableVar): + """Base class for immutable string vars.""" + + def __add__(self, other: StringVar | str) -> ConcatVarOperation: + """Concatenate two strings. + + Args: + other: The other string. + + Returns: + The string concatenation operation. + """ + return ConcatVarOperation(self, other) + + def __radd__(self, other: StringVar | str) -> ConcatVarOperation: + """Concatenate two strings. + + Args: + other: The other string. + + Returns: + The string concatenation operation. + """ + return ConcatVarOperation(other, self) + + def __mul__(self, other: int) -> ConcatVarOperation: + """Concatenate two strings. + + Args: + other: The other string. + + Returns: + The string concatenation operation. + """ + return ConcatVarOperation(*[self for _ in range(other)]) + + def __rmul__(self, other: int) -> ConcatVarOperation: + """Concatenate two strings. + + Args: + other: The other string. + + Returns: + The string concatenation operation. + """ + return ConcatVarOperation(*[self for _ in range(other)]) + + def __getitem__(self, i: slice | int) -> StringSliceOperation | StringItemOperation: + """Get a slice of the string. + + Args: + i: The slice. + + Returns: + The string slice operation. + """ + if isinstance(i, slice): + return StringSliceOperation(self, i) + return StringItemOperation(self, i) + + def length(self) -> StringLengthOperation: + """Get the length of the string. + + Returns: + The string length operation. + """ + return StringLengthOperation(self) + + def lower(self) -> StringLowerOperation: + """Convert the string to lowercase. + + Returns: + The string lower operation. + """ + return StringLowerOperation(self) + + def upper(self) -> StringUpperOperation: + """Convert the string to uppercase. + + Returns: + The string upper operation. + """ + return StringUpperOperation(self) + + def strip(self) -> StringStripOperation: + """Strip the string. + + Returns: + The string strip operation. + """ + return StringStripOperation(self) + + def bool(self) -> NotEqualOperation: + """Boolean conversion. + + Returns: + The boolean value of the string. + """ + return NotEqualOperation(self.length(), 0) + + def reversed(self) -> StringReverseOperation: + """Reverse the string. + + Returns: + The string reverse operation. + """ + return StringReverseOperation(self) + + def contains(self, other: StringVar | str) -> StringContainsOperation: + """Check if the string contains another string. + + Args: + other: The other string. + + Returns: + The string contains operation. + """ + return StringContainsOperation(self, other) + + def split(self, separator: StringVar | str = "") -> StringSplitOperation: + """Split the string. + + Args: + separator: The separator. + + Returns: + The string split operation. + """ + return StringSplitOperation(self, separator) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringToNumberOperation(NumberVar): + """Base class for immutable number vars that are the result of a string to number operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__(self, a: StringVar | str, _var_data: VarData | None = None): + """Initialize the string to number operation var. + + Args: + a: The string. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringToNumberOperation, self).__init__( + _var_name="", + _var_type=float, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError( + "StringToNumberOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringToNumberOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class StringLengthOperation(StringToNumberOperation): + """Base class for immutable number vars that are the result of a string length operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.length" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringToStringOperation(StringVar): + """Base class for immutable string vars that are the result of a string to string operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__(self, a: StringVar | str, _var_data: VarData | None = None): + """Initialize the string to string operation var. + + Args: + a: The string. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringToStringOperation, self).__init__( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError( + "StringToStringOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringToStringOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data() if isinstance(self.a, Var) else None, + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class StringLowerOperation(StringToStringOperation): + """Base class for immutable string vars that are the result of a string lower operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.toLowerCase()" + + +class StringUpperOperation(StringToStringOperation): + """Base class for immutable string vars that are the result of a string upper operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.toUpperCase()" + + +class StringStripOperation(StringToStringOperation): + """Base class for immutable string vars that are the result of a string strip operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.trim()" + + +class StringReverseOperation(StringToStringOperation): + """Base class for immutable string vars that are the result of a string reverse operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.split('').reverse().join('')" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringContainsOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a string contains operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + b: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__( + self, a: StringVar | str, b: StringVar | str, _var_data: VarData | None = None + ): + """Initialize the string contains operation var. + + Args: + a: The first string. + b: The second string. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringContainsOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__setattr__( + self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.includes({str(self.b)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringContainsOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringSliceOperation(StringVar): + """Base class for immutable string vars that are the result of a string slice operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + _slice: slice = dataclasses.field(default_factory=lambda: slice(None, None, None)) + + def __init__( + self, a: StringVar | str, _slice: slice, _var_data: VarData | None = None + ): + """Initialize the string slice operation var. + + Args: + a: The string. + _slice: The slice. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringSliceOperation, self).__init__( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__setattr__(self, "_slice", _slice) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + + Raises: + ValueError: If the slice step is zero. + """ + start, end, step = self._slice.start, self._slice.stop, self._slice.step + + if step is not None and step < 0: + actual_start = end + 1 if end is not None else 0 + actual_end = start + 1 if start is not None else self.a.length() + return str( + StringSliceOperation( + StringReverseOperation( + StringSliceOperation(self.a, slice(actual_start, actual_end)) + ), + slice(None, None, -step), + ) + ) + + start = ( + LiteralVar.create(start) + if start is not None + else ImmutableVar.create_safe("undefined") + ) + end = ( + LiteralVar.create(end) + if end is not None + else ImmutableVar.create_safe("undefined") + ) + + if step is None: + return f"{str(self.a)}.slice({str(start)}, {str(end)})" + if step == 0: + raise ValueError("slice step cannot be zero") + return f"{str(self.a)}.slice({str(start)}, {str(end)}).split('').filter((_, i) => i % {str(step)} === 0).join('')" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringSliceOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), + self.start._get_all_var_data(), + self.end._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringItemOperation(StringVar): + """Base class for immutable string vars that are the result of a string item operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + i: int = dataclasses.field(default=0) + + def __init__(self, a: StringVar | str, i: int, _var_data: VarData | None = None): + """Initialize the string item operation var. + + Args: + a: The string. + i: The index. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringItemOperation, self).__init__( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__setattr__(self, "i", i) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.at({str(self.i)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringItemOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class ArrayJoinOperation(StringVar): + """Base class for immutable string vars that are the result of an array join operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + b: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__( + self, a: ArrayVar | list, b: StringVar | str, _var_data: VarData | None = None + ): + """Initialize the array join operation var. + + Args: + a: The array. + b: The separator. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArrayJoinOperation, self).__init__( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralArrayVar.create(a) + ) + object.__setattr__( + self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.join({str(self.b)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayJoinOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +# Compile regex for finding reflex var tags. +_decode_var_pattern_re = ( + rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}" +) +_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralStringVar(LiteralVar, StringVar): + """Base class for immutable literal string vars.""" + + _var_value: str = dataclasses.field(default="") + + def __init__( + self, + _var_value: str, + _var_data: VarData | None = None, + ): + """Initialize the string var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralStringVar, self).__init__( + _var_name=f'"{_var_value}"', + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_var_value", _var_value) + + @classmethod + def create( + cls, + value: str, + _var_data: VarData | None = None, + ) -> LiteralStringVar | ConcatVarOperation: + """Create a var from a string value. + + Args: + value: The value to create the var from. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + if REFLEX_VAR_OPENING_TAG in value: + strings_and_vals: list[Var | str] = [] + offset = 0 + + # Initialize some methods for reading json. + var_data_config = VarData().__config__ + + def json_loads(s): + try: + return var_data_config.json_loads(s) + except json.decoder.JSONDecodeError: + return var_data_config.json_loads( + var_data_config.json_loads(f'"{s}"') + ) + + # Find all tags + while m := _decode_var_pattern.search(value): + start, end = m.span() + if start > 0: + strings_and_vals.append(value[:start]) + + serialized_data = m.group(1) + + if serialized_data.isnumeric() or ( + serialized_data[0] == "-" and serialized_data[1:].isnumeric() + ): + # This is a global immutable var. + var = _global_vars[int(serialized_data)] + strings_and_vals.append(var) + value = value[(end + len(var._var_name)) :] + else: + data = json_loads(serialized_data) + string_length = data.pop("string_length", None) + var_data = VarData.parse_obj(data) + + # Use string length to compute positions of interpolations. + if string_length is not None: + realstart = start + offset + var_data.interpolations = [ + (realstart, realstart + string_length) + ] + strings_and_vals.append( + ImmutableVar.create_safe( + value[end : (end + string_length)], _var_data=var_data + ) + ) + value = value[(end + string_length) :] + + offset += end - start + + if value: + strings_and_vals.append(value) + + return ConcatVarOperation(*strings_and_vals, _var_data=_var_data) + + return LiteralStringVar( + value, + _var_data=_var_data, + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ConcatVarOperation(StringVar): + """Representing a concatenation of literal string vars.""" + + _var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple) + + def __init__(self, *value: Var | str, _var_data: VarData | None = None): + """Initialize the operation of concatenating literal string vars. + + Args: + value: The values to concatenate. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ConcatVarOperation, self).__init__( + _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str + ) + object.__setattr__(self, "_var_value", value) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "(" + + "+".join( + [ + str(element) if isinstance(element, Var) else f'"{element}"' + for element in self._var_value + ] + ) + + ")" + ) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[ + var._get_all_var_data() + for var in self._var_value + if isinstance(var, Var) + ], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" + pass + + +class ArrayVar(ImmutableVar): + """Base class for immutable array vars.""" + + from reflex.experimental.vars.sequence import StringVar + + def join(self, sep: StringVar | str = "") -> ArrayJoinOperation: + """Join the elements of the array. + + Args: + sep: The separator between elements. + + Returns: + The joined elements. + """ + from reflex.experimental.vars.sequence import ArrayJoinOperation + + return ArrayJoinOperation(self, sep) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralArrayVar(LiteralVar, ArrayVar): + """Base class for immutable literal array vars.""" + + _var_value: Union[ + List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...] + ] = dataclasses.field(default_factory=list) + + def __init__( + self, + _var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any], + _var_data: VarData | None = None, + ): + """Initialize the array var. + + Args: + _var_value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralArrayVar, self).__init__( + _var_name="", + _var_data=ImmutableVarData.merge(_var_data), + _var_type=list, + ) + object.__setattr__(self, "_var_value", _var_value) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @functools.cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "[" + + ", ".join( + [str(LiteralVar.create(element)) for element in self._var_value] + ) + + "]" + ) + + @functools.cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[ + var._get_all_var_data() + for var in self._var_value + if isinstance(var, Var) + ], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringSplitOperation(ArrayVar): + """Base class for immutable array vars that are the result of a string split operation.""" + + a: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + b: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__( + self, a: StringVar | str, b: StringVar | str, _var_data: VarData | None = None + ): + """Initialize the string split operation var. + + Args: + a: The string. + b: The separator. + _var_data: Additional hooks and imports associated with the Var. + """ + super(StringSplitOperation, self).__init__( + _var_name="", + _var_type=list, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) + ) + object.__setattr__( + self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.split({str(self.b)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringSplitOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data diff --git a/reflex/vars.py b/reflex/vars.py index c6ad4eed5..f857cf03e 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -379,7 +379,9 @@ def _decode_var_immutable(value: str) -> tuple[ImmutableVarData | None, str]: serialized_data = m.group(1) - if serialized_data[1:].isnumeric(): + if serialized_data.isnumeric() or ( + serialized_data[0] == "-" and serialized_data[1:].isnumeric() + ): # This is a global immutable var. var = _global_vars[int(serialized_data)] var_data = var._var_data @@ -473,7 +475,9 @@ def _decode_var(value: str) -> tuple[VarData | None, str]: serialized_data = m.group(1) - if serialized_data[1:].isnumeric(): + if serialized_data.isnumeric() or ( + serialized_data[0] == "-" and serialized_data[1:].isnumeric() + ): # This is a global immutable var. var = _global_vars[int(serialized_data)] var_data = var._var_data diff --git a/tests/test_var.py b/tests/test_var.py index 47d4f223b..761375464 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -1,4 +1,5 @@ import json +import math import typing from typing import Dict, List, Set, Tuple, Union @@ -8,13 +9,17 @@ from pandas import DataFrame from reflex.base import Base from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.experimental.vars.base import ( - ArgsFunctionOperation, - ConcatVarOperation, - FunctionStringVar, ImmutableVar, - LiteralStringVar, LiteralVar, + var_operation, ) +from reflex.experimental.vars.function import ArgsFunctionOperation, FunctionStringVar +from reflex.experimental.vars.number import ( + LiteralBooleanVar, + LiteralNumberVar, + NumberVar, +) +from reflex.experimental.vars.sequence import ConcatVarOperation, LiteralStringVar from reflex.state import BaseState from reflex.utils.imports import ImportVar from reflex.vars import ( @@ -913,6 +918,60 @@ def test_function_var(): ) +def test_var_operation(): + @var_operation(output=NumberVar) + def add(a: Union[NumberVar, int], b: Union[NumberVar, int]) -> str: + return f"({a} + {b})" + + assert str(add(1, 2)) == "(1 + 2)" + assert str(add(a=4, b=-9)) == "(4 + -9)" + + five = LiteralNumberVar(5) + seven = add(2, five) + + assert isinstance(seven, NumberVar) + + +def test_string_operations(): + basic_string = LiteralStringVar.create("Hello, World!") + + assert str(basic_string.length()) == '"Hello, World!".length' + assert str(basic_string.lower()) == '"Hello, World!".toLowerCase()' + assert str(basic_string.upper()) == '"Hello, World!".toUpperCase()' + assert str(basic_string.strip()) == '"Hello, World!".trim()' + assert str(basic_string.contains("World")) == '"Hello, World!".includes("World")' + assert ( + str(basic_string.split(" ").join(",")) == '"Hello, World!".split(" ").join(",")' + ) + + +def test_all_number_operations(): + starting_number = LiteralNumberVar(-5.4) + + complicated_number = (((-(starting_number + 1)) * 2 / 3) // 2 % 3) ** 2 + + assert ( + str(complicated_number) + == "((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)" + ) + + even_more_complicated_number = ~( + abs(math.floor(complicated_number)) | 2 & 3 & round(complicated_number) + ) + + assert ( + str(even_more_complicated_number) + == "!(((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) != 0) || (true && (Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)) != 0))))" + ) + + assert str(LiteralNumberVar(5) > False) == "(5 > 0)" + assert str(LiteralBooleanVar(False) < 5) == "((false ? 1 : 0) < 5)" + assert ( + str(LiteralBooleanVar(False) < LiteralBooleanVar(True)) + == "((false ? 1 : 0) < (true ? 1 : 0))" + ) + + def test_retrival(): var_without_data = ImmutableVar.create("test") assert var_without_data is not None From d389f4b5cab8bff5060d757ff4e12785c36b735f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Thu, 25 Jul 2024 19:50:18 +0200 Subject: [PATCH 11/34] fix var warning (#3704) --- reflex/components/recharts/cartesian.py | 2 +- reflex/components/recharts/general.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/reflex/components/recharts/cartesian.py b/reflex/components/recharts/cartesian.py index e3f086a3d..710fef19b 100644 --- a/reflex/components/recharts/cartesian.py +++ b/reflex/components/recharts/cartesian.py @@ -299,7 +299,7 @@ class Area(Cartesian): fill: Var[Union[str, Color]] = Var.create_safe(Color("accent", 5)) # The interpolation type of area. And customized interpolation function can be set to type. 'basis' | 'basisClosed' | 'basisOpen' | 'bumpX' | 'bumpY' | 'bump' | 'linear' | 'linearClosed' | 'natural' | 'monotoneX' | 'monotoneY' | 'monotone' | 'step' | 'stepBefore' | 'stepAfter' | - type_: Var[LiteralAreaType] = Var.create_safe("monotone") + type_: Var[LiteralAreaType] = Var.create_safe("monotone", _var_is_string=True) # If false set, dots will not be drawn. If true set, dots will be drawn which have the props calculated internally. dot: Var[Union[bool, Dict[str, Any]]] diff --git a/reflex/components/recharts/general.py b/reflex/components/recharts/general.py index 581907785..613e6fbf0 100644 --- a/reflex/components/recharts/general.py +++ b/reflex/components/recharts/general.py @@ -234,7 +234,7 @@ class LabelList(Recharts): fill: Var[Union[str, Color]] = Var.create_safe(Color("gray", 10)) # The stroke color of each label - stroke: Var[Union[str, Color]] = Var.create_safe("none") + stroke: Var[Union[str, Color]] = Var.create_safe("none", _var_is_string=True) responsive_container = ResponsiveContainer.create From c4346c262438165f109ba556f51f15ac8acbab1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Thu, 25 Jul 2024 19:51:44 +0200 Subject: [PATCH 12/34] update init prompt to use new templates from reflex-dev/templates (#3677) --- reflex/utils/prerequisites.py | 104 ++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 36 deletions(-) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 40a2338d3..765560129 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -1311,39 +1311,63 @@ def migrate_to_reflex(): print(line, end="") -def fetch_app_templates() -> dict[str, Template]: - """Fetch the list of app templates from the Reflex backend server. +RELEASES_URL = f"https://api.github.com/repos/reflex-dev/templates/releases" + + +def fetch_app_templates(version: str) -> dict[str, Template]: + """Fetch a dict of templates from the templates repo using github API. + + Args: + version: The version of the templates to fetch. Returns: - The name and download URL as a dictionary. + The dict of templates. """ - config = get_config() - if not config.cp_backend_url: - console.info( - "Skip fetching App templates. No backend URL is specified in the config." - ) - return {} - try: - response = httpx.get( - f"{config.cp_backend_url}{constants.Templates.APP_TEMPLATES_ROUTE}" - ) + + def get_release_by_tag(tag: str) -> dict | None: + response = httpx.get(RELEASES_URL) response.raise_for_status() - return { - template["name"]: Template.parse_obj(template) - for template in response.json() - } - except httpx.HTTPError as ex: - console.info(f"Failed to fetch app templates: {ex}") - return {} - except (TypeError, KeyError, json.JSONDecodeError) as tkje: - console.info(f"Unable to process server response for app templates: {tkje}") + releases = response.json() + for release in releases: + if release["tag_name"] == f"v{tag}": + return release + return None + + release = get_release_by_tag(version) + if release is None: + console.warn(f"No templates known for version {version}") return {} + assets = release.get("assets", []) + asset = next((a for a in assets if a["name"] == "templates.json"), None) + if asset is None: + console.warn(f"Templates metadata not found for version {version}") + return {} + else: + templates_url = asset["browser_download_url"] -def create_config_init_app_from_remote_template( - app_name: str, - template_url: str, -): + templates_data = httpx.get(templates_url, follow_redirects=True).json()["templates"] + + for template in templates_data: + if template["name"] == "blank": + template["code_url"] = "" + continue + template["code_url"] = next( + ( + a["browser_download_url"] + for a in assets + if a["name"] == f"{template['name']}.zip" + ), + None, + ) + return { + tp["name"]: Template.parse_obj(tp) + for tp in templates_data + if not tp["hidden"] and tp["code_url"] is not None + } + + +def create_config_init_app_from_remote_template(app_name: str, template_url: str): """Create new rxconfig and initialize app using a remote template. Args: @@ -1437,15 +1461,20 @@ def initialize_app(app_name: str, template: str | None = None): telemetry.send("reinit") return - # Get the available templates - templates: dict[str, Template] = fetch_app_templates() + templates: dict[str, Template] = {} - # Prompt for a template if not provided. - if template is None and len(templates) > 0: - template = prompt_for_template(list(templates.values())) - elif template is None: - template = constants.Templates.DEFAULT - assert template is not None + # Don't fetch app templates if the user directly asked for DEFAULT. + if template is None or (template != constants.Templates.DEFAULT): + try: + # Get the available templates + templates = fetch_app_templates(constants.Reflex.VERSION) + if template is None and len(templates) > 0: + template = prompt_for_template(list(templates.values())) + except Exception as e: + console.warn("Failed to fetch templates. Falling back to default template.") + console.debug(f"Error while fetching templates: {e}") + finally: + template = template or constants.Templates.DEFAULT # If the blank template is selected, create a blank app. if template == constants.Templates.DEFAULT: @@ -1468,9 +1497,12 @@ def initialize_app(app_name: str, template: str | None = None): else: console.error(f"Template `{template}` not found.") raise typer.Exit(1) + + if template_url is None: + return + create_config_init_app_from_remote_template( - app_name=app_name, - template_url=template_url, + app_name=app_name, template_url=template_url ) telemetry.send("init", template=template) From a4e3f0560104fadf75f76ee4bcef0a720f5022b8 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 26 Jul 2024 17:10:08 -0700 Subject: [PATCH 13/34] [REF-3375] useMemo on generateUUID props to maintain consistent value (#3708) When using rx.vars.get_uuid_string_var, wrap the prop Var in `useMemo` so that the value remains consistent across re-renders of the component. Fix #3707 --- reflex/vars.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/reflex/vars.py b/reflex/vars.py index f857cf03e..911915534 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -2552,17 +2552,25 @@ class CallableVar(BaseVar): def get_uuid_string_var() -> Var: - """Return a var that generates UUIDs via .web/utils/state.js. + """Return a Var that generates a single memoized UUID via .web/utils/state.js. + + useMemo with an empty dependency array ensures that the generated UUID is + consistent across re-renders of the component. Returns: - the var to generate UUIDs at runtime. + A Var that generates a UUID at runtime. """ from reflex.utils.imports import ImportVar unique_uuid_var_data = VarData( - imports={f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="generateUUID")}} # type: ignore + imports={ + f"/{constants.Dirs.STATE_PATH}": {ImportVar(tag="generateUUID")}, # type: ignore + "react": "useMemo", + } ) return BaseVar( - _var_name="generateUUID()", _var_type=str, _var_data=unique_uuid_var_data + _var_name="useMemo(generateUUID, [])", + _var_type=str, + _var_data=unique_uuid_var_data, ) From 2e726f1bb94a2a03d4aec362a487141c1032dd7e Mon Sep 17 00:00:00 2001 From: paoloemilioserra Date: Mon, 29 Jul 2024 02:51:08 +0200 Subject: [PATCH 14/34] Update vars.py (#3659) Prevent a validation error from pydantic/v1 that cannot find _var_name, etc. in __dataclass_fields__ --- reflex/vars.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/reflex/vars.py b/reflex/vars.py index 911915534..00f02804c 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -2168,6 +2168,24 @@ class ComputedVar(Var, property): # Interval at which the computed var should be updated _update_interval: Optional[datetime.timedelta] = dataclasses.field(default=None) + # The name of the var. + _var_name: str = dataclasses.field() + + # The type of the var. + _var_type: Type = dataclasses.field(default=Any) + + # Whether this is a local javascript variable. + _var_is_local: bool = dataclasses.field(default=False) + + # Whether the var is a string literal. + _var_is_string: bool = dataclasses.field(default=False) + + # _var_full_name should be prefixed with _var_state + _var_full_name_needs_state_prefix: bool = dataclasses.field(default=False) + + # Extra metadata associated with the Var + _var_data: Optional[VarData] = dataclasses.field(default=None) + def __init__( self, fget: Callable[[BaseState], Any], From 4cdd87d8511063644014807ad9cabef2f6f9ad2c Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 29 Jul 2024 09:52:45 -0700 Subject: [PATCH 15/34] Bump to v0.5.8 (#3716) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4900510cb..0feb287c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "reflex" -version = "0.5.7" +version = "0.5.8" description = "Web apps in pure Python." license = "Apache-2.0" authors = [ From 800685da6815d5aaf2fdf7d266a926d461023037 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Mon, 29 Jul 2024 16:57:33 -0700 Subject: [PATCH 16/34] fix silly bug when style is set directly to breakpoints (#3719) * fix silly bug when style is set directly to breakpoints * add helpful comment Co-authored-by: Masen Furer --------- Co-authored-by: Masen Furer --- reflex/components/component.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/reflex/components/component.py b/reflex/components/component.py index 1786dd060..bb3e9053f 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -24,6 +24,7 @@ from typing import ( import reflex.state from reflex.base import Base from reflex.compiler.templates import STATEFUL_COMPONENT +from reflex.components.core.breakpoints import Breakpoints from reflex.components.tags import Tag from reflex.constants import ( Dirs, @@ -466,6 +467,12 @@ class Component(BaseComponent, ABC): # Merge styles, the later ones overriding keys in the earlier ones. style = {k: v for style_dict in style for k, v in style_dict.items()} + if isinstance(style, Breakpoints): + style = { + # Assign the Breakpoints to the self-referential selector to avoid squashing down to a regular dict. + "&": style, + } + kwargs["style"] = Style( { **self.get_fields()["style"].default, From 06833f6d8d5d9c55a82dd53e30ccb9780b691565 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 29 Jul 2024 17:17:51 -0700 Subject: [PATCH 17/34] [REF-3203] Find a DOM event-like object in addEvents (#3706) --- reflex/.templates/web/utils/state.js | 7 ++++++- reflex/utils/format.py | 9 +++++---- tests/components/base/test_script.py | 6 +++--- tests/components/test_component.py | 2 +- tests/utils/test_format.py | 10 +++++----- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index f67ce6858..81ac40100 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -647,7 +647,12 @@ export const useEventLoop = ( const [connectErrors, setConnectErrors] = useState([]); // Function to add new events to the event queue. - const addEvents = (events, _e, event_actions) => { + const addEvents = (events, args, event_actions) => { + if (!(args instanceof Array)) { + args = [args]; + } + const _e = args.filter((o) => o?.preventDefault !== undefined)[0] + if (event_actions?.preventDefault && _e?.preventDefault) { _e.preventDefault(); } diff --git a/reflex/utils/format.py b/reflex/utils/format.py index e163ebaac..59bbbd91c 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -438,15 +438,16 @@ def format_prop( sig = inspect.signature(prop.args_spec) # type: ignore if sig.parameters: arg_def = ",".join(f"_{p}" for p in sig.parameters) - arg_def = f"({arg_def})" + arg_def_expr = f"[{arg_def}]" else: # add a default argument for addEvents if none were specified in prop.args_spec # used to trigger the preventDefault() on the event. - arg_def = "(_e)" + arg_def = "...args" + arg_def_expr = "args" chain = ",".join([format_event(event) for event in prop.events]) - event = f"addEvents([{chain}], {arg_def}, {json_dumps(prop.event_actions)})" - prop = f"{arg_def} => {event}" + event = f"addEvents([{chain}], {arg_def_expr}, {json_dumps(prop.event_actions)})" + prop = f"({arg_def}) => {event}" # Handle other types. elif isinstance(prop, str): diff --git a/tests/components/base/test_script.py b/tests/components/base/test_script.py index e06914258..372178d5f 100644 --- a/tests/components/base/test_script.py +++ b/tests/components/base/test_script.py @@ -58,14 +58,14 @@ def test_script_event_handler(): ) render_dict = component.render() assert ( - f'onReady={{(_e) => addEvents([Event("{EvState.get_full_name()}.on_ready", {{}})], (_e), {{}})}}' + f'onReady={{(...args) => addEvents([Event("{EvState.get_full_name()}.on_ready", {{}})], args, {{}})}}' in render_dict["props"] ) assert ( - f'onLoad={{(_e) => addEvents([Event("{EvState.get_full_name()}.on_load", {{}})], (_e), {{}})}}' + f'onLoad={{(...args) => addEvents([Event("{EvState.get_full_name()}.on_load", {{}})], args, {{}})}}' in render_dict["props"] ) assert ( - f'onError={{(_e) => addEvents([Event("{EvState.get_full_name()}.on_error", {{}})], (_e), {{}})}}' + f'onError={{(...args) => addEvents([Event("{EvState.get_full_name()}.on_error", {{}})], args, {{}})}}' in render_dict["props"] ) diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 64354ada9..78c42f177 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -826,7 +826,7 @@ def test_component_event_trigger_arbitrary_args(): assert comp.render()["props"][0] == ( "onFoo={(__e,_alpha,_bravo,_charlie) => addEvents(" f'[Event("{C1State.get_full_name()}.mock_handler", {{_e:__e.target.value,_bravo:_bravo["nested"],_charlie:((_charlie.custom) + (42))}})], ' - "(__e,_alpha,_bravo,_charlie), {})}" + "[__e,_alpha,_bravo,_charlie], {})}" ) diff --git a/tests/utils/test_format.py b/tests/utils/test_format.py index 7037a3798..95ebc047b 100644 --- a/tests/utils/test_format.py +++ b/tests/utils/test_format.py @@ -477,7 +477,7 @@ def test_format_match( events=[EventSpec(handler=EventHandler(fn=mock_event))], args_spec=lambda: [], ), - '{(_e) => addEvents([Event("mock_event", {})], (_e), {})}', + '{(...args) => addEvents([Event("mock_event", {})], args, {})}', ), ( EventChain( @@ -495,9 +495,9 @@ def test_format_match( ), ) ], - args_spec=lambda: [], + args_spec=lambda e: [e.target.value], ), - '{(_e) => addEvents([Event("mock_event", {arg:_e.target.value})], (_e), {})}', + '{(_e) => addEvents([Event("mock_event", {arg:_e.target.value})], [_e], {})}', ), ( EventChain( @@ -505,7 +505,7 @@ def test_format_match( args_spec=lambda: [], event_actions={"stopPropagation": True}, ), - '{(_e) => addEvents([Event("mock_event", {})], (_e), {"stopPropagation": true})}', + '{(...args) => addEvents([Event("mock_event", {})], args, {"stopPropagation": true})}', ), ( EventChain( @@ -513,7 +513,7 @@ def test_format_match( args_spec=lambda: [], event_actions={"preventDefault": True}, ), - '{(_e) => addEvents([Event("mock_event", {})], (_e), {"preventDefault": true})}', + '{(...args) => addEvents([Event("mock_event", {})], args, {"preventDefault": true})}', ), ({"a": "red", "b": "blue"}, '{{"a": "red", "b": "blue"}}'), (BaseVar(_var_name="var", _var_type="int"), "{var}"), From 1c400043c646ba723fbfb87564fcb459d36fd75f Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Tue, 30 Jul 2024 15:38:32 -0700 Subject: [PATCH 18/34] [REF-3328] Implement __getitem__ for ArrayVar (#3705) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * half of the way there * implement __getitem__ for array * add some tests * add fixes to pyright * fix default factory * implement array operations * format code * fix pyright issue * give up * add object operations * add test for merge * pyright 🥺 * use str isntead of _var_name Co-authored-by: Masen Furer * wrong var_type * make to much nicer * add subclass checking * enhance types * use builtin list type * improve typing even more * i'm awaiting october * use even better typing * add hash, json, and guess type method * fix pyright issues * add a test and fix lots of errors * fix pyright once again * add type inference to list --------- Co-authored-by: Masen Furer --- reflex/experimental/vars/__init__.py | 4 +- reflex/experimental/vars/base.py | 254 +++--- reflex/experimental/vars/function.py | 78 +- reflex/experimental/vars/number.py | 165 +++- reflex/experimental/vars/object.py | 627 +++++++++++++++ reflex/experimental/vars/sequence.py | 1061 ++++++++++++++++++++------ reflex/vars.py | 8 + reflex/vars.pyi | 1 + tests/test_var.py | 97 ++- 9 files changed, 1971 insertions(+), 324 deletions(-) create mode 100644 reflex/experimental/vars/object.py diff --git a/reflex/experimental/vars/__init__.py b/reflex/experimental/vars/__init__.py index 945cf25fc..8fa5196ff 100644 --- a/reflex/experimental/vars/__init__.py +++ b/reflex/experimental/vars/__init__.py @@ -1,9 +1,7 @@ """Experimental Immutable-Based Var System.""" from .base import ImmutableVar as ImmutableVar -from .base import LiteralObjectVar as LiteralObjectVar from .base import LiteralVar as LiteralVar -from .base import ObjectVar as ObjectVar from .base import var_operation as var_operation from .function import FunctionStringVar as FunctionStringVar from .function import FunctionVar as FunctionVar @@ -12,6 +10,8 @@ from .number import BooleanVar as BooleanVar from .number import LiteralBooleanVar as LiteralBooleanVar from .number import LiteralNumberVar as LiteralNumberVar from .number import NumberVar as NumberVar +from .object import LiteralObjectVar as LiteralObjectVar +from .object import ObjectVar as ObjectVar from .sequence import ArrayJoinOperation as ArrayJoinOperation from .sequence import ArrayVar as ArrayVar from .sequence import ConcatVarOperation as ConcatVarOperation diff --git a/reflex/experimental/vars/base.py b/reflex/experimental/vars/base.py index 55b5673bd..dadcc38bd 100644 --- a/reflex/experimental/vars/base.py +++ b/reflex/experimental/vars/base.py @@ -4,18 +4,19 @@ from __future__ import annotations import dataclasses import functools +import inspect import sys from typing import ( + TYPE_CHECKING, Any, Callable, - Dict, Optional, Type, TypeVar, - Union, + overload, ) -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, get_origin from reflex import constants from reflex.base import Base @@ -30,6 +31,17 @@ from reflex.vars import ( _global_vars, ) +if TYPE_CHECKING: + from .function import FunctionVar, ToFunctionOperation + from .number import ( + BooleanVar, + NumberVar, + ToBooleanVarOperation, + ToNumberVarOperation, + ) + from .object import ObjectVar, ToObjectOperation + from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation + @dataclasses.dataclass( eq=False, @@ -43,7 +55,7 @@ class ImmutableVar(Var): _var_name: str = dataclasses.field() # The type of the var. - _var_type: Type = dataclasses.field(default=Any) + _var_type: types.GenericType = dataclasses.field(default=Any) # Extra metadata associated with the Var _var_data: Optional[ImmutableVarData] = dataclasses.field(default=None) @@ -265,9 +277,138 @@ class ImmutableVar(Var): # Encode the _var_data into the formatted output for tracking purposes. return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}" + @overload + def to( + self, output: Type[NumberVar], var_type: type[int] | type[float] = float + ) -> ToNumberVarOperation: ... -class ObjectVar(ImmutableVar): - """Base class for immutable object vars.""" + @overload + def to(self, output: Type[BooleanVar]) -> ToBooleanVarOperation: ... + + @overload + def to( + self, + output: Type[ArrayVar], + var_type: type[list] | type[tuple] | type[set] = list, + ) -> ToArrayOperation: ... + + @overload + def to(self, output: Type[StringVar]) -> ToStringOperation: ... + + @overload + def to( + self, output: Type[ObjectVar], var_type: types.GenericType = dict + ) -> ToObjectOperation: ... + + @overload + def to( + self, output: Type[FunctionVar], var_type: Type[Callable] = Callable + ) -> ToFunctionOperation: ... + + @overload + def to( + self, output: Type[OUTPUT], var_type: types.GenericType | None = None + ) -> OUTPUT: ... + + def to( + self, output: Type[OUTPUT], var_type: types.GenericType | None = None + ) -> Var: + """Convert the var to a different type. + + Args: + output: The output type. + var_type: The type of the var. + + Raises: + TypeError: If the var_type is not a supported type for the output. + + Returns: + The converted var. + """ + from .number import ( + BooleanVar, + NumberVar, + ToBooleanVarOperation, + ToNumberVarOperation, + ) + + fixed_type = ( + var_type + if var_type is None or inspect.isclass(var_type) + else get_origin(var_type) + ) + + if issubclass(output, NumberVar): + if fixed_type is not None and not issubclass(fixed_type, (int, float)): + raise TypeError( + f"Unsupported type {var_type} for NumberVar. Must be int or float." + ) + return ToNumberVarOperation(self, var_type or float) + if issubclass(output, BooleanVar): + return ToBooleanVarOperation(self) + + from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation + + if issubclass(output, ArrayVar): + if fixed_type is not None and not issubclass( + fixed_type, (list, tuple, set) + ): + raise TypeError( + f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set." + ) + return ToArrayOperation(self, var_type or list) + if issubclass(output, StringVar): + return ToStringOperation(self) + + from .object import ObjectVar, ToObjectOperation + + if issubclass(output, ObjectVar): + return ToObjectOperation(self, var_type or dict) + + from .function import FunctionVar, ToFunctionOperation + + if issubclass(output, FunctionVar): + if fixed_type is not None and not issubclass(fixed_type, Callable): + raise TypeError( + f"Unsupported type {var_type} for FunctionVar. Must be Callable." + ) + return ToFunctionOperation(self, var_type or Callable) + + return output( + _var_name=self._var_name, + _var_type=self._var_type if var_type is None else var_type, + _var_data=self._var_data, + ) + + def guess_type(self) -> ImmutableVar: + """Guess the type of the var. + + Returns: + The guessed type. + """ + from .number import NumberVar + from .object import ObjectVar + from .sequence import ArrayVar, StringVar + + if self._var_type is Any: + return self + + var_type = self._var_type + + fixed_type = var_type if inspect.isclass(var_type) else get_origin(var_type) + + if issubclass(fixed_type, (int, float)): + return self.to(NumberVar, var_type) + if issubclass(fixed_type, dict): + return self.to(ObjectVar, var_type) + if issubclass(fixed_type, (list, tuple, set)): + return self.to(ArrayVar, var_type) + if issubclass(fixed_type, str): + return self.to(StringVar) + return self + + +OUTPUT = TypeVar("OUTPUT", bound=ImmutableVar) class LiteralVar(ImmutableVar): @@ -299,6 +440,8 @@ class LiteralVar(ImmutableVar): if value is None: return ImmutableVar.create_safe("null", _var_data=_var_data) + from .object import LiteralObjectVar + if isinstance(value, Base): return LiteralObjectVar( value.dict(), _var_type=type(value), _var_data=_var_data @@ -330,102 +473,15 @@ class LiteralVar(ImmutableVar): def __post_init__(self): """Post-initialize the var.""" + def json(self) -> str: + """Serialize the var to a JSON string. -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LiteralObjectVar(LiteralVar): - """Base class for immutable literal object vars.""" - - _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( - default_factory=dict - ) - - def __init__( - self, - _var_value: dict[Var | Any, Var | Any], - _var_type: Type = dict, - _var_data: VarData | None = None, - ): - """Initialize the object var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. + Raises: + NotImplementedError: If the method is not implemented. """ - super(LiteralObjectVar, self).__init__( - _var_name="", - _var_type=_var_type, - _var_data=ImmutableVarData.merge(_var_data), + raise NotImplementedError( + "LiteralVar subclasses must implement the json method." ) - object.__setattr__( - self, - "_var_value", - _var_value, - ) - object.__delattr__(self, "_var_name") - - def __getattr__(self, name): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - if name == "_var_name": - return self._cached_var_name - return super(type(self), self).__getattr__(name) - - @functools.cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return ( - "{ " - + ", ".join( - [ - f"[{str(LiteralVar.create(key))}] : {str(LiteralVar.create(value))}" - for key, value in self._var_value.items() - ] - ) - + " }" - ) - - @functools.cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - *[ - value._get_all_var_data() - for key, value in self._var_value - if isinstance(value, Var) - ], - *[ - key._get_all_var_data() - for key, value in self._var_value - if isinstance(key, Var) - ], - self._var_data, - ) - - def _get_all_var_data(self) -> ImmutableVarData | None: - """Wrapper method for cached property. - - Returns: - The VarData of the components and all of its children. - """ - return self._cached_get_all_var_data P = ParamSpec("P") diff --git a/reflex/experimental/vars/function.py b/reflex/experimental/vars/function.py index f1cf83886..adce1329d 100644 --- a/reflex/experimental/vars/function.py +++ b/reflex/experimental/vars/function.py @@ -5,7 +5,7 @@ from __future__ import annotations import dataclasses import sys from functools import cached_property -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Type, Union from reflex.experimental.vars.base import ImmutableVar, LiteralVar from reflex.vars import ImmutableVarData, Var, VarData @@ -212,3 +212,79 @@ class ArgsFunctionOperation(FunctionVar): def __post_init__(self): """Post-initialize the var.""" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToFunctionOperation(FunctionVar): + """Base class of converting a var to a function.""" + + _original_var: Var = dataclasses.field( + default_factory=lambda: LiteralVar.create(None) + ) + + def __init__( + self, + original_var: Var, + _var_type: Type[Callable] = Callable, + _var_data: VarData | None = None, + ) -> None: + """Initialize the function with arguments var. + + Args: + original_var: The original var to convert to a function. + _var_type: The type of the function. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ToFunctionOperation, self).__init__( + _var_name=f"", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_original_var", original_var) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str(self._original_var) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._original_var._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data diff --git a/reflex/experimental/vars/number.py b/reflex/experimental/vars/number.py index 6b74bc336..c83c5c4d2 100644 --- a/reflex/experimental/vars/number.py +++ b/reflex/experimental/vars/number.py @@ -3,6 +3,7 @@ from __future__ import annotations import dataclasses +import json import sys from functools import cached_property from typing import Any, Union @@ -1253,6 +1254,22 @@ class LiteralBooleanVar(LiteralVar, BooleanVar): ) object.__setattr__(self, "_var_value", _var_value) + def __hash__(self) -> int: + """Hash the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._var_value)) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + """ + return "true" if self._var_value else "false" + @dataclasses.dataclass( eq=False, @@ -1288,8 +1305,154 @@ class LiteralNumberVar(LiteralVar, NumberVar): Returns: The hash of the var. """ - return hash(self._var_value) + return hash((self.__class__.__name__, self._var_value)) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + """ + return json.dumps(self._var_value) number_types = Union[NumberVar, LiteralNumberVar, int, float] boolean_types = Union[BooleanVar, LiteralBooleanVar, bool] + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToNumberVarOperation(NumberVar): + """Base class for immutable number vars that are the result of a number operation.""" + + _original_value: Var = dataclasses.field( + default_factory=lambda: LiteralNumberVar(0) + ) + + def __init__( + self, + _original_value: Var, + _var_type: type[int] | type[float] = float, + _var_data: VarData | None = None, + ): + """Initialize the number var. + + Args: + _original_value: The original value. + _var_type: The type of the Var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ToNumberVarOperation, self).__init__( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_original_value", _original_value) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str(self._original_value) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ToNumberVarOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._original_value._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToBooleanVarOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a boolean operation.""" + + _original_value: Var = dataclasses.field( + default_factory=lambda: LiteralBooleanVar(False) + ) + + def __init__( + self, + _original_value: Var, + _var_data: VarData | None = None, + ): + """Initialize the boolean var. + + Args: + _original_value: The original value. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ToBooleanVarOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_original_value", _original_value) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str(self._original_value) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ToBooleanVarOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._original_value._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data diff --git a/reflex/experimental/vars/object.py b/reflex/experimental/vars/object.py new file mode 100644 index 000000000..4522473c7 --- /dev/null +++ b/reflex/experimental/vars/object.py @@ -0,0 +1,627 @@ +"""Classes for immutable object vars.""" + +from __future__ import annotations + +import dataclasses +import sys +import typing +from functools import cached_property +from typing import Any, Dict, List, Tuple, Type, Union + +from reflex.experimental.vars.base import ImmutableVar, LiteralVar +from reflex.experimental.vars.sequence import ArrayVar, unionize +from reflex.vars import ImmutableVarData, Var, VarData + + +class ObjectVar(ImmutableVar): + """Base class for immutable object vars.""" + + def _key_type(self) -> Type: + """Get the type of the keys of the object. + + Returns: + The type of the keys of the object. + """ + return ImmutableVar + + def _value_type(self) -> Type: + """Get the type of the values of the object. + + Returns: + The type of the values of the object. + """ + return ImmutableVar + + def keys(self) -> ObjectKeysOperation: + """Get the keys of the object. + + Returns: + The keys of the object. + """ + return ObjectKeysOperation(self) + + def values(self) -> ObjectValuesOperation: + """Get the values of the object. + + Returns: + The values of the object. + """ + return ObjectValuesOperation(self) + + def entries(self) -> ObjectEntriesOperation: + """Get the entries of the object. + + Returns: + The entries of the object. + """ + return ObjectEntriesOperation(self) + + def merge(self, other: ObjectVar) -> ObjectMergeOperation: + """Merge two objects. + + Args: + other: The other object to merge. + + Returns: + The merged object. + """ + return ObjectMergeOperation(self, other) + + def __getitem__(self, key: Var | Any) -> ImmutableVar: + """Get an item from the object. + + Args: + key: The key to get from the object. + + Returns: + The item from the object. + """ + return ObjectItemOperation(self, key).guess_type() + + def __getattr__(self, name) -> ObjectItemOperation: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + return ObjectItemOperation(self, name) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralObjectVar(LiteralVar, ObjectVar): + """Base class for immutable literal object vars.""" + + _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( + default_factory=dict + ) + + def __init__( + self, + _var_value: dict[Var | Any, Var | Any], + _var_type: Type | None = None, + _var_data: VarData | None = None, + ): + """Initialize the object var. + + Args: + _var_value: The value of the var. + _var_type: The type of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralObjectVar, self).__init__( + _var_name="", + _var_type=( + Dict[ + unionize(*map(type, _var_value.keys())), + unionize(*map(type, _var_value.values())), + ] + if _var_type is None + else _var_type + ), + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "_var_value", + _var_value, + ) + object.__delattr__(self, "_var_name") + + def _key_type(self) -> Type: + """Get the type of the keys of the object. + + Returns: + The type of the keys of the object. + """ + args_list = typing.get_args(self._var_type) + return args_list[0] if args_list else Any + + def _value_type(self) -> Type: + """Get the type of the values of the object. + + Returns: + The type of the values of the object. + """ + args_list = typing.get_args(self._var_type) + return args_list[1] if args_list else Any + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "({ " + + ", ".join( + [ + f"[{str(LiteralVar.create(key))}] : {str(LiteralVar.create(value))}" + for key, value in self._var_value.items() + ] + ) + + " })" + ) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[ + value._get_all_var_data() + for key, value in self._var_value + if isinstance(value, Var) + ], + *[ + key._get_all_var_data() + for key, value in self._var_value + if isinstance(key, Var) + ], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def json(self) -> str: + """Get the JSON representation of the object. + + Returns: + The JSON representation of the object. + """ + return ( + "{" + + ", ".join( + [ + f"{LiteralVar.create(key).json()}:{LiteralVar.create(value).json()}" + for key, value in self._var_value.items() + ] + ) + + "}" + ) + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._var_name)) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ObjectToArrayOperation(ArrayVar): + """Base class for object to array operations.""" + + value: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + + def __init__( + self, + _var_value: ObjectVar, + _var_type: Type = list, + _var_data: VarData | None = None, + ): + """Initialize the object to array operation. + + Args: + _var_value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectToArrayOperation, self).__init__( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "value", _var_value) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Raises: + NotImplementedError: Must implement _cached_var_name. + """ + raise NotImplementedError( + "ObjectToArrayOperation must implement _cached_var_name" + ) + + def __getattr__(self, name): + """Get an attribute of the operation. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the operation. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the operation. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.value._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + +class ObjectKeysOperation(ObjectToArrayOperation): + """Operation to get the keys of an object.""" + + def __init__( + self, + value: ObjectVar, + _var_data: VarData | None = None, + ): + """Initialize the object keys operation. + + Args: + value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectKeysOperation, self).__init__( + value, List[value._key_type()], _var_data + ) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return f"Object.keys({self.value._var_name})" + + +class ObjectValuesOperation(ObjectToArrayOperation): + """Operation to get the values of an object.""" + + def __init__( + self, + value: ObjectVar, + _var_data: VarData | None = None, + ): + """Initialize the object values operation. + + Args: + value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectValuesOperation, self).__init__( + value, List[value._value_type()], _var_data + ) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return f"Object.values({self.value._var_name})" + + +class ObjectEntriesOperation(ObjectToArrayOperation): + """Operation to get the entries of an object.""" + + def __init__( + self, + value: ObjectVar, + _var_data: VarData | None = None, + ): + """Initialize the object entries operation. + + Args: + value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectEntriesOperation, self).__init__( + value, List[Tuple[value._key_type(), value._value_type()]], _var_data + ) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return f"Object.entries({self.value._var_name})" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ObjectMergeOperation(ObjectVar): + """Operation to merge two objects.""" + + left: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + right: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + + def __init__( + self, + left: ObjectVar, + right: ObjectVar, + _var_data: VarData | None = None, + ): + """Initialize the object merge operation. + + Args: + left: The left object to merge. + right: The right object to merge. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectMergeOperation, self).__init__( + _var_name="", + _var_type=left._var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "left", left) + object.__setattr__(self, "right", right) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return f"Object.assign({self.left._var_name}, {self.right._var_name})" + + def __getattr__(self, name): + """Get an attribute of the operation. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the operation. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the operation. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.left._get_all_var_data(), + self.right._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ObjectItemOperation(ImmutableVar): + """Operation to get an item from an object.""" + + value: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + key: Var | Any = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) + + def __init__( + self, + value: ObjectVar, + key: Var | Any, + _var_data: VarData | None = None, + ): + """Initialize the object item operation. + + Args: + value: The value of the operation. + key: The key to get from the object. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectItemOperation, self).__init__( + _var_name="", + _var_type=value._value_type(), + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "value", value) + object.__setattr__( + self, "key", key if isinstance(key, Var) else LiteralVar.create(key) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return f"{str(self.value)}[{str(self.key)}]" + + def __getattr__(self, name): + """Get an attribute of the operation. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the operation. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the operation. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.value._get_all_var_data(), + self.key._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToObjectOperation(ObjectVar): + """Operation to convert a var to an object.""" + + _original_var: Var = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + + def __init__( + self, + _original_var: Var, + _var_type: Type = dict, + _var_data: VarData | None = None, + ): + """Initialize the to object operation. + + Args: + _original_var: The original var to convert. + _var_type: The type of the var. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ToObjectOperation, self).__init__( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_original_var", _original_var) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return str(self._original_var) + + def __getattr__(self, name): + """Get an attribute of the operation. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the operation. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the operation. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._original_var._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data diff --git a/reflex/experimental/vars/sequence.py b/reflex/experimental/vars/sequence.py index c0e8bb9d7..8db1300ec 100644 --- a/reflex/experimental/vars/sequence.py +++ b/reflex/experimental/vars/sequence.py @@ -4,11 +4,15 @@ from __future__ import annotations import dataclasses import functools +import inspect import json import re import sys +import typing from functools import cached_property -from typing import Any, List, Set, Tuple, Union +from typing import Any, List, Set, Tuple, Type, Union, overload + +from typing_extensions import get_origin from reflex import constants from reflex.constants.base import REFLEX_VAR_OPENING_TAG @@ -16,7 +20,13 @@ from reflex.experimental.vars.base import ( ImmutableVar, LiteralVar, ) -from reflex.experimental.vars.number import BooleanVar, NotEqualOperation, NumberVar +from reflex.experimental.vars.number import ( + BooleanVar, + LiteralNumberVar, + NotEqualOperation, + NumberVar, +) +from reflex.utils.types import GenericType from reflex.vars import ImmutableVarData, Var, VarData, _global_vars @@ -67,7 +77,15 @@ class StringVar(ImmutableVar): """ return ConcatVarOperation(*[self for _ in range(other)]) - def __getitem__(self, i: slice | int) -> StringSliceOperation | StringItemOperation: + @overload + def __getitem__(self, i: slice) -> ArrayJoinOperation: ... + + @overload + def __getitem__(self, i: int | NumberVar) -> StringItemOperation: ... + + def __getitem__( + self, i: slice | int | NumberVar + ) -> ArrayJoinOperation | StringItemOperation: """Get a slice of the string. Args: @@ -77,16 +95,16 @@ class StringVar(ImmutableVar): The string slice operation. """ if isinstance(i, slice): - return StringSliceOperation(self, i) + return self.split()[i].join() return StringItemOperation(self, i) - def length(self) -> StringLengthOperation: + def length(self) -> NumberVar: """Get the length of the string. Returns: The string length operation. """ - return StringLengthOperation(self) + return self.split().length() def lower(self) -> StringLowerOperation: """Convert the string to lowercase. @@ -120,13 +138,13 @@ class StringVar(ImmutableVar): """ return NotEqualOperation(self.length(), 0) - def reversed(self) -> StringReverseOperation: + def reversed(self) -> ArrayJoinOperation: """Reverse the string. Returns: The string reverse operation. """ - return StringReverseOperation(self) + return self.split().reverse().join() def contains(self, other: StringVar | str) -> StringContainsOperation: """Check if the string contains another string. @@ -151,85 +169,6 @@ class StringVar(ImmutableVar): return StringSplitOperation(self, separator) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class StringToNumberOperation(NumberVar): - """Base class for immutable number vars that are the result of a string to number operation.""" - - a: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) - - def __init__(self, a: StringVar | str, _var_data: VarData | None = None): - """Initialize the string to number operation var. - - Args: - a: The string. - _var_data: Additional hooks and imports associated with the Var. - """ - super(StringToNumberOperation, self).__init__( - _var_name="", - _var_type=float, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) - ) - object.__delattr__(self, "_var_name") - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Raises: - NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError( - "StringToNumberOperation must implement _cached_var_name" - ) - - def __getattr__(self, name: str) -> Any: - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute value. - """ - if name == "_var_name": - return self._cached_var_name - getattr(super(StringToNumberOperation, self), name) - - @cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) - - def _get_all_var_data(self) -> ImmutableVarData | None: - return self._cached_get_all_var_data - - -class StringLengthOperation(StringToNumberOperation): - """Base class for immutable number vars that are the result of a string length operation.""" - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self.a)}.length" - - @dataclasses.dataclass( eq=False, frozen=True, @@ -338,19 +277,6 @@ class StringStripOperation(StringToStringOperation): return f"{str(self.a)}.trim()" -class StringReverseOperation(StringToStringOperation): - """Base class for immutable string vars that are the result of a string reverse operation.""" - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self.a)}.split('').reverse().join('')" - - @dataclasses.dataclass( eq=False, frozen=True, @@ -426,112 +352,6 @@ class StringContainsOperation(BooleanVar): return self._cached_get_all_var_data -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class StringSliceOperation(StringVar): - """Base class for immutable string vars that are the result of a string slice operation.""" - - a: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) - _slice: slice = dataclasses.field(default_factory=lambda: slice(None, None, None)) - - def __init__( - self, a: StringVar | str, _slice: slice, _var_data: VarData | None = None - ): - """Initialize the string slice operation var. - - Args: - a: The string. - _slice: The slice. - _var_data: Additional hooks and imports associated with the Var. - """ - super(StringSliceOperation, self).__init__( - _var_name="", - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) - ) - object.__setattr__(self, "_slice", _slice) - object.__delattr__(self, "_var_name") - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - - Raises: - ValueError: If the slice step is zero. - """ - start, end, step = self._slice.start, self._slice.stop, self._slice.step - - if step is not None and step < 0: - actual_start = end + 1 if end is not None else 0 - actual_end = start + 1 if start is not None else self.a.length() - return str( - StringSliceOperation( - StringReverseOperation( - StringSliceOperation(self.a, slice(actual_start, actual_end)) - ), - slice(None, None, -step), - ) - ) - - start = ( - LiteralVar.create(start) - if start is not None - else ImmutableVar.create_safe("undefined") - ) - end = ( - LiteralVar.create(end) - if end is not None - else ImmutableVar.create_safe("undefined") - ) - - if step is None: - return f"{str(self.a)}.slice({str(start)}, {str(end)})" - if step == 0: - raise ValueError("slice step cannot be zero") - return f"{str(self.a)}.slice({str(start)}, {str(end)}).split('').filter((_, i) => i % {str(step)} === 0).join('')" - - def __getattr__(self, name: str) -> Any: - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute value. - """ - if name == "_var_name": - return self._cached_var_name - getattr(super(StringSliceOperation, self), name) - - @cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - self.a._get_all_var_data(), - self.start._get_all_var_data(), - self.end._get_all_var_data(), - self._var_data, - ) - - def _get_all_var_data(self) -> ImmutableVarData | None: - return self._cached_get_all_var_data - - @dataclasses.dataclass( eq=False, frozen=True, @@ -543,9 +363,11 @@ class StringItemOperation(StringVar): a: StringVar = dataclasses.field( default_factory=lambda: LiteralStringVar.create("") ) - i: int = dataclasses.field(default=0) + i: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) - def __init__(self, a: StringVar | str, i: int, _var_data: VarData | None = None): + def __init__( + self, a: StringVar | str, i: int | NumberVar, _var_data: VarData | None = None + ): """Initialize the string item operation var. Args: @@ -561,7 +383,7 @@ class StringItemOperation(StringVar): object.__setattr__( self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) ) - object.__setattr__(self, "i", i) + object.__setattr__(self, "i", i if isinstance(i, Var) else LiteralNumberVar(i)) object.__delattr__(self, "_var_name") @cached_property @@ -593,7 +415,9 @@ class StringItemOperation(StringVar): Returns: The VarData of the components and all of its children. """ - return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.i._get_all_var_data(), self._var_data + ) def _get_all_var_data(self) -> ImmutableVarData | None: return self._cached_get_all_var_data @@ -608,7 +432,7 @@ class ArrayJoinOperation(StringVar): ) def __init__( - self, a: ArrayVar | list, b: StringVar | str, _var_data: VarData | None = None + self, a: ArrayVar, b: StringVar | str, _var_data: VarData | None = None ): """Initialize the array join operation var. @@ -622,9 +446,7 @@ class ArrayJoinOperation(StringVar): _var_type=str, _var_data=ImmutableVarData.merge(_var_data), ) - object.__setattr__( - self, "a", a if isinstance(a, Var) else LiteralArrayVar.create(a) - ) + object.__setattr__(self, "a", a) object.__setattr__( self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) ) @@ -777,6 +599,22 @@ class LiteralStringVar(LiteralVar, StringVar): _var_data=_var_data, ) + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._var_value)) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + """ + return json.dumps(self._var_value) + @dataclasses.dataclass( eq=False, @@ -879,6 +717,94 @@ class ArrayVar(ImmutableVar): return ArrayJoinOperation(self, sep) + def reverse(self) -> ArrayReverseOperation: + """Reverse the array. + + Returns: + The reversed array. + """ + return ArrayReverseOperation(self) + + @overload + def __getitem__(self, i: slice) -> ArraySliceOperation: ... + + @overload + def __getitem__(self, i: int | NumberVar) -> ImmutableVar: ... + + def __getitem__( + self, i: slice | int | NumberVar + ) -> ArraySliceOperation | ImmutableVar: + """Get a slice of the array. + + Args: + i: The slice. + + Returns: + The array slice operation. + """ + if isinstance(i, slice): + return ArraySliceOperation(self, i) + return ArrayItemOperation(self, i).guess_type() + + def length(self) -> NumberVar: + """Get the length of the array. + + Returns: + The length of the array. + """ + return ArrayLengthOperation(self) + + @overload + @classmethod + def range(cls, stop: int | NumberVar, /) -> RangeOperation: ... + + @overload + @classmethod + def range( + cls, + start: int | NumberVar, + end: int | NumberVar, + step: int | NumberVar = 1, + /, + ) -> RangeOperation: ... + + @classmethod + def range( + cls, + first_endpoint: int | NumberVar, + second_endpoint: int | NumberVar | None = None, + step: int | NumberVar | None = None, + ) -> RangeOperation: + """Create a range of numbers. + + Args: + first_endpoint: The end of the range if second_endpoint is not provided, otherwise the start of the range. + second_endpoint: The end of the range. + step: The step of the range. + + Returns: + The range of numbers. + """ + if second_endpoint is None: + start = 0 + end = first_endpoint + else: + start = first_endpoint + end = second_endpoint + + return RangeOperation(start, end, step or 1) + + def contains(self, other: Any) -> ArrayContainsOperation: + """Check if the array contains an element. + + Args: + other: The element to check for. + + Returns: + The array contains operation. + """ + return ArrayContainsOperation(self, other) + @dataclasses.dataclass( eq=False, @@ -894,19 +820,25 @@ class LiteralArrayVar(LiteralVar, ArrayVar): def __init__( self, - _var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any], + _var_value: list[Var | Any] | tuple[Var | Any, ...] | set[Var | Any], + _var_type: type[list] | type[tuple] | type[set] | None = None, _var_data: VarData | None = None, ): """Initialize the array var. Args: _var_value: The value of the var. + _var_type: The type of the var. _var_data: Additional hooks and imports associated with the Var. """ super(LiteralArrayVar, self).__init__( _var_name="", _var_data=ImmutableVarData.merge(_var_data), - _var_type=list, + _var_type=( + List[unionize(*map(type, _var_value))] + if _var_type is None + else _var_type + ), ) object.__setattr__(self, "_var_value", _var_value) object.__delattr__(self, "_var_name") @@ -963,6 +895,28 @@ class LiteralArrayVar(LiteralVar, ArrayVar): """ return self._cached_get_all_var_data + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._var_name)) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + """ + return ( + "[" + + ", ".join( + [LiteralVar.create(element).json() for element in self._var_value] + ) + + "]" + ) + @dataclasses.dataclass( eq=False, @@ -991,7 +945,7 @@ class StringSplitOperation(ArrayVar): """ super(StringSplitOperation, self).__init__( _var_name="", - _var_type=list, + _var_type=List[str], _var_data=ImmutableVarData.merge(_var_data), ) object.__setattr__( @@ -1037,3 +991,676 @@ class StringSplitOperation(ArrayVar): def _get_all_var_data(self) -> ImmutableVarData | None: return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArrayToArrayOperation(ArrayVar): + """Base class for immutable array vars that are the result of an array to array operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + + def __init__(self, a: ArrayVar, _var_data: VarData | None = None): + """Initialize the array to array operation var. + + Args: + a: The string. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArrayToArrayOperation, self).__init__( + _var_name="", + _var_type=a._var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError( + "ArrayToArrayOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayToArrayOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data() if isinstance(self.a, Var) else None, + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArraySliceOperation(ArrayVar): + """Base class for immutable string vars that are the result of a string slice operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + _slice: slice = dataclasses.field(default_factory=lambda: slice(None, None, None)) + + def __init__(self, a: ArrayVar, _slice: slice, _var_data: VarData | None = None): + """Initialize the string slice operation var. + + Args: + a: The string. + _slice: The slice. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArraySliceOperation, self).__init__( + _var_name="", + _var_type=a._var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__setattr__(self, "_slice", _slice) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + + Raises: + ValueError: If the slice step is zero. + """ + start, end, step = self._slice.start, self._slice.stop, self._slice.step + + normalized_start = ( + LiteralVar.create(start) + if start is not None + else ImmutableVar.create_safe("undefined") + ) + normalized_end = ( + LiteralVar.create(end) + if end is not None + else ImmutableVar.create_safe("undefined") + ) + if step is None: + return ( + f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)})" + ) + if not isinstance(step, Var): + if step < 0: + actual_start = end + 1 if end is not None else 0 + actual_end = start + 1 if start is not None else self.a.length() + return str( + ArraySliceOperation( + ArrayReverseOperation( + ArraySliceOperation(self.a, slice(actual_start, actual_end)) + ), + slice(None, None, -step), + ) + ) + if step == 0: + raise ValueError("slice step cannot be zero") + return f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0)" + + actual_start_reverse = end + 1 if end is not None else 0 + actual_end_reverse = start + 1 if start is not None else self.a.length() + + return f"{str(self.step)} > 0 ? {str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0) : {str(self.a)}.slice({str(actual_start_reverse)}, {str(actual_end_reverse)}).reverse().filter((_, i) => i % {str(-step)} === 0)" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArraySliceOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), + *[ + slice_value._get_all_var_data() + for slice_value in ( + self._slice.start, + self._slice.stop, + self._slice.step, + ) + if slice_value is not None and isinstance(slice_value, Var) + ], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class ArrayReverseOperation(ArrayToArrayOperation): + """Base class for immutable string vars that are the result of a string reverse operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.reverse()" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArrayToNumberOperation(NumberVar): + """Base class for immutable number vars that are the result of an array to number operation.""" + + a: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar([]), + ) + + def __init__(self, a: ArrayVar, _var_data: VarData | None = None): + """Initialize the string to number operation var. + + Args: + a: The array. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArrayToNumberOperation, self).__init__( + _var_name="", + _var_type=int, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a)) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError( + "StringToNumberOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayToNumberOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class ArrayLengthOperation(ArrayToNumberOperation): + """Base class for immutable number vars that are the result of an array length operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.length" + + +def unionize(*args: Type) -> Type: + """Unionize the types. + + Args: + args: The types to unionize. + + Returns: + The unionized types. + """ + if not args: + return Any + first, *rest = args + if not rest: + return first + return Union[first, unionize(*rest)] + + +def is_tuple_type(t: GenericType) -> bool: + """Check if a type is a tuple type. + + Args: + t: The type to check. + + Returns: + Whether the type is a tuple type. + """ + if inspect.isclass(t): + return issubclass(t, tuple) + return get_origin(t) is tuple + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArrayItemOperation(ImmutableVar): + """Base class for immutable array vars that are the result of an array item operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + i: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + + def __init__( + self, + a: ArrayVar, + i: NumberVar | int, + _var_data: VarData | None = None, + ): + """Initialize the array item operation var. + + Args: + a: The array. + i: The index. + _var_data: Additional hooks and imports associated with the Var. + """ + args = typing.get_args(a._var_type) + if args and isinstance(i, int) and is_tuple_type(a._var_type): + element_type = args[i % len(args)] + else: + element_type = unionize(*args) + super(ArrayItemOperation, self).__init__( + _var_name="", + _var_type=element_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a)) + object.__setattr__( + self, + "i", + i if isinstance(i, Var) else LiteralNumberVar(i), + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.at({str(self.i)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayItemOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.i._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class RangeOperation(ArrayVar): + """Base class for immutable array vars that are the result of a range operation.""" + + start: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + end: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + step: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(1)) + + def __init__( + self, + start: NumberVar | int, + end: NumberVar | int, + step: NumberVar | int, + _var_data: VarData | None = None, + ): + """Initialize the range operation var. + + Args: + start: The start of the range. + end: The end of the range. + step: The step of the range. + _var_data: Additional hooks and imports associated with the Var. + """ + super(RangeOperation, self).__init__( + _var_name="", + _var_type=List[int], + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "start", + start if isinstance(start, Var) else LiteralNumberVar(start), + ) + object.__setattr__( + self, + "end", + end if isinstance(end, Var) else LiteralNumberVar(end), + ) + object.__setattr__( + self, + "step", + step if isinstance(step, Var) else LiteralNumberVar(step), + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + start, end, step = self.start, self.end, self.step + return f"Array.from({{ length: ({str(end)} - {str(start)}) / {str(step)} }}, (_, i) => {str(start)} + i * {str(step)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(RangeOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.start._get_all_var_data(), + self.end._get_all_var_data(), + self.step._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArrayContainsOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of an array contains operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + b: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) + + def __init__(self, a: ArrayVar, b: Any | Var, _var_data: VarData | None = None): + """Initialize the array contains operation var. + + Args: + a: The array. + b: The element to check for. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArrayContainsOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__setattr__(self, "b", b if isinstance(b, Var) else LiteralVar.create(b)) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.includes({str(self.b)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayContainsOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToStringOperation(StringVar): + """Base class for immutable string vars that are the result of a to string operation.""" + + original_var: Var = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__(self, original_var: Var, _var_data: VarData | None = None): + """Initialize the to string operation var. + + Args: + original_var: The original var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ToStringOperation, self).__init__( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "original_var", + original_var, + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str(self.original_var) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ToStringOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.original_var._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToArrayOperation(ArrayVar): + """Base class for immutable array vars that are the result of a to array operation.""" + + original_var: Var = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + + def __init__( + self, + original_var: Var, + _var_type: type[list] | type[set] | type[tuple] = list, + _var_data: VarData | None = None, + ): + """Initialize the to array operation var. + + Args: + original_var: The original var. + _var_type: The type of the array. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ToArrayOperation, self).__init__( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "original_var", + original_var, + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str(self.original_var) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ToArrayOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.original_var._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data diff --git a/reflex/vars.py b/reflex/vars.py index 00f02804c..c5a66a20d 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1997,6 +1997,14 @@ class Var: """ return self._var_data + def json(self) -> str: + """Serialize the var to a JSON string. + + Raises: + NotImplementedError: If the method is not implemented. + """ + raise NotImplementedError("Var subclasses must implement the json method.") + @property def _var_name_unwrapped(self) -> str: """Get the var str without wrapping in curly braces. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 4aa6afc33..77a878086 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -151,6 +151,7 @@ class Var: def _var_full_name(self) -> str: ... def _var_set_state(self, state: Type[BaseState] | str) -> Any: ... def _get_all_var_data(self) -> VarData: ... + def json(self) -> str: ... @dataclass(eq=False) class BaseVar(Var): diff --git a/tests/test_var.py b/tests/test_var.py index 761375464..66599db72 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -19,7 +19,13 @@ from reflex.experimental.vars.number import ( LiteralNumberVar, NumberVar, ) -from reflex.experimental.vars.sequence import ConcatVarOperation, LiteralStringVar +from reflex.experimental.vars.object import LiteralObjectVar +from reflex.experimental.vars.sequence import ( + ArrayVar, + ConcatVarOperation, + LiteralArrayVar, + LiteralStringVar, +) from reflex.state import BaseState from reflex.utils.imports import ImportVar from reflex.vars import ( @@ -881,7 +887,7 @@ def test_literal_var(): ) assert ( str(complicated_var) - == '[{ ["a"] : 1, ["b"] : 2, ["c"] : { ["d"] : 3, ["e"] : 4 } }, [1, 2, 3, 4], 9, "string", true, false, null, [1, 2, 3]]' + == '[({ ["a"] : 1, ["b"] : 2, ["c"] : ({ ["d"] : 3, ["e"] : 4 }) }), [1, 2, 3, 4], 9, "string", true, false, null, [1, 2, 3]]' ) @@ -898,7 +904,7 @@ def test_function_var(): ) assert ( str(manual_addition_func.call(1, 2)) - == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))' + == '(((a, b) => (({ ["args"] : [a, b], ["result"] : a + b })))(1, 2))' ) increment_func = addition_func(1) @@ -935,7 +941,7 @@ def test_var_operation(): def test_string_operations(): basic_string = LiteralStringVar.create("Hello, World!") - assert str(basic_string.length()) == '"Hello, World!".length' + assert str(basic_string.length()) == '"Hello, World!".split("").length' assert str(basic_string.lower()) == '"Hello, World!".toLowerCase()' assert str(basic_string.upper()) == '"Hello, World!".toUpperCase()' assert str(basic_string.strip()) == '"Hello, World!".trim()' @@ -972,6 +978,89 @@ def test_all_number_operations(): ) +def test_index_operation(): + array_var = LiteralArrayVar([1, 2, 3, 4, 5]) + assert str(array_var[0]) == "[1, 2, 3, 4, 5].at(0)" + assert str(array_var[1:2]) == "[1, 2, 3, 4, 5].slice(1, 2)" + assert ( + str(array_var[1:4:2]) + == "[1, 2, 3, 4, 5].slice(1, 4).filter((_, i) => i % 2 === 0)" + ) + assert ( + str(array_var[::-1]) + == "[1, 2, 3, 4, 5].slice(0, [1, 2, 3, 4, 5].length).reverse().slice(undefined, undefined).filter((_, i) => i % 1 === 0)" + ) + assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].reverse()" + assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)" + + +def test_array_operations(): + array_var = LiteralArrayVar.create([1, 2, 3, 4, 5]) + + assert str(array_var.length()) == "[1, 2, 3, 4, 5].length" + assert str(array_var.contains(3)) == "[1, 2, 3, 4, 5].includes(3)" + assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].reverse()" + assert ( + str(ArrayVar.range(10)) + == "Array.from({ length: (10 - 0) / 1 }, (_, i) => 0 + i * 1)" + ) + assert ( + str(ArrayVar.range(1, 10)) + == "Array.from({ length: (10 - 1) / 1 }, (_, i) => 1 + i * 1)" + ) + assert ( + str(ArrayVar.range(1, 10, 2)) + == "Array.from({ length: (10 - 1) / 2 }, (_, i) => 1 + i * 2)" + ) + assert ( + str(ArrayVar.range(1, 10, -1)) + == "Array.from({ length: (10 - 1) / -1 }, (_, i) => 1 + i * -1)" + ) + + +def test_object_operations(): + object_var = LiteralObjectVar({"a": 1, "b": 2, "c": 3}) + + assert ( + str(object_var.keys()) == 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))' + ) + assert ( + str(object_var.values()) + == 'Object.values(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))' + ) + assert ( + str(object_var.entries()) + == 'Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))' + ) + assert str(object_var.a) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]' + assert str(object_var["a"]) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]' + assert ( + str(object_var.merge(LiteralObjectVar({"c": 4, "d": 5}))) + == 'Object.assign(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }), ({ ["c"] : 4, ["d"] : 5 }))' + ) + + +def test_type_chains(): + object_var = LiteralObjectVar({"a": 1, "b": 2, "c": 3}) + assert object_var._var_type is Dict[str, int] + assert (object_var.keys()._var_type, object_var.values()._var_type) == ( + List[str], + List[int], + ) + assert ( + str(object_var.keys()[0].upper()) # type: ignore + == 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(0).toUpperCase()' + ) + assert ( + str(object_var.entries()[1][1] - 1) # type: ignore + == '(Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(1).at(1) - 1)' + ) + assert ( + str(object_var["c"] + object_var["b"]) # type: ignore + == '(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["c"] + ({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["b"])' + ) + + def test_retrival(): var_without_data = ImmutableVar.create("test") assert var_without_data is not None From 2629366b235845018b8a4ea72acf717f4b054837 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Wed, 31 Jul 2024 19:00:52 +0200 Subject: [PATCH 19/34] fix initial_value for computed_var (#3726) * fix initial_value for computed_var * fix initial_value in pyi --- reflex/vars.py | 2 +- reflex/vars.pyi | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/reflex/vars.py b/reflex/vars.py index c5a66a20d..ffaf16455 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -2482,7 +2482,7 @@ class ComputedVar(Var, property): def computed_var( fget: Callable[[BaseState], Any] | None = None, - initial_value: Any | None = None, + initial_value: Any | types.Unset = types.Unset(), cache: bool = False, deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 77a878086..47d433374 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -190,7 +190,7 @@ class ComputedVar(Var): @overload def computed_var( fget: Callable[[BaseState], Any] | None = None, - initial_value: Any | None = None, + initial_value: Any | types.Unset = types.Unset(), cache: bool = False, deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, @@ -202,7 +202,7 @@ def computed_var(fget: Callable[[Any], Any]) -> ComputedVar: ... @overload def cached_var( fget: Callable[[BaseState], Any] | None = None, - initial_value: Any | None = None, + initial_value: Any | types.Unset = types.Unset(), deps: Optional[List[Union[str, Var]]] = None, auto_deps: bool = True, interval: Optional[Union[datetime.timedelta, int]] = None, From 129adc941a6bb2ae5a9de6c010144c17a93823ee Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:37:53 +0200 Subject: [PATCH 20/34] add test for initial state dict (#3727) --- tests/test_state.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_state.py b/tests/test_state.py index 18d740015..aa5705b09 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1012,6 +1012,21 @@ def interdependent_state() -> BaseState: return s +def test_interdependent_state_initial_dict() -> None: + s = InterdependentState() + state_name = s.get_name() + d = s.dict(initial=True)[state_name] + d.pop("router") + assert d == { + "x": 0, + "v1": 0, + "v1x2": 0, + "v2x2": 2, + "v1x2x2": 0, + "v3x2": 2, + } + + def test_not_dirty_computed_var_from_var( interdependent_state: InterdependentState, ) -> None: From ad14f383296f1748a45b2eb8fda2449bba6e5a69 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Wed, 31 Jul 2024 12:01:17 -0700 Subject: [PATCH 21/34] add type hinting to existing types (#3729) * add type hinting to existing types * dang it darglint * i cannot --- reflex/experimental/vars/base.py | 52 ++++++- reflex/experimental/vars/function.py | 2 +- reflex/experimental/vars/number.py | 4 +- reflex/experimental/vars/object.py | 225 ++++++++++++++++++++++++--- reflex/experimental/vars/sequence.py | 170 +++++++++++++++----- tests/test_var.py | 27 +++- 6 files changed, 415 insertions(+), 65 deletions(-) diff --git a/reflex/experimental/vars/base.py b/reflex/experimental/vars/base.py index dadcc38bd..7da1e6537 100644 --- a/reflex/experimental/vars/base.py +++ b/reflex/experimental/vars/base.py @@ -10,9 +10,15 @@ from typing import ( TYPE_CHECKING, Any, Callable, + Dict, + Generic, + List, Optional, + Set, + Tuple, Type, TypeVar, + Union, overload, ) @@ -42,13 +48,15 @@ if TYPE_CHECKING: from .object import ObjectVar, ToObjectOperation from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation +VAR_TYPE = TypeVar("VAR_TYPE") + @dataclasses.dataclass( eq=False, frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class ImmutableVar(Var): +class ImmutableVar(Var, Generic[VAR_TYPE]): """Base class for immutable vars.""" # The name of the var. @@ -405,6 +413,8 @@ class ImmutableVar(Var): return self.to(ArrayVar, var_type) if issubclass(fixed_type, str): return self.to(StringVar) + if issubclass(fixed_type, Base): + return self.to(ObjectVar, var_type) return self @@ -531,3 +541,43 @@ def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P return wrapper return decorator + + +def unionize(*args: Type) -> Type: + """Unionize the types. + + Args: + args: The types to unionize. + + Returns: + The unionized types. + """ + if not args: + return Any + first, *rest = args + if not rest: + return first + return Union[first, unionize(*rest)] + + +def figure_out_type(value: Any) -> Type: + """Figure out the type of the value. + + Args: + value: The value to figure out the type of. + + Returns: + The type of the value. + """ + if isinstance(value, list): + return List[unionize(*(figure_out_type(v) for v in value))] + if isinstance(value, set): + return Set[unionize(*(figure_out_type(v) for v in value))] + if isinstance(value, tuple): + return Tuple[unionize(*(figure_out_type(v) for v in value)), ...] + if isinstance(value, dict): + return Dict[ + unionize(*(figure_out_type(k) for k in value)), + unionize(*(figure_out_type(v) for v in value.values())), + ] + return type(value) diff --git a/reflex/experimental/vars/function.py b/reflex/experimental/vars/function.py index adce1329d..4514a482d 100644 --- a/reflex/experimental/vars/function.py +++ b/reflex/experimental/vars/function.py @@ -11,7 +11,7 @@ from reflex.experimental.vars.base import ImmutableVar, LiteralVar from reflex.vars import ImmutableVarData, Var, VarData -class FunctionVar(ImmutableVar): +class FunctionVar(ImmutableVar[Callable]): """Base class for immutable function vars.""" def __call__(self, *args: Var | Any) -> ArgsFunctionOperation: diff --git a/reflex/experimental/vars/number.py b/reflex/experimental/vars/number.py index c83c5c4d2..6bd3a7ff7 100644 --- a/reflex/experimental/vars/number.py +++ b/reflex/experimental/vars/number.py @@ -15,7 +15,7 @@ from reflex.experimental.vars.base import ( from reflex.vars import ImmutableVarData, Var, VarData -class NumberVar(ImmutableVar): +class NumberVar(ImmutableVar[Union[int, float]]): """Base class for immutable number vars.""" def __add__(self, other: number_types | boolean_types) -> NumberAddOperation: @@ -693,7 +693,7 @@ class NumberTruncOperation(UnaryNumberOperation): return f"Math.trunc({str(value)})" -class BooleanVar(ImmutableVar): +class BooleanVar(ImmutableVar[bool]): """Base class for immutable boolean vars.""" def __and__(self, other: bool) -> BooleanAndOperation: diff --git a/reflex/experimental/vars/object.py b/reflex/experimental/vars/object.py index 4522473c7..a227f0d7c 100644 --- a/reflex/experimental/vars/object.py +++ b/reflex/experimental/vars/object.py @@ -6,23 +6,69 @@ import dataclasses import sys import typing from functools import cached_property -from typing import Any, Dict, List, Tuple, Type, Union +from inspect import isclass +from typing import ( + Any, + Dict, + List, + NoReturn, + Tuple, + Type, + TypeVar, + Union, + get_args, + overload, +) -from reflex.experimental.vars.base import ImmutableVar, LiteralVar -from reflex.experimental.vars.sequence import ArrayVar, unionize +from typing_extensions import get_origin + +from reflex.experimental.vars.base import ( + ImmutableVar, + LiteralVar, + figure_out_type, +) +from reflex.experimental.vars.number import NumberVar +from reflex.experimental.vars.sequence import ArrayVar, StringVar +from reflex.utils.exceptions import VarAttributeError +from reflex.utils.types import GenericType, get_attribute_access_type from reflex.vars import ImmutableVarData, Var, VarData +OBJECT_TYPE = TypeVar("OBJECT_TYPE") -class ObjectVar(ImmutableVar): +KEY_TYPE = TypeVar("KEY_TYPE") +VALUE_TYPE = TypeVar("VALUE_TYPE") + +ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE") + +OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE") + + +class ObjectVar(ImmutableVar[OBJECT_TYPE]): """Base class for immutable object vars.""" + @overload + def _key_type(self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]) -> KEY_TYPE: ... + + @overload + def _key_type(self) -> Type: ... + def _key_type(self) -> Type: """Get the type of the keys of the object. Returns: The type of the keys of the object. """ - return ImmutableVar + fixed_type = ( + self._var_type if isclass(self._var_type) else get_origin(self._var_type) + ) + args = get_args(self._var_type) if issubclass(fixed_type, dict) else () + return args[0] if args else Any + + @overload + def _value_type(self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]) -> VALUE_TYPE: ... + + @overload + def _value_type(self) -> Type: ... def _value_type(self) -> Type: """Get the type of the values of the object. @@ -30,9 +76,21 @@ class ObjectVar(ImmutableVar): Returns: The type of the values of the object. """ - return ImmutableVar + fixed_type = ( + self._var_type if isclass(self._var_type) else get_origin(self._var_type) + ) + args = get_args(self._var_type) if issubclass(fixed_type, dict) else () + return args[1] if args else Any - def keys(self) -> ObjectKeysOperation: + @overload + def keys( + self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]], + ) -> ArrayVar[List[KEY_TYPE]]: ... + + @overload + def keys(self) -> ArrayVar: ... + + def keys(self) -> ArrayVar: """Get the keys of the object. Returns: @@ -40,7 +98,15 @@ class ObjectVar(ImmutableVar): """ return ObjectKeysOperation(self) - def values(self) -> ObjectValuesOperation: + @overload + def values( + self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]], + ) -> ArrayVar[List[VALUE_TYPE]]: ... + + @overload + def values(self) -> ArrayVar: ... + + def values(self) -> ArrayVar: """Get the values of the object. Returns: @@ -48,7 +114,15 @@ class ObjectVar(ImmutableVar): """ return ObjectValuesOperation(self) - def entries(self) -> ObjectEntriesOperation: + @overload + def entries( + self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]], + ) -> ArrayVar[List[Tuple[KEY_TYPE, VALUE_TYPE]]]: ... + + @overload + def entries(self) -> ArrayVar: ... + + def entries(self) -> ArrayVar: """Get the entries of the object. Returns: @@ -67,6 +141,53 @@ class ObjectVar(ImmutableVar): """ return ObjectMergeOperation(self, other) + # NoReturn is used here to catch when key value is Any + @overload + def __getitem__( + self: ObjectVar[Dict[KEY_TYPE, NoReturn]], + key: Var | Any, + ) -> ImmutableVar: ... + + @overload + def __getitem__( + self: ( + ObjectVar[Dict[KEY_TYPE, int]] + | ObjectVar[Dict[KEY_TYPE, float]] + | ObjectVar[Dict[KEY_TYPE, int | float]] + ), + key: Var | Any, + ) -> NumberVar: ... + + @overload + def __getitem__( + self: ObjectVar[Dict[KEY_TYPE, str]], + key: Var | Any, + ) -> StringVar: ... + + @overload + def __getitem__( + self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]], + key: Var | Any, + ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... + + @overload + def __getitem__( + self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]], + key: Var | Any, + ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... + + @overload + def __getitem__( + self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]], + key: Var | Any, + ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... + + @overload + def __getitem__( + self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + key: Var | Any, + ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + def __getitem__(self, key: Var | Any) -> ImmutableVar: """Get an item from the object. @@ -78,16 +199,78 @@ class ObjectVar(ImmutableVar): """ return ObjectItemOperation(self, key).guess_type() - def __getattr__(self, name) -> ObjectItemOperation: + # NoReturn is used here to catch when key value is Any + @overload + def __getattr__( + self: ObjectVar[Dict[KEY_TYPE, NoReturn]], + name: str, + ) -> ImmutableVar: ... + + @overload + def __getattr__( + self: ( + ObjectVar[Dict[KEY_TYPE, int]] + | ObjectVar[Dict[KEY_TYPE, float]] + | ObjectVar[Dict[KEY_TYPE, int | float]] + ), + name: str, + ) -> NumberVar: ... + + @overload + def __getattr__( + self: ObjectVar[Dict[KEY_TYPE, str]], + name: str, + ) -> StringVar: ... + + @overload + def __getattr__( + self: ObjectVar[Dict[KEY_TYPE, list[ARRAY_INNER_TYPE]]], + name: str, + ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... + + @overload + def __getattr__( + self: ObjectVar[Dict[KEY_TYPE, set[ARRAY_INNER_TYPE]]], + name: str, + ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... + + @overload + def __getattr__( + self: ObjectVar[Dict[KEY_TYPE, tuple[ARRAY_INNER_TYPE, ...]]], + name: str, + ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... + + @overload + def __getattr__( + self: ObjectVar[Dict[KEY_TYPE, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], + name: str, + ) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... + + def __getattr__(self, name) -> ImmutableVar: """Get an attribute of the var. Args: name: The name of the attribute. + Raises: + VarAttributeError: The State var has no such attribute or may have been annotated wrongly. + Returns: The attribute of the var. """ - return ObjectItemOperation(self, name) + fixed_type = ( + self._var_type if isclass(self._var_type) else get_origin(self._var_type) + ) + if not issubclass(fixed_type, dict): + attribute_type = get_attribute_access_type(self._var_type, name) + if attribute_type is None: + raise VarAttributeError( + f"The State var `{self._var_name}` has no attribute '{name}' or may have been annotated " + f"wrongly." + ) + return ObjectItemOperation(self, name, attribute_type).guess_type() + else: + return ObjectItemOperation(self, name).guess_type() @dataclasses.dataclass( @@ -95,7 +278,7 @@ class ObjectVar(ImmutableVar): frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class LiteralObjectVar(LiteralVar, ObjectVar): +class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]): """Base class for immutable literal object vars.""" _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( @@ -103,9 +286,9 @@ class LiteralObjectVar(LiteralVar, ObjectVar): ) def __init__( - self, - _var_value: dict[Var | Any, Var | Any], - _var_type: Type | None = None, + self: LiteralObjectVar[OBJECT_TYPE], + _var_value: OBJECT_TYPE, + _var_type: Type[OBJECT_TYPE] | None = None, _var_data: VarData | None = None, ): """Initialize the object var. @@ -117,14 +300,7 @@ class LiteralObjectVar(LiteralVar, ObjectVar): """ super(LiteralObjectVar, self).__init__( _var_name="", - _var_type=( - Dict[ - unionize(*map(type, _var_value.keys())), - unionize(*map(type, _var_value.values())), - ] - if _var_type is None - else _var_type - ), + _var_type=(figure_out_type(_var_value) if _var_type is None else _var_type), _var_data=ImmutableVarData.merge(_var_data), ) object.__setattr__( @@ -489,6 +665,7 @@ class ObjectItemOperation(ImmutableVar): self, value: ObjectVar, key: Var | Any, + _var_type: GenericType | None = None, _var_data: VarData | None = None, ): """Initialize the object item operation. @@ -500,7 +677,7 @@ class ObjectItemOperation(ImmutableVar): """ super(ObjectItemOperation, self).__init__( _var_name="", - _var_type=value._value_type(), + _var_type=value._value_type() if _var_type is None else _var_type, _var_data=ImmutableVarData.merge(_var_data), ) object.__setattr__(self, "value", value) diff --git a/reflex/experimental/vars/sequence.py b/reflex/experimental/vars/sequence.py index 8db1300ec..f622159a6 100644 --- a/reflex/experimental/vars/sequence.py +++ b/reflex/experimental/vars/sequence.py @@ -10,7 +10,18 @@ import re import sys import typing from functools import cached_property -from typing import Any, List, Set, Tuple, Type, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Set, + Tuple, + TypeVar, + Union, + overload, +) from typing_extensions import get_origin @@ -19,6 +30,8 @@ from reflex.constants.base import REFLEX_VAR_OPENING_TAG from reflex.experimental.vars.base import ( ImmutableVar, LiteralVar, + figure_out_type, + unionize, ) from reflex.experimental.vars.number import ( BooleanVar, @@ -29,8 +42,11 @@ from reflex.experimental.vars.number import ( from reflex.utils.types import GenericType from reflex.vars import ImmutableVarData, Var, VarData, _global_vars +if TYPE_CHECKING: + from .object import ObjectVar -class StringVar(ImmutableVar): + +class StringVar(ImmutableVar[str]): """Base class for immutable string vars.""" def __add__(self, other: StringVar | str) -> ConcatVarOperation: @@ -699,7 +715,17 @@ class ConcatVarOperation(StringVar): pass -class ArrayVar(ImmutableVar): +ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set]) + +OTHER_TUPLE = TypeVar("OTHER_TUPLE") + +INNER_ARRAY_VAR = TypeVar("INNER_ARRAY_VAR") + +KEY_TYPE = TypeVar("KEY_TYPE") +VALUE_TYPE = TypeVar("VALUE_TYPE") + + +class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): """Base class for immutable array vars.""" from reflex.experimental.vars.sequence import StringVar @@ -717,7 +743,7 @@ class ArrayVar(ImmutableVar): return ArrayJoinOperation(self, sep) - def reverse(self) -> ArrayReverseOperation: + def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]: """Reverse the array. Returns: @@ -726,14 +752,98 @@ class ArrayVar(ImmutableVar): return ArrayReverseOperation(self) @overload - def __getitem__(self, i: slice) -> ArraySliceOperation: ... + def __getitem__(self, i: slice) -> ArrayVar[ARRAY_VAR_TYPE]: ... + + @overload + def __getitem__( + self: ( + ArrayVar[Tuple[int, OTHER_TUPLE]] + | ArrayVar[Tuple[float, OTHER_TUPLE]] + | ArrayVar[Tuple[int | float, OTHER_TUPLE]] + ), + i: Literal[0, -2], + ) -> NumberVar: ... + + @overload + def __getitem__( + self: ( + ArrayVar[Tuple[OTHER_TUPLE, int]] + | ArrayVar[Tuple[OTHER_TUPLE, float]] + | ArrayVar[Tuple[OTHER_TUPLE, int | float]] + ), + i: Literal[1, -1], + ) -> NumberVar: ... + + @overload + def __getitem__( + self: ArrayVar[Tuple[str, OTHER_TUPLE]], i: Literal[0, -2] + ) -> StringVar: ... + + @overload + def __getitem__( + self: ArrayVar[Tuple[OTHER_TUPLE, str]], i: Literal[1, -1] + ) -> StringVar: ... + + @overload + def __getitem__( + self: ArrayVar[Tuple[bool, OTHER_TUPLE]], i: Literal[0, -2] + ) -> BooleanVar: ... + + @overload + def __getitem__( + self: ArrayVar[Tuple[OTHER_TUPLE, bool]], i: Literal[1, -1] + ) -> BooleanVar: ... + + @overload + def __getitem__( + self: ( + ARRAY_VAR_OF_LIST_ELEMENT[int] + | ARRAY_VAR_OF_LIST_ELEMENT[float] + | ARRAY_VAR_OF_LIST_ELEMENT[int | float] + ), + i: int | NumberVar, + ) -> NumberVar: ... + + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar + ) -> StringVar: ... + + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar + ) -> BooleanVar: ... + + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]], + i: int | NumberVar, + ) -> ArrayVar[List[INNER_ARRAY_VAR]]: ... + + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]], + i: int | NumberVar, + ) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ... + + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[INNER_ARRAY_VAR, ...]], + i: int | NumberVar, + ) -> ArrayVar[Tuple[INNER_ARRAY_VAR, ...]]: ... + + @overload + def __getitem__( + self: ARRAY_VAR_OF_LIST_ELEMENT[Dict[KEY_TYPE, VALUE_TYPE]], + i: int | NumberVar, + ) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ... @overload def __getitem__(self, i: int | NumberVar) -> ImmutableVar: ... def __getitem__( self, i: slice | int | NumberVar - ) -> ArraySliceOperation | ImmutableVar: + ) -> ArrayVar[ARRAY_VAR_TYPE] | ImmutableVar: """Get a slice of the array. Args: @@ -756,7 +866,7 @@ class ArrayVar(ImmutableVar): @overload @classmethod - def range(cls, stop: int | NumberVar, /) -> RangeOperation: ... + def range(cls, stop: int | NumberVar, /) -> ArrayVar[List[int]]: ... @overload @classmethod @@ -766,7 +876,7 @@ class ArrayVar(ImmutableVar): end: int | NumberVar, step: int | NumberVar = 1, /, - ) -> RangeOperation: ... + ) -> ArrayVar[List[int]]: ... @classmethod def range( @@ -774,7 +884,7 @@ class ArrayVar(ImmutableVar): first_endpoint: int | NumberVar, second_endpoint: int | NumberVar | None = None, step: int | NumberVar | None = None, - ) -> RangeOperation: + ) -> ArrayVar[List[int]]: """Create a range of numbers. Args: @@ -794,7 +904,7 @@ class ArrayVar(ImmutableVar): return RangeOperation(start, end, step or 1) - def contains(self, other: Any) -> ArrayContainsOperation: + def contains(self, other: Any) -> BooleanVar: """Check if the array contains an element. Args: @@ -806,12 +916,21 @@ class ArrayVar(ImmutableVar): return ArrayContainsOperation(self, other) +LIST_ELEMENT = TypeVar("LIST_ELEMENT") + +ARRAY_VAR_OF_LIST_ELEMENT = Union[ + ArrayVar[List[LIST_ELEMENT]], + ArrayVar[Set[LIST_ELEMENT]], + ArrayVar[Tuple[LIST_ELEMENT, ...]], +] + + @dataclasses.dataclass( eq=False, frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class LiteralArrayVar(LiteralVar, ArrayVar): +class LiteralArrayVar(LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): """Base class for immutable literal array vars.""" _var_value: Union[ @@ -819,9 +938,9 @@ class LiteralArrayVar(LiteralVar, ArrayVar): ] = dataclasses.field(default_factory=list) def __init__( - self, - _var_value: list[Var | Any] | tuple[Var | Any, ...] | set[Var | Any], - _var_type: type[list] | type[tuple] | type[set] | None = None, + self: LiteralArrayVar[ARRAY_VAR_TYPE], + _var_value: ARRAY_VAR_TYPE, + _var_type: type[ARRAY_VAR_TYPE] | None = None, _var_data: VarData | None = None, ): """Initialize the array var. @@ -834,11 +953,7 @@ class LiteralArrayVar(LiteralVar, ArrayVar): super(LiteralArrayVar, self).__init__( _var_name="", _var_data=ImmutableVarData.merge(_var_data), - _var_type=( - List[unionize(*map(type, _var_value))] - if _var_type is None - else _var_type - ), + _var_type=(figure_out_type(_var_value) if _var_type is None else _var_type), ) object.__setattr__(self, "_var_value", _var_value) object.__delattr__(self, "_var_name") @@ -1261,23 +1376,6 @@ class ArrayLengthOperation(ArrayToNumberOperation): return f"{str(self.a)}.length" -def unionize(*args: Type) -> Type: - """Unionize the types. - - Args: - args: The types to unionize. - - Returns: - The unionized types. - """ - if not args: - return Any - first, *rest = args - if not rest: - return first - return Union[first, unionize(*rest)] - - def is_tuple_type(t: GenericType) -> bool: """Check if a type is a tuple type. diff --git a/tests/test_var.py b/tests/test_var.py index 66599db72..5c67d9924 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -1042,7 +1042,7 @@ def test_object_operations(): def test_type_chains(): object_var = LiteralObjectVar({"a": 1, "b": 2, "c": 3}) - assert object_var._var_type is Dict[str, int] + assert (object_var._key_type(), object_var._value_type()) == (str, int) assert (object_var.keys()._var_type, object_var.values()._var_type) == ( List[str], List[int], @@ -1061,6 +1061,31 @@ def test_type_chains(): ) +def test_nested_dict(): + arr = LiteralArrayVar([{"bar": ["foo", "bar"]}], List[Dict[str, List[str]]]) + + assert ( + str(arr[0]["bar"][0]) == '[({ ["bar"] : ["foo", "bar"] })].at(0)["bar"].at(0)' + ) + + +def nested_base(): + class Boo(Base): + foo: str + bar: int + + class Foo(Base): + bar: Boo + baz: int + + parent_obj = LiteralVar.create(Foo(bar=Boo(foo="bar", bar=5), baz=5)) + + assert ( + str(parent_obj.bar.foo) + == '({ ["bar"] : ({ ["foo"] : "bar", ["bar"] : 5 }), ["baz"] : 5 })["bar"]["foo"]' + ) + + def test_retrival(): var_without_data = ImmutableVar.create("test") assert var_without_data is not None From 76627c207691af997948e82a8192b01e0d4caa81 Mon Sep 17 00:00:00 2001 From: Alek Petuskey Date: Thu, 1 Aug 2024 13:24:39 -0700 Subject: [PATCH 22/34] Add comments to html metadata component (#3731) --- reflex/components/el/elements/metadata.py | 27 +++++++++++++++++++++- reflex/components/el/elements/metadata.pyi | 13 +++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/reflex/components/el/elements/metadata.py b/reflex/components/el/elements/metadata.py index c19612abe..9a4d18b73 100644 --- a/reflex/components/el/elements/metadata.py +++ b/reflex/components/el/elements/metadata.py @@ -29,24 +29,49 @@ class Link(BaseHTML): # noqa: E742 tag = "link" + # Specifies the CORS settings for the linked resource cross_origin: Var[Union[str, int, bool]] + + # Specifies the URL of the linked document/resource href: Var[Union[str, int, bool]] + + # Specifies the language of the text in the linked document href_lang: Var[Union[str, int, bool]] + + # Allows a browser to check the fetched link for integrity integrity: Var[Union[str, int, bool]] + + # Specifies on what device the linked document will be displayed media: Var[Union[str, int, bool]] + + # Specifies the referrer policy of the linked document referrer_policy: Var[Union[str, int, bool]] + + # Specifies the relationship between the current document and the linked one rel: Var[Union[str, int, bool]] + + # Specifies the sizes of icons for visual media sizes: Var[Union[str, int, bool]] + + # Specifies the MIME type of the linked document type: Var[Union[str, int, bool]] class Meta(BaseHTML): # Inherits common attributes from BaseHTML """Display the meta element.""" - tag = "meta" + tag = "meta" # The HTML tag for this element is + + # Specifies the character encoding for the HTML document char_set: Var[Union[str, int, bool]] + + # Defines the content of the metadata content: Var[Union[str, int, bool]] + + # Provides an HTTP header for the information/value of the content attribute http_equiv: Var[Union[str, int, bool]] + + # Specifies a name for the metadata name: Var[Union[str, int, bool]] diff --git a/reflex/components/el/elements/metadata.pyi b/reflex/components/el/elements/metadata.pyi index d4d68adb6..e08c1d723 100644 --- a/reflex/components/el/elements/metadata.pyi +++ b/reflex/components/el/elements/metadata.pyi @@ -346,6 +346,15 @@ class Link(BaseHTML): Args: *children: The children of the component. + cross_origin: Specifies the CORS settings for the linked resource + href: Specifies the URL of the linked document/resource + href_lang: Specifies the language of the text in the linked document + integrity: Allows a browser to check the fetched link for integrity + media: Specifies on what device the linked document will be displayed + referrer_policy: Specifies the referrer policy of the linked document + rel: Specifies the relationship between the current document and the linked one + sizes: Specifies the sizes of icons for visual media + type: Specifies the MIME type of the linked document access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. @@ -466,6 +475,10 @@ class Meta(BaseHTML): Args: *children: The children of the component. + char_set: Specifies the character encoding for the HTML document + content: Defines the content of the metadata + http_equiv: Provides an HTTP header for the information/value of the content attribute + name: Specifies a name for the metadata access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. From c7e30522bc224e960487d2084e45f0d273109357 Mon Sep 17 00:00:00 2001 From: Manas Gupta <53006261+Manas1820@users.noreply.github.com> Date: Tue, 6 Aug 2024 02:32:47 +0530 Subject: [PATCH 23/34] fix: add verification for path /404 (#3723) Co-authored-by: coolstorm --- reflex/.templates/web/utils/client_side_routing.js | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/reflex/.templates/web/utils/client_side_routing.js b/reflex/.templates/web/utils/client_side_routing.js index 75fb581c8..1718c8e61 100644 --- a/reflex/.templates/web/utils/client_side_routing.js +++ b/reflex/.templates/web/utils/client_side_routing.js @@ -23,7 +23,12 @@ export const useClientSideRouting = () => { router.replace({ pathname: window.location.pathname, query: window.location.search.slice(1), - }) + }).then(()=>{ + // Check if the current route is /404 + if (router.pathname === '/404') { + setRouteNotFound(true); // Mark as an actual 404 + } + }) .catch((e) => { setRouteNotFound(true) // navigation failed, so this is a real 404 }) From 7d9ed7e2ce376ac71909dfc955fa62f2fb5ce5ce Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 5 Aug 2024 14:03:30 -0700 Subject: [PATCH 24/34] Use the new state name when setting `is_hydrated` to false (#3738) --- reflex/.templates/web/utils/state.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 81ac40100..26b2d0d0c 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -782,7 +782,7 @@ export const useEventLoop = ( // Route after the initial page hydration. useEffect(() => { const change_start = () => { - const main_state_dispatch = dispatch["state"] + const main_state_dispatch = dispatch["reflex___state____state"] if (main_state_dispatch !== undefined) { main_state_dispatch({ is_hydrated: false }) } From 3309c0e53356256268294286c44ea09f50014e0a Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 5 Aug 2024 14:03:49 -0700 Subject: [PATCH 25/34] Use `._is_mutable()` to account for parent state proxy (#3739) When a parent state proxy is set, also allow child StateProxy._self_mutable to override the parent's `_is_mutable()`. --- reflex/state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index e29336042..b0c6646ce 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2059,7 +2059,7 @@ class StateProxy(wrapt.ObjectProxy): Whether the state is mutable. """ if self._self_parent_state_proxy is not None: - return self._self_parent_state_proxy._is_mutable() + return self._self_parent_state_proxy._is_mutable() or self._self_mutable return self._self_mutable async def __aenter__(self) -> StateProxy: @@ -3302,7 +3302,7 @@ class ImmutableMutableProxy(MutableProxy): Raises: ImmutableStateError: if the StateProxy is not mutable. """ - if not self._self_state._self_mutable: + if not self._self_state._is_mutable(): raise ImmutableStateError( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." From 1131caebe825f4266856ab7ca73fb76f0490c1e6 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 6 Aug 2024 13:43:12 -0700 Subject: [PATCH 26/34] bump to 0.5.9 (#3746) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0feb287c8..656ff091a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "reflex" -version = "0.5.8" +version = "0.5.9" description = "Web apps in pure Python." license = "Apache-2.0" authors = [ From 1a3a6f4a34b3e6fc95dda258402f6503c2bb7a34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Fri, 9 Aug 2024 21:08:16 +0200 Subject: [PATCH 27/34] add message when installing requirements.txt is needed for chosen template during init (#3750) --- reflex/constants/base.py | 2 ++ reflex/utils/prerequisites.py | 11 ++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/reflex/constants/base.py b/reflex/constants/base.py index 65d957d27..a858a69b1 100644 --- a/reflex/constants/base.py +++ b/reflex/constants/base.py @@ -77,6 +77,8 @@ class Reflex(SimpleNamespace): os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) + RELEASES_URL = f"https://api.github.com/repos/reflex-dev/templates/releases" + class ReflexHostingCLI(SimpleNamespace): """Base constants concerning Reflex Hosting CLI.""" diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 765560129..72eb75b90 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -1311,9 +1311,6 @@ def migrate_to_reflex(): print(line, end="") -RELEASES_URL = f"https://api.github.com/repos/reflex-dev/templates/releases" - - def fetch_app_templates(version: str) -> dict[str, Template]: """Fetch a dict of templates from the templates repo using github API. @@ -1325,7 +1322,7 @@ def fetch_app_templates(version: str) -> dict[str, Template]: """ def get_release_by_tag(tag: str) -> dict | None: - response = httpx.get(RELEASES_URL) + response = httpx.get(constants.Reflex.RELEASES_URL) response.raise_for_status() releases = response.json() for release in releases: @@ -1437,7 +1434,11 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str template_code_dir_name=template_name, template_dir=template_dir, ) - + req_file = Path("requirements.txt") + if req_file.exists() and len(req_file.read_text().splitlines()) > 1: + console.info( + "Run `pip install -r requirements.txt` to install the required python packages for this template." + ) # Clean up the temp directories. shutil.rmtree(temp_dir) shutil.rmtree(unzip_dir) From 911c2af044ec2702d40045ef2913603505eac02e Mon Sep 17 00:00:00 2001 From: Shubhankar Dimri Date: Sat, 10 Aug 2024 00:42:49 +0530 Subject: [PATCH 28/34] #3752 bugfix add domain for XAxis (#3764) --- reflex/components/recharts/cartesian.py | 3 +++ reflex/components/recharts/cartesian.pyi | 2 ++ 2 files changed, 5 insertions(+) diff --git a/reflex/components/recharts/cartesian.py b/reflex/components/recharts/cartesian.py index 710fef19b..0f2ec2f32 100644 --- a/reflex/components/recharts/cartesian.py +++ b/reflex/components/recharts/cartesian.py @@ -138,6 +138,9 @@ class XAxis(Axis): # Ensures that all datapoints within a chart contribute to its domain calculation, even when they are hidden include_hidden: Var[bool] = Var.create_safe(False) + # The range of the axis. Work best in conjuction with allow_data_overflow. + domain: Var[List] + class YAxis(Axis): """A YAxis component in Recharts.""" diff --git a/reflex/components/recharts/cartesian.pyi b/reflex/components/recharts/cartesian.pyi index 21be32b46..498ed7444 100644 --- a/reflex/components/recharts/cartesian.pyi +++ b/reflex/components/recharts/cartesian.pyi @@ -192,6 +192,7 @@ class XAxis(Axis): ] = None, x_axis_id: Optional[Union[Var[Union[int, str]], str, int]] = None, include_hidden: Optional[Union[Var[bool], bool]] = None, + domain: Optional[Union[Var[List], List]] = None, data_key: Optional[Union[Var[Union[int, str]], str, int]] = None, hide: Optional[Union[Var[bool], bool]] = None, width: Optional[Union[Var[Union[int, str]], str, int]] = None, @@ -320,6 +321,7 @@ class XAxis(Axis): orientation: The orientation of axis 'top' | 'bottom' x_axis_id: The id of x-axis which is corresponding to the data. include_hidden: Ensures that all datapoints within a chart contribute to its domain calculation, even when they are hidden + domain: The range of the axis. Work best in conjuction with allow_data_overflow. data_key: The key of data displayed in the axis. hide: If set true, the axis do not display in the chart. width: The width of axis which is usually calculated internally. From b58ce1082eef49d5c3dc49cfe9948def1f21cfc3 Mon Sep 17 00:00:00 2001 From: benedikt-bartscher <31854409+benedikt-bartscher@users.noreply.github.com> Date: Mon, 12 Aug 2024 02:42:39 +0200 Subject: [PATCH 29/34] fix appharness app_source typing (#3777) --- reflex/testing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/reflex/testing.py b/reflex/testing.py index bde218c5e..52bff03c0 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -115,7 +115,7 @@ class AppHarness: app_name: str app_source: Optional[ - types.FunctionType | types.ModuleType | str | functools.partial + types.FunctionType | types.ModuleType | str | functools.partial[Any] ] app_path: pathlib.Path app_module_path: pathlib.Path @@ -134,7 +134,9 @@ class AppHarness: def create( cls, root: pathlib.Path, - app_source: Optional[types.FunctionType | types.ModuleType | str] = None, + app_source: Optional[ + types.FunctionType | types.ModuleType | str | functools.partial[Any] + ] = None, app_name: Optional[str] = None, ) -> "AppHarness": """Create an AppHarness instance at root. From 0eff63eed4580679c2c04232a0d42ee0205f87ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Mon, 12 Aug 2024 12:04:55 +0200 Subject: [PATCH 30/34] fix import clash between connectionToaster and hooks.useState (#3749) --- reflex/experimental/hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/experimental/hooks.py b/reflex/experimental/hooks.py index 090c2b39d..848ad7fb7 100644 --- a/reflex/experimental/hooks.py +++ b/reflex/experimental/hooks.py @@ -7,7 +7,7 @@ from reflex.vars import Var, VarData def _compose_react_imports(tags: list[str]) -> dict[str, list[ImportVar]]: - return {"react": [ImportVar(tag=tag, install=False) for tag in tags]} + return {"react": [ImportVar(tag=tag) for tag in tags]} def const(name, value) -> Var: From 910bcdb017aba7a05b4efa6501c5f3dbff6f235d Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Mon, 12 Aug 2024 03:06:34 -0700 Subject: [PATCH 31/34] use different registry when in china, fixes #3700 (#3702) --- reflex/constants/installer.py | 2 ++ reflex/utils/prerequisites.py | 10 ++++++++ reflex/utils/registry.py | 48 +++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 reflex/utils/registry.py diff --git a/reflex/constants/installer.py b/reflex/constants/installer.py index fbee4cd5a..4a7027ee8 100644 --- a/reflex/constants/installer.py +++ b/reflex/constants/installer.py @@ -51,6 +51,8 @@ class Bun(SimpleNamespace): WINDOWS_INSTALL_URL = ( "https://raw.githubusercontent.com/reflex-dev/reflex/main/scripts/install.ps1" ) + # Path of the bunfig file + CONFIG_PATH = "bunfig.toml" # FNM config. diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 72eb75b90..a80bcd814 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -38,6 +38,7 @@ from reflex.compiler import templates from reflex.config import Config, get_config from reflex.utils import console, path_ops, processes from reflex.utils.format import format_library_name +from reflex.utils.registry import _get_best_registry CURRENTLY_INSTALLING_NODE = False @@ -577,6 +578,15 @@ def initialize_package_json(): code = _compile_package_json() output_path.write_text(code) + best_registry = _get_best_registry() + bun_config_path = get_web_dir() / constants.Bun.CONFIG_PATH + bun_config_path.write_text( + f""" +[install] +registry = "{best_registry}" +""" + ) + def init_reflex_json(project_hash: int | None): """Write the hash of the Reflex project to a REFLEX_JSON. diff --git a/reflex/utils/registry.py b/reflex/utils/registry.py new file mode 100644 index 000000000..551292f2d --- /dev/null +++ b/reflex/utils/registry.py @@ -0,0 +1,48 @@ +"""Utilities for working with registries.""" + +import httpx + +from reflex.utils import console + + +def latency(registry: str) -> int: + """Get the latency of a registry. + + Args: + registry (str): The URL of the registry. + + Returns: + int: The latency of the registry in microseconds. + """ + try: + return httpx.get(registry).elapsed.microseconds + except httpx.HTTPError: + console.info(f"Failed to connect to {registry}.") + return 10_000_000 + + +def average_latency(registry, attempts: int = 3) -> int: + """Get the average latency of a registry. + + Args: + registry (str): The URL of the registry. + attempts (int): The number of attempts to make. Defaults to 10. + + Returns: + int: The average latency of the registry in microseconds. + """ + return sum(latency(registry) for _ in range(attempts)) // attempts + + +def _get_best_registry() -> str: + """Get the best registry based on latency. + + Returns: + str: The best registry. + """ + registries = [ + "https://registry.npmjs.org", + "https://r.cnpmjs.org", + ] + + return min(registries, key=average_latency) From 634c0916f6a893caa56b9aa8eb7a0c1159debd99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Tue, 13 Aug 2024 23:14:38 +0200 Subject: [PATCH 32/34] do not reload compilation if using local app in AppHarness (#3790) * do not reload if using local app * Update reflex/testing.py Co-authored-by: Masen Furer --------- Co-authored-by: Masen Furer --- reflex/testing.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/reflex/testing.py b/reflex/testing.py index 52bff03c0..c52396fc6 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -276,7 +276,10 @@ class AppHarness: before_decorated_pages = reflex.app.DECORATED_PAGES[self.app_name].copy() # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) - self.app_module = reflex.utils.prerequisites.get_compiled_app(reload=True) + self.app_module = reflex.utils.prerequisites.get_compiled_app( + # Do not reload the module for pre-existing apps (only apps generated from source) + reload=self.app_source is not None + ) # Save the pages that were added during testing self._decorated_pages = [ p From b01c4b6c6a0d5642079b73be6a5c27193730e425 Mon Sep 17 00:00:00 2001 From: Alek Petuskey Date: Tue, 13 Aug 2024 15:49:43 -0700 Subject: [PATCH 33/34] Bump memory on relevant actions (#3781) Co-authored-by: Alek Petuskey --- .github/workflows/benchmarks.yml | 2 +- .github/workflows/integration_tests.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index c8ab75adb..b849dd328 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -17,7 +17,7 @@ defaults: env: PYTHONIOENCODING: 'utf8' TELEMETRY_ENABLED: false - NODE_OPTIONS: '--max_old_space_size=4096' + NODE_OPTIONS: '--max_old_space_size=8192' PR_TITLE: ${{ github.event.pull_request.title }} jobs: diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 8fa787aba..58f7a668a 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -29,7 +29,7 @@ env: # - Best effort print lines that contain illegal chars (map to some default char, etc.) PYTHONIOENCODING: 'utf8' TELEMETRY_ENABLED: false - NODE_OPTIONS: '--max_old_space_size=4096' + NODE_OPTIONS: '--max_old_space_size=8192' PR_TITLE: ${{ github.event.pull_request.title }} jobs: From 4190a7983572335311ce9572c42b22fb766250ad Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Thu, 15 Aug 2024 16:30:32 +0000 Subject: [PATCH 34/34] [REF-3334] Validate Toast Props (#3793) --- reflex/app.py | 2 +- reflex/components/sonner/toast.py | 28 +++++++++++++++++++++++++++- reflex/components/sonner/toast.pyi | 5 +++++ tests/test_state.py | 2 +- 4 files changed, 34 insertions(+), 3 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 7e40a95bf..69a5ff978 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -125,8 +125,8 @@ def default_backend_exception_handler(exception: Exception) -> EventSpec: ) if Toaster.is_used: return toast( + "An error occurred.", level="error", - title="An error occurred.", description="
".join(error_message), position="top-center", id="backend_error", diff --git a/reflex/components/sonner/toast.py b/reflex/components/sonner/toast.py index d4df31e82..e00eaad4e 100644 --- a/reflex/components/sonner/toast.py +++ b/reflex/components/sonner/toast.py @@ -4,6 +4,8 @@ from __future__ import annotations from typing import Any, ClassVar, Literal, Optional, Union +from pydantic import ValidationError + from reflex.base import Base from reflex.components.component import Component, ComponentNamespace from reflex.components.lucide.icon import Icon @@ -27,7 +29,6 @@ LiteralPosition = Literal[ "bottom-right", ] - toast_ref = Var.create_safe("refs['__toast']", _var_is_string=False) @@ -128,6 +129,24 @@ class ToastProps(PropsBase): # Function that gets called when the toast disappears automatically after it's timeout (duration` prop). on_auto_close: Optional[Any] + def __init__(self, **kwargs): + """Initialize the props. + + Args: + kwargs: Kwargs to initialize the props. + + Raises: + ValueError: If invalid props are passed on instantiation. + """ + try: + super().__init__(**kwargs) + except ValidationError as e: + invalid_fields = ", ".join([error["loc"][0] for error in e.errors()]) # type: ignore + supported_props_str = ", ".join(f'"{field}"' for field in self.get_fields()) + raise ValueError( + f"Invalid prop(s) {invalid_fields} for rx.toast. Supported props are {supported_props_str}" + ) from None + def dict(self, *args, **kwargs) -> dict[str, Any]: """Convert the object to a dictionary. @@ -159,6 +178,13 @@ class ToastProps(PropsBase): ) return d + class Config: + """Pydantic config.""" + + arbitrary_types_allowed = True + use_enum_values = True + extra = "forbid" + class Toaster(Component): """A Toaster Component for displaying toast notifications.""" diff --git a/reflex/components/sonner/toast.pyi b/reflex/components/sonner/toast.pyi index 7e5758b16..c7e626915 100644 --- a/reflex/components/sonner/toast.pyi +++ b/reflex/components/sonner/toast.pyi @@ -51,6 +51,11 @@ class ToastProps(PropsBase): def dict(self, *args, **kwargs) -> dict[str, Any]: ... + class Config: + arbitrary_types_allowed = True + use_enum_values = True + extra = "forbid" + class Toaster(Component): is_used: ClassVar[bool] = False diff --git a/tests/test_state.py b/tests/test_state.py index aa5705b09..1a014739d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1565,7 +1565,7 @@ async def test_state_with_invalid_yield(capsys, mock_app): assert update.events == rx.event.fix_events( [ rx.toast( - title="An error occurred.", + "An error occurred.", description="TypeError: Your handler test_state_with_invalid_yield..StateWithInvalidYield.invalid_handler must only return/yield: None, Events or other EventHandlers referenced by their class (not using `self`).
See logs for details.", level="error", id="backend_error",