diff --git a/reflex/__init__.py b/reflex/__init__.py index 9364f02a2..d029486a2 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -324,7 +324,8 @@ _MAPPING: dict = { "style": ["Style", "toggle_color_mode"], "utils.imports": ["ImportVar"], "utils.serializers": ["serializer"], - "vars": ["cached_var", "Var"], + "vars": ["Var"], + "ivars.base": ["cached_var"], } _SUBMODULES: set[str] = { diff --git a/reflex/__init__.pyi b/reflex/__init__.pyi index 94103a1d0..5d3fea2bb 100644 --- a/reflex/__init__.pyi +++ b/reflex/__init__.pyi @@ -175,6 +175,7 @@ from .event import stop_propagation as stop_propagation from .event import upload_files as upload_files from .event import window_alert as window_alert from .experimental import _x as _x +from .ivars.base import cached_var as cached_var from .middleware import Middleware as Middleware from .middleware import middleware as middleware from .model import Model as Model @@ -191,7 +192,6 @@ from .style import toggle_color_mode as toggle_color_mode from .utils.imports import ImportVar as ImportVar from .utils.serializers import serializer as serializer from .vars import Var as Var -from .vars import cached_var as cached_var del compat RADIX_THEMES_MAPPING: dict diff --git a/reflex/app.py b/reflex/app.py index 90c29fefa..ca438fd2d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -442,7 +442,11 @@ class App(MiddlewareMixin, LifespanMixin, Base): raise except TypeError as e: message = str(e) - if "BaseVar" in message or "ComputedVar" in message: + if ( + "BaseVar" in message + or "ComputedVar" in message + or "ImmutableComputedVar" in message + ): raise VarOperationTypeError( "You may be trying to use an invalid Python function on a state var. " "When referencing a var inside your render code, only limited var operations are supported. " diff --git a/reflex/components/base/bare.py b/reflex/components/base/bare.py index 8cc83b83e..dca9a9287 100644 --- a/reflex/components/base/bare.py +++ b/reflex/components/base/bare.py @@ -28,7 +28,7 @@ class Bare(Component): """ if isinstance(contents, ImmutableVar): return cls(contents=contents) - if isinstance(contents, Var) and contents._var_data: + if isinstance(contents, Var) and contents._get_all_var_data(): contents = contents.to(str) else: contents = str(contents) if contents is not None else "" diff --git a/reflex/components/component.py b/reflex/components/component.py index 159c6e6ca..af583a735 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1121,7 +1121,8 @@ class Component(BaseComponent, ABC): for child in self.children: if not isinstance(child, Component): continue - vars.extend(child._get_vars(include_children=include_children)) + child_vars = child._get_vars(include_children=include_children) + vars.extend(child_vars) return vars @@ -1326,13 +1327,13 @@ class Component(BaseComponent, ABC): other_imports = [] user_hooks = self._get_hooks() - if ( - user_hooks is not None - and isinstance(user_hooks, Var) - and user_hooks._var_data is not None - and user_hooks._var_data.imports - ): - other_imports.append(user_hooks._var_data.imports) + user_hooks_data = ( + VarData.merge(user_hooks._get_all_var_data()) + if user_hooks is not None and isinstance(user_hooks, Var) + else None + ) + if user_hooks_data is not None: + other_imports.append(user_hooks_data.imports) other_imports.extend( hook_imports for hook_imports in self._get_added_hooks().values() ) @@ -1830,9 +1831,11 @@ class CustomComponent(Component): Returns: Each var referenced by the component (props, styles, event handlers). """ - return super()._get_vars(include_children=include_children) + [ - prop for prop in self.props.values() if isinstance(prop, Var) - ] + return ( + super()._get_vars(include_children=include_children) + + [prop for prop in self.props.values() if isinstance(prop, Var)] + + self.get_component(self)._get_vars(include_children=include_children) + ) @lru_cache(maxsize=None) # noqa def get_component(self) -> Component: diff --git a/reflex/components/core/banner.py b/reflex/components/core/banner.py index acdab19c5..0b79375b5 100644 --- a/reflex/components/core/banner.py +++ b/reflex/components/core/banner.py @@ -153,7 +153,7 @@ class ConnectionToaster(Toaster): }} """ ), - LiteralArrayVar([connect_errors]), + LiteralArrayVar.create([connect_errors]), ), ] diff --git a/reflex/components/core/banner.pyi b/reflex/components/core/banner.pyi index b9b6d506f..23cd1c641 100644 --- a/reflex/components/core/banner.pyi +++ b/reflex/components/core/banner.pyi @@ -29,7 +29,7 @@ class WebsocketTargetURL(Bare): def create( # type: ignore cls, *children, - contents: Optional[Union[Var[str], str]] = None, + contents: Optional[Union[Var[Any], Any]] = None, style: Optional[Style] = None, key: Optional[Any] = None, id: Optional[Any] = None, diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index f00f7a150..8d8f93c7e 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -104,7 +104,7 @@ class Cond(MemoizationLeaf): The import dict for the component. """ cond_imports: dict[str, str | ImportVar | list[str | ImportVar]] = getattr( - self.cond._var_data, "imports", {} + VarData.merge(self.cond._get_all_var_data()), "imports", {} ) return {**cond_imports, **_IS_TRUE_IMPORT} @@ -135,6 +135,8 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | ImmutableVar: Raises: ValueError: If the arguments are invalid. """ + if isinstance(condition, Var) and not isinstance(condition, ImmutableVar): + condition._var_is_local = True # Convert the condition to a Var. cond_var = LiteralVar.create(condition) assert cond_var is not None, "The condition must be set." @@ -161,8 +163,8 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | ImmutableVar: c2 = create_var(c2) # Create the conditional var. - return TernaryOperator( - condition=cond_var, + return TernaryOperator.create( + condition=cond_var.to(bool), if_true=c1, if_false=c2, _var_data=VarData(imports=_IS_TRUE_IMPORT), diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py index 557844c81..9725c1754 100644 --- a/reflex/components/core/match.py +++ b/reflex/components/core/match.py @@ -11,7 +11,7 @@ from reflex.style import Style from reflex.utils import format, types from reflex.utils.exceptions import MatchTypeError from reflex.utils.imports import ImportDict -from reflex.vars import ImmutableVarData, Var +from reflex.vars import ImmutableVarData, Var, VarData class Match(MemoizationLeaf): @@ -264,7 +264,7 @@ class Match(MemoizationLeaf): Returns: The import dict. """ - return getattr(self.cond._var_data, "imports", {}) + return getattr(VarData.merge(self.cond._get_all_var_data()), "imports", {}) match = Match.create diff --git a/reflex/components/core/upload.py b/reflex/components/core/upload.py index f45ba276e..854444ae3 100644 --- a/reflex/components/core/upload.py +++ b/reflex/components/core/upload.py @@ -22,7 +22,7 @@ from reflex.event import ( from reflex.ivars.base import ImmutableCallableVar, ImmutableVar from reflex.ivars.sequence import LiteralStringVar from reflex.utils.imports import ImportVar -from reflex.vars import Var, VarData +from reflex.vars import ImmutableVarData, Var, VarData DEFAULT_UPLOAD_ID: str = "default" @@ -61,7 +61,7 @@ def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> ImmutableVar: return ImmutableVar( _var_name=var_name, _var_type=EventChain, - _var_data=VarData.merge( + _var_data=ImmutableVarData.merge( upload_files_context_var_data, id_var._get_all_var_data() ), ) @@ -81,7 +81,7 @@ def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> ImmutableVar: return ImmutableVar( _var_name=f"(filesById[{str(id_var)}] ? filesById[{str(id_var)}].map((f) => (f.path || f.name)) : [])", _var_type=List[str], - _var_data=VarData.merge( + _var_data=ImmutableVarData.merge( upload_files_context_var_data, id_var._get_all_var_data() ), ).guess_type() diff --git a/reflex/components/core/upload.pyi b/reflex/components/core/upload.pyi index 54bb351de..d0b89a898 100644 --- a/reflex/components/core/upload.pyi +++ b/reflex/components/core/upload.pyi @@ -12,16 +12,17 @@ from reflex.event import ( EventHandler, EventSpec, ) +from reflex.ivars.base import ImmutableCallableVar, ImmutableVar from reflex.style import Style -from reflex.vars import BaseVar, CallableVar, Var, VarData +from reflex.vars import BaseVar, Var, VarData DEFAULT_UPLOAD_ID: str upload_files_context_var_data: VarData -@CallableVar -def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ... -@CallableVar -def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> BaseVar: ... +@ImmutableCallableVar +def upload_file(id_: str = DEFAULT_UPLOAD_ID) -> ImmutableVar: ... +@ImmutableCallableVar +def selected_files(id_: str = DEFAULT_UPLOAD_ID) -> ImmutableVar: ... @CallableEventSpec def clear_selected_files(id_: str = DEFAULT_UPLOAD_ID) -> EventSpec: ... def cancel_upload(upload_id: str) -> EventSpec: ... diff --git a/reflex/components/datadisplay/code.py b/reflex/components/datadisplay/code.py index 56ea1fd9a..e2d82c610 100644 --- a/reflex/components/datadisplay/code.py +++ b/reflex/components/datadisplay/code.py @@ -390,7 +390,7 @@ class CodeBlock(Component): The import dict. """ imports_: ImportDict = {} - themes = re.findall(r"`(.*?)`", self.theme._var_name) + themes = re.findall(r'"(.*?)"', self.theme._var_name) if not themes: themes = [self.theme._var_name] @@ -509,11 +509,8 @@ class CodeBlock(Component): style=ImmutableVar.create( format.to_camel_case(f"{predicate}{qmark}{value.replace('`', '')}"), ) - ).remove_props("theme", "code") - if self.code is not None: - out.special_props.add( - Var.create_safe(f"children={str(self.code)}", _var_is_string=False) - ) + ).remove_props("theme", "code").add_props(children=self.code) + return out @staticmethod diff --git a/reflex/components/el/elements/forms.py b/reflex/components/el/elements/forms.py index 0ea9aefd4..de767ccb2 100644 --- a/reflex/components/el/elements/forms.py +++ b/reflex/components/el/elements/forms.py @@ -14,7 +14,7 @@ from reflex.event import EventChain, EventHandler from reflex.ivars.base import ImmutableVar from reflex.utils.format import format_event_chain from reflex.utils.imports import ImportDict -from reflex.vars import Var +from reflex.vars import Var, VarData from .base import BaseHTML @@ -218,7 +218,7 @@ class Form(BaseHTML): f"getRefValues({str(ref_var)})", _var_is_local=False, _var_is_string=False, - _var_data=ref_var._var_data, + _var_data=VarData.merge(ref_var._get_all_var_data()), ) else: ref_var = Var.create_safe(ref, _var_is_string=False).as_ref() @@ -226,7 +226,7 @@ class Form(BaseHTML): f"getRefValue({str(ref_var)})", _var_is_local=False, _var_is_string=False, - _var_data=ref_var._var_data, + _var_data=VarData.merge(ref_var._get_all_var_data()), ) return form_refs @@ -632,7 +632,7 @@ class Textarea(BaseHTML): f"(e) => enterKeySubmitOnKeyDown(e, {self.enter_key_submit._var_name_unwrapped})", _var_is_local=False, _var_is_string=False, - _var_data=self.enter_key_submit._var_data, + _var_data=VarData.merge(self.enter_key_submit._get_all_var_data()), ) ) if self.auto_height is not None: @@ -641,7 +641,7 @@ class Textarea(BaseHTML): f"(e) => autoHeightOnInput(e, {self.auto_height._var_name_unwrapped})", _var_is_local=False, _var_is_string=False, - _var_data=self.auto_height._var_data, + _var_data=VarData.merge(self.auto_height._get_all_var_data()), ) ) return tag diff --git a/reflex/components/el/elements/forms.pyi b/reflex/components/el/elements/forms.pyi index 31a56e9be..4d03e9c5e 100644 --- a/reflex/components/el/elements/forms.pyi +++ b/reflex/components/el/elements/forms.pyi @@ -9,13 +9,14 @@ from jinja2 import Environment from reflex.components.el.element import Element from reflex.event import EventHandler, EventSpec +from reflex.ivars.base import ImmutableVar from reflex.style import Style from reflex.utils.imports import ImportDict from reflex.vars import BaseVar, Var from .base import BaseHTML -FORM_DATA = Var.create("form_data", _var_is_string=False) +FORM_DATA = ImmutableVar.create("form_data") HANDLE_SUBMIT_JS_JINJA2 = Environment().from_string( "\n const handleSubmit_{{ handle_submit_unique_name }} = useCallback((ev) => {\n const $form = ev.target\n ev.preventDefault()\n const {{ form_data }} = {...Object.fromEntries(new FormData($form).entries()), ...{{ field_ref_mapping }}}\n\n {{ on_submit_event_chain }}\n\n if ({{ reset_on_submit }}) {\n $form.reset()\n }\n })\n " ) diff --git a/reflex/components/el/elements/metadata.py b/reflex/components/el/elements/metadata.py index c19612abe..9a4d18b73 100644 --- a/reflex/components/el/elements/metadata.py +++ b/reflex/components/el/elements/metadata.py @@ -29,24 +29,49 @@ class Link(BaseHTML): # noqa: E742 tag = "link" + # Specifies the CORS settings for the linked resource cross_origin: Var[Union[str, int, bool]] + + # Specifies the URL of the linked document/resource href: Var[Union[str, int, bool]] + + # Specifies the language of the text in the linked document href_lang: Var[Union[str, int, bool]] + + # Allows a browser to check the fetched link for integrity integrity: Var[Union[str, int, bool]] + + # Specifies on what device the linked document will be displayed media: Var[Union[str, int, bool]] + + # Specifies the referrer policy of the linked document referrer_policy: Var[Union[str, int, bool]] + + # Specifies the relationship between the current document and the linked one rel: Var[Union[str, int, bool]] + + # Specifies the sizes of icons for visual media sizes: Var[Union[str, int, bool]] + + # Specifies the MIME type of the linked document type: Var[Union[str, int, bool]] class Meta(BaseHTML): # Inherits common attributes from BaseHTML """Display the meta element.""" - tag = "meta" + tag = "meta" # The HTML tag for this element is + + # Specifies the character encoding for the HTML document char_set: Var[Union[str, int, bool]] + + # Defines the content of the metadata content: Var[Union[str, int, bool]] + + # Provides an HTTP header for the information/value of the content attribute http_equiv: Var[Union[str, int, bool]] + + # Specifies a name for the metadata name: Var[Union[str, int, bool]] diff --git a/reflex/components/el/elements/metadata.pyi b/reflex/components/el/elements/metadata.pyi index d4d68adb6..e08c1d723 100644 --- a/reflex/components/el/elements/metadata.pyi +++ b/reflex/components/el/elements/metadata.pyi @@ -346,6 +346,15 @@ class Link(BaseHTML): Args: *children: The children of the component. + cross_origin: Specifies the CORS settings for the linked resource + href: Specifies the URL of the linked document/resource + href_lang: Specifies the language of the text in the linked document + integrity: Allows a browser to check the fetched link for integrity + media: Specifies on what device the linked document will be displayed + referrer_policy: Specifies the referrer policy of the linked document + rel: Specifies the relationship between the current document and the linked one + sizes: Specifies the sizes of icons for visual media + type: Specifies the MIME type of the linked document access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. @@ -466,6 +475,10 @@ class Meta(BaseHTML): Args: *children: The children of the component. + char_set: Specifies the character encoding for the HTML document + content: Defines the content of the metadata + http_equiv: Provides an HTTP header for the information/value of the content attribute + name: Specifies a name for the metadata access_key: Provides a hint for generating a keyboard shortcut for the current element. auto_capitalize: Controls whether and how text input is automatically capitalized as it is entered/edited by the user. content_editable: Indicates whether the element's content is editable. diff --git a/reflex/components/gridjs/datatable.py b/reflex/components/gridjs/datatable.py index 6d856cf45..075c08d59 100644 --- a/reflex/components/gridjs/datatable.py +++ b/reflex/components/gridjs/datatable.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Union from reflex.components.component import Component from reflex.components.tags import Tag +from reflex.ivars.base import ImmutableComputedVar from reflex.utils import types from reflex.utils.imports import ImportDict from reflex.utils.serializers import serialize @@ -65,14 +66,17 @@ class DataTable(Gridjs): # The annotation should be provided if data is a computed var. We need this to know how to # render pandas dataframes. - if isinstance(data, ComputedVar) and data._var_type == Any: + if ( + isinstance(data, (ComputedVar, ImmutableComputedVar)) + and data._var_type == Any + ): raise ValueError( "Annotation of the computed var assigned to the data field should be provided." ) if ( columns is not None - and isinstance(columns, ComputedVar) + and isinstance(columns, (ComputedVar, ImmutableComputedVar)) and columns._var_type == Any ): raise ValueError( diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index 536b43930..1d9f94ea1 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -17,30 +17,26 @@ from reflex.components.radix.themes.typography.heading import Heading from reflex.components.radix.themes.typography.link import Link from reflex.components.radix.themes.typography.text import Text from reflex.components.tags.tag import Tag -from reflex.ivars.base import LiteralVar +from reflex.ivars.base import ImmutableVar, LiteralVar from reflex.utils import types from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var # Special vars used in the component map. -_CHILDREN = Var.create_safe("children", _var_is_local=False, _var_is_string=False) -_PROPS = Var.create_safe("...props", _var_is_local=False, _var_is_string=False) -_MOCK_ARG = Var.create_safe("", _var_is_string=False) +_CHILDREN = ImmutableVar.create_safe("children") +_PROPS = ImmutableVar.create_safe("...props") +_MOCK_ARG = ImmutableVar.create_safe("") # Special remark plugins. -_REMARK_MATH = Var.create_safe("remarkMath", _var_is_local=False, _var_is_string=False) -_REMARK_GFM = Var.create_safe("remarkGfm", _var_is_local=False, _var_is_string=False) -_REMARK_UNWRAP_IMAGES = Var.create_safe( - "remarkUnwrapImages", _var_is_local=False, _var_is_string=False -) -_REMARK_PLUGINS = Var.create_safe([_REMARK_MATH, _REMARK_GFM, _REMARK_UNWRAP_IMAGES]) +_REMARK_MATH = ImmutableVar.create_safe("remarkMath") +_REMARK_GFM = ImmutableVar.create_safe("remarkGfm") +_REMARK_UNWRAP_IMAGES = ImmutableVar.create_safe("remarkUnwrapImages") +_REMARK_PLUGINS = LiteralVar.create([_REMARK_MATH, _REMARK_GFM, _REMARK_UNWRAP_IMAGES]) # Special rehype plugins. -_REHYPE_KATEX = Var.create_safe( - "rehypeKatex", _var_is_local=False, _var_is_string=False -) -_REHYPE_RAW = Var.create_safe("rehypeRaw", _var_is_local=False, _var_is_string=False) -_REHYPE_PLUGINS = Var.create_safe([_REHYPE_KATEX, _REHYPE_RAW]) +_REHYPE_KATEX = ImmutableVar.create_safe("rehypeKatex") +_REHYPE_RAW = ImmutableVar.create_safe("rehypeRaw") +_REHYPE_PLUGINS = LiteralVar.create([_REHYPE_KATEX, _REHYPE_RAW]) # These tags do NOT get props passed to them NO_PROPS_TAGS = ("ul", "ol", "li") @@ -209,10 +205,11 @@ class Markdown(Component): children_prop = props.pop("children", None) if children_prop is not None: special_props.add( - Var.create_safe(f"children={str(children_prop)}", _var_is_string=False) + Var.create_safe( + f"children={{{str(children_prop)}}}", _var_is_string=False + ) ) children = [] - # Get the component. component = self.component_map[tag](*children, **props).set( special_props=special_props @@ -238,7 +235,7 @@ class Markdown(Component): The formatted component map. """ components = { - tag: f"{{({{node, {_CHILDREN._var_name}, {_PROPS._var_name}}}) => {self.format_component(tag)}}}" + tag: f"{{({{node, {_CHILDREN._var_name}, {_PROPS._var_name}}}) => ({self.format_component(tag)})}}" for tag in self.component_map } @@ -261,7 +258,7 @@ class Markdown(Component): return inline ? ( {self.format_component("code")} ) : ( - {self.format_component("codeblock", language=Var.create_safe("language", _var_is_local=False, _var_is_string=False))} + {self.format_component("codeblock", language=ImmutableVar.create_safe("language"))} ); }}}}""".replace("\n", " ") @@ -288,7 +285,7 @@ class Markdown(Component): function {self._get_component_map_name()} () {{ {formatted_hooks} return ( - {str(LiteralVar.create(self.format_component_map()))} + {str(ImmutableVar.create_safe(self.format_component_map()))} ) }} """ @@ -300,14 +297,10 @@ class Markdown(Component): .add_props( remark_plugins=_REMARK_PLUGINS, rehype_plugins=_REHYPE_PLUGINS, + components=ImmutableVar.create_safe( + f"{self._get_component_map_name()}()" + ), ) .remove_props("componentMap", "componentMapHash") ) - tag.special_props.add( - Var.create_safe( - f"components={{{self._get_component_map_name()}()}}", - _var_is_local=True, - _var_is_string=False, - ), - ) return tag diff --git a/reflex/components/markdown/markdown.pyi b/reflex/components/markdown/markdown.pyi index e0eb43454..f4443bf3d 100644 --- a/reflex/components/markdown/markdown.pyi +++ b/reflex/components/markdown/markdown.pyi @@ -8,24 +8,21 @@ from typing import Any, Callable, Dict, Optional, Union, overload from reflex.components.component import Component from reflex.event import EventHandler, EventSpec +from reflex.ivars.base import ImmutableVar, LiteralVar from reflex.style import Style from reflex.utils.imports import ImportDict from reflex.vars import BaseVar, Var -_CHILDREN = Var.create_safe("children", _var_is_local=False, _var_is_string=False) -_PROPS = Var.create_safe("...props", _var_is_local=False, _var_is_string=False) -_MOCK_ARG = Var.create_safe("", _var_is_string=False) -_REMARK_MATH = Var.create_safe("remarkMath", _var_is_local=False, _var_is_string=False) -_REMARK_GFM = Var.create_safe("remarkGfm", _var_is_local=False, _var_is_string=False) -_REMARK_UNWRAP_IMAGES = Var.create_safe( - "remarkUnwrapImages", _var_is_local=False, _var_is_string=False -) -_REMARK_PLUGINS = Var.create_safe([_REMARK_MATH, _REMARK_GFM, _REMARK_UNWRAP_IMAGES]) -_REHYPE_KATEX = Var.create_safe( - "rehypeKatex", _var_is_local=False, _var_is_string=False -) -_REHYPE_RAW = Var.create_safe("rehypeRaw", _var_is_local=False, _var_is_string=False) -_REHYPE_PLUGINS = Var.create_safe([_REHYPE_KATEX, _REHYPE_RAW]) +_CHILDREN = ImmutableVar.create_safe("children") +_PROPS = ImmutableVar.create_safe("...props") +_MOCK_ARG = ImmutableVar.create_safe("") +_REMARK_MATH = ImmutableVar.create_safe("remarkMath") +_REMARK_GFM = ImmutableVar.create_safe("remarkGfm") +_REMARK_UNWRAP_IMAGES = ImmutableVar.create_safe("remarkUnwrapImages") +_REMARK_PLUGINS = LiteralVar.create([_REMARK_MATH, _REMARK_GFM, _REMARK_UNWRAP_IMAGES]) +_REHYPE_KATEX = ImmutableVar.create_safe("rehypeKatex") +_REHYPE_RAW = ImmutableVar.create_safe("rehypeRaw") +_REHYPE_PLUGINS = LiteralVar.create([_REHYPE_KATEX, _REHYPE_RAW]) NO_PROPS_TAGS = ("ul", "ol", "li") @lru_cache diff --git a/reflex/components/radix/themes/color_mode.py b/reflex/components/radix/themes/color_mode.py index f0ef477cc..511a2ea31 100644 --- a/reflex/components/radix/themes/color_mode.py +++ b/reflex/components/radix/themes/color_mode.py @@ -206,5 +206,5 @@ class ColorModeNamespace(ImmutableVar): color_mode = color_mode_var_and_namespace = ColorModeNamespace( _var_name=color_mode._var_name, _var_type=color_mode._var_type, - _var_data=color_mode._var_data, + _var_data=color_mode.get_default_value(), ) diff --git a/reflex/components/radix/themes/color_mode.pyi b/reflex/components/radix/themes/color_mode.pyi index 474c4cf36..dcbc645cb 100644 --- a/reflex/components/radix/themes/color_mode.pyi +++ b/reflex/components/radix/themes/color_mode.pyi @@ -3,7 +3,6 @@ # ------------------- DO NOT EDIT ---------------------- # This file was generated by `reflex/utils/pyi_generator.py`! # ------------------------------------------------------ -import dataclasses from typing import Any, Callable, Dict, Literal, Optional, Union, get_args, overload from reflex.components.component import BaseComponent @@ -12,6 +11,7 @@ from reflex.components.core.cond import Cond from reflex.components.lucide.icon import Icon from reflex.components.radix.themes.components.switch import Switch from reflex.event import EventHandler, EventSpec +from reflex.ivars.base import ImmutableVar from reflex.style import ( Style, color_mode, @@ -533,11 +533,13 @@ class ColorModeSwitch(Switch): """ ... -class ColorModeNamespace(BaseVar): +class ColorModeNamespace(ImmutableVar): icon = staticmethod(ColorModeIcon.create) button = staticmethod(ColorModeIconButton.create) switch = staticmethod(ColorModeSwitch.create) color_mode = color_mode_var_and_namespace = ColorModeNamespace( - **dataclasses.asdict(color_mode) + _var_name=color_mode._var_name, + _var_type=color_mode._var_type, + _var_data=color_mode.get_default_value(), ) diff --git a/reflex/components/radix/themes/components/checkbox.py b/reflex/components/radix/themes/components/checkbox.py index 4751c4995..f191ce613 100644 --- a/reflex/components/radix/themes/components/checkbox.py +++ b/reflex/components/radix/themes/components/checkbox.py @@ -7,6 +7,7 @@ from reflex.components.core.breakpoints import Responsive from reflex.components.radix.themes.layout.flex import Flex from reflex.components.radix.themes.typography.text import Text from reflex.event import EventHandler +from reflex.ivars.base import LiteralVar from reflex.vars import Var from ..base import ( @@ -115,9 +116,7 @@ class HighLevelCheckbox(RadixThemesComponent): on_change: EventHandler[lambda e0: [e0]] @classmethod - def create( - cls, text: Var[str] = Var.create_safe("", _var_is_string=True), **props - ) -> Component: + def create(cls, text: Var[str] = LiteralVar.create(""), **props) -> Component: """Create a checkbox with a label. Args: diff --git a/reflex/components/radix/themes/components/radio_group.py b/reflex/components/radix/themes/components/radio_group.py index b16a73a33..049e47b29 100644 --- a/reflex/components/radix/themes/components/radio_group.py +++ b/reflex/components/radix/themes/components/radio_group.py @@ -11,7 +11,6 @@ from reflex.components.radix.themes.layout.flex import Flex from reflex.components.radix.themes.typography.text import Text from reflex.event import EventHandler from reflex.ivars.base import ImmutableVar, LiteralVar -from reflex.ivars.function import JSON_STRINGIFY from reflex.ivars.sequence import StringVar from reflex.vars import Var @@ -30,14 +29,10 @@ class RadioGroupRoot(RadixThemesComponent): tag = "RadioGroup.Root" # The size of the radio group: "1" | "2" | "3" - size: Var[Responsive[Literal["1", "2", "3"]]] = Var.create_safe( - "2", _var_is_string=True - ) + size: Var[Responsive[Literal["1", "2", "3"]]] = LiteralVar.create("2") # The variant of the radio group - variant: Var[Literal["classic", "surface", "soft"]] = Var.create_safe( - "classic", _var_is_string=True - ) + variant: Var[Literal["classic", "surface", "soft"]] = LiteralVar.create("classic") # The color of the radio group color_scheme: Var[LiteralAccentColor] @@ -89,20 +84,16 @@ class HighLevelRadioGroup(RadixThemesComponent): items: Var[List[str]] # The direction of the radio group. - direction: Var[LiteralFlexDirection] = Var.create_safe( - "column", _var_is_string=True - ) + direction: Var[LiteralFlexDirection] = LiteralVar.create("column") # The gap between the items of the radio group. - spacing: Var[LiteralSpacing] = Var.create_safe("2", _var_is_string=True) + spacing: Var[LiteralSpacing] = LiteralVar.create("2") # The size of the radio group. - size: Var[Literal["1", "2", "3"]] = Var.create_safe("2", _var_is_string=True) + size: Var[Literal["1", "2", "3"]] = LiteralVar.create("2") # The variant of the radio group - variant: Var[Literal["classic", "surface", "soft"]] = Var.create_safe( - "classic", _var_is_string=True - ) + variant: Var[Literal["classic", "surface", "soft"]] = LiteralVar.create("classic") # The color of the radio group color_scheme: Var[LiteralAccentColor] @@ -159,13 +150,13 @@ class HighLevelRadioGroup(RadixThemesComponent): ): default_value = LiteralVar.create(default_value) # type: ignore else: - default_value = JSON_STRINGIFY.call(ImmutableVar.create(default_value)) + default_value = ImmutableVar.create_safe(default_value).to_string() def radio_group_item(value: Var) -> Component: item_value = rx.cond( value._type() == "string", value, - JSON_STRINGIFY.call(value), + value.to_string(), ).to(StringVar) return Text.create( diff --git a/reflex/components/radix/themes/components/separator.py b/reflex/components/radix/themes/components/separator.py index fbc445878..81f83194b 100644 --- a/reflex/components/radix/themes/components/separator.py +++ b/reflex/components/radix/themes/components/separator.py @@ -3,6 +3,7 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive +from reflex.ivars.base import LiteralVar from reflex.vars import Var from ..base import ( @@ -19,9 +20,7 @@ class Separator(RadixThemesComponent): tag = "Separator" # The size of the select: "1" | "2" | "3" | "4" - size: Var[Responsive[LiteralSeperatorSize]] = Var.create_safe( - "4", _var_is_string=True - ) + size: Var[Responsive[LiteralSeperatorSize]] = LiteralVar.create("4") # The color of the select color_scheme: Var[LiteralAccentColor] diff --git a/reflex/components/radix/themes/layout/container.py b/reflex/components/radix/themes/layout/container.py index c0f1c7dc8..4ed18031d 100644 --- a/reflex/components/radix/themes/layout/container.py +++ b/reflex/components/radix/themes/layout/container.py @@ -6,6 +6,7 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.ivars.base import LiteralVar from reflex.style import STACK_CHILDREN_FULL_WIDTH from reflex.vars import Var @@ -23,9 +24,7 @@ class Container(elements.Div, RadixThemesComponent): tag = "Container" # The size of the container: "1" - "4" (default "3") - size: Var[Responsive[LiteralContainerSize]] = Var.create_safe( - "3", _var_is_string=True - ) + size: Var[Responsive[LiteralContainerSize]] = LiteralVar.create("3") @classmethod def create( diff --git a/reflex/components/radix/themes/layout/section.py b/reflex/components/radix/themes/layout/section.py index d9b27bdf9..a3e58be86 100644 --- a/reflex/components/radix/themes/layout/section.py +++ b/reflex/components/radix/themes/layout/section.py @@ -6,6 +6,7 @@ from typing import Literal from reflex.components.core.breakpoints import Responsive from reflex.components.el import elements +from reflex.ivars.base import LiteralVar from reflex.vars import Var from ..base import RadixThemesComponent @@ -19,9 +20,7 @@ class Section(elements.Section, RadixThemesComponent): tag = "Section" # The size of the section: "1" - "3" (default "2") - size: Var[Responsive[LiteralSectionSize]] = Var.create_safe( - "2", _var_is_string=True - ) + size: Var[Responsive[LiteralSectionSize]] = LiteralVar.create("2") section = Section.create diff --git a/reflex/components/recharts/cartesian.py b/reflex/components/recharts/cartesian.py index 710fef19b..2e68d5f23 100644 --- a/reflex/components/recharts/cartesian.py +++ b/reflex/components/recharts/cartesian.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Union from reflex.constants import EventTriggers from reflex.constants.colors import Color from reflex.event import EventHandler +from reflex.ivars.base import LiteralVar from reflex.vars import Var from .recharts import ( @@ -86,7 +87,7 @@ class Axis(Recharts): tick_count: Var[int] # If set false, no axis tick lines will be drawn. - tick_line: Var[bool] = Var.create_safe(False) + tick_line: Var[bool] = LiteralVar.create(False) # The length of tick line. tick_size: Var[int] @@ -95,7 +96,7 @@ class Axis(Recharts): min_tick_gap: Var[int] # The stroke color of axis - stroke: Var[Union[str, Color]] = Var.create_safe(Color("gray", 9)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("gray", 9)) # The text anchor of axis text_anchor: Var[str] # 'start', 'middle', 'end' @@ -136,7 +137,7 @@ class XAxis(Axis): x_axis_id: Var[Union[str, int]] # Ensures that all datapoints within a chart contribute to its domain calculation, even when they are hidden - include_hidden: Var[bool] = Var.create_safe(False) + include_hidden: Var[bool] = LiteralVar.create(False) class YAxis(Axis): @@ -187,10 +188,10 @@ class Brush(Recharts): alias = "RechartsBrush" # Stroke color - stroke: Var[Union[str, Color]] = Var.create_safe(Color("gray", 9)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("gray", 9)) # The fill color of brush. - fill: Var[Union[str, Color]] = Var.create_safe(Color("gray", 2)) + fill: Var[Union[str, Color]] = LiteralVar.create(Color("gray", 2)) # The key of data displayed in the axis. data_key: Var[Union[str, int]] @@ -290,22 +291,22 @@ class Area(Cartesian): alias = "RechartsArea" # The color of the line stroke. - stroke: Var[Union[str, Color]] = Var.create_safe(Color("accent", 9)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("accent", 9)) # The width of the line stroke. - stroke_width: Var[int] = Var.create_safe(1) + stroke_width: Var[int] = LiteralVar.create(1) # The color of the area fill. - fill: Var[Union[str, Color]] = Var.create_safe(Color("accent", 5)) + fill: Var[Union[str, Color]] = LiteralVar.create(Color("accent", 5)) # The interpolation type of area. And customized interpolation function can be set to type. 'basis' | 'basisClosed' | 'basisOpen' | 'bumpX' | 'bumpY' | 'bump' | 'linear' | 'linearClosed' | 'natural' | 'monotoneX' | 'monotoneY' | 'monotone' | 'step' | 'stepBefore' | 'stepAfter' | - type_: Var[LiteralAreaType] = Var.create_safe("monotone", _var_is_string=True) + type_: Var[LiteralAreaType] = LiteralVar.create("monotone") # If false set, dots will not be drawn. If true set, dots will be drawn which have the props calculated internally. dot: Var[Union[bool, Dict[str, Any]]] # The dot is shown when user enter an area chart and this chart has tooltip. If false set, no active dot will not be drawn. If true set, active dot will be drawn which have the props calculated internally. - active_dot: Var[Union[bool, Dict[str, Any]]] = Var.create_safe( + active_dot: Var[Union[bool, Dict[str, Any]]] = LiteralVar.create( { "stroke": Color("accent", 2), "fill": Color("accent", 10), @@ -342,7 +343,7 @@ class Bar(Cartesian): stroke_width: Var[int] # The width of the line stroke. - fill: Var[Union[str, Color]] = Var.create_safe(Color("accent", 9)) + fill: Var[Union[str, Color]] = LiteralVar.create(Color("accent", 9)) # If false set, background of bars will not be drawn. If true set, background of bars will be drawn which have the props calculated internally. background: Var[bool] @@ -403,13 +404,13 @@ class Line(Cartesian): type_: Var[LiteralAreaType] # The color of the line stroke. - stroke: Var[Union[str, Color]] = Var.create_safe(Color("accent", 9)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("accent", 9)) # The width of the line stroke. stroke_width: Var[int] # The dot is shown when mouse enter a line chart and this chart has tooltip. If false set, no active dot will not be drawn. If true set, active dot will be drawn which have the props calculated internally. - dot: Var[Union[bool, Dict[str, Any]]] = Var.create_safe( + dot: Var[Union[bool, Dict[str, Any]]] = LiteralVar.create( { "stroke": Color("accent", 10), "fill": Color("accent", 4), @@ -417,7 +418,7 @@ class Line(Cartesian): ) # The dot is shown when user enter an area chart and this chart has tooltip. If false set, no active dot will not be drawn. If true set, active dot will be drawn which have the props calculated internally. - active_dot: Var[Union[bool, Dict[str, Any]]] = Var.create_safe( + active_dot: Var[Union[bool, Dict[str, Any]]] = LiteralVar.create( { "stroke": Color("accent", 2), "fill": Color("accent", 10), @@ -475,7 +476,7 @@ class Scatter(Recharts): line_type: Var[LiteralLineType] # The fill - fill: Var[Union[str, Color]] = Var.create_safe(Color("accent", 9)) + fill: Var[Union[str, Color]] = LiteralVar.create(Color("accent", 9)) # the name name: Var[Union[str, int]] @@ -552,7 +553,7 @@ class Funnel(Recharts): animation_easing: Var[LiteralAnimationEasing] # stroke color - stroke: Var[Union[str, Color]] = Var.create_safe(Color("gray", 3)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("gray", 3)) # Valid children components _valid_children: List[str] = ["LabelList", "Cell"] @@ -605,7 +606,7 @@ class ErrorBar(Recharts): width: Var[int] # The stroke color of error bar. - stroke: Var[Union[str, Color]] = Var.create_safe(Color("gray", 8)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("gray", 8)) # The stroke width of error bar. stroke_width: Var[int] @@ -795,7 +796,7 @@ class CartesianGrid(Grid): stroke_dasharray: Var[str] # the stroke color of grid - stroke: Var[Union[str, Color]] = Var.create_safe(Color("gray", 7)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("gray", 7)) class CartesianAxis(Grid): diff --git a/reflex/components/recharts/charts.py b/reflex/components/recharts/charts.py index 0f52545d0..37b865b1d 100644 --- a/reflex/components/recharts/charts.py +++ b/reflex/components/recharts/charts.py @@ -9,6 +9,7 @@ from reflex.components.recharts.general import ResponsiveContainer from reflex.constants import EventTriggers from reflex.constants.colors import Color from reflex.event import EventHandler +from reflex.ivars.base import LiteralVar from reflex.vars import Var from .recharts import ( @@ -156,10 +157,10 @@ class BarChart(CategoricalChartBase): alias = "RechartsBarChart" # The gap between two bar categories, which can be a percent value or a fixed value. Percentage | Number - bar_category_gap: Var[Union[str, int]] = Var.create_safe("10%", _var_is_string=True) # type: ignore + bar_category_gap: Var[Union[str, int]] = LiteralVar.create("10%") # The gap between two bars in the same category, which can be a percent value or a fixed value. Percentage | Number - bar_gap: Var[Union[str, int]] = Var.create_safe(4) # type: ignore + bar_gap: Var[Union[str, int]] = LiteralVar.create(4) # type: ignore # The width of all the bars in the chart. Number bar_size: Var[int] diff --git a/reflex/components/recharts/general.py b/reflex/components/recharts/general.py index 613e6fbf0..4f81ea833 100644 --- a/reflex/components/recharts/general.py +++ b/reflex/components/recharts/general.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Union from reflex.components.component import MemoizationLeaf from reflex.constants.colors import Color from reflex.event import EventHandler +from reflex.ivars.base import LiteralVar from reflex.vars import Var from .recharts import ( @@ -139,7 +140,7 @@ class GraphingTooltip(Recharts): filter_null: Var[bool] # If set false, no cursor will be drawn when tooltip is active. - cursor: Var[Union[Dict[str, Any], bool]] = Var.create_safe( + cursor: Var[Union[Dict[str, Any], bool]] = LiteralVar.create( { "strokeWidth": 1, "fill": Color("gray", 3), @@ -150,7 +151,7 @@ class GraphingTooltip(Recharts): view_box: Var[Dict[str, Any]] # The style of default tooltip content item which is a li element. DEFAULT: {} - item_style: Var[Dict[str, Any]] = Var.create_safe( + item_style: Var[Dict[str, Any]] = LiteralVar.create( { "color": Color("gray", 12), } @@ -159,7 +160,7 @@ class GraphingTooltip(Recharts): # The style of tooltip wrapper which is a dom element. DEFAULT: {} wrapper_style: Var[Dict[str, Any]] # The style of tooltip content which is a dom element. DEFAULT: {} - content_style: Var[Dict[str, Any]] = Var.create_safe( + content_style: Var[Dict[str, Any]] = LiteralVar.create( { "background": Color("gray", 1), "borderColor": Color("gray", 4), @@ -168,10 +169,10 @@ class GraphingTooltip(Recharts): ) # The style of default tooltip label which is a p element. DEFAULT: {} - label_style: Var[Dict[str, Any]] = Var.create_safe({"color": Color("gray", 11)}) + label_style: Var[Dict[str, Any]] = LiteralVar.create({"color": Color("gray", 11)}) # This option allows the tooltip to extend beyond the viewBox of the chart itself. DEFAULT: { x: false, y: false } - allow_escape_view_box: Var[Dict[str, bool]] = Var.create_safe( + allow_escape_view_box: Var[Dict[str, bool]] = LiteralVar.create( {"x": False, "y": False} ) @@ -231,10 +232,10 @@ class LabelList(Recharts): offset: Var[int] # The fill color of each label - fill: Var[Union[str, Color]] = Var.create_safe(Color("gray", 10)) + fill: Var[Union[str, Color]] = LiteralVar.create(Color("gray", 10)) # The stroke color of each label - stroke: Var[Union[str, Color]] = Var.create_safe("none", _var_is_string=True) + stroke: Var[Union[str, Color]] = LiteralVar.create("none") responsive_container = ResponsiveContainer.create diff --git a/reflex/components/recharts/polar.py b/reflex/components/recharts/polar.py index 64fa00ecd..76347352b 100644 --- a/reflex/components/recharts/polar.py +++ b/reflex/components/recharts/polar.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Union from reflex.constants import EventTriggers from reflex.constants.colors import Color from reflex.event import EventHandler +from reflex.ivars.base import LiteralVar from reflex.vars import Var from .recharts import ( @@ -72,10 +73,10 @@ class Pie(Recharts): _valid_children: List[str] = ["Cell", "LabelList"] # Stoke color - stroke: Var[Union[str, Color]] = Var.create_safe(Color("accent", 9)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("accent", 9)) # Fill color - fill: Var[Union[str, Color]] = Var.create_safe(Color("accent", 3)) + fill: Var[Union[str, Color]] = LiteralVar.create(Color("accent", 3)) def get_event_triggers(self) -> dict[str, Union[Var, Any]]: """Get the event triggers that pass the component's value to the handler. @@ -110,13 +111,13 @@ class Radar(Recharts): dot: Var[bool] # Stoke color - stroke: Var[Union[str, Color]] = Var.create_safe(Color("accent", 9)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("accent", 9)) # Fill color - fill: Var[str] = Var.create_safe(Color("accent", 3)) + fill: Var[str] = LiteralVar.create(Color("accent", 3)) # opacity - fill_opacity: Var[float] = Var.create_safe(0.6) + fill_opacity: Var[float] = LiteralVar.create(0.6) # The type of icon in legend. If set to 'none', no legend item will be rendered. legend_type: Var[str] @@ -218,7 +219,7 @@ class PolarAngleAxis(Recharts): axis_line_type: Var[str] # If false set, tick lines will not be drawn. If true set, tick lines will be drawn which have the props calculated internally. If object set, tick lines will be drawn which have the props mergered by the internal calculated props and the option. - tick_line: Var[Union[bool, Dict[str, Any]]] = Var.create_safe(False) + tick_line: Var[Union[bool, Dict[str, Any]]] = LiteralVar.create(False) # The width or height of tick. tick: Var[Union[int, str]] @@ -230,7 +231,7 @@ class PolarAngleAxis(Recharts): orient: Var[str] # The stroke color of axis - stroke: Var[Union[str, Color]] = Var.create_safe(Color("gray", 10)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("gray", 10)) # Allow the axis has duplicated categorys or not when the type of axis is "category". allow_duplicated_category: Var[bool] @@ -292,7 +293,7 @@ class PolarGrid(Recharts): grid_type: Var[LiteralGridType] # The stroke color of grid - stroke: Var[Union[str, Color]] = Var.create_safe(Color("gray", 10)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("gray", 10)) # Valid children components _valid_children: List[str] = ["RadarChart", "RadiarBarChart"] @@ -342,10 +343,10 @@ class PolarRadiusAxis(Recharts): _valid_children: List[str] = ["Label"] # The domain of the polar radius axis, specifying the minimum and maximum values. - domain: Var[List[int]] = Var.create_safe([0, 250]) + domain: Var[List[int]] = LiteralVar.create([0, 250]) # The stroke color of axis - stroke: Var[Union[str, Color]] = Var.create_safe(Color("gray", 10)) + stroke: Var[Union[str, Color]] = LiteralVar.create(Color("gray", 10)) def get_event_triggers(self) -> dict[str, Union[Var, Any]]: """Get the event triggers that pass the component's value to the handler. diff --git a/reflex/components/sonner/toast.py b/reflex/components/sonner/toast.py index d4df31e82..97175ec0e 100644 --- a/reflex/components/sonner/toast.py +++ b/reflex/components/sonner/toast.py @@ -12,6 +12,7 @@ from reflex.event import ( EventSpec, call_script, ) +from reflex.ivars.base import LiteralVar from reflex.style import Style, resolved_color_mode from reflex.utils import format from reflex.utils.imports import ImportVar @@ -171,21 +172,19 @@ class Toaster(Component): theme: Var[str] = resolved_color_mode # whether to show rich colors - rich_colors: Var[bool] = Var.create_safe(True) + rich_colors: Var[bool] = LiteralVar.create(True) # whether to expand the toast - expand: Var[bool] = Var.create_safe(True) + expand: Var[bool] = LiteralVar.create(True) # the number of toasts that are currently visible visible_toasts: Var[int] # the position of the toast - position: Var[LiteralPosition] = Var.create_safe( - "bottom-right", _var_is_string=True - ) + position: Var[LiteralPosition] = LiteralVar.create("bottom-right") # whether to show the close button - close_button: Var[bool] = Var.create_safe(False) + close_button: Var[bool] = LiteralVar.create(False) # offset of the toast offset: Var[str] @@ -330,7 +329,7 @@ class Toaster(Component): if isinstance(id, Var): dismiss = f"{toast_ref}.dismiss({id._var_name_unwrapped})" - dismiss_var_data = id._var_data + dismiss_var_data = id._get_all_var_data() elif isinstance(id, str): dismiss = f"{toast_ref}.dismiss('{id}')" else: @@ -339,7 +338,7 @@ class Toaster(Component): dismiss, _var_is_string=False, _var_is_local=True, - _var_data=dismiss_var_data, + _var_data=VarData.merge(dismiss_var_data), ) return call_script(dismiss_action) diff --git a/reflex/event.py b/reflex/event.py index 1c43334a1..e4b600eb9 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -714,15 +714,11 @@ def download( url = "data:text/plain," + urllib.parse.quote(data) elif isinstance(data, Var): # Need to check on the frontend if the Var already looks like a data: URI. - is_data_url = data._replace( - _var_name=( - f"typeof {data._var_full_name} == 'string' && " - f"{data._var_full_name}.startsWith('data:')" - ), - _var_type=bool, - _var_is_string=False, - _var_full_name_needs_state_prefix=False, + + is_data_url = (data._type() == "string") & ( + data.to(str).startswith("data:") ) + # If it's a data: URI, use it as is, otherwise convert the Var to JSON in a data: URI. url = cond( # type: ignore is_data_url, diff --git a/reflex/ivars/base.py b/reflex/ivars/base.py index 55dad65ce..5bbc5a3d1 100644 --- a/reflex/ivars/base.py +++ b/reflex/ivars/base.py @@ -2,10 +2,15 @@ from __future__ import annotations +import contextlib import dataclasses +import datetime +import dis import functools import inspect +import json import sys +from types import CodeType, FunctionType from typing import ( TYPE_CHECKING, Any, @@ -20,17 +25,22 @@ from typing import ( Type, TypeVar, Union, + cast, get_args, overload, + override, ) -from typing_extensions import ParamSpec, get_origin +from typing_extensions import ParamSpec, get_origin, get_type_hints from reflex import constants from reflex.base import Base +from reflex.constants.colors import Color from reflex.utils import console, imports, serializers, types -from reflex.utils.exceptions import VarTypeError +from reflex.utils.exceptions import VarDependencyError, VarTypeError, VarValueError +from reflex.utils.format import format_state_name from reflex.vars import ( + ComputedVar, ImmutableVarData, Var, VarData, @@ -320,11 +330,15 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): @overload def to( - self, output: Type[OUTPUT], var_type: types.GenericType | None = None + self, + output: Type[OUTPUT] | types.GenericType, + var_type: types.GenericType | None = None, ) -> OUTPUT: ... def to( - self, output: Type[OUTPUT], var_type: types.GenericType | None = None + self, + output: Type[OUTPUT] | types.GenericType, + var_type: types.GenericType | None = None, ) -> Var: """Convert the var to a different type. @@ -338,12 +352,15 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: The converted var. """ + from .function import FunctionVar, ToFunctionOperation from .number import ( BooleanVar, NumberVar, ToBooleanVarOperation, ToNumberVarOperation, ) + from .object import ObjectVar, ToObjectOperation + from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation fixed_type = ( var_type @@ -351,16 +368,28 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): else get_origin(var_type) ) + fixed_output_type = output if inspect.isclass(output) else get_origin(output) + + if fixed_output_type is dict: + return self.to(ObjectVar, output) + if fixed_output_type in (list, tuple, set): + return self.to(ArrayVar, output) + if fixed_output_type in (int, float): + return self.to(NumberVar, output) + if fixed_output_type is str: + return self.to(StringVar, output) + if fixed_output_type is bool: + return self.to(BooleanVar, output) + if issubclass(output, NumberVar): if fixed_type is not None and not issubclass(fixed_type, (int, float)): raise TypeError( f"Unsupported type {var_type} for NumberVar. Must be int or float." ) - return ToNumberVarOperation(self, var_type or float) - if issubclass(output, BooleanVar): - return ToBooleanVarOperation(self) + return ToNumberVarOperation.create(self, var_type or float) - from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation + if issubclass(output, BooleanVar): + return ToBooleanVarOperation.create(self) if issubclass(output, ArrayVar): if fixed_type is not None and not issubclass( @@ -369,28 +398,30 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): raise TypeError( f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set." ) - return ToArrayOperation(self, var_type or list) + return ToArrayOperation.create(self, var_type or list) + if issubclass(output, StringVar): - return ToStringOperation(self) + return ToStringOperation.create(self) - from .object import ObjectVar, ToObjectOperation - - if issubclass(output, ObjectVar): - return ToObjectOperation(self, var_type or dict) - - from .function import FunctionVar, ToFunctionOperation + if issubclass(output, (ObjectVar, Base)): + return ToObjectOperation.create(self, var_type or dict) if issubclass(output, FunctionVar): # if fixed_type is not None and not issubclass(fixed_type, Callable): # raise TypeError( # f"Unsupported type {var_type} for FunctionVar. Must be Callable." # ) - return ToFunctionOperation(self, var_type or Callable) + return ToFunctionOperation.create(self, var_type or Callable) - return output( - _var_name=self._var_name, - _var_type=self._var_type if var_type is None else var_type, - _var_data=self._var_data, + if not issubclass(output, Var) and var_type is None: + return dataclasses.replace( + self, + _var_type=output, + ) + + return dataclasses.replace( + self, + _var_type=var_type, ) def guess_type(self) -> ImmutableVar: @@ -413,6 +444,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): if fixed_type is Union: return self + if not inspect.isclass(fixed_type): + raise TypeError(f"Unsupported type {var_type} for guess_type.") + if issubclass(fixed_type, (int, float)): return self.to(NumberVar, var_type) if issubclass(fixed_type, dict): @@ -477,12 +511,12 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: The name of the setter function. """ - setter = constants.SETTER_PREFIX + self._var_name + var_name_parts = self._var_name.split(".") + setter = constants.SETTER_PREFIX + var_name_parts[-1] if self._var_data is None: return setter if not include_state or self._var_data.state == "": return setter - print("get_setter_name", self._var_data.state, setter) return ".".join((self._var_data.state, setter)) def get_setter(self) -> Callable[[BaseState, Any], None]: @@ -491,6 +525,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A function that that creates a setter for the var. """ + actual_name = self._var_name.split(".")[-1] def setter(state: BaseState, value: Any): """Get the setter for the var. @@ -502,13 +537,13 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): if self._var_type in [int, float]: try: value = self._var_type(value) - setattr(state, self._var_name, value) + setattr(state, actual_name, value) except ValueError: console.debug( f"{type(state).__name__}.{self._var_name}: Failed conversion of {value} to '{self._var_type.__name__}'. Value not set.", ) else: - setattr(state, self._var_name, value) + setattr(state, actual_name, value) setter.__qualname__ = self.get_setter_name() @@ -525,7 +560,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): """ from .number import EqualOperation - return EqualOperation(self, other) + return EqualOperation.create(self, other) def __ne__(self, other: Var | Any) -> BooleanVar: """Check if the current object is not equal to the given object. @@ -538,7 +573,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): """ from .number import EqualOperation - return ~EqualOperation(self, other) + return ~EqualOperation.create(self, other) def __gt__(self, other: Var | Any) -> BooleanVar: """Compare the current instance with another variable and return a BooleanVar representing the result of the greater than operation. @@ -551,7 +586,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): """ from .number import GreaterThanOperation - return GreaterThanOperation(self, other) + return GreaterThanOperation.create(self, other) def __ge__(self, other: Var | Any) -> BooleanVar: """Check if the value of this variable is greater than or equal to the value of another variable or object. @@ -564,7 +599,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): """ from .number import GreaterThanOrEqualOperation - return GreaterThanOrEqualOperation(self, other) + return GreaterThanOrEqualOperation.create(self, other) def __lt__(self, other: Var | Any) -> BooleanVar: """Compare the current instance with another variable using the less than (<) operator. @@ -577,7 +612,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): """ from .number import LessThanOperation - return LessThanOperation(self, other) + return LessThanOperation.create(self, other) def __le__(self, other: Var | Any) -> BooleanVar: """Compare if the current instance is less than or equal to the given value. @@ -590,7 +625,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): """ from .number import LessThanOrEqualOperation - return LessThanOrEqualOperation(self, other) + return LessThanOrEqualOperation.create(self, other) def bool(self) -> BooleanVar: """Convert the var to a boolean. @@ -600,7 +635,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): """ from .number import ToBooleanVarOperation - return ToBooleanVarOperation(self) + return ToBooleanVarOperation.create(self) def __and__(self, other: Var | Any) -> ImmutableVar: """Perform a logical AND operation on the current instance and another variable. @@ -611,7 +646,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the logical AND operation. """ - return AndOperation(self, other) + return AndOperation.create(self, other) def __rand__(self, other: Var | Any) -> ImmutableVar: """Perform a logical AND operation on the current instance and another variable. @@ -622,7 +657,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the logical AND operation. """ - return AndOperation(other, self) + return AndOperation.create(other, self) def __or__(self, other: Var | Any) -> ImmutableVar: """Perform a logical OR operation on the current instance and another variable. @@ -633,7 +668,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the logical OR operation. """ - return OrOperation(self, other) + return OrOperation.create(self, other) def __ror__(self, other: Var | Any) -> ImmutableVar: """Perform a logical OR operation on the current instance and another variable. @@ -644,7 +679,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the logical OR operation. """ - return OrOperation(other, self) + return OrOperation.create(other, self) def __invert__(self) -> BooleanVar: """Perform a logical NOT operation on the current instance. @@ -654,7 +689,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): """ from .number import BooleanNotOperation - return BooleanNotOperation(self.bool()) + return BooleanNotOperation.create(self.bool()) def to_string(self) -> ImmutableVar: """Convert the var to a string. @@ -663,8 +698,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): The string var. """ from .function import JSON_STRINGIFY + from .sequence import StringVar - return JSON_STRINGIFY.call(self) + return JSON_STRINGIFY.call(self).to(StringVar) def as_ref(self) -> ImmutableVar: """Get a reference to the var. @@ -732,35 +768,97 @@ class LiteralVar(ImmutableVar): if value is None: return ImmutableVar.create_safe("null", _var_data=_var_data) + from reflex.event import EventChain, EventSpec + from reflex.utils.format import get_event_handler_parts + + from .function import ArgsFunctionOperation, FunctionStringVar from .object import LiteralObjectVar + if isinstance(value, EventSpec): + event_name = LiteralVar.create( + ".".join(get_event_handler_parts(value.handler)) + ) + event_args = LiteralVar.create({name: value for name, value in value.args}) + event_client_name = LiteralVar.create(value.client_handler_name) + return FunctionStringVar("Event").call( + event_name, event_args, event_client_name + ) + + if isinstance(value, EventChain): + sig = inspect.signature(value.args_spec) # type: ignore + if sig.parameters: + arg_def = tuple((f"_{p}" for p in sig.parameters)) + arg_def_expr = LiteralVar.create( + [ImmutableVar.create_safe(arg) for arg in arg_def] + ) + else: + # add a default argument for addEvents if none were specified in value.args_spec + # used to trigger the preventDefault() on the event. + arg_def = ("...args",) + arg_def_expr = ImmutableVar.create_safe("args") + + return ArgsFunctionOperation.create( + arg_def, + FunctionStringVar.create("addEvents").call( + LiteralVar.create( + [LiteralVar.create(event) for event in value.events] + ), + arg_def_expr, + LiteralVar.create(value.event_actions), + ), + ) + + from plotly.graph_objects import Figure, layout + from plotly.io import to_json + + if isinstance(value, Figure): + return LiteralObjectVar.create( + json.loads(to_json(value)), _var_type=Figure, _var_data=_var_data + ) + + if isinstance(value, layout.Template): + return LiteralObjectVar.create( + { + "data": json.loads(to_json(value.data)), + "layout": json.loads(to_json(value.layout)), + }, + _var_type=layout.Template, + _var_data=_var_data, + ) + if isinstance(value, Base): - return LiteralObjectVar( + return LiteralObjectVar.create( value.dict(), _var_type=type(value), _var_data=_var_data ) if isinstance(value, dict): - return LiteralObjectVar(value, _var_data=_var_data) + return LiteralObjectVar.create(value, _var_data=_var_data) - from .number import LiteralBooleanVar, LiteralNumberVar from .sequence import LiteralArrayVar, LiteralStringVar if isinstance(value, str): return LiteralStringVar.create(value, _var_data=_var_data) + if isinstance(value, Color): + return LiteralStringVar.create(f"{value}", _var_data=_var_data) + + from .number import LiteralBooleanVar, LiteralNumberVar + type_mapping = { - int: LiteralNumberVar, - float: LiteralNumberVar, - bool: LiteralBooleanVar, - list: LiteralArrayVar, - tuple: LiteralArrayVar, - set: LiteralArrayVar, + int: LiteralNumberVar.create, + float: LiteralNumberVar.create, + bool: LiteralBooleanVar.create, + list: LiteralArrayVar.create, + tuple: LiteralArrayVar.create, + set: LiteralArrayVar.create, } constructor = type_mapping.get(type(value)) if constructor is None: - raise TypeError(f"Unsupported type {type(value)} for LiteralVar.") + raise TypeError( + f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}." + ) return constructor(value, _var_data=_var_data) @@ -881,27 +979,8 @@ class AndOperation(ImmutableVar): # The second var. _var2: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) - def __init__( - self, var1: Var | Any, var2: Var | Any, _var_data: VarData | None = None - ): - """Initialize the AndOperation. - - Args: - var1: The first var. - var2: The second var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(AndOperation, self).__init__( - _var_name="", - _var_type=Union[var1._var_type, var2._var_type], - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, "_var1", var1 if isinstance(var1, Var) else LiteralVar.create(var1) - ) - object.__setattr__( - self, "_var2", var2 if isinstance(var2, Var) else LiteralVar.create(var2) - ) + def __post_init__(self): + """Post-initialize the AndOperation.""" object.__delattr__(self, "_var_name") @functools.cached_property @@ -955,6 +1034,29 @@ class AndOperation(ImmutableVar): """ return hash((self.__class__.__name__, self._var1, self._var2)) + @classmethod + def create( + cls, var1: Var | Any, var2: Var | Any, _var_data: VarData | None = None + ) -> AndOperation: + """Create an AndOperation. + + Args: + var1: The first var. + var2: The second var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The AndOperation. + """ + var1, var2 = map(LiteralVar.create, (var1, var2)) + return AndOperation( + _var_name="", + _var_type=unionize(var1._var_type, var2._var_type), + _var_data=ImmutableVarData.merge(_var_data), + _var1=var1, + _var2=var2, + ) + @dataclasses.dataclass( eq=False, @@ -970,27 +1072,8 @@ class OrOperation(ImmutableVar): # The second var. _var2: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) - def __init__( - self, var1: Var | Any, var2: Var | Any, _var_data: VarData | None = None - ): - """Initialize the OrOperation. - - Args: - var1: The first var. - var2: The second var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(OrOperation, self).__init__( - _var_name="", - _var_type=Union[var1._var_type, var2._var_type], - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, "_var1", var1 if isinstance(var1, Var) else LiteralVar.create(var1) - ) - object.__setattr__( - self, "_var2", var2 if isinstance(var2, Var) else LiteralVar.create(var2) - ) + def __post_init__(self): + """Post-initialize the OrOperation.""" object.__delattr__(self, "_var_name") @functools.cached_property @@ -1044,6 +1127,29 @@ class OrOperation(ImmutableVar): """ return hash((self.__class__.__name__, self._var1, self._var2)) + @classmethod + def create( + cls, var1: Var | Any, var2: Var | Any, _var_data: VarData | None = None + ) -> OrOperation: + """Create an OrOperation. + + Args: + var1: The first var. + var2: The second var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The OrOperation. + """ + var1, var2 = map(LiteralVar.create, (var1, var2)) + return OrOperation( + _var_name="", + _var_type=unionize(var1._var_type, var2._var_type), + _var_data=ImmutableVarData.merge(_var_data), + _var1=var1, + _var2=var2, + ) + @dataclasses.dataclass( eq=False, @@ -1057,14 +1163,14 @@ class ImmutableCallableVar(ImmutableVar): API with functions that return a family of Var. """ - fn: Callable[..., ImmutableVar] = dataclasses.field( - default_factory=lambda: lambda: LiteralVar.create(None) + fn: Callable[..., Var] = dataclasses.field( + default_factory=lambda: lambda: ImmutableVar(_var_name="undefined") ) - original_var: ImmutableVar = dataclasses.field( - default_factory=lambda: LiteralVar.create(None) + original_var: Var = dataclasses.field( + default_factory=lambda: ImmutableVar(_var_name="undefined") ) - def __init__(self, fn: Callable[..., ImmutableVar]): + def __init__(self, fn: Callable[..., Var]): """Initialize a CallableVar. Args: @@ -1074,12 +1180,12 @@ class ImmutableCallableVar(ImmutableVar): super(ImmutableCallableVar, self).__init__( _var_name=original_var._var_name, _var_type=original_var._var_type, - _var_data=original_var._var_data, + _var_data=ImmutableVarData.merge(original_var._var_data), ) object.__setattr__(self, "fn", fn) object.__setattr__(self, "original_var", original_var) - def __call__(self, *args, **kwargs) -> ImmutableVar: + def __call__(self, *args, **kwargs) -> Var: """Call the decorated function. Args: @@ -1098,3 +1204,433 @@ class ImmutableCallableVar(ImmutableVar): The hash of the object. """ return hash((self.__class__.__name__, self.original_var)) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ImmutableComputedVar(ImmutableVar): + """A field with computed getters.""" + + # Whether to track dependencies and cache computed values + _cache: bool = dataclasses.field(default=False) + + # Whether the computed var is a backend var + _backend: bool = dataclasses.field(default=False) + + # The initial value of the computed var + _initial_value: Any | types.Unset = dataclasses.field(default=types.Unset()) + + # Explicit var dependencies to track + _static_deps: set[str] = dataclasses.field(default_factory=set) + + # Whether var dependencies should be auto-determined + _auto_deps: bool = dataclasses.field(default=True) + + # Interval at which the computed var should be updated + _update_interval: Optional[datetime.timedelta] = dataclasses.field(default=None) + + _fget: Callable[[BaseState], Any] = dataclasses.field( + default_factory=lambda: lambda _: None + ) + + def __init__( + self, + fget: Callable[[BaseState], Any], + initial_value: Any | types.Unset = types.Unset(), + cache: bool = False, + deps: Optional[List[Union[str, Var]]] = None, + auto_deps: bool = True, + interval: Optional[Union[int, datetime.timedelta]] = None, + backend: bool | None = None, + **kwargs, + ): + """Initialize a ComputedVar. + + Args: + fget: The getter function. + initial_value: The initial value of the computed var. + cache: Whether to cache the computed value. + deps: Explicit var dependencies to track. + auto_deps: Whether var dependencies should be auto-determined. + interval: Interval at which the computed var should be updated. + backend: Whether the computed var is a backend var. + **kwargs: additional attributes to set on the instance + + Raises: + TypeError: If the computed var dependencies are not Var instances or var names. + """ + hints = get_type_hints(fget) + hint = hints.get("return", Any) + + kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__) + kwargs["_var_type"] = kwargs.pop("_var_type", hint) + + super(ImmutableComputedVar, self).__init__( + _var_name=kwargs.pop("_var_name"), + _var_type=kwargs.pop("_var_type"), + _var_data=ImmutableVarData.merge(kwargs.pop("_var_data", None)), + ) + + if backend is None: + backend = fget.__name__.startswith("_") + + object.__setattr__(self, "_backend", backend) + object.__setattr__(self, "_initial_value", initial_value) + object.__setattr__(self, "_cache", cache) + + if isinstance(interval, int): + interval = datetime.timedelta(seconds=interval) + + object.__setattr__(self, "_update_interval", interval) + + if deps is None: + deps = [] + else: + for dep in deps: + if isinstance(dep, Var): + continue + if isinstance(dep, str) and dep != "": + continue + raise TypeError( + "ComputedVar dependencies must be Var instances or var names (non-empty strings)." + ) + object.__setattr__( + self, + "_static_deps", + {dep._var_name if isinstance(dep, Var) else dep for dep in deps}, + ) + object.__setattr__(self, "_auto_deps", auto_deps) + + object.__setattr__(self, "_fget", fget) + + @override + def _replace(self, merge_var_data=None, **kwargs: Any) -> ImmutableComputedVar: + """Replace the attributes of the ComputedVar. + + Args: + merge_var_data: VarData to merge into the existing VarData. + **kwargs: Var fields to update. + + Returns: + The new ComputedVar instance. + + Raises: + TypeError: If kwargs contains keys that are not allowed. + """ + field_values = dict( + fget=kwargs.pop("fget", self._fget), + initial_value=kwargs.pop("initial_value", self._initial_value), + cache=kwargs.pop("cache", self._cache), + deps=kwargs.pop("deps", self._static_deps), + auto_deps=kwargs.pop("auto_deps", self._auto_deps), + interval=kwargs.pop("interval", self._update_interval), + backend=kwargs.pop("backend", self._backend), + _var_name=kwargs.pop("_var_name", self._var_name), + _var_type=kwargs.pop("_var_type", self._var_type), + _var_is_local=kwargs.pop("_var_is_local", self._var_is_local), + _var_is_string=kwargs.pop("_var_is_string", self._var_is_string), + _var_full_name_needs_state_prefix=kwargs.pop( + "_var_full_name_needs_state_prefix", + self._var_full_name_needs_state_prefix, + ), + _var_data=kwargs.pop( + "_var_data", VarData.merge(self._var_data, merge_var_data) + ), + ) + + if kwargs: + unexpected_kwargs = ", ".join(kwargs.keys()) + raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}") + + return ImmutableComputedVar(**field_values) + + @property + def _cache_attr(self) -> str: + """Get the attribute used to cache the value on the instance. + + Returns: + An attribute name. + """ + return f"__cached_{self._var_name}" + + @property + def _last_updated_attr(self) -> str: + """Get the attribute used to store the last updated timestamp. + + Returns: + An attribute name. + """ + return f"__last_updated_{self._var_name}" + + def needs_update(self, instance: BaseState) -> bool: + """Check if the computed var needs to be updated. + + Args: + instance: The state instance that the computed var is attached to. + + Returns: + True if the computed var needs to be updated, False otherwise. + """ + if self._update_interval is None: + return False + last_updated = getattr(instance, self._last_updated_attr, None) + if last_updated is None: + return True + return datetime.datetime.now() - last_updated > self._update_interval + + def __get__(self, instance: BaseState | None, owner): + """Get the ComputedVar value. + + If the value is already cached on the instance, return the cached value. + + Args: + instance: the instance of the class accessing this computed var. + owner: the class that this descriptor is attached to. + + Returns: + The value of the var for the given instance. + """ + if instance is None: + return self._replace( + _var_name=format_state_name(owner.get_full_name()) + + "." + + self._var_name, + merge_var_data=ImmutableVarData.from_state(owner), + ).guess_type() + + if not self._cache: + return self.fget(instance) + + # handle caching + if not hasattr(instance, self._cache_attr) or self.needs_update(instance): + # Set cache attr on state instance. + setattr(instance, self._cache_attr, self.fget(instance)) + # Ensure the computed var gets serialized to redis. + instance._was_touched = True + # Set the last updated timestamp on the state instance. + setattr(instance, self._last_updated_attr, datetime.datetime.now()) + return getattr(instance, self._cache_attr) + + def _deps( + self, + objclass: Type, + obj: FunctionType | CodeType | None = None, + self_name: Optional[str] = None, + ) -> set[str]: + """Determine var dependencies of this ComputedVar. + + Save references to attributes accessed on "self". Recursively called + when the function makes a method call on "self" or define comprehensions + or nested functions that may reference "self". + + Args: + objclass: the class obj this ComputedVar is attached to. + obj: the object to disassemble (defaults to the fget function). + self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions. + + Returns: + A set of variable names accessed by the given obj. + + Raises: + VarValueError: if the function references the get_state, parent_state, or substates attributes + (cannot track deps in a related state, only implicitly via parent state). + """ + if not self._auto_deps: + return self._static_deps + d = self._static_deps.copy() + if obj is None: + fget = self._fget + if fget is not None: + obj = cast(FunctionType, fget) + else: + return set() + with contextlib.suppress(AttributeError): + # unbox functools.partial + obj = cast(FunctionType, obj.func) # type: ignore + with contextlib.suppress(AttributeError): + # unbox EventHandler + obj = cast(FunctionType, obj.fn) # type: ignore + + if self_name is None and isinstance(obj, FunctionType): + try: + # the first argument to the function is the name of "self" arg + self_name = obj.__code__.co_varnames[0] + except (AttributeError, IndexError): + self_name = None + if self_name is None: + # cannot reference attributes on self if method takes no args + return set() + + invalid_names = ["get_state", "parent_state", "substates", "get_substate"] + self_is_top_of_stack = False + for instruction in dis.get_instructions(obj): + if ( + instruction.opname in ("LOAD_FAST", "LOAD_DEREF") + and instruction.argval == self_name + ): + # bytecode loaded the class instance to the top of stack, next load instruction + # is referencing an attribute on self + self_is_top_of_stack = True + continue + if self_is_top_of_stack and instruction.opname in ( + "LOAD_ATTR", + "LOAD_METHOD", + ): + try: + ref_obj = getattr(objclass, instruction.argval) + except Exception: + ref_obj = None + if instruction.argval in invalid_names: + raise VarValueError( + f"Cached var {self._var_full_name} cannot access arbitrary state via `{instruction.argval}`." + ) + if callable(ref_obj): + # recurse into callable attributes + d.update( + self._deps( + objclass=objclass, + obj=ref_obj, + ) + ) + # recurse into property fget functions + elif isinstance(ref_obj, property) and not isinstance( + ref_obj, ImmutableComputedVar + ): + d.update( + self._deps( + objclass=objclass, + obj=ref_obj.fget, # type: ignore + ) + ) + elif ( + instruction.argval in objclass.backend_vars + or instruction.argval in objclass.vars + ): + # var access + d.add(instruction.argval) + elif instruction.opname == "LOAD_CONST" and isinstance( + instruction.argval, CodeType + ): + # recurse into nested functions / comprehensions, which can reference + # instance attributes from the outer scope + d.update( + self._deps( + objclass=objclass, + obj=instruction.argval, + self_name=self_name, + ) + ) + self_is_top_of_stack = False + return d + + def mark_dirty(self, instance) -> None: + """Mark this ComputedVar as dirty. + + Args: + instance: the state instance that needs to recompute the value. + """ + with contextlib.suppress(AttributeError): + delattr(instance, self._cache_attr) + + def _determine_var_type(self) -> Type: + """Get the type of the var. + + Returns: + The type of the var. + """ + hints = get_type_hints(self._fget) + if "return" in hints: + return hints["return"] + return Any + + @property + def __class__(self) -> Type: + """Get the class of the var. + + Returns: + The class of the var. + """ + return ComputedVar + + @property + def fget(self) -> Callable[[BaseState], Any]: + """Get the getter function. + + Returns: + The getter function. + """ + return self._fget + + +def immutable_computed_var( + fget: Callable[[BaseState], Any] | None = None, + initial_value: Any | types.Unset = types.Unset(), + cache: bool = False, + deps: Optional[List[Union[str, Var]]] = None, + auto_deps: bool = True, + interval: Optional[Union[datetime.timedelta, int]] = None, + backend: bool | None = None, + _deprecated_cached_var: bool = False, + **kwargs, +) -> ( + ImmutableComputedVar | Callable[[Callable[[BaseState], Any]], ImmutableComputedVar] +): + """A ComputedVar decorator with or without kwargs. + + Args: + fget: The getter function. + initial_value: The initial value of the computed var. + cache: Whether to cache the computed value. + deps: Explicit var dependencies to track. + auto_deps: Whether var dependencies should be auto-determined. + interval: Interval at which the computed var should be updated. + backend: Whether the computed var is a backend var. + _deprecated_cached_var: Indicate usage of deprecated cached_var partial function. + **kwargs: additional attributes to set on the instance + + Returns: + A ComputedVar instance. + + Raises: + ValueError: If caching is disabled and an update interval is set. + VarDependencyError: If user supplies dependencies without caching. + """ + if _deprecated_cached_var: + console.deprecate( + feature_name="cached_var", + reason=("Use @rx.var(cache=True) instead of @rx.cached_var."), + deprecation_version="0.5.6", + removal_version="0.6.0", + ) + + if cache is False and interval is not None: + raise ValueError("Cannot set update interval without caching.") + + if cache is False and (deps is not None or auto_deps is False): + raise VarDependencyError("Cannot track dependencies without caching.") + + if fget is not None: + return ImmutableComputedVar(fget, cache=cache) + + def wrapper(fget: Callable[[BaseState], Any]) -> ImmutableComputedVar: + return ImmutableComputedVar( + fget, + initial_value=initial_value, + cache=cache, + deps=deps, + auto_deps=auto_deps, + interval=interval, + backend=backend, + **kwargs, + ) + + return wrapper + + +# Partial function of computed_var with cache=True +cached_var = functools.partial( + immutable_computed_var, cache=True, _deprecated_cached_var=True +) diff --git a/reflex/ivars/function.py b/reflex/ivars/function.py index 9033b886e..b16902b7c 100644 --- a/reflex/ivars/function.py +++ b/reflex/ivars/function.py @@ -7,6 +7,7 @@ import sys from functools import cached_property from typing import Any, Callable, Optional, Tuple, Type, Union +from reflex.utils.types import GenericType from reflex.vars import ImmutableVarData, Var, VarData from .base import ImmutableVar, LiteralVar @@ -24,9 +25,9 @@ class FunctionVar(ImmutableVar[Callable]): Returns: The function call operation. """ - return ArgsFunctionOperation( + return ArgsFunctionOperation.create( ("...args",), - VarOperationCall(self, *args, ImmutableVar.create_safe("...args")), + VarOperationCall.create(self, *args, ImmutableVar.create_safe("...args")), ) def call(self, *args: Var | Any) -> VarOperationCall: @@ -38,22 +39,31 @@ class FunctionVar(ImmutableVar[Callable]): Returns: The function call operation. """ - return VarOperationCall(self, *args) + return VarOperationCall.create(self, *args) class FunctionStringVar(FunctionVar): """Base class for immutable function vars from a string.""" - def __init__(self, func: str, _var_data: VarData | None = None) -> None: - """Initialize the function var. + @classmethod + def create( + cls, + func: str, + _var_type: Type[Callable] = Callable, + _var_data: VarData | None = None, + ) -> FunctionStringVar: + """Create a new function var from a string. Args: func: The function to call. _var_data: Additional hooks and imports associated with the Var. + + Returns: + The function var. """ - super(FunctionVar, self).__init__( + return cls( _var_name=func, - _var_type=Callable, + _var_type=_var_type, _var_data=ImmutableVarData.merge(_var_data), ) @@ -69,25 +79,6 @@ class VarOperationCall(ImmutableVar): _func: Optional[FunctionVar] = dataclasses.field(default=None) _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) - def __init__( - self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None - ): - """Initialize the function call var. - - Args: - func: The function to call. - *args: The arguments to call the function with. - _var_data: Additional hooks and imports associated with the Var. - """ - super(VarOperationCall, self).__init__( - _var_name="", - _var_type=Any, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_func", func) - object.__setattr__(self, "_args", args) - object.__delattr__(self, "_var_name") - def __getattr__(self, name): """Get an attribute of the var. @@ -133,7 +124,7 @@ class VarOperationCall(ImmutableVar): def __post_init__(self): """Post-initialize the var.""" - pass + object.__delattr__(self, "_var_name") def __hash__(self): """Hash the var. @@ -143,6 +134,32 @@ class VarOperationCall(ImmutableVar): """ return hash((self.__class__.__name__, self._func, self._args)) + @classmethod + def create( + cls, + func: FunctionVar, + *args: Var | Any, + _var_type: GenericType = Any, + _var_data: VarData | None = None, + ) -> VarOperationCall: + """Create a new function call var. + + Args: + func: The function to call. + *args: The arguments to call the function with. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The function call var. + """ + return cls( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + _func=func, + _args=args, + ) + @dataclasses.dataclass( eq=False, @@ -155,28 +172,6 @@ class ArgsFunctionOperation(FunctionVar): _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple) _return_expr: Union[Var, Any] = dataclasses.field(default=None) - def __init__( - self, - args_names: Tuple[str, ...], - return_expr: Var | Any, - _var_data: VarData | None = None, - ) -> None: - """Initialize the function with arguments var. - - Args: - args_names: The names of the arguments. - return_expr: The return expression of the function. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ArgsFunctionOperation, self).__init__( - _var_name=f"", - _var_type=Callable, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_args_names", args_names) - object.__setattr__(self, "_return_expr", return_expr) - object.__delattr__(self, "_var_name") - def __getattr__(self, name): """Get an attribute of the var. @@ -221,6 +216,7 @@ class ArgsFunctionOperation(FunctionVar): def __post_init__(self): """Post-initialize the var.""" + object.__delattr__(self, "_var_name") def __hash__(self): """Hash the var. @@ -230,6 +226,32 @@ class ArgsFunctionOperation(FunctionVar): """ return hash((self.__class__.__name__, self._args_names, self._return_expr)) + @classmethod + def create( + cls, + args_names: Tuple[str, ...], + return_expr: Var | Any, + _var_type: GenericType = Callable, + _var_data: VarData | None = None, + ) -> ArgsFunctionOperation: + """Create a new function var. + + Args: + args_names: The names of the arguments. + return_expr: The return expression of the function. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The function var. + """ + return cls( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + _args_names=args_names, + _return_expr=return_expr, + ) + @dataclasses.dataclass( eq=False, @@ -243,25 +265,8 @@ class ToFunctionOperation(FunctionVar): default_factory=lambda: LiteralVar.create(None) ) - def __init__( - self, - original_var: Var, - _var_type: Type[Callable] = Callable, - _var_data: VarData | None = None, - ) -> None: - """Initialize the function with arguments var. - - Args: - original_var: The original var to convert to a function. - _var_type: The type of the function. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ToFunctionOperation, self).__init__( - _var_name=f"", - _var_type=_var_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_original_var", original_var) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") def __getattr__(self, name): @@ -314,5 +319,29 @@ class ToFunctionOperation(FunctionVar): """ return hash((self.__class__.__name__, self._original_var)) + @classmethod + def create( + cls, + original_var: Var, + _var_type: GenericType = Callable, + _var_data: VarData | None = None, + ) -> ToFunctionOperation: + """Create a new function var. -JSON_STRINGIFY = FunctionStringVar("JSON.stringify") + Args: + original_var: The original var to convert to a function. + _var_type: The type of the function. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The function var. + """ + return cls( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + _original_var=original_var, + ) + + +JSON_STRINGIFY = FunctionStringVar.create("JSON.stringify") diff --git a/reflex/ivars/number.py b/reflex/ivars/number.py index ff33c1779..8f764cfce 100644 --- a/reflex/ivars/number.py +++ b/reflex/ivars/number.py @@ -8,12 +8,12 @@ import sys from functools import cached_property from typing import Any, Union -from reflex.utils.types import GenericType from reflex.vars import ImmutableVarData, Var, VarData from .base import ( ImmutableVar, LiteralVar, + unionize, ) @@ -29,7 +29,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number addition operation. """ - return NumberAddOperation(self, +other) + return NumberAddOperation.create(self, +other) def __radd__(self, other: number_types | boolean_types) -> NumberAddOperation: """Add two numbers. @@ -40,7 +40,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number addition operation. """ - return NumberAddOperation(+other, self) + return NumberAddOperation.create(+other, self) def __sub__(self, other: number_types | boolean_types) -> NumberSubtractOperation: """Subtract two numbers. @@ -51,7 +51,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number subtraction operation. """ - return NumberSubtractOperation(self, +other) + return NumberSubtractOperation.create(self, +other) def __rsub__(self, other: number_types | boolean_types) -> NumberSubtractOperation: """Subtract two numbers. @@ -62,7 +62,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number subtraction operation. """ - return NumberSubtractOperation(+other, self) + return NumberSubtractOperation.create(+other, self) def __abs__(self) -> NumberAbsoluteOperation: """Get the absolute value of the number. @@ -70,7 +70,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number absolute operation. """ - return NumberAbsoluteOperation(self) + return NumberAbsoluteOperation.create(self) def __mul__(self, other: number_types | boolean_types) -> NumberMultiplyOperation: """Multiply two numbers. @@ -81,7 +81,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number multiplication operation. """ - return NumberMultiplyOperation(self, +other) + return NumberMultiplyOperation.create(self, +other) def __rmul__(self, other: number_types | boolean_types) -> NumberMultiplyOperation: """Multiply two numbers. @@ -92,7 +92,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number multiplication operation. """ - return NumberMultiplyOperation(+other, self) + return NumberMultiplyOperation.create(+other, self) def __truediv__(self, other: number_types | boolean_types) -> NumberTrueDivision: """Divide two numbers. @@ -103,7 +103,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number true division operation. """ - return NumberTrueDivision(self, +other) + return NumberTrueDivision.create(self, +other) def __rtruediv__(self, other: number_types | boolean_types) -> NumberTrueDivision: """Divide two numbers. @@ -114,7 +114,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number true division operation. """ - return NumberTrueDivision(+other, self) + return NumberTrueDivision.create(+other, self) def __floordiv__(self, other: number_types | boolean_types) -> NumberFloorDivision: """Floor divide two numbers. @@ -125,7 +125,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number floor division operation. """ - return NumberFloorDivision(self, +other) + return NumberFloorDivision.create(self, +other) def __rfloordiv__(self, other: number_types | boolean_types) -> NumberFloorDivision: """Floor divide two numbers. @@ -136,7 +136,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number floor division operation. """ - return NumberFloorDivision(+other, self) + return NumberFloorDivision.create(+other, self) def __mod__(self, other: number_types | boolean_types) -> NumberModuloOperation: """Modulo two numbers. @@ -147,7 +147,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number modulo operation. """ - return NumberModuloOperation(self, +other) + return NumberModuloOperation.create(self, +other) def __rmod__(self, other: number_types | boolean_types) -> NumberModuloOperation: """Modulo two numbers. @@ -158,7 +158,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number modulo operation. """ - return NumberModuloOperation(+other, self) + return NumberModuloOperation.create(+other, self) def __pow__(self, other: number_types | boolean_types) -> NumberExponentOperation: """Exponentiate two numbers. @@ -169,7 +169,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number exponent operation. """ - return NumberExponentOperation(self, +other) + return NumberExponentOperation.create(self, +other) def __rpow__(self, other: number_types | boolean_types) -> NumberExponentOperation: """Exponentiate two numbers. @@ -180,7 +180,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number exponent operation. """ - return NumberExponentOperation(+other, self) + return NumberExponentOperation.create(+other, self) def __neg__(self) -> NumberNegateOperation: """Negate the number. @@ -188,7 +188,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number negation operation. """ - return NumberNegateOperation(self) + return NumberNegateOperation.create(self) def __invert__(self) -> BooleanNotOperation: """Boolean NOT the number. @@ -196,7 +196,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The boolean NOT operation. """ - return BooleanNotOperation(self.bool()) + return BooleanNotOperation.create(self.bool()) def __pos__(self) -> NumberVar: """Positive the number. @@ -212,7 +212,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number round operation. """ - return NumberRoundOperation(self) + return NumberRoundOperation.create(self) def __ceil__(self) -> NumberCeilOperation: """Ceil the number. @@ -220,7 +220,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number ceil operation. """ - return NumberCeilOperation(self) + return NumberCeilOperation.create(self) def __floor__(self) -> NumberFloorOperation: """Floor the number. @@ -228,7 +228,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number floor operation. """ - return NumberFloorOperation(self) + return NumberFloorOperation.create(self) def __trunc__(self) -> NumberTruncOperation: """Trunc the number. @@ -236,7 +236,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number trunc operation. """ - return NumberTruncOperation(self) + return NumberTruncOperation.create(self) def __lt__(self, other: Any) -> LessThanOperation: """Less than comparison. @@ -248,8 +248,8 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return LessThanOperation(self, +other) - return LessThanOperation(self, other) + return LessThanOperation.create(self, +other) + return LessThanOperation.create(self, other) def __le__(self, other: Any) -> LessThanOrEqualOperation: """Less than or equal comparison. @@ -261,8 +261,8 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return LessThanOrEqualOperation(self, +other) - return LessThanOrEqualOperation(self, other) + return LessThanOrEqualOperation.create(self, +other) + return LessThanOrEqualOperation.create(self, other) def __eq__(self, other: Any) -> EqualOperation: """Equal comparison. @@ -274,8 +274,8 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return EqualOperation(self, +other) - return EqualOperation(self, other) + return EqualOperation.create(self, +other) + return EqualOperation.create(self, other) def __ne__(self, other: Any) -> NotEqualOperation: """Not equal comparison. @@ -287,8 +287,8 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return NotEqualOperation(self, +other) - return NotEqualOperation(self, other) + return NotEqualOperation.create(self, +other) + return NotEqualOperation.create(self, other) def __gt__(self, other: Any) -> GreaterThanOperation: """Greater than comparison. @@ -300,8 +300,8 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return GreaterThanOperation(self, +other) - return GreaterThanOperation(self, other) + return GreaterThanOperation.create(self, +other) + return GreaterThanOperation.create(self, other) def __ge__(self, other: Any) -> GreaterThanOrEqualOperation: """Greater than or equal comparison. @@ -313,8 +313,8 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return GreaterThanOrEqualOperation(self, +other) - return GreaterThanOrEqualOperation(self, other) + return GreaterThanOrEqualOperation.create(self, +other) + return GreaterThanOrEqualOperation.create(self, other) def bool(self) -> NotEqualOperation: """Boolean conversion. @@ -322,7 +322,7 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The boolean value of the number. """ - return NotEqualOperation(self, 0) + return NotEqualOperation.create(self, 0) @dataclasses.dataclass( @@ -333,29 +333,15 @@ class NumberVar(ImmutableVar[Union[int, float]]): class BinaryNumberOperation(NumberVar): """Base class for immutable number vars that are the result of a binary operation.""" - a: number_types = dataclasses.field(default=0) - b: number_types = dataclasses.field(default=0) + _lhs: NumberVar = dataclasses.field( + default_factory=lambda: LiteralNumberVar.create(0) + ) + _rhs: NumberVar = dataclasses.field( + default_factory=lambda: LiteralNumberVar.create(0) + ) - def __init__( - self, - a: number_types, - b: number_types, - _var_data: VarData | None = None, - ): - """Initialize the binary number operation var. - - Args: - a: The first number. - b: The second number. - _var_data: Additional hooks and imports associated with the Var. - """ - super(BinaryNumberOperation, self).__init__( - _var_name="", - _var_type=float, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a) - object.__setattr__(self, "b", b) + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") @cached_property @@ -389,8 +375,12 @@ class BinaryNumberOperation(NumberVar): Returns: The VarData of the components and all of its children. """ - first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) + first_value = ( + self._lhs if isinstance(self._lhs, Var) else LiteralNumberVar(self._lhs) + ) + second_value = ( + self._rhs if isinstance(self._rhs, Var) else LiteralNumberVar(self._rhs) + ) return ImmutableVarData.merge( first_value._get_all_var_data(), second_value._get_all_var_data(), @@ -406,7 +396,33 @@ class BinaryNumberOperation(NumberVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self.a, self.b)) + return hash((self.__class__.__name__, self._lhs, self._rhs)) + + @classmethod + def create( + cls, lhs: number_types, rhs: number_types, _var_data: VarData | None = None + ): + """Create the binary number operation var. + + Args: + lhs: The first number. + rhs: The second number. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The binary number operation var. + """ + _lhs, _rhs = map( + lambda v: LiteralNumberVar.create(v) if not isinstance(v, NumberVar) else v, + (lhs, rhs), + ) + return cls( + _var_name="", + _var_type=unionize(_lhs._var_type, _rhs._var_type), + _var_data=ImmutableVarData.merge(_var_data), + _lhs=_lhs, + _rhs=_rhs, + ) @dataclasses.dataclass( @@ -417,25 +433,12 @@ class BinaryNumberOperation(NumberVar): class UnaryNumberOperation(NumberVar): """Base class for immutable number vars that are the result of a unary operation.""" - a: number_types = dataclasses.field(default=0) + _value: NumberVar = dataclasses.field( + default_factory=lambda: LiteralNumberVar.create(0) + ) - def __init__( - self, - a: number_types, - _var_data: VarData | None = None, - ): - """Initialize the unary number operation var. - - Args: - a: The number. - _var_data: Additional hooks and imports associated with the Var. - """ - super(UnaryNumberOperation, self).__init__( - _var_name="", - _var_type=float, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a) + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") @cached_property @@ -469,7 +472,11 @@ class UnaryNumberOperation(NumberVar): Returns: The VarData of the components and all of its children. """ - value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) + value = ( + self._value + if isinstance(self._value, Var) + else LiteralNumberVar(self._value) + ) return ImmutableVarData.merge(value._get_all_var_data(), self._var_data) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -481,7 +488,25 @@ class UnaryNumberOperation(NumberVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self.a)) + return hash((self.__class__.__name__, self._value)) + + @classmethod + def create(cls, value: NumberVar, _var_data: VarData | None = None): + """Create the unary number operation var. + + Args: + value: The number. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The unary number operation var. + """ + return cls( + _var_name="", + _var_type=value._var_type, + _var_data=ImmutableVarData.merge(_var_data), + _value=value, + ) class NumberAddOperation(BinaryNumberOperation): @@ -494,9 +519,7 @@ class NumberAddOperation(BinaryNumberOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) - return f"({str(first_value)} + {str(second_value)})" + return f"({str(self._lhs)} + {str(self._rhs)})" class NumberSubtractOperation(BinaryNumberOperation): @@ -509,9 +532,7 @@ class NumberSubtractOperation(BinaryNumberOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) - return f"({str(first_value)} - {str(second_value)})" + return f"({str(self._lhs)} - {str(self._rhs)})" class NumberAbsoluteOperation(UnaryNumberOperation): @@ -524,8 +545,7 @@ class NumberAbsoluteOperation(UnaryNumberOperation): Returns: The name of the var. """ - value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - return f"Math.abs({str(value)})" + return f"Math.abs({str(self._value)})" class NumberMultiplyOperation(BinaryNumberOperation): @@ -538,9 +558,7 @@ class NumberMultiplyOperation(BinaryNumberOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) - return f"({str(first_value)} * {str(second_value)})" + return f"({str(self._lhs)} * {str(self._rhs)})" class NumberNegateOperation(UnaryNumberOperation): @@ -553,8 +571,7 @@ class NumberNegateOperation(UnaryNumberOperation): Returns: The name of the var. """ - value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - return f"-({str(value)})" + return f"-({str(self._value)})" class NumberTrueDivision(BinaryNumberOperation): @@ -567,9 +584,7 @@ class NumberTrueDivision(BinaryNumberOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) - return f"({str(first_value)} / {str(second_value)})" + return f"({str(self._lhs)} / {str(self._rhs)})" class NumberFloorDivision(BinaryNumberOperation): @@ -582,9 +597,7 @@ class NumberFloorDivision(BinaryNumberOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) - return f"Math.floor({str(first_value)} / {str(second_value)})" + return f"Math.floor({str(self._lhs)} / {str(self._rhs)})" class NumberModuloOperation(BinaryNumberOperation): @@ -597,9 +610,7 @@ class NumberModuloOperation(BinaryNumberOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) - return f"({str(first_value)} % {str(second_value)})" + return f"({str(self._lhs)} % {str(self._rhs)})" class NumberExponentOperation(BinaryNumberOperation): @@ -612,9 +623,7 @@ class NumberExponentOperation(BinaryNumberOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralNumberVar(self.b) - return f"({str(first_value)} ** {str(second_value)})" + return f"({str(self._lhs)} ** {str(self._rhs)})" class NumberRoundOperation(UnaryNumberOperation): @@ -627,8 +636,7 @@ class NumberRoundOperation(UnaryNumberOperation): Returns: The name of the var. """ - value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - return f"Math.round({str(value)})" + return f"Math.round({str(self._value)})" class NumberCeilOperation(UnaryNumberOperation): @@ -641,8 +649,7 @@ class NumberCeilOperation(UnaryNumberOperation): Returns: The name of the var. """ - value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - return f"Math.ceil({str(value)})" + return f"Math.ceil({str(self._value)})" class NumberFloorOperation(UnaryNumberOperation): @@ -655,8 +662,7 @@ class NumberFloorOperation(UnaryNumberOperation): Returns: The name of the var. """ - value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - return f"Math.floor({str(value)})" + return f"Math.floor({str(self._value)})" class NumberTruncOperation(UnaryNumberOperation): @@ -669,8 +675,7 @@ class NumberTruncOperation(UnaryNumberOperation): Returns: The name of the var. """ - value = self.a if isinstance(self.a, Var) else LiteralNumberVar(self.a) - return f"Math.trunc({str(value)})" + return f"Math.trunc({str(self._value)})" class BooleanVar(ImmutableVar[bool]): @@ -682,7 +687,7 @@ class BooleanVar(ImmutableVar[bool]): Returns: The boolean NOT operation. """ - return BooleanNotOperation(self) + return BooleanNotOperation.create(self) def __int__(self) -> BooleanToIntOperation: """Convert the boolean to an int. @@ -690,7 +695,7 @@ class BooleanVar(ImmutableVar[bool]): Returns: The boolean to int operation. """ - return BooleanToIntOperation(self) + return BooleanToIntOperation.create(self) def __pos__(self) -> BooleanToIntOperation: """Convert the boolean to an int. @@ -698,7 +703,7 @@ class BooleanVar(ImmutableVar[bool]): Returns: The boolean to int operation. """ - return BooleanToIntOperation(self) + return BooleanToIntOperation.create(self) def bool(self) -> BooleanVar: """Boolean conversion. @@ -717,7 +722,7 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return LessThanOperation(+self, +other) + return LessThanOperation.create(+self, +other) def __le__(self, other: boolean_types | number_types) -> LessThanOrEqualOperation: """Less than or equal comparison. @@ -728,7 +733,7 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return LessThanOrEqualOperation(+self, +other) + return LessThanOrEqualOperation.create(+self, +other) def __eq__(self, other: boolean_types | number_types) -> EqualOperation: """Equal comparison. @@ -739,7 +744,7 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return EqualOperation(+self, +other) + return EqualOperation.create(+self, +other) def __ne__(self, other: boolean_types | number_types) -> NotEqualOperation: """Not equal comparison. @@ -750,7 +755,7 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return NotEqualOperation(+self, +other) + return NotEqualOperation.create(+self, +other) def __gt__(self, other: boolean_types | number_types) -> GreaterThanOperation: """Greater than comparison. @@ -761,7 +766,7 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return GreaterThanOperation(+self, +other) + return GreaterThanOperation.create(+self, +other) def __ge__( self, other: boolean_types | number_types @@ -774,7 +779,7 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return GreaterThanOrEqualOperation(+self, +other) + return GreaterThanOrEqualOperation.create(+self, +other) @dataclasses.dataclass( @@ -785,25 +790,12 @@ class BooleanVar(ImmutableVar[bool]): class BooleanToIntOperation(NumberVar): """Base class for immutable number vars that are the result of a boolean to int operation.""" - a: boolean_types = dataclasses.field(default=False) + _value: BooleanVar = dataclasses.field( + default_factory=lambda: LiteralBooleanVar.create(False) + ) - def __init__( - self, - a: boolean_types, - _var_data: VarData | None = None, - ): - """Initialize the boolean to int operation var. - - Args: - a: The boolean. - _var_data: Additional hooks and imports associated with the Var. - """ - super(BooleanToIntOperation, self).__init__( - _var_name="", - _var_type=int, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a) + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") @cached_property @@ -813,7 +805,7 @@ class BooleanToIntOperation(NumberVar): Returns: The name of the var. """ - return f"({str(self.a)} ? 1 : 0)" + return f"({str(self._value)} ? 1 : 0)" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -836,7 +828,7 @@ class BooleanToIntOperation(NumberVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data() if isinstance(self.a, Var) else None, + self._value._get_all_var_data() if isinstance(self._value, Var) else None, self._var_data, ) @@ -849,7 +841,25 @@ class BooleanToIntOperation(NumberVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self.a)) + return hash((self.__class__.__name__, self._value)) + + @classmethod + def create(cls, value: BooleanVar, _var_data: VarData | None = None): + """Create the boolean to int operation var. + + Args: + value: The boolean. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The boolean to int operation var. + """ + return cls( + _var_name="", + _var_type=int, + _var_data=ImmutableVarData.merge(_var_data), + _value=value, + ) @dataclasses.dataclass( @@ -860,29 +870,15 @@ class BooleanToIntOperation(NumberVar): class ComparisonOperation(BooleanVar): """Base class for immutable boolean vars that are the result of a comparison operation.""" - a: Var = dataclasses.field(default_factory=lambda: LiteralBooleanVar(True)) - b: Var = dataclasses.field(default_factory=lambda: LiteralBooleanVar(True)) + _lhs: Var = dataclasses.field( + default_factory=lambda: LiteralBooleanVar.create(False) + ) + _rhs: Var = dataclasses.field( + default_factory=lambda: LiteralBooleanVar.create(False) + ) - def __init__( - self, - a: Var | Any, - b: Var | Any, - _var_data: VarData | None = None, - ): - """Initialize the comparison operation var. - - Args: - a: The first value. - b: The second value. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ComparisonOperation, self).__init__( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralVar.create(a)) - object.__setattr__(self, "b", b if isinstance(b, Var) else LiteralVar.create(b)) + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") @cached_property @@ -914,10 +910,8 @@ class ComparisonOperation(BooleanVar): Returns: The VarData of the components and all of its children. """ - first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) return ImmutableVarData.merge( - first_value._get_all_var_data(), second_value._get_all_var_data() + self._lhs._get_all_var_data(), self._rhs._get_all_var_data() ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -929,7 +923,28 @@ class ComparisonOperation(BooleanVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self.a, self.b)) + return hash((self.__class__.__name__, self._lhs, self._rhs)) + + @classmethod + def create(cls, lhs: Var | Any, rhs: Var | Any, _var_data: VarData | None = None): + """Create the comparison operation var. + + Args: + lhs: The first value. + rhs: The second value. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The comparison operation var. + """ + lhs, rhs = map(LiteralVar.create, (lhs, rhs)) + return cls( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + _lhs=lhs, + _rhs=rhs, + ) class GreaterThanOperation(ComparisonOperation): @@ -942,9 +957,7 @@ class GreaterThanOperation(ComparisonOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) - return f"({str(first_value)} > {str(second_value)})" + return f"({str(self._lhs)} > {str(self._rhs)})" class GreaterThanOrEqualOperation(ComparisonOperation): @@ -957,9 +970,7 @@ class GreaterThanOrEqualOperation(ComparisonOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) - return f"({str(first_value)} >= {str(second_value)})" + return f"({str(self._lhs)} >= {str(self._rhs)})" class LessThanOperation(ComparisonOperation): @@ -972,9 +983,7 @@ class LessThanOperation(ComparisonOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) - return f"({str(first_value)} < {str(second_value)})" + return f"({str(self._lhs)} < {str(self._rhs)})" class LessThanOrEqualOperation(ComparisonOperation): @@ -987,9 +996,7 @@ class LessThanOrEqualOperation(ComparisonOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) - return f"({str(first_value)} <= {str(second_value)})" + return f"({str(self._lhs)} <= {str(self._rhs)})" class EqualOperation(ComparisonOperation): @@ -1002,9 +1009,7 @@ class EqualOperation(ComparisonOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) - return f"({str(first_value)} === {str(second_value)})" + return f"({str(self._lhs)} === {str(self._rhs)})" class NotEqualOperation(ComparisonOperation): @@ -1017,9 +1022,7 @@ class NotEqualOperation(ComparisonOperation): Returns: The name of the var. """ - first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) - return f"({str(first_value)} != {str(second_value)})" + return f"({str(self._lhs)} != {str(self._rhs)})" @dataclasses.dataclass( @@ -1030,26 +1033,15 @@ class NotEqualOperation(ComparisonOperation): class LogicalOperation(BooleanVar): """Base class for immutable boolean vars that are the result of a logical operation.""" - a: boolean_types = dataclasses.field(default=False) - b: boolean_types = dataclasses.field(default=False) + _lhs: BooleanVar = dataclasses.field( + default_factory=lambda: LiteralBooleanVar.create(False) + ) + _rhs: BooleanVar = dataclasses.field( + default_factory=lambda: LiteralBooleanVar.create(False) + ) - def __init__( - self, a: boolean_types, b: boolean_types, _var_data: VarData | None = None - ): - """Initialize the logical operation var. - - Args: - a: The first value. - b: The second value. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LogicalOperation, self).__init__( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a) - object.__setattr__(self, "b", b) + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") @cached_property @@ -1081,10 +1073,8 @@ class LogicalOperation(BooleanVar): Returns: The VarData of the components and all of its children. """ - first_value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) - second_value = self.b if isinstance(self.b, Var) else LiteralVar.create(self.b) return ImmutableVarData.merge( - first_value._get_all_var_data(), second_value._get_all_var_data() + self._lhs._get_all_var_data(), self._rhs._get_all_var_data() ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -1096,27 +1086,51 @@ class LogicalOperation(BooleanVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self.a, self.b)) + return hash((self.__class__.__name__, self._lhs, self._rhs)) - -class BooleanNotOperation(BooleanVar): - """Base class for immutable boolean vars that are the result of a logical NOT operation.""" - - a: boolean_types = dataclasses.field() - - def __init__(self, a: boolean_types, _var_data: VarData | None = None): - """Initialize the logical NOT operation var. + @classmethod + def create( + cls, lhs: boolean_types, rhs: boolean_types, _var_data: VarData | None = None + ): + """Create the logical operation var. Args: - a: The value. + lhs: The first boolean. + rhs: The second boolean. _var_data: Additional hooks and imports associated with the Var. + + Returns: + The logical operation var. """ - super(BooleanNotOperation, self).__init__( + lhs, rhs = map( + lambda v: ( + LiteralBooleanVar.create(v) if not isinstance(v, BooleanVar) else v + ), + (lhs, rhs), + ) + return cls( _var_name="", _var_type=bool, _var_data=ImmutableVarData.merge(_var_data), + _lhs=lhs, + _rhs=rhs, ) - object.__setattr__(self, "a", a) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class BooleanNotOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a logical NOT operation.""" + + _value: BooleanVar = dataclasses.field( + default_factory=lambda: LiteralBooleanVar.create(False) + ) + + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") @cached_property @@ -1126,8 +1140,7 @@ class BooleanNotOperation(BooleanVar): Returns: The name of the var. """ - value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) - return f"!({str(value)})" + return f"!({str(self._value)})" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -1149,8 +1162,7 @@ class BooleanNotOperation(BooleanVar): Returns: The VarData of the components and all of its children. """ - value = self.a if isinstance(self.a, Var) else LiteralVar.create(self.a) - return ImmutableVarData.merge(value._get_all_var_data()) + return ImmutableVarData.merge(self._value._get_all_var_data()) def _get_all_var_data(self) -> ImmutableVarData | None: return self._cached_get_all_var_data @@ -1161,7 +1173,26 @@ class BooleanNotOperation(BooleanVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self.a)) + return hash((self.__class__.__name__, self._value)) + + @classmethod + def create(cls, value: boolean_types, _var_data: VarData | None = None): + """Create the logical NOT operation var. + + Args: + value: The value. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The logical NOT operation var. + """ + value = value if isinstance(value, Var) else LiteralBooleanVar.create(value) + return cls( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + _value=value, + ) @dataclasses.dataclass( @@ -1174,24 +1205,6 @@ class LiteralBooleanVar(LiteralVar, BooleanVar): _var_value: bool = dataclasses.field(default=False) - def __init__( - self, - _var_value: bool, - _var_data: VarData | None = None, - ): - """Initialize the boolean var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralBooleanVar, self).__init__( - _var_name="true" if _var_value else "false", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_var_value", _var_value) - def __hash__(self) -> int: """Hash the var. @@ -1208,6 +1221,24 @@ class LiteralBooleanVar(LiteralVar, BooleanVar): """ return "true" if self._var_value else "false" + @classmethod + def create(cls, value: bool, _var_data: VarData | None = None): + """Create the boolean var. + + Args: + value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The boolean var. + """ + return cls( + _var_name="true" if value else "false", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + _var_value=value, + ) + @dataclasses.dataclass( eq=False, @@ -1219,24 +1250,6 @@ class LiteralNumberVar(LiteralVar, NumberVar): _var_value: float | int = dataclasses.field(default=0) - def __init__( - self, - _var_value: float | int, - _var_data: VarData | None = None, - ): - """Initialize the number var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralNumberVar, self).__init__( - _var_name=str(_var_value), - _var_type=type(_var_value), - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_var_value", _var_value) - def __hash__(self) -> int: """Hash the var. @@ -1253,9 +1266,27 @@ class LiteralNumberVar(LiteralVar, NumberVar): """ return json.dumps(self._var_value) + @classmethod + def create(cls, value: float | int, _var_data: VarData | None = None): + """Create the number var. -number_types = Union[NumberVar, LiteralNumberVar, int, float] -boolean_types = Union[BooleanVar, LiteralBooleanVar, bool] + Args: + value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The number var. + """ + return cls( + _var_name=str(value), + _var_type=type(value), + _var_data=ImmutableVarData.merge(_var_data), + _var_value=value, + ) + + +number_types = Union[NumberVar, int, float] +boolean_types = Union[BooleanVar, bool] @dataclasses.dataclass( @@ -1267,28 +1298,11 @@ class ToNumberVarOperation(NumberVar): """Base class for immutable number vars that are the result of a number operation.""" _original_value: Var = dataclasses.field( - default_factory=lambda: LiteralNumberVar(0) + default_factory=lambda: LiteralNumberVar.create(0) ) - def __init__( - self, - _original_value: Var, - _var_type: type[int] | type[float] = float, - _var_data: VarData | None = None, - ): - """Initialize the number var. - - Args: - _original_value: The original value. - _var_type: The type of the Var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ToNumberVarOperation, self).__init__( - _var_name="", - _var_type=_var_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_original_value", _original_value) + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") @cached_property @@ -1335,6 +1349,30 @@ class ToNumberVarOperation(NumberVar): """ return hash((self.__class__.__name__, self._original_value)) + @classmethod + def create( + cls, + value: Var, + _var_type: type[int] | type[float] = float, + _var_data: VarData | None = None, + ): + """Create the number var. + + Args: + value: The value of the var. + _var_type: The type of the Var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The number var. + """ + return cls( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + _original_value=value, + ) + @dataclasses.dataclass( eq=False, @@ -1345,28 +1383,9 @@ class ToBooleanVarOperation(BooleanVar): """Base class for immutable boolean vars that are the result of a boolean operation.""" _original_value: Var = dataclasses.field( - default_factory=lambda: LiteralBooleanVar(False) + default_factory=lambda: LiteralBooleanVar.create(False) ) - def __init__( - self, - _original_value: Var, - _var_data: VarData | None = None, - ): - """Initialize the boolean var. - - Args: - _original_value: The original value. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ToBooleanVarOperation, self).__init__( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_original_value", _original_value) - object.__delattr__(self, "_var_name") - @cached_property def _cached_var_name(self) -> str: """The name of the var. @@ -1411,46 +1430,52 @@ class ToBooleanVarOperation(BooleanVar): """ return hash((self.__class__.__name__, self._original_value)) + def __post_init__(self): + """Post initialization.""" + object.__delattr__(self, "_var_name") + @classmethod + def create( + cls, + value: Var, + _var_data: VarData | None = None, + ): + """Create the boolean var. + + Args: + value: The value of the var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The boolean var. + """ + return cls( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + _original_value=value, + ) + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) class TernaryOperator(ImmutableVar): """Base class for immutable vars that are the result of a ternary operation.""" - condition: Var = dataclasses.field(default_factory=lambda: LiteralBooleanVar(False)) - if_true: Var = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) - if_false: Var = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + _condition: BooleanVar = dataclasses.field( + default_factory=lambda: LiteralBooleanVar.create(False) + ) + _if_true: Var = dataclasses.field( + default_factory=lambda: LiteralNumberVar.create(0) + ) + _if_false: Var = dataclasses.field( + default_factory=lambda: LiteralNumberVar.create(0) + ) - def __init__( - self, - condition: Var | Any, - if_true: Var | Any, - if_false: Var | Any, - _var_type: GenericType | None = None, - _var_data: VarData | None = None, - ): - """Initialize the ternary operation var. - - Args: - condition: The condition. - if_true: The value if the condition is true. - if_false: The value if the condition is false. - _var_data: Additional hooks and imports associated with the Var. - """ - condition = ( - condition if isinstance(condition, Var) else LiteralVar.create(condition) - ) - if_true = if_true if isinstance(if_true, Var) else LiteralVar.create(if_true) - if_false = ( - if_false if isinstance(if_false, Var) else LiteralVar.create(if_false) - ) - - super(TernaryOperator, self).__init__( - _var_name="", - _var_type=_var_type or Union[if_true._var_type, if_false._var_type], - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "condition", condition) - object.__setattr__(self, "if_true", if_true) - object.__setattr__(self, "if_false", if_false) + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") @cached_property @@ -1460,7 +1485,9 @@ class TernaryOperator(ImmutableVar): Returns: The name of the var. """ - return f"({str(self.condition)} ? {str(self.if_true)} : {str(self.if_false)})" + return ( + f"({str(self._condition)} ? {str(self._if_true)} : {str(self._if_false)})" + ) def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -1483,9 +1510,9 @@ class TernaryOperator(ImmutableVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.condition._get_all_var_data(), - self.if_true._get_all_var_data(), - self.if_false._get_all_var_data(), + self._condition._get_all_var_data(), + self._if_true._get_all_var_data(), + self._if_false._get_all_var_data(), self._var_data, ) @@ -1499,5 +1526,42 @@ class TernaryOperator(ImmutableVar): int: The hash value of the object. """ return hash( - (self.__class__.__name__, self.condition, self.if_true, self.if_false) + (self.__class__.__name__, self._condition, self._if_true, self._if_false) + ) + + @classmethod + def create( + cls, + condition: boolean_types, + if_true: Var | Any, + if_false: Var | Any, + _var_data: VarData | None = None, + ): + """Create the ternary operation var. + + Args: + condition: The condition. + if_true: The value if the condition is true. + if_false: The value if the condition is false. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The ternary operation var. + """ + condition = ( + condition + if isinstance(condition, Var) + else LiteralBooleanVar.create(condition) + ) + _if_true, _if_false = map( + lambda v: (LiteralVar.create(v) if not isinstance(v, Var) else v), + (if_true, if_false), + ) + return TernaryOperator( + _var_name="", + _var_type=unionize(_if_true._var_type, _if_false._var_type), + _var_data=ImmutableVarData.merge(_var_data), + _condition=condition, + _if_true=_if_true, + _if_false=_if_false, ) diff --git a/reflex/ivars/object.py b/reflex/ivars/object.py index 8a5dcb756..f2b07c754 100644 --- a/reflex/ivars/object.py +++ b/reflex/ivars/object.py @@ -34,7 +34,7 @@ from .base import ( from .number import BooleanVar, NumberVar from .sequence import ArrayVar, StringVar -OBJECT_TYPE = TypeVar("OBJECT_TYPE") +OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) KEY_TYPE = TypeVar("KEY_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE") @@ -79,7 +79,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The keys of the object. """ - return ObjectKeysOperation(self) + return ObjectKeysOperation.create(self) @overload def values( @@ -95,7 +95,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The values of the object. """ - return ObjectValuesOperation(self) + return ObjectValuesOperation.create(self) @overload def entries( @@ -111,7 +111,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The entries of the object. """ - return ObjectEntriesOperation(self) + return ObjectEntriesOperation.create(self) def merge(self, other: ObjectVar) -> ObjectMergeOperation: """Merge two objects. @@ -122,7 +122,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The merged object. """ - return ObjectMergeOperation(self, other) + return ObjectMergeOperation.create(self, other) # NoReturn is used here to catch when key value is Any @overload @@ -180,7 +180,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The item from the object. """ - return ObjectItemOperation(self, key).guess_type() + return ObjectItemOperation.create(self, key).guess_type() # NoReturn is used here to catch when key value is Any @overload @@ -253,9 +253,9 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): f"The State var `{self._var_name}` has no attribute '{name}' or may have been annotated " f"wrongly." ) - return ObjectItemOperation(self, name, attribute_type).guess_type() + return ObjectItemOperation.create(self, name, attribute_type).guess_type() else: - return ObjectItemOperation(self, name).guess_type() + return ObjectItemOperation.create(self, name).guess_type() def contains(self, key: Var | Any) -> BooleanVar: """Check if the object contains a key. @@ -266,7 +266,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The result of the check. """ - return ObjectHasOwnProperty(self, key) + return ObjectHasOwnProperty.create(self, key) @dataclasses.dataclass( @@ -281,29 +281,8 @@ class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]): default_factory=dict ) - def __init__( - self: LiteralObjectVar[OBJECT_TYPE], - _var_value: OBJECT_TYPE, - _var_type: Type[OBJECT_TYPE] | None = None, - _var_data: VarData | None = None, - ): - """Initialize the object var. - - Args: - _var_value: The value of the var. - _var_type: The type of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralObjectVar, self).__init__( - _var_name="", - _var_type=(figure_out_type(_var_value) if _var_type is None else _var_type), - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, - "_var_value", - _var_value, - ) + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") def _key_type(self) -> Type: @@ -409,6 +388,30 @@ class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]): """ return hash((self.__class__.__name__, self._var_name)) + @classmethod + def create( + cls, + _var_value: OBJECT_TYPE, + _var_type: GenericType | None = None, + _var_data: VarData | None = None, + ) -> LiteralObjectVar[OBJECT_TYPE]: + """Create the literal object var. + + Args: + _var_value: The value of the var. + _var_type: The type of the var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The literal object var. + """ + return LiteralObjectVar( + _var_name="", + _var_type=(figure_out_type(_var_value) if _var_type is None else _var_type), + _var_data=ImmutableVarData.merge(_var_data), + _var_value=_var_value, + ) + @dataclasses.dataclass( eq=False, @@ -418,26 +421,12 @@ class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]): class ObjectToArrayOperation(ArrayVar): """Base class for object to array operations.""" - value: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + _value: ObjectVar = dataclasses.field( + default_factory=lambda: LiteralObjectVar.create({}) + ) - def __init__( - self, - _var_value: ObjectVar, - _var_type: Type = list, - _var_data: VarData | None = None, - ): - """Initialize the object to array operation. - - Args: - _var_value: The value of the operation. - _var_data: Additional hooks and imports associated with the operation. - """ - super(ObjectToArrayOperation, self).__init__( - _var_name="", - _var_type=_var_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "value", _var_value) + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") @cached_property @@ -472,7 +461,7 @@ class ObjectToArrayOperation(ArrayVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.value._get_all_var_data(), + self._value._get_all_var_data(), self._var_data, ) @@ -490,26 +479,37 @@ class ObjectToArrayOperation(ArrayVar): Returns: The hash of the operation. """ - return hash((self.__class__.__name__, self.value)) + return hash((self.__class__.__name__, self._value)) + + @classmethod + def create( + cls, + value: ObjectVar, + _var_type: GenericType | None = None, + _var_data: VarData | None = None, + ) -> ObjectToArrayOperation: + """Create the object to array operation. + + Args: + value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + + Returns: + The object to array operation. + """ + return cls( + _var_name="", + _var_type=list if _var_type is None else _var_type, + _var_data=ImmutableVarData.merge(_var_data), + _value=value, + ) class ObjectKeysOperation(ObjectToArrayOperation): """Operation to get the keys of an object.""" - def __init__( - self, - value: ObjectVar, - _var_data: VarData | None = None, - ): - """Initialize the object keys operation. - - Args: - value: The value of the operation. - _var_data: Additional hooks and imports associated with the operation. - """ - super(ObjectKeysOperation, self).__init__( - value, List[value._key_type()], _var_data - ) + # value, List[value._key_type()], _var_data + # ) @cached_property def _cached_var_name(self) -> str: @@ -518,27 +518,34 @@ class ObjectKeysOperation(ObjectToArrayOperation): Returns: The name of the operation. """ - return f"Object.keys({self.value._var_name})" + return f"Object.keys({str(self._value)})" + + @classmethod + def create( + cls, + value: ObjectVar, + _var_data: VarData | None = None, + ) -> ObjectKeysOperation: + """Create the object keys operation. + + Args: + value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + + Returns: + The object keys operation. + """ + return cls( + _var_name="", + _var_type=List[str], + _var_data=ImmutableVarData.merge(_var_data), + _value=value, + ) class ObjectValuesOperation(ObjectToArrayOperation): """Operation to get the values of an object.""" - def __init__( - self, - value: ObjectVar, - _var_data: VarData | None = None, - ): - """Initialize the object values operation. - - Args: - value: The value of the operation. - _var_data: Additional hooks and imports associated with the operation. - """ - super(ObjectValuesOperation, self).__init__( - value, List[value._value_type()], _var_data - ) - @cached_property def _cached_var_name(self) -> str: """The name of the operation. @@ -546,27 +553,34 @@ class ObjectValuesOperation(ObjectToArrayOperation): Returns: The name of the operation. """ - return f"Object.values({self.value._var_name})" + return f"Object.values({self._value._var_name})" + + @classmethod + def create( + cls, + value: ObjectVar, + _var_data: VarData | None = None, + ) -> ObjectValuesOperation: + """Create the object values operation. + + Args: + value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + + Returns: + The object values operation. + """ + return cls( + _var_name="", + _var_type=List[value._value_type()], + _var_data=ImmutableVarData.merge(_var_data), + _value=value, + ) class ObjectEntriesOperation(ObjectToArrayOperation): """Operation to get the entries of an object.""" - def __init__( - self, - value: ObjectVar, - _var_data: VarData | None = None, - ): - """Initialize the object entries operation. - - Args: - value: The value of the operation. - _var_data: Additional hooks and imports associated with the operation. - """ - super(ObjectEntriesOperation, self).__init__( - value, List[Tuple[value._key_type(), value._value_type()]], _var_data - ) - @cached_property def _cached_var_name(self) -> str: """The name of the operation. @@ -574,7 +588,29 @@ class ObjectEntriesOperation(ObjectToArrayOperation): Returns: The name of the operation. """ - return f"Object.entries({self.value._var_name})" + return f"Object.entries({self._value._var_name})" + + @classmethod + def create( + cls, + value: ObjectVar, + _var_data: VarData | None = None, + ) -> ObjectEntriesOperation: + """Create the object entries operation. + + Args: + value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + + Returns: + The object entries operation. + """ + return cls( + _var_name="", + _var_type=List[Tuple[str, value._value_type()]], + _var_data=ImmutableVarData.merge(_var_data), + _value=value, + ) @dataclasses.dataclass( @@ -585,30 +621,12 @@ class ObjectEntriesOperation(ObjectToArrayOperation): class ObjectMergeOperation(ObjectVar): """Operation to merge two objects.""" - left: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) - right: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) - - def __init__( - self, - left: ObjectVar, - right: ObjectVar, - _var_data: VarData | None = None, - ): - """Initialize the object merge operation. - - Args: - left: The left object to merge. - right: The right object to merge. - _var_data: Additional hooks and imports associated with the operation. - """ - super(ObjectMergeOperation, self).__init__( - _var_name="", - _var_type=left._var_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "left", left) - object.__setattr__(self, "right", right) - object.__delattr__(self, "_var_name") + _lhs: ObjectVar = dataclasses.field( + default_factory=lambda: LiteralObjectVar.create({}) + ) + _rhs: ObjectVar = dataclasses.field( + default_factory=lambda: LiteralObjectVar.create({}) + ) @cached_property def _cached_var_name(self) -> str: @@ -617,7 +635,7 @@ class ObjectMergeOperation(ObjectVar): Returns: The name of the operation. """ - return f"Object.assign({self.left._var_name}, {self.right._var_name})" + return f"Object.assign({self._lhs._var_name}, {self._rhs._var_name})" def __getattr__(self, name): """Get an attribute of the operation. @@ -640,8 +658,8 @@ class ObjectMergeOperation(ObjectVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.left._get_all_var_data(), - self.right._get_all_var_data(), + self._lhs._get_all_var_data(), + self._rhs._get_all_var_data(), self._var_data, ) @@ -659,7 +677,33 @@ class ObjectMergeOperation(ObjectVar): Returns: The hash of the operation. """ - return hash((self.__class__.__name__, self.left, self.right)) + return hash((self.__class__.__name__, self._lhs, self._rhs)) + + @classmethod + def create( + cls, + lhs: ObjectVar, + rhs: ObjectVar, + _var_data: VarData | None = None, + ) -> ObjectMergeOperation: + """Create the object merge operation. + + Args: + lhs: The left object to merge. + rhs: The right object to merge. + _var_data: Additional hooks and imports associated with the operation. + + Returns: + The object merge operation. + """ + # TODO: Figure out how to merge the types + return cls( + _var_name="", + _var_type=lhs._var_type, + _var_data=ImmutableVarData.merge(_var_data), + _lhs=lhs, + _rhs=rhs, + ) @dataclasses.dataclass( @@ -670,33 +714,10 @@ class ObjectMergeOperation(ObjectVar): class ObjectItemOperation(ImmutableVar): """Operation to get an item from an object.""" - value: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) - key: Var | Any = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) - - def __init__( - self, - value: ObjectVar, - key: Var | Any, - _var_type: GenericType | None = None, - _var_data: VarData | None = None, - ): - """Initialize the object item operation. - - Args: - value: The value of the operation. - key: The key to get from the object. - _var_data: Additional hooks and imports associated with the operation. - """ - super(ObjectItemOperation, self).__init__( - _var_name="", - _var_type=value._value_type() if _var_type is None else _var_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "value", value) - object.__setattr__( - self, "key", key if isinstance(key, Var) else LiteralVar.create(key) - ) - object.__delattr__(self, "_var_name") + _object: ObjectVar = dataclasses.field( + default_factory=lambda: LiteralObjectVar.create({}) + ) + _key: Var | Any = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) @cached_property def _cached_var_name(self) -> str: @@ -705,7 +726,7 @@ class ObjectItemOperation(ImmutableVar): Returns: The name of the operation. """ - return f"{str(self.value)}[{str(self.key)}]" + return f"{str(self._object)}[{str(self._key)}]" def __getattr__(self, name): """Get an attribute of the operation. @@ -728,8 +749,8 @@ class ObjectItemOperation(ImmutableVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.value._get_all_var_data(), - self.key._get_all_var_data(), + self._object._get_all_var_data(), + self._key._get_all_var_data(), self._var_data, ) @@ -747,7 +768,38 @@ class ObjectItemOperation(ImmutableVar): Returns: The hash of the operation. """ - return hash((self.__class__.__name__, self.value, self.key)) + return hash((self.__class__.__name__, self._object, self._key)) + + def __post_init__(self): + """Post initialization.""" + object.__delattr__(self, "_var_name") + + @classmethod + def create( + cls, + object: ObjectVar, + key: Var | Any, + _var_type: GenericType | None = None, + _var_data: VarData | None = None, + ) -> ObjectItemOperation: + """Create the object item operation. + + Args: + object: The object to get the item from. + key: The key to get from the object. + _var_type: The type of the item. + _var_data: Additional hooks and imports associated with the operation. + + Returns: + The object item operation. + """ + return cls( + _var_name="", + _var_type=object._value_type() if _var_type is None else _var_type, + _var_data=ImmutableVarData.merge(_var_data), + _object=object, + _key=key if isinstance(key, Var) else LiteralVar.create(key), + ) @dataclasses.dataclass( @@ -758,28 +810,9 @@ class ObjectItemOperation(ImmutableVar): class ToObjectOperation(ObjectVar): """Operation to convert a var to an object.""" - _original_var: Var = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) - - def __init__( - self, - _original_var: Var, - _var_type: Type = dict, - _var_data: VarData | None = None, - ): - """Initialize the to object operation. - - Args: - _original_var: The original var to convert. - _var_type: The type of the var. - _var_data: Additional hooks and imports associated with the operation. - """ - super(ToObjectOperation, self).__init__( - _var_name="", - _var_type=_var_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_original_var", _original_var) - object.__delattr__(self, "_var_name") + _original_var: Var = dataclasses.field( + default_factory=lambda: LiteralObjectVar.create({}) + ) @cached_property def _cached_var_name(self) -> str: @@ -831,6 +864,34 @@ class ToObjectOperation(ObjectVar): """ return hash((self.__class__.__name__, self._original_var)) + def __post_init__(self): + """Post initialization.""" + object.__delattr__(self, "_var_name") + + @classmethod + def create( + cls, + original_var: Var, + _var_type: GenericType | None = None, + _var_data: VarData | None = None, + ) -> ToObjectOperation: + """Create the to object operation. + + Args: + original_var: The original var to convert. + _var_type: The type of the var. + _var_data: Additional hooks and imports associated with the operation. + + Returns: + The to object operation. + """ + return cls( + _var_name="", + _var_type=dict if _var_type is None else _var_type, + _var_data=ImmutableVarData.merge(_var_data), + _original_var=original_var, + ) + @dataclasses.dataclass( eq=False, @@ -840,30 +901,13 @@ class ToObjectOperation(ObjectVar): class ObjectHasOwnProperty(BooleanVar): """Operation to check if an object has a property.""" - value: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) - key: Var | Any = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) + _object: ObjectVar = dataclasses.field( + default_factory=lambda: LiteralObjectVar.create({}) + ) + _key: Var | Any = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) - def __init__( - self, - value: ObjectVar, - key: Var | Any, - _var_data: VarData | None = None, - ): - """Initialize the object has own property operation. - - Args: - value: The value of the operation. - key: The key to check. - _var_data: Additional hooks and imports associated with the operation. - """ - super(ObjectHasOwnProperty, self).__init__( - _var_name="", - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "value", value) - object.__setattr__( - self, "key", key if isinstance(key, Var) else LiteralVar.create(key) - ) + def __post_init__(self): + """Post initialization.""" object.__delattr__(self, "_var_name") @cached_property @@ -873,7 +917,7 @@ class ObjectHasOwnProperty(BooleanVar): Returns: The name of the operation. """ - return f"{str(self.value)}.hasOwnProperty({str(self.key)})" + return f"{str(self._object)}.hasOwnProperty({str(self._key)})" def __getattr__(self, name): """Get an attribute of the operation. @@ -896,8 +940,8 @@ class ObjectHasOwnProperty(BooleanVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.value._get_all_var_data(), - self.key._get_all_var_data(), + self._object._get_all_var_data(), + self._key._get_all_var_data(), self._var_data, ) @@ -915,4 +959,29 @@ class ObjectHasOwnProperty(BooleanVar): Returns: The hash of the operation. """ - return hash((self.__class__.__name__, self.value, self.key)) + return hash((self.__class__.__name__, self._object, self._key)) + + @classmethod + def create( + cls, + object: ObjectVar, + key: Var | Any, + _var_data: VarData | None = None, + ) -> ObjectHasOwnProperty: + """Create the object has own property operation. + + Args: + object: The object to check. + key: The key to check. + _var_data: Additional hooks and imports associated with the operation. + + Returns: + The object has own property operation. + """ + return cls( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + _object=object, + _key=key if isinstance(key, Var) else LiteralVar.create(key), + ) diff --git a/reflex/ivars/sequence.py b/reflex/ivars/sequence.py index cae939ed6..8d2e659a8 100644 --- a/reflex/ivars/sequence.py +++ b/reflex/ivars/sequence.py @@ -18,6 +18,7 @@ from typing import ( Literal, Set, Tuple, + Type, TypeVar, Union, overload, @@ -59,7 +60,7 @@ class StringVar(ImmutableVar[str]): Returns: The string concatenation operation. """ - return ConcatVarOperation(self, other) + return ConcatVarOperation.create(self, other) def __radd__(self, other: StringVar | str) -> ConcatVarOperation: """Concatenate two strings. @@ -70,7 +71,7 @@ class StringVar(ImmutableVar[str]): Returns: The string concatenation operation. """ - return ConcatVarOperation(other, self) + return ConcatVarOperation.create(other, self) def __mul__(self, other: NumberVar | int) -> StringVar: """Multiply the sequence by a number or an integer. @@ -113,7 +114,7 @@ class StringVar(ImmutableVar[str]): """ if isinstance(i, slice): return self.split()[i].join() - return StringItemOperation(self, i) + return StringItemOperation.create(self, i) def length(self) -> NumberVar: """Get the length of the string. @@ -123,29 +124,29 @@ class StringVar(ImmutableVar[str]): """ return self.split().length() - def lower(self) -> StringLowerOperation: + def lower(self) -> StringVar: """Convert the string to lowercase. Returns: The string lower operation. """ - return StringLowerOperation(self) + return StringLowerOperation.create(self) - def upper(self) -> StringUpperOperation: + def upper(self) -> StringVar: """Convert the string to uppercase. Returns: The string upper operation. """ - return StringUpperOperation(self) + return StringUpperOperation.create(self) - def strip(self) -> StringStripOperation: + def strip(self) -> StringVar: """Strip the string. Returns: The string strip operation. """ - return StringStripOperation(self) + return StringStripOperation.create(self) def bool(self) -> NotEqualOperation: """Boolean conversion. @@ -153,7 +154,7 @@ class StringVar(ImmutableVar[str]): Returns: The boolean value of the string. """ - return NotEqualOperation(self.length(), 0) + return NotEqualOperation.create(self.length(), 0) def reversed(self) -> ArrayJoinOperation: """Reverse the string. @@ -172,7 +173,7 @@ class StringVar(ImmutableVar[str]): Returns: The string contains operation. """ - return StringContainsOperation(self, other) + return StringContainsOperation.create(self, other) def split(self, separator: StringVar | str = "") -> StringSplitOperation: """Split the string. @@ -183,7 +184,18 @@ class StringVar(ImmutableVar[str]): Returns: The string split operation. """ - return StringSplitOperation(self, separator) + return StringSplitOperation.create(self, separator) + + def startswith(self, prefix: StringVar | str) -> StringStartsWithOperation: + """Check if the string starts with a prefix. + + Args: + prefix: The prefix. + + Returns: + The string starts with operation. + """ + return StringStartsWithOperation.create(self, prefix) @dataclasses.dataclass( @@ -194,25 +206,12 @@ class StringVar(ImmutableVar[str]): class StringToStringOperation(StringVar): """Base class for immutable string vars that are the result of a string to string operation.""" - a: StringVar = dataclasses.field( + _value: StringVar = dataclasses.field( default_factory=lambda: LiteralStringVar.create("") ) - def __init__(self, a: StringVar | str, _var_data: VarData | None = None): - """Initialize the string to string operation var. - - Args: - a: The string. - _var_data: Additional hooks and imports associated with the Var. - """ - super(StringToStringOperation, self).__init__( - _var_name="", - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) - ) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -247,7 +246,7 @@ class StringToStringOperation(StringVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data() if isinstance(self.a, Var) else None, + self._value._get_all_var_data() if isinstance(self._value, Var) else None, self._var_data, ) @@ -260,7 +259,29 @@ class StringToStringOperation(StringVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self.a)) + return hash((self.__class__.__name__, self._value)) + + @classmethod + def create( + cls, + value: StringVar, + _var_data: VarData | None = None, + ) -> StringVar: + """Create a var from a string value. + + Args: + value: The value to create the var from. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + _value=value, + ) class StringLowerOperation(StringToStringOperation): @@ -273,7 +294,7 @@ class StringLowerOperation(StringToStringOperation): Returns: The name of the var. """ - return f"{str(self.a)}.toLowerCase()" + return f"{str(self._value)}.toLowerCase()" class StringUpperOperation(StringToStringOperation): @@ -286,7 +307,7 @@ class StringUpperOperation(StringToStringOperation): Returns: The name of the var. """ - return f"{str(self.a)}.toUpperCase()" + return f"{str(self._value)}.toUpperCase()" class StringStripOperation(StringToStringOperation): @@ -299,7 +320,7 @@ class StringStripOperation(StringToStringOperation): Returns: The name of the var. """ - return f"{str(self.a)}.trim()" + return f"{str(self._value)}.trim()" @dataclasses.dataclass( @@ -310,34 +331,15 @@ class StringStripOperation(StringToStringOperation): class StringContainsOperation(BooleanVar): """Base class for immutable boolean vars that are the result of a string contains operation.""" - a: StringVar = dataclasses.field( + _haystack: StringVar = dataclasses.field( default_factory=lambda: LiteralStringVar.create("") ) - b: StringVar = dataclasses.field( + _needle: StringVar = dataclasses.field( default_factory=lambda: LiteralStringVar.create("") ) - def __init__( - self, a: StringVar | str, b: StringVar | str, _var_data: VarData | None = None - ): - """Initialize the string contains operation var. - - Args: - a: The first string. - b: The second string. - _var_data: Additional hooks and imports associated with the Var. - """ - super(StringContainsOperation, self).__init__( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) - ) - object.__setattr__( - self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) - ) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -347,7 +349,7 @@ class StringContainsOperation(BooleanVar): Returns: The name of the var. """ - return f"{str(self.a)}.includes({str(self.b)})" + return f"{str(self._haystack)}.includes({str(self._needle)})" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -370,7 +372,9 @@ class StringContainsOperation(BooleanVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + self._haystack._get_all_var_data(), + self._needle._get_all_var_data(), + self._var_data, ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -382,7 +386,135 @@ class StringContainsOperation(BooleanVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self.a, self.b)) + return hash((self.__class__.__name__, self._haystack, self._needle)) + + @classmethod + def create( + cls, + haystack: StringVar | str, + needle: StringVar | str, + _var_data: VarData | None = None, + ) -> StringContainsOperation: + """Create a var from a string value. + + Args: + haystack: The haystack. + needle: The needle. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + _haystack=( + haystack + if isinstance(haystack, Var) + else LiteralStringVar.create(haystack) + ), + _needle=( + needle if isinstance(needle, Var) else LiteralStringVar.create(needle) + ), + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class StringStartsWithOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a string starts with operation.""" + + _full_string: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + _prefix: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __post_init__(self): + """Post-initialize the var.""" + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self._full_string)}.startsWith({str(self._prefix)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(StringStartsWithOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._full_string._get_all_var_data(), + self._prefix._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + def __hash__(self) -> int: + """Calculate the hash value of the object. + + Returns: + int: The hash value of the object. + """ + return hash((self.__class__.__name__, self._full_string, self._prefix)) + + @classmethod + def create( + cls, + full_string: StringVar | str, + prefix: StringVar | str, + _var_data: VarData | None = None, + ) -> StringStartsWithOperation: + """Create a var from a string value. + + Args: + full_string: The full string. + prefix: The prefix. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + _full_string=( + full_string + if isinstance(full_string, Var) + else LiteralStringVar.create(full_string) + ), + _prefix=( + prefix if isinstance(prefix, Var) else LiteralStringVar.create(prefix) + ), + ) @dataclasses.dataclass( @@ -393,31 +525,10 @@ class StringContainsOperation(BooleanVar): class StringItemOperation(StringVar): """Base class for immutable string vars that are the result of a string item operation.""" - a: StringVar = dataclasses.field( + _string: StringVar = dataclasses.field( default_factory=lambda: LiteralStringVar.create("") ) - i: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) - - def __init__( - self, a: StringVar | str, i: int | NumberVar, _var_data: VarData | None = None - ): - """Initialize the string item operation var. - - Args: - a: The string. - i: The index. - _var_data: Additional hooks and imports associated with the Var. - """ - super(StringItemOperation, self).__init__( - _var_name="", - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) - ) - object.__setattr__(self, "i", i if isinstance(i, Var) else LiteralNumberVar(i)) - object.__delattr__(self, "_var_name") + _index: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar.create(0)) @cached_property def _cached_var_name(self) -> str: @@ -426,7 +537,7 @@ class StringItemOperation(StringVar): Returns: The name of the var. """ - return f"{str(self.a)}.at({str(self.i)})" + return f"{str(self._string)}.at({str(self._index)})" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -449,7 +560,9 @@ class StringItemOperation(StringVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data(), self.i._get_all_var_data(), self._var_data + self._string._get_all_var_data(), + self._index._get_all_var_data(), + self._var_data, ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -461,36 +574,58 @@ class StringItemOperation(StringVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self.a, self.i)) + return hash((self.__class__.__name__, self._string, self._index)) + def __post_init__(self): + """Post-initialize the var.""" + object.__delattr__(self, "_var_name") -class ArrayJoinOperation(StringVar): - """Base class for immutable string vars that are the result of an array join operation.""" - - a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) - b: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) - - def __init__( - self, a: ArrayVar, b: StringVar | str, _var_data: VarData | None = None - ): - """Initialize the array join operation var. + @classmethod + def create( + cls, + string: StringVar | str, + index: NumberVar | int, + _var_data: VarData | None = None, + ) -> StringItemOperation: + """Create a var from a string value. Args: - a: The array. - b: The separator. + string: The string. + index: The index. _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. """ - super(ArrayJoinOperation, self).__init__( + return cls( _var_name="", _var_type=str, _var_data=ImmutableVarData.merge(_var_data), + _string=( + string if isinstance(string, Var) else LiteralStringVar.create(string) + ), + _index=( + index if isinstance(index, Var) else LiteralNumberVar.create(index) + ), ) - object.__setattr__(self, "a", a) - object.__setattr__( - self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) - ) + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArrayJoinOperation(StringVar): + """Base class for immutable string vars that are the result of an array join operation.""" + + _array: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]) + ) + _sep: StringVar = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -500,7 +635,7 @@ class ArrayJoinOperation(StringVar): Returns: The name of the var. """ - return f"{str(self.a)}.join({str(self.b)})" + return f"{str(self._array)}.join({str(self._sep)})" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -523,7 +658,9 @@ class ArrayJoinOperation(StringVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + self._array._get_all_var_data(), + self._sep._get_all_var_data(), + self._var_data, ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -535,7 +672,32 @@ class ArrayJoinOperation(StringVar): Returns: int: The hash value of the object. """ - return hash((self.__class__.__name__, self.a, self.b)) + return hash((self.__class__.__name__, self._array, self._sep)) + + @classmethod + def create( + cls, + array: ArrayVar, + sep: StringVar | str = "", + _var_data: VarData | None = None, + ) -> ArrayJoinOperation: + """Create a var from a string value. + + Args: + array: The array. + sep: The separator. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + _array=array, + _sep=sep if isinstance(sep, Var) else LiteralStringVar.create(sep), + ) # Compile regex for finding reflex var tags. @@ -555,24 +717,6 @@ class LiteralStringVar(LiteralVar, StringVar): _var_value: str = dataclasses.field(default="") - def __init__( - self, - _var_value: str, - _var_data: VarData | None = None, - ): - """Initialize the string var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralStringVar, self).__init__( - _var_name=f'"{_var_value}"', - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "_var_value", _var_value) - @classmethod def create( cls, @@ -606,8 +750,8 @@ class LiteralStringVar(LiteralVar, StringVar): # Find all tags while m := _decode_var_pattern.search(value): start, end = m.span() - if start > 0: - strings_and_vals.append(value[:start]) + + strings_and_vals.append(value[:start]) serialized_data = m.group(1) @@ -645,14 +789,18 @@ class LiteralStringVar(LiteralVar, StringVar): offset += end - start - if value: - strings_and_vals.append(value) + strings_and_vals.append(value) - return ConcatVarOperation(*strings_and_vals, _var_data=_var_data) + return ConcatVarOperation.create( + *filter(lambda s: isinstance(s, Var) or s, strings_and_vals), + _var_data=_var_data, + ) return LiteralStringVar( - value, - _var_data=_var_data, + _var_name=json.dumps(value), + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + _var_value=value, ) def __hash__(self) -> int: @@ -680,20 +828,7 @@ class LiteralStringVar(LiteralVar, StringVar): class ConcatVarOperation(StringVar): """Representing a concatenation of literal string vars.""" - _var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple) - - def __init__(self, *value: Var | str, _var_data: VarData | None = None): - """Initialize the operation of concatenating literal string vars. - - Args: - value: The values to concatenate. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ConcatVarOperation, self).__init__( - _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str - ) - object.__setattr__(self, "_var_value", value) - object.__delattr__(self, "_var_name") + _var_value: Tuple[Var, ...] = dataclasses.field(default_factory=tuple) def __getattr__(self, name): """Get an attribute of the var. @@ -715,16 +850,7 @@ class ConcatVarOperation(StringVar): Returns: The name of the var. """ - return ( - "(" - + "+".join( - [ - str(element) if isinstance(element, Var) else f'"{element}"' - for element in self._var_value - ] - ) - + ")" - ) + return "(" + "+".join([str(element) for element in self._var_value]) + ")" @cached_property def _cached_get_all_var_data(self) -> ImmutableVarData | None: @@ -752,7 +878,7 @@ class ConcatVarOperation(StringVar): def __post_init__(self): """Post-initialize the var.""" - pass + object.__delattr__(self, "_var_name") def __hash__(self) -> int: """Get the hash of the var. @@ -762,6 +888,28 @@ class ConcatVarOperation(StringVar): """ return hash((self.__class__.__name__, *self._var_value)) + @classmethod + def create( + cls, + *value: Var | str, + _var_data: VarData | None = None, + ) -> ConcatVarOperation: + """Create a var from a string value. + + Args: + value: The values to concatenate. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + _var_value=tuple(map(LiteralVar.create, value)), + ) + ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set]) @@ -785,7 +933,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: The joined elements. """ - return ArrayJoinOperation(self, sep) + return ArrayJoinOperation.create(self, sep) def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]: """Reverse the array. @@ -793,7 +941,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: The reversed array. """ - return ArrayReverseOperation(self) + return ArrayReverseOperation.create(self) def __add__(self, other: ArrayVar[ARRAY_VAR_TYPE]) -> ArrayConcatOperation: """Concatenate two arrays. @@ -804,7 +952,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: ArrayConcatOperation: The concatenation of the two arrays. """ - return ArrayConcatOperation(self, other) + return ArrayConcatOperation.create(self, other) @overload def __getitem__(self, i: slice) -> ArrayVar[ARRAY_VAR_TYPE]: ... @@ -908,8 +1056,8 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): The array slice operation. """ if isinstance(i, slice): - return ArraySliceOperation(self, i) - return ArrayItemOperation(self, i).guess_type() + return ArraySliceOperation.create(self, i) + return ArrayItemOperation.create(self, i).guess_type() def length(self) -> NumberVar: """Get the length of the array. @@ -917,7 +1065,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: The length of the array. """ - return ArrayLengthOperation(self) + return ArrayLengthOperation.create(self) @overload @classmethod @@ -957,7 +1105,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): start = first_endpoint end = second_endpoint - return RangeOperation(start, end, step or 1) + return RangeOperation.create(start, end, step or 1) def contains(self, other: Any) -> BooleanVar: """Check if the array contains an element. @@ -968,7 +1116,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: The array contains operation. """ - return ArrayContainsOperation(self, other) + return ArrayContainsOperation.create(self, other) def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: """Multiply the sequence by a number or integer. @@ -979,7 +1127,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: ArrayVar[ARRAY_VAR_TYPE]: The result of multiplying the sequence by the given number or integer. """ - return ArrayRepeatOperation(self, other) + return ArrayRepeatOperation.create(self, other) def __rmul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: """Multiply the sequence by a number or integer. @@ -990,7 +1138,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: ArrayVar[ARRAY_VAR_TYPE]: The result of multiplying the sequence by the given number or integer. """ - return ArrayRepeatOperation(self, other) + return ArrayRepeatOperation.create(self, other) LIST_ELEMENT = TypeVar("LIST_ELEMENT") @@ -1014,27 +1162,6 @@ class LiteralArrayVar(LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...] ] = dataclasses.field(default_factory=list) - def __init__( - self: LiteralArrayVar[ARRAY_VAR_TYPE], - _var_value: ARRAY_VAR_TYPE, - _var_type: type[ARRAY_VAR_TYPE] | None = None, - _var_data: VarData | None = None, - ): - """Initialize the array var. - - Args: - _var_value: The value of the var. - _var_type: The type of the var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(LiteralArrayVar, self).__init__( - _var_name="", - _var_data=ImmutableVarData.merge(_var_data), - _var_type=(figure_out_type(_var_value) if _var_type is None else _var_type), - ) - object.__setattr__(self, "_var_value", _var_value) - object.__delattr__(self, "_var_name") - def __getattr__(self, name): """Get an attribute of the var. @@ -1109,6 +1236,33 @@ class LiteralArrayVar(LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): + "]" ) + def __post_init__(self): + """Post-initialize the var.""" + object.__delattr__(self, "_var_name") + + @classmethod + def create( + cls, + value: ARRAY_VAR_TYPE, + _var_type: Type[ARRAY_VAR_TYPE] | None = None, + _var_data: VarData | None = None, + ) -> LiteralArrayVar[ARRAY_VAR_TYPE]: + """Create a var from a string value. + + Args: + value: The value to create the var from. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=figure_out_type(value) if _var_type is None else _var_type, + _var_data=ImmutableVarData.merge(_var_data), + _var_value=value, + ) + @dataclasses.dataclass( eq=False, @@ -1118,34 +1272,15 @@ class LiteralArrayVar(LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): class StringSplitOperation(ArrayVar): """Base class for immutable array vars that are the result of a string split operation.""" - a: StringVar = dataclasses.field( + _string: StringVar = dataclasses.field( default_factory=lambda: LiteralStringVar.create("") ) - b: StringVar = dataclasses.field( + _sep: StringVar = dataclasses.field( default_factory=lambda: LiteralStringVar.create("") ) - def __init__( - self, a: StringVar | str, b: StringVar | str, _var_data: VarData | None = None - ): - """Initialize the string split operation var. - - Args: - a: The string. - b: The separator. - _var_data: Additional hooks and imports associated with the Var. - """ - super(StringSplitOperation, self).__init__( - _var_name="", - _var_type=List[str], - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) - ) - object.__setattr__( - self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) - ) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -1155,7 +1290,7 @@ class StringSplitOperation(ArrayVar): Returns: The name of the var. """ - return f"{str(self.a)}.split({str(self.b)})" + return f"{str(self._string)}.split({str(self._sep)})" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -1178,7 +1313,9 @@ class StringSplitOperation(ArrayVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + self._string._get_all_var_data(), + self._sep._get_all_var_data(), + self._var_data, ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -1190,7 +1327,34 @@ class StringSplitOperation(ArrayVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.a, self.b)) + return hash((self.__class__.__name__, self._string, self._sep)) + + @classmethod + def create( + cls, + string: StringVar | str, + sep: StringVar | str, + _var_data: VarData | None = None, + ) -> StringSplitOperation: + """Create a var from a string value. + + Args: + string: The string. + sep: The separator. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=List[str], + _var_data=ImmutableVarData.merge(_var_data), + _string=( + string if isinstance(string, Var) else LiteralStringVar.create(string) + ), + _sep=(sep if isinstance(sep, Var) else LiteralStringVar.create(sep)), + ) @dataclasses.dataclass( @@ -1201,21 +1365,12 @@ class StringSplitOperation(ArrayVar): class ArrayToArrayOperation(ArrayVar): """Base class for immutable array vars that are the result of an array to array operation.""" - a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + _value: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]) + ) - def __init__(self, a: ArrayVar, _var_data: VarData | None = None): - """Initialize the array to array operation var. - - Args: - a: The string. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ArrayToArrayOperation, self).__init__( - _var_name="", - _var_type=a._var_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -1250,7 +1405,7 @@ class ArrayToArrayOperation(ArrayVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data() if isinstance(self.a, Var) else None, + self._value._get_all_var_data() if isinstance(self._value, Var) else None, self._var_data, ) @@ -1263,7 +1418,29 @@ class ArrayToArrayOperation(ArrayVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.a)) + return hash((self.__class__.__name__, self._value)) + + @classmethod + def create( + cls, + value: ArrayVar, + _var_data: VarData | None = None, + ) -> ArrayToArrayOperation: + """Create a var from a string value. + + Args: + value: The value to create the var from. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=value._var_type, + _var_data=ImmutableVarData.merge(_var_data), + _value=value, + ) @dataclasses.dataclass( @@ -1274,24 +1451,13 @@ class ArrayToArrayOperation(ArrayVar): class ArraySliceOperation(ArrayVar): """Base class for immutable string vars that are the result of a string slice operation.""" - a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + _array: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]) + ) _slice: slice = dataclasses.field(default_factory=lambda: slice(None, None, None)) - def __init__(self, a: ArrayVar, _slice: slice, _var_data: VarData | None = None): - """Initialize the string slice operation var. - - Args: - a: The string. - _slice: The slice. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ArraySliceOperation, self).__init__( - _var_name="", - _var_type=a._var_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a) - object.__setattr__(self, "_slice", _slice) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -1317,29 +1483,29 @@ class ArraySliceOperation(ArrayVar): else ImmutableVar.create_safe("undefined") ) if step is None: - return ( - f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)})" - ) + return f"{str(self._array)}.slice({str(normalized_start)}, {str(normalized_end)})" if not isinstance(step, Var): if step < 0: actual_start = end + 1 if end is not None else 0 - actual_end = start + 1 if start is not None else self.a.length() + actual_end = start + 1 if start is not None else self._array.length() return str( - ArraySliceOperation( - ArrayReverseOperation( - ArraySliceOperation(self.a, slice(actual_start, actual_end)) + ArraySliceOperation.create( + ArrayReverseOperation.create( + ArraySliceOperation.create( + self._array, slice(actual_start, actual_end) + ) ), slice(None, None, -step), ) ) if step == 0: raise ValueError("slice step cannot be zero") - return f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0)" + return f"{str(self._array)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0)" actual_start_reverse = end + 1 if end is not None else 0 - actual_end_reverse = start + 1 if start is not None else self.a.length() + actual_end_reverse = start + 1 if start is not None else self._array.length() - return f"{str(self.step)} > 0 ? {str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0) : {str(self.a)}.slice({str(actual_start_reverse)}, {str(actual_end_reverse)}).reverse().filter((_, i) => i % {str(-step)} === 0)" + return f"{str(self.step)} > 0 ? {str(self._array)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0) : {str(self._array)}.slice({str(actual_start_reverse)}, {str(actual_end_reverse)}).reverse().filter((_, i) => i % {str(-step)} === 0)" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -1362,7 +1528,7 @@ class ArraySliceOperation(ArrayVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data(), + self._array._get_all_var_data(), *[ slice_value._get_all_var_data() for slice_value in ( @@ -1384,7 +1550,32 @@ class ArraySliceOperation(ArrayVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.a, self._slice)) + return hash((self.__class__.__name__, self._array, self._slice)) + + @classmethod + def create( + cls, + array: ArrayVar, + slice: slice, + _var_data: VarData | None = None, + ) -> ArraySliceOperation: + """Create a var from a string value. + + Args: + array: The array. + slice: The slice. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=array._var_type, + _var_data=ImmutableVarData.merge(_var_data), + _array=array, + _slice=slice, + ) class ArrayReverseOperation(ArrayToArrayOperation): @@ -1397,7 +1588,7 @@ class ArrayReverseOperation(ArrayToArrayOperation): Returns: The name of the var. """ - return f"{str(self.a)}.slice().reverse()" + return f"{str(self._value)}.slice().reverse()" @dataclasses.dataclass( @@ -1408,23 +1599,12 @@ class ArrayReverseOperation(ArrayToArrayOperation): class ArrayToNumberOperation(NumberVar): """Base class for immutable number vars that are the result of an array to number operation.""" - a: ArrayVar = dataclasses.field( - default_factory=lambda: LiteralArrayVar([]), + _array: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]), ) - def __init__(self, a: ArrayVar, _var_data: VarData | None = None): - """Initialize the string to number operation var. - - Args: - a: The array. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ArrayToNumberOperation, self).__init__( - _var_name="", - _var_type=int, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a)) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -1458,7 +1638,7 @@ class ArrayToNumberOperation(NumberVar): Returns: The VarData of the components and all of its children. """ - return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) + return ImmutableVarData.merge(self._array._get_all_var_data(), self._var_data) def _get_all_var_data(self) -> ImmutableVarData | None: return self._cached_get_all_var_data @@ -1469,7 +1649,29 @@ class ArrayToNumberOperation(NumberVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.a)) + return hash((self.__class__.__name__, self._array)) + + @classmethod + def create( + cls, + array: ArrayVar, + _var_data: VarData | None = None, + ) -> ArrayToNumberOperation: + """Create a var from a string value. + + Args: + array: The array. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=int, + _var_data=ImmutableVarData.merge(_var_data), + _array=array, + ) class ArrayLengthOperation(ArrayToNumberOperation): @@ -1482,7 +1684,7 @@ class ArrayLengthOperation(ArrayToNumberOperation): Returns: The name of the var. """ - return f"{str(self.a)}.length" + return f"{str(self._array)}.length" def is_tuple_type(t: GenericType) -> bool: @@ -1507,38 +1709,13 @@ def is_tuple_type(t: GenericType) -> bool: class ArrayItemOperation(ImmutableVar): """Base class for immutable array vars that are the result of an array item operation.""" - a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) - i: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + _array: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]) + ) + _index: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar.create(0)) - def __init__( - self, - a: ArrayVar, - i: NumberVar | int, - _var_data: VarData | None = None, - ): - """Initialize the array item operation var. - - Args: - a: The array. - i: The index. - _var_data: Additional hooks and imports associated with the Var. - """ - args = typing.get_args(a._var_type) - if args and isinstance(i, int) and is_tuple_type(a._var_type): - element_type = args[i % len(args)] - else: - element_type = unionize(*args) - super(ArrayItemOperation, self).__init__( - _var_name="", - _var_type=element_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a)) - object.__setattr__( - self, - "i", - i if isinstance(i, Var) else LiteralNumberVar(i), - ) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -1548,7 +1725,7 @@ class ArrayItemOperation(ImmutableVar): Returns: The name of the var. """ - return f"{str(self.a)}.at({str(self.i)})" + return f"{str(self._array)}.at({str(self._index)})" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -1571,7 +1748,9 @@ class ArrayItemOperation(ImmutableVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data(), self.i._get_all_var_data(), self._var_data + self._array._get_all_var_data(), + self._index._get_all_var_data(), + self._var_data, ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -1583,7 +1762,39 @@ class ArrayItemOperation(ImmutableVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.a, self.i)) + return hash((self.__class__.__name__, self._array, self._index)) + + @classmethod + def create( + cls, + array: ArrayVar, + index: NumberVar | int, + _var_type: GenericType | None = None, + _var_data: VarData | None = None, + ) -> ArrayItemOperation: + """Create a var from a string value. + + Args: + array: The array. + index: The index. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + args = typing.get_args(array._var_type) + if args and isinstance(index, int) and is_tuple_type(array._var_type): + element_type = args[index % len(args)] + else: + element_type = unionize(*args) + + return cls( + _var_name="", + _var_type=element_type if _var_type is None else _var_type, + _var_data=ImmutableVarData.merge(_var_data), + _array=array, + _index=index if isinstance(index, Var) else LiteralNumberVar.create(index), + ) @dataclasses.dataclass( @@ -1594,45 +1805,12 @@ class ArrayItemOperation(ImmutableVar): class RangeOperation(ArrayVar): """Base class for immutable array vars that are the result of a range operation.""" - start: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) - end: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) - step: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(1)) + _start: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar.create(0)) + _stop: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar.create(0)) + _step: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar.create(1)) - def __init__( - self, - start: NumberVar | int, - end: NumberVar | int, - step: NumberVar | int, - _var_data: VarData | None = None, - ): - """Initialize the range operation var. - - Args: - start: The start of the range. - end: The end of the range. - step: The step of the range. - _var_data: Additional hooks and imports associated with the Var. - """ - super(RangeOperation, self).__init__( - _var_name="", - _var_type=List[int], - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, - "start", - start if isinstance(start, Var) else LiteralNumberVar(start), - ) - object.__setattr__( - self, - "end", - end if isinstance(end, Var) else LiteralNumberVar(end), - ) - object.__setattr__( - self, - "step", - step if isinstance(step, Var) else LiteralNumberVar(step), - ) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -1642,7 +1820,7 @@ class RangeOperation(ArrayVar): Returns: The name of the var. """ - start, end, step = self.start, self.end, self.step + start, end, step = self._start, self._stop, self._step return f"Array.from({{ length: ({str(end)} - {str(start)}) / {str(step)} }}, (_, i) => {str(start)} + i * {str(step)})" def __getattr__(self, name: str) -> Any: @@ -1666,9 +1844,9 @@ class RangeOperation(ArrayVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.start._get_all_var_data(), - self.end._get_all_var_data(), - self.step._get_all_var_data(), + self._start._get_all_var_data(), + self._stop._get_all_var_data(), + self._step._get_all_var_data(), self._var_data, ) @@ -1681,7 +1859,35 @@ class RangeOperation(ArrayVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.start, self.end, self.step)) + return hash((self.__class__.__name__, self._start, self._stop, self._step)) + + @classmethod + def create( + cls, + start: NumberVar | int, + stop: NumberVar | int, + step: NumberVar | int, + _var_data: VarData | None = None, + ) -> RangeOperation: + """Create a var from a string value. + + Args: + start: The start of the range. + stop: The end of the range. + step: The step of the range. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=List[int], + _var_data=ImmutableVarData.merge(_var_data), + _start=start if isinstance(start, Var) else LiteralNumberVar.create(start), + _stop=stop if isinstance(stop, Var) else LiteralNumberVar.create(stop), + _step=step if isinstance(step, Var) else LiteralNumberVar.create(step), + ) @dataclasses.dataclass( @@ -1692,24 +1898,13 @@ class RangeOperation(ArrayVar): class ArrayContainsOperation(BooleanVar): """Base class for immutable boolean vars that are the result of an array contains operation.""" - a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) - b: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) + _haystack: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]) + ) + _needle: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) - def __init__(self, a: ArrayVar, b: Any | Var, _var_data: VarData | None = None): - """Initialize the array contains operation var. - - Args: - a: The array. - b: The element to check for. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ArrayContainsOperation, self).__init__( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a) - object.__setattr__(self, "b", b if isinstance(b, Var) else LiteralVar.create(b)) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -1719,7 +1914,7 @@ class ArrayContainsOperation(BooleanVar): Returns: The name of the var. """ - return f"{str(self.a)}.includes({str(self.b)})" + return f"{str(self._haystack)}.includes({str(self._needle)})" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -1742,7 +1937,9 @@ class ArrayContainsOperation(BooleanVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + self._haystack._get_all_var_data(), + self._needle._get_all_var_data(), + self._var_data, ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -1754,7 +1951,32 @@ class ArrayContainsOperation(BooleanVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.a, self.b)) + return hash((self.__class__.__name__, self._haystack, self._needle)) + + @classmethod + def create( + cls, + haystack: ArrayVar, + needle: Any | Var, + _var_data: VarData | None = None, + ) -> ArrayContainsOperation: + """Create a var from a string value. + + Args: + haystack: The array. + needle: The element to check for. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + _haystack=haystack, + _needle=needle if isinstance(needle, Var) else LiteralVar.create(needle), + ) @dataclasses.dataclass( @@ -1765,27 +1987,12 @@ class ArrayContainsOperation(BooleanVar): class ToStringOperation(StringVar): """Base class for immutable string vars that are the result of a to string operation.""" - original_var: Var = dataclasses.field( + _original_var: Var = dataclasses.field( default_factory=lambda: LiteralStringVar.create("") ) - def __init__(self, original_var: Var, _var_data: VarData | None = None): - """Initialize the to string operation var. - - Args: - original_var: The original var. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ToStringOperation, self).__init__( - _var_name="", - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, - "original_var", - original_var, - ) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -1795,7 +2002,7 @@ class ToStringOperation(StringVar): Returns: The name of the var. """ - return str(self.original_var) + return str(self._original_var) def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -1818,7 +2025,7 @@ class ToStringOperation(StringVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.original_var._get_all_var_data(), self._var_data + self._original_var._get_all_var_data(), self._var_data ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -1830,7 +2037,29 @@ class ToStringOperation(StringVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.original_var)) + return hash((self.__class__.__name__, self._original_var)) + + @classmethod + def create( + cls, + original_var: Var, + _var_data: VarData | None = None, + ) -> ToStringOperation: + """Create a var from a string value. + + Args: + original_var: The original var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + _original_var=original_var, + ) @dataclasses.dataclass( @@ -1841,31 +2070,12 @@ class ToStringOperation(StringVar): class ToArrayOperation(ArrayVar): """Base class for immutable array vars that are the result of a to array operation.""" - original_var: Var = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + _original_var: Var = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]) + ) - def __init__( - self, - original_var: Var, - _var_type: type[list] | type[set] | type[tuple] = list, - _var_data: VarData | None = None, - ): - """Initialize the to array operation var. - - Args: - original_var: The original var. - _var_type: The type of the array. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ToArrayOperation, self).__init__( - _var_name="", - _var_type=_var_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, - "original_var", - original_var, - ) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -1875,7 +2085,7 @@ class ToArrayOperation(ArrayVar): Returns: The name of the var. """ - return str(self.original_var) + return str(self._original_var) def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -1898,7 +2108,7 @@ class ToArrayOperation(ArrayVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.original_var._get_all_var_data(), self._var_data + self._original_var._get_all_var_data(), self._var_data ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -1910,7 +2120,30 @@ class ToArrayOperation(ArrayVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.original_var)) + return hash((self.__class__.__name__, self._original_var)) + + @classmethod + def create( + cls, + original_var: Var, + _var_type: type[list] | type[set] | type[tuple] | None = None, + _var_data: VarData | None = None, + ) -> ToArrayOperation: + """Create a var from a string value. + + Args: + original_var: The original var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=list if _var_type is None else _var_type, + _var_data=ImmutableVarData.merge(_var_data), + _original_var=original_var, + ) @dataclasses.dataclass( @@ -1921,30 +2154,13 @@ class ToArrayOperation(ArrayVar): class ArrayRepeatOperation(ArrayVar): """Base class for immutable array vars that are the result of an array repeat operation.""" - a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) - n: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + _array: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]) + ) + _count: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar.create(0)) - def __init__( - self, a: ArrayVar, n: NumberVar | int, _var_data: VarData | None = None - ): - """Initialize the array repeat operation var. - - Args: - a: The array. - n: The number of times to repeat the array. - _var_data: Additional hooks and imports associated with the Var. - """ - super(ArrayRepeatOperation, self).__init__( - _var_name="", - _var_type=a._var_type, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a) - object.__setattr__( - self, - "n", - n if isinstance(n, Var) else LiteralNumberVar(n), - ) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -1954,7 +2170,7 @@ class ArrayRepeatOperation(ArrayVar): Returns: The name of the var. """ - return f"Array.from({{ length: {str(self.n)} }}).flatMap(() => {str(self.a)})" + return f"Array.from({{ length: {str(self._count)} }}).flatMap(() => {str(self._array)})" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -1977,7 +2193,9 @@ class ArrayRepeatOperation(ArrayVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data(), self.n._get_all_var_data(), self._var_data + self._array._get_all_var_data(), + self._count._get_all_var_data(), + self._var_data, ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -1989,7 +2207,32 @@ class ArrayRepeatOperation(ArrayVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.a, self.n)) + return hash((self.__class__.__name__, self._array, self._count)) + + @classmethod + def create( + cls, + array: ArrayVar, + count: NumberVar | int, + _var_data: VarData | None = None, + ) -> ArrayRepeatOperation: + """Create a var from a string value. + + Args: + array: The array. + count: The number of times to repeat the array. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return cls( + _var_name="", + _var_type=array._var_type, + _var_data=ImmutableVarData.merge(_var_data), + _array=array, + _count=count if isinstance(count, Var) else LiteralNumberVar.create(count), + ) @dataclasses.dataclass( @@ -2000,25 +2243,15 @@ class ArrayRepeatOperation(ArrayVar): class ArrayConcatOperation(ArrayVar): """Base class for immutable array vars that are the result of an array concat operation.""" - a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) - b: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + _lhs: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]) + ) + _rhs: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar.create([]) + ) - def __init__(self, a: ArrayVar, b: ArrayVar, _var_data: VarData | None = None): - """Initialize the array concat operation var. - - Args: - a: The first array. - b: The second array. - _var_data: Additional hooks and imports associated with the Var. - """ - # TODO: Figure out how to merge the types of a and b - super(ArrayConcatOperation, self).__init__( - _var_name="", - _var_type=Union[a._var_type, b._var_type], - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__(self, "a", a) - object.__setattr__(self, "b", b) + def __post_init__(self): + """Post-initialize the var.""" object.__delattr__(self, "_var_name") @cached_property @@ -2028,7 +2261,7 @@ class ArrayConcatOperation(ArrayVar): Returns: The name of the var. """ - return f"[...{str(self.a)}, ...{str(self.b)}]" + return f"[...{str(self._lhs)}, ...{str(self._rhs)}]" def __getattr__(self, name: str) -> Any: """Get an attribute of the var. @@ -2051,7 +2284,7 @@ class ArrayConcatOperation(ArrayVar): The VarData of the components and all of its children. """ return ImmutableVarData.merge( - self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + self._lhs._get_all_var_data(), self._rhs._get_all_var_data(), self._var_data ) def _get_all_var_data(self) -> ImmutableVarData | None: @@ -2063,4 +2296,30 @@ class ArrayConcatOperation(ArrayVar): Returns: The hash of the var. """ - return hash((self.__class__.__name__, self.a, self.b)) + return hash((self.__class__.__name__, self._lhs, self._rhs)) + + @classmethod + def create( + cls, + lhs: ArrayVar, + rhs: ArrayVar, + _var_data: VarData | None = None, + ) -> ArrayConcatOperation: + """Create a var from a string value. + + Args: + lhs: The left-hand side array. + rhs: The right-hand side array. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + # TODO: Figure out how to merge the types of a and b + return cls( + _var_name="", + _var_type=Union[lhs._var_type, rhs._var_type], + _var_data=ImmutableVarData.merge(_var_data), + _lhs=lhs, + _rhs=rhs, + ) diff --git a/reflex/state.py b/reflex/state.py index 2d90902d4..1ec4eb6e9 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -32,7 +32,7 @@ import dill from sqlalchemy.orm import DeclarativeBase from reflex.config import get_config -from reflex.ivars.base import ImmutableVar +from reflex.ivars.base import ImmutableComputedVar, ImmutableVar, immutable_computed_var try: import pydantic.v1 as pydantic @@ -60,7 +60,6 @@ from reflex.vars import ( ComputedVar, ImmutableVarData, Var, - computed_var, ) if TYPE_CHECKING: @@ -68,7 +67,7 @@ if TYPE_CHECKING: Delta = Dict[str, Any] -var = computed_var +var = immutable_computed_var # If the state is this large, it's considered a performance issue. @@ -307,7 +306,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): base_vars: ClassVar[Dict[str, ImmutableVar]] = {} # The computed vars of the class. - computed_vars: ClassVar[Dict[str, ComputedVar]] = {} + computed_vars: ClassVar[Dict[str, Union[ComputedVar, ImmutableComputedVar]]] = {} # Vars inherited by the parent state. inherited_vars: ClassVar[Dict[str, Var]] = {} @@ -420,7 +419,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): return f"{self.__class__.__name__}({self.dict()})" @classmethod - def _get_computed_vars(cls) -> list[ComputedVar]: + def _get_computed_vars(cls) -> list[Union[ComputedVar, ImmutableComputedVar]]: """Helper function to get all computed vars of a instance. Returns: @@ -430,7 +429,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): v for mixin in cls._mixins() + [cls] for v in mixin.__dict__.values() - if isinstance(v, ComputedVar) + if isinstance(v, (ComputedVar, ImmutableComputedVar)) ] @classmethod @@ -534,7 +533,10 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for f in cls.get_fields().values() if f.name not in cls.get_skip_vars() } - cls.computed_vars = {v._var_name: v._var_set_state(cls) for v in computed_vars} + cls.computed_vars = { + v._var_name: v._replace(merge_var_data=ImmutableVarData.from_state(cls)) + for v in computed_vars + } cls.vars = { **cls.inherited_vars, **cls.base_vars, @@ -555,12 +557,12 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): for mixin in cls._mixins(): for name, value in mixin.__dict__.items(): - if isinstance(value, ComputedVar): + if isinstance(value, (ComputedVar, ImmutableComputedVar)): fget = cls._copy_fn(value.fget) - newcv = value._replace(fget=fget) + newcv = value._replace( + fget=fget, _var_data=ImmutableVarData.from_state(cls) + ) # cleanup refs to mixin cls in var_data - newcv._var_data = None - newcv._var_set_state(cls) setattr(cls, name, newcv) cls.computed_vars[newcv._var_name] = newcv cls.vars[newcv._var_name] = newcv @@ -897,8 +899,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) # create the variable based on name and type - var = ImmutableVar(_var_name=name, _var_type=type_).guess_type() - var._var_set_state(cls) + var = ImmutableVar( + _var_name=name, _var_type=type_, _var_data=ImmutableVarData.from_state(cls) + ).guess_type() # add the pydantic field dynamically (must be done before _init_var) cls.add_field(var, default_value) @@ -983,7 +986,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): and not types.is_optional(prop._var_type) ): # Ensure frontend uses null coalescing when accessing. - prop._var_type = Optional[prop._var_type] + object.__setattr__(prop, "_var_type", Optional[prop._var_type]) @staticmethod def _get_base_functions() -> dict[str, FunctionType]: @@ -1783,7 +1786,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Include initial computed vars. prop_name: ( cv._initial_value - if isinstance(cv, ComputedVar) + if isinstance(cv, (ComputedVar, ImmutableComputedVar)) and not isinstance(cv._initial_value, types.Unset) else self.get_value(getattr(self, prop_name)) ) diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 01b7cb712..6a016c5d6 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -9,8 +9,6 @@ import re from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union from reflex import constants -from reflex.ivars.base import ImmutableVar -from reflex.ivars.function import FunctionVar from reflex.utils import exceptions, types from reflex.vars import BaseVar, Var @@ -274,8 +272,10 @@ def format_f_string_prop(prop: BaseVar) -> str: Returns: The formatted string. """ + from reflex.ivars.base import VarData + s = prop._var_full_name - var_data = prop._var_data + var_data = VarData.merge(prop._get_all_var_data()) interps = var_data.interpolations if var_data else [] parts: List[str] = [] @@ -423,6 +423,7 @@ def format_prop( # import here to avoid circular import. from reflex.event import EventChain from reflex.utils import serializers + from reflex.vars import VarData try: # Handle var props. @@ -430,7 +431,8 @@ def format_prop( if not prop._var_is_local or prop._var_is_string: return str(prop) if isinstance(prop, BaseVar) and types._issubclass(prop._var_type, str): - if prop._var_data and prop._var_data.interpolations: + var_data = VarData.merge(prop._get_all_var_data()) + if var_data and var_data.interpolations: return format_f_string_prop(prop) return format_string(prop._var_full_name) prop = prop._var_full_name @@ -485,17 +487,38 @@ def format_props(*single_props, **key_value_props) -> list[str]: The formatted props list. """ # Format all the props. - from reflex.ivars.base import ImmutableVar + from reflex.ivars.base import ImmutableVar, LiteralVar + + # print( + # *[ + # f"{name}={{{format_prop(prop if isinstance(prop, Var) else LiteralVar.create(prop))}}}" + # for name, prop in sorted(key_value_props.items()) + # if prop is not None + # ], + # sep="\n", + # ) + + # if single_props: + # print("single_props", single_props) return [ ( - f"{name}={{{format_prop(prop)}}}" - if isinstance(prop, ImmutableVar) - else f"{name}={format_prop(prop)}" + f"{name}={format_prop(prop)}" + if isinstance(prop, Var) and not isinstance(prop, ImmutableVar) + else ( + f"{name}={{{format_prop(prop if isinstance(prop, Var) else LiteralVar.create(prop))}}}" + ) ) for name, prop in sorted(key_value_props.items()) if prop is not None - ] + [str(prop) for prop in single_props] + ] + [ + ( + str(prop) + if isinstance(prop, Var) and not isinstance(prop, ImmutableVar) + else f"{{{str(LiteralVar.create(prop))}}}" + ) + for prop in single_props + ] def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: @@ -510,13 +533,13 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: # Get the class that defines the event handler. parts = handler.fn.__qualname__.split(".") - # If there's no enclosing class, just return the function name. - if len(parts) == 1: - return ("", parts[-1]) - # Get the state full name state_full_name = handler.state_full_name + # If there's no enclosing class, just return the function name. + if not state_full_name: + return ("", parts[-1]) + # Get the function name name = parts[-1] @@ -655,6 +678,7 @@ def format_queue_events( call_event_fn, call_event_handler, ) + from reflex.ivars.base import FunctionVar, ImmutableVar if not events: return ImmutableVar("(() => null)").to(FunctionVar, EventChain) @@ -944,6 +968,8 @@ def format_data_editor_cell(cell: Any): Returns: The formatted cell. """ + from reflex.ivars.base import ImmutableVar + return { "kind": ImmutableVar.create("GridCellKind.Text"), "data": cell, diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 3bb5eae35..53528e25e 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -491,10 +491,18 @@ def is_backend_base_variable(name: str, cls: Type) -> bool: return False if callable(value): return False + from reflex.ivars.base import ImmutableComputedVar from reflex.vars import ComputedVar if isinstance( - value, (types.FunctionType, property, cached_property, ComputedVar) + value, + ( + types.FunctionType, + property, + cached_property, + ComputedVar, + ImmutableComputedVar, + ), ): return False diff --git a/reflex/vars.py b/reflex/vars.py index 23f2fe6a7..72441a86a 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -892,6 +892,7 @@ class Var: Raises: VarTypeError: If the var is not indexable. """ + print(repr(self)) from reflex.utils import format # Indexing is only supported for strings, lists, tuples, dicts, and dataframes. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 69041d563..8c923f03e 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -154,6 +154,7 @@ class Var: def _var_set_state(self, state: Type[BaseState] | str) -> Any: ... def _get_all_var_data(self) -> VarData | ImmutableVarData: ... def json(self) -> str: ... + def _type(self) -> Var: ... @dataclass(eq=False) class BaseVar(Var):