diff --git a/pyproject.toml b/pyproject.toml index 67611a867..7c9f4df32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ fastapi = ">=0.96.0,!=0.111.0,!=0.111.1" gunicorn = ">=20.1.0,<24.0" jinja2 = ">=3.1.2,<4.0" psutil = ">=5.9.4,<7.0" -pydantic = ">=1.10.2,<3.0" +pydantic = ">=1.10.17,<3.0" python-multipart = ">=0.0.5,<0.1" python-socketio = ">=5.7.0,<6.0" redis = ">=4.3.5,<6.0" @@ -55,7 +55,7 @@ typing_extensions = ">=4.6.0" [tool.poetry.group.dev.dependencies] pytest = ">=7.1.2,<9.0" pytest-mock = ">=3.10.0,<4.0" -pyright = ">=1.1.392, <1.2" +pyright = ">=1.1.392.post0,<1.2" darglint = ">=1.8.1,<2.0" dill = ">=0.3.8" toml = ">=0.10.2,<1.0" diff --git a/reflex/.templates/jinja/web/pages/utils.js.jinja2 b/reflex/.templates/jinja/web/pages/utils.js.jinja2 index c883dadcb..176a5d063 100644 --- a/reflex/.templates/jinja/web/pages/utils.js.jinja2 +++ b/reflex/.templates/jinja/web/pages/utils.js.jinja2 @@ -6,12 +6,6 @@ {% filter indent(width=indent_width) %} {%- if component is not mapping %} {{- component }} - {%- elif "iterable" in component %} - {{- render_iterable_tag(component) }} - {%- elif component.name == "match"%} - {{- render_match_tag(component) }} - {%- elif "cond" in component %} - {{- render_condition_tag(component) }} {%- elif component.children|length %} {{- render_tag(component) }} {%- else %} @@ -44,30 +38,6 @@ {%- endmacro %} -{# Rendering condition component. #} -{# Args: #} -{# component: component dictionary #} -{% macro render_condition_tag(component) %} -{ {{- component.cond_state }} ? ( - {{ render(component.true_value) }} -) : ( - {{ render(component.false_value) }} -)} -{%- endmacro %} - - -{# Rendering iterable component. #} -{# Args: #} -{# component: component dictionary #} -{% macro render_iterable_tag(component) %} -<>{ {{ component.iterable_state }}.map(({{ component.arg_name }}, {{ component.arg_index }}) => ( - {% for child in component.children %} - {{ render(child) }} - {% endfor %} -))} -{%- endmacro %} - - {# Rendering props of a component. #} {# Args: #} {# component: component dictionary #} @@ -75,29 +45,6 @@ {% if props|length %} {{ props|join(" ") }}{% endif %} {% endmacro %} -{# Rendering Match component. #} -{# Args: #} -{# component: component dictionary #} -{% macro render_match_tag(component) %} -{ - (() => { - switch (JSON.stringify({{ component.cond._js_expr }})) { - {% for case in component.match_cases %} - {% for condition in case[:-1] %} - case JSON.stringify({{ condition._js_expr }}): - {% endfor %} - return {{ render(case[-1]) }}; - break; - {% endfor %} - default: - return {{ render(component.default) }}; - break; - } - })() - } -{%- endmacro %} - - {# Rendering content with args. #} {# Args: #} {# component: component dictionary #} diff --git a/reflex/.templates/web/utils/helpers/range.js b/reflex/.templates/web/utils/helpers/range.js index 7d1aedaaf..8ff97fc67 100644 --- a/reflex/.templates/web/utils/helpers/range.js +++ b/reflex/.templates/web/utils/helpers/range.js @@ -1,43 +1,43 @@ /** * Simulate the python range() builtin function. * inspired by https://dev.to/guyariely/using-python-range-in-javascript-337p - * + * * If needed outside of an iterator context, use `Array.from(range(10))` or * spread syntax `[...range(10)]` to get an array. - * + * * @param {number} start: the start or end of the range. * @param {number} stop: the end of the range. * @param {number} step: the step of the range. * @returns {object} an object with a Symbol.iterator method over the range */ export default function range(start, stop, step) { - return { - [Symbol.iterator]() { - if (stop === undefined) { - stop = start; - start = 0; - } - if (step === undefined) { - step = 1; - } - - let i = start - step; - - return { - next() { - i += step; - if ((step > 0 && i < stop) || (step < 0 && i > stop)) { - return { - value: i, - done: false, - }; - } + return { + [Symbol.iterator]() { + if ((stop ?? undefined) === undefined) { + stop = start; + start = 0; + } + if ((step ?? undefined) === undefined) { + step = 1; + } + + let i = start - step; + + return { + next() { + i += step; + if ((step > 0 && i < stop) || (step < 0 && i > stop)) { return { - value: undefined, - done: true, + value: i, + done: false, }; - }, - }; - }, - }; - } \ No newline at end of file + } + return { + value: undefined, + done: true, + }; + }, + }; + }, + }; +} diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 2f09ac2de..af078d2f9 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -926,6 +926,45 @@ export const isTrue = (val) => { return Boolean(val); }; +/** + * Returns a copy of a section of an array. + * @param {Array | string} arrayLike The array to slice. + * @param {[number, number, number]} slice The slice to apply. + * @returns The sliced array. + */ +export const atSlice = (arrayLike, slice) => { + const array = [...arrayLike]; + const [startSlice, endSlice, stepSlice] = slice; + if (stepSlice ?? null === null) { + return array.slice(startSlice ?? undefined, endSlice ?? undefined); + } + const step = stepSlice ?? 1; + if (step > 0) { + return array + .slice(startSlice ?? undefined, endSlice ?? undefined) + .filter((_, i) => i % step === 0); + } + const actualStart = (endSlice ?? null) === null ? 0 : endSlice + 1; + const actualEnd = + (startSlice ?? null) === null ? array.length : startSlice + 1; + return array + .slice(actualStart, actualEnd) + .reverse() + .filter((_, i) => i % step === 0); +}; + +/** + * Get the value at a slice or index. + * @param {Array | string} arrayLike The array to get the value from. + * @param {number | [number, number, number]} sliceOrIndex The slice or index to get the value at. + * @returns The value at the slice or index. + */ +export const atSliceOrIndex = (arrayLike, sliceOrIndex) => { + return Array.isArray(sliceOrIndex) + ? atSlice(arrayLike, sliceOrIndex) + : arrayLike.at(sliceOrIndex); +}; + /** * Get the value from a ref. * @param ref The ref to get the value from. diff --git a/reflex/base.py b/reflex/base.py index f6bbb8ce4..c900f0039 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -5,15 +5,9 @@ from __future__ import annotations import os from typing import TYPE_CHECKING, Any, List, Type -try: - import pydantic.v1.main as pydantic_main - from pydantic.v1 import BaseModel - from pydantic.v1.fields import ModelField -except ModuleNotFoundError: - if not TYPE_CHECKING: - import pydantic.main as pydantic_main - from pydantic import BaseModel - from pydantic.fields import ModelField +import pydantic.v1.main as pydantic_main +from pydantic.v1 import BaseModel +from pydantic.v1.fields import ModelField def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None: @@ -50,7 +44,7 @@ if TYPE_CHECKING: from reflex.vars import Var -class Base(BaseModel): # pyright: ignore [reportPossiblyUnboundVariable] +class Base(BaseModel): """The base class subclassed by all Reflex classes. This class wraps Pydantic and provides common methods such as diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 73b0680d3..c8784e443 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import Any, Iterator -from reflex.components.component import Component, LiteralComponentVar +from reflex.components.component import Component, ComponentStyle from reflex.components.tags import Tag from reflex.components.tags.tagless import Tagless from reflex.config import PerformanceMode, environment @@ -12,7 +12,7 @@ from reflex.utils import console from reflex.utils.decorator import once from reflex.utils.imports import ParsedImportDict from reflex.vars import BooleanVar, ObjectVar, Var -from reflex.vars.base import VarData +from reflex.vars.base import GLOBAL_CACHE, VarData from reflex.vars.sequence import LiteralStringVar @@ -80,8 +80,11 @@ class Bare(Component): The hooks for the component. """ hooks = super()._get_all_hooks_internal() - if isinstance(self.contents, LiteralComponentVar): - hooks |= self.contents._var_value._get_all_hooks_internal() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + hooks |= component._get_all_hooks_internal() return hooks def _get_all_hooks(self) -> dict[str, VarData | None]: @@ -91,18 +94,24 @@ class Bare(Component): The hooks for the component. """ hooks = super()._get_all_hooks() - if isinstance(self.contents, LiteralComponentVar): - hooks |= self.contents._var_value._get_all_hooks() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + hooks |= component._get_all_hooks() return hooks - def _get_all_imports(self) -> ParsedImportDict: + def _get_all_imports(self, collapse: bool = False) -> ParsedImportDict: """Include the imports for the component. + Args: + collapse: Whether to collapse the imports. + Returns: The imports for the component. """ - imports = super()._get_all_imports() - if isinstance(self.contents, LiteralComponentVar): + imports = super()._get_all_imports(collapse=collapse) + if isinstance(self.contents, Var): var_data = self.contents._get_all_var_data() if var_data: imports |= {k: list(v) for k, v in var_data.imports} @@ -115,8 +124,11 @@ class Bare(Component): The dynamic imports. """ dynamic_imports = super()._get_all_dynamic_imports() - if isinstance(self.contents, LiteralComponentVar): - dynamic_imports |= self.contents._var_value._get_all_dynamic_imports() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + dynamic_imports |= component._get_all_dynamic_imports() return dynamic_imports def _get_all_custom_code(self) -> set[str]: @@ -126,10 +138,28 @@ class Bare(Component): The custom code. """ custom_code = super()._get_all_custom_code() - if isinstance(self.contents, LiteralComponentVar): - custom_code |= self.contents._var_value._get_all_custom_code() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + custom_code |= component._get_all_custom_code() return custom_code + def _get_all_app_wrap_components(self) -> dict[tuple[int, str], Component]: + """Get the components that should be wrapped in the app. + + Returns: + The components that should be wrapped in the app. + """ + app_wrap_components = super()._get_all_app_wrap_components() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + if isinstance(component, Component): + app_wrap_components |= component._get_all_app_wrap_components() + return app_wrap_components + def _get_all_refs(self) -> set[str]: """Get the refs for the children of the component. @@ -137,8 +167,11 @@ class Bare(Component): The refs for the children. """ refs = super()._get_all_refs() - if isinstance(self.contents, LiteralComponentVar): - refs |= self.contents._var_value._get_all_refs() + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + refs |= component._get_all_refs() return refs def _render(self) -> Tag: @@ -148,6 +181,30 @@ class Bare(Component): return Tagless(contents=f"{{{self.contents!s}}}") return Tagless(contents=str(self.contents)) + def _add_style_recursive( + self, style: ComponentStyle, theme: Component | None = None + ) -> Component: + """Add style to the component and its children. + + Args: + style: The style to add. + theme: The theme to add. + + Returns: + The component with the style added. + """ + new_self = super()._add_style_recursive(style, theme) + if isinstance(self.contents, Var): + var_data = self.contents._get_all_var_data() + if var_data: + for component in var_data.components: + if isinstance(component, Component): + component._add_style_recursive(style, theme) + + GLOBAL_CACHE.clear() + + return new_self + def _get_vars( self, include_children: bool = False, ignore_ids: set[int] | None = None ) -> Iterator[Var]: diff --git a/reflex/components/component.py b/reflex/components/component.py index d27bddf78..da2c198bb 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -65,8 +65,7 @@ from reflex.vars.base import ( Var, cached_property_no_lock, ) -from reflex.vars.function import ArgsFunctionOperation, FunctionStringVar -from reflex.vars.number import ternary_operation +from reflex.vars.function import FunctionStringVar from reflex.vars.object import ObjectVar from reflex.vars.sequence import LiteralArrayVar @@ -889,10 +888,8 @@ class Component(BaseComponent, ABC): children: The children of the component. """ + from reflex.components.base.bare import Bare from reflex.components.base.fragment import Fragment - from reflex.components.core.cond import Cond - from reflex.components.core.foreach import Foreach - from reflex.components.core.match import Match no_valid_parents_defined = all(child._valid_parents == [] for child in children) if ( @@ -903,9 +900,7 @@ class Component(BaseComponent, ABC): return comp_name = type(self).__name__ - allowed_components = [ - comp.__name__ for comp in (Fragment, Foreach, Cond, Match) - ] + allowed_components = [comp.__name__ for comp in (Fragment,)] def validate_child(child: Any): child_name = type(child).__name__ @@ -915,24 +910,39 @@ class Component(BaseComponent, ABC): for c in child.children: validate_child(c) - if isinstance(child, Cond): - validate_child(child.comp1) - validate_child(child.comp2) - - if isinstance(child, Match): - for cases in child.match_cases: - validate_child(cases[-1]) - validate_child(child.default) + if ( + isinstance(child, Bare) + and child.contents is not None + and isinstance(child.contents, Var) + ): + var_data = child.contents._get_all_var_data() + if var_data is not None: + for c in var_data.components: + validate_child(c) if self._invalid_children and child_name in self._invalid_children: raise ValueError( f"The component `{comp_name}` cannot have `{child_name}` as a child component" ) - if self._valid_children and child_name not in [ - *self._valid_children, - *allowed_components, - ]: + valid_children = self._valid_children + allowed_components + + def child_is_in_valid(child_component: Any): + if type(child_component).__name__ in valid_children: + return True + + if ( + not isinstance(child_component, Bare) + or child_component.contents is None + or not isinstance(child_component.contents, Var) + or (var_data := child_component.contents._get_all_var_data()) + is None + ): + return False + + return all(child_is_in_valid(c) for c in var_data.components) + + if self._valid_children and not child_is_in_valid(child): valid_child_list = ", ".join( [f"`{v_child}`" for v_child in self._valid_children] ) @@ -1918,8 +1928,6 @@ class StatefulComponent(BaseComponent): Returns: The stateful component or None if the component should not be memoized. """ - from reflex.components.core.foreach import Foreach - if component._memoization_mode.disposition == MemoizationDisposition.NEVER: # Never memoize this component. return None @@ -1948,10 +1956,6 @@ class StatefulComponent(BaseComponent): # Skip BaseComponent and StatefulComponent children. if not isinstance(child, Component): continue - # Always consider Foreach something that must be memoized by the parent. - if isinstance(child, Foreach): - should_memoize = True - break child = cls._child_var(child) if isinstance(child, Var) and child._get_all_var_data(): should_memoize = True @@ -2001,18 +2005,9 @@ class StatefulComponent(BaseComponent): The Var from the child component or the child itself (for regular cases). """ from reflex.components.base.bare import Bare - from reflex.components.core.cond import Cond - from reflex.components.core.foreach import Foreach - from reflex.components.core.match import Match if isinstance(child, Bare): return child.contents - if isinstance(child, Cond): - return child.cond - if isinstance(child, Foreach): - return child.iterable - if isinstance(child, Match): - return child.cond return child @classmethod @@ -2359,53 +2354,6 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) -> return render_dict_to_var(tag.render(), imported_names) return Var.create(tag) - if "iterable" in tag: - function_return = Var.create( - [ - render_dict_to_var(child.render(), imported_names) - for child in tag["children"] - ] - ) - - func = ArgsFunctionOperation.create( - (tag["arg_var_name"], tag["index_var_name"]), - function_return, - ) - - return FunctionStringVar.create("Array.prototype.map.call").call( - tag["iterable"] - if not isinstance(tag["iterable"], ObjectVar) - else tag["iterable"].items(), - func, - ) - - if tag["name"] == "match": - element = tag["cond"] - - conditionals = render_dict_to_var(tag["default"], imported_names) - - for case in tag["match_cases"][::-1]: - condition = case[0].to_string() == element.to_string() - for pattern in case[1:-1]: - condition = condition | (pattern.to_string() == element.to_string()) - - conditionals = ternary_operation( - condition, - render_dict_to_var(case[-1], imported_names), - conditionals, - ) - - return conditionals - - if "cond" in tag: - return ternary_operation( - tag["cond"], - render_dict_to_var(tag["true_value"], imported_names), - render_dict_to_var(tag["false_value"], imported_names) - if tag["false_value"] is not None - else Var.create(None), - ) - props = {} special_props = [] @@ -2485,17 +2433,14 @@ class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): "@emotion/react": [ ImportVar(tag="jsx"), ], - } - ), - VarData( - imports=self._var_value._get_all_imports(), - ), - VarData( - imports={ "react": [ ImportVar(tag="Fragment"), ], - } + }, + components=(self._var_value,), + ), + VarData( + imports=self._var_value._get_all_imports(), ), ) diff --git a/reflex/components/core/__init__.py b/reflex/components/core/__init__.py index fbe0bdc84..534035f12 100644 --- a/reflex/components/core/__init__.py +++ b/reflex/components/core/__init__.py @@ -21,16 +21,14 @@ _SUBMOD_ATTRS: dict[str, list[str]] = { "colors": [ "color", ], - "cond": ["Cond", "color_mode_cond", "cond"], + "cond": ["color_mode_cond", "cond"], "debounce": ["DebounceInput", "debounce_input"], "foreach": [ "foreach", - "Foreach", ], "html": ["html", "Html"], "match": [ "match", - "Match", ], "breakpoints": ["breakpoints", "set_breakpoints"], "responsive": [ diff --git a/reflex/components/core/__init__.pyi b/reflex/components/core/__init__.pyi index ea9275334..2f1fb2084 100644 --- a/reflex/components/core/__init__.pyi +++ b/reflex/components/core/__init__.pyi @@ -17,16 +17,13 @@ from .breakpoints import set_breakpoints as set_breakpoints from .clipboard import Clipboard as Clipboard from .clipboard import clipboard as clipboard from .colors import color as color -from .cond import Cond as Cond from .cond import color_mode_cond as color_mode_cond from .cond import cond as cond from .debounce import DebounceInput as DebounceInput from .debounce import debounce_input as debounce_input -from .foreach import Foreach as Foreach from .foreach import foreach as foreach from .html import Html as Html from .html import html as html -from .match import Match as Match from .match import match as match from .responsive import desktop_only as desktop_only from .responsive import mobile_and_tablet as mobile_and_tablet diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py index d96f1655a..d5d047322 100644 --- a/reflex/components/core/banner.py +++ b/reflex/components/core/banner.py @@ -5,6 +5,7 @@ from __future__ import annotations from typing import Optional from reflex import constants +from reflex.components.base.fragment import Fragment from reflex.components.component import Component from reflex.components.core.cond import cond from reflex.components.el.elements.typography import Div @@ -163,7 +164,7 @@ class ConnectionToaster(Toaster): return super().create(*children, **props) -class ConnectionBanner(Component): +class ConnectionBanner(Fragment): """A connection banner component.""" @classmethod @@ -190,10 +191,10 @@ class ConnectionBanner(Component): position="fixed", ) - return cond(has_connection_errors, comp) + return super().create(cond(has_connection_errors, comp)) -class ConnectionModal(Component): +class ConnectionModal(Fragment): """A connection status modal window.""" @classmethod @@ -208,16 +209,18 @@ class ConnectionModal(Component): """ if not comp: comp = Text.create(*default_connection_error()) - return cond( - has_too_many_connection_errors, - DialogRoot.create( - DialogContent.create( - DialogTitle.create("Connection Error"), - comp, + return super().create( + cond( + has_too_many_connection_errors, + DialogRoot.create( + DialogContent.create( + DialogTitle.create("Connection Error"), + comp, + ), + open=has_too_many_connection_errors, + z_index=9999, ), - open=has_too_many_connection_errors, - z_index=9999, - ), + ) ) diff --git a/reflex/components/core/banner.pyi b/reflex/components/core/banner.pyi index 707076310..166a8ff0f 100644 --- a/reflex/components/core/banner.pyi +++ b/reflex/components/core/banner.pyi @@ -5,6 +5,7 @@ # ------------------------------------------------------ from typing import Any, Dict, Literal, Optional, Union, overload +from reflex.components.base.fragment import Fragment from reflex.components.component import Component from reflex.components.el.elements.typography import Div from reflex.components.lucide.icon import Icon @@ -137,7 +138,7 @@ class ConnectionToaster(Toaster): """ ... -class ConnectionBanner(Component): +class ConnectionBanner(Fragment): @overload @classmethod def create( # type: ignore @@ -176,7 +177,7 @@ class ConnectionBanner(Component): """ ... -class ConnectionModal(Component): +class ConnectionModal(Fragment): @overload @classmethod def create( # type: ignore diff --git a/reflex/components/core/client_side_routing.py b/reflex/components/core/client_side_routing.py index 0fc40de5f..43d705ada 100644 --- a/reflex/components/core/client_side_routing.py +++ b/reflex/components/core/client_side_routing.py @@ -41,7 +41,7 @@ class ClientSideRouting(Component): return "" -def wait_for_client_redirect(component: Component) -> Component: +def wait_for_client_redirect(component: Component) -> Var[Component]: """Wait for a redirect to occur before rendering a component. This prevents the 404 page from flashing while the redirect is happening. @@ -53,9 +53,9 @@ def wait_for_client_redirect(component: Component) -> Component: The conditionally rendered component. """ return cond( - condition=route_not_found, - c1=component, - c2=ClientSideRouting.create(), + route_not_found, + component, + ClientSideRouting.create(), ) diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index 6f9110a16..d2e7c26b0 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -2,126 +2,31 @@ from __future__ import annotations -from typing import Any, Dict, Optional, overload +from typing import Any, TypeVar, Union, overload from reflex.components.base.fragment import Fragment -from reflex.components.component import BaseComponent, Component, MemoizationLeaf -from reflex.components.tags import CondTag, Tag -from reflex.constants import Dirs +from reflex.components.component import BaseComponent, Component from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode -from reflex.utils.imports import ImportDict, ImportVar -from reflex.vars import VarData +from reflex.utils import types from reflex.vars.base import LiteralVar, Var from reflex.vars.number import ternary_operation -_IS_TRUE_IMPORT: ImportDict = { - f"$/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], -} + +@overload +def cond( + condition: Any, c1: BaseComponent | Var[BaseComponent], c2: Any = None, / +) -> Var[Component]: ... -class Cond(MemoizationLeaf): - """Render one of two components based on a condition.""" - - # The cond to determine which component to render. - cond: Var[Any] - - # The component to render if the cond is true. - comp1: BaseComponent | None = None - # The component to render if the cond is false. - comp2: BaseComponent | None = None - - @classmethod - def create( - cls, - cond: Var, - comp1: BaseComponent, - comp2: Optional[BaseComponent] = None, - ) -> Component: - """Create a conditional component. - - Args: - cond: The cond to determine which component to render. - comp1: The component to render if the cond is true. - comp2: The component to render if the cond is false. - - Returns: - The conditional component. - """ - # Wrap everything in fragments. - if type(comp1).__name__ != "Fragment": - comp1 = Fragment.create(comp1) - if comp2 is None or type(comp2).__name__ != "Fragment": - comp2 = Fragment.create(comp2) if comp2 else Fragment.create() - return Fragment.create( - cls( - cond=cond, - comp1=comp1, - comp2=comp2, - children=[comp1, comp2], - ) - ) - - def _get_props_imports(self): - """Get the imports needed for component's props. - - Returns: - The imports for the component's props of the component. - """ - return [] - - def _render(self) -> Tag: - return CondTag( - cond=self.cond, - true_value=self.comp1.render(), # pyright: ignore [reportOptionalMemberAccess] - false_value=self.comp2.render(), # pyright: ignore [reportOptionalMemberAccess] - ) - - def render(self) -> Dict: - """Render the component. - - Returns: - The dictionary for template of component. - """ - tag = self._render() - return dict( - tag.add_props( - **self.event_triggers, - key=self.key, - sx=self.style, - id=self.id, - class_name=self.class_name, - ).set( - props=tag.format_props(), - ), - cond_state=f"isTrue({self.cond!s})", - ) - - def add_imports(self) -> ImportDict: - """Add imports for the Cond component. - - Returns: - The import dict for the component. - """ - var_data = VarData.merge(self.cond._get_all_var_data()) - - imports = var_data.old_school_imports() if var_data else {} - - return {**imports, **_IS_TRUE_IMPORT} +T = TypeVar("T") +V = TypeVar("V") @overload -def cond(condition: Any, c1: Component, c2: Any) -> Component: ... # pyright: ignore [reportOverlappingOverload] +def cond(condition: Any, c1: T | Var[T], c2: V | Var[V], /) -> Var[T | V]: ... -@overload -def cond(condition: Any, c1: Component) -> Component: ... - - -@overload -def cond(condition: Any, c1: Any, c2: Any) -> Var: ... - - -def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var: +def cond(condition: Any, c1: Any, c2: Any = None, /) -> Var: """Create a conditional component or Prop. Args: @@ -137,48 +42,40 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var: """ # Convert the condition to a Var. cond_var = LiteralVar.create(condition) - if cond_var is None: - raise ValueError("The condition must be set.") - # If the first component is a component, create a Cond component. - if isinstance(c1, BaseComponent): - if c2 is not None and not isinstance(c2, BaseComponent): - raise ValueError("Both arguments must be components.") - return Cond.create(cond_var, c1, c2) + # If the first component is a component, create a Fragment if the second component is not set. + if isinstance(c1, BaseComponent) or ( + isinstance(c1, Var) + and types.safe_typehint_issubclass( + c1._var_type, Union[BaseComponent, list[BaseComponent]] + ) + ): + c2 = c2 if c2 is not None else Fragment.create() - # Otherwise, create a conditional Var. # Check that the second argument is valid. - if isinstance(c2, BaseComponent): - raise ValueError("Both arguments must be props.") if c2 is None: raise ValueError("For conditional vars, the second argument must be set.") - def create_var(cond_part: Any) -> Var[Any]: - return LiteralVar.create(cond_part) - - # convert the truth and false cond parts into vars so the _var_data can be obtained. - c1 = create_var(c1) - c2 = create_var(c2) - # Create the conditional var. return ternary_operation( - cond_var.bool()._replace( - merge_var_data=VarData(imports=_IS_TRUE_IMPORT), - ), + cond_var.bool(), c1, c2, ) @overload -def color_mode_cond(light: Component, dark: Component | None = None) -> Component: ... # pyright: ignore [reportOverlappingOverload] +def color_mode_cond( + light: BaseComponent | Var[BaseComponent], + dark: BaseComponent | Var[BaseComponent] | None = ..., +) -> Var[Component]: ... @overload -def color_mode_cond(light: Any, dark: Any = None) -> Var: ... +def color_mode_cond(light: T | Var[T], dark: V | Var[V]) -> Var[T | V]: ... -def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: +def color_mode_cond(light: Any, dark: Any = None) -> Var: """Create a component or Prop based on color_mode. Args: @@ -193,3 +90,9 @@ def color_mode_cond(light: Any, dark: Any = None) -> Var | Component: light, dark, ) + + +class Cond: + """Create a conditional component or Prop.""" + + create = staticmethod(cond) diff --git a/reflex/components/core/foreach.py b/reflex/components/core/foreach.py index 13db48575..92acc72f9 100644 --- a/reflex/components/core/foreach.py +++ b/reflex/components/core/foreach.py @@ -2,16 +2,9 @@ from __future__ import annotations -import functools -import inspect -from typing import Any, Callable, Iterable +from typing import Callable, Iterable -from reflex.components.base.fragment import Fragment -from reflex.components.component import Component -from reflex.components.tags import IterTag -from reflex.constants import MemoizationMode -from reflex.state import ComponentState -from reflex.utils.exceptions import UntypedVarError +from reflex.vars import ArrayVar, ObjectVar, StringVar from reflex.vars.base import LiteralVar, Var @@ -23,149 +16,42 @@ class ForeachRenderError(TypeError): """Raised when there is an error with the foreach render function.""" -class Foreach(Component): - """A component that takes in an iterable and a render function and renders a list of components.""" +def foreach( + iterable: Var[Iterable] | Iterable, + render_fn: Callable, +) -> Var: + """Create a foreach component. - _memoization_mode = MemoizationMode(recursive=False) + Args: + iterable: The iterable to create components from. + render_fn: A function from the render args to the component. - # The iterable to create components from. - iterable: Var[Iterable] + Returns: + The foreach component. - # A function from the render args to the component. - render_fn: Callable = Fragment.create + Raises: + ForeachVarError: If the iterable is of type Any. + TypeError: If the render function is a ComponentState. + UntypedVarError: If the iterable is of type Any without a type annotation. + """ + iterable = LiteralVar.create(iterable).guess_type() - @classmethod - def create( - cls, - iterable: Var[Iterable] | Iterable, - render_fn: Callable, - ) -> Foreach: - """Create a foreach component. + if isinstance(iterable, ObjectVar): + iterable = iterable.entries() - Args: - iterable: The iterable to create components from. - render_fn: A function from the render args to the component. + if isinstance(iterable, StringVar): + iterable = iterable.split() - Returns: - The foreach component. - - Raises: - ForeachVarError: If the iterable is of type Any. - TypeError: If the render function is a ComponentState. - UntypedVarError: If the iterable is of type Any without a type annotation. - """ - from reflex.vars import ArrayVar, ObjectVar, StringVar - - iterable = LiteralVar.create(iterable).guess_type() - - if iterable._var_type == Any: - raise ForeachVarError( - f"Could not foreach over var `{iterable!s}` of type Any. " - "(If you are trying to foreach over a state var, add a type annotation to the var). " - "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" - ) - - if ( - hasattr(render_fn, "__qualname__") - and render_fn.__qualname__ == ComponentState.create.__qualname__ - ): - raise TypeError( - "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet." - ) - - if isinstance(iterable, ObjectVar): - iterable = iterable.entries() - - if isinstance(iterable, StringVar): - iterable = iterable.split() - - if not isinstance(iterable, ArrayVar): - raise ForeachVarError( - f"Could not foreach over var `{iterable!s}` of type {iterable._var_type}. " - "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" - ) - - component = cls( - iterable=iterable, - render_fn=render_fn, - ) - try: - # Keep a ref to a rendered component to determine correct imports/hooks/styles. - component.children = [component._render().render_component()] - except UntypedVarError as e: - raise UntypedVarError( - f"Could not foreach over var `{iterable!s}` without a type annotation. " - "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" - ) from e - return component - - def _render(self) -> IterTag: - props = {} - - render_sig = inspect.signature(self.render_fn) - params = list(render_sig.parameters.values()) - - # Validate the render function signature. - if len(params) == 0 or len(params) > 2: - raise ForeachRenderError( - "Expected 1 or 2 parameters in foreach render function, got " - f"{[p.name for p in params]}. See " - "https://reflex.dev/docs/library/dynamic-rendering/foreach/" - ) - - if len(params) >= 1: - # Determine the arg var name based on the params accepted by render_fn. - props["arg_var_name"] = params[0].name - - if len(params) == 2: - # Determine the index var name based on the params accepted by render_fn. - props["index_var_name"] = params[1].name - else: - render_fn = self.render_fn - # Otherwise, use a deterministic index, based on the render function bytecode. - code_hash = ( - hash( - getattr( - render_fn, - "__code__", - ( - repr(self.render_fn) - if not isinstance(render_fn, functools.partial) - else render_fn.func.__code__ - ), - ) - ) - .to_bytes( - length=8, - byteorder="big", - signed=True, - ) - .hex() - ) - props["index_var_name"] = f"index_{code_hash}" - - return IterTag( - iterable=self.iterable, - render_fn=self.render_fn, - children=self.children, - **props, + if not isinstance(iterable, ArrayVar): + raise ForeachVarError( + f"Could not foreach over var `{iterable!s}` of type {iterable._var_type}. " + "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" ) - def render(self): - """Render the component. - - Returns: - The dictionary for template of component. - """ - tag = self._render() - - return dict( - tag, - iterable_state=str(tag.iterable), - arg_name=tag.arg_var_name, - arg_index=tag.get_index_var_arg(), - iterable_type=tag.iterable._var_type.mro()[0].__name__, - ) + return iterable.foreach(render_fn) -foreach = Foreach.create +class Foreach: + """Create a foreach component.""" + + create = staticmethod(foreach) diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py index 2d936544a..4db9bf5bd 100644 --- a/reflex/components/core/match.py +++ b/reflex/components/core/match.py @@ -1,274 +1,161 @@ """rx.match.""" -import textwrap -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Union, cast + +from typing_extensions import Unpack from reflex.components.base import Fragment -from reflex.components.component import BaseComponent, Component, MemoizationLeaf -from reflex.components.tags import MatchTag, Tag -from reflex.style import Style -from reflex.utils import format, types +from reflex.components.component import BaseComponent +from reflex.utils import types from reflex.utils.exceptions import MatchTypeError -from reflex.utils.imports import ImportDict -from reflex.vars import VarData -from reflex.vars.base import LiteralVar, Var +from reflex.vars.base import VAR_TYPE, Var +from reflex.vars.number import MatchOperation + +CASE_TYPE = tuple[Unpack[tuple[Any, ...]], Var[VAR_TYPE] | VAR_TYPE] -class Match(MemoizationLeaf): - """Match cases based on a condition.""" +def _process_match_cases(cases: tuple[CASE_TYPE[VAR_TYPE], ...]): + """Process the individual match cases. - # The condition to determine which case to match. - cond: Var[Any] + Args: + cases: The match cases. - # The list of match cases to be matched. - match_cases: List[Any] = [] - - # The catchall case to match. - default: Any - - @classmethod - def create(cls, cond: Any, *cases) -> Union[Component, Var]: - """Create a Match Component. - - Args: - cond: The condition to determine which case to match. - cases: This list of cases to match. - - Returns: - The match component. - - Raises: - ValueError: When a default case is not provided for cases with Var return types. - """ - match_cond_var = cls._create_condition_var(cond) - cases, default = cls._process_cases(list(cases)) - match_cases = cls._process_match_cases(cases) - - cls._validate_return_types(match_cases) - - if default is None and types._issubclass(type(match_cases[0][-1]), Var): + Raises: + ValueError: If the default case is not the last case or the tuple elements are less than 2. + """ + for case in cases: + if not isinstance(case, tuple): raise ValueError( - "For cases with return types as Vars, a default case must be provided" + "rx.match should have tuples of cases and a default case as the last argument." ) - return cls._create_match_cond_var_or_component( - match_cond_var, match_cases, default - ) - - @classmethod - def _create_condition_var(cls, cond: Any) -> Var: - """Convert the condition to a Var. - - Args: - cond: The condition. - - Returns: - The condition as a base var - - Raises: - ValueError: If the condition is not provided. - """ - match_cond_var = LiteralVar.create(cond) - - if match_cond_var is None: - raise ValueError("The condition must be set") - return match_cond_var - - @classmethod - def _process_cases( - cls, cases: List - ) -> Tuple[List, Optional[Union[Var, BaseComponent]]]: - """Process the list of match cases and the catchall default case. - - Args: - cases: The list of match cases. - - Returns: - The default case and the list of match case tuples. - - Raises: - ValueError: If there are multiple default cases. - """ - default = None - - if len([case for case in cases if not isinstance(case, tuple)]) > 1: - raise ValueError("rx.match can only have one default case.") - - if not cases: - raise ValueError("rx.match should have at least one case.") - - # Get the default case which should be the last non-tuple arg - if not isinstance(cases[-1], tuple): - default = cases.pop() - default = ( - cls._create_case_var_with_var_data(default) - if not isinstance(default, BaseComponent) - else default + # There should be at least two elements in a case tuple(a condition and return value) + if len(case) < 2: + raise ValueError( + "A case tuple should have at least a match case element and a return value." ) - return cases, default - @classmethod - def _create_case_var_with_var_data(cls, case_element: Any) -> Var: - """Convert a case element into a Var.If the case - is a Style type, we extract the var data and merge it with the - newly created Var. +def _validate_return_types(*return_values: Any) -> bool: + """Validate that match cases have the same return types. - Args: - case_element: The case element. + Args: + return_values: The return values of the match cases. - Returns: - The case element Var. - """ - _var_data = case_element._var_data if isinstance(case_element, Style) else None - case_element = LiteralVar.create(case_element, _var_data=_var_data) - return case_element + Returns: + True if all cases have the same return types. - @classmethod - def _process_match_cases(cls, cases: List) -> List[List[Var]]: - """Process the individual match cases. + Raises: + MatchTypeError: If the return types of cases are different. + """ - Args: - cases: The match cases. - - Returns: - The processed match cases. - - Raises: - ValueError: If the default case is not the last case or the tuple elements are less than 2. - """ - match_cases = [] - for case in cases: - if not isinstance(case, tuple): - raise ValueError( - "rx.match should have tuples of cases and a default case as the last argument." - ) - # There should be at least two elements in a case tuple(a condition and return value) - if len(case) < 2: - raise ValueError( - "A case tuple should have at least a match case element and a return value." - ) - - case_list = [] - for element in case: - # convert all non component element to vars. - el = ( - cls._create_case_var_with_var_data(element) - if not isinstance(element, BaseComponent) - else element - ) - if not isinstance(el, (Var, BaseComponent)): - raise ValueError("Case element must be a var or component") - case_list.append(el) - - match_cases.append(case_list) - - return match_cases - - @classmethod - def _validate_return_types(cls, match_cases: List[List[Var]]) -> None: - """Validate that match cases have the same return types. - - Args: - match_cases: The match cases. - - Raises: - MatchTypeError: If the return types of cases are different. - """ - first_case_return = match_cases[0][-1] - return_type = type(first_case_return) - - if isinstance(first_case_return, BaseComponent): - return_type = BaseComponent - elif isinstance(first_case_return, Var): - return_type = Var - - for index, case in enumerate(match_cases): - if not types._issubclass(type(case[-1]), return_type): - raise MatchTypeError( - f"Match cases should have the same return types. Case {index} with return " - f"value `{case[-1]._js_expr if isinstance(case[-1], Var) else textwrap.shorten(str(case[-1]), width=250)}`" - f" of type {type(case[-1])!r} is not {return_type}" - ) - - @classmethod - def _create_match_cond_var_or_component( - cls, - match_cond_var: Var, - match_cases: List[List[Var]], - default: Optional[Union[Var, BaseComponent]], - ) -> Union[Component, Var]: - """Create and return the match condition var or component. - - Args: - match_cond_var: The match condition. - match_cases: The list of match cases. - default: The default case. - - Returns: - The match component wrapped in a fragment or the match var. - - Raises: - ValueError: If the return types are not vars when creating a match var for Var types. - """ - if default is None and types._issubclass( - type(match_cases[0][-1]), BaseComponent - ): - default = Fragment.create() - - if types._issubclass(type(match_cases[0][-1]), BaseComponent): - return Fragment.create( - cls( - cond=match_cond_var, - match_cases=match_cases, - default=default, - children=[case[-1] for case in match_cases] + [default], # pyright: ignore [reportArgumentType] - ) + def is_component_or_component_var(obj: Any) -> bool: + return types._isinstance(obj, BaseComponent) or ( + isinstance(obj, Var) + and types.safe_typehint_issubclass( + obj._var_type, Union[list[BaseComponent], BaseComponent] ) - - # Validate the match cases (as well as the default case) to have Var return types. - if any( - case for case in match_cases if not isinstance(case[-1], Var) - ) or not isinstance(default, Var): - raise ValueError("Return types of match cases should be Vars.") - - return Var( - _js_expr=format.format_match( - cond=str(match_cond_var), - match_cases=match_cases, - default=default, # pyright: ignore [reportArgumentType] - ), - _var_type=default._var_type, # pyright: ignore [reportAttributeAccessIssue,reportOptionalMemberAccess] - _var_data=VarData.merge( - match_cond_var._get_all_var_data(), - *[el._get_all_var_data() for case in match_cases for el in case], - default._get_all_var_data(), # pyright: ignore [reportAttributeAccessIssue, reportOptionalMemberAccess] - ), ) - def _render(self) -> Tag: - return MatchTag( - cond=self.cond, match_cases=self.match_cases, default=self.default + def type_of_return_value(obj: Any) -> Any: + if isinstance(obj, Var): + return obj._var_type + return type(obj) + + is_return_type_component = [ + is_component_or_component_var(return_type) for return_type in return_values + ] + + if any(is_return_type_component) and not all(is_return_type_component): + non_component_return_types = [ + (type_of_return_value(return_value), i) + for i, return_value in enumerate(return_values) + if not is_return_type_component[i] + ] + raise MatchTypeError( + "Match cases should have the same return types. " + + "Expected return types to be of type Component or Var[Component]. " + + ". ".join( + [ + f"Return type of case {i} is {return_type}" + for return_type, i in non_component_return_types + ] + ) ) - def render(self) -> Dict: - """Render the component. - - Returns: - The dictionary for template of component. - """ - tag = self._render() - tag.name = "match" - return dict(tag) - - def add_imports(self) -> ImportDict: - """Add imports for the Match component. - - Returns: - The import dict. - """ - var_data = VarData.merge(self.cond._get_all_var_data()) - return var_data.old_school_imports() if var_data else {} + return all(is_return_type_component) -match = Match.create +def _create_match_var( + match_cond_var: Var, + match_cases: tuple[CASE_TYPE[VAR_TYPE], ...], + default: VAR_TYPE | Var[VAR_TYPE], +) -> Var[VAR_TYPE]: + """Create the match var. + + Args: + match_cond_var: The match condition var. + match_cases: The match cases. + default: The default case. + + Returns: + The match var. + """ + return MatchOperation.create(match_cond_var, match_cases, default) + + +def match( + cond: Any, + *cases: Unpack[ + tuple[Unpack[tuple[CASE_TYPE[VAR_TYPE], ...]], Var[VAR_TYPE] | VAR_TYPE] + ], +) -> Var[VAR_TYPE]: + """Create a match var. + + Args: + cond: The condition to match. + cases: The match cases. Each case should be a tuple with the first elements as the match case and the last element as the return value. The last argument should be the default case. + + Returns: + The match var. + + Raises: + ValueError: If the default case is not the last case or the tuple elements are less than 2. + """ + default = types.Unset() + + if len([case for case in cases if not isinstance(case, tuple)]) > 1: + raise ValueError("rx.match can only have one default case.") + + if not cases: + raise ValueError("rx.match should have at least one case.") + + # Get the default case which should be the last non-tuple arg + if not isinstance(cases[-1], tuple): + default = cases[-1] + actual_cases = cases[:-1] + else: + actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases) + + _process_match_cases(actual_cases) + + is_component_match = _validate_return_types( + *[case[-1] for case in actual_cases], + *([default] if not isinstance(default, types.Unset) else []), + ) + + if isinstance(default, types.Unset) and not is_component_match: + raise ValueError( + "For cases with return types as Vars, a default case must be provided" + ) + + if isinstance(default, types.Unset): + default = Fragment.create() + + default = cast(Var[VAR_TYPE] | VAR_TYPE, default) + + return _create_match_var( + cond, + actual_cases, + default, + ) diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index 897b89608..6c86d3c44 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -29,7 +29,7 @@ from reflex.event import ( from reflex.utils import format from reflex.utils.imports import ImportVar from reflex.vars import VarData -from reflex.vars.base import CallableVar, Var, get_unique_variable_name +from reflex.vars.base import Var, get_unique_variable_name from reflex.vars.sequence import LiteralStringVar DEFAULT_UPLOAD_ID: str = "default" @@ -45,7 +45,6 @@ upload_files_context_var_data: VarData = VarData( ) -@CallableVar def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> Var: """Get the file upload drop trigger. @@ -75,7 +74,6 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> Var: ) -@CallableVar def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> Var: """Get the list of selected files. diff --git a/reflex/components/core/upload.pyi b/reflex/components/core/upload.pyi index 6ed96a15e..d1ddceb4d 100644 --- a/reflex/components/core/upload.pyi +++ b/reflex/components/core/upload.pyi @@ -13,14 +13,12 @@ from reflex.event import CallableEventSpec, EventSpec, EventType from reflex.style import Style from reflex.utils.imports import ImportVar from reflex.vars import VarData -from reflex.vars.base import CallableVar, Var +from reflex.vars.base import Var DEFAULT_UPLOAD_ID: str upload_files_context_var_data: VarData -@CallableVar def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> Var: ... -@CallableVar def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> Var: ... @CallableEventSpec def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ... diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index dfac0452a..6ca82a716 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -303,7 +303,7 @@ class DataEditor(NoSSRComponent): # Fired when editing is finished. on_finished_editing: EventHandler[ - passthrough_event_spec(Union[GridCell, None], tuple[int, int]) + passthrough_event_spec(Union[GridCell, None], tuple[int, int]) # pyright: ignore[reportArgumentType] ] # Fired when a row is appended. diff --git a/reflex/components/datadisplay/shiki_code_block.py b/reflex/components/datadisplay/shiki_code_block.py index a4aaec1d4..1930dd02e 100644 --- a/reflex/components/datadisplay/shiki_code_block.py +++ b/reflex/components/datadisplay/shiki_code_block.py @@ -828,7 +828,7 @@ class ShikiHighLevelCodeBlock(ShikiCodeBlock): if isinstance(code, Var): return string_replace_operation( code, StringVar(_js_expr=f"/{regex_pattern}/g", _var_type=str), "" - ) + ).guess_type() if isinstance(code, str): return re.sub(regex_pattern, "", code) diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index 51d3dd3dd..f3ea2bb85 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -114,8 +114,8 @@ class MarkdownComponentMap: explicit_return = explicit_return or cls._explicit_return return ArgsFunctionOperation.create( - args_names=(DestructuredArg(fields=tuple(fn_args)),), - return_expr=fn_body, + (DestructuredArg(fields=tuple(fn_args)),), + fn_body, explicit_return=explicit_return, _var_data=var_data, ) diff --git a/reflex/components/radix/primitives/slider.py b/reflex/components/radix/primitives/slider.py index 6136e3171..a6f8245ff 100644 --- a/reflex/components/radix/primitives/slider.py +++ b/reflex/components/radix/primitives/slider.py @@ -188,7 +188,7 @@ class Slider(ComponentNamespace): else: children = [ track, - # Foreach.create(props.get("value"), lambda e: SliderThumb.create()), # foreach doesn't render Thumbs properly # noqa: ERA001 + # foreach(props.get("value"), lambda e: SliderThumb.create()), # foreach doesn't render Thumbs properly # noqa: ERA001 ] return SliderRoot.create(*children, **props) diff --git a/reflex/components/radix/themes/color_mode.py b/reflex/components/radix/themes/color_mode.py index d9b7c0b02..b1ab6f8c5 100644 --- a/reflex/components/radix/themes/color_mode.py +++ b/reflex/components/radix/themes/color_mode.py @@ -20,7 +20,7 @@ from __future__ import annotations from typing import Any, Dict, List, Literal, Optional, Union, get_args from reflex.components.component import BaseComponent -from reflex.components.core.cond import Cond, color_mode_cond, cond +from reflex.components.core.cond import color_mode_cond, cond from reflex.components.lucide.icon import Icon from reflex.components.radix.themes.components.dropdown_menu import dropdown_menu from reflex.components.radix.themes.components.switch import Switch @@ -40,28 +40,23 @@ DEFAULT_LIGHT_ICON: Icon = Icon.create(tag="sun") DEFAULT_DARK_ICON: Icon = Icon.create(tag="moon") -class ColorModeIcon(Cond): - """Displays the current color mode as an icon.""" +def color_mode_icon( + light_component: BaseComponent | None = None, + dark_component: BaseComponent | None = None, +): + """Create a color mode icon component. - @classmethod - def create( - cls, - light_component: BaseComponent | None = None, - dark_component: BaseComponent | None = None, - ): - """Create an icon component based on color_mode. + Args: + light_component: The component to render in light mode. + dark_component: The component to render in dark mode. - Args: - light_component: the component to display when color mode is default - dark_component: the component to display when color mode is dark (non-default) - - Returns: - The conditionally rendered component - """ - return color_mode_cond( - light=light_component or DEFAULT_LIGHT_ICON, - dark=dark_component or DEFAULT_DARK_ICON, - ) + Returns: + The color mode icon component. + """ + return color_mode_cond( + light=light_component or DEFAULT_LIGHT_ICON, + dark=dark_component or DEFAULT_DARK_ICON, + ) LiteralPosition = Literal["top-left", "top-right", "bottom-left", "bottom-right"] @@ -144,7 +139,7 @@ class ColorModeIconButton(IconButton): if allow_system: - def color_mode_item(_color_mode: str): + def color_mode_item(_color_mode: Literal["light", "dark", "system"]): return dropdown_menu.item( _color_mode.title(), on_click=set_color_mode(_color_mode) ) @@ -152,7 +147,7 @@ class ColorModeIconButton(IconButton): return dropdown_menu.root( dropdown_menu.trigger( super().create( - ColorModeIcon.create(), + color_mode_icon(), ), **props, ), @@ -163,7 +158,7 @@ class ColorModeIconButton(IconButton): ), ) return IconButton.create( - ColorModeIcon.create(), + color_mode_icon(), on_click=toggle_color_mode, **props, ) @@ -197,7 +192,7 @@ class ColorModeSwitch(Switch): class ColorModeNamespace(Var): """Namespace for color mode components.""" - icon = staticmethod(ColorModeIcon.create) + icon = staticmethod(color_mode_icon) button = staticmethod(ColorModeIconButton.create) switch = staticmethod(ColorModeSwitch.create) diff --git a/reflex/components/radix/themes/color_mode.pyi b/reflex/components/radix/themes/color_mode.pyi index 3b92b752d..8b9cd0ce2 100644 --- a/reflex/components/radix/themes/color_mode.pyi +++ b/reflex/components/radix/themes/color_mode.pyi @@ -7,7 +7,6 @@ from typing import Any, Dict, List, Literal, Optional, Union, overload from reflex.components.component import BaseComponent from reflex.components.core.breakpoints import Breakpoints -from reflex.components.core.cond import Cond from reflex.components.lucide.icon import Icon from reflex.components.radix.themes.components.switch import Switch from reflex.event import EventType @@ -19,48 +18,10 @@ from .components.icon_button import IconButton DEFAULT_LIGHT_ICON: Icon DEFAULT_DARK_ICON: Icon -class ColorModeIcon(Cond): - @overload - @classmethod - def create( # type: ignore - cls, - *children, - cond: Optional[Union[Any, Var[Any]]] = None, - comp1: Optional[BaseComponent] = None, - comp2: Optional[BaseComponent] = None, - style: Optional[Style] = None, - key: Optional[Any] = None, - id: Optional[Any] = None, - class_name: Optional[Any] = None, - autofocus: Optional[bool] = None, - custom_attrs: Optional[Dict[str, Union[Var, Any]]] = None, - on_blur: Optional[EventType[()]] = None, - on_click: Optional[EventType[()]] = None, - on_context_menu: Optional[EventType[()]] = None, - on_double_click: Optional[EventType[()]] = None, - on_focus: Optional[EventType[()]] = None, - on_mount: Optional[EventType[()]] = None, - on_mouse_down: Optional[EventType[()]] = None, - on_mouse_enter: Optional[EventType[()]] = None, - on_mouse_leave: Optional[EventType[()]] = None, - on_mouse_move: Optional[EventType[()]] = None, - on_mouse_out: Optional[EventType[()]] = None, - on_mouse_over: Optional[EventType[()]] = None, - on_mouse_up: Optional[EventType[()]] = None, - on_scroll: Optional[EventType[()]] = None, - on_unmount: Optional[EventType[()]] = None, - **props, - ) -> "ColorModeIcon": - """Create an icon component based on color_mode. - - Args: - light_component: the component to display when color mode is default - dark_component: the component to display when color mode is dark (non-default) - - Returns: - The conditionally rendered component - """ - ... +def color_mode_icon( + light_component: BaseComponent | None = None, + dark_component: BaseComponent | None = None, +): ... LiteralPosition = Literal["top-left", "top-right", "bottom-left", "bottom-right"] position_values: List[str] @@ -440,7 +401,7 @@ class ColorModeSwitch(Switch): ... class ColorModeNamespace(Var): - icon = staticmethod(ColorModeIcon.create) + icon = staticmethod(color_mode_icon) button = staticmethod(ColorModeIconButton.create) switch = staticmethod(ColorModeSwitch.create) diff --git a/reflex/components/radix/themes/components/icon_button.py b/reflex/components/radix/themes/components/icon_button.py index aafb9e1eb..89821e991 100644 --- a/reflex/components/radix/themes/components/icon_button.py +++ b/reflex/components/radix/themes/components/icon_button.py @@ -6,7 +6,7 @@ from typing import Literal from reflex.components.component import Component from reflex.components.core.breakpoints import Responsive -from reflex.components.core.match import Match +from reflex.components.core.match import match from reflex.components.el import elements from reflex.components.lucide import Icon from reflex.style import Style @@ -77,7 +77,7 @@ class IconButton(elements.Button, RadixLoadingProp, RadixThemesComponent): if isinstance(props["size"], str): children[0].size = RADIX_TO_LUCIDE_SIZE[props["size"]] else: - size_map_var = Match.create( + size_map_var = match( props["size"], *list(RADIX_TO_LUCIDE_SIZE.items()), 12, diff --git a/reflex/components/radix/themes/layout/list.py b/reflex/components/radix/themes/layout/list.py index 04fcb6ae5..31aaed2c2 100644 --- a/reflex/components/radix/themes/layout/list.py +++ b/reflex/components/radix/themes/layout/list.py @@ -5,7 +5,7 @@ from __future__ import annotations from typing import Any, Iterable, Literal, Union from reflex.components.component import Component, ComponentNamespace -from reflex.components.core.foreach import Foreach +from reflex.components.core.foreach import foreach from reflex.components.el.elements.typography import Li, Ol, Ul from reflex.components.lucide.icon import Icon from reflex.components.markdown.markdown import MarkdownComponentMap @@ -70,7 +70,7 @@ class BaseList(Component, MarkdownComponentMap): if not children and items is not None: if isinstance(items, Var): - children = [Foreach.create(items, ListItem.create)] + children = [foreach(items, ListItem.create)] else: children = [ListItem.create(item) for item in items] props["direction"] = "column" diff --git a/reflex/components/tags/__init__.py b/reflex/components/tags/__init__.py index 993da11fe..330bcc279 100644 --- a/reflex/components/tags/__init__.py +++ b/reflex/components/tags/__init__.py @@ -1,6 +1,3 @@ """Representations for React tags.""" -from .cond_tag import CondTag -from .iter_tag import IterTag -from .match_tag import MatchTag from .tag import Tag diff --git a/reflex/components/tags/cond_tag.py b/reflex/components/tags/cond_tag.py deleted file mode 100644 index b4d0fe469..000000000 --- a/reflex/components/tags/cond_tag.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Tag to conditionally render components.""" - -import dataclasses -from typing import Any, Dict, Optional - -from reflex.components.tags.tag import Tag -from reflex.vars.base import Var - - -@dataclasses.dataclass() -class CondTag(Tag): - """A conditional tag.""" - - # The condition to determine which component to render. - cond: Var[Any] = dataclasses.field(default_factory=lambda: Var.create(True)) - - # The code to render if the condition is true. - true_value: Dict = dataclasses.field(default_factory=dict) - - # The code to render if the condition is false. - false_value: Optional[Dict] = None diff --git a/reflex/components/tags/iter_tag.py b/reflex/components/tags/iter_tag.py deleted file mode 100644 index 221b65ca9..000000000 --- a/reflex/components/tags/iter_tag.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Tag to loop through a list of components.""" - -from __future__ import annotations - -import dataclasses -import inspect -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, Union, get_args - -from reflex.components.tags.tag import Tag -from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name - -if TYPE_CHECKING: - from reflex.components.component import Component - - -@dataclasses.dataclass() -class IterTag(Tag): - """An iterator tag.""" - - # The var to iterate over. - iterable: Var[Iterable] = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]) - ) - - # The component render function for each item in the iterable. - render_fn: Callable = dataclasses.field(default_factory=lambda: lambda x: x) - - # The name of the arg var. - arg_var_name: str = dataclasses.field(default_factory=get_unique_variable_name) - - # The name of the index var. - index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name) - - def get_iterable_var_type(self) -> Type: - """Get the type of the iterable var. - - Returns: - The type of the iterable var. - """ - iterable = self.iterable - try: - if iterable._var_type.mro()[0] is dict: - # Arg is a tuple of (key, value). - return Tuple[get_args(iterable._var_type)] # pyright: ignore [reportReturnType] - elif iterable._var_type.mro()[0] is tuple: - # Arg is a union of any possible values in the tuple. - return Union[get_args(iterable._var_type)] # pyright: ignore [reportReturnType] - else: - return get_args(iterable._var_type)[0] - except Exception: - return Any # pyright: ignore [reportReturnType] - - def get_index_var(self) -> Var: - """Get the index var for the tag (with curly braces). - - This is used to reference the index var within the tag. - - Returns: - The index var. - """ - return Var( - _js_expr=self.index_var_name, - _var_type=int, - ).guess_type() - - def get_arg_var(self) -> Var: - """Get the arg var for the tag (with curly braces). - - This is used to reference the arg var within the tag. - - Returns: - The arg var. - """ - return Var( - _js_expr=self.arg_var_name, - _var_type=self.get_iterable_var_type(), - ).guess_type() - - def get_index_var_arg(self) -> Var: - """Get the index var for the tag (without curly braces). - - This is used to render the index var in the .map() function. - - Returns: - The index var. - """ - return Var( - _js_expr=self.index_var_name, - _var_type=int, - ).guess_type() - - def get_arg_var_arg(self) -> Var: - """Get the arg var for the tag (without curly braces). - - This is used to render the arg var in the .map() function. - - Returns: - The arg var. - """ - return Var( - _js_expr=self.arg_var_name, - _var_type=self.get_iterable_var_type(), - ).guess_type() - - def render_component(self) -> Component: - """Render the component. - - Raises: - ValueError: If the render function takes more than 2 arguments. - - Returns: - The rendered component. - """ - # Import here to avoid circular imports. - from reflex.components.base.fragment import Fragment - from reflex.components.core.cond import Cond - from reflex.components.core.foreach import Foreach - - # Get the render function arguments. - args = inspect.getfullargspec(self.render_fn).args - arg = self.get_arg_var() - index = self.get_index_var() - - if len(args) == 1: - # If the render function doesn't take the index as an argument. - component = self.render_fn(arg) - else: - # If the render function takes the index as an argument. - if len(args) != 2: - raise ValueError("The render function must take 2 arguments.") - component = self.render_fn(arg, index) - - # Nested foreach components or cond must be wrapped in fragments. - if isinstance(component, (Foreach, Cond)): - component = Fragment.create(component) - - # If the component is a tuple, unpack and wrap it in a fragment. - if isinstance(component, tuple): - component = Fragment.create(*component) - - # Set the component key. - if component.key is None: - component.key = index - - return component diff --git a/reflex/components/tags/match_tag.py b/reflex/components/tags/match_tag.py deleted file mode 100644 index 01eedb296..000000000 --- a/reflex/components/tags/match_tag.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Tag to conditionally match cases.""" - -import dataclasses -from typing import Any, List - -from reflex.components.tags.tag import Tag -from reflex.vars.base import Var - - -@dataclasses.dataclass() -class MatchTag(Tag): - """A match tag.""" - - # The condition to determine which case to match. - cond: Var[Any] = dataclasses.field(default_factory=lambda: Var.create(True)) - - # The list of match cases to be matched. - match_cases: List[Any] = dataclasses.field(default_factory=list) - - # The catchall case to match. - default: Any = dataclasses.field(default=Var.create(None)) diff --git a/reflex/event.py b/reflex/event.py index c2eb8db3a..ee9db2f53 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -600,14 +600,16 @@ stop_propagation = EventChain(events=[], args_spec=no_args_event_spec).stop_prop prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default -T = TypeVar("T") -U = TypeVar("U") +EVENT_T = TypeVar("EVENT_T") +EVENT_U = TypeVar("EVENT_U") + +Ts = TypeVarTuple("Ts") -class IdentityEventReturn(Generic[T], Protocol): +class IdentityEventReturn(Generic[Unpack[Ts]], Protocol): """Protocol for an identity event return.""" - def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]: + def __call__(self, *values: Unpack[Ts]) -> tuple[Unpack[Ts]]: """Return the input values. Args: @@ -620,22 +622,26 @@ class IdentityEventReturn(Generic[T], Protocol): @overload -def passthrough_event_spec( # pyright: ignore [reportOverlappingOverload] - event_type: Type[T], / -) -> Callable[[Var[T]], Tuple[Var[T]]]: ... +def passthrough_event_spec( + event_type: Type[EVENT_T], / +) -> IdentityEventReturn[Var[EVENT_T]]: ... @overload def passthrough_event_spec( - event_type_1: Type[T], event_type2: Type[U], / -) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ... + event_type_1: Type[EVENT_T], event_type2: Type[EVENT_U], / +) -> IdentityEventReturn[Var[EVENT_T], Var[EVENT_U]]: ... @overload -def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: ... +def passthrough_event_spec( + *event_types: Unpack[tuple[Type[EVENT_T]]], +) -> IdentityEventReturn[Unpack[tuple[Var[EVENT_T], ...]]]: ... -def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: # pyright: ignore [reportInconsistentOverload] +def passthrough_event_spec( # pyright: ignore[reportInconsistentOverload] + *event_types: Type[EVENT_T], +) -> IdentityEventReturn[Unpack[tuple[Var[EVENT_T], ...]]]: """A helper function that returns the input event as output. Args: @@ -645,7 +651,7 @@ def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: # A function that returns the input event as output. """ - def inner(*values: Var[T]) -> Tuple[Var[T], ...]: + def inner(*values: Var[EVENT_T]) -> Tuple[Var[EVENT_T], ...]: return values inner_type = tuple(Var[event_type] for event_type in event_types) @@ -780,7 +786,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec: return None fn.__qualname__ = name - fn.__signature__ = sig # pyright: ignore [reportFunctionMemberAccess] + fn.__signature__ = sig # pyright: ignore[reportFunctionMemberAccess] return EventSpec( handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE), args=tuple( diff --git a/reflex/state.py b/reflex/state.py index 77c352cfa..ff1a4b2ca 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -607,8 +607,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if cls._item_is_event_handler(name, fn) } - for mixin in cls._mixins(): # pyright: ignore [reportAssignmentType] - for name, value in mixin.__dict__.items(): + for mixin_class in cls._mixins(): + for name, value in mixin_class.__dict__.items(): if name in cls.inherited_vars: continue if is_computed_var(value): @@ -619,7 +619,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cls.computed_vars[newcv._js_expr] = newcv cls.vars[newcv._js_expr] = newcv continue - if types.is_backend_base_variable(name, mixin): # pyright: ignore [reportArgumentType] + if types.is_backend_base_variable(name, mixin_class): cls.backend_vars[name] = copy.deepcopy(value) continue if events.get(name) is not None: @@ -3653,6 +3653,9 @@ def get_state_manager() -> StateManager: return prerequisites.get_and_validate_app().app.state_manager +DATACLASS_FIELDS = getattr(dataclasses, "_FIELDS", "__dataclass_fields__") + + class MutableProxy(wrapt.ObjectProxy): """A proxy for a mutable object that tracks changes.""" @@ -3724,12 +3727,7 @@ class MutableProxy(wrapt.ObjectProxy): cls.__dataclass_proxies__[wrapper_cls_name] = type( wrapper_cls_name, (cls,), - { - dataclasses._FIELDS: getattr( # pyright: ignore [reportAttributeAccessIssue] - wrapped_cls, - dataclasses._FIELDS, # pyright: ignore [reportAttributeAccessIssue] - ), - }, + {DATACLASS_FIELDS: getattr(wrapped_cls, DATACLASS_FIELDS)}, ) cls = cls.__dataclass_proxies__[wrapper_cls_name] return super().__new__(cls) @@ -3878,11 +3876,11 @@ class MutableProxy(wrapt.ObjectProxy): if ( isinstance(self.__wrapped__, Base) and __name not in self.__never_wrap_base_attrs__ - and hasattr(value, "__func__") + and (value_func := getattr(value, "__func__", None)) ): # Wrap methods called on Base subclasses, which might do _anything_ return wrapt.FunctionWrapper( - functools.partial(value.__func__, self), # pyright: ignore [reportFunctionMemberAccess] + functools.partial(value_func, self), self._wrap_recursive_decorator, ) diff --git a/reflex/style.py b/reflex/style.py index 192835ca3..1d818ed06 100644 --- a/reflex/style.py +++ b/reflex/style.py @@ -12,7 +12,7 @@ from reflex.utils.exceptions import ReflexError from reflex.utils.imports import ImportVar from reflex.utils.types import get_origin from reflex.vars import VarData -from reflex.vars.base import CallableVar, LiteralVar, Var +from reflex.vars.base import LiteralVar, Var from reflex.vars.function import FunctionVar from reflex.vars.object import ObjectVar @@ -48,7 +48,6 @@ def _color_mode_var(_js_expr: str, _var_type: Type = str) -> Var: ).guess_type() -@CallableVar def set_color_mode( new_color_mode: LiteralColorMode | Var[LiteralColorMode] | None = None, ) -> Var[EventChain]: diff --git a/reflex/testing.py b/reflex/testing.py index e463ddea7..754edce8d 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -68,10 +68,8 @@ try: from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports] WebElement, ) - - has_selenium = True except ImportError: - has_selenium = False + webdriver = None # The timeout (minutes) to check for the port. DEFAULT_TIMEOUT = 15 @@ -296,11 +294,13 @@ class AppHarness: if p not in before_decorated_pages ] self.app_instance = self.app_module.app - if self.app_instance and isinstance( - self.app_instance._state_manager, StateManagerRedis - ): + if self.app_instance is None: + raise RuntimeError("App was not initialized.") + if isinstance(self.app_instance._state_manager, StateManagerRedis): # Create our own redis connection for testing. - self.state_manager = StateManagerRedis.create(self.app_instance._state) # pyright: ignore [reportArgumentType] + if self.app_instance._state is None: + raise RuntimeError("App state is not initialized.") + self.state_manager = StateManagerRedis.create(self.app_instance._state) else: self.state_manager = ( self.app_instance._state_manager if self.app_instance else None @@ -615,7 +615,7 @@ class AppHarness: Raises: RuntimeError: when selenium is not importable or frontend is not running """ - if not has_selenium: + if webdriver is None: raise RuntimeError( "Frontend functionality requires `selenium` to be installed, " "and it could not be imported." diff --git a/reflex/utils/console.py b/reflex/utils/console.py index 70bbb0e82..18bd93ff6 100644 --- a/reflex/utils/console.py +++ b/reflex/utils/console.py @@ -201,10 +201,13 @@ def _get_first_non_framework_frame() -> FrameType | None: # Exclude utility modules that should never be the source of deprecated reflex usage. exclude_modules = [click, rx, typer, typing_extensions] exclude_roots = [ - p.parent.resolve() - if (p := Path(m.__file__)).name == "__init__.py" # pyright: ignore [reportArgumentType] - else p.resolve() + ( + p.parent.resolve() + if (p := Path(m.__file__)).name == "__init__.py" + else p.resolve() + ) for m in exclude_modules + if m.__file__ ] # Specifically exclude the reflex cli module. if reflex_bin := shutil.which(b"reflex"): diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 214c845f8..2f23d8441 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -4,9 +4,8 @@ from __future__ import annotations import inspect import json -import os import re -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from reflex import constants from reflex.constants.state import FRONTEND_EVENT_STATE @@ -136,22 +135,6 @@ def wrap( return f"{open * num}{text}{close * num}" -def indent(text: str, indent_level: int = 2) -> str: - """Indent the given text by the given indent level. - - Args: - text: The text to indent. - indent_level: The indent level. - - Returns: - The indented text. - """ - lines = text.splitlines() - if len(lines) < 2: - return text - return os.linesep.join(f"{' ' * indent_level}{line}" for line in lines) + os.linesep - - def to_snake_case(text: str) -> str: """Convert a string to snake case. @@ -239,80 +222,6 @@ def make_default_page_title(app_name: str, route: str) -> str: return to_title_case(title) -def _escape_js_string(string: str) -> str: - """Escape the string for use as a JS string literal. - - Args: - string: The string to escape. - - Returns: - The escaped string. - """ - - # TODO: we may need to re-vist this logic after new Var API is implemented. - def escape_outside_segments(segment: str): - """Escape backticks in segments outside of `${}`. - - Args: - segment: The part of the string to escape. - - Returns: - The escaped or unescaped segment. - """ - if segment.startswith("${") and segment.endswith("}"): - # Return the `${}` segment unchanged - return segment - else: - # Escape backticks in the segment - segment = segment.replace(r"\`", "`") - segment = segment.replace("`", r"\`") - return segment - - # Split the string into parts, keeping the `${}` segments - parts = re.split(r"(\$\{.*?\})", string) - escaped_parts = [escape_outside_segments(part) for part in parts] - escaped_string = "".join(escaped_parts) - return escaped_string - - -def _wrap_js_string(string: str) -> str: - """Wrap string so it looks like {`string`}. - - Args: - string: The string to wrap. - - Returns: - The wrapped string. - """ - string = wrap(string, "`") - string = wrap(string, "{") - return string - - -def format_string(string: str) -> str: - """Format the given string as a JS string literal.. - - Args: - string: The string to format. - - Returns: - The formatted string. - """ - return _wrap_js_string(_escape_js_string(string)) - - -def format_var(var: Var) -> str: - """Format the given Var as a javascript value. - - Args: - var: The Var to format. - - Returns: - The formatted Var. - """ - return str(var) - - def format_route(route: str, format_case: bool = True) -> str: """Format the given route. @@ -335,40 +244,6 @@ def format_route(route: str, format_case: bool = True) -> str: return route -def format_match( - cond: str | Var, - match_cases: List[List[Var]], - default: Var, -) -> str: - """Format a match expression whose return type is a Var. - - Args: - cond: The condition. - match_cases: The list of cases to match. - default: The default case. - - Returns: - The formatted match expression - - """ - switch_code = f"(() => {{ switch (JSON.stringify({cond})) {{" - - for case in match_cases: - conditions = case[:-1] - return_value = case[-1] - - case_conditions = " ".join( - [f"case JSON.stringify({condition!s}):" for condition in conditions] - ) - case_code = f"{case_conditions} return ({return_value!s}); break;" - switch_code += case_code - - switch_code += f"default: return ({default!s}); break;" - switch_code += "};})()" - - return switch_code - - def format_prop( prop: Union[Var, EventChain, ComponentStyle, str], ) -> Union[int, float, str]: diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 93c0e89d3..5386ea0a4 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -16,7 +16,7 @@ from itertools import chain from multiprocessing import Pool, cpu_count from pathlib import Path from types import ModuleType, SimpleNamespace -from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin +from typing import Any, Callable, Iterable, Sequence, Type, cast, get_args, get_origin from reflex.components.component import Component from reflex.utils import types as rx_types @@ -230,7 +230,9 @@ def _generate_imports( """ return [ *[ - ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values]) # pyright: ignore [reportCallIssue] + ast.ImportFrom( + module=name, names=[ast.alias(name=val) for val in values], level=0 + ) for name, values in DEFAULT_IMPORTS.items() ], ast.Import([ast.alias("reflex")]), @@ -429,18 +431,15 @@ def type_to_ast(typ: Any, cls: type) -> ast.AST: return ast.Name(id=base_name) # Convert all type arguments recursively - arg_nodes = [type_to_ast(arg, cls) for arg in args] + arg_nodes = cast(list[ast.expr], [type_to_ast(arg, cls) for arg in args]) # Special case for single-argument types (like List[T] or Optional[T]) if len(arg_nodes) == 1: slice_value = arg_nodes[0] else: - slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load()) # pyright: ignore [reportArgumentType] - + slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load()) return ast.Subscript( - value=ast.Name(id=base_name), - slice=ast.Index(value=slice_value), # pyright: ignore [reportArgumentType] - ctx=ast.Load(), + value=ast.Name(id=base_name), slice=slice_value, ctx=ast.Load() ) @@ -635,7 +634,7 @@ def _generate_component_create_functiondef( ), ), ast.Expr( - value=ast.Constant(value=Ellipsis), + value=ast.Constant(...), ), ], decorator_list=[ @@ -646,8 +645,8 @@ def _generate_component_create_functiondef( else [ast.Name(id="classmethod")] ), ], - lineno=node.lineno if node is not None else None, # pyright: ignore [reportArgumentType] returns=ast.Constant(value=clz.__name__), + lineno=node.lineno if node is not None else None, # pyright: ignore[reportArgumentType] ) return definition @@ -695,7 +694,6 @@ def _generate_staticmethod_call_functiondef( ), ], decorator_list=[ast.Name(id="staticmethod")], - lineno=node.lineno if node is not None else None, # pyright: ignore [reportArgumentType] returns=ast.Constant( value=_get_type_hint( typing.get_type_hints(clz.__call__).get("return", None), @@ -703,6 +701,7 @@ def _generate_staticmethod_call_functiondef( is_optional=False, ) ), + lineno=node.lineno if node is not None else None, # pyright: ignore[reportArgumentType] ) return definition @@ -723,6 +722,9 @@ def _generate_namespace_call_functiondef( Returns: The create functiondef node for the ast. + + Raises: + TypeError: If the __call__ method does not have a __func__. """ # add the imports needed by get_type_hint later type_hint_globals.update( @@ -737,7 +739,12 @@ def _generate_namespace_call_functiondef( # Determine which class is wrapped by the namespace __call__ method component_clz = clz.__call__.__self__ - if clz.__call__.__func__.__name__ != "create": # pyright: ignore [reportFunctionMemberAccess] + func = getattr(clz.__call__, "__func__", None) + + if func is None: + raise TypeError(f"__call__ method on {clz_name} does not have a __func__") + + if func.__name__ != "create": return None definition = _generate_component_create_functiondef( @@ -920,7 +927,7 @@ class StubGenerator(ast.NodeTransformer): node.body.append(call_definition) if not node.body: # We should never return an empty body. - node.body.append(ast.Expr(value=ast.Constant(value=Ellipsis))) + node.body.append(ast.Expr(value=ast.Constant(...))) self.current_class = None return node @@ -947,9 +954,9 @@ class StubGenerator(ast.NodeTransformer): if node.name.startswith("_") and node.name != "__call__": return None # remove private methods - if node.body[-1] != ast.Expr(value=ast.Constant(value=Ellipsis)): + if node.body[-1] != ast.Expr(value=ast.Constant(...)): # Blank out the function body for public functions. - node.body = [ast.Expr(value=ast.Constant(value=Ellipsis))] + node.body = [ast.Expr(value=ast.Constant(...))] return node def visit_Assign(self, node: ast.Assign) -> ast.Assign | None: diff --git a/reflex/utils/types.py b/reflex/utils/types.py index b432319e0..6090893d7 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -7,6 +7,7 @@ import dataclasses import inspect import sys import types +from collections import abc from functools import cached_property, lru_cache, wraps from typing import ( TYPE_CHECKING, @@ -23,6 +24,7 @@ from typing import ( Sequence, Tuple, Type, + TypeVar, Union, _GenericAlias, # pyright: ignore [reportAttributeAccessIssue] get_args, @@ -31,6 +33,7 @@ from typing import ( from typing import get_origin as get_origin_og import sqlalchemy +import typing_extensions from typing_extensions import is_typeddict import reflex @@ -68,13 +71,13 @@ else: # Potential GenericAlias types for isinstance checks. -GenericAliasTypes = [_GenericAlias] +_GenericAliasTypes: list[type] = [_GenericAlias] with contextlib.suppress(ImportError): # For newer versions of Python. from types import GenericAlias - GenericAliasTypes.append(GenericAlias) + _GenericAliasTypes.append(GenericAlias) with contextlib.suppress(ImportError): # For older versions of Python. @@ -82,9 +85,9 @@ with contextlib.suppress(ImportError): _SpecialGenericAlias, # pyright: ignore [reportAttributeAccessIssue] ) - GenericAliasTypes.append(_SpecialGenericAlias) + _GenericAliasTypes.append(_SpecialGenericAlias) -GenericAliasTypes = tuple(GenericAliasTypes) +GenericAliasTypes = tuple(_GenericAliasTypes) # Potential Union types for isinstance checks (UnionType added in py3.10). UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,) @@ -183,7 +186,7 @@ def is_generic_alias(cls: GenericType) -> bool: return isinstance(cls, GenericAliasTypes) # pyright: ignore [reportArgumentType] -def unionize(*args: GenericType) -> Type: +def unionize(*args: GenericType) -> GenericType: """Unionize the types. Args: @@ -417,7 +420,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None @lru_cache() -def get_base_class(cls: GenericType) -> Type: +def get_base_class(cls: GenericType) -> Type | tuple[Type, ...]: """Get the base class of a class. Args: @@ -437,7 +440,14 @@ def get_base_class(cls: GenericType) -> Type: return type(get_args(cls)[0]) if is_union(cls): - return tuple(get_base_class(arg) for arg in get_args(cls)) # pyright: ignore [reportReturnType] + base_classes = [] + for arg in get_args(cls): + sub_base_classes = get_base_class(arg) + if isinstance(sub_base_classes, tuple): + base_classes.extend(sub_base_classes) + else: + base_classes.append(sub_base_classes) + return tuple(base_classes) return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls @@ -847,18 +857,22 @@ StateBases = get_base_class(StateVar) StateIterBases = get_base_class(StateIterVar) -def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]): - """Check if a class is a subclass of another class. Returns False if internal error occurs. +def safe_issubclass(cls: Any, class_or_tuple: Any, /) -> bool: + """Check if a class is a subclass of another class or a tuple of classes. Args: cls: The class to check. - cls_check: The class to check against. + class_or_tuple: The class or tuple of classes to check against. Returns: - Whether the class is a subclass of the other class. + Whether the class is a subclass of the other class or tuple of classes. """ + if cls is class_or_tuple or ( + isinstance(class_or_tuple, tuple) and cls in class_or_tuple + ): + return True try: - return issubclass(cls, cls_check) + return issubclass(cls, class_or_tuple) except TypeError: return False @@ -873,17 +887,32 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo Returns: Whether the type hint is a subclass of the other type hint. """ + if isinstance(possible_subclass, Sequence) and isinstance( + possible_superclass, Sequence + ): + return all( + typehint_issubclass(subclass, superclass) + for subclass, superclass in zip( + possible_subclass, possible_superclass, strict=False + ) + ) + if possible_subclass is possible_superclass: + return True if possible_superclass is Any: return True if possible_subclass is Any: return False + if isinstance( + possible_subclass, (TypeVar, typing_extensions.TypeVar) + ) or isinstance(possible_superclass, (TypeVar, typing_extensions.TypeVar)): + return True provided_type_origin = get_origin(possible_subclass) accepted_type_origin = get_origin(possible_superclass) if provided_type_origin is None and accepted_type_origin is None: # In this case, we are dealing with a non-generic type, so we can use issubclass - return issubclass(possible_subclass, possible_superclass) + return safe_issubclass(possible_subclass, possible_superclass) # Remove this check when Python 3.10 is the minimum supported version if hasattr(types, "UnionType"): @@ -898,24 +927,64 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo provided_args = get_args(possible_subclass) accepted_args = get_args(possible_superclass) - if accepted_type_origin is Union: - if provided_type_origin is not Union: - return any( - typehint_issubclass(possible_subclass, accepted_arg) - for accepted_arg in accepted_args - ) + if provided_type_origin is Union: return all( - any( - typehint_issubclass(provided_arg, accepted_arg) - for accepted_arg in accepted_args - ) + typehint_issubclass(provided_arg, possible_superclass) for provided_arg in provided_args ) + if accepted_type_origin is Union: + return any( + typehint_issubclass(possible_subclass, accepted_arg) + for accepted_arg in accepted_args + ) + + # Check specifically for Sequence and Iterable + if (accepted_type_origin or possible_superclass) in ( + Sequence, + abc.Sequence, + Iterable, + abc.Iterable, + ): + iterable_type = accepted_args[0] if accepted_args else Any + + if provided_type_origin is None: + if not safe_issubclass( + possible_subclass, (accepted_type_origin or possible_superclass) + ): + return False + + if safe_issubclass(possible_subclass, str) and not isinstance( + iterable_type, TypeVar + ): + return typehint_issubclass(str, iterable_type) + return True + + if not safe_issubclass( + provided_type_origin, (accepted_type_origin or possible_superclass) + ): + return False + + if not isinstance(iterable_type, (TypeVar, typing_extensions.TypeVar)): + if provided_type_origin in (list, tuple, set): + # Ensure all specific types are compatible with accepted types + return all( + typehint_issubclass(provided_arg, iterable_type) + for provided_arg in provided_args + if provided_arg is not ... # Ellipsis in Tuples + ) + if possible_subclass is dict: + # Ensure all specific types are compatible with accepted types + return all( + typehint_issubclass(provided_arg, iterable_type) + for provided_arg in provided_args[:1] + ) + return True + # Check if the origin of both types is the same (e.g., list for List[int]) - # This probably should be issubclass instead of == - if (provided_type_origin or possible_subclass) != ( - accepted_type_origin or possible_superclass + if not safe_issubclass( + provided_type_origin or possible_subclass, + accepted_type_origin or possible_superclass, ): return False @@ -927,5 +996,21 @@ def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> boo for provided_arg, accepted_arg in zip( provided_args, accepted_args, strict=False ) - if accepted_arg is not Any + if accepted_arg is not Any and not isinstance(accepted_arg, TypeVar) ) + + +def safe_typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool: + """Check if a type hint is a subclass of another type hint. + + Args: + possible_subclass: The type hint to check. + possible_superclass: The type hint to check against. + + Returns: + Whether the type hint is a subclass of the other type hint. + """ + try: + return typehint_issubclass(possible_subclass, possible_superclass) + except Exception: + return False diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 593c60f3e..d10e0bfbf 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -13,7 +13,7 @@ import re import string import uuid import warnings -from types import CodeType, FunctionType +from types import CodeType, EllipsisType, FunctionType from typing import ( TYPE_CHECKING, Any, @@ -41,7 +41,14 @@ from typing import ( ) from sqlalchemy.orm import DeclarativeBase -from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override +from typing_extensions import ( + ParamSpec, + Protocol, + TypeGuard, + deprecated, + get_type_hints, + override, +) from reflex import constants from reflex.base import Base @@ -50,7 +57,6 @@ from reflex.utils import console, exceptions, imports, serializers, types from reflex.utils.exceptions import ( ComputedVarSignatureError, UntypedComputedVarError, - VarAttributeError, VarDependencyError, VarTypeError, ) @@ -69,24 +75,80 @@ from reflex.utils.types import ( get_origin, has_args, safe_issubclass, + typehint_issubclass, unionize, ) if TYPE_CHECKING: + from reflex.components.component import BaseComponent from reflex.state import BaseState + from .function import ArgsFunctionOperation from .number import BooleanVar, LiteralBooleanVar, LiteralNumberVar, NumberVar from .object import LiteralObjectVar, ObjectVar from .sequence import ArrayVar, LiteralArrayVar, LiteralStringVar, StringVar VAR_TYPE = TypeVar("VAR_TYPE", covariant=True) +VALUE = TypeVar("VALUE") +INT_OR_FLOAT = TypeVar("INT_OR_FLOAT", int, float) +FAKE_VAR_TYPE = TypeVar("FAKE_VAR_TYPE") OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE") STRING_T = TypeVar("STRING_T", bound=str) SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence) warnings.filterwarnings("ignore", message="fields may not start with an underscore") +P = ParamSpec("P") +R = TypeVar("R") + + +class ReflexCallable(Protocol[P, R]): + """Protocol for a callable.""" + + __call__: Callable[P, R] + + +ReflexCallableParams = Union[EllipsisType, Tuple[GenericType, ...]] + + +def unwrap_reflex_callalbe( + callable_type: GenericType, +) -> Tuple[ReflexCallableParams, GenericType]: + """Unwrap the ReflexCallable type. + + Args: + callable_type: The ReflexCallable type to unwrap. + + Returns: + The unwrapped ReflexCallable type. + """ + if callable_type is ReflexCallable: + return Ellipsis, Any + + origin = get_origin(callable_type) + + if origin is not ReflexCallable: + if origin in types.UnionTypes: + args = get_args(callable_type) + params: List[ReflexCallableParams] = [] + return_types: List[GenericType] = [] + for arg in args: + param, return_type = unwrap_reflex_callalbe(arg) + if param not in params: + params.append(param) + return_types.append(return_type) + return ( + Ellipsis if len(params) > 1 else params[0], + unionize(*return_types), + ) + return Ellipsis, Any + + args = get_args(callable_type) + if not args or len(args) != 2: + return Ellipsis, Any + return args + @dataclasses.dataclass( eq=False, @@ -98,6 +160,7 @@ class VarSubclassEntry: var_subclass: Type[Var] to_var_subclass: Type[ToOperation] python_types: Tuple[GenericType, ...] + is_subclass: Callable[[GenericType], bool] | None _var_subclasses: List[VarSubclassEntry] = [] @@ -123,6 +186,9 @@ class VarData: # Hooks that need to be present in the component to render this var hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + # Components that need to be present in the component to render this var + components: Tuple[BaseComponent, ...] = dataclasses.field(default_factory=tuple) + # Dependencies of the var deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple) @@ -135,6 +201,7 @@ class VarData: field_name: str = "", imports: ImportDict | ParsedImportDict | None = None, hooks: Mapping[str, VarData | None] | Sequence[str] | str | None = None, + components: Iterable[BaseComponent] | None = None, deps: list[Var] | None = None, position: Hooks.HookPosition | None = None, ): @@ -145,6 +212,7 @@ class VarData: field_name: The name of the field in the state. imports: Imports needed to render this var. hooks: Hooks that need to be present in the component to render this var. + components: Components that need to be present in the component to render this var. deps: Dependencies of the var for useCallback. position: Position of the hook in the component. """ @@ -159,6 +227,9 @@ class VarData: object.__setattr__(self, "field_name", field_name) object.__setattr__(self, "imports", immutable_imports) object.__setattr__(self, "hooks", tuple(hooks or {})) + object.__setattr__( + self, "components", tuple(components) if components is not None else () + ) object.__setattr__(self, "deps", tuple(deps or [])) object.__setattr__(self, "position", position or None) @@ -183,16 +254,11 @@ class VarData: def merge(*all: VarData | None) -> VarData | None: """Merge multiple var data objects. - Args: - *all: The var data objects to merge. - - Raises: - ReflexError: If trying to merge VarData with different positions. - Returns: The merged var data object. - # noqa: DAR102 *all + Raises: + ReflexError: If the positions of the var data objects are different. """ all_var_datas = list(filter(None, all)) @@ -230,6 +296,11 @@ class VarData: if var_data.position is not None } ) + + components = tuple( + component for var_data in all_var_datas for component in var_data.components + ) + if positions: if len(positions) > 1: raise exceptions.ReflexError( @@ -239,17 +310,15 @@ class VarData: else: position = None - if state or _imports or hooks or field_name or deps or position: - return VarData( - state=state, - field_name=field_name, - imports=_imports, - hooks=hooks, - deps=deps, - position=position, - ) - - return None + return VarData( + state=state, + field_name=field_name, + imports=_imports, + hooks=hooks, + deps=deps, + position=position, + components=components, + ) def __bool__(self) -> bool: """Check if the var data is non-empty. @@ -257,14 +326,7 @@ class VarData: Returns: True if any field is set to a non-default value. """ - return bool( - self.state - or self.imports - or self.hooks - or self.field_name - or self.deps - or self.position - ) + return any(getattr(self, field.name) for field in dataclasses.fields(self)) @classmethod def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData: @@ -428,6 +490,7 @@ class Var(Generic[VAR_TYPE]): cls, python_types: Tuple[GenericType, ...] | GenericType = types.Unset(), default_type: GenericType = types.Unset(), + is_subclass: Callable[[GenericType], bool] | types.Unset = types.Unset(), **kwargs, ): """Initialize the subclass. @@ -435,11 +498,12 @@ class Var(Generic[VAR_TYPE]): Args: python_types: The python types that the var represents. default_type: The default type of the var. Defaults to the first python type. + is_subclass: A function to check if a type is a subclass of the var. **kwargs: Additional keyword arguments. """ super().__init_subclass__(**kwargs) - if python_types or default_type: + if python_types or default_type or is_subclass: python_types = ( (python_types if isinstance(python_types, tuple) else (python_types,)) if python_types @@ -469,7 +533,14 @@ class Var(Generic[VAR_TYPE]): ) ToVarOperation.__name__ = new_to_var_operation_name - _var_subclasses.append(VarSubclassEntry(cls, ToVarOperation, python_types)) + _var_subclasses.append( + VarSubclassEntry( + cls, + ToVarOperation, + python_types, + is_subclass if not isinstance(is_subclass, types.Unset) else None, + ) + ) def __post_init__(self): """Post-initialize the var.""" @@ -478,9 +549,11 @@ class Var(Generic[VAR_TYPE]): if _var_data or _js_expr != self._js_expr: self.__init__( - _js_expr=_js_expr, - _var_type=self._var_type, - _var_data=VarData.merge(self._var_data, _var_data), + **{ + **dataclasses.asdict(self), + "_js_expr": _js_expr, + "_var_data": VarData.merge(self._var_data, _var_data), + } ) def __hash__(self) -> int: @@ -709,19 +782,22 @@ class Var(Generic[VAR_TYPE]): return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._js_expr}" @overload - def to(self, output: Type[str]) -> StringVar: ... + def to(self, output: Type[bool]) -> BooleanVar: ... # pyright: ignore [reportOverlappingOverload] @overload - def to(self, output: Type[bool]) -> BooleanVar: ... + def to(self, output: Type[int]) -> NumberVar[int]: ... @overload - def to(self, output: type[int] | type[float]) -> NumberVar: ... + def to(self, output: type[float]) -> NumberVar[float]: ... + + @overload + def to(self, output: Type[str]) -> StringVar: ... # pyright: ignore [reportOverlappingOverload] @overload def to( self, - output: type[list] | type[tuple] | type[set], - ) -> ArrayVar: ... + output: type[Sequence[VALUE]] | type[set[VALUE]], + ) -> ArrayVar[Sequence[VALUE]]: ... @overload def to( @@ -769,8 +845,11 @@ class Var(Generic[VAR_TYPE]): # If the first argument is a python type, we map it to the corresponding Var type. for var_subclass in _var_subclasses[::-1]: - if fixed_output_type in var_subclass.python_types or safe_issubclass( - fixed_output_type, var_subclass.python_types + if ( + var_subclass.python_types + and safe_issubclass(fixed_output_type, var_subclass.python_types) + ) or ( + var_subclass.is_subclass and var_subclass.is_subclass(fixed_output_type) ): return self.to(var_subclass.var_subclass, output) @@ -810,17 +889,29 @@ class Var(Generic[VAR_TYPE]): return self + # We use `NoReturn` here to catch `Var[Any]` and `Var[Unknown]` cases first. @overload - def guess_type(self: Var[NoReturn]) -> Var[Any]: ... # pyright: ignore [reportOverlappingOverload] - - @overload - def guess_type(self: Var[str]) -> StringVar: ... + def guess_type(self: Var[NoReturn]) -> Var: ... # pyright: ignore [reportOverlappingOverload] @overload def guess_type(self: Var[bool]) -> BooleanVar: ... @overload - def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ... + def guess_type(self: Var[INT_OR_FLOAT]) -> NumberVar[INT_OR_FLOAT]: ... + + @overload + def guess_type(self: Var[str]) -> StringVar: ... # pyright: ignore [reportOverlappingOverload] + + @overload + def guess_type(self: Var[Sequence[VALUE]]) -> ArrayVar[Sequence[VALUE]]: ... + + @overload + def guess_type(self: Var[Set[VALUE]]) -> ArrayVar[Set[VALUE]]: ... + + @overload + def guess_type( + self: Var[Dict[VALUE, OTHER_VAR_TYPE]], + ) -> ObjectVar[Dict[VALUE, OTHER_VAR_TYPE]]: ... @overload def guess_type(self: Var[BASE_TYPE]) -> ObjectVar[BASE_TYPE]: ... @@ -837,12 +928,13 @@ class Var(Generic[VAR_TYPE]): Raises: TypeError: If the type is not supported for guessing. """ - from .number import NumberVar from .object import ObjectVar var_type = self._var_type + if var_type is None: return self.to(None) + if types.is_optional(var_type): var_type = types.get_args(var_type)[0] @@ -854,10 +946,15 @@ class Var(Generic[VAR_TYPE]): if fixed_type in types.UnionTypes: inner_types = get_args(var_type) - if all( - inspect.isclass(t) and issubclass(t, (int, float)) for t in inner_types - ): - return self.to(NumberVar, self._var_type) + for var_subclass in _var_subclasses: + if all( + ( + safe_issubclass(t, var_subclass.python_types) + or (var_subclass.is_subclass and var_subclass.is_subclass(t)) + ) + for t in inner_types + ): + return self.to(var_subclass.var_subclass, self._var_type) if can_use_in_object_var(var_type): return self.to(ObjectVar, self._var_type) @@ -875,7 +972,9 @@ class Var(Generic[VAR_TYPE]): return self.to(None) for var_subclass in _var_subclasses[::-1]: - if issubclass(fixed_type, var_subclass.python_types): + if safe_issubclass(fixed_type, var_subclass.python_types) or ( + var_subclass.is_subclass and var_subclass.is_subclass(fixed_type) + ): return self.to(var_subclass.var_subclass, self._var_type) if can_use_in_object_var(fixed_type): @@ -1009,7 +1108,7 @@ class Var(Generic[VAR_TYPE]): """ from .number import equal_operation - return equal_operation(self, other) + return equal_operation(self, other).guess_type() def __ne__(self, other: Var | Any) -> BooleanVar: """Check if the current object is not equal to the given object. @@ -1022,7 +1121,7 @@ class Var(Generic[VAR_TYPE]): """ from .number import equal_operation - return ~equal_operation(self, other) + return (~equal_operation(self, other)).guess_type() def bool(self) -> BooleanVar: """Convert the var to a boolean. @@ -1032,7 +1131,7 @@ class Var(Generic[VAR_TYPE]): """ from .number import boolify - return boolify(self) + return boolify(self) # pyright: ignore [reportReturnType] def __and__(self, other: Var | Any) -> Var: """Perform a logical AND operation on the current instance and another variable. @@ -1084,7 +1183,7 @@ class Var(Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the logical NOT operation. """ - return ~self.bool() + return (~self.bool()).guess_type() def to_string(self, use_json: bool = True) -> StringVar: """Convert the var to a string. @@ -1197,7 +1296,7 @@ class Var(Generic[VAR_TYPE]): @overload @classmethod - def range(cls, stop: int | NumberVar, /) -> ArrayVar[List[int]]: ... + def range(cls, stop: int | NumberVar, /) -> ArrayVar[Sequence[int]]: ... @overload @classmethod @@ -1207,15 +1306,16 @@ class Var(Generic[VAR_TYPE]): end: int | NumberVar, step: int | NumberVar = 1, /, - ) -> ArrayVar[List[int]]: ... + ) -> ArrayVar[Sequence[int]]: ... @classmethod def range( cls, - first_endpoint: int | NumberVar, - second_endpoint: int | NumberVar | None = None, - step: int | NumberVar | None = None, - ) -> ArrayVar[List[int]]: + first_endpoint: int | Var[int], + second_endpoint: int | Var[int] | None = None, + step: int | Var[int] | None = None, + /, + ) -> ArrayVar[Sequence[int]]: """Create a range of numbers. Args: @@ -1228,42 +1328,13 @@ class Var(Generic[VAR_TYPE]): """ from .sequence import ArrayVar + if step is None: + return ArrayVar.range(first_endpoint, second_endpoint) + return ArrayVar.range(first_endpoint, second_endpoint, step) if not TYPE_CHECKING: - def __getattr__(self, name: str): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Raises: - VarAttributeError: If the attribute does not exist. - UntypedVarError: If the var type is Any. - TypeError: If the var type is Any. - - # noqa: DAR101 self - """ - if name.startswith("_"): - raise VarAttributeError(f"Attribute {name} not found.") - - if name == "contains": - raise TypeError( - f"Var of type {self._var_type} does not support contains check." - ) - if name == "reverse": - raise TypeError("Cannot reverse non-list var.") - - if self._var_type is Any: - raise exceptions.UntypedVarError( - f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`." - ) - - raise VarAttributeError( - f"The State var has no attribute '{name}' or may have been annotated wrongly.", - ) - def __bool__(self) -> bool: """Raise exception if using Var in a boolean context. @@ -1308,6 +1379,28 @@ VAR_SUBCLASS = TypeVar("VAR_SUBCLASS", bound=Var) VAR_INSIDE = TypeVar("VAR_INSIDE") +class VarWithDefault(Var[VAR_TYPE]): + """Annotate an optional argument.""" + + def __init__(self, default_value: VAR_TYPE): + """Initialize the default value. + + Args: + default_value: The default value. + """ + super().__init__("") + self._default = default_value + + @property + def default(self) -> Var[VAR_TYPE]: + """Get the default value. + + Returns: + The default value. + """ + return Var.create(self._default) + + class ToOperation: """A var operation that converts a var to another type.""" @@ -1449,9 +1542,6 @@ class LiteralVar(Var): Raises: TypeError: If the value is not a supported type for LiteralVar. """ - from .object import LiteralObjectVar - from .sequence import ArrayVar, LiteralStringVar - if isinstance(value, Var): if _var_data is None: return value @@ -1464,6 +1554,9 @@ class LiteralVar(Var): from reflex.event import EventHandler from reflex.utils.format import get_event_handler_parts + from .object import LiteralObjectVar + from .sequence import LiteralStringVar + if isinstance(value, EventHandler): return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value)))) @@ -1562,72 +1655,189 @@ def get_python_literal(value: Union[LiteralVar, Any]) -> Any | None: return value +def validate_arg(type_hint: GenericType) -> Callable[[Any], str | None]: + """Create a validator for an argument. + + Args: + type_hint: The type hint of the argument. + + Returns: + The validator. + """ + + def validate(value: Any): + if isinstance(value, LiteralVar): + if not _isinstance(value._var_value, type_hint): + return f"Expected {type_hint} but got {value._var_value} of type {type(value._var_value)}." + elif isinstance(value, Var): + if not typehint_issubclass(value._var_type, type_hint): + return f"Expected {type_hint} but got {value._var_type}." + else: + if not _isinstance(value, type_hint): + return f"Expected {type_hint} but got {value} of type {type(value)}." + + return validate + + P = ParamSpec("P") T = TypeVar("T") +V1 = TypeVar("V1") +V2 = TypeVar("V2") +V3 = TypeVar("V3") +V4 = TypeVar("V4") +V5 = TypeVar("V5") -# NoReturn is used to match CustomVarOperationReturn with no type hint. -@overload -def var_operation( # pyright: ignore [reportOverlappingOverload] - func: Callable[P, CustomVarOperationReturn[NoReturn]], -) -> Callable[P, Var]: ... +class TypeComputer(Protocol): + """A protocol for type computers.""" + + def __call__(self, *args: Var) -> Tuple[GenericType, Union[TypeComputer, None]]: + """Compute the type of the operation. + + Args: + *args: The arguments to compute the type of. + """ + ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[bool]], -) -> Callable[P, BooleanVar]: ... - - -NUMBER_T = TypeVar("NUMBER_T", int, float, Union[int, float]) + func: Callable[[Var[V1], Var[V2], Var[V3]], CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3], T]]: ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[NUMBER_T]], -) -> Callable[P, NumberVar[NUMBER_T]]: ... + func: Callable[[Var[V1], Var[V2], VarWithDefault[V3]], CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, VarWithDefault[V3]], T]]: ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[str]], -) -> Callable[P, StringVar]: ... - - -LIST_T = TypeVar("LIST_T", bound=Sequence) + func: Callable[ + [ + Var[V1], + VarWithDefault[V2], + VarWithDefault[V3], + ], + CustomVarOperationReturn[T], + ], +) -> ArgsFunctionOperation[ + ReflexCallable[ + [ + V1, + VarWithDefault[V2], + VarWithDefault[V3], + ], + T, + ] +]: ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[LIST_T]], -) -> Callable[P, ArrayVar[LIST_T]]: ... - - -OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping) + func: Callable[ + [ + VarWithDefault[V1], + VarWithDefault[V2], + VarWithDefault[V3], + ], + CustomVarOperationReturn[T], + ], +) -> ArgsFunctionOperation[ + ReflexCallable[ + [ + VarWithDefault[V1], + VarWithDefault[V1], + VarWithDefault[V1], + ], + T, + ] +]: ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]], -) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ... + func: Callable[[Var[V1], Var[V2]], CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[[V1, V2], T]]: ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[T]], -) -> Callable[P, Var[T]]: ... + func: Callable[ + [ + Var[V1], + VarWithDefault[V2], + ], + CustomVarOperationReturn[T], + ], +) -> ArgsFunctionOperation[ + ReflexCallable[ + [ + V1, + VarWithDefault[V2], + ], + T, + ] +]: ... -def var_operation( # pyright: ignore [reportInconsistentOverload] - func: Callable[P, CustomVarOperationReturn[T]], -) -> Callable[P, Var[T]]: +@overload +def var_operation( + func: Callable[ + [ + VarWithDefault[V1], + VarWithDefault[V2], + ], + CustomVarOperationReturn[T], + ], +) -> ArgsFunctionOperation[ + ReflexCallable[ + [ + VarWithDefault[V1], + VarWithDefault[V2], + ], + T, + ] +]: ... + + +@overload +def var_operation( + func: Callable[[Var[V1]], CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[[V1], T]]: ... + + +@overload +def var_operation( + func: Callable[ + [VarWithDefault[V1]], + CustomVarOperationReturn[T], + ], +) -> ArgsFunctionOperation[ + ReflexCallable[ + [VarWithDefault[V1]], + T, + ] +]: ... + + +@overload +def var_operation( + func: Callable[[], CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[[], T]]: ... + + +def var_operation( + func: Callable[..., CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[..., T]]: """Decorator for creating a var operation. Example: ```python @var_operation - def add(a: NumberVar, b: NumberVar): - return custom_var_operation(f"{a} + {b}") + def add(a: Var[int], b: Var[int]): + return var_operation_return(f"{a} + {b}") ``` Args: @@ -1635,27 +1845,93 @@ def var_operation( # pyright: ignore [reportInconsistentOverload] Returns: The decorated function. + + Raises: + TypeError: If the function has keyword-only arguments or arguments without Var type hints. """ + from .function import ArgsFunctionOperation, ReflexCallable - @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> Var[T]: - func_args = list(inspect.signature(func).parameters) - args_vars = { - func_args[i]: (LiteralVar.create(arg) if not isinstance(arg, Var) else arg) - for i, arg in enumerate(args) - } - kwargs_vars = { - key: LiteralVar.create(value) if not isinstance(value, Var) else value - for key, value in kwargs.items() - } + func_name = func.__name__ - return CustomVarOperation.create( - name=func.__name__, - args=tuple(list(args_vars.items()) + list(kwargs_vars.items())), - return_var=func(*args_vars.values(), **kwargs_vars), # pyright: ignore [reportCallIssue, reportReturnType] - ).guess_type() + func_arg_spec = inspect.getfullargspec(func) + func_signature = inspect.signature(func) - return wrapper + if func_arg_spec.kwonlyargs: + raise TypeError(f"Function {func_name} cannot have keyword-only arguments.") + if func_arg_spec.varargs: + raise TypeError(f"Function {func_name} cannot have variable arguments.") + + arg_names = func_arg_spec.args + + arg_default_values: Sequence[inspect.Parameter.empty | VarWithDefault] = tuple( + ( + default_value + if isinstance( + (default_value := func_signature.parameters[arg_name].default), + VarWithDefault, + ) + else inspect.Parameter.empty() + ) + for arg_name in arg_names + ) + + type_hints = get_type_hints(func) + + if not all( + (get_origin((type_hint := type_hints.get(arg_name, Any))) or type_hint) + in (Var, VarWithDefault) + and len(get_args(type_hint)) <= 1 + for arg_name in arg_names + ): + raise TypeError( + f"Function {func_name} must have type hints of the form `Var[Type]`." + ) + + args_with_type_hints = tuple( + (arg_name, (args[0] if (args := get_args(type_hints[arg_name])) else Any)) + for arg_name in arg_names + ) + + arg_vars = tuple( + ( + Var("_" + arg_name, _var_type=arg_python_type) + if not isinstance(arg_python_type, TypeVar) + else Var("_" + arg_name) + ) + for arg_name, arg_python_type in args_with_type_hints + ) + + custom_operation_return = func(*arg_vars) + + def simplified_operation(*args): + return func(*args)._js_expr + + args_operation = ArgsFunctionOperation.create( + tuple(map(str, arg_vars)), + custom_operation_return, + default_values=arg_default_values, + validators=tuple( + validate_arg(arg_type) + if not isinstance(arg_type, TypeVar) + else validate_arg(arg_type.__bound__ or Any) + for _, arg_type in args_with_type_hints + ), + function_name=func_name, + type_computer=custom_operation_return._type_computer, + _raw_js_function=custom_operation_return._raw_js_function, + _original_var_operation=simplified_operation, + _var_type=ReflexCallable[ + tuple( # pyright: ignore [reportInvalidTypeArguments] + arg_python_type + if isinstance(arg_default_values[i], inspect.Parameter) + else VarWithDefault[arg_python_type] + for i, (_, arg_python_type) in enumerate(args_with_type_hints) + ), + custom_operation_return._var_type, + ], + ) + + return args_operation def figure_out_type(value: Any) -> types.GenericType: @@ -1843,128 +2119,8 @@ class CachedVarOperation: ) -def and_operation(a: Var | Any, b: Var | Any) -> Var: - """Perform a logical AND operation on two variables. - - Args: - a: The first variable. - b: The second variable. - - Returns: - The result of the logical AND operation. - """ - return _and_operation(a, b) - - -@var_operation -def _and_operation(a: Var, b: Var): - """Perform a logical AND operation on two variables. - - Args: - a: The first variable. - b: The second variable. - - Returns: - The result of the logical AND operation. - """ - return var_operation_return( - js_expression=f"({a} && {b})", - var_type=unionize(a._var_type, b._var_type), - ) - - -def or_operation(a: Var | Any, b: Var | Any) -> Var: - """Perform a logical OR operation on two variables. - - Args: - a: The first variable. - b: The second variable. - - Returns: - The result of the logical OR operation. - """ - return _or_operation(a, b) - - -@var_operation -def _or_operation(a: Var, b: Var): - """Perform a logical OR operation on two variables. - - Args: - a: The first variable. - b: The second variable. - - Returns: - The result of the logical OR operation. - """ - return var_operation_return( - js_expression=f"({a} || {b})", - var_type=unionize(a._var_type, b._var_type), - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - slots=True, -) -class CallableVar(Var): - """Decorate a Var-returning function to act as both a Var and a function. - - This is used as a compatibility shim for replacing Var objects in the - API with functions that return a family of Var. - """ - - fn: Callable[..., Var] = dataclasses.field( - default_factory=lambda: lambda: Var(_js_expr="undefined") - ) - original_var: Var = dataclasses.field( - default_factory=lambda: Var(_js_expr="undefined") - ) - - def __init__(self, fn: Callable[..., Var]): - """Initialize a CallableVar. - - Args: - fn: The function to decorate (must return Var) - """ - original_var = fn() - super(CallableVar, self).__init__( - _js_expr=original_var._js_expr, - _var_type=original_var._var_type, - _var_data=VarData.merge(original_var._get_all_var_data()), - ) - object.__setattr__(self, "fn", fn) - object.__setattr__(self, "original_var", original_var) - - def __call__(self, *args: Any, **kwargs: Any) -> Var: - """Call the decorated function. - - Args: - *args: The args to pass to the function. - **kwargs: The kwargs to pass to the function. - - Returns: - The Var returned from calling the function. - """ - return self.fn(*args, **kwargs) - - def __hash__(self) -> int: - """Calculate the hash of the object. - - Returns: - The hash of the object. - """ - return hash((type(self).__name__, self.original_var)) - - RETURN_TYPE = TypeVar("RETURN_TYPE") -DICT_KEY = TypeVar("DICT_KEY") -DICT_VAL = TypeVar("DICT_VAL") - -LIST_INSIDE = TypeVar("LIST_INSIDE") - class FakeComputedVarBaseClass(property): """A fake base class for ComputedVar to avoid inheriting from property.""" @@ -2244,38 +2400,10 @@ class ComputedVar(Var[RETURN_TYPE]): @overload def __get__( - self: ComputedVar[list[LIST_INSIDE]], + self: ComputedVar[SEQUENCE_TYPE], instance: None, owner: Type, - ) -> ArrayVar[list[LIST_INSIDE]]: ... - - @overload - def __get__( - self: ComputedVar[tuple[LIST_INSIDE, ...]], - instance: None, - owner: Type, - ) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ... - - @overload - def __get__( - self: ComputedVar[BASE_TYPE], - instance: None, - owner: Type, - ) -> ObjectVar[BASE_TYPE]: ... - - @overload - def __get__( - self: ComputedVar[SQLA_TYPE], - instance: None, - owner: Type, - ) -> ObjectVar[SQLA_TYPE]: ... - - if TYPE_CHECKING: - - @overload - def __get__( - self: ComputedVar[DATACLASS_TYPE], instance: None, owner: Any - ) -> ObjectVar[DATACLASS_TYPE]: ... + ) -> ArrayVar[SEQUENCE_TYPE]: ... @overload def __get__(self, instance: None, owner: Type) -> ComputedVar[RETURN_TYPE]: ... @@ -2428,7 +2556,7 @@ class ComputedVar(Var[RETURN_TYPE]): f"field name, got {dep!r}." ) - def _determine_var_type(self) -> Type: + def _determine_var_type(self) -> GenericType: """Get the type of the var. Returns: @@ -2511,17 +2639,10 @@ class AsyncComputedVar(ComputedVar[RETURN_TYPE]): @overload def __get__( - self: AsyncComputedVar[list[LIST_INSIDE]], + self: AsyncComputedVar[SEQUENCE_TYPE], instance: None, owner: Type, - ) -> ArrayVar[list[LIST_INSIDE]]: ... - - @overload - def __get__( - self: AsyncComputedVar[tuple[LIST_INSIDE, ...]], - instance: None, - owner: Type, - ) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ... + ) -> ArrayVar[SEQUENCE_TYPE]: ... @overload def __get__( @@ -2713,22 +2834,34 @@ def computed_var( RETURN = TypeVar("RETURN") +@dataclasses.dataclass( + eq=False, + frozen=True, + slots=True, +) class CustomVarOperationReturn(Var[RETURN]): """Base class for custom var operations.""" + _type_computer: TypeComputer | None = dataclasses.field(default=None) + _raw_js_function: str | None = dataclasses.field(default=None) + @classmethod def create( cls, js_expression: str, _var_type: Type[RETURN] | None = None, + _type_computer: TypeComputer | None = None, _var_data: VarData | None = None, + _raw_js_function: str | None = None, ) -> CustomVarOperationReturn[RETURN]: """Create a CustomVarOperation. Args: js_expression: The JavaScript expression to evaluate. _var_type: The type of the var. + _type_computer: A function to compute the type of the var given the arguments. _var_data: Additional hooks and imports associated with the Var. + _raw_js_function: If provided, it will be used when the operation is being called with all of its arguments at once. Returns: The CustomVarOperation. @@ -2736,29 +2869,37 @@ class CustomVarOperationReturn(Var[RETURN]): return CustomVarOperationReturn( _js_expr=js_expression, _var_type=_var_type or Any, + _type_computer=_type_computer, _var_data=_var_data, + _raw_js_function=_raw_js_function, ) def var_operation_return( js_expression: str, var_type: Type[RETURN] | None = None, + type_computer: Optional[TypeComputer] = None, var_data: VarData | None = None, + _raw_js_function: str | None = None, ) -> CustomVarOperationReturn[RETURN]: """Shortcut for creating a CustomVarOperationReturn. Args: js_expression: The JavaScript expression to evaluate. var_type: The type of the var. + type_computer: A function to compute the type of the var given the arguments. var_data: Additional hooks and imports associated with the Var. + _raw_js_function: If provided, it will be used when the operation is being called with all of its arguments at once. Returns: The CustomVarOperationReturn. """ return CustomVarOperationReturn.create( - js_expression, - var_type, - var_data, + js_expression=js_expression, + _var_type=var_type, + _type_computer=type_computer, + _var_data=var_data, + _raw_js_function=_raw_js_function, ) @@ -3029,8 +3170,12 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: var_datas.append(sub._var_data) elif not isinstance(sub, str): # Recurse into dict values. - if hasattr(sub, "values") and callable(sub.values): - var_datas.extend(_extract_var_data(sub.values())) # pyright: ignore [reportArgumentType] + if ( + (values_fn := getattr(sub, "values", None)) is not None + and callable(values_fn) + and isinstance((values := values_fn()), Iterable) + ): + var_datas.extend(_extract_var_data(values)) # Recurse into iterable values (or dict keys). var_datas.extend(_extract_var_data(sub)) @@ -3039,9 +3184,9 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: var_datas.append(value._var_data) else: # Recurse when value is a dict itself. - values = getattr(value, "values", None) - if callable(values): - var_datas.extend(_extract_var_data(values())) # pyright: ignore [reportArgumentType] + values_fn = getattr(value, "values", None) + if callable(values_fn) and isinstance((values := values_fn()), Iterable): + var_datas.extend(_extract_var_data(values)) return var_datas @@ -3309,22 +3454,23 @@ class Field(Generic[FIELD_TYPE]): def __get__(self: Field[bool], instance: None, owner: Any) -> BooleanVar: ... @overload - def __get__( - self: Field[int] | Field[float] | Field[int | float], instance: None, owner: Any - ) -> NumberVar: ... + def __get__(self: Field[int], instance: None, owner: Any) -> NumberVar[int]: ... @overload - def __get__(self: Field[str], instance: None, owner: Any) -> StringVar: ... + def __get__(self: Field[float], instance: None, owner: Any) -> NumberVar[float]: ... + + @overload + def __get__(self: Field[str], instance: None, owner: Any) -> StringVar[str]: ... @overload def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ... @overload def __get__( - self: Field[List[V]] | Field[Set[V]] | Field[Tuple[V, ...]], + self: Field[Sequence[V]] | Field[Set[V]] | Field[List[V]], instance: None, owner: Any, - ) -> ArrayVar[List[V]]: ... + ) -> ArrayVar[Sequence[V]]: ... @overload def __get__( @@ -3373,3 +3519,151 @@ def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]: The Field. """ return value # pyright: ignore [reportReturnType] + + +def and_operation(a: Var | Any, b: Var | Any) -> Var: + """Perform a logical AND operation on two variables. + + Args: + a: The first variable. + b: The second variable. + + Returns: + The result of the logical AND operation. + """ + return _and_operation(a, b) + + +def or_operation(a: Var | Any, b: Var | Any) -> Var: + """Perform a logical OR operation on two variables. + + Args: + a: The first variable. + b: The second variable. + + Returns: + The result of the logical OR operation. + """ + return _or_operation(a, b) + + +def passthrough_unary_type_computer(no_args: GenericType) -> TypeComputer: + """Create a type computer for unary operations. + + Args: + no_args: The type to return when no arguments are provided. + + Returns: + The type computer. + """ + + def type_computer(*args: Var): + if not args: + return (no_args, type_computer) + return (ReflexCallable[[], args[0]._var_type], None) + + return type_computer + + +def unary_type_computer( + no_args: GenericType, computer: Callable[[Var], GenericType] +) -> TypeComputer: + """Create a type computer for unary operations. + + Args: + no_args: The type to return when no arguments are provided. + computer: The function to compute the type. + + Returns: + The type computer. + """ + + def type_computer(*args: Var): + if not args: + return (no_args, type_computer) + return (ReflexCallable[[], computer(args[0])], None) + + return type_computer + + +def nary_type_computer( + *types: GenericType, computer: Callable[..., GenericType] +) -> TypeComputer: + """Create a type computer for n-ary operations. + + Args: + types: The types to return when no arguments are provided. + computer: The function to compute the type. + + Returns: + The type computer. + """ + + def type_computer(*args: Var): + if len(args) != len(types): + return ( + types[len(args)], + functools.partial(type_computer, *args), + ) + return ( + ReflexCallable[[], computer(args)], + None, + ) + + return type_computer + + +T_LOGICAL = TypeVar("T_LOGICAL") +U_LOGICAL = TypeVar("U_LOGICAL") + + +@var_operation +def _and_operation( + a: Var[T_LOGICAL], b: Var[U_LOGICAL] +) -> CustomVarOperationReturn[Union[T_LOGICAL, U_LOGICAL]]: + """Perform a logical AND operation on two variables. + + Args: + a: The first variable. + b: The second variable. + + Returns: + The result of the logical AND operation. + """ + return var_operation_return( + js_expression=f"({a} && {b})", + type_computer=nary_type_computer( + ReflexCallable[[Any, Any], Any], + ReflexCallable[[Any], Any], + computer=lambda args: unionize( + args[0]._var_type, + args[1]._var_type, + ), + ), + ) + + +@var_operation +def _or_operation( + a: Var[T_LOGICAL], b: Var[U_LOGICAL] +) -> CustomVarOperationReturn[Union[T_LOGICAL, U_LOGICAL]]: + """Perform a logical OR operation on two variables. + + Args: + a: The first variable. + b: The second variable. + + Returns: + The result ocomputerf the logical OR operation. + """ + return var_operation_return( + js_expression=f"({a} || {b})", + type_computer=nary_type_computer( + ReflexCallable[[Any, Any], Any], + ReflexCallable[[Any], Any], + computer=lambda args: unionize( + args[0]._var_type, + args[1]._var_type, + ), + ), + ) diff --git a/reflex/vars/datetime.py b/reflex/vars/datetime.py index a18df78d0..c58b75d77 100644 --- a/reflex/vars/datetime.py +++ b/reflex/vars/datetime.py @@ -4,10 +4,7 @@ from __future__ import annotations import dataclasses from datetime import date, datetime -from typing import Any, NoReturn, TypeVar, Union, overload - -from reflex.utils.exceptions import VarTypeError -from reflex.vars.number import BooleanVar +from typing import TypeVar, Union from .base import ( CustomVarOperationReturn, @@ -23,156 +20,11 @@ DATETIME_T = TypeVar("DATETIME_T", datetime, date) datetime_types = Union[datetime, date] -def raise_var_type_error(): - """Raise a VarTypeError. - - Raises: - VarTypeError: Cannot compare a datetime object with a non-datetime object. - """ - raise VarTypeError("Cannot compare a datetime object with a non-datetime object.") - - -class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)): - """A variable that holds a datetime or date object.""" - - @overload - def __lt__(self, other: datetime_types) -> BooleanVar: ... - - @overload - def __lt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __lt__(self, other: Any): - """Less than comparison. - - Args: - other: The other datetime to compare. - - Returns: - The result of the comparison. - """ - if not isinstance(other, DATETIME_TYPES): - raise_var_type_error() - return date_lt_operation(self, other) - - @overload - def __le__(self, other: datetime_types) -> BooleanVar: ... - - @overload - def __le__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __le__(self, other: Any): - """Less than or equal comparison. - - Args: - other: The other datetime to compare. - - Returns: - The result of the comparison. - """ - if not isinstance(other, DATETIME_TYPES): - raise_var_type_error() - return date_le_operation(self, other) - - @overload - def __gt__(self, other: datetime_types) -> BooleanVar: ... - - @overload - def __gt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __gt__(self, other: Any): - """Greater than comparison. - - Args: - other: The other datetime to compare. - - Returns: - The result of the comparison. - """ - if not isinstance(other, DATETIME_TYPES): - raise_var_type_error() - return date_gt_operation(self, other) - - @overload - def __ge__(self, other: datetime_types) -> BooleanVar: ... - - @overload - def __ge__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __ge__(self, other: Any): - """Greater than or equal comparison. - - Args: - other: The other datetime to compare. - - Returns: - The result of the comparison. - """ - if not isinstance(other, DATETIME_TYPES): - raise_var_type_error() - return date_ge_operation(self, other) - - -@var_operation -def date_gt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: - """Greater than comparison. - - Args: - lhs: The left-hand side of the operation. - rhs: The right-hand side of the operation. - - Returns: - The result of the operation. - """ - return date_compare_operation(rhs, lhs, strict=True) - - -@var_operation -def date_lt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: - """Less than comparison. - - Args: - lhs: The left-hand side of the operation. - rhs: The right-hand side of the operation. - - Returns: - The result of the operation. - """ - return date_compare_operation(lhs, rhs, strict=True) - - -@var_operation -def date_le_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: - """Less than or equal comparison. - - Args: - lhs: The left-hand side of the operation. - rhs: The right-hand side of the operation. - - Returns: - The result of the operation. - """ - return date_compare_operation(lhs, rhs) - - -@var_operation -def date_ge_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: - """Greater than or equal comparison. - - Args: - lhs: The left-hand side of the operation. - rhs: The right-hand side of the operation. - - Returns: - The result of the operation. - """ - return date_compare_operation(rhs, lhs) - - def date_compare_operation( - lhs: DateTimeVar[DATETIME_T] | Any, - rhs: DateTimeVar[DATETIME_T] | Any, + lhs: Var[datetime_types], + rhs: Var[datetime_types], strict: bool = False, -) -> CustomVarOperationReturn: +) -> CustomVarOperationReturn[bool]: """Check if the value is less than the other value. Args: @@ -189,6 +41,84 @@ def date_compare_operation( ) +@var_operation +def date_gt_operation( + lhs: Var[datetime_types], + rhs: Var[datetime_types], +) -> CustomVarOperationReturn: + """Greater than comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(rhs, lhs, strict=True) + + +@var_operation +def date_lt_operation( + lhs: Var[datetime_types], + rhs: Var[datetime_types], +) -> CustomVarOperationReturn: + """Less than comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(lhs, rhs, strict=True) + + +@var_operation +def date_le_operation( + lhs: Var[datetime_types], rhs: Var[datetime_types] +) -> CustomVarOperationReturn: + """Less than or equal comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(lhs, rhs) + + +@var_operation +def date_ge_operation( + lhs: Var[datetime_types], rhs: Var[datetime_types] +) -> CustomVarOperationReturn: + """Greater than or equal comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(rhs, lhs) + + +class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)): + """A variable that holds a datetime or date object.""" + + __lt__ = date_lt_operation + + __le__ = date_le_operation + + __gt__ = date_gt_operation + + __ge__ = date_ge_operation + + @dataclasses.dataclass( eq=False, frozen=True, diff --git a/reflex/vars/function.py b/reflex/vars/function.py index 505a69b4c..56c38a007 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -3,30 +3,54 @@ from __future__ import annotations import dataclasses +import inspect import sys -from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Mapping, + NoReturn, + Optional, + Sequence, + Tuple, + Type, + Union, + overload, +) -from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar +from typing_extensions import Concatenate, Generic, ParamSpec, TypeVar from reflex.utils import format -from reflex.utils.types import GenericType +from reflex.utils.exceptions import VarTypeError +from reflex.utils.types import GenericType, Unset, get_origin -from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock +from .base import ( + CachedVarOperation, + LiteralVar, + ReflexCallable, + TypeComputer, + Var, + VarData, + VarWithDefault, + cached_property_no_lock, + unwrap_reflex_callalbe, +) + +if TYPE_CHECKING: + from .number import BooleanVar, NumberVar + from .object import ObjectVar + from .sequence import ArrayVar, StringVar P = ParamSpec("P") +R = TypeVar("R") +R2 = TypeVar("R2") V1 = TypeVar("V1") V2 = TypeVar("V2") V3 = TypeVar("V3") V4 = TypeVar("V4") V5 = TypeVar("V5") V6 = TypeVar("V6") -R = TypeVar("R") - - -class ReflexCallable(Protocol[P, R]): - """Protocol for a callable.""" - - __call__: Callable[P, R] CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True) @@ -34,8 +58,32 @@ OTHER_CALLABLE_TYPE = TypeVar( "OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True ) +MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping, covariant=True) +SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence, covariant=True) -class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): + +def type_is_reflex_callable(type_: Any) -> bool: + """Check if a type is a ReflexCallable. + + Args: + type_: The type to check. + + Returns: + True if the type is a ReflexCallable. + """ + return type_ is ReflexCallable or get_origin(type_) is ReflexCallable + + +K = TypeVar("K", covariant=True) +V = TypeVar("V", covariant=True) +T = TypeVar("T", covariant=True) + + +class FunctionVar( + Var[CALLABLE_TYPE], + default_type=ReflexCallable[Any, Any], + is_subclass=type_is_reflex_callable, +): """Base class for immutable function vars.""" @overload @@ -43,20 +91,39 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): @overload def partial( - self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]], + self: FunctionVar[ReflexCallable[Concatenate[VarWithDefault[V1], P], R]] + | FunctionVar[ReflexCallable[Concatenate[V1, P], R]], arg1: Union[V1, Var[V1]], ) -> FunctionVar[ReflexCallable[P, R]]: ... @overload def partial( - self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]], + self: FunctionVar[ + ReflexCallable[Concatenate[VarWithDefault[V1], VarWithDefault[V2], P], R] + ] + | FunctionVar[ReflexCallable[Concatenate[V1, VarWithDefault[V2], P], R]] + | FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]], arg1: Union[V1, Var[V1]], arg2: Union[V2, Var[V2]], ) -> FunctionVar[ReflexCallable[P, R]]: ... @overload def partial( - self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]], + self: FunctionVar[ + ReflexCallable[ + Concatenate[ + VarWithDefault[V1], VarWithDefault[V2], VarWithDefault[V3], P + ], + R, + ] + ] + | FunctionVar[ + ReflexCallable[ + Concatenate[V1, VarWithDefault[V2], VarWithDefault[V3], P], R + ] + ] + | FunctionVar[ReflexCallable[Concatenate[V1, V2, VarWithDefault[V3], P], R]] + | FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]], arg1: Union[V1, Var[V1]], arg2: Union[V2, Var[V2]], arg3: Union[V3, Var[V3]], @@ -110,23 +177,879 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): The partially applied function. """ if not args: - return ArgsFunctionOperation.create((), self) + return self + + args = tuple(map(LiteralVar.create, args)) + + remaining_validators = self._pre_check(*args) + + partial_types, type_computer = self._partial_type(*args) + + if self.__call__ is self.partial: + # if the default behavior is partial, we should return a new partial function + return ArgsFunctionOperationBuilder.create( + (), + VarOperationCall.create( + self, + *args, + Var(_js_expr="...args"), + _var_type=self._return_type(*args), + ), + rest="args", + validators=remaining_validators, + type_computer=type_computer, + _var_type=partial_types, + ) return ArgsFunctionOperation.create( - ("...args",), - VarOperationCall.create(self, *args, Var(_js_expr="...args")), + (), + VarOperationCall.create( + self, *args, Var(_js_expr="...args"), _var_type=self._return_type(*args) + ), + rest="args", + validators=remaining_validators, + type_computer=type_computer, + _var_type=partial_types, ) + # THIS CODE IS GENERATED BY `_generate_overloads_for_function_var_call` function below. + + @overload + def call( + self: FunctionVar[ReflexCallable[[], bool]], + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[], int]], + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[], float]], + ) -> NumberVar[float]: ... + + @overload + def call( # pyright: ignore[reportOverlappingOverload] + self: FunctionVar[ReflexCallable[[], str]], + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[], SEQUENCE_TYPE]], + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[], MAPPING_TYPE]], + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[], R]], + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[VarWithDefault[V1]], bool]], + arg1: Union[V1, Var[V1], Unset] = Unset(), + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[VarWithDefault[V1]], int]], + arg1: Union[V1, Var[V1], Unset] = Unset(), + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[VarWithDefault[V1]], float]], + arg1: Union[V1, Var[V1], Unset] = Unset(), + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[VarWithDefault[V1]], str]], + arg1: Union[V1, Var[V1], Unset] = Unset(), + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[VarWithDefault[V1]], SEQUENCE_TYPE]], + arg1: Union[V1, Var[V1], Unset] = Unset(), + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[VarWithDefault[V1]], MAPPING_TYPE]], + arg1: Union[V1, Var[V1], Unset] = Unset(), + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[VarWithDefault[V1]], R]], + arg1: Union[V1, Var[V1], Unset] = Unset(), + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[VarWithDefault[V1], VarWithDefault[V2]], bool] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[VarWithDefault[V1], VarWithDefault[V2]], int] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[VarWithDefault[V1], VarWithDefault[V2]], float] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[VarWithDefault[V1], VarWithDefault[V2]], str] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[VarWithDefault[V1], VarWithDefault[V2]], SEQUENCE_TYPE] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[VarWithDefault[V1], VarWithDefault[V2]], MAPPING_TYPE] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[VarWithDefault[V1], VarWithDefault[V2]], R]], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [VarWithDefault[V1], VarWithDefault[V2], VarWithDefault[V3]], bool + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [VarWithDefault[V1], VarWithDefault[V2], VarWithDefault[V3]], int + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [VarWithDefault[V1], VarWithDefault[V2], VarWithDefault[V3]], float + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [VarWithDefault[V1], VarWithDefault[V2], VarWithDefault[V3]], str + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [VarWithDefault[V1], VarWithDefault[V2], VarWithDefault[V3]], + SEQUENCE_TYPE, + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [VarWithDefault[V1], VarWithDefault[V2], VarWithDefault[V3]], + MAPPING_TYPE, + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [VarWithDefault[V1], VarWithDefault[V2], VarWithDefault[V3]], R + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [ + VarWithDefault[V1], + VarWithDefault[V2], + VarWithDefault[V3], + VarWithDefault[V4], + ], + bool, + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [ + VarWithDefault[V1], + VarWithDefault[V2], + VarWithDefault[V3], + VarWithDefault[V4], + ], + int, + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [ + VarWithDefault[V1], + VarWithDefault[V2], + VarWithDefault[V3], + VarWithDefault[V4], + ], + float, + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [ + VarWithDefault[V1], + VarWithDefault[V2], + VarWithDefault[V3], + VarWithDefault[V4], + ], + str, + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [ + VarWithDefault[V1], + VarWithDefault[V2], + VarWithDefault[V3], + VarWithDefault[V4], + ], + SEQUENCE_TYPE, + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [ + VarWithDefault[V1], + VarWithDefault[V2], + VarWithDefault[V3], + VarWithDefault[V4], + ], + MAPPING_TYPE, + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [ + VarWithDefault[V1], + VarWithDefault[V2], + VarWithDefault[V3], + VarWithDefault[V4], + ], + R, + ] + ], + arg1: Union[V1, Var[V1], Unset] = Unset(), + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1], bool]], arg1: Union[V1, Var[V1]] + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1], int]], arg1: Union[V1, Var[V1]] + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1], float]], arg1: Union[V1, Var[V1]] + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1], str]], arg1: Union[V1, Var[V1]] + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1], SEQUENCE_TYPE]], arg1: Union[V1, Var[V1]] + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1], MAPPING_TYPE]], arg1: Union[V1, Var[V1]] + ) -> ObjectVar[MAPPING_TYPE]: ... + @overload def call( self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]] - ) -> VarOperationCall[[V1], R]: ... + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, VarWithDefault[V2]], bool]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, VarWithDefault[V2]], int]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, VarWithDefault[V2]], float]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, VarWithDefault[V2]], str]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, VarWithDefault[V2]], SEQUENCE_TYPE]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, VarWithDefault[V2]], MAPPING_TYPE]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, VarWithDefault[V2]], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, VarWithDefault[V2], VarWithDefault[V3]], bool] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, VarWithDefault[V2], VarWithDefault[V3]], int] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, VarWithDefault[V2], VarWithDefault[V3]], float] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, VarWithDefault[V2], VarWithDefault[V3]], str] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, VarWithDefault[V2], VarWithDefault[V3]], SEQUENCE_TYPE] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, VarWithDefault[V2], VarWithDefault[V3]], MAPPING_TYPE] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, VarWithDefault[V2], VarWithDefault[V3]], R] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [V1, VarWithDefault[V2], VarWithDefault[V3], VarWithDefault[V4]], bool + ] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [V1, VarWithDefault[V2], VarWithDefault[V3], VarWithDefault[V4]], int + ] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [V1, VarWithDefault[V2], VarWithDefault[V3], VarWithDefault[V4]], float + ] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [V1, VarWithDefault[V2], VarWithDefault[V3], VarWithDefault[V4]], str + ] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [V1, VarWithDefault[V2], VarWithDefault[V3], VarWithDefault[V4]], + SEQUENCE_TYPE, + ] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [V1, VarWithDefault[V2], VarWithDefault[V3], VarWithDefault[V4]], + MAPPING_TYPE, + ] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [V1, VarWithDefault[V2], VarWithDefault[V3], VarWithDefault[V4]], R + ] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2], Unset] = Unset(), + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2], bool]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2], int]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2], float]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2], str]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2], SEQUENCE_TYPE]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2], MAPPING_TYPE]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + ) -> ObjectVar[MAPPING_TYPE]: ... @overload def call( self: FunctionVar[ReflexCallable[[V1, V2], R]], arg1: Union[V1, Var[V1]], arg2: Union[V2, Var[V2]], - ) -> VarOperationCall[[V1, V2], R]: ... + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, VarWithDefault[V3]], bool]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, VarWithDefault[V3]], int]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, VarWithDefault[V3]], float]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, VarWithDefault[V3]], str]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, VarWithDefault[V3]], SEQUENCE_TYPE]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, VarWithDefault[V3]], MAPPING_TYPE]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, VarWithDefault[V3]], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, V2, VarWithDefault[V3], VarWithDefault[V4]], bool] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, V2, VarWithDefault[V3], VarWithDefault[V4]], int] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, V2, VarWithDefault[V3], VarWithDefault[V4]], float] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, V2, VarWithDefault[V3], VarWithDefault[V4]], str] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [V1, V2, VarWithDefault[V3], VarWithDefault[V4]], SEQUENCE_TYPE + ] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[ + [V1, V2, VarWithDefault[V3], VarWithDefault[V4]], MAPPING_TYPE + ] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, V2, VarWithDefault[V3], VarWithDefault[V4]], R] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3], Unset] = Unset(), + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3], bool]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3], int]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3], float]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3], str]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3], SEQUENCE_TYPE]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3], MAPPING_TYPE]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + ) -> ObjectVar[MAPPING_TYPE]: ... @overload def call( @@ -134,7 +1057,128 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): arg1: Union[V1, Var[V1]], arg2: Union[V2, Var[V2]], arg3: Union[V3, Var[V3]], - ) -> VarOperationCall[[V1, V2, V3], R]: ... + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, VarWithDefault[V4]], bool]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, VarWithDefault[V4]], int]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, VarWithDefault[V4]], float]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, VarWithDefault[V4]], str]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, V2, V3, VarWithDefault[V4]], SEQUENCE_TYPE] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ + ReflexCallable[[V1, V2, V3, VarWithDefault[V4]], MAPPING_TYPE] + ], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, VarWithDefault[V4]], R]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4], Unset] = Unset(), + ) -> Var[R]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], bool]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + ) -> BooleanVar: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], int]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + ) -> NumberVar[int]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], float]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + ) -> NumberVar[float]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], str]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + ) -> StringVar[str]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], SEQUENCE_TYPE]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + ) -> ArrayVar[SEQUENCE_TYPE]: ... + + @overload + def call( + self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], MAPPING_TYPE]], + arg1: Union[V1, Var[V1]], + arg2: Union[V2, Var[V2]], + arg3: Union[V3, Var[V3]], + arg4: Union[V4, Var[V4]], + ) -> ObjectVar[MAPPING_TYPE]: ... @overload def call( @@ -143,36 +1187,11 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): arg2: Union[V2, Var[V2]], arg3: Union[V3, Var[V3]], arg4: Union[V4, Var[V4]], - ) -> VarOperationCall[[V1, V2, V3, V4], R]: ... + ) -> Var[R]: ... + # Capture Any to allow for arbitrary number of arguments @overload - def call( - self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]], - arg1: Union[V1, Var[V1]], - arg2: Union[V2, Var[V2]], - arg3: Union[V3, Var[V3]], - arg4: Union[V4, Var[V4]], - arg5: Union[V5, Var[V5]], - ) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ... - - @overload - def call( - self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]], - arg1: Union[V1, Var[V1]], - arg2: Union[V2, Var[V2]], - arg3: Union[V3, Var[V3]], - arg4: Union[V4, Var[V4]], - arg5: Union[V5, Var[V5]], - arg6: Union[V6, Var[V6]], - ) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ... - - @overload - def call( - self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any - ) -> VarOperationCall[P, R]: ... - - @overload - def call(self, *args: Var | Any) -> Var: ... + def call(self: FunctionVar[NoReturn], *args: Var | Any) -> Var: ... def call(self, *args: Var | Any) -> Var: # pyright: ignore [reportInconsistentOverload] """Call the function with the given arguments. @@ -182,12 +1201,253 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): Returns: The function call operation. + + Raises: + VarTypeError: If the number of arguments is invalid """ - return VarOperationCall.create(self, *args).guess_type() + required_arg_len = self._required_arg_len() + arg_len = self._arg_len() + if arg_len is not None: + if len(args) < required_arg_len: + raise VarTypeError( + f"Passed {len(args)} arguments, expected at least {required_arg_len} for {self!s}" + ) + if len(args) > arg_len: + raise VarTypeError( + f"Passed {len(args)} arguments, expected at most {arg_len} for {self!s}" + ) + args = tuple(map(LiteralVar.create, args)) + self._pre_check(*args) + return_type = self._return_type(*args) + if isinstance(self, (ArgsFunctionOperation, ArgsFunctionOperationBuilder)): + default_args = self._default_args() + max_allowed_arguments = ( + arg_len if arg_len is not None else len(args) + len(default_args) + ) + provided_argument_count = len(args) + + # we skip default args which we provided + default_args_provided = len(default_args) - ( + max_allowed_arguments - provided_argument_count + ) + + full_args = args + tuple(default_args[default_args_provided:]) + + if self._raw_js_function is not None: + return VarOperationCall.create( + FunctionStringVar.create( + self._raw_js_function, + _var_type=self._var_type, + _var_data=self._get_all_var_data(), + ), + *full_args, + _var_type=return_type, + ).guess_type() + if self._original_var_operation is not None: + return ExpressionCall.create( + self._original_var_operation, + *full_args, + _var_data=self._get_all_var_data(), + _var_type=return_type, + ).guess_type() + + return VarOperationCall.create(self, *args, _var_type=return_type).guess_type() + + def chain( + self: FunctionVar[ReflexCallable[P, R]], + other: FunctionVar[ReflexCallable[[R], R2]] + | FunctionVar[ReflexCallable[[R, VarWithDefault[Any]], R2]] + | FunctionVar[ + ReflexCallable[[R, VarWithDefault[Any], VarWithDefault[Any]], R2] + ], + ) -> FunctionVar[ReflexCallable[P, R2]]: + """Chain two functions together. + + Args: + other: The other function to chain. + + Returns: + The chained function. + """ + self_arg_type, self_return_type = unwrap_reflex_callalbe(self._var_type) + _, other_return_type = unwrap_reflex_callalbe(other._var_type) + + return ArgsFunctionOperationBuilder.create( + (), + VarOperationCall.create( + other, + VarOperationCall.create( + self, Var(_js_expr="...args"), _var_type=self_return_type + ), + _var_type=other_return_type, + ), + rest="arg", + _var_type=ReflexCallable[self_arg_type, other_return_type], # pyright: ignore [reportInvalidTypeArguments] + ) + + def _partial_type( + self, *args: Var | Any + ) -> Tuple[GenericType, Optional[TypeComputer]]: + """Override the type of the function call with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The overridden type of the function call. + """ + args_types, return_type = unwrap_reflex_callalbe(self._var_type) + if isinstance(args_types, tuple): + return ReflexCallable[[*args_types[len(args) :]], return_type], None + return ReflexCallable[..., return_type], None + + def _arg_len(self) -> int | None: + """Get the number of arguments the function takes. + + Returns: + The number of arguments the function takes. + """ + args_types, _ = unwrap_reflex_callalbe(self._var_type) + if isinstance(args_types, tuple): + return len(args_types) + return None + + def _required_arg_len(self) -> int: + """Get the number of required arguments the function takes. + + Returns: + The number of required arguments the function takes. + """ + args_types, _ = unwrap_reflex_callalbe(self._var_type) + if isinstance(args_types, tuple): + return sum( + 1 + for arg_type in args_types + if get_origin(arg_type) is not VarWithDefault + ) + return 0 + + def _default_args(self) -> list[Any]: + """Get the default arguments of the function. + + Returns: + The default arguments of the function. + """ + if isinstance(self, (ArgsFunctionOperation, ArgsFunctionOperationBuilder)): + return [ + arg.default + for arg in self._default_values + if not isinstance(arg, inspect.Parameter.empty) + ] + return [] + + def _return_type(self, *args: Var | Any) -> GenericType: + """Override the type of the function call with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The overridden type of the function call. + """ + partial_types, _ = self._partial_type(*args) + return unwrap_reflex_callalbe(partial_types)[1] + + def _pre_check( + self, *args: Var | Any + ) -> Tuple[Callable[[Any], Optional[str]], ...]: + """Check if the function can be called with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + True if the function can be called with the given arguments. + """ + return () + + @overload + def __get__(self, instance: None, owner: Any) -> FunctionVar[CALLABLE_TYPE]: ... + + @overload + def __get__( + self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]], + instance: Var[V1], + owner: Any, + ) -> FunctionVar[ReflexCallable[P, R]]: ... + + def __get__(self, instance: Any, owner: Any): + """Get the function var. + + Args: + instance: The instance of the class. + owner: The owner of the class. + + Returns: + The function var. + """ + if instance is None: + return self + return self.partial(instance) __call__ = call +@dataclasses.dataclass(frozen=True) +class ExpressionCall(CachedVarOperation, Var[R]): + """Class for expression calls.""" + + _original_var_operation: Callable = dataclasses.field(default=lambda *args: "") + _args: Tuple[Var, ...] = dataclasses.field(default_factory=tuple) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return self._original_var_operation(*self._args) + + @cached_property_no_lock + def _cached_get_all_var_data(self) -> VarData | None: + """Get all the var data associated with the var. + + Returns: + All the var data associated with the var. + """ + return VarData.merge( + *[arg._get_all_var_data() for arg in self._args], + self._var_data, + ) + + @classmethod + def create( + cls, + _original_var_operation: Callable, + *args: Var | Any, + _var_type: GenericType = Any, + _var_data: VarData | None = None, + ) -> ExpressionCall: + """Create a new expression call. + + Args: + _original_var_operation: The original var operation. + *args: The arguments to call the expression with. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The expression call var. + """ + return ExpressionCall( + _js_expr="", + _var_type=_var_type, + _var_data=_var_data, + _original_var_operation=_original_var_operation, + _args=args, + ) + + class BuilderFunctionVar( FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any] ): @@ -231,7 +1491,9 @@ class FunctionStringVar(FunctionVar[CALLABLE_TYPE]): class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]): """Base class for immutable vars that are the result of a function call.""" - _func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None) + _func: Optional[FunctionVar[ReflexCallable[..., R]]] = dataclasses.field( + default=None + ) _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) @cached_property_no_lock @@ -321,62 +1583,133 @@ class FunctionArgs: def format_args_function_operation( - args: FunctionArgs, return_expr: Var | Any, explicit_return: bool + self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, ) -> str: """Format an args function operation. Args: - args: The function arguments. - return_expr: The return expression. - explicit_return: Whether to use explicit return syntax. + self: The function operation. Returns: The formatted args function operation. """ arg_names_str = ", ".join( - [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args] - ) + (f", ...{args.rest}" if args.rest else "") + [ + (arg if isinstance(arg, str) else arg.to_javascript()) + + ( + f" = {default_value.default!s}" + if i < len(self._default_values) + and not isinstance( + (default_value := self._default_values[i]), inspect.Parameter.empty + ) + else "" + ) + for i, arg in enumerate(self._args.args) + ] + + ([f"...{self._args.rest}"] if self._args.rest else []) + ) - return_expr_str = str(LiteralVar.create(return_expr)) + return_expr_str = str(LiteralVar.create(self._return_expr)) # Wrap return expression in curly braces if explicit return syntax is used. return_expr_str_wrapped = ( - format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str + format.wrap(return_expr_str, "{", "}") + if self._explicit_return + else return_expr_str ) return f"(({arg_names_str}) => {return_expr_str_wrapped})" +def pre_check_args( + self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, *args: Var | Any +) -> Tuple[Callable[[Any], Optional[str]], ...]: + """Check if the function can be called with the given arguments. + + Args: + self: The function operation. + *args: The arguments to call the function with. + + Returns: + True if the function can be called with the given arguments. + + Raises: + VarTypeError: If the arguments are invalid. + """ + for i, (validator, arg) in enumerate(zip(self._validators, args, strict=False)): + if (validation_message := validator(arg)) is not None: + arg_name = self._args.args[i] if i < len(self._args.args) else None + if arg_name is not None: + raise VarTypeError( + f"Invalid argument {arg!s} provided to {arg_name} in {self._function_name or 'var operation'}. {validation_message}" + ) + raise VarTypeError( + f"Invalid argument {arg!s} provided to argument {i} in {self._function_name or 'var operation'}. {validation_message}" + ) + return self._validators[len(args) :] + + +def figure_partial_type( + self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, + *args: Var | Any, +) -> Tuple[GenericType, Optional[TypeComputer]]: + """Figure out the return type of the function. + + Args: + self: The function operation. + *args: The arguments to call the function with. + + Returns: + The return type of the function. + """ + return ( + self._type_computer(*args) + if self._type_computer is not None + else FunctionVar._partial_type(self, *args) + ) + + @dataclasses.dataclass( eq=False, frozen=True, slots=True, ) -class ArgsFunctionOperation(CachedVarOperation, FunctionVar): +class ArgsFunctionOperation(CachedVarOperation, FunctionVar[CALLABLE_TYPE]): """Base class for immutable function defined via arguments and return expression.""" _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) + _default_values: Tuple[VarWithDefault | inspect.Parameter.empty, ...] = ( + dataclasses.field(default_factory=tuple) + ) + _validators: Tuple[Callable[[Any], Optional[str]], ...] = dataclasses.field( + default_factory=tuple + ) _return_expr: Union[Var, Any] = dataclasses.field(default=None) + _function_name: str = dataclasses.field(default="") + _type_computer: Optional[TypeComputer] = dataclasses.field(default=None) _explicit_return: bool = dataclasses.field(default=False) + _raw_js_function: str | None = dataclasses.field(default=None) + _original_var_operation: Callable | None = dataclasses.field(default=None) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + _cached_var_name = cached_property_no_lock(format_args_function_operation) - Returns: - The name of the var. - """ - return format_args_function_operation( - self._args, self._return_expr, self._explicit_return - ) + _pre_check = pre_check_args + + _partial_type = figure_partial_type @classmethod def create( cls, args_names: Sequence[Union[str, DestructuredArg]], return_expr: Var | Any, + default_values: Sequence[VarWithDefault | inspect.Parameter.empty] = (), rest: str | None = None, + validators: Sequence[Callable[[Any], Optional[str]]] = (), + function_name: str = "", explicit_return: bool = False, + type_computer: Optional[TypeComputer] = None, + _raw_js_function: str | None = None, + _original_var_operation: Callable | None = None, _var_type: GenericType = Callable, _var_data: VarData | None = None, ): @@ -385,9 +1718,15 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): Args: args_names: The names of the arguments. return_expr: The return expression of the function. + default_values: The default values of the arguments. rest: The name of the rest argument. + validators: The validators for the arguments. + function_name: The name of the function. explicit_return: Whether to use explicit return syntax. - _var_type: The type of the Var. + type_computer: A function to compute the return type. + _raw_js_function: If provided, it will be used when the operation is being called with all of its arguments at once. + _original_var_operation: The original var operation, if any. + _var_type: The type of the var. _var_data: Additional hooks and imports associated with the Var. Returns: @@ -399,8 +1738,14 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): _var_type=_var_type, _var_data=_var_data, _args=FunctionArgs(args=tuple(args_names), rest=rest), + _raw_js_function=_raw_js_function, + _original_var_operation=_original_var_operation, + _default_values=tuple(default_values), + _function_name=function_name, + _validators=tuple(validators), _return_expr=return_expr, _explicit_return=explicit_return, + _type_computer=type_computer, ) @@ -409,31 +1754,44 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): frozen=True, slots=True, ) -class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): +class ArgsFunctionOperationBuilder( + CachedVarOperation, BuilderFunctionVar[CALLABLE_TYPE] +): """Base class for immutable function defined via arguments and return expression with the builder pattern.""" _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) + _default_values: Tuple[VarWithDefault | inspect.Parameter.empty, ...] = ( + dataclasses.field(default_factory=tuple) + ) + _validators: Tuple[Callable[[Any], Optional[str]], ...] = dataclasses.field( + default_factory=tuple + ) _return_expr: Union[Var, Any] = dataclasses.field(default=None) + _function_name: str = dataclasses.field(default="") + _type_computer: Optional[TypeComputer] = dataclasses.field(default=None) _explicit_return: bool = dataclasses.field(default=False) + _raw_js_function: str | None = dataclasses.field(default=None) + _original_var_operation: Callable | None = dataclasses.field(default=None) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + _cached_var_name = cached_property_no_lock(format_args_function_operation) - Returns: - The name of the var. - """ - return format_args_function_operation( - self._args, self._return_expr, self._explicit_return - ) + _pre_check = pre_check_args + + _partial_type = figure_partial_type @classmethod def create( cls, args_names: Sequence[Union[str, DestructuredArg]], return_expr: Var | Any, + default_values: Sequence[VarWithDefault | inspect.Parameter.empty] = (), rest: str | None = None, + validators: Sequence[Callable[[Any], Optional[str]]] = (), + function_name: str = "", explicit_return: bool = False, + type_computer: Optional[TypeComputer] = None, + _raw_js_function: str | None = None, + _original_var_operation: Callable | None = None, _var_type: GenericType = Callable, _var_data: VarData | None = None, ): @@ -442,9 +1800,15 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): Args: args_names: The names of the arguments. return_expr: The return expression of the function. + default_values: The default values of the arguments. rest: The name of the rest argument. + validators: The validators for the arguments. + function_name: The name of the function. explicit_return: Whether to use explicit return syntax. - _var_type: The type of the Var. + type_computer: A function to compute the return type. + _raw_js_function: If provided, it will be used when the operation is being called with all of its arguments at once. + _original_var_operation: The original var operation, if any. + _var_type: The type of the var. _var_data: Additional hooks and imports associated with the Var. Returns: @@ -456,8 +1820,14 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): _var_type=_var_type, _var_data=_var_data, _args=FunctionArgs(args=tuple(args_names), rest=rest), + _raw_js_function=_raw_js_function, + _original_var_operation=_original_var_operation, + _default_values=tuple(default_values), + _function_name=function_name, + _validators=tuple(validators), _return_expr=return_expr, _explicit_return=explicit_return, + _type_computer=type_computer, ) @@ -483,3 +1853,60 @@ else: "((__to_string) => __to_string.toString())", _var_type=ReflexCallable[Any, str], ) + + +def _generate_overloads_for_function_var_call(maximum_args: int = 4) -> str: + """Generate overloads for the function var call method. + + Args: + maximum_args: The maximum number of arguments to generate overloads for. + + Returns: + The generated overloads. + """ + overloads = [] + return_type_mapping = { + "bool": "BooleanVar", + "int": "NumberVar[int]", + "float": "NumberVar[float]", + "str": "StringVar[str]", + "SEQUENCE_TYPE": "ArrayVar[SEQUENCE_TYPE]", + "MAPPING_TYPE": "ObjectVar[MAPPING_TYPE]", + "R": "Var[R]", + } + for number_of_required_args in range(maximum_args + 1): + for number_of_optional_args in range( + maximum_args + 1 - number_of_required_args + ): + for return_type, return_type_var in return_type_mapping.items(): + required_args = [ + f"arg{j + 1}: Union[V{j + 1}, Var[V{j + 1}]]" + for j in range(number_of_required_args) + ] + optional_args = [ + f"arg{j + 1}: Union[V{j + 1}, Var[V{j + 1}], Unset] = Unset()" + for j in range( + number_of_required_args, + number_of_required_args + number_of_optional_args, + ) + ] + required_params = [f"V{j + 1}" for j in range(number_of_required_args)] + optional_params = [ + f"VarWithDefault[V{j + 1}]" + for j in range( + number_of_required_args, + number_of_required_args + number_of_optional_args, + ) + ] + function_type_hint = f"""FunctionVar[ReflexCallable[[{", ".join(required_params + optional_params)}], {return_type}]]""" + new_line = "\n" + overloads.append( + f""" + @overload + def call( + self: {function_type_hint}, + {"," + new_line + " ".join(required_args + optional_args)} + ) -> {return_type_var}: ... + """ + ) + return "\n".join(overloads) diff --git a/reflex/vars/number.py b/reflex/vars/number.py index 87f1760a6..6f08db1af 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -3,6 +3,7 @@ from __future__ import annotations import dataclasses +import functools import json import math from typing import ( @@ -10,21 +11,30 @@ from typing import ( Any, Callable, NoReturn, - Type, + Sequence, TypeVar, Union, + cast, overload, ) +from typing_extensions import Unpack + from reflex.constants.base import Dirs from reflex.utils.exceptions import PrimitiveUnserializableToJSONError, VarTypeError from reflex.utils.imports import ImportDict, ImportVar from .base import ( + VAR_TYPE, + CachedVarOperation, CustomVarOperationReturn, LiteralVar, + ReflexCallable, Var, VarData, + cached_property_no_lock, + nary_type_computer, + passthrough_unary_type_computer, unionize, var_operation, var_operation_return, @@ -33,6 +43,7 @@ from .base import ( NUMBER_T = TypeVar("NUMBER_T", int, float, bool) if TYPE_CHECKING: + from .function import FunctionVar from .sequence import ArrayVar @@ -56,13 +67,7 @@ def raise_unsupported_operand_types( class NumberVar(Var[NUMBER_T], python_types=(int, float)): """Base class for immutable number vars.""" - @overload - def __add__(self, other: number_types) -> NumberVar: ... - - @overload - def __add__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __add__(self, other: Any): + def __add__(self, other: number_types) -> NumberVar: """Add two numbers. Args: @@ -73,15 +78,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("+", (type(self), type(other))) - return number_add_operation(self, +other) + return number_add_operation(self, +other).guess_type() - @overload - def __radd__(self, other: number_types) -> NumberVar: ... - - @overload - def __radd__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __radd__(self, other: Any): + def __radd__(self, other: number_types) -> NumberVar: """Add two numbers. Args: @@ -92,15 +91,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("+", (type(other), type(self))) - return number_add_operation(+other, self) + return number_add_operation(+other, self).guess_type() - @overload - def __sub__(self, other: number_types) -> NumberVar: ... - - @overload - def __sub__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __sub__(self, other: Any): + def __sub__(self, other: number_types) -> NumberVar: """Subtract two numbers. Args: @@ -112,15 +105,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("-", (type(self), type(other))) - return number_subtract_operation(self, +other) + return number_subtract_operation(self, +other).guess_type() - @overload - def __rsub__(self, other: number_types) -> NumberVar: ... - - @overload - def __rsub__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __rsub__(self, other: Any): + def __rsub__(self, other: number_types) -> NumberVar: """Subtract two numbers. Args: @@ -132,7 +119,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("-", (type(other), type(self))) - return number_subtract_operation(+other, self) + return number_subtract_operation(+other, self).guess_type() def __abs__(self): """Get the absolute value of the number. @@ -167,7 +154,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("*", (type(self), type(other))) - return number_multiply_operation(self, +other) + return number_multiply_operation(self, +other).guess_type() @overload def __rmul__(self, other: number_types | boolean_types) -> NumberVar: ... @@ -194,15 +181,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("*", (type(other), type(self))) - return number_multiply_operation(+other, self) + return number_multiply_operation(+other, self).guess_type() - @overload - def __truediv__(self, other: number_types) -> NumberVar: ... - - @overload - def __truediv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __truediv__(self, other: Any): + def __truediv__(self, other: number_types) -> NumberVar: """Divide two numbers. Args: @@ -214,15 +195,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("/", (type(self), type(other))) - return number_true_division_operation(self, +other) + return number_true_division_operation(self, +other).guess_type() - @overload - def __rtruediv__(self, other: number_types) -> NumberVar: ... - - @overload - def __rtruediv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __rtruediv__(self, other: Any): + def __rtruediv__(self, other: number_types) -> NumberVar: """Divide two numbers. Args: @@ -234,15 +209,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("/", (type(other), type(self))) - return number_true_division_operation(+other, self) + return number_true_division_operation(+other, self).guess_type() - @overload - def __floordiv__(self, other: number_types) -> NumberVar: ... - - @overload - def __floordiv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __floordiv__(self, other: Any): + def __floordiv__(self, other: number_types) -> NumberVar: """Floor divide two numbers. Args: @@ -254,15 +223,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("//", (type(self), type(other))) - return number_floor_division_operation(self, +other) + return number_floor_division_operation(self, +other).guess_type() - @overload - def __rfloordiv__(self, other: number_types) -> NumberVar: ... - - @overload - def __rfloordiv__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __rfloordiv__(self, other: Any): + def __rfloordiv__(self, other: number_types) -> NumberVar: """Floor divide two numbers. Args: @@ -274,15 +237,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("//", (type(other), type(self))) - return number_floor_division_operation(+other, self) + return number_floor_division_operation(+other, self).guess_type() - @overload - def __mod__(self, other: number_types) -> NumberVar: ... - - @overload - def __mod__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __mod__(self, other: Any): + def __mod__(self, other: number_types) -> NumberVar: """Modulo two numbers. Args: @@ -294,15 +251,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("%", (type(self), type(other))) - return number_modulo_operation(self, +other) + return number_modulo_operation(self, +other).guess_type() - @overload - def __rmod__(self, other: number_types) -> NumberVar: ... - - @overload - def __rmod__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __rmod__(self, other: Any): + def __rmod__(self, other: number_types) -> NumberVar: """Modulo two numbers. Args: @@ -314,15 +265,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("%", (type(other), type(self))) - return number_modulo_operation(+other, self) + return number_modulo_operation(+other, self).guess_type() - @overload - def __pow__(self, other: number_types) -> NumberVar: ... - - @overload - def __pow__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __pow__(self, other: Any): + def __pow__(self, other: number_types) -> NumberVar: """Exponentiate two numbers. Args: @@ -334,15 +279,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("**", (type(self), type(other))) - return number_exponent_operation(self, +other) + return number_exponent_operation(self, +other).guess_type() - @overload - def __rpow__(self, other: number_types) -> NumberVar: ... - - @overload - def __rpow__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __rpow__(self, other: Any): + def __rpow__(self, other: number_types) -> NumberVar: """Exponentiate two numbers. Args: @@ -354,7 +293,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("**", (type(other), type(self))) - return number_exponent_operation(+other, self) + return number_exponent_operation(+other, self).guess_type() def __neg__(self): """Negate the number. @@ -362,7 +301,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): Returns: The number negation operation. """ - return number_negate_operation(self) + return number_negate_operation(self).guess_type() def __invert__(self): """Boolean NOT the number. @@ -370,7 +309,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): Returns: The boolean NOT operation. """ - return boolean_not_operation(self.bool()) + return boolean_not_operation(self.bool()).guess_type() def __pos__(self) -> NumberVar: """Positive the number. @@ -386,7 +325,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): Returns: The number round operation. """ - return number_round_operation(self) + return number_round_operation(self).guess_type() def __ceil__(self): """Ceil the number. @@ -394,7 +333,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): Returns: The number ceil operation. """ - return number_ceil_operation(self) + return number_ceil_operation(self).guess_type() def __floor__(self): """Floor the number. @@ -402,7 +341,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): Returns: The number floor operation. """ - return number_floor_operation(self) + return number_floor_operation(self).guess_type() def __trunc__(self): """Trunc the number. @@ -410,15 +349,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): Returns: The number trunc operation. """ - return number_trunc_operation(self) + return number_trunc_operation(self).guess_type() - @overload - def __lt__(self, other: number_types) -> BooleanVar: ... - - @overload - def __lt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __lt__(self, other: Any): + def __lt__(self, other: number_types) -> BooleanVar: """Less than comparison. Args: @@ -429,15 +362,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("<", (type(self), type(other))) - return less_than_operation(+self, +other) + return less_than_operation(+self, +other).guess_type() - @overload - def __le__(self, other: number_types) -> BooleanVar: ... - - @overload - def __le__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __le__(self, other: Any): + def __le__(self, other: number_types) -> BooleanVar: """Less than or equal comparison. Args: @@ -448,9 +375,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types("<=", (type(self), type(other))) - return less_than_or_equal_operation(+self, +other) + return less_than_or_equal_operation(+self, +other).guess_type() - def __eq__(self, other: Any): + def __eq__(self, other: Any) -> BooleanVar: """Equal comparison. Args: @@ -460,10 +387,10 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): The result of the comparison. """ if isinstance(other, NUMBER_TYPES): - return equal_operation(+self, +other) - return equal_operation(self, other) + return equal_operation(+self, +other).guess_type() + return equal_operation(self, other).guess_type() - def __ne__(self, other: Any): + def __ne__(self, other: Any) -> BooleanVar: """Not equal comparison. Args: @@ -473,16 +400,10 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): The result of the comparison. """ if isinstance(other, NUMBER_TYPES): - return not_equal_operation(+self, +other) - return not_equal_operation(self, other) + return not_equal_operation(+self, +other).guess_type() + return not_equal_operation(self, other).guess_type() - @overload - def __gt__(self, other: number_types) -> BooleanVar: ... - - @overload - def __gt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __gt__(self, other: Any): + def __gt__(self, other: number_types) -> BooleanVar: """Greater than comparison. Args: @@ -493,15 +414,9 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types(">", (type(self), type(other))) - return greater_than_operation(+self, +other) + return greater_than_operation(+self, +other).guess_type() - @overload - def __ge__(self, other: number_types) -> BooleanVar: ... - - @overload - def __ge__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __ge__(self, other: Any): + def __ge__(self, other: number_types) -> BooleanVar: """Greater than or equal comparison. Args: @@ -512,7 +427,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ if not isinstance(other, NUMBER_TYPES): raise_unsupported_operand_types(">=", (type(self), type(other))) - return greater_than_or_equal_operation(+self, +other) + return greater_than_or_equal_operation(+self, +other).guess_type() def _is_strict_float(self) -> bool: """Check if the number is a float. @@ -532,8 +447,8 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): def binary_number_operation( - func: Callable[[NumberVar, NumberVar], str], -) -> Callable[[number_types, number_types], NumberVar]: + func: Callable[[Var[int | float], Var[int | float]], str], +): """Decorator to create a binary number operation. Args: @@ -543,30 +458,37 @@ def binary_number_operation( The binary number operation. """ - @var_operation - def operation(lhs: NumberVar, rhs: NumberVar): + def operation( + lhs: Var[int | float], rhs: Var[int | float] + ) -> CustomVarOperationReturn[int | float]: + def type_computer(*args: Var): + if not args: + return ( + ReflexCallable[[int | float, int | float], int | float], + type_computer, + ) + if len(args) == 1: + return ( + ReflexCallable[[int | float], int | float], + functools.partial(type_computer, args[0]), + ) + return ( + ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)], + None, + ) + return var_operation_return( js_expression=func(lhs, rhs), - var_type=unionize(lhs._var_type, rhs._var_type), + type_computer=type_computer, ) - def wrapper(lhs: number_types, rhs: number_types) -> NumberVar: - """Create the binary number operation. + operation.__name__ = func.__name__ - Args: - lhs: The first number. - rhs: The second number. - - Returns: - The binary number operation. - """ - return operation(lhs, rhs) # pyright: ignore [reportReturnType, reportArgumentType] - - return wrapper + return var_operation(operation) @binary_number_operation -def number_add_operation(lhs: NumberVar, rhs: NumberVar): +def number_add_operation(lhs: Var[int | float], rhs: Var[int | float]): """Add two numbers. Args: @@ -580,7 +502,7 @@ def number_add_operation(lhs: NumberVar, rhs: NumberVar): @binary_number_operation -def number_subtract_operation(lhs: NumberVar, rhs: NumberVar): +def number_subtract_operation(lhs: Var[int | float], rhs: Var[int | float]): """Subtract two numbers. Args: @@ -593,8 +515,15 @@ def number_subtract_operation(lhs: NumberVar, rhs: NumberVar): return f"({lhs} - {rhs})" +unary_operation_type_computer = passthrough_unary_type_computer( + ReflexCallable[[int | float], int | float] +) + + @var_operation -def number_abs_operation(value: NumberVar): +def number_abs_operation( + value: Var[int | float], +) -> CustomVarOperationReturn[int | float]: """Get the absolute value of the number. Args: @@ -604,12 +533,14 @@ def number_abs_operation(value: NumberVar): The number absolute operation. """ return var_operation_return( - js_expression=f"Math.abs({value})", var_type=value._var_type + js_expression=f"Math.abs({value})", + type_computer=unary_operation_type_computer, + _raw_js_function="Math.abs", ) @binary_number_operation -def number_multiply_operation(lhs: NumberVar, rhs: NumberVar): +def number_multiply_operation(lhs: Var[int | float], rhs: Var[int | float]): """Multiply two numbers. Args: @@ -624,7 +555,7 @@ def number_multiply_operation(lhs: NumberVar, rhs: NumberVar): @var_operation def number_negate_operation( - value: NumberVar[NUMBER_T], + value: Var[NUMBER_T], ) -> CustomVarOperationReturn[NUMBER_T]: """Negate the number. @@ -634,11 +565,13 @@ def number_negate_operation( Returns: The number negation operation. """ - return var_operation_return(js_expression=f"-({value})", var_type=value._var_type) + return var_operation_return( + js_expression=f"-({value})", type_computer=unary_operation_type_computer + ) @binary_number_operation -def number_true_division_operation(lhs: NumberVar, rhs: NumberVar): +def number_true_division_operation(lhs: Var[int | float], rhs: Var[int | float]): """Divide two numbers. Args: @@ -652,7 +585,7 @@ def number_true_division_operation(lhs: NumberVar, rhs: NumberVar): @binary_number_operation -def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar): +def number_floor_division_operation(lhs: Var[int | float], rhs: Var[int | float]): """Floor divide two numbers. Args: @@ -666,7 +599,7 @@ def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar): @binary_number_operation -def number_modulo_operation(lhs: NumberVar, rhs: NumberVar): +def number_modulo_operation(lhs: Var[int | float], rhs: Var[int | float]): """Modulo two numbers. Args: @@ -680,7 +613,7 @@ def number_modulo_operation(lhs: NumberVar, rhs: NumberVar): @binary_number_operation -def number_exponent_operation(lhs: NumberVar, rhs: NumberVar): +def number_exponent_operation(lhs: Var[int | float], rhs: Var[int | float]): """Exponentiate two numbers. Args: @@ -694,7 +627,7 @@ def number_exponent_operation(lhs: NumberVar, rhs: NumberVar): @var_operation -def number_round_operation(value: NumberVar): +def number_round_operation(value: Var[int | float]): """Round the number. Args: @@ -707,7 +640,7 @@ def number_round_operation(value: NumberVar): @var_operation -def number_ceil_operation(value: NumberVar): +def number_ceil_operation(value: Var[int | float]): """Ceil the number. Args: @@ -720,7 +653,7 @@ def number_ceil_operation(value: NumberVar): @var_operation -def number_floor_operation(value: NumberVar): +def number_floor_operation(value: Var[int | float]): """Floor the number. Args: @@ -729,11 +662,15 @@ def number_floor_operation(value: NumberVar): Returns: The number floor operation. """ - return var_operation_return(js_expression=f"Math.floor({value})", var_type=int) + return var_operation_return( + js_expression=f"Math.floor({value})", + var_type=int, + _raw_js_function="Math.floor", + ) @var_operation -def number_trunc_operation(value: NumberVar): +def number_trunc_operation(value: Var[int | float]): """Trunc the number. Args: @@ -754,7 +691,7 @@ class BooleanVar(NumberVar[bool], python_types=bool): Returns: The boolean NOT operation. """ - return boolean_not_operation(self) + return boolean_not_operation(self).guess_type() def __int__(self): """Convert the boolean to an int. @@ -762,7 +699,7 @@ class BooleanVar(NumberVar[bool], python_types=bool): Returns: The boolean to int operation. """ - return boolean_to_number_operation(self) + return boolean_to_number_operation(self).guess_type() def __pos__(self): """Convert the boolean to an int. @@ -770,7 +707,7 @@ class BooleanVar(NumberVar[bool], python_types=bool): Returns: The boolean to int operation. """ - return boolean_to_number_operation(self) + return boolean_to_number_operation(self).guess_type() def bool(self) -> BooleanVar: """Boolean conversion. @@ -826,7 +763,7 @@ class BooleanVar(NumberVar[bool], python_types=bool): @var_operation -def boolean_to_number_operation(value: BooleanVar): +def boolean_to_number_operation(value: Var[bool]): """Convert the boolean to a number. Args: @@ -835,12 +772,14 @@ def boolean_to_number_operation(value: BooleanVar): Returns: The boolean to number operation. """ - return var_operation_return(js_expression=f"Number({value})", var_type=int) + return var_operation_return( + js_expression=f"Number({value})", var_type=int, _raw_js_function="Number" + ) def comparison_operator( func: Callable[[Var, Var], str], -) -> Callable[[Var | Any, Var | Any], BooleanVar]: +) -> FunctionVar[ReflexCallable[[Any, Any], bool]]: """Decorator to create a comparison operation. Args: @@ -850,26 +789,15 @@ def comparison_operator( The comparison operation. """ - @var_operation - def operation(lhs: Var, rhs: Var): + def operation(lhs: Var[Any], rhs: Var[Any]): return var_operation_return( js_expression=func(lhs, rhs), var_type=bool, ) - def wrapper(lhs: Var | Any, rhs: Var | Any) -> BooleanVar: - """Create the comparison operation. + operation.__name__ = func.__name__ - Args: - lhs: The first value. - rhs: The second value. - - Returns: - The comparison operation. - """ - return operation(lhs, rhs) - - return wrapper + return var_operation(operation) @comparison_operator @@ -957,7 +885,7 @@ def not_equal_operation(lhs: Var, rhs: Var): @var_operation -def boolean_not_operation(value: BooleanVar): +def boolean_not_operation(value: Var[bool]): """Boolean NOT the boolean. Args: @@ -1081,6 +1009,18 @@ _IS_TRUE_IMPORT: ImportDict = { f"$/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")], } +_AT_SLICE_IMPORT: ImportDict = { + f"$/{Dirs.STATE_PATH}": [ImportVar(tag="atSlice")], +} + +_AT_SLICE_OR_INDEX: ImportDict = { + f"$/{Dirs.STATE_PATH}": [ImportVar(tag="atSliceOrIndex")], +} + +_RANGE_IMPORT: ImportDict = { + f"$/{Dirs.UTILS}/helpers/range": [ImportVar(tag="range", is_default=True)], +} + @var_operation def boolify(value: Var): @@ -1096,16 +1036,17 @@ def boolify(value: Var): js_expression=f"isTrue({value})", var_type=bool, var_data=VarData(imports=_IS_TRUE_IMPORT), + _raw_js_function="isTrue", ) -T = TypeVar("T") -U = TypeVar("U") +T = TypeVar("T", bound=Any) +U = TypeVar("U", bound=Any) @var_operation def ternary_operation( - condition: BooleanVar, if_true: Var[T], if_false: Var[U] + condition: Var[bool], if_true: Var[T], if_false: Var[U] ) -> CustomVarOperationReturn[Union[T, U]]: """Create a ternary operation. @@ -1117,14 +1058,125 @@ def ternary_operation( Returns: The ternary operation. """ - type_value: Union[Type[T], Type[U]] = unionize( - if_true._var_type, if_false._var_type - ) value: CustomVarOperationReturn[Union[T, U]] = var_operation_return( js_expression=f"({condition} ? {if_true} : {if_false})", - var_type=type_value, + type_computer=nary_type_computer( + ReflexCallable[[bool, Any, Any], Any], + ReflexCallable[[Any, Any], Any], + ReflexCallable[[Any], Any], + computer=lambda args: unionize(args[1]._var_type, args[2]._var_type), + ), ) return value +TUPLE_ENDS_IN_VAR = tuple[Unpack[tuple[Var[Any], ...]], Var[VAR_TYPE]] + +TUPLE_ENDS_IN_VAR_RELAXED = tuple[ + Unpack[tuple[Var[Any] | Any, ...]], Var[VAR_TYPE] | VAR_TYPE +] + + +@dataclasses.dataclass( + eq=False, + frozen=True, + slots=True, +) +class MatchOperation(CachedVarOperation, Var[VAR_TYPE]): + """Base class for immutable match operations.""" + + _cond: Var[bool] = dataclasses.field( + default_factory=lambda: LiteralBooleanVar.create(True) + ) + _cases: tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...] = dataclasses.field( + default_factory=tuple + ) + _default: Var[VAR_TYPE] = dataclasses.field( # pyright: ignore[reportAssignmentType] + default_factory=lambda: Var.create(None) + ) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """Get the name of the var. + + Returns: + The name of the var. + """ + switch_code = f"(() => {{ switch (JSON.stringify({self._cond!s})) {{" + + for case in self._cases: + conditions = case[:-1] + return_value = case[-1] + + case_conditions = " ".join( + [f"case JSON.stringify({condition!s}):" for condition in conditions] + ) + case_code = f"{case_conditions} return ({return_value!s}); break;" + switch_code += case_code + + switch_code += f"default: return ({self._default!s}); break;" + switch_code += "};})()" + + return switch_code + + @cached_property_no_lock + def _cached_get_all_var_data(self) -> VarData | None: + """Get the VarData for the var. + + Returns: + The VarData for the var. + """ + return VarData.merge( + self._cond._get_all_var_data(), + *( + cond_or_return._get_all_var_data() + for case in self._cases + for cond_or_return in case + ), + self._default._get_all_var_data(), + self._var_data, + ) + + @classmethod + def create( + cls, + cond: Any, + cases: Sequence[TUPLE_ENDS_IN_VAR_RELAXED[VAR_TYPE]], + default: Var[VAR_TYPE] | VAR_TYPE, + _var_data: VarData | None = None, + _var_type: type[VAR_TYPE] | None = None, + ): + """Create the match operation. + + Args: + cond: The condition. + cases: The cases. + default: The default case. + _var_data: Additional hooks and imports associated with the Var. + _var_type: The type of the Var. + + Returns: + The match operation. + """ + cond = Var.create(cond) + cases = cast( + tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...], + tuple(tuple(Var.create(c) for c in case) for case in cases), + ) + + _default = cast(Var[VAR_TYPE], Var.create(default)) + var_type = _var_type or unionize( + *(case[-1]._var_type for case in cases), + _default._var_type, + ) + return cls( + _js_expr="", + _var_data=_var_data, + _var_type=var_type, + _cond=cond, + _cases=cases, + _default=_default, + ) + + NUMBER_TYPES = (int, float, NumberVar) diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 89479bbc4..c23fca1f6 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -10,6 +10,7 @@ from typing import ( List, Mapping, NoReturn, + Sequence, Tuple, Type, TypeVar, @@ -27,15 +28,19 @@ from reflex.utils.types import ( get_attribute_access_type, get_origin, safe_issubclass, + unionize, ) from .base import ( CachedVarOperation, LiteralVar, + ReflexCallable, Var, VarData, cached_property_no_lock, figure_out_type, + nary_type_computer, + unary_type_computer, var_operation, var_operation_return, ) @@ -69,9 +74,9 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): ) -> Type[VALUE_TYPE]: ... @overload - def _value_type(self) -> Type: ... + def _value_type(self) -> GenericType: ... - def _value_type(self) -> Type: + def _value_type(self) -> GenericType: """Get the type of the values of the object. Returns: @@ -83,18 +88,18 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else () return args[1] if args else Any # pyright: ignore [reportReturnType] - def keys(self) -> ArrayVar[List[str]]: + def keys(self) -> ArrayVar[Sequence[str]]: """Get the keys of the object. Returns: The keys of the object. """ - return object_keys_operation(self) + return object_keys_operation(self).guess_type() @overload def values( self: ObjectVar[Mapping[Any, VALUE_TYPE]], - ) -> ArrayVar[List[VALUE_TYPE]]: ... + ) -> ArrayVar[Sequence[VALUE_TYPE]]: ... @overload def values(self) -> ArrayVar: ... @@ -105,12 +110,12 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): Returns: The values of the object. """ - return object_values_operation(self) + return object_values_operation(self).guess_type() @overload def entries( self: ObjectVar[Mapping[Any, VALUE_TYPE]], - ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ... + ) -> ArrayVar[Sequence[Tuple[str, VALUE_TYPE]]]: ... @overload def entries(self) -> ArrayVar: ... @@ -121,7 +126,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): Returns: The entries of the object. """ - return object_entries_operation(self) + return object_entries_operation(self).guess_type() items = entries @@ -167,15 +172,10 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): @overload def __getitem__( - self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, Sequence[ARRAY_INNER_TYPE]]] + | ObjectVar[Mapping[Any, List[ARRAY_INNER_TYPE]]], key: Var | Any, - ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... - - @overload - def __getitem__( - self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], - key: Var | Any, - ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... + ) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ... @overload def __getitem__( @@ -227,15 +227,9 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): @overload def __getattr__( - self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]], + self: ObjectVar[Mapping[Any, Sequence[ARRAY_INNER_TYPE]]], name: str, - ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... - - @overload - def __getattr__( - self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], - name: str, - ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... + ) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ... @overload def __getattr__( @@ -295,7 +289,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): Returns: The result of the check. """ - return object_has_own_property_operation(self, key) + return object_has_own_property_operation(self, key).guess_type() @dataclasses.dataclass( @@ -310,7 +304,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): default_factory=dict ) - def _key_type(self) -> Type: + def _key_type(self) -> GenericType: """Get the type of the keys of the object. Returns: @@ -319,7 +313,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): args_list = typing.get_args(self._var_type) return args_list[0] if args_list else Any # pyright: ignore [reportReturnType] - def _value_type(self) -> Type: + def _value_type(self) -> GenericType: """Get the type of the values of the object. Returns: @@ -416,7 +410,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): @var_operation -def object_keys_operation(value: ObjectVar): +def object_keys_operation(value: Var): """Get the keys of an object. Args: @@ -428,11 +422,12 @@ def object_keys_operation(value: ObjectVar): return var_operation_return( js_expression=f"Object.keys({value})", var_type=List[str], + _raw_js_function="Object.keys", ) @var_operation -def object_values_operation(value: ObjectVar): +def object_values_operation(value: Var): """Get the values of an object. Args: @@ -443,12 +438,17 @@ def object_values_operation(value: ObjectVar): """ return var_operation_return( js_expression=f"Object.values({value})", - var_type=List[value._value_type()], + type_computer=unary_type_computer( + ReflexCallable[[Any], List[Any]], + lambda x: List[x.to(ObjectVar)._value_type()], + ), + var_type=List[Any], + _raw_js_function="Object.values", ) @var_operation -def object_entries_operation(value: ObjectVar): +def object_entries_operation(value: Var): """Get the entries of an object. Args: @@ -457,14 +457,20 @@ def object_entries_operation(value: ObjectVar): Returns: The entries of the object. """ + value = value.to(ObjectVar) return var_operation_return( js_expression=f"Object.entries({value})", - var_type=List[Tuple[str, value._value_type()]], + type_computer=unary_type_computer( + ReflexCallable[[Any], List[Tuple[str, Any]]], + lambda x: List[Tuple[str, x.to(ObjectVar)._value_type()]], + ), + var_type=List[Tuple[str, Any]], + _raw_js_function="Object.entries", ) @var_operation -def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): +def object_merge_operation(lhs: Var, rhs: Var): """Merge two objects. Args: @@ -476,10 +482,15 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): """ return var_operation_return( js_expression=f"({{...{lhs}, ...{rhs}}})", - var_type=Mapping[ - Union[lhs._key_type(), rhs._key_type()], - Union[lhs._value_type(), rhs._value_type()], - ], + type_computer=nary_type_computer( + ReflexCallable[[Any, Any], Mapping[Any, Any]], + ReflexCallable[[Any], Mapping[Any, Any]], + computer=lambda args: Mapping[ + unionize(*[arg.to(ObjectVar)._key_type() for arg in args]), + unionize(*[arg.to(ObjectVar)._value_type() for arg in args]), + ], + ), + var_type=Mapping[Any, Any], ) @@ -536,7 +547,7 @@ class ObjectItemOperation(CachedVarOperation, Var): @var_operation -def object_has_own_property_operation(object: ObjectVar, key: Var): +def object_has_own_property_operation(object: Var, key: Var): """Check if an object has a key. Args: diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 0e7b082f9..b24590bf3 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -3,6 +3,7 @@ from __future__ import annotations import dataclasses +import functools import inspect import json import re @@ -10,42 +11,50 @@ import typing from typing import ( TYPE_CHECKING, Any, - Dict, + Callable, + ClassVar, List, - Literal, - NoReturn, Sequence, + Set, Tuple, Type, Union, - overload, + cast, ) -from typing_extensions import TypeVar +from typing_extensions import TypeAliasType, TypeVar from reflex import constants from reflex.constants.base import REFLEX_VAR_OPENING_TAG from reflex.constants.colors import Color -from reflex.utils.exceptions import VarTypeError +from reflex.utils.exceptions import UntypedVarError, VarTypeError from reflex.utils.types import GenericType, get_origin - -from .base import ( +from reflex.vars.base import ( CachedVarOperation, CustomVarOperationReturn, LiteralVar, + ReflexCallable, Var, VarData, + VarWithDefault, _global_vars, cached_property_no_lock, figure_out_type, get_python_literal, get_unique_variable_name, + nary_type_computer, + passthrough_unary_type_computer, unionize, + unwrap_reflex_callalbe, var_operation, var_operation_return, ) + from .number import ( - BooleanVar, + _AT_SLICE_IMPORT, + _AT_SLICE_OR_INDEX, + _IS_TRUE_IMPORT, + _RANGE_IMPORT, LiteralNumberVar, NumberVar, raise_unsupported_operand_types, @@ -53,355 +62,24 @@ from .number import ( ) if TYPE_CHECKING: - from .base import BASE_TYPE, DATACLASS_TYPE, SQLA_TYPE from .function import FunctionVar - from .object import ObjectVar STRING_TYPE = TypeVar("STRING_TYPE", default=str) +ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[Set, Tuple, Sequence]) +OTHER_ARRAY_VAR_TYPE = TypeVar( + "OTHER_ARRAY_VAR_TYPE", bound=Union[Set, Tuple, Sequence] +) +INNER_ARRAY_VAR = TypeVar("INNER_ARRAY_VAR", covariant=True) +ANOTHER_ARRAY_VAR = TypeVar("ANOTHER_ARRAY_VAR", covariant=True) -class StringVar(Var[STRING_TYPE], python_types=str): - """Base class for immutable string vars.""" - - @overload - def __add__(self, other: StringVar | str) -> ConcatVarOperation: ... - - @overload - def __add__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __add__(self, other: Any) -> ConcatVarOperation: - """Concatenate two strings. - - Args: - other: The other string. - - Returns: - The string concatenation operation. - """ - if not isinstance(other, (StringVar, str)): - raise_unsupported_operand_types("+", (type(self), type(other))) - - return ConcatVarOperation.create(self, other) - - @overload - def __radd__(self, other: StringVar | str) -> ConcatVarOperation: ... - - @overload - def __radd__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __radd__(self, other: Any) -> ConcatVarOperation: - """Concatenate two strings. - - Args: - other: The other string. - - Returns: - The string concatenation operation. - """ - if not isinstance(other, (StringVar, str)): - raise_unsupported_operand_types("+", (type(other), type(self))) - - return ConcatVarOperation.create(other, self) - - @overload - def __mul__(self, other: NumberVar | int) -> StringVar: ... - - @overload - def __mul__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __mul__(self, other: Any) -> StringVar: - """Multiply the sequence by a number or an integer. - - Args: - other: The number or integer to multiply the sequence by. - - Returns: - StringVar: The resulting sequence after multiplication. - """ - if not isinstance(other, (NumberVar, int)): - raise_unsupported_operand_types("*", (type(self), type(other))) - - return (self.split() * other).join() - - @overload - def __rmul__(self, other: NumberVar | int) -> StringVar: ... - - @overload - def __rmul__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __rmul__(self, other: Any) -> StringVar: - """Multiply the sequence by a number or an integer. - - Args: - other: The number or integer to multiply the sequence by. - - Returns: - StringVar: The resulting sequence after multiplication. - """ - if not isinstance(other, (NumberVar, int)): - raise_unsupported_operand_types("*", (type(other), type(self))) - - return (self.split() * other).join() - - @overload - def __getitem__(self, i: slice) -> StringVar: ... - - @overload - def __getitem__(self, i: int | NumberVar) -> StringVar: ... - - def __getitem__(self, i: Any) -> StringVar: - """Get a slice of the string. - - Args: - i: The slice. - - Returns: - The string slice operation. - """ - if isinstance(i, slice): - return self.split()[i].join() - if not isinstance(i, (int, NumberVar)) or ( - isinstance(i, NumberVar) and i._is_strict_float() - ): - raise_unsupported_operand_types("[]", (type(self), type(i))) - return string_item_operation(self, i) - - def length(self) -> NumberVar: - """Get the length of the string. - - Returns: - The string length operation. - """ - return self.split().length() - - def lower(self) -> StringVar: - """Convert the string to lowercase. - - Returns: - The string lower operation. - """ - return string_lower_operation(self) - - def upper(self) -> StringVar: - """Convert the string to uppercase. - - Returns: - The string upper operation. - """ - return string_upper_operation(self) - - def strip(self) -> StringVar: - """Strip the string. - - Returns: - The string strip operation. - """ - return string_strip_operation(self) - - def reversed(self) -> StringVar: - """Reverse the string. - - Returns: - The string reverse operation. - """ - return self.split().reverse().join() - - @overload - def contains( - self, other: StringVar | str, field: StringVar | str | None = None - ) -> BooleanVar: ... - - @overload - def contains( # pyright: ignore [reportOverlappingOverload] - self, other: NoReturn, field: StringVar | str | None = None - ) -> NoReturn: ... - - def contains(self, other: Any, field: Any = None) -> BooleanVar: - """Check if the string contains another string. - - Args: - other: The other string. - field: The field to check. - - Returns: - The string contains operation. - """ - if not isinstance(other, (StringVar, str)): - raise_unsupported_operand_types("contains", (type(self), type(other))) - if field is not None: - if not isinstance(field, (StringVar, str)): - raise_unsupported_operand_types("contains", (type(self), type(field))) - return string_contains_field_operation(self, other, field) - return string_contains_operation(self, other) - - @overload - def split(self, separator: StringVar | str = "") -> ArrayVar[List[str]]: ... - - @overload - def split(self, separator: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def split(self, separator: Any = "") -> ArrayVar[List[str]]: - """Split the string. - - Args: - separator: The separator. - - Returns: - The string split operation. - """ - if not isinstance(separator, (StringVar, str)): - raise_unsupported_operand_types("split", (type(self), type(separator))) - return string_split_operation(self, separator) - - @overload - def startswith(self, prefix: StringVar | str) -> BooleanVar: ... - - @overload - def startswith(self, prefix: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def startswith(self, prefix: Any) -> BooleanVar: - """Check if the string starts with a prefix. - - Args: - prefix: The prefix. - - Returns: - The string starts with operation. - """ - if not isinstance(prefix, (StringVar, str)): - raise_unsupported_operand_types("startswith", (type(self), type(prefix))) - return string_starts_with_operation(self, prefix) - - @overload - def endswith(self, suffix: StringVar | str) -> BooleanVar: ... - - @overload - def endswith(self, suffix: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def endswith(self, suffix: Any) -> BooleanVar: - """Check if the string ends with a suffix. - - Args: - suffix: The suffix. - - Returns: - The string ends with operation. - """ - if not isinstance(suffix, (StringVar, str)): - raise_unsupported_operand_types("endswith", (type(self), type(suffix))) - return string_ends_with_operation(self, suffix) - - @overload - def __lt__(self, other: StringVar | str) -> BooleanVar: ... - - @overload - def __lt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __lt__(self, other: Any): - """Check if the string is less than another string. - - Args: - other: The other string. - - Returns: - The string less than operation. - """ - if not isinstance(other, (StringVar, str)): - raise_unsupported_operand_types("<", (type(self), type(other))) - - return string_lt_operation(self, other) - - @overload - def __gt__(self, other: StringVar | str) -> BooleanVar: ... - - @overload - def __gt__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __gt__(self, other: Any): - """Check if the string is greater than another string. - - Args: - other: The other string. - - Returns: - The string greater than operation. - """ - if not isinstance(other, (StringVar, str)): - raise_unsupported_operand_types(">", (type(self), type(other))) - - return string_gt_operation(self, other) - - @overload - def __le__(self, other: StringVar | str) -> BooleanVar: ... - - @overload - def __le__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __le__(self, other: Any): - """Check if the string is less than or equal to another string. - - Args: - other: The other string. - - Returns: - The string less than or equal operation. - """ - if not isinstance(other, (StringVar, str)): - raise_unsupported_operand_types("<=", (type(self), type(other))) - - return string_le_operation(self, other) - - @overload - def __ge__(self, other: StringVar | str) -> BooleanVar: ... - - @overload - def __ge__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __ge__(self, other: Any): - """Check if the string is greater than or equal to another string. - - Args: - other: The other string. - - Returns: - The string greater than or equal operation. - """ - if not isinstance(other, (StringVar, str)): - raise_unsupported_operand_types(">=", (type(self), type(other))) - - return string_ge_operation(self, other) - - @overload - def replace( # pyright: ignore [reportOverlappingOverload] - self, search_value: StringVar | str, new_value: StringVar | str - ) -> StringVar: ... - - @overload - def replace( - self, search_value: Any, new_value: Any - ) -> CustomVarOperationReturn[StringVar]: ... - - def replace(self, search_value: Any, new_value: Any) -> StringVar: # pyright: ignore [reportInconsistentOverload] - """Replace a string with a value. - - Args: - search_value: The string to search. - new_value: The value to be replaced with. - - Returns: - The string replace operation. - """ - if not isinstance(search_value, (StringVar, str)): - raise_unsupported_operand_types("replace", (type(self), type(search_value))) - if not isinstance(new_value, (StringVar, str)): - raise_unsupported_operand_types("replace", (type(self), type(new_value))) - - return string_replace_operation(self, search_value, new_value) +KEY_TYPE = TypeVar("KEY_TYPE") +VALUE_TYPE = TypeVar("VALUE_TYPE") @var_operation -def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): +def string_lt_operation(lhs: Var[str], rhs: Var[str]): """Check if a string is less than another string. Args: @@ -411,11 +89,11 @@ def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): Returns: The string less than operation. """ - return var_operation_return(js_expression=f"{lhs} < {rhs}", var_type=bool) + return var_operation_return(js_expression=f"({lhs} < {rhs})", var_type=bool) @var_operation -def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): +def string_gt_operation(lhs: Var[str], rhs: Var[str]): """Check if a string is greater than another string. Args: @@ -425,11 +103,11 @@ def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): Returns: The string greater than operation. """ - return var_operation_return(js_expression=f"{lhs} > {rhs}", var_type=bool) + return var_operation_return(js_expression=f"({lhs} > {rhs})", var_type=bool) @var_operation -def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): +def string_le_operation(lhs: Var[str], rhs: Var[str]): """Check if a string is less than or equal to another string. Args: @@ -439,11 +117,11 @@ def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): Returns: The string less than or equal operation. """ - return var_operation_return(js_expression=f"{lhs} <= {rhs}", var_type=bool) + return var_operation_return(js_expression=f"({lhs} <= {rhs})", var_type=bool) @var_operation -def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): +def string_ge_operation(lhs: Var[str], rhs: Var[str]): """Check if a string is greater than or equal to another string. Args: @@ -453,11 +131,11 @@ def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): Returns: The string greater than or equal operation. """ - return var_operation_return(js_expression=f"{lhs} >= {rhs}", var_type=bool) + return var_operation_return(js_expression=f"({lhs} >= {rhs})", var_type=bool) @var_operation -def string_lower_operation(string: StringVar[Any]): +def string_lower_operation(string: Var[str]): """Convert a string to lowercase. Args: @@ -466,11 +144,15 @@ def string_lower_operation(string: StringVar[Any]): Returns: The lowercase string. """ - return var_operation_return(js_expression=f"{string}.toLowerCase()", var_type=str) + return var_operation_return( + js_expression=f"String.prototype.toLowerCase.apply({string})", + var_type=str, + _raw_js_function="String.prototype.toLowerCase.apply", + ) @var_operation -def string_upper_operation(string: StringVar[Any]): +def string_upper_operation(string: Var[str]): """Convert a string to uppercase. Args: @@ -479,11 +161,15 @@ def string_upper_operation(string: StringVar[Any]): Returns: The uppercase string. """ - return var_operation_return(js_expression=f"{string}.toUpperCase()", var_type=str) + return var_operation_return( + js_expression=f"String.prototype.toUpperCase.apply({string})", + var_type=str, + _raw_js_function="String.prototype.toUpperCase.apply", + ) @var_operation -def string_strip_operation(string: StringVar[Any]): +def string_strip_operation(string: Var[str]): """Strip a string. Args: @@ -492,31 +178,38 @@ def string_strip_operation(string: StringVar[Any]): Returns: The stripped string. """ - return var_operation_return(js_expression=f"{string}.trim()", var_type=str) + return var_operation_return( + js_expression=f"String.prototype.trim.apply({string})", + var_type=str, + _raw_js_function="String.prototype.trim.apply", + ) @var_operation def string_contains_field_operation( - haystack: StringVar[Any], needle: StringVar[Any] | str, field: StringVar[Any] | str + haystack: Var[str], + needle: Var[str], ): """Check if a string contains another string. Args: haystack: The haystack. needle: The needle. - field: The field to check. Returns: The string contains operation. """ return var_operation_return( - js_expression=f"{haystack}.some(obj => obj[{field}] === {needle})", + js_expression=f"{haystack}.includes({needle})", var_type=bool, + var_data=VarData( + imports=_IS_TRUE_IMPORT, + ), ) @var_operation -def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] | str): +def string_contains_operation(haystack: Var[str], needle: Var[str]): """Check if a string contains another string. Args: @@ -532,9 +225,7 @@ def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] | @var_operation -def string_starts_with_operation( - full_string: StringVar[Any], prefix: StringVar[Any] | str -): +def string_starts_with_operation(full_string: Var[str], prefix: Var[str]): """Check if a string starts with a prefix. Args: @@ -550,9 +241,7 @@ def string_starts_with_operation( @var_operation -def string_ends_with_operation( - full_string: StringVar[Any], suffix: StringVar[Any] | str -): +def string_ends_with_operation(full_string: Var[str], suffix: Var[str]): """Check if a string ends with a suffix. Args: @@ -568,7 +257,7 @@ def string_ends_with_operation( @var_operation -def string_item_operation(string: StringVar[Any], index: NumberVar | int): +def string_item_operation(string: Var[str], index: Var[int]): """Get an item from a string. Args: @@ -582,22 +271,61 @@ def string_item_operation(string: StringVar[Any], index: NumberVar | int): @var_operation -def array_join_operation(array: ArrayVar, sep: StringVar[Any] | str = ""): - """Join the elements of an array. +def string_slice_operation( + string: Var[str], slice: Var[slice] +) -> CustomVarOperationReturn[str]: + """Get a slice from a string. Args: - array: The array. - sep: The separator. + string: The string. + slice: The slice. Returns: - The joined elements. + The sliced string. """ - return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str) + return var_operation_return( + js_expression=f'atSlice({string}.split(""), {slice}).join("")', + type_computer=nary_type_computer( + ReflexCallable[[List[str], slice], str], + ReflexCallable[[slice], str], + computer=lambda args: str, + ), + var_data=VarData( + imports=_AT_SLICE_IMPORT, + ), + ) + + +@var_operation +def string_index_or_slice_operation( + string: Var[str], index_or_slice: Var[Union[int, slice]] +) -> CustomVarOperationReturn[Union[str, Sequence[str]]]: + """Get an item or slice from a string. + + Args: + string: The string. + index_or_slice: The index or slice. + + Returns: + The item or slice from the string. + """ + return var_operation_return( + js_expression=f"Array.prototype.join.apply(atSliceOrIndex({string}, {index_or_slice}), [''])", + _raw_js_function="atSliceOrIndex", + type_computer=nary_type_computer( + ReflexCallable[[List[str], Union[int, slice]], str], + ReflexCallable[[Union[int, slice]], str], + computer=lambda args: str, + ), + var_data=VarData( + imports=_AT_SLICE_OR_INDEX, + ), + ) @var_operation def string_replace_operation( - string: StringVar[Any], search_value: StringVar | str, new_value: StringVar | str + string: Var[str], search_value: Var[str], new_value: Var[str] ): """Replace a string with a value. @@ -615,6 +343,908 @@ def string_replace_operation( ) +@var_operation +def array_pluck_operation( + array: Var[Sequence[Any]], + field: Var[str], +) -> CustomVarOperationReturn[Sequence[Any]]: + """Pluck a field from an array of objects. + + Args: + array: The array to pluck from. + field: The field to pluck from the objects in the array. + + Returns: + The reversed array. + """ + return var_operation_return( + js_expression=f"Array.prototype.map.apply({array}, [e=>e?.[{field}]])", + var_type=List[Any], + ) + + +@var_operation +def array_join_operation( + array: Var[Sequence[Any]], sep: VarWithDefault[str] = VarWithDefault("") +): + """Join the elements of an array. + + Args: + array: The array. + sep: The separator. + + Returns: + The joined elements. + """ + return var_operation_return( + js_expression=f"Array.prototype.join.apply({array},[{sep}])", var_type=str + ) + + +@var_operation +def array_reverse_operation( + array: Var[Sequence[INNER_ARRAY_VAR]], +) -> CustomVarOperationReturn[Sequence[INNER_ARRAY_VAR]]: + """Reverse an array. + + Args: + array: The array to reverse. + + Returns: + The reversed array. + """ + return var_operation_return( + js_expression=f"{array}.slice().reverse()", + type_computer=passthrough_unary_type_computer(ReflexCallable[[List], List]), + ) + + +@var_operation +def array_lt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]): + """Check if an array is less than another array. + + Args: + lhs: The left-hand side array. + rhs: The right-hand side array. + + Returns: + The array less than operation. + """ + return var_operation_return(js_expression=f"{lhs} < {rhs}", var_type=bool) + + +@var_operation +def array_gt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]): + """Check if an array is greater than another array. + + Args: + lhs: The left-hand side array. + rhs: The right-hand side array. + + Returns: + The array greater than operation. + """ + return var_operation_return(js_expression=f"{lhs} > {rhs}", var_type=bool) + + +@var_operation +def array_le_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]): + """Check if an array is less than or equal to another array. + + Args: + lhs: The left-hand side array. + rhs: The right-hand side array. + + Returns: + The array less than or equal operation. + """ + return var_operation_return(js_expression=f"{lhs} <= {rhs}", var_type=bool) + + +@var_operation +def array_ge_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]): + """Check if an array is greater than or equal to another array. + + Args: + lhs: The left-hand side array. + rhs: The right-hand side array. + + Returns: + The array greater than or equal operation. + """ + return var_operation_return(js_expression=f"{lhs} >= {rhs}", var_type=bool) + + +@var_operation +def array_length_operation(array: Var[ARRAY_VAR_TYPE]): + """Get the length of an array. + + Args: + array: The array. + + Returns: + The length of the array. + """ + return var_operation_return( + js_expression=f"{array}.length", + var_type=int, + ) + + +@var_operation +def string_split_operation( + string: Var[STRING_TYPE], sep: VarWithDefault[STRING_TYPE] = VarWithDefault("") +): + """Split a string. + + Args: + string: The string to split. + sep: The separator. + + Returns: + The split string. + """ + return var_operation_return( + js_expression=f"isTrue({sep}) ? {string}.split({sep}) : [...{string}]", + var_type=Sequence[str], + var_data=VarData(imports=_IS_TRUE_IMPORT), + ) + + +def _element_type(array: Var, index: Var) -> Any: + array_args = typing.get_args(array._var_type) + + if ( + array_args + and isinstance(index, LiteralNumberVar) + and is_tuple_type(array._var_type) + ): + index_value = int(index._var_value) + return array_args[index_value % len(array_args)] + + return unionize(*(array_arg for array_arg in array_args if array_arg is not ...)) + + +@var_operation +def array_item_or_slice_operation( + array: Var[Sequence[INNER_ARRAY_VAR]], + index_or_slice: Var[Union[int, slice]], +) -> CustomVarOperationReturn[Union[INNER_ARRAY_VAR, Sequence[INNER_ARRAY_VAR]]]: + """Get an item or slice from an array. + + Args: + array: The array. + index_or_slice: The index or slice. + + Returns: + The item or slice from the array. + """ + return var_operation_return( + js_expression=f"atSliceOrIndex({array}, {index_or_slice})", + _raw_js_function="atSliceOrIndex", + type_computer=nary_type_computer( + ReflexCallable[[Sequence, Union[int, slice]], Any], + ReflexCallable[[Union[int, slice]], Any], + computer=lambda args: ( + args[0]._var_type + if args[1]._var_type is slice + else (_element_type(args[0], args[1])) + ), + ), + var_data=VarData( + imports=_AT_SLICE_OR_INDEX, + ), + ) + + +@var_operation +def array_slice_operation( + array: Var[Sequence[INNER_ARRAY_VAR]], + slice: Var[slice], +) -> CustomVarOperationReturn[Sequence[INNER_ARRAY_VAR]]: + """Get a slice from an array. + + Args: + array: The array. + slice: The slice. + + Returns: + The item or slice from the array. + """ + return var_operation_return( + js_expression=f"atSlice({array}, {slice})", + type_computer=nary_type_computer( + ReflexCallable[[List, slice], Any], + ReflexCallable[[slice], Any], + computer=lambda args: args[0]._var_type, + ), + var_data=VarData( + imports=_AT_SLICE_IMPORT, + ), + ) + + +@var_operation +def array_item_operation( + array: Var[Sequence[INNER_ARRAY_VAR]], index: Var[int] +) -> CustomVarOperationReturn[INNER_ARRAY_VAR]: + """Get an item from an array. + + Args: + array: The array. + index: The index of the item. + + Returns: + The item from the array. + """ + + def type_computer(*args): + if len(args) == 0: + return ( + ReflexCallable[[List[Any], int], Any], + functools.partial(type_computer, *args), + ) + + array = args[0] + array_args = typing.get_args(array._var_type) + + if len(args) == 1: + return ( + ReflexCallable[[int], unionize(*array_args)], + functools.partial(type_computer, *args), + ) + + index = args[1] + + if ( + array_args + and isinstance(index, LiteralNumberVar) + and is_tuple_type(array._var_type) + ): + index_value = int(index._var_value) + element_type = array_args[index_value % len(array_args)] + else: + element_type = unionize(*array_args) + + return (ReflexCallable[[], element_type], None) + + return var_operation_return( + js_expression=f"{array}.at({index})", + type_computer=type_computer, + ) + + +@var_operation +def array_range_operation( + e1: Var[int], + e2: VarWithDefault[int | None] = VarWithDefault(None), + step: VarWithDefault[int] = VarWithDefault(1), +) -> CustomVarOperationReturn[Sequence[int]]: + """Create a range of numbers. + + Args: + e1: The end of the range if e2 is not provided, otherwise the start of the range. + e2: The end of the range. + step: The step of the range. + + Returns: + The range of numbers. + """ + return var_operation_return( + js_expression=f"[...range({e1}, {e2}, {step})]", + var_type=List[int], + var_data=VarData( + imports=_RANGE_IMPORT, + ), + ) + + +@var_operation +def array_contains_field_operation( + haystack: Var[ARRAY_VAR_TYPE], + needle: Var[Any], + field: VarWithDefault[str] = VarWithDefault(""), +): + """Check if an array contains an element. + + Args: + haystack: The array to check. + needle: The element to check for. + field: The field to check. + + Returns: + The array contains operation. + """ + return var_operation_return( + js_expression=f"isTrue({field}) ? {haystack}.some(obj => obj[{field}] === {needle}) : {haystack}.some(obj => obj === {needle})", + var_type=bool, + var_data=VarData( + imports=_IS_TRUE_IMPORT, + ), + ) + + +@var_operation +def array_contains_operation(haystack: Var[ARRAY_VAR_TYPE], needle: Var): + """Check if an array contains an element. + + Args: + haystack: The array to check. + needle: The element to check for. + + Returns: + The array contains operation. + """ + return var_operation_return( + js_expression=f"{haystack}.includes({needle})", + var_type=bool, + ) + + +@var_operation +def repeat_array_operation( + array: Var[Sequence[INNER_ARRAY_VAR]], count: Var[int] +) -> CustomVarOperationReturn[Sequence[INNER_ARRAY_VAR]]: + """Repeat an array a number of times. + + Args: + array: The array to repeat. + count: The number of times to repeat the array. + + Returns: + The repeated array. + """ + + def type_computer(*args: Var): + if not args: + return ( + ReflexCallable[[List[Any], int], List[Any]], + type_computer, + ) + if len(args) == 1: + return ( + ReflexCallable[[int], args[0]._var_type], + functools.partial(type_computer, *args), + ) + return (ReflexCallable[[], args[0]._var_type], None) + + return var_operation_return( + js_expression=f"Array.from({{ length: {count} }}).flatMap(() => {array})", + type_computer=type_computer, + ) + + +@var_operation +def repeat_string_operation( + string: Var[str], count: Var[int] +) -> CustomVarOperationReturn[str]: + """Repeat a string a number of times. + + Args: + string: The string to repeat. + count: The number of times to repeat the string. + + Returns: + The repeated string. + """ + return var_operation_return( + js_expression=f"{string}.repeat({count})", + var_type=str, + ) + + +if TYPE_CHECKING: + pass + + +@var_operation +def map_array_operation( + array: Var[Sequence[INNER_ARRAY_VAR]], + function: Var[ + ReflexCallable[[INNER_ARRAY_VAR, int], ANOTHER_ARRAY_VAR] + | ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR] + | ReflexCallable[[], ANOTHER_ARRAY_VAR] + ], +) -> CustomVarOperationReturn[Sequence[ANOTHER_ARRAY_VAR]]: + """Map a function over an array. + + Args: + array: The array. + function: The function to map. + + Returns: + The mapped array. + """ + + def type_computer(*args: Var): + if not args: + return ( + ReflexCallable[[List[Any], ReflexCallable], List[Any]], + type_computer, + ) + if len(args) == 1: + return ( + ReflexCallable[[ReflexCallable], List[Any]], + functools.partial(type_computer, *args), + ) + return (ReflexCallable[[], List[args[0]._var_type]], None) + + return var_operation_return( + js_expression=f"Array.prototype.map.apply({array}, [{function}])", + type_computer=nary_type_computer( + ReflexCallable[[List[Any], ReflexCallable], List[Any]], + ReflexCallable[[ReflexCallable], List[Any]], + computer=lambda args: List[unwrap_reflex_callalbe(args[1]._var_type)[1]], + ), + ) + + +@var_operation +def array_concat_operation( + lhs: Var[Sequence[INNER_ARRAY_VAR]], rhs: Var[Sequence[ANOTHER_ARRAY_VAR]] +) -> CustomVarOperationReturn[Sequence[INNER_ARRAY_VAR | ANOTHER_ARRAY_VAR]]: + """Concatenate two arrays. + + Args: + lhs: The left-hand side array. + rhs: The right-hand side array. + + Returns: + The concatenated array. + """ + return var_operation_return( + js_expression=f"[...{lhs}, ...{rhs}]", + type_computer=nary_type_computer( + ReflexCallable[[List[Any], List[Any]], List[Any]], + ReflexCallable[[List[Any]], List[Any]], + computer=lambda args: unionize(args[0]._var_type, args[1]._var_type), + ), + ) + + +@var_operation +def string_concat_operation( + lhs: Var[str], rhs: Var[str] +) -> CustomVarOperationReturn[str]: + """Concatenate two strings. + + Args: + lhs: The left-hand side string. + rhs: The right-hand side string. + + Returns: + The concatenated string. + """ + return var_operation_return( + js_expression=f"{lhs} + {rhs}", + var_type=str, + ) + + +@var_operation +def reverse_string_concat_operation( + lhs: Var[str], rhs: Var[str] +) -> CustomVarOperationReturn[str]: + """Concatenate two strings in reverse order. + + Args: + lhs: The left-hand side string. + rhs: The right-hand side string. + + Returns: + The concatenated string. + """ + return var_operation_return( + js_expression=f"{rhs} + {lhs}", + var_type=str, + ) + + +class SliceVar(Var[slice], python_types=slice): + """Base class for immutable slice vars.""" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + slots=True, +) +class LiteralSliceVar(CachedVarOperation, LiteralVar, SliceVar): + """Base class for immutable literal slice vars.""" + + _var_value: slice = dataclasses.field(default_factory=lambda: slice(None)) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"[{LiteralVar.create(self._var_value.start)!s}, {LiteralVar.create(self._var_value.stop)!s}, {LiteralVar.create(self._var_value.step)!s}]" + + @cached_property_no_lock + def _cached_get_all_var_data(self) -> VarData | None: + """Get all the VarData asVarDatae Var. + + Returns: + The VarData associated with the Var. + """ + return VarData.merge( + *[ + var._get_all_var_data() + for var in [ + self._var_value.start, + self._var_value.stop, + self._var_value.step, + ] + if isinstance(var, Var) + ], + self._var_data, + ) + + @classmethod + def create( + cls, + value: slice, + _var_type: Type[slice] | None = None, + _var_data: VarData | None = None, + ) -> SliceVar: + """Create a var from a slice value. + + Args: + value: The value to create the var from. + _var_type: The type of the var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _js_expr="", + _var_type=slice if _var_type is None else _var_type, + _var_data=_var_data, + _var_value=value, + ) + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash( + ( + self.__class__.__name__, + self._var_value.start, + self._var_value.stop, + self._var_value.step, + ) + ) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + """ + return json.dumps( + [self._var_value.start, self._var_value.stop, self._var_value.step] + ) + + +class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(Sequence, set)): + """Base class for immutable array vars.""" + + join = array_join_operation + + reverse = array_reverse_operation + + __add__ = array_concat_operation + + __getitem__ = array_item_or_slice_operation + + at = array_item_operation + + slice = array_slice_operation + + length = array_length_operation + + range: ClassVar[ + FunctionVar[ + ReflexCallable[ + [int, VarWithDefault[int | None], VarWithDefault[int]], Sequence[int] + ] + ] + ] = array_range_operation + + contains = array_contains_field_operation + + pluck = array_pluck_operation + + __rmul__ = __mul__ = repeat_array_operation + + __lt__ = array_lt_operation + + __gt__ = array_gt_operation + + __le__ = array_le_operation + + __ge__ = array_ge_operation + + def foreach( + self: ArrayVar[Sequence[INNER_ARRAY_VAR]], + fn: Callable[[Var[INNER_ARRAY_VAR], NumberVar[int]], ANOTHER_ARRAY_VAR] + | Callable[[Var[INNER_ARRAY_VAR]], ANOTHER_ARRAY_VAR] + | Callable[[], ANOTHER_ARRAY_VAR], + ) -> ArrayVar[Sequence[ANOTHER_ARRAY_VAR]]: + """Apply a function to each element of the array. + + Args: + fn: The function to apply. + + Returns: + The array after applying the function. + + Raises: + VarTypeError: If the function takes more than one argument. + TypeError: If the function is a ComponentState. + """ + from reflex.state import ComponentState + + from .function import ArgsFunctionOperation + + if not callable(fn): + raise_unsupported_operand_types("foreach", (type(self), type(fn))) + + # get the number of arguments of the function + required_num_args = len( + [ + p + for p in inspect.signature(fn).parameters.values() + if p.default == p.empty + ] + ) + if required_num_args > 2: + raise VarTypeError( + "The function passed to foreach should take at most two arguments." + ) + + num_args = len(inspect.signature(fn).parameters) + + if ( + hasattr(fn, "__qualname__") + and fn.__qualname__ == ComponentState.create.__qualname__ + ): + raise TypeError( + "Using a ComponentState as `render_fn` inside `rx.foreach` is not supported yet." + ) + + def invoke_fn(*args): + try: + return fn(*args) + except UntypedVarError as e: + raise UntypedVarError( + f"Could not foreach over var `{self!s}` without a type annotation. " + "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" + ) from e + + if num_args == 0: + fn_result = invoke_fn() + return_value = Var.create(fn_result) + simple_function_var: FunctionVar[ReflexCallable[[], ANOTHER_ARRAY_VAR]] = ( + ArgsFunctionOperation.create( + (), + return_value, + _var_type=ReflexCallable[[], return_value._var_type], + ) + ) + return map_array_operation(self, simple_function_var).guess_type() + + # generic number var + number_var = Var("").to(NumberVar, int) + + first_arg_type = self.__getitem__(number_var)._var_type + + arg_name = get_unique_variable_name() + + # get first argument type + first_arg = cast( + Var[Any], + Var( + _js_expr=arg_name, + _var_type=first_arg_type, + ).guess_type(), + ) + + if required_num_args < 2: + fn_result = invoke_fn(first_arg) + + return_value = Var.create(fn_result) + + function_var = cast( + Var[ReflexCallable[[INNER_ARRAY_VAR], ANOTHER_ARRAY_VAR]], + ArgsFunctionOperation.create( + (arg_name,), + return_value, + _var_type=ReflexCallable[[first_arg_type], return_value._var_type], + ), + ) + + return map_array_operation.call(self, function_var).guess_type() + + second_arg = cast( + NumberVar[int], + Var( + _js_expr=get_unique_variable_name(), + _var_type=int, + ).guess_type(), + ) + + fn_result = invoke_fn(first_arg, second_arg) + + return_value = Var.create(fn_result) + + function_var = cast( + Var[ReflexCallable[[INNER_ARRAY_VAR, int], ANOTHER_ARRAY_VAR]], + ArgsFunctionOperation.create( + (arg_name, second_arg._js_expr), + return_value, + _var_type=ReflexCallable[[first_arg_type, int], return_value._var_type], + ), + ) + + return map_array_operation.call(self, function_var).guess_type() + + +LIST_ELEMENT = TypeVar("LIST_ELEMENT", covariant=True) + +ARRAY_VAR_OF_LIST_ELEMENT = TypeAliasType( + "ARRAY_VAR_OF_LIST_ELEMENT", + Union[ + ArrayVar[Sequence[LIST_ELEMENT]], + ArrayVar[Set[LIST_ELEMENT]], + ], + type_params=(LIST_ELEMENT,), +) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + slots=True, +) +class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): + """Base class for immutable literal array vars.""" + + _var_value: Union[ + Sequence[Union[Var, Any]], + Set[Union[Var, Any]], + ] = dataclasses.field(default_factory=list) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "[" + + ", ".join( + [str(LiteralVar.create(element)) for element in self._var_value] + ) + + "]" + ) + + @cached_property_no_lock + def _cached_get_all_var_data(self) -> VarData | None: + """Get all the VarData associated with the Var. + + Returns: + The VarData associated with the Var. + """ + return VarData.merge( + *[ + LiteralVar.create(element)._get_all_var_data() + for element in self._var_value + ], + self._var_data, + ) + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._js_expr)) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + + Raises: + TypeError: If the array elements are not of type LiteralVar. + """ + elements = [] + for element in self._var_value: + element_var = LiteralVar.create(element) + if not isinstance(element_var, LiteralVar): + raise TypeError( + f"Array elements must be of type LiteralVar, not {type(element_var)}" + ) + elements.append(element_var.json()) + + return "[" + ", ".join(elements) + "]" + + @classmethod + def create( + cls, + value: OTHER_ARRAY_VAR_TYPE, + _var_type: Type[OTHER_ARRAY_VAR_TYPE] | None = None, + _var_data: VarData | None = None, + ) -> LiteralArrayVar[OTHER_ARRAY_VAR_TYPE]: + """Create a var from a string value. + + Args: + value: The value to create the var from. + _var_type: The type of the var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return LiteralArrayVar( + _js_expr="", + _var_type=figure_out_type(value) if _var_type is None else _var_type, + _var_data=_var_data, + _var_value=value, + ) + + +class StringVar(Var[STRING_TYPE], python_types=str): + """Base class for immutable string vars.""" + + __add__ = string_concat_operation + + __radd__ = reverse_string_concat_operation + + __getitem__ = string_index_or_slice_operation + + at = string_item_operation + + slice = string_slice_operation + + lower = string_lower_operation + + upper = string_upper_operation + + strip = string_strip_operation + + contains = string_contains_field_operation + + split = string_split_operation + + length = split.chain(array_length_operation) + + reversed = split.chain(array_reverse_operation).chain(array_join_operation) + + startswith = string_starts_with_operation + + __rmul__ = __mul__ = repeat_string_operation + + __lt__ = string_lt_operation + + __gt__ = string_gt_operation + + __le__ = string_le_operation + + __ge__ = string_ge_operation + + # Compile regex for finding reflex var tags. _decode_var_pattern_re = ( rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}" @@ -733,7 +1363,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]): Returns: The hash of the var. """ - return hash((type(self).__name__, self._var_value)) + return hash((self.__class__.__name__, self._var_value)) def json(self) -> str: """Get the JSON representation of the var. @@ -809,7 +1439,7 @@ class ConcatVarOperation(CachedVarOperation, StringVar[str]): """Create a var from a string value. Args: - *value: The values to concatenate. + value: The values to concatenate. _var_data: Additional hooks and imports associated with the Var. Returns: @@ -823,763 +1453,6 @@ class ConcatVarOperation(CachedVarOperation, StringVar[str]): ) -ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Sequence, covariant=True) -OTHER_ARRAY_VAR_TYPE = TypeVar("OTHER_ARRAY_VAR_TYPE", bound=Sequence) - -OTHER_TUPLE = TypeVar("OTHER_TUPLE") - -INNER_ARRAY_VAR = TypeVar("INNER_ARRAY_VAR") - -KEY_TYPE = TypeVar("KEY_TYPE") -VALUE_TYPE = TypeVar("VALUE_TYPE") - - -class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): - """Base class for immutable array vars.""" - - @overload - def join(self, sep: StringVar | str = "") -> StringVar: ... - - @overload - def join(self, sep: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def join(self, sep: Any = "") -> StringVar: - """Join the elements of the array. - - Args: - sep: The separator between elements. - - Returns: - The joined elements. - """ - if not isinstance(sep, (StringVar, str)): - raise_unsupported_operand_types("join", (type(self), type(sep))) - if ( - isinstance(self, LiteralArrayVar) - and ( - len( - args := [ - x - for x in self._var_value - if isinstance(x, (LiteralStringVar, str)) - ] - ) - == len(self._var_value) - ) - and isinstance(sep, (LiteralStringVar, str)) - ): - sep_str = sep._var_value if isinstance(sep, LiteralStringVar) else sep - return LiteralStringVar.create( - sep_str.join( - i._var_value if isinstance(i, LiteralStringVar) else i for i in args - ) - ) - return array_join_operation(self, sep) - - def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]: - """Reverse the array. - - Returns: - The reversed array. - """ - return array_reverse_operation(self) - - @overload - def __add__(self, other: ArrayVar[ARRAY_VAR_TYPE]) -> ArrayVar[ARRAY_VAR_TYPE]: ... - - @overload - def __add__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __add__(self, other: Any) -> ArrayVar[ARRAY_VAR_TYPE]: - """Concatenate two arrays. - - Parameters: - other: The other array to concatenate. - - Returns: - ArrayConcatOperation: The concatenation of the two arrays. - """ - if not isinstance(other, ArrayVar): - raise_unsupported_operand_types("+", (type(self), type(other))) - - return array_concat_operation(self, other) - - @overload - def __getitem__(self, i: slice) -> ArrayVar[ARRAY_VAR_TYPE]: ... - - @overload - def __getitem__( - self: ( - ArrayVar[Tuple[int, OTHER_TUPLE]] - | ArrayVar[Tuple[float, OTHER_TUPLE]] - | ArrayVar[Tuple[int | float, OTHER_TUPLE]] - ), - i: Literal[0, -2], - ) -> NumberVar: ... - - @overload - def __getitem__( - self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1] - ) -> BooleanVar: ... - - @overload - def __getitem__( - self: ( - ArrayVar[Tuple[Any, int]] - | ArrayVar[Tuple[Any, float]] - | ArrayVar[Tuple[Any, int | float]] - ), - i: Literal[1, -1], - ) -> NumberVar: ... - - @overload - def __getitem__( - self: ArrayVar[Tuple[str, Any]], i: Literal[0, -2] - ) -> StringVar: ... - - @overload - def __getitem__( - self: ArrayVar[Tuple[Any, str]], i: Literal[1, -1] - ) -> StringVar: ... - - @overload - def __getitem__( - self: ArrayVar[Tuple[bool, Any]], i: Literal[0, -2] - ) -> BooleanVar: ... - - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar - ) -> BooleanVar: ... - - @overload - def __getitem__( - self: ( - ARRAY_VAR_OF_LIST_ELEMENT[int] - | ARRAY_VAR_OF_LIST_ELEMENT[float] - | ARRAY_VAR_OF_LIST_ELEMENT[int | float] - ), - i: int | NumberVar, - ) -> NumberVar: ... - - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar - ) -> StringVar: ... - - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]], - i: int | NumberVar, - ) -> ArrayVar[List[INNER_ARRAY_VAR]]: ... - - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[KEY_TYPE, VALUE_TYPE]], - i: int | NumberVar, - ) -> ArrayVar[Tuple[KEY_TYPE, VALUE_TYPE]]: ... - - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[INNER_ARRAY_VAR, ...]], - i: int | NumberVar, - ) -> ArrayVar[Tuple[INNER_ARRAY_VAR, ...]]: ... - - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[Dict[KEY_TYPE, VALUE_TYPE]], - i: int | NumberVar, - ) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ... - - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[BASE_TYPE], - i: int | NumberVar, - ) -> ObjectVar[BASE_TYPE]: ... - - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[SQLA_TYPE], - i: int | NumberVar, - ) -> ObjectVar[SQLA_TYPE]: ... - - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[DATACLASS_TYPE], - i: int | NumberVar, - ) -> ObjectVar[DATACLASS_TYPE]: ... - - @overload - def __getitem__(self, i: int | NumberVar) -> Var: ... - - def __getitem__(self, i: Any) -> ArrayVar[ARRAY_VAR_TYPE] | Var: - """Get a slice of the array. - - Args: - i: The slice. - - Returns: - The array slice operation. - """ - if isinstance(i, slice): - return ArraySliceOperation.create(self, i) - if not isinstance(i, (int, NumberVar)) or ( - isinstance(i, NumberVar) and i._is_strict_float() - ): - raise_unsupported_operand_types("[]", (type(self), type(i))) - return array_item_operation(self, i) - - def length(self) -> NumberVar[int]: - """Get the length of the array. - - Returns: - The length of the array. - """ - return array_length_operation(self) - - @overload - @classmethod - def range(cls, stop: int | NumberVar, /) -> ArrayVar[List[int]]: ... - - @overload - @classmethod - def range( - cls, - start: int | NumberVar, - end: int | NumberVar, - step: int | NumberVar = 1, - /, - ) -> ArrayVar[List[int]]: ... - - @overload - @classmethod - def range( - cls, - first_endpoint: int | NumberVar, - second_endpoint: int | NumberVar | None = None, - step: int | NumberVar | None = None, - ) -> ArrayVar[List[int]]: ... - - @classmethod - def range( - cls, - first_endpoint: int | NumberVar, - second_endpoint: int | NumberVar | None = None, - step: int | NumberVar | None = None, - ) -> ArrayVar[List[int]]: - """Create a range of numbers. - - Args: - first_endpoint: The end of the range if second_endpoint is not provided, otherwise the start of the range. - second_endpoint: The end of the range. - step: The step of the range. - - Returns: - The range of numbers. - """ - if any( - not isinstance(i, (int, NumberVar)) - for i in (first_endpoint, second_endpoint, step) - if i is not None - ): - raise_unsupported_operand_types( - "range", (type(first_endpoint), type(second_endpoint), type(step)) - ) - if second_endpoint is None: - start = 0 - end = first_endpoint - else: - start = first_endpoint - end = second_endpoint - - return array_range_operation(start, end, step or 1) - - @overload - def contains(self, other: Any) -> BooleanVar: ... - - @overload - def contains(self, other: Any, field: StringVar | str) -> BooleanVar: ... - - def contains(self, other: Any, field: Any = None) -> BooleanVar: - """Check if the array contains an element. - - Args: - other: The element to check for. - field: The field to check. - - Returns: - The array contains operation. - """ - if field is not None: - if not isinstance(field, (StringVar, str)): - raise_unsupported_operand_types("contains", (type(self), type(field))) - return array_contains_field_operation(self, other, field) - return array_contains_operation(self, other) - - def pluck(self, field: StringVar | str) -> ArrayVar: - """Pluck a field from the array. - - Args: - field: The field to pluck from the array. - - Returns: - The array pluck operation. - """ - return array_pluck_operation(self, field) - - @overload - def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: ... - - @overload - def __mul__(self, other: NoReturn) -> NoReturn: ... # pyright: ignore [reportOverlappingOverload] - - def __mul__(self, other: Any) -> ArrayVar[ARRAY_VAR_TYPE]: - """Multiply the sequence by a number or integer. - - Parameters: - other: The number or integer to multiply the sequence by. - - Returns: - ArrayVar[ARRAY_VAR_TYPE]: The result of multiplying the sequence by the given number or integer. - """ - if not isinstance(other, (NumberVar, int)) or ( - isinstance(other, NumberVar) and other._is_strict_float() - ): - raise_unsupported_operand_types("*", (type(self), type(other))) - - return repeat_array_operation(self, other) - - __rmul__ = __mul__ - - @overload - def __lt__(self, other: ArrayVar[ARRAY_VAR_TYPE]) -> BooleanVar: ... - - @overload - def __lt__(self, other: list | tuple) -> BooleanVar: ... - - def __lt__(self, other: Any): - """Check if the array is less than another array. - - Args: - other: The other array. - - Returns: - The array less than operation. - """ - if not isinstance(other, (ArrayVar, list, tuple)): - raise_unsupported_operand_types("<", (type(self), type(other))) - - return array_lt_operation(self, other) - - @overload - def __gt__(self, other: ArrayVar[ARRAY_VAR_TYPE]) -> BooleanVar: ... - - @overload - def __gt__(self, other: list | tuple) -> BooleanVar: ... - - def __gt__(self, other: Any): - """Check if the array is greater than another array. - - Args: - other: The other array. - - Returns: - The array greater than operation. - """ - if not isinstance(other, (ArrayVar, list, tuple)): - raise_unsupported_operand_types(">", (type(self), type(other))) - - return array_gt_operation(self, other) - - @overload - def __le__(self, other: ArrayVar[ARRAY_VAR_TYPE]) -> BooleanVar: ... - - @overload - def __le__(self, other: list | tuple) -> BooleanVar: ... - - def __le__(self, other: Any): - """Check if the array is less than or equal to another array. - - Args: - other: The other array. - - Returns: - The array less than or equal operation. - """ - if not isinstance(other, (ArrayVar, list, tuple)): - raise_unsupported_operand_types("<=", (type(self), type(other))) - - return array_le_operation(self, other) - - @overload - def __ge__(self, other: ArrayVar[ARRAY_VAR_TYPE]) -> BooleanVar: ... - - @overload - def __ge__(self, other: list | tuple) -> BooleanVar: ... - - def __ge__(self, other: Any): - """Check if the array is greater than or equal to another array. - - Args: - other: The other array. - - Returns: - The array greater than or equal operation. - """ - if not isinstance(other, (ArrayVar, list, tuple)): - raise_unsupported_operand_types(">=", (type(self), type(other))) - - return array_ge_operation(self, other) - - def foreach(self, fn: Any): - """Apply a function to each element of the array. - - Args: - fn: The function to apply. - - Returns: - The array after applying the function. - - Raises: - VarTypeError: If the function takes more than one argument. - """ - from .function import ArgsFunctionOperation - - if not callable(fn): - raise_unsupported_operand_types("foreach", (type(self), type(fn))) - # get the number of arguments of the function - num_args = len(inspect.signature(fn).parameters) - if num_args > 1: - raise VarTypeError( - "The function passed to foreach should take at most one argument." - ) - - if num_args == 0: - return_value = fn() - function_var = ArgsFunctionOperation.create((), return_value) - else: - # generic number var - number_var = Var("").to(NumberVar, int) - - first_arg_type = self[number_var]._var_type - - arg_name = get_unique_variable_name() - - # get first argument type - first_arg = Var( - _js_expr=arg_name, - _var_type=first_arg_type, - ).guess_type() - - function_var = ArgsFunctionOperation.create( - (arg_name,), - Var.create(fn(first_arg)), - ) - - return map_array_operation(self, function_var) - - -LIST_ELEMENT = TypeVar("LIST_ELEMENT") - -ARRAY_VAR_OF_LIST_ELEMENT = ArrayVar[Sequence[LIST_ELEMENT]] - - -@dataclasses.dataclass( - eq=False, - frozen=True, - slots=True, -) -class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): - """Base class for immutable literal array vars.""" - - _var_value: Sequence[Union[Var, Any]] = dataclasses.field(default=()) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return ( - "[" - + ", ".join( - [str(LiteralVar.create(element)) for element in self._var_value] - ) - + "]" - ) - - @cached_property_no_lock - def _cached_get_all_var_data(self) -> VarData | None: - """Get all the VarData associated with the Var. - - Returns: - The VarData associated with the Var. - """ - return VarData.merge( - *[ - LiteralVar.create(element)._get_all_var_data() - for element in self._var_value - ], - self._var_data, - ) - - def __hash__(self) -> int: - """Get the hash of the var. - - Returns: - The hash of the var. - """ - return hash((self.__class__.__name__, self._js_expr)) - - def json(self) -> str: - """Get the JSON representation of the var. - - Returns: - The JSON representation of the var. - - Raises: - TypeError: If the array elements are not of type LiteralVar. - """ - elements = [] - for element in self._var_value: - element_var = LiteralVar.create(element) - if not isinstance(element_var, LiteralVar): - raise TypeError( - f"Array elements must be of type LiteralVar, not {type(element_var)}" - ) - elements.append(element_var.json()) - - return "[" + ", ".join(elements) + "]" - - @classmethod - def create( - cls, - value: OTHER_ARRAY_VAR_TYPE, - _var_type: Type[OTHER_ARRAY_VAR_TYPE] | None = None, - _var_data: VarData | None = None, - ) -> LiteralArrayVar[OTHER_ARRAY_VAR_TYPE]: - """Create a var from a string value. - - Args: - value: The value to create the var from. - _var_type: The type of the var. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return LiteralArrayVar( - _js_expr="", - _var_type=figure_out_type(value) if _var_type is None else _var_type, - _var_data=_var_data, - _var_value=value, - ) - - -@var_operation -def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""): - """Split a string. - - Args: - string: The string to split. - sep: The separator. - - Returns: - The split string. - """ - return var_operation_return( - js_expression=f"{string}.split({sep})", var_type=List[str] - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - slots=True, -) -class ArraySliceOperation(CachedVarOperation, ArrayVar): - """Base class for immutable string vars that are the result of a string slice operation.""" - - _array: ArrayVar = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]) - ) - _start: NumberVar | int = dataclasses.field(default_factory=lambda: 0) - _stop: NumberVar | int = dataclasses.field(default_factory=lambda: 0) - _step: NumberVar | int = dataclasses.field(default_factory=lambda: 1) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - - Raises: - ValueError: If the slice step is zero. - """ - start, end, step = self._start, self._stop, self._step - - normalized_start = ( - LiteralVar.create(start) if start is not None else Var(_js_expr="undefined") - ) - normalized_end = ( - LiteralVar.create(end) if end is not None else Var(_js_expr="undefined") - ) - if step is None: - return f"{self._array!s}.slice({normalized_start!s}, {normalized_end!s})" - if not isinstance(step, Var): - if step < 0: - actual_start = end + 1 if end is not None else 0 - actual_end = start + 1 if start is not None else self._array.length() - return str(self._array[actual_start:actual_end].reverse()[::-step]) - if step == 0: - raise ValueError("slice step cannot be zero") - return f"{self._array!s}.slice({normalized_start!s}, {normalized_end!s}).filter((_, i) => i % {step!s} === 0)" - - actual_start_reverse = end + 1 if end is not None else 0 - actual_end_reverse = start + 1 if start is not None else self._array.length() - - return f"{self.step!s} > 0 ? {self._array!s}.slice({normalized_start!s}, {normalized_end!s}).filter((_, i) => i % {step!s} === 0) : {self._array!s}.slice({actual_start_reverse!s}, {actual_end_reverse!s}).reverse().filter((_, i) => i % {-step!s} === 0)" - - @classmethod - def create( - cls, - array: ArrayVar, - slice: slice, - _var_data: VarData | None = None, - ) -> ArraySliceOperation: - """Create a var from a string value. - - Args: - array: The array. - slice: The slice. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _js_expr="", - _var_type=array._var_type, - _var_data=_var_data, - _array=array, - _start=slice.start, - _stop=slice.stop, - _step=slice.step, - ) - - -@var_operation -def array_pluck_operation( - array: ArrayVar[ARRAY_VAR_TYPE], - field: StringVar | str, -) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: - """Pluck a field from an array of objects. - - Args: - array: The array to pluck from. - field: The field to pluck from the objects in the array. - - Returns: - The reversed array. - """ - return var_operation_return( - js_expression=f"{array}.map(e=>e?.[{field}])", - var_type=array._var_type, - ) - - -@var_operation -def array_reverse_operation( - array: ArrayVar[ARRAY_VAR_TYPE], -) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: - """Reverse an array. - - Args: - array: The array to reverse. - - Returns: - The reversed array. - """ - return var_operation_return( - js_expression=f"{array}.slice().reverse()", - var_type=array._var_type, - ) - - -@var_operation -def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): - """Check if an array is less than another array. - - Args: - lhs: The left-hand side array. - rhs: The right-hand side array. - - Returns: - The array less than operation. - """ - return var_operation_return(js_expression=f"{lhs} < {rhs}", var_type=bool) - - -@var_operation -def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): - """Check if an array is greater than another array. - - Args: - lhs: The left-hand side array. - rhs: The right-hand side array. - - Returns: - The array greater than operation. - """ - return var_operation_return(js_expression=f"{lhs} > {rhs}", var_type=bool) - - -@var_operation -def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): - """Check if an array is less than or equal to another array. - - Args: - lhs: The left-hand side array. - rhs: The right-hand side array. - - Returns: - The array less than or equal operation. - """ - return var_operation_return(js_expression=f"{lhs} <= {rhs}", var_type=bool) - - -@var_operation -def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): - """Check if an array is greater than or equal to another array. - - Args: - lhs: The left-hand side array. - rhs: The right-hand side array. - - Returns: - The array greater than or equal operation. - """ - return var_operation_return(js_expression=f"{lhs} >= {rhs}", var_type=bool) - - -@var_operation -def array_length_operation(array: ArrayVar): - """Get the length of an array. - - Args: - array: The array. - - Returns: - The length of the array. - """ - return var_operation_return( - js_expression=f"{array}.length", - var_type=int, - ) - - def is_tuple_type(t: GenericType) -> bool: """Check if a type is a tuple type. @@ -1594,146 +1467,6 @@ def is_tuple_type(t: GenericType) -> bool: return get_origin(t) is tuple -@var_operation -def array_item_operation(array: ArrayVar, index: NumberVar | int): - """Get an item from an array. - - Args: - array: The array. - index: The index of the item. - - Returns: - The item from the array. - """ - args = typing.get_args(array._var_type) - if args and isinstance(index, LiteralNumberVar) and is_tuple_type(array._var_type): - index_value = int(index._var_value) - element_type = args[index_value % len(args)] - else: - element_type = unionize(*args) - - return var_operation_return( - js_expression=f"{array!s}.at({index!s})", - var_type=element_type, - ) - - -@var_operation -def array_range_operation( - start: NumberVar | int, stop: NumberVar | int, step: NumberVar | int -): - """Create a range of numbers. - - Args: - start: The start of the range. - stop: The end of the range. - step: The step of the range. - - Returns: - The range of numbers. - """ - return var_operation_return( - js_expression=f"Array.from({{ length: Math.ceil(({stop!s} - {start!s}) / {step!s}) }}, (_, i) => {start!s} + i * {step!s})", - var_type=List[int], - ) - - -@var_operation -def array_contains_field_operation( - haystack: ArrayVar, needle: Any | Var, field: StringVar | str -): - """Check if an array contains an element. - - Args: - haystack: The array to check. - needle: The element to check for. - field: The field to check. - - Returns: - The array contains operation. - """ - return var_operation_return( - js_expression=f"{haystack}.some(obj => obj[{field}] === {needle})", - var_type=bool, - ) - - -@var_operation -def array_contains_operation( - haystack: ArrayVar, needle: Any | Var -) -> CustomVarOperationReturn[bool]: - """Check if an array contains an element. - - Args: - haystack: The array to check. - needle: The element to check for. - - Returns: - The array contains operation. - """ - return var_operation_return( - js_expression=f"{haystack}.includes({needle})", - var_type=bool, - ) - - -@var_operation -def repeat_array_operation( - array: ArrayVar[ARRAY_VAR_TYPE], count: NumberVar | int -) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: - """Repeat an array a number of times. - - Args: - array: The array to repeat. - count: The number of times to repeat the array. - - Returns: - The repeated array. - """ - return var_operation_return( - js_expression=f"Array.from({{ length: {count} }}).flatMap(() => {array})", - var_type=array._var_type, - ) - - -@var_operation -def map_array_operation( - array: ArrayVar[ARRAY_VAR_TYPE], - function: FunctionVar, -) -> CustomVarOperationReturn[List[Any]]: - """Map a function over an array. - - Args: - array: The array. - function: The function to map. - - Returns: - The mapped array. - """ - return var_operation_return( - js_expression=f"{array}.map({function})", var_type=List[Any] - ) - - -@var_operation -def array_concat_operation( - lhs: ArrayVar[ARRAY_VAR_TYPE], rhs: ArrayVar[ARRAY_VAR_TYPE] -) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: - """Concatenate two arrays. - - Args: - lhs: The left-hand side array. - rhs: The right-hand side array. - - Returns: - The concatenated array. - """ - return var_operation_return( - js_expression=f"[...{lhs}, ...{rhs}]", - var_type=Union[lhs._var_type, rhs._var_type], # pyright: ignore [reportArgumentType] - ) - - class ColorVar(StringVar[Color], python_types=Color): """Base class for immutable color vars.""" @@ -1794,7 +1527,7 @@ class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar): Returns: The name of the var. """ - alpha = self._var_value.alpha + alpha = cast(Union[Var[bool], bool], self._var_value.alpha) alpha = ( ternary_operation( alpha, diff --git a/tests/integration/test_event_actions.py b/tests/integration/test_event_actions.py index 707410075..62600bf7a 100644 --- a/tests/integration/test_event_actions.py +++ b/tests/integration/test_event_actions.py @@ -28,9 +28,11 @@ def TestEventAction(): def on_click2(self): self.order.append("on_click2") + @rx.event def on_click_throttle(self): self.order.append("on_click_throttle") + @rx.event def on_click_debounce(self): self.order.append("on_click_debounce") diff --git a/tests/integration/test_lifespan.py b/tests/integration/test_lifespan.py index 24dd7df6a..fd0cad951 100644 --- a/tests/integration/test_lifespan.py +++ b/tests/integration/test_lifespan.py @@ -17,45 +17,68 @@ def LifespanApp(): import reflex as rx - lifespan_task_global = 0 - lifespan_context_global = 0 + def create_tasks(): + lifespan_task_global = 0 + lifespan_context_global = 0 - @asynccontextmanager - async def lifespan_context(app, inc: int = 1): - global lifespan_context_global - print(f"Lifespan context entered: {app}.") - lifespan_context_global += inc # pyright: ignore[reportUnboundVariable] - try: - yield - finally: - print("Lifespan context exited.") - lifespan_context_global += inc - - async def lifespan_task(inc: int = 1): - global lifespan_task_global - print("Lifespan global started.") - try: - while True: - lifespan_task_global += inc # pyright: ignore[reportUnboundVariable, reportPossiblyUnboundVariable] - await asyncio.sleep(0.1) - except asyncio.CancelledError as ce: - print(f"Lifespan global cancelled: {ce}.") - lifespan_task_global = 0 - - class LifespanState(rx.State): - interval: int = 100 - - @rx.var(cache=False) - def task_global(self) -> int: - return lifespan_task_global - - @rx.var(cache=False) - def context_global(self) -> int: + def lifespan_context_global_getter(): return lifespan_context_global - @rx.event - def tick(self, date): - pass + def lifespan_task_global_getter(): + return lifespan_task_global + + @asynccontextmanager + async def lifespan_context(app, inc: int = 1): + nonlocal lifespan_context_global + print(f"Lifespan context entered: {app}.") + lifespan_context_global += inc + try: + yield + finally: + print("Lifespan context exited.") + lifespan_context_global += inc + + async def lifespan_task(inc: int = 1): + nonlocal lifespan_task_global + print("Lifespan global started.") + try: + while True: + lifespan_task_global += inc + await asyncio.sleep(0.1) + except asyncio.CancelledError as ce: + print(f"Lifespan global cancelled: {ce}.") + lifespan_task_global = 0 + + class LifespanState(rx.State): + interval: int = 100 + + @rx.var(cache=False) + def task_global(self) -> int: + return lifespan_task_global + + @rx.var(cache=False) + def context_global(self) -> int: + return lifespan_context_global + + @rx.event + def tick(self, date): + pass + + return ( + lifespan_task, + lifespan_context, + LifespanState, + lifespan_task_global_getter, + lifespan_context_global_getter, + ) + + ( + lifespan_task, + lifespan_context, + LifespanState, + lifespan_task_global_getter, + lifespan_context_global_getter, + ) = create_tasks() def index(): return rx.vstack( @@ -113,13 +136,16 @@ async def test_lifespan(lifespan_app: AppHarness): task_global = driver.find_element(By.ID, "task_global") assert context_global.text == "2" - assert lifespan_app.app_module.lifespan_context_global == 2 + assert lifespan_app.app_module.lifespan_context_global_getter() == 2 original_task_global_text = task_global.text original_task_global_value = int(original_task_global_text) lifespan_app.poll_for_content(task_global, exp_not_equal=original_task_global_text) driver.find_element(By.ID, "toggle-tick").click() # avoid teardown errors - assert lifespan_app.app_module.lifespan_task_global > original_task_global_value + assert ( + lifespan_app.app_module.lifespan_task_global_getter() + > original_task_global_value + ) assert int(task_global.text) > original_task_global_value # Kill the backend @@ -129,5 +155,5 @@ async def test_lifespan(lifespan_app: AppHarness): lifespan_app.backend_thread.join() # Check that the lifespan tasks have been cancelled - assert lifespan_app.app_module.lifespan_task_global == 0 - assert lifespan_app.app_module.lifespan_context_global == 4 + assert lifespan_app.app_module.lifespan_task_global_getter() == 0 + assert lifespan_app.app_module.lifespan_context_global_getter() == 4 diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index e20b1cd6d..471382570 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -87,7 +87,7 @@ def UploadFile(): ), rx.box( rx.foreach( - rx.selected_files, + rx.selected_files(), lambda f: rx.text(f, as_="p"), ), id="selected_files", diff --git a/tests/integration/tests_playwright/test_appearance.py b/tests/integration/tests_playwright/test_appearance.py index d325b183f..0b1440ed1 100644 --- a/tests/integration/tests_playwright/test_appearance.py +++ b/tests/integration/tests_playwright/test_appearance.py @@ -61,7 +61,7 @@ def ColorToggleApp(): rx.icon(tag="moon", size=20), value="dark", ), - on_change=set_color_mode, + on_change=set_color_mode(), variant="classic", radius="large", value=color_mode, diff --git a/tests/units/components/core/test_banner.py b/tests/units/components/core/test_banner.py index e1498d12c..fc572e9a6 100644 --- a/tests/units/components/core/test_banner.py +++ b/tests/units/components/core/test_banner.py @@ -25,6 +25,7 @@ def test_connection_banner(): "react", "$/utils/context", "$/utils/state", + "@emotion/react", RadixThemesComponent().library or "", "$/env.json", ) @@ -43,6 +44,7 @@ def test_connection_modal(): "react", "$/utils/context", "$/utils/state", + "@emotion/react", RadixThemesComponent().library or "", "$/env.json", ) diff --git a/tests/units/components/core/test_cond.py b/tests/units/components/core/test_cond.py index ac073ed29..1fe1061f5 100644 --- a/tests/units/components/core/test_cond.py +++ b/tests/units/components/core/test_cond.py @@ -3,8 +3,7 @@ from typing import Any, Union import pytest -from reflex.components.base.fragment import Fragment -from reflex.components.core.cond import Cond, cond +from reflex.components.core.cond import cond from reflex.components.radix.themes.typography.text import Text from reflex.state import BaseState from reflex.utils.format import format_state_name @@ -40,32 +39,23 @@ def test_validate_cond(cond_state: BaseState): Args: cond_state: A fixture. """ - cond_component = cond( + first_component = Text.create("cond is True") + second_component = Text.create("cond is False") + cond_var = cond( cond_state.value, - Text.create("cond is True"), - Text.create("cond is False"), + first_component, + second_component, ) - cond_dict = cond_component.render() if type(cond_component) is Fragment else {} - assert cond_dict["name"] == "Fragment" - [condition] = cond_dict["children"] - assert condition["cond_state"] == f"isTrue({cond_state.get_full_name()}.value)" + assert isinstance(cond_var, Var) + assert ( + str(cond_var) + == f'({cond_state.value.bool()!s} ? (jsx(RadixThemesText, ({{ ["as"] : "p" }}), (jsx(Fragment, ({{ }}), "cond is True")))) : (jsx(RadixThemesText, ({{ ["as"] : "p" }}), (jsx(Fragment, ({{ }}), "cond is False")))))' + ) - # true value - true_value = condition["true_value"] - assert true_value["name"] == "Fragment" - - [true_value_text] = true_value["children"] - assert true_value_text["name"] == "RadixThemesText" - assert true_value_text["children"][0]["contents"] == '{"cond is True"}' - - # false value - false_value = condition["false_value"] - assert false_value["name"] == "Fragment" - - [false_value_text] = false_value["children"] - assert false_value_text["name"] == "RadixThemesText" - assert false_value_text["children"][0]["contents"] == '{"cond is False"}' + var_data = cond_var._get_all_var_data() + assert var_data is not None + assert var_data.components == (first_component, second_component) @pytest.mark.parametrize( @@ -99,22 +89,25 @@ def test_prop_cond(c1: Any, c2: Any): assert str(prop_cond) == f"(true ? {c1!s} : {c2!s})" -def test_cond_no_mix(): - """Test if cond can't mix components and props.""" - with pytest.raises(ValueError): - cond(True, LiteralVar.create("hello"), Text.create("world")) +def test_cond_mix(): + """Test if cond can mix components and props.""" + x = cond(True, LiteralVar.create("hello"), Text.create("world")) + assert isinstance(x, Var) + assert ( + str(x) + == '(true ? "hello" : (jsx(RadixThemesText, ({ ["as"] : "p" }), (jsx(Fragment, ({ }), "world")))))' + ) def test_cond_no_else(): """Test if cond can be used without else.""" # Components should support the use of cond without else comp = cond(True, Text.create("hello")) - assert isinstance(comp, Fragment) - comp = comp.children[0] - assert isinstance(comp, Cond) - assert comp.cond._decode() is True - assert comp.comp1.render() == Fragment.create(Text.create("hello")).render() # pyright: ignore [reportOptionalMemberAccess] - assert comp.comp2 == Fragment.create() + assert isinstance(comp, Var) + assert ( + str(comp) + == '(true ? (jsx(RadixThemesText, ({ ["as"] : "p" }), (jsx(Fragment, ({ }), "hello")))) : (jsx(Fragment, ({ }))))' + ) # Props do not support the use of cond without else with pytest.raises(ValueError): diff --git a/tests/units/components/core/test_foreach.py b/tests/units/components/core/test_foreach.py index 48fae85e8..c1659540b 100644 --- a/tests/units/components/core/test_foreach.py +++ b/tests/units/components/core/test_foreach.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set, Tuple, Union +from typing import Dict, List, Sequence, Set, Tuple, Union import pydantic.v1 import pytest @@ -6,16 +6,11 @@ import pytest from reflex import el from reflex.base import Base from reflex.components.component import Component -from reflex.components.core.foreach import ( - Foreach, - ForeachRenderError, - ForeachVarError, - foreach, -) +from reflex.components.core.foreach import ForeachVarError, foreach from reflex.components.radix.themes.layout.box import box from reflex.components.radix.themes.typography.text import text from reflex.state import BaseState, ComponentState -from reflex.vars.base import Var +from reflex.utils.exceptions import VarTypeError from reflex.vars.number import NumberVar from reflex.vars.sequence import ArrayVar @@ -125,7 +120,9 @@ def display_colors_set(color): return box(text(color)) -def display_nested_list_element(element: ArrayVar[List[str]], index: NumberVar[int]): +def display_nested_list_element( + element: ArrayVar[Sequence[str]], index: NumberVar[int] +): assert element._var_type == List[str] assert index._var_type is int return box(text(element[index])) @@ -139,143 +136,35 @@ def display_color_index_tuple(color): seen_index_vars = set() -@pytest.mark.parametrize( - "state_var, render_fn, render_dict", - [ - ( - ForEachState.colors_list, - display_color, - { - "iterable_state": f"{ForEachState.get_full_name()}.colors_list", - "iterable_type": "list", - }, - ), - ( - ForEachState.colors_dict_list, - display_color_name, - { - "iterable_state": f"{ForEachState.get_full_name()}.colors_dict_list", - "iterable_type": "list", - }, - ), - ( - ForEachState.colors_nested_dict_list, - display_shade, - { - "iterable_state": f"{ForEachState.get_full_name()}.colors_nested_dict_list", - "iterable_type": "list", - }, - ), - ( - ForEachState.primary_color, - display_primary_colors, - { - "iterable_state": f"Object.entries({ForEachState.get_full_name()}.primary_color)", - "iterable_type": "list", - }, - ), - ( - ForEachState.color_with_shades, - display_color_with_shades, - { - "iterable_state": f"Object.entries({ForEachState.get_full_name()}.color_with_shades)", - "iterable_type": "list", - }, - ), - ( - ForEachState.nested_colors_with_shades, - display_nested_color_with_shades, - { - "iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)", - "iterable_type": "list", - }, - ), - ( - ForEachState.nested_colors_with_shades, - display_nested_color_with_shades_v2, - { - "iterable_state": f"Object.entries({ForEachState.get_full_name()}.nested_colors_with_shades)", - "iterable_type": "list", - }, - ), - ( - ForEachState.color_tuple, - display_color_tuple, - { - "iterable_state": f"{ForEachState.get_full_name()}.color_tuple", - "iterable_type": "tuple", - }, - ), - ( - ForEachState.colors_set, - display_colors_set, - { - "iterable_state": f"{ForEachState.get_full_name()}.colors_set", - "iterable_type": "set", - }, - ), - ( - ForEachState.nested_colors_list, - lambda el, i: display_nested_list_element(el, i), - { - "iterable_state": f"{ForEachState.get_full_name()}.nested_colors_list", - "iterable_type": "list", - }, - ), - ( - ForEachState.color_index_tuple, - display_color_index_tuple, - { - "iterable_state": f"{ForEachState.get_full_name()}.color_index_tuple", - "iterable_type": "tuple", - }, - ), - ], -) -def test_foreach_render(state_var, render_fn, render_dict): - """Test that the foreach component renders without error. - - Args: - state_var: the state var. - render_fn: The render callable - render_dict: return dict on calling `component.render` - """ - component = Foreach.create(state_var, render_fn) - - rend = component.render() - assert rend["iterable_state"] == render_dict["iterable_state"] - assert rend["iterable_type"] == render_dict["iterable_type"] - - # Make sure the index vars are unique. - arg_index = rend["arg_index"] - assert isinstance(arg_index, Var) - assert arg_index._js_expr not in seen_index_vars - assert arg_index._var_type is int - seen_index_vars.add(arg_index._js_expr) - - def test_foreach_bad_annotations(): """Test that the foreach component raises a ForeachVarError if the iterable is of type Any.""" with pytest.raises(ForeachVarError): - Foreach.create( + foreach( ForEachState.bad_annotation_list, - lambda sublist: Foreach.create(sublist, lambda color: text(color)), + lambda sublist: foreach(sublist, lambda color: text(color)), ) def test_foreach_no_param_in_signature(): - """Test that the foreach component raises a ForeachRenderError if no parameters are passed.""" - with pytest.raises(ForeachRenderError): - Foreach.create( - ForEachState.colors_list, - lambda: text("color"), - ) + """Test that the foreach component DOES NOT raise an error if no parameters are passed.""" + foreach( + ForEachState.colors_list, + lambda: text("color"), + ) + + +def test_foreach_with_index(): + """Test that the foreach component works with an index.""" + foreach( + ForEachState.colors_list, + lambda color, index: text(color, index), + ) def test_foreach_too_many_params_in_signature(): """Test that the foreach component raises a ForeachRenderError if too many parameters are passed.""" - with pytest.raises(ForeachRenderError): - Foreach.create( + with pytest.raises(VarTypeError): + foreach( ForEachState.colors_list, lambda color, index, extra: text(color), ) @@ -290,13 +179,13 @@ def test_foreach_component_styles(): ) ) component._add_style_recursive({box: {"color": "red"}}) - assert 'css={({ ["color"] : "red" })}' in str(component) + assert '{ ["css"] : ({ ["color"] : "red" }) }' in str(component) def test_foreach_component_state(): """Test that using a component state to render in the foreach raises an error.""" with pytest.raises(TypeError): - Foreach.create( + foreach( ForEachState.colors_list, ComponentStateTest.create, ) @@ -304,7 +193,7 @@ def test_foreach_component_state(): def test_foreach_default_factory(): """Test that the default factory is called.""" - _ = Foreach.create( + _ = foreach( ForEachState.default_factory_list, lambda tag: text(tag.name), ) diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py index 11602b77a..b374b0f41 100644 --- a/tests/units/components/core/test_match.py +++ b/tests/units/components/core/test_match.py @@ -1,10 +1,10 @@ -from typing import List, Mapping, Tuple +import re +from typing import Tuple import pytest import reflex as rx -from reflex.components.component import Component -from reflex.components.core.match import Match +from reflex.components.core.match import match from reflex.state import BaseState from reflex.utils.exceptions import MatchTypeError from reflex.vars.base import Var @@ -18,75 +18,6 @@ class MatchState(BaseState): string: str = "random string" -def test_match_components(): - """Test matching cases with return values as components.""" - match_case_tuples = ( - (1, rx.text("first value")), - (2, 3, rx.text("second value")), - ([1, 2], rx.text("third value")), - ("random", rx.text("fourth value")), - ({"foo": "bar"}, rx.text("fifth value")), - (MatchState.num + 1, rx.text("sixth value")), - rx.text("default value"), - ) - match_comp = Match.create(MatchState.value, *match_case_tuples) - - assert isinstance(match_comp, Component) - match_dict = match_comp.render() - assert match_dict["name"] == "Fragment" - - [match_child] = match_dict["children"] - - assert match_child["name"] == "match" - assert str(match_child["cond"]) == f"{MatchState.get_name()}.value" - - match_cases = match_child["match_cases"] - assert len(match_cases) == 6 - - assert match_cases[0][0]._js_expr == "1" - assert match_cases[0][0]._var_type is int - first_return_value_render = match_cases[0][1] - assert first_return_value_render["name"] == "RadixThemesText" - assert first_return_value_render["children"][0]["contents"] == '{"first value"}' - - assert match_cases[1][0]._js_expr == "2" - assert match_cases[1][0]._var_type is int - assert match_cases[1][1]._js_expr == "3" - assert match_cases[1][1]._var_type is int - second_return_value_render = match_cases[1][2] - assert second_return_value_render["name"] == "RadixThemesText" - assert second_return_value_render["children"][0]["contents"] == '{"second value"}' - - assert match_cases[2][0]._js_expr == "[1, 2]" - assert match_cases[2][0]._var_type == List[int] - third_return_value_render = match_cases[2][1] - assert third_return_value_render["name"] == "RadixThemesText" - assert third_return_value_render["children"][0]["contents"] == '{"third value"}' - - assert match_cases[3][0]._js_expr == '"random"' - assert match_cases[3][0]._var_type is str - fourth_return_value_render = match_cases[3][1] - assert fourth_return_value_render["name"] == "RadixThemesText" - assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}' - - assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })' - assert match_cases[4][0]._var_type == Mapping[str, str] - fifth_return_value_render = match_cases[4][1] - assert fifth_return_value_render["name"] == "RadixThemesText" - assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}' - - assert match_cases[5][0]._js_expr == f"({MatchState.get_name()}.num + 1)" - assert match_cases[5][0]._var_type is int - fifth_return_value_render = match_cases[5][1] - assert fifth_return_value_render["name"] == "RadixThemesText" - assert fifth_return_value_render["children"][0]["contents"] == '{"sixth value"}' - - default = match_child["default"] - - assert default["name"] == "RadixThemesText" - assert default["children"][0]["contents"] == '{"default value"}' - - @pytest.mark.parametrize( "cases, expected", [ @@ -137,7 +68,7 @@ def test_match_vars(cases, expected): cases: The match cases. expected: The expected var full name. """ - match_comp = Match.create(MatchState.value, *cases) + match_comp = match(MatchState.value, *cases) # pyright: ignore[reportCallIssue] assert isinstance(match_comp, Var) assert str(match_comp) == expected @@ -146,18 +77,14 @@ def test_match_on_component_without_default(): """Test that matching cases with return values as components returns a Fragment as the default case if not provided. """ - from reflex.components.base.fragment import Fragment - match_case_tuples = ( (1, rx.text("first value")), (2, 3, rx.text("second value")), ) - match_comp = Match.create(MatchState.value, *match_case_tuples) - assert isinstance(match_comp, Component) - default = match_comp.render()["children"][0]["default"] + match_comp = match(MatchState.value, *match_case_tuples) - assert isinstance(default, dict) and default["name"] == Fragment.__name__ + assert isinstance(match_comp, Var) def test_match_on_var_no_default(): @@ -172,7 +99,7 @@ def test_match_on_var_no_default(): ValueError, match="For cases with return types as Vars, a default case must be provided", ): - Match.create(MatchState.value, *match_case_tuples) + match(MatchState.value, *match_case_tuples) @pytest.mark.parametrize( @@ -205,7 +132,7 @@ def test_match_default_not_last_arg(match_case): ValueError, match="rx.match should have tuples of cases and a default case as the last argument.", ): - Match.create(MatchState.value, *match_case) + match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] @pytest.mark.parametrize( @@ -235,7 +162,7 @@ def test_match_case_tuple_elements(match_case): ValueError, match="A case tuple should have at least a match case element and a return value.", ): - Match.create(MatchState.value, *match_case) + match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] @pytest.mark.parametrize( @@ -251,8 +178,7 @@ def test_match_case_tuple_elements(match_case): (MatchState.num + 1, "black"), rx.text("default value"), ), - 'Match cases should have the same return types. Case 3 with return value `"red"` of type ' - " is not ", + "Match cases should have the same return types. Expected return types to be of type Component or Var[Component]. Return type of case 3 is . Return type of case 4 is . Return type of case 5 is ", ), ( ( @@ -264,8 +190,7 @@ def test_match_case_tuple_elements(match_case): ([1, 2], rx.text("third value")), rx.text("default value"), ), - 'Match cases should have the same return types. Case 3 with return value ` {"first value"} ` ' - "of type is not ", + "Match cases should have the same return types. Expected return types to be of type Component or Var[Component]. Return type of case 0 is . Return type of case 1 is . Return type of case 2 is ", ), ], ) @@ -276,8 +201,8 @@ def test_match_different_return_types(cases: Tuple, error_msg: str): cases: The match cases. error_msg: Expected error message. """ - with pytest.raises(MatchTypeError, match=error_msg): - Match.create(MatchState.value, *cases) + with pytest.raises(MatchTypeError, match=re.escape(error_msg)): + match(MatchState.value, *cases) # pyright: ignore[reportCallIssue] @pytest.mark.parametrize( @@ -309,9 +234,9 @@ def test_match_multiple_default_cases(match_case): match_case: the cases to match. """ with pytest.raises(ValueError, match="rx.match can only have one default case."): - Match.create(MatchState.value, *match_case) + match(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] def test_match_no_cond(): with pytest.raises(ValueError): - _ = Match.create(None) + _ = match(None) # pyright: ignore[reportCallIssue] diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 8cffa6e0e..ca5a904ba 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -1457,13 +1457,10 @@ def test_instantiate_all_components(): # These components all have required arguments and cannot be trivially instantiated. untested_components = { "Card", - "Cond", "DebounceInput", - "Foreach", "FormControl", "Html", "Icon", - "Match", "Markdown", "MultiSelect", "Option", @@ -2156,14 +2153,11 @@ def test_add_style_foreach(): page = rx.vstack(rx.foreach(Var.range(3), lambda i: StyledComponent.create(i))) page._add_style_recursive(Style()) - # Expect only a single child of the foreach on the python side - assert len(page.children[0].children) == 1 - # Expect the style to be added to the child of the foreach - assert 'css={({ ["color"] : "red" })}' in str(page.children[0].children[0]) + assert '({ ["css"] : ({ ["color"] : "red" }) }),' in str(page.children[0]) # Expect only one instance of this CSS dict in the rendered page - assert str(page).count('css={({ ["color"] : "red" })}') == 1 + assert str(page).count('({ ["css"] : ({ ["color"] : "red" }) }),') == 1 class TriggerState(rx.State): diff --git a/tests/units/components/test_tag.py b/tests/units/components/test_tag.py index a69e40b8b..a83ebe41a 100644 --- a/tests/units/components/test_tag.py +++ b/tests/units/components/test_tag.py @@ -2,7 +2,7 @@ from typing import Dict, List import pytest -from reflex.components.tags import CondTag, Tag, tagless +from reflex.components.tags import Tag, tagless from reflex.vars.base import LiteralVar, Var @@ -105,29 +105,6 @@ def test_format_tag(tag: Tag, expected: Dict): assert prop_value.equals(LiteralVar.create(expected["props"][prop])) -def test_format_cond_tag(): - """Test that the cond tag dict is correct.""" - tag = CondTag( - true_value=dict(Tag(name="h1", contents="True content")), - false_value=dict(Tag(name="h2", contents="False content")), - cond=Var(_js_expr="logged_in", _var_type=bool), - ) - tag_dict = dict(tag) - cond, true_value, false_value = ( - tag_dict["cond"], - tag_dict["true_value"], - tag_dict["false_value"], - ) - assert cond._js_expr == "logged_in" - assert cond._var_type is bool - - assert true_value["name"] == "h1" - assert true_value["contents"] == "True content" - - assert false_value["name"] == "h2" - assert false_value["contents"] == "False content" - - def test_tagless_string_representation(): """Test that the string representation of a tagless is correct.""" tag = tagless.Tagless(contents="Hello world") diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 88cb36509..7c3a56d09 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -9,7 +9,7 @@ import unittest.mock import uuid from contextlib import nullcontext as does_not_raise from pathlib import Path -from typing import Generator, List, Tuple, Type +from typing import Generator, List, Tuple, Type, cast from unittest.mock import AsyncMock import pytest @@ -31,9 +31,8 @@ from reflex.app import ( ) from reflex.components import Component from reflex.components.base.fragment import Fragment -from reflex.components.core.cond import Cond from reflex.components.radix.themes.typography.text import Text -from reflex.event import Event +from reflex.event import Event, EventHandler from reflex.middleware import HydrateMiddleware from reflex.model import Model from reflex.state import ( @@ -916,7 +915,7 @@ class DynamicState(BaseState): """ return self.dynamic - on_load_internal = OnLoadInternalState.on_load_internal.fn # pyright: ignore [reportFunctionMemberAccess] + on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn def test_dynamic_arg_shadow( @@ -1189,7 +1188,7 @@ async def test_process_events(mocker, token: str): pass assert (await app.state_manager.get_state(event.substate_token)).value == 5 - assert app._postprocess.call_count == 6 # pyright: ignore [reportFunctionMemberAccess] + assert getattr(app._postprocess, "call_count", None) == 6 if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() @@ -1227,10 +1226,6 @@ def test_overlay_component( assert app.overlay_component is not None generated_component = app._generate_component(app.overlay_component) assert isinstance(generated_component, OverlayFragment) - assert isinstance( - generated_component.children[0], - Cond, # ConnectionModal is a Cond under the hood - ) else: assert app.overlay_component is not None assert isinstance( @@ -1246,8 +1241,8 @@ def test_overlay_component( if exp_page_child is not None: assert len(page.children) == 3 - children_types = (type(child) for child in page.children) - assert exp_page_child in children_types # pyright: ignore [reportOperatorIssue] + children_types = [type(child) for child in page.children] + assert exp_page_child in children_types else: assert len(page.children) == 2 diff --git a/tests/units/test_event.py b/tests/units/test_event.py index df4f282cf..6be23e907 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -6,6 +6,7 @@ import reflex as rx from reflex.constants.compiler import Hooks, Imports from reflex.event import ( Event, + EventActionsMixin, EventChain, EventHandler, EventSpec, @@ -410,6 +411,7 @@ def test_event_actions(): def test_event_actions_on_state(): class EventActionState(BaseState): + @rx.event def handler(self): pass @@ -417,7 +419,8 @@ def test_event_actions_on_state(): assert isinstance(handler, EventHandler) assert not handler.event_actions - sp_handler = EventActionState.handler.stop_propagation # pyright: ignore [reportFunctionMemberAccess] + sp_handler = EventActionState.handler.stop_propagation + assert isinstance(sp_handler, EventActionsMixin) assert sp_handler.event_actions == {"stopPropagation": True} # should NOT affect other references to the handler assert not handler.event_actions diff --git a/tests/units/test_health_endpoint.py b/tests/units/test_health_endpoint.py index 5b3aedc00..abfa6cc62 100644 --- a/tests/units/test_health_endpoint.py +++ b/tests/units/test_health_endpoint.py @@ -122,9 +122,12 @@ async def test_health( # Call the async health function response = await health() - print(json.loads(response.body)) # pyright: ignore [reportArgumentType] + body = response.body + assert isinstance(body, bytes) + + print(json.loads(body)) print(expected_status) # Verify the response content and status code assert response.status_code == expected_code - assert json.loads(response.body) == expected_status # pyright: ignore [reportArgumentType] + assert json.loads(body) == expected_status diff --git a/tests/units/test_state.py b/tests/units/test_state.py index e0390c5ac..f633d6a49 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -18,6 +18,7 @@ from typing import ( Dict, List, Optional, + Sequence, Set, Tuple, Union, @@ -121,8 +122,8 @@ class TestState(BaseState): num2: float = 3.14 key: str map_key: str = "a" - array: List[float] = [1, 2, 3.14] - mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]} + array: rx.Field[List[float]] = rx.field([1, 2, 3.14]) + mapping: rx.Field[Dict[str, List[int]]] = rx.field({"a": [1, 2, 3], "b": [4, 5, 6]}) obj: Object = Object() complex: Dict[int, Object] = {1: Object(), 2: Object()} fig: Figure = Figure() @@ -432,13 +433,16 @@ def test_default_setters(test_state): def test_class_indexing_with_vars(): """Test that we can index into a state var with another var.""" - prop = TestState.array[TestState.num1] # pyright: ignore [reportCallIssue, reportArgumentType] - assert str(prop) == f"{TestState.get_name()}.array.at({TestState.get_name()}.num1)" + prop = TestState.array[TestState.num1] + assert ( + str(prop) + == f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({TestState.get_name()}.array, ...args)))({TestState.get_name()}.num1))" + ) prop = TestState.mapping["a"][TestState.num1] # pyright: ignore [reportCallIssue, reportArgumentType] assert ( str(prop) - == f'{TestState.get_name()}.mapping["a"].at({TestState.get_name()}.num1)' + == f'(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({TestState.get_name()}.mapping["a"], ...args)))({TestState.get_name()}.num1))' ) prop = TestState.mapping[TestState.map_key] @@ -1358,6 +1362,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): class HandlerState(BaseState): x: int = 42 + @rx.event def handler(self): self.x = self.x + 1 @@ -1368,11 +1373,11 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): counter += 1 return counter + assert isinstance(HandlerState.handler, EventHandler) if use_partial: - HandlerState.handler = functools.partial(HandlerState.handler.fn) # pyright: ignore [reportFunctionMemberAccess] + partial_guy = functools.partial(HandlerState.handler.fn) + HandlerState.handler = partial_guy # pyright: ignore[reportAttributeAccessIssue] assert isinstance(HandlerState.handler, functools.partial) - else: - assert isinstance(HandlerState.handler, EventHandler) s = HandlerState() assert ( @@ -2029,8 +2034,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): # ensure state update was emitted assert mock_app.event_namespace is not None - mock_app.event_namespace.emit.assert_called_once() # pyright: ignore [reportFunctionMemberAccess] - mcall = mock_app.event_namespace.emit.mock_calls[0] # pyright: ignore [reportFunctionMemberAccess] + mock_app.event_namespace.emit.assert_called_once() # pyright: ignore[reportFunctionMemberAccess] + mock_calls = getattr(mock_app.event_namespace.emit, "mock_calls", None) + assert mock_calls is not None + assert isinstance(mock_calls, Sequence) + mcall = mock_calls[0] assert mcall.args[0] == str(SocketEvent.EVENT) assert mcall.args[1] == StateUpdate( delta={ @@ -2231,7 +2239,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): assert mock_app.event_namespace is not None emit_mock = mock_app.event_namespace.emit - first_ws_message = emit_mock.mock_calls[0].args[1] # pyright: ignore [reportFunctionMemberAccess] + mock_calls = getattr(emit_mock, "mock_calls", None) + assert mock_calls is not None + assert isinstance(mock_calls, Sequence) + + first_ws_message = mock_calls[0].args[1] assert ( first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router") is not None @@ -2246,7 +2258,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): events=[], final=True, ) - for call in emit_mock.mock_calls[1:5]: # pyright: ignore [reportFunctionMemberAccess] + for call in mock_calls[1:5]: assert call.args[1] == StateUpdate( delta={ BackgroundTaskState.get_full_name(): { @@ -2256,7 +2268,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): events=[], final=True, ) - assert emit_mock.mock_calls[-2].args[1] == StateUpdate( # pyright: ignore [reportFunctionMemberAccess] + assert mock_calls[-2].args[1] == StateUpdate( delta={ BackgroundTaskState.get_full_name(): { "order": exp_order, @@ -2267,7 +2279,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): events=[], final=True, ) - assert emit_mock.mock_calls[-1].args[1] == StateUpdate( # pyright: ignore [reportFunctionMemberAccess] + assert mock_calls[-1].args[1] == StateUpdate( delta={ BackgroundTaskState.get_full_name(): { "computed_order": exp_order, diff --git a/tests/units/test_var.py b/tests/units/test_var.py index 8fcd288e6..4d6840963 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -328,7 +328,10 @@ def test_basic_operations(TestObj): assert str(LiteralNumberVar.create(1) ** 2) == "(1 ** 2)" assert str(LiteralNumberVar.create(1) & v(2)) == "(1 && 2)" assert str(LiteralNumberVar.create(1) | v(2)) == "(1 || 2)" - assert str(LiteralArrayVar.create([1, 2, 3])[0]) == "[1, 2, 3].at(0)" + assert ( + str(LiteralArrayVar.create([1, 2, 3])[0]) + == "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3], ...args)))(0))" + ) assert ( str(LiteralObjectVar.create({"a": 1, "b": 2})["a"]) == '({ ["a"] : 1, ["b"] : 2 })["a"]' @@ -350,27 +353,33 @@ def test_basic_operations(TestObj): str(Var(_js_expr="foo").to(ObjectVar, TestObj)._var_set_state("state").bar) == 'state.foo["bar"]' ) - assert str(abs(LiteralNumberVar.create(1))) == "Math.abs(1)" - assert str(LiteralArrayVar.create([1, 2, 3]).length()) == "[1, 2, 3].length" + assert str(abs(LiteralNumberVar.create(1))) == "(Math.abs(1))" + assert ( + str(LiteralArrayVar.create([1, 2, 3]).length()) + == "(((...args) => (((_array) => _array.length)([1, 2, 3], ...args)))())" + ) assert ( str(LiteralArrayVar.create([1, 2]) + LiteralArrayVar.create([3, 4])) - == "[...[1, 2], ...[3, 4]]" + == "(((...args) => (((_lhs, _rhs) => [..._lhs, ..._rhs])([1, 2], ...args)))([3, 4]))" ) # Tests for reverse operation assert ( str(LiteralArrayVar.create([1, 2, 3]).reverse()) - == "[1, 2, 3].slice().reverse()" + == "(((...args) => (((_array) => _array.slice().reverse())([1, 2, 3], ...args)))())" ) assert ( str(LiteralArrayVar.create(["1", "2", "3"]).reverse()) - == '["1", "2", "3"].slice().reverse()' + == '(((...args) => (((_array) => _array.slice().reverse())(["1", "2", "3"], ...args)))())' ) assert ( str(Var(_js_expr="foo")._var_set_state("state").to(list).reverse()) - == "state.foo.slice().reverse()" + == "(((...args) => (((_array) => _array.slice().reverse())(state.foo, ...args)))())" + ) + assert ( + str(Var(_js_expr="foo").to(list).reverse()) + == "(((...args) => (((_array) => _array.slice().reverse())(foo, ...args)))())" ) - assert str(Var(_js_expr="foo").to(list).reverse()) == "foo.slice().reverse()" assert str(Var(_js_expr="foo", _var_type=str).js_type()) == "(typeof(foo))" @@ -395,14 +404,32 @@ def test_basic_operations(TestObj): ], ) def test_list_tuple_contains(var, expected): - assert str(var.contains(1)) == f"{expected}.includes(1)" - assert str(var.contains("1")) == f'{expected}.includes("1")' - assert str(var.contains(v(1))) == f"{expected}.includes(1)" - assert str(var.contains(v("1"))) == f'{expected}.includes("1")' + assert ( + str(var.contains(1)) + == f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(1))' + ) + assert ( + str(var.contains("1")) + == f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))("1"))' + ) + assert ( + str(var.contains(v(1))) + == f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(1))' + ) + assert ( + str(var.contains(v("1"))) + == f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))("1"))' + ) other_state_var = Var(_js_expr="other", _var_type=str)._var_set_state("state") other_var = Var(_js_expr="other", _var_type=str) - assert str(var.contains(other_state_var)) == f"{expected}.includes(state.other)" - assert str(var.contains(other_var)) == f"{expected}.includes(other)" + assert ( + str(var.contains(other_state_var)) + == f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(state.other))' + ) + assert ( + str(var.contains(other_var)) + == f'(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))({expected!s}, ...args)))(other))' + ) class Foo(rx.Base): @@ -446,15 +473,23 @@ def test_var_types(var, var_type): ], ) def test_str_contains(var, expected): - assert str(var.contains("1")) == f'{expected}.includes("1")' - assert str(var.contains(v("1"))) == f'{expected}.includes("1")' + assert ( + str(var.contains("1")) + == f'(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))("1"))' + ) + assert ( + str(var.contains(v("1"))) + == f'(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))("1"))' + ) other_state_var = Var(_js_expr="other")._var_set_state("state").to(str) other_var = Var(_js_expr="other").to(str) - assert str(var.contains(other_state_var)) == f"{expected}.includes(state.other)" - assert str(var.contains(other_var)) == f"{expected}.includes(other)" assert ( - str(var.contains("1", "hello")) - == f'{expected}.some(obj => obj["hello"] === "1")' + str(var.contains(other_state_var)) + == f"(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))(state.other))" + ) + assert ( + str(var.contains(other_var)) + == f"(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))({expected!s}, ...args)))(other))" ) @@ -467,16 +502,17 @@ def test_str_contains(var, expected): ], ) def test_dict_contains(var, expected): - assert str(var.contains(1)) == f"{expected}.hasOwnProperty(1)" - assert str(var.contains("1")) == f'{expected}.hasOwnProperty("1")' - assert str(var.contains(v(1))) == f"{expected}.hasOwnProperty(1)" - assert str(var.contains(v("1"))) == f'{expected}.hasOwnProperty("1")' + assert str(var.contains(1)) == f"{expected!s}.hasOwnProperty(1)" + assert str(var.contains("1")) == f'{expected!s}.hasOwnProperty("1")' + assert str(var.contains(v(1))) == f"{expected!s}.hasOwnProperty(1)" + assert str(var.contains(v("1"))) == f'{expected!s}.hasOwnProperty("1")' other_state_var = Var(_js_expr="other")._var_set_state("state").to(str) other_var = Var(_js_expr="other").to(str) assert ( - str(var.contains(other_state_var)) == f"{expected}.hasOwnProperty(state.other)" + str(var.contains(other_state_var)) + == f"{expected!s}.hasOwnProperty(state.other)" ) - assert str(var.contains(other_var)) == f"{expected}.hasOwnProperty(other)" + assert str(var.contains(other_var)) == f"{expected!s}.hasOwnProperty(other)" @pytest.mark.parametrize( @@ -484,7 +520,6 @@ def test_dict_contains(var, expected): [ Var(_js_expr="list", _var_type=List[int]).guess_type(), Var(_js_expr="tuple", _var_type=Tuple[int, int]).guess_type(), - Var(_js_expr="str", _var_type=str).guess_type(), ], ) def test_var_indexing_lists(var): @@ -494,11 +529,20 @@ def test_var_indexing_lists(var): var : The str, list or tuple base var. """ # Test basic indexing. - assert str(var[0]) == f"{var._js_expr}.at(0)" - assert str(var[1]) == f"{var._js_expr}.at(1)" + assert ( + str(var[0]) + == f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))(0))" + ) + assert ( + str(var[1]) + == f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))(1))" + ) # Test negative indexing. - assert str(var[-1]) == f"{var._js_expr}.at(-1)" + assert ( + str(var[-1]) + == f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))(-1))" + ) @pytest.mark.parametrize( @@ -532,11 +576,20 @@ def test_var_indexing_str(): assert str_var[0]._var_type is str # Test basic indexing. - assert str(str_var[0]) == "str.at(0)" - assert str(str_var[1]) == "str.at(1)" + assert ( + str(str_var[0]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))(0))" + ) + assert ( + str(str_var[1]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))(1))" + ) # Test negative indexing. - assert str(str_var[-1]) == "str.at(-1)" + assert ( + str(str_var[-1]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))(-1))" + ) @pytest.mark.parametrize( @@ -651,9 +704,18 @@ def test_var_list_slicing(var): Args: var : The str, list or tuple base var. """ - assert str(var[:1]) == f"{var._js_expr}.slice(undefined, 1)" - assert str(var[1:]) == f"{var._js_expr}.slice(1, undefined)" - assert str(var[:]) == f"{var._js_expr}.slice(undefined, undefined)" + assert ( + str(var[:1]) + == f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))([null, 1, null]))" + ) + assert ( + str(var[1:]) + == f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))([1, null, null]))" + ) + assert ( + str(var[:]) + == f"(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))({var!s}, ...args)))([null, null, null]))" + ) def test_str_var_slicing(): @@ -665,16 +727,40 @@ def test_str_var_slicing(): assert str_var[:1]._var_type is str # Test basic slicing. - assert str(str_var[:1]) == 'str.split("").slice(undefined, 1).join("")' - assert str(str_var[1:]) == 'str.split("").slice(1, undefined).join("")' - assert str(str_var[:]) == 'str.split("").slice(undefined, undefined).join("")' - assert str(str_var[1:2]) == 'str.split("").slice(1, 2).join("")' + assert ( + str(str_var[:1]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, 1, null]))" + ) + assert ( + str(str_var[1:]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([1, null, null]))" + ) + assert ( + str(str_var[:]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, null, null]))" + ) + assert ( + str(str_var[1:2]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([1, 2, null]))" + ) # Test negative slicing. - assert str(str_var[:-1]) == 'str.split("").slice(undefined, -1).join("")' - assert str(str_var[-1:]) == 'str.split("").slice(-1, undefined).join("")' - assert str(str_var[:-2]) == 'str.split("").slice(undefined, -2).join("")' - assert str(str_var[-2:]) == 'str.split("").slice(-2, undefined).join("")' + assert ( + str(str_var[:-1]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, -1, null]))" + ) + assert ( + str(str_var[-1:]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([-1, null, null]))" + ) + assert ( + str(str_var[:-2]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([null, -2, null]))" + ) + assert ( + str(str_var[-2:]) + == "(((...args) => (((_string, _index_or_slice) => Array.prototype.join.apply(atSliceOrIndex(_string, _index_or_slice), ['']))(str, ...args)))([-2, null, null]))" + ) def test_dict_indexing(): @@ -963,11 +1049,11 @@ def test_function_var(): def test_var_operation(): @var_operation - def add(a: Union[NumberVar, int], b: Union[NumberVar, int]): + def add(a: Var[int], b: Var[int]): return var_operation_return(js_expression=f"({a} + {b})", var_type=int) assert str(add(1, 2)) == "(1 + 2)" - assert str(add(a=4, b=-9)) == "(4 + -9)" + assert str(add(4, -9)) == "(4 + -9)" five = LiteralNumberVar.create(5) seven = add(2, five) @@ -978,13 +1064,29 @@ def test_var_operation(): def test_string_operations(): basic_string = LiteralStringVar.create("Hello, World!") - assert str(basic_string.length()) == '"Hello, World!".split("").length' - assert str(basic_string.lower()) == '"Hello, World!".toLowerCase()' - assert str(basic_string.upper()) == '"Hello, World!".toUpperCase()' - assert str(basic_string.strip()) == '"Hello, World!".trim()' - assert str(basic_string.contains("World")) == '"Hello, World!".includes("World")' assert ( - str(basic_string.split(" ").join(",")) == '"Hello, World!".split(" ").join(",")' + str(basic_string.length()) + == '(((...args) => (((...arg) => (((_array) => _array.length)((((_string, _sep = "") => isTrue(_sep) ? _string.split(_sep) : [..._string])(...args)))))("Hello, World!", ...args)))())' + ) + assert ( + str(basic_string.lower()) + == '(((...args) => (((_string) => String.prototype.toLowerCase.apply(_string))("Hello, World!", ...args)))())' + ) + assert ( + str(basic_string.upper()) + == '(((...args) => (((_string) => String.prototype.toUpperCase.apply(_string))("Hello, World!", ...args)))())' + ) + assert ( + str(basic_string.strip()) + == '(((...args) => (((_string) => String.prototype.trim.apply(_string))("Hello, World!", ...args)))())' + ) + assert ( + str(basic_string.contains("World")) + == '(((...args) => (((_haystack, _needle) => _haystack.includes(_needle))("Hello, World!", ...args)))("World"))' + ) + assert ( + str(basic_string.split(" ").join(",")) + == '(((...args) => (((_array, _sep = "") => Array.prototype.join.apply(_array,[_sep]))((((...args) => (((_string, _sep = "") => isTrue(_sep) ? _string.split(_sep) : [..._string])("Hello, World!", ...args)))(" ")), ...args)))(","))' ) @@ -1004,14 +1106,14 @@ def test_all_number_operations(): assert ( str(even_more_complicated_number) - == "!(isTrue((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))))" + == "!((isTrue(((Math.abs((Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)))))))" ) assert str(LiteralNumberVar.create(5) > False) == "(5 > 0)" - assert str(LiteralBooleanVar.create(False) < 5) == "(Number(false) < 5)" + assert str(LiteralBooleanVar.create(False) < 5) == "((Number(false)) < 5)" assert ( str(LiteralBooleanVar.create(False) < LiteralBooleanVar.create(True)) - == "(Number(false) < Number(true))" + == "((Number(false)) < (Number(true)))" ) @@ -1020,10 +1122,10 @@ def test_all_number_operations(): [ (Var.create(False), "false"), (Var.create(True), "true"), - (Var.create("false"), 'isTrue("false")'), - (Var.create([1, 2, 3]), "isTrue([1, 2, 3])"), - (Var.create({"a": 1, "b": 2}), 'isTrue(({ ["a"] : 1, ["b"] : 2 }))'), - (Var("mysterious_var"), "isTrue(mysterious_var)"), + (Var.create("false"), '(isTrue("false"))'), + (Var.create([1, 2, 3]), "(isTrue([1, 2, 3]))"), + (Var.create({"a": 1, "b": 2}), '(isTrue(({ ["a"] : 1, ["b"] : 2 })))'), + (Var("mysterious_var"), "(isTrue(mysterious_var))"), ], ) def test_boolify_operations(var, expected): @@ -1032,18 +1134,30 @@ def test_boolify_operations(var, expected): def test_index_operation(): array_var = LiteralArrayVar.create([1, 2, 3, 4, 5]) - assert str(array_var[0]) == "[1, 2, 3, 4, 5].at(0)" - assert str(array_var[1:2]) == "[1, 2, 3, 4, 5].slice(1, 2)" + assert ( + str(array_var[0]) + == "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))(0))" + ) + assert ( + str(array_var[1:2]) + == "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))([1, 2, null]))" + ) assert ( str(array_var[1:4:2]) - == "[1, 2, 3, 4, 5].slice(1, 4).filter((_, i) => i % 2 === 0)" + == "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))([1, 4, 2]))" ) assert ( str(array_var[::-1]) - == "[1, 2, 3, 4, 5].slice(0, [1, 2, 3, 4, 5].length).slice().reverse().slice(undefined, undefined).filter((_, i) => i % 1 === 0)" + == "(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))([null, null, -1]))" + ) + assert ( + str(array_var.reverse()) + == "(((...args) => (((_array) => _array.slice().reverse())([1, 2, 3, 4, 5], ...args)))())" + ) + assert ( + str(array_var[0].to(NumberVar) + 9) + == "((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))([1, 2, 3, 4, 5], ...args)))(0)) + 9)" ) - assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].slice().reverse()" - assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)" @pytest.mark.parametrize( @@ -1065,40 +1179,37 @@ def test_inf_and_nan(var, expected_js): def test_array_operations(): array_var = LiteralArrayVar.create([1, 2, 3, 4, 5]) - assert str(array_var.length()) == "[1, 2, 3, 4, 5].length" - assert str(array_var.contains(3)) == "[1, 2, 3, 4, 5].includes(3)" - assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].slice().reverse()" assert ( - str(ArrayVar.range(10)) - == "Array.from({ length: Math.ceil((10 - 0) / 1) }, (_, i) => 0 + i * 1)" + str(array_var.length()) + == "(((...args) => (((_array) => _array.length)([1, 2, 3, 4, 5], ...args)))())" ) assert ( - str(ArrayVar.range(1, 10)) - == "Array.from({ length: Math.ceil((10 - 1) / 1) }, (_, i) => 1 + i * 1)" + str(array_var.contains(3)) + == '(((...args) => (((_haystack, _needle, _field = "") => isTrue(_field) ? _haystack.some(obj => obj[_field] === _needle) : _haystack.some(obj => obj === _needle))([1, 2, 3, 4, 5], ...args)))(3))' ) assert ( - str(ArrayVar.range(1, 10, 2)) - == "Array.from({ length: Math.ceil((10 - 1) / 2) }, (_, i) => 1 + i * 2)" - ) - assert ( - str(ArrayVar.range(1, 10, -1)) - == "Array.from({ length: Math.ceil((10 - 1) / -1) }, (_, i) => 1 + i * -1)" + str(array_var.reverse()) + == "(((...args) => (((_array) => _array.slice().reverse())([1, 2, 3, 4, 5], ...args)))())" ) + assert str(ArrayVar.range(10)) == "[...range(10, null, 1)]" + assert str(ArrayVar.range(1, 10)) == "[...range(1, 10, 1)]" + assert str(ArrayVar.range(1, 10, 2)) == "[...range(1, 10, 2)]" + assert str(ArrayVar.range(1, 10, -1)) == "[...range(1, 10, -1)]" def test_object_operations(): object_var = LiteralObjectVar.create({"a": 1, "b": 2, "c": 3}) assert ( - str(object_var.keys()) == 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))' + str(object_var.keys()) == '(Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })))' ) assert ( str(object_var.values()) - == 'Object.values(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))' + == '(Object.values(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })))' ) assert ( str(object_var.entries()) - == 'Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))' + == '(Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })))' ) assert str(object_var.a) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]' assert str(object_var["a"]) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]' @@ -1139,12 +1250,12 @@ def test_type_chains(): List[int], ) assert ( - str(object_var.keys()[0].upper()) - == 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(0).toUpperCase()' + str(object_var.keys()[0].upper()) # pyright: ignore [reportAttributeAccessIssue] + == '(((...args) => (((_string) => String.prototype.toUpperCase.apply(_string))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(0)), ...args)))())' ) assert ( - str(object_var.entries()[1][1] - 1) - == '(Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(1).at(1) - 1)' + str(object_var.entries()[1][1] - 1) # pyright: ignore [reportCallIssue, reportOperatorIssue] + == '((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))), ...args)))(1)), ...args)))(1)) - 1)' ) assert ( str(object_var["c"] + object_var["b"]) # pyright: ignore [reportCallIssue, reportOperatorIssue] @@ -1153,10 +1264,14 @@ def test_type_chains(): def test_nested_dict(): - arr = LiteralArrayVar.create([{"bar": ["foo", "bar"]}], List[Dict[str, List[str]]]) + arr = Var.create([{"bar": ["foo", "bar"]}]).to(List[Dict[str, List[str]]]) + first_dict = arr.at(0) + bar_element = first_dict["bar"] + first_bar_element = bar_element[0] assert ( - str(arr[0]["bar"][0]) == '[({ ["bar"] : ["foo", "bar"] })].at(0)["bar"].at(0)' # pyright: ignore [reportIndexIssue] + str(first_bar_element) + == '(((...args) => (((_array, _index_or_slice) => atSliceOrIndex(_array, _index_or_slice))((((...args) => (((_array, _index) => _array.at(_index))([({ ["bar"] : ["foo", "bar"] })], ...args)))(0))["bar"], ...args)))(0))' # pyright: ignore [reportIndexIssue] ) @@ -1331,9 +1446,9 @@ def test_unsupported_types_for_reverse(var): Args: var: The base var. """ - with pytest.raises(TypeError) as err: + with pytest.raises(AttributeError) as err: var.reverse() - assert err.value.args[0] == "Cannot reverse non-list var." + assert err.value.args[0] == "'Var' object has no attribute 'reverse'" @pytest.mark.parametrize( @@ -1351,12 +1466,9 @@ def test_unsupported_types_for_contains(var: Var): Args: var: The base var. """ - with pytest.raises(TypeError) as err: + with pytest.raises(AttributeError) as err: assert var.contains(1) # pyright: ignore [reportAttributeAccessIssue] - assert ( - err.value.args[0] - == f"Var of type {var._var_type} does not support contains check." - ) + assert err.value.args[0] == "'Var' object has no attribute 'contains'" @pytest.mark.parametrize( @@ -1376,7 +1488,7 @@ def test_unsupported_types_for_string_contains(other): assert Var(_js_expr="var").to(str).contains(other) assert ( err.value.args[0] - == f"Unsupported Operand type(s) for contains: ToStringOperation, {type(other).__name__}" + == f"Invalid argument other provided to argument 0 in var operation. Expected but got {other._var_type}." ) @@ -1608,17 +1720,12 @@ def test_valid_var_operations(operand1_var: Var, operand2_var, operators: List[s LiteralVar.create([10, 20]), LiteralVar.create("5"), [ - "+", "-", "/", "//", "*", "%", "**", - ">", - "<", - "<=", - ">=", "^", "<<", ">>", diff --git a/tests/units/utils/test_format.py b/tests/units/utils/test_format.py index 89197a03e..500e6341f 100644 --- a/tests/units/utils/test_format.py +++ b/tests/units/utils/test_format.py @@ -2,7 +2,7 @@ from __future__ import annotations import datetime import json -from typing import Any, List +from typing import Any import plotly.graph_objects as go import pytest @@ -98,60 +98,6 @@ def test_wrap(text: str, open: str, expected: str, check_first: bool, num: int): assert format.wrap(text, open, check_first=check_first, num=num) == expected -@pytest.mark.parametrize( - "string,expected_output", - [ - ("This is a random string", "This is a random string"), - ( - "This is a random string with `backticks`", - "This is a random string with \\`backticks\\`", - ), - ( - "This is a random string with `backticks`", - "This is a random string with \\`backticks\\`", - ), - ( - "This is a string with ${someValue[`string interpolation`]} unescaped", - "This is a string with ${someValue[`string interpolation`]} unescaped", - ), - ( - "This is a string with `backticks` and ${someValue[`string interpolation`]} unescaped", - "This is a string with \\`backticks\\` and ${someValue[`string interpolation`]} unescaped", - ), - ( - "This is a string with `backticks`, ${someValue[`the first string interpolation`]} and ${someValue[`the second`]}", - "This is a string with \\`backticks\\`, ${someValue[`the first string interpolation`]} and ${someValue[`the second`]}", - ), - ], -) -def test_escape_js_string(string, expected_output): - assert format._escape_js_string(string) == expected_output - - -@pytest.mark.parametrize( - "text,indent_level,expected", - [ - ("", 2, ""), - ("hello", 2, "hello"), - ("hello\nworld", 2, " hello\n world\n"), - ("hello\nworld", 4, " hello\n world\n"), - (" hello\n world", 2, " hello\n world\n"), - ], -) -def test_indent(text: str, indent_level: int, expected: str, windows_platform: bool): - """Test indenting a string. - - Args: - text: The text to indent. - indent_level: The number of spaces to indent by. - expected: The expected output string. - windows_platform: Whether the system is windows. - """ - assert format.indent(text, indent_level) == ( - expected.replace("\n", "\r\n") if windows_platform else expected - ) - - @pytest.mark.parametrize( "input,output", [ @@ -252,25 +198,6 @@ def test_to_kebab_case(input: str, output: str): assert format.to_kebab_case(input) == output -@pytest.mark.parametrize( - "input,output", - [ - ("", "{``}"), - ("hello", "{`hello`}"), - ("hello world", "{`hello world`}"), - ("hello=`world`", "{`hello=\\`world\\``}"), - ], -) -def test_format_string(input: str, output: str): - """Test formatting the input as JS string literal. - - Args: - input: the input string. - output: the output string. - """ - assert format.format_string(input) == output - - @pytest.mark.parametrize( "input,output", [ @@ -310,45 +237,6 @@ def test_format_route(route: str, format_case: bool, expected: bool): assert format.format_route(route, format_case=format_case) == expected -@pytest.mark.parametrize( - "condition, match_cases, default,expected", - [ - ( - "state__state.value", - [ - [LiteralVar.create(1), LiteralVar.create("red")], - [LiteralVar.create(2), LiteralVar.create(3), LiteralVar.create("blue")], - [TestState.mapping, TestState.num1], - [ - LiteralVar.create(f"{TestState.map_key}-key"), - LiteralVar.create("return-key"), - ], - ], - LiteralVar.create("yellow"), - '(() => { switch (JSON.stringify(state__state.value)) {case JSON.stringify(1): return ("red"); break;case JSON.stringify(2): case JSON.stringify(3): ' - f'return ("blue"); break;case JSON.stringify({TestState.get_full_name()}.mapping): return ' - f'({TestState.get_full_name()}.num1); break;case JSON.stringify(({TestState.get_full_name()}.map_key+"-key")): return ("return-key");' - ' break;default: return ("yellow"); break;};})()', - ) - ], -) -def test_format_match( - condition: str, - match_cases: List[List[Var]], - default: Var, - expected: str, -): - """Test formatting a match statement. - - Args: - condition: The condition to match. - match_cases: List of match cases to be matched. - default: Catchall case for the match statement. - expected: The expected string output. - """ - assert format.format_match(condition, match_cases, default) == expected - - @pytest.mark.parametrize( "prop,formatted", [ diff --git a/tests/units/utils/test_utils.py b/tests/units/utils/test_utils.py index 74dcf79b0..427671f11 100644 --- a/tests/units/utils/test_utils.py +++ b/tests/units/utils/test_utils.py @@ -2,7 +2,7 @@ import os import typing from functools import cached_property from pathlib import Path -from typing import Any, ClassVar, Dict, List, Literal, Type, Union +from typing import Any, ClassVar, Dict, List, Literal, Sequence, Tuple, Type, Union import pytest import typer @@ -109,10 +109,20 @@ def test_is_generic_alias(cls: type, expected: bool): (Dict[str, str], dict[str, str], True), (Dict[str, str], dict[str, Any], True), (Dict[str, Any], dict[str, Any], True), + (List[int], Sequence[int], True), + (List[str], Sequence[int], False), + (Tuple[int], Sequence[int], True), + (Tuple[int, str], Sequence[int], False), + (Tuple[int, ...], Sequence[int], True), + (str, Sequence[int], False), + (str, Sequence[str], True), ], ) def test_typehint_issubclass(subclass, superclass, expected): - assert types.typehint_issubclass(subclass, superclass) == expected + if expected: + assert types.typehint_issubclass(subclass, superclass) + else: + assert not types.typehint_issubclass(subclass, superclass) def test_validate_none_bun_path(mocker):