use less boilerplate

This commit is contained in:
Khaleel Al-Adhami 2024-08-15 16:05:10 -07:00
parent 1603e0606d
commit 9e1a35f514
20 changed files with 218 additions and 1718 deletions

View File

@ -325,7 +325,6 @@ _MAPPING: dict = {
"utils.imports": ["ImportVar"], "utils.imports": ["ImportVar"],
"utils.serializers": ["serializer"], "utils.serializers": ["serializer"],
"vars": ["Var"], "vars": ["Var"],
"ivars.base": ["cached_var"],
} }
_SUBMODULES: set[str] = { _SUBMODULES: set[str] = {

View File

@ -175,7 +175,6 @@ from .event import stop_propagation as stop_propagation
from .event import upload_files as upload_files from .event import upload_files as upload_files
from .event import window_alert as window_alert from .event import window_alert as window_alert
from .experimental import _x as _x from .experimental import _x as _x
from .ivars.base import cached_var as cached_var
from .middleware import Middleware as Middleware from .middleware import Middleware as Middleware
from .middleware import middleware as middleware from .middleware import middleware as middleware
from .model import Model as Model from .model import Model as Model

View File

@ -442,11 +442,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
raise raise
except TypeError as e: except TypeError as e:
message = str(e) message = str(e)
if ( if "Var" in message:
"BaseVar" in message
or "ComputedVar" in message
or "ImmutableComputedVar" in message
):
raise VarOperationTypeError( raise VarOperationTypeError(
"You may be trying to use an invalid Python function on a state var. " "You may be trying to use an invalid Python function on a state var. "
"When referencing a var inside your render code, only limited var operations are supported. " "When referencing a var inside your render code, only limited var operations are supported. "

View File

@ -7,7 +7,6 @@ from typing import Any, Optional
from reflex.components.chakra import ChakraComponent, LiteralImageLoading from reflex.components.chakra import ChakraComponent, LiteralImageLoading
from reflex.components.component import Component from reflex.components.component import Component
from reflex.event import EventHandler from reflex.event import EventHandler
from reflex.ivars.base import LiteralVar
from reflex.vars import Var from reflex.vars import Var
@ -69,7 +68,4 @@ class Image(ChakraComponent):
Returns: Returns:
The Image component. The Image component.
""" """
src = props.get("src", None)
if src is not None and not isinstance(src, (Var)):
props["src"] = LiteralVar.create(value=src)
return super().create(*children, **props) return super().create(*children, **props)

View File

@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, Optional, Union, overload
from reflex.components.chakra import ChakraComponent from reflex.components.chakra import ChakraComponent
from reflex.components.chakra.navigation.link import Link from reflex.components.chakra.navigation.link import Link
from reflex.components.component import Component
from reflex.event import EventHandler, EventSpec from reflex.event import EventHandler, EventSpec
from reflex.style import Style from reflex.style import Style
from reflex.vars import Var from reflex.vars import Var
@ -239,7 +240,7 @@ class BreadcrumbLink(Link):
rel: Optional[Union[Var[str], str]] = None, rel: Optional[Union[Var[str], str]] = None,
href: Optional[Union[Var[str], str]] = None, href: Optional[Union[Var[str], str]] = None,
text: Optional[Union[Var[str], str]] = None, text: Optional[Union[Var[str], str]] = None,
as_: Optional[Union[Var[str], str]] = None, as_: Optional[Union[Var[Component], Component]] = None,
is_external: Optional[Union[Var[bool], bool]] = None, is_external: Optional[Union[Var[bool], bool]] = None,
style: Optional[Style] = None, style: Optional[Style] = None,
key: Optional[Any] = None, key: Optional[Any] = None,

View File

@ -25,7 +25,7 @@ class Link(ChakraComponent):
text: Var[str] text: Var[str]
# What the link renders to. # What the link renders to.
as_: Var[str] = ImmutableVar.create_safe("NextLink").to(str) as_: Var[Component] = ImmutableVar(_var_name="NextLink", _var_type=Component)
# If true, the link will open in new tab. # If true, the link will open in new tab.
is_external: Var[bool] is_external: Var[bool]

View File

@ -6,6 +6,7 @@
from typing import Any, Callable, Dict, Optional, Union, overload from typing import Any, Callable, Dict, Optional, Union, overload
from reflex.components.chakra import ChakraComponent from reflex.components.chakra import ChakraComponent
from reflex.components.component import Component
from reflex.components.next.link import NextLink from reflex.components.next.link import NextLink
from reflex.event import EventHandler, EventSpec from reflex.event import EventHandler, EventSpec
from reflex.style import Style from reflex.style import Style
@ -24,7 +25,7 @@ class Link(ChakraComponent):
rel: Optional[Union[Var[str], str]] = None, rel: Optional[Union[Var[str], str]] = None,
href: Optional[Union[Var[str], str]] = None, href: Optional[Union[Var[str], str]] = None,
text: Optional[Union[Var[str], str]] = None, text: Optional[Union[Var[str], str]] = None,
as_: Optional[Union[Var[str], str]] = None, as_: Optional[Union[Var[Component], Component]] = None,
is_external: Optional[Union[Var[bool], bool]] = None, is_external: Optional[Union[Var[bool], bool]] = None,
style: Optional[Style] = None, style: Optional[Style] = None,
key: Optional[Any] = None, key: Optional[Any] = None,

View File

@ -22,7 +22,7 @@ from reflex.ivars.base import ImmutableVar, LiteralVar
from reflex.ivars.function import FunctionStringVar from reflex.ivars.function import FunctionStringVar
from reflex.ivars.number import BooleanVar from reflex.ivars.number import BooleanVar
from reflex.ivars.sequence import LiteralArrayVar from reflex.ivars.sequence import LiteralArrayVar
from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.imports import ImportVar
from reflex.vars import ImmutableVarData, Var, VarData from reflex.vars import ImmutableVarData, Var, VarData
connect_error_var_data: VarData = VarData( # type: ignore connect_error_var_data: VarData = VarData( # type: ignore
@ -56,20 +56,9 @@ has_too_many_connection_errors: Var = ImmutableVar.create_safe(
).to(BooleanVar) ).to(BooleanVar)
class WebsocketTargetURL: class WebsocketTargetURL(ImmutableVar):
"""A component that renders the websocket target URL.""" """A component that renders the websocket target URL."""
def add_imports(self) -> ImportDict:
"""Add imports for the websocket target URL component.
Returns:
The import dict.
"""
return {
f"/{Dirs.STATE_PATH}": [ImportVar(tag="getBackendURL")],
"/env.json": [ImportVar(tag="env", is_default=True)],
}
@classmethod @classmethod
def create(cls) -> ImmutableVar: def create(cls) -> ImmutableVar:
"""Create a websocket target URL component. """Create a websocket target URL component.
@ -85,6 +74,7 @@ class WebsocketTargetURL:
f"/{Dirs.STATE_PATH}": [ImportVar(tag="getBackendURL")], f"/{Dirs.STATE_PATH}": [ImportVar(tag="getBackendURL")],
}, },
), ),
_var_type=WebsocketTargetURL,
) )

View File

@ -12,7 +12,7 @@ from reflex.components.sonner.toast import Toaster, ToastProps
from reflex.event import EventHandler, EventSpec from reflex.event import EventHandler, EventSpec
from reflex.ivars.base import ImmutableVar from reflex.ivars.base import ImmutableVar
from reflex.style import Style from reflex.style import Style
from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.imports import ImportVar
from reflex.vars import Var, VarData from reflex.vars import Var, VarData
connect_error_var_data: VarData connect_error_var_data: VarData
@ -22,8 +22,7 @@ connection_errors_count: Var
has_connection_errors: Var has_connection_errors: Var
has_too_many_connection_errors: Var has_too_many_connection_errors: Var
class WebsocketTargetURL: class WebsocketTargetURL(ImmutableVar):
def add_imports(self) -> ImportDict: ...
@classmethod @classmethod
def create(cls) -> ImmutableVar: ... # type: ignore def create(cls) -> ImmutableVar: ... # type: ignore

View File

@ -2,9 +2,10 @@
from __future__ import annotations from __future__ import annotations
import re
from typing import Dict, Literal, Optional, Union from typing import Dict, Literal, Optional, Union
from typing_extensions import get_args
from reflex.components.component import Component from reflex.components.component import Component
from reflex.components.core.cond import color_mode_cond from reflex.components.core.cond import color_mode_cond
from reflex.components.lucide.icon import Icon from reflex.components.lucide.icon import Icon
@ -350,6 +351,20 @@ LiteralCodeLanguage = Literal[
] ]
def replace_quotes_with_camel_case(value: str) -> str:
"""Replaces quotes in the given string with camel case format.
Args:
value (str): The string to be processed.
Returns:
str: The processed string with quotes replaced by camel case.
"""
for theme in get_args(LiteralCodeBlockTheme):
value = value.replace(f'"{theme}"', format.to_camel_case(theme))
return value
class CodeBlock(Component): class CodeBlock(Component):
"""A code block.""" """A code block."""
@ -393,20 +408,13 @@ class CodeBlock(Component):
themeString = str(self.theme) themeString = str(self.theme)
themes = re.findall(r'"(.*?)"', themeString) selected_themes = []
if not themes:
themes = [themeString]
if "oneLight" in themeString: for possibleTheme in get_args(LiteralCodeBlockTheme):
themes.append("light") if possibleTheme in themeString:
if "oneDark" in themeString: selected_themes.append(possibleTheme)
themes.append("dark")
if "one-light" in themeString:
themes.append("light")
if "one-dark" in themeString:
themes.append("dark")
themes = sorted(set(themes)) selected_themes = sorted(set(selected_themes))
imports_.update( imports_.update(
{ {
@ -417,7 +425,7 @@ class CodeBlock(Component):
install=False, install=False,
) )
] ]
for theme in themes for theme in selected_themes
} }
) )
@ -523,12 +531,14 @@ class CodeBlock(Component):
def _render(self): def _render(self):
out = super()._render() out = super()._render()
predicate, qmark, value = self.theme._var_name.partition("?")
out.add_props( theme = self.theme._replace(
style=ImmutableVar.create( _var_name=replace_quotes_with_camel_case(str(self.theme))
format.to_camel_case(f"{predicate}{qmark}{value.replace('`', '')}"), )
out.add_props(style=theme).remove_props("theme", "code").add_props(
children=self.code
) )
).remove_props("theme", "code").add_props(children=self.code)
return out return out

View File

@ -341,6 +341,8 @@ LiteralCodeLanguage = Literal[
"zig", "zig",
] ]
def replace_quotes_with_camel_case(value: str) -> str: ...
class CodeBlock(Component): class CodeBlock(Component):
def add_imports(self) -> ImportDict: ... def add_imports(self) -> ImportDict: ...
@overload @overload

View File

@ -3,7 +3,6 @@
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from reflex.event import EventHandler from reflex.event import EventHandler
from reflex.ivars.base import LiteralVar
from reflex.utils import types from reflex.utils import types
from reflex.vars import Var from reflex.vars import Var
@ -103,8 +102,4 @@ class Image(NextComponent):
# mysteriously, following `sizes` prop is needed to avoid blury images. # mysteriously, following `sizes` prop is needed to avoid blury images.
props["sizes"] = "100vw" props["sizes"] = "100vw"
src = props.get("src", None)
if src is not None and not isinstance(src, (Var)):
props["src"] = LiteralVar.create(src)
return super().create(*children, **props) return super().create(*children, **props)

View File

@ -174,14 +174,13 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
"The _var_full_name_needs_state_prefix argument is not supported for ImmutableVar." "The _var_full_name_needs_state_prefix argument is not supported for ImmutableVar."
) )
field_values = dict( return dataclasses.replace(
_var_name=kwargs.pop("_var_name", self._var_name), self,
_var_type=kwargs.pop("_var_type", self._var_type),
_var_data=ImmutableVarData.merge( _var_data=ImmutableVarData.merge(
kwargs.get("_var_data", self._var_data), merge_var_data kwargs.get("_var_data", self._var_data), merge_var_data
), ),
**kwargs,
) )
return dataclasses.replace(self, **field_values)
@classmethod @classmethod
def create( def create(
@ -366,6 +365,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
fixed_output_type = get_origin(output) or output fixed_output_type = get_origin(output) or output
# If the first argument is a python type, we map it to the corresponding Var type.
if fixed_output_type is dict: if fixed_output_type is dict:
return self.to(ObjectVar, output) return self.to(ObjectVar, output)
if fixed_output_type in (list, tuple, set): if fixed_output_type in (list, tuple, set):
@ -409,17 +409,22 @@ class ImmutableVar(Var, Generic[VAR_TYPE]):
# ) # )
return ToFunctionOperation.create(self, var_type or Callable) return ToFunctionOperation.create(self, var_type or Callable)
if not issubclass(output, Var) and var_type is None: # If we can't determine the first argument, we just replace the _var_type.
if not issubclass(output, Var) or var_type is None:
return dataclasses.replace( return dataclasses.replace(
self, self,
_var_type=output, _var_type=output,
) )
# We couldn't determine the output type to be any other Var type, so we replace the _var_type.
if var_type is not None:
return dataclasses.replace( return dataclasses.replace(
self, self,
_var_type=var_type, _var_type=var_type,
) )
return self
def guess_type(self) -> ImmutableVar: def guess_type(self) -> ImmutableVar:
"""Guesses the type of the variable based on its `_var_type` attribute. """Guesses the type of the variable based on its `_var_type` attribute.
@ -1005,33 +1010,13 @@ def figure_out_type(value: Any) -> types.GenericType:
return type(value) return type(value)
@dataclasses.dataclass( class CachedVarOperation:
eq=False, """Base class for cached var operations to lower boilerplate code."""
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class AndOperation(ImmutableVar):
"""Class for the logical AND operation."""
# The first var.
_var1: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
# The second var.
_var2: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
def __post_init__(self): def __post_init__(self):
"""Post-initialize the AndOperation.""" """Post-initialize the CachedVarOperation."""
object.__delattr__(self, "_var_name") object.__delattr__(self, "_var_name")
@functools.cached_property
def _cached_var_name(self) -> str:
"""Get the cached var name.
Returns:
The cached var name.
"""
return f"({str(self._var1)} && {str(self._var2)})"
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var. """Get an attribute of the var.
@ -1043,19 +1028,13 @@ class AndOperation(ImmutableVar):
""" """
if name == "_var_name": if name == "_var_name":
return self._cached_var_name return self._cached_var_name
return getattr(super(type(self), self), name)
@functools.cached_property parent_classes = inspect.getmro(self.__class__)
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get the cached VarData.
Returns: print(repr(self), parent_classes, name)
The cached VarData.
""" return parent_classes[parent_classes.index(CachedVarOperation) + 1].__getattr__( # type: ignore
return ImmutableVarData.merge( self, name
self._var1._get_all_var_data(),
self._var2._get_all_var_data(),
self._var_data,
) )
def _get_all_var_data(self) -> ImmutableVarData | None: def _get_all_var_data(self) -> ImmutableVarData | None:
@ -1066,6 +1045,67 @@ class AndOperation(ImmutableVar):
""" """
return self._cached_get_all_var_data return self._cached_get_all_var_data
@functools.cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get the cached VarData.
Returns:
The cached VarData.
"""
return ImmutableVarData.merge(
*map(
lambda value: (
value._get_all_var_data() if isinstance(value, Var) else None
),
map(
lambda field: getattr(self, field.name),
dataclasses.fields(self), # type: ignore
),
),
self._var_data,
)
def __hash__(self) -> int:
"""Calculate the hash of the object.
Returns:
The hash of the object.
"""
return hash(
(
self.__class__.__name__,
*[
getattr(self, field.name)
for field in dataclasses.fields(self) # type: ignore
if field.name not in ["_var_name", "_var_data", "_var_type"]
],
)
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class AndOperation(CachedVarOperation, ImmutableVar):
"""Class for the logical AND operation."""
# The first var.
_var1: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
# The second var.
_var2: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
@functools.cached_property
def _cached_var_name(self) -> str:
"""Get the cached var name.
Returns:
The cached var name.
"""
return f"({str(self._var1)} && {str(self._var2)})"
def __hash__(self) -> int: def __hash__(self) -> int:
"""Calculates the hash value of the object. """Calculates the hash value of the object.
@ -1103,7 +1143,7 @@ class AndOperation(ImmutableVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class OrOperation(ImmutableVar): class OrOperation(CachedVarOperation, ImmutableVar):
"""Class for the logical OR operation.""" """Class for the logical OR operation."""
# The first var. # The first var.
@ -1112,10 +1152,6 @@ class OrOperation(ImmutableVar):
# The second var. # The second var.
_var2: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) _var2: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
def __post_init__(self):
"""Post-initialize the OrOperation."""
object.__delattr__(self, "_var_name")
@functools.cached_property @functools.cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""Get the cached var name. """Get the cached var name.
@ -1125,40 +1161,6 @@ class OrOperation(ImmutableVar):
""" """
return f"({str(self._var1)} || {str(self._var2)})" return f"({str(self._var1)} || {str(self._var2)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute.
"""
if name == "_var_name":
return self._cached_var_name
return getattr(super(type(self), self), name)
@functools.cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get the cached VarData.
Returns:
The cached VarData.
"""
return ImmutableVarData.merge(
self._var1._get_all_var_data(),
self._var2._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __hash__(self) -> int: def __hash__(self) -> int:
"""Calculates the hash value for the object. """Calculates the hash value for the object.
@ -1378,12 +1380,6 @@ class ImmutableComputedVar(ImmutableVar[RETURN_TYPE]):
backend=kwargs.pop("backend", self._backend), backend=kwargs.pop("backend", self._backend),
_var_name=kwargs.pop("_var_name", self._var_name), _var_name=kwargs.pop("_var_name", self._var_name),
_var_type=kwargs.pop("_var_type", self._var_type), _var_type=kwargs.pop("_var_type", self._var_type),
_var_is_local=kwargs.pop("_var_is_local", self._var_is_local),
_var_is_string=kwargs.pop("_var_is_string", self._var_is_string),
_var_full_name_needs_state_prefix=kwargs.pop(
"_var_full_name_needs_state_prefix",
self._var_full_name_needs_state_prefix,
),
_var_data=kwargs.pop( _var_data=kwargs.pop(
"_var_data", VarData.merge(self._var_data, merge_var_data) "_var_data", VarData.merge(self._var_data, merge_var_data)
), ),
@ -1676,7 +1672,6 @@ def immutable_computed_var(
auto_deps: bool = True, auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None, interval: Optional[Union[datetime.timedelta, int]] = None,
backend: bool | None = None, backend: bool | None = None,
_deprecated_cached_var: bool = False,
**kwargs, **kwargs,
) -> Callable[ ) -> Callable[
[Callable[[BASE_STATE], RETURN_TYPE]], ImmutableComputedVar[RETURN_TYPE] [Callable[[BASE_STATE], RETURN_TYPE]], ImmutableComputedVar[RETURN_TYPE]
@ -1692,7 +1687,6 @@ def immutable_computed_var(
auto_deps: bool = True, auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None, interval: Optional[Union[datetime.timedelta, int]] = None,
backend: bool | None = None, backend: bool | None = None,
_deprecated_cached_var: bool = False,
**kwargs, **kwargs,
) -> ImmutableComputedVar[RETURN_TYPE]: ... ) -> ImmutableComputedVar[RETURN_TYPE]: ...
@ -1705,7 +1699,6 @@ def immutable_computed_var(
auto_deps: bool = True, auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None, interval: Optional[Union[datetime.timedelta, int]] = None,
backend: bool | None = None, backend: bool | None = None,
_deprecated_cached_var: bool = False,
**kwargs, **kwargs,
) -> ( ) -> (
ImmutableComputedVar | Callable[[Callable[[BASE_STATE], Any]], ImmutableComputedVar] ImmutableComputedVar | Callable[[Callable[[BASE_STATE], Any]], ImmutableComputedVar]
@ -1720,7 +1713,6 @@ def immutable_computed_var(
auto_deps: Whether var dependencies should be auto-determined. auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated. interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var. backend: Whether the computed var is a backend var.
_deprecated_cached_var: Indicate usage of deprecated cached_var partial function.
**kwargs: additional attributes to set on the instance **kwargs: additional attributes to set on the instance
Returns: Returns:
@ -1730,14 +1722,6 @@ def immutable_computed_var(
ValueError: If caching is disabled and an update interval is set. ValueError: If caching is disabled and an update interval is set.
VarDependencyError: If user supplies dependencies without caching. VarDependencyError: If user supplies dependencies without caching.
""" """
if _deprecated_cached_var:
console.deprecate(
feature_name="cached_var",
reason=("Use @rx.var(cache=True) instead of @rx.cached_var."),
deprecation_version="0.5.6",
removal_version="0.6.0",
)
if cache is False and interval is not None: if cache is False and interval is not None:
raise ValueError("Cannot set update interval without caching.") raise ValueError("Cannot set update interval without caching.")
@ -1760,9 +1744,3 @@ def immutable_computed_var(
) )
return wrapper return wrapper
# Partial function of computed_var with cache=True
cached_var = functools.partial(
immutable_computed_var, cache=True, _deprecated_cached_var=True
)

View File

@ -10,7 +10,7 @@ from typing import Any, Callable, Optional, Tuple, Type, Union
from reflex.utils.types import GenericType from reflex.utils.types import GenericType
from reflex.vars import ImmutableVarData, Var, VarData from reflex.vars import ImmutableVarData, Var, VarData
from .base import ImmutableVar, LiteralVar from .base import CachedVarOperation, ImmutableVar, LiteralVar
class FunctionVar(ImmutableVar[Callable]): class FunctionVar(ImmutableVar[Callable]):
@ -73,25 +73,12 @@ class FunctionStringVar(FunctionVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class VarOperationCall(ImmutableVar): class VarOperationCall(CachedVarOperation, ImmutableVar):
"""Base class for immutable vars that are the result of a function call.""" """Base class for immutable vars that are the result of a function call."""
_func: Optional[FunctionVar] = dataclasses.field(default=None) _func: Optional[FunctionVar] = dataclasses.field(default=None)
_args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -101,39 +88,6 @@ class VarOperationCall(ImmutableVar):
""" """
return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))" return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._func._get_all_var_data() if self._func is not None else None,
*[var._get_all_var_data() for var in self._args],
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""
object.__delattr__(self, "_var_name")
def __hash__(self):
"""Hash the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._func, self._args))
@classmethod @classmethod
def create( def create(
cls, cls,
@ -166,25 +120,12 @@ class VarOperationCall(ImmutableVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ArgsFunctionOperation(FunctionVar): class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
"""Base class for immutable function defined via arguments and return expression.""" """Base class for immutable function defined via arguments and return expression."""
_args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple) _args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
_return_expr: Union[Var, Any] = dataclasses.field(default=None) _return_expr: Union[Var, Any] = dataclasses.field(default=None)
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -194,38 +135,6 @@ class ArgsFunctionOperation(FunctionVar):
""" """
return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))" return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._return_expr._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""
object.__delattr__(self, "_var_name")
def __hash__(self):
"""Hash the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._args_names, self._return_expr))
@classmethod @classmethod
def create( def create(
cls, cls,
@ -258,30 +167,13 @@ class ArgsFunctionOperation(FunctionVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ToFunctionOperation(FunctionVar): class ToFunctionOperation(CachedVarOperation, FunctionVar):
"""Base class of converting a var to a function.""" """Base class of converting a var to a function."""
_original_var: Var = dataclasses.field( _original_var: Var = dataclasses.field(
default_factory=lambda: LiteralVar.create(None) default_factory=lambda: LiteralVar.create(None)
) )
def __post_init__(self):
"""Post-initialize the var."""
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -291,34 +183,6 @@ class ToFunctionOperation(FunctionVar):
""" """
return str(self._original_var) return str(self._original_var)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._original_var._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __hash__(self):
"""Hash the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._original_var))
@classmethod @classmethod
def create( def create(
cls, cls,

View File

@ -11,6 +11,7 @@ from typing import Any, Union
from reflex.vars import ImmutableVarData, Var, VarData from reflex.vars import ImmutableVarData, Var, VarData
from .base import ( from .base import (
CachedVarOperation,
ImmutableVar, ImmutableVar,
LiteralVar, LiteralVar,
unionize, unionize,
@ -330,7 +331,7 @@ class NumberVar(ImmutableVar[Union[int, float]]):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class BinaryNumberOperation(NumberVar): class BinaryNumberOperation(CachedVarOperation, NumberVar):
"""Base class for immutable number vars that are the result of a binary operation.""" """Base class for immutable number vars that are the result of a binary operation."""
_lhs: NumberVar = dataclasses.field( _lhs: NumberVar = dataclasses.field(
@ -340,10 +341,6 @@ class BinaryNumberOperation(NumberVar):
default_factory=lambda: LiteralNumberVar.create(0) default_factory=lambda: LiteralNumberVar.create(0)
) )
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -355,49 +352,6 @@ class BinaryNumberOperation(NumberVar):
"BinaryNumberOperation must implement _cached_var_name" "BinaryNumberOperation must implement _cached_var_name"
) )
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(BinaryNumberOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
first_value = (
self._lhs if isinstance(self._lhs, Var) else LiteralNumberVar(self._lhs)
)
second_value = (
self._rhs if isinstance(self._rhs, Var) else LiteralNumberVar(self._rhs)
)
return ImmutableVarData.merge(
first_value._get_all_var_data(),
second_value._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._lhs, self._rhs))
@classmethod @classmethod
def create( def create(
cls, lhs: number_types, rhs: number_types, _var_data: VarData | None = None cls, lhs: number_types, rhs: number_types, _var_data: VarData | None = None
@ -430,17 +384,13 @@ class BinaryNumberOperation(NumberVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class UnaryNumberOperation(NumberVar): class UnaryNumberOperation(CachedVarOperation, NumberVar):
"""Base class for immutable number vars that are the result of a unary operation.""" """Base class for immutable number vars that are the result of a unary operation."""
_value: NumberVar = dataclasses.field( _value: NumberVar = dataclasses.field(
default_factory=lambda: LiteralNumberVar.create(0) default_factory=lambda: LiteralNumberVar.create(0)
) )
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -452,44 +402,6 @@ class UnaryNumberOperation(NumberVar):
"UnaryNumberOperation must implement _cached_var_name" "UnaryNumberOperation must implement _cached_var_name"
) )
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(UnaryNumberOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
value = (
self._value
if isinstance(self._value, Var)
else LiteralNumberVar(self._value)
)
return ImmutableVarData.merge(value._get_all_var_data(), self._var_data)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._value))
@classmethod @classmethod
def create(cls, value: NumberVar, _var_data: VarData | None = None): def create(cls, value: NumberVar, _var_data: VarData | None = None):
"""Create the unary number operation var. """Create the unary number operation var.
@ -787,17 +699,13 @@ class BooleanVar(ImmutableVar[bool]):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class BooleanToIntOperation(NumberVar): class BooleanToIntOperation(CachedVarOperation, NumberVar):
"""Base class for immutable number vars that are the result of a boolean to int operation.""" """Base class for immutable number vars that are the result of a boolean to int operation."""
_value: BooleanVar = dataclasses.field( _value: BooleanVar = dataclasses.field(
default_factory=lambda: LiteralBooleanVar.create(False) default_factory=lambda: LiteralBooleanVar.create(False)
) )
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -807,42 +715,6 @@ class BooleanToIntOperation(NumberVar):
""" """
return f"({str(self._value)} ? 1 : 0)" return f"({str(self._value)} ? 1 : 0)"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(BooleanToIntOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._value._get_all_var_data() if isinstance(self._value, Var) else None,
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._value))
@classmethod @classmethod
def create(cls, value: BooleanVar, _var_data: VarData | None = None): def create(cls, value: BooleanVar, _var_data: VarData | None = None):
"""Create the boolean to int operation var. """Create the boolean to int operation var.
@ -867,7 +739,7 @@ class BooleanToIntOperation(NumberVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ComparisonOperation(BooleanVar): class ComparisonOperation(CachedVarOperation, BooleanVar):
"""Base class for immutable boolean vars that are the result of a comparison operation.""" """Base class for immutable boolean vars that are the result of a comparison operation."""
_lhs: Var = dataclasses.field( _lhs: Var = dataclasses.field(
@ -877,10 +749,6 @@ class ComparisonOperation(BooleanVar):
default_factory=lambda: LiteralBooleanVar.create(False) default_factory=lambda: LiteralBooleanVar.create(False)
) )
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -890,41 +758,6 @@ class ComparisonOperation(BooleanVar):
""" """
raise NotImplementedError("ComparisonOperation must implement _cached_var_name") raise NotImplementedError("ComparisonOperation must implement _cached_var_name")
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ComparisonOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._lhs._get_all_var_data(), self._rhs._get_all_var_data()
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._lhs, self._rhs))
@classmethod @classmethod
def create(cls, lhs: Var | Any, rhs: Var | Any, _var_data: VarData | None = None): def create(cls, lhs: Var | Any, rhs: Var | Any, _var_data: VarData | None = None):
"""Create the comparison operation var. """Create the comparison operation var.
@ -1022,7 +855,7 @@ class NotEqualOperation(ComparisonOperation):
Returns: Returns:
The name of the var. The name of the var.
""" """
return f"({str(self._lhs)} != {str(self._rhs)})" return f"({str(self._lhs)} !== {str(self._rhs)})"
@dataclasses.dataclass( @dataclasses.dataclass(
@ -1030,7 +863,7 @@ class NotEqualOperation(ComparisonOperation):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class LogicalOperation(BooleanVar): class LogicalOperation(CachedVarOperation, BooleanVar):
"""Base class for immutable boolean vars that are the result of a logical operation.""" """Base class for immutable boolean vars that are the result of a logical operation."""
_lhs: BooleanVar = dataclasses.field( _lhs: BooleanVar = dataclasses.field(
@ -1040,10 +873,6 @@ class LogicalOperation(BooleanVar):
default_factory=lambda: LiteralBooleanVar.create(False) default_factory=lambda: LiteralBooleanVar.create(False)
) )
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -1053,41 +882,6 @@ class LogicalOperation(BooleanVar):
""" """
raise NotImplementedError("LogicalOperation must implement _cached_var_name") raise NotImplementedError("LogicalOperation must implement _cached_var_name")
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(LogicalOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._lhs._get_all_var_data(), self._rhs._get_all_var_data()
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._lhs, self._rhs))
@classmethod @classmethod
def create( def create(
cls, lhs: boolean_types, rhs: boolean_types, _var_data: VarData | None = None cls, lhs: boolean_types, rhs: boolean_types, _var_data: VarData | None = None
@ -1122,17 +916,13 @@ class LogicalOperation(BooleanVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class BooleanNotOperation(BooleanVar): class BooleanNotOperation(CachedVarOperation, BooleanVar):
"""Base class for immutable boolean vars that are the result of a logical NOT operation.""" """Base class for immutable boolean vars that are the result of a logical NOT operation."""
_value: BooleanVar = dataclasses.field( _value: BooleanVar = dataclasses.field(
default_factory=lambda: LiteralBooleanVar.create(False) default_factory=lambda: LiteralBooleanVar.create(False)
) )
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -1142,39 +932,6 @@ class BooleanNotOperation(BooleanVar):
""" """
return f"!({str(self._value)})" return f"!({str(self._value)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(BooleanNotOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(self._value._get_all_var_data())
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._value))
@classmethod @classmethod
def create(cls, value: boolean_types, _var_data: VarData | None = None): def create(cls, value: boolean_types, _var_data: VarData | None = None):
"""Create the logical NOT operation var. """Create the logical NOT operation var.
@ -1205,14 +962,6 @@ class LiteralBooleanVar(LiteralVar, BooleanVar):
_var_value: bool = dataclasses.field(default=False) _var_value: bool = dataclasses.field(default=False)
def __hash__(self) -> int:
"""Hash the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._var_value))
def json(self) -> str: def json(self) -> str:
"""Get the JSON representation of the var. """Get the JSON representation of the var.
@ -1221,6 +970,14 @@ class LiteralBooleanVar(LiteralVar, BooleanVar):
""" """
return "true" if self._var_value else "false" return "true" if self._var_value else "false"
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._var_value))
@classmethod @classmethod
def create(cls, value: bool, _var_data: VarData | None = None): def create(cls, value: bool, _var_data: VarData | None = None):
"""Create the boolean var. """Create the boolean var.
@ -1250,14 +1007,6 @@ class LiteralNumberVar(LiteralVar, NumberVar):
_var_value: float | int = dataclasses.field(default=0) _var_value: float | int = dataclasses.field(default=0)
def __hash__(self) -> int:
"""Hash the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._var_value))
def json(self) -> str: def json(self) -> str:
"""Get the JSON representation of the var. """Get the JSON representation of the var.
@ -1266,6 +1015,14 @@ class LiteralNumberVar(LiteralVar, NumberVar):
""" """
return json.dumps(self._var_value) return json.dumps(self._var_value)
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._var_value))
@classmethod @classmethod
def create(cls, value: float | int, _var_data: VarData | None = None): def create(cls, value: float | int, _var_data: VarData | None = None):
"""Create the number var. """Create the number var.
@ -1294,17 +1051,13 @@ boolean_types = Union[BooleanVar, bool]
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ToNumberVarOperation(NumberVar): class ToNumberVarOperation(CachedVarOperation, NumberVar):
"""Base class for immutable number vars that are the result of a number operation.""" """Base class for immutable number vars that are the result of a number operation."""
_original_value: Var = dataclasses.field( _original_value: Var = dataclasses.field(
default_factory=lambda: LiteralNumberVar.create(0) default_factory=lambda: LiteralNumberVar.create(0)
) )
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -1314,41 +1067,6 @@ class ToNumberVarOperation(NumberVar):
""" """
return str(self._original_value) return str(self._original_value)
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ToNumberVarOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._original_value._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._original_value))
@classmethod @classmethod
def create( def create(
cls, cls,
@ -1379,7 +1097,7 @@ class ToNumberVarOperation(NumberVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ToBooleanVarOperation(BooleanVar): class ToBooleanVarOperation(CachedVarOperation, BooleanVar):
"""Base class for immutable boolean vars that are the result of a boolean operation.""" """Base class for immutable boolean vars that are the result of a boolean operation."""
_original_value: Var = dataclasses.field( _original_value: Var = dataclasses.field(
@ -1395,45 +1113,6 @@ class ToBooleanVarOperation(BooleanVar):
""" """
return f"Boolean({str(self._original_value)})" return f"Boolean({str(self._original_value)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ToBooleanVarOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._original_value._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._original_value))
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@classmethod @classmethod
def create( def create(
cls, cls,
@ -1462,7 +1141,7 @@ class ToBooleanVarOperation(BooleanVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class TernaryOperator(ImmutableVar): class TernaryOperator(CachedVarOperation, ImmutableVar):
"""Base class for immutable vars that are the result of a ternary operation.""" """Base class for immutable vars that are the result of a ternary operation."""
_condition: BooleanVar = dataclasses.field( _condition: BooleanVar = dataclasses.field(
@ -1475,10 +1154,6 @@ class TernaryOperator(ImmutableVar):
default_factory=lambda: LiteralNumberVar.create(0) default_factory=lambda: LiteralNumberVar.create(0)
) )
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -1490,46 +1165,6 @@ class TernaryOperator(ImmutableVar):
f"({str(self._condition)} ? {str(self._if_true)} : {str(self._if_false)})" f"({str(self._condition)} ? {str(self._if_true)} : {str(self._if_false)})"
) )
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(TernaryOperator, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._condition._get_all_var_data(),
self._if_true._get_all_var_data(),
self._if_false._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash(
(self.__class__.__name__, self._condition, self._if_true, self._if_false)
)
@classmethod @classmethod
def create( def create(
cls, cls,

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import functools
import sys import sys
import typing import typing
from functools import cached_property from functools import cached_property
@ -28,6 +29,7 @@ from reflex.utils.types import GenericType, get_attribute_access_type
from reflex.vars import ImmutableVarData, Var, VarData from reflex.vars import ImmutableVarData, Var, VarData
from .base import ( from .base import (
CachedVarOperation,
ImmutableVar, ImmutableVar,
LiteralVar, LiteralVar,
figure_out_type, figure_out_type,
@ -255,7 +257,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]):
attribute_type = get_attribute_access_type(var_type, name) attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None: if attribute_type is None:
raise VarAttributeError( raise VarAttributeError(
f"The State var `{self._var_name}` has no attribute '{name}' or may have been annotated " f"The State var `{str(self)}` has no attribute '{name}' or may have been annotated "
f"wrongly." f"wrongly."
) )
return ObjectItemOperation.create(self, name, attribute_type).guess_type() return ObjectItemOperation.create(self, name, attribute_type).guess_type()
@ -279,17 +281,13 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]): class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
"""Base class for immutable literal object vars.""" """Base class for immutable literal object vars."""
_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
default_factory=dict default_factory=dict
) )
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
def _key_type(self) -> Type: def _key_type(self) -> Type:
"""Get the type of the keys of the object. """Get the type of the keys of the object.
@ -308,19 +306,6 @@ class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]):
args_list = typing.get_args(self._var_type) args_list = typing.get_args(self._var_type)
return args_list[1] if args_list else Any return args_list[1] if args_list else Any
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the var. """The name of the var.
@ -339,30 +324,6 @@ class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]):
+ " })" + " })"
) )
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
*[
LiteralVar.create(value)._get_all_var_data()
for value in self._var_value.values()
],
*[LiteralVar.create(key)._get_all_var_data() for key in self._var_value],
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def json(self) -> str: def json(self) -> str:
"""Get the JSON representation of the object. """Get the JSON representation of the object.
@ -388,6 +349,22 @@ class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]):
""" """
return hash((self.__class__.__name__, self._var_name)) return hash((self.__class__.__name__, self._var_name))
@functools.cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all the var data.
Returns:
The var data.
"""
return ImmutableVarData.merge(
*[LiteralVar.create(var)._get_all_var_data() for var in self._var_value],
*[
LiteralVar.create(var)._get_all_var_data()
for var in self._var_value.values()
],
self._var_data,
)
@classmethod @classmethod
def create( def create(
cls, cls,
@ -418,17 +395,13 @@ class LiteralObjectVar(LiteralVar, ObjectVar[OBJECT_TYPE]):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ObjectToArrayOperation(ArrayVar): class ObjectToArrayOperation(CachedVarOperation, ArrayVar):
"""Base class for object to array operations.""" """Base class for object to array operations."""
_value: ObjectVar = dataclasses.field( _value: ObjectVar = dataclasses.field(
default_factory=lambda: LiteralObjectVar.create({}) default_factory=lambda: LiteralObjectVar.create({})
) )
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the operation. """The name of the operation.
@ -440,47 +413,6 @@ class ObjectToArrayOperation(ArrayVar):
"ObjectToArrayOperation must implement _cached_var_name" "ObjectToArrayOperation must implement _cached_var_name"
) )
def __getattr__(self, name):
"""Get an attribute of the operation.
Args:
name: The name of the attribute.
Returns:
The attribute of the operation.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the operation.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._value._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Get the hash of the operation.
Returns:
The hash of the operation.
"""
return hash((self.__class__.__name__, self._value))
@classmethod @classmethod
def create( def create(
cls, cls,
@ -508,9 +440,6 @@ class ObjectToArrayOperation(ArrayVar):
class ObjectKeysOperation(ObjectToArrayOperation): class ObjectKeysOperation(ObjectToArrayOperation):
"""Operation to get the keys of an object.""" """Operation to get the keys of an object."""
# value, List[value._key_type()], _var_data
# )
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the operation. """The name of the operation.
@ -553,7 +482,7 @@ class ObjectValuesOperation(ObjectToArrayOperation):
Returns: Returns:
The name of the operation. The name of the operation.
""" """
return f"Object.values({self._value._var_name})" return f"Object.values({str(self._value)})"
@classmethod @classmethod
def create( def create(
@ -588,7 +517,7 @@ class ObjectEntriesOperation(ObjectToArrayOperation):
Returns: Returns:
The name of the operation. The name of the operation.
""" """
return f"Object.entries({self._value._var_name})" return f"Object.entries({str(self._value)})"
@classmethod @classmethod
def create( def create(
@ -618,7 +547,7 @@ class ObjectEntriesOperation(ObjectToArrayOperation):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ObjectMergeOperation(ObjectVar): class ObjectMergeOperation(CachedVarOperation, ObjectVar):
"""Operation to merge two objects.""" """Operation to merge two objects."""
_lhs: ObjectVar = dataclasses.field( _lhs: ObjectVar = dataclasses.field(
@ -635,53 +564,7 @@ class ObjectMergeOperation(ObjectVar):
Returns: Returns:
The name of the operation. The name of the operation.
""" """
return f"Object.assign({self._lhs._var_name}, {self._rhs._var_name})" return f"({{...{str(self._lhs)}, ...{str(self._rhs)}}})"
def __getattr__(self, name):
"""Get an attribute of the operation.
Args:
name: The name of the attribute.
Returns:
The attribute of the operation.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the operation.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._lhs._get_all_var_data(),
self._rhs._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Get the hash of the operation.
Returns:
The hash of the operation.
"""
return hash((self.__class__.__name__, self._lhs, self._rhs))
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@classmethod @classmethod
def create( def create(
@ -715,7 +598,7 @@ class ObjectMergeOperation(ObjectVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ObjectItemOperation(ImmutableVar): class ObjectItemOperation(CachedVarOperation, ImmutableVar):
"""Operation to get an item from an object.""" """Operation to get an item from an object."""
_object: ObjectVar = dataclasses.field( _object: ObjectVar = dataclasses.field(
@ -734,52 +617,6 @@ class ObjectItemOperation(ImmutableVar):
return f"{str(self._object)}?.[{str(self._key)}]" return f"{str(self._object)}?.[{str(self._key)}]"
return f"{str(self._object)}[{str(self._key)}]" return f"{str(self._object)}[{str(self._key)}]"
def __getattr__(self, name):
"""Get an attribute of the operation.
Args:
name: The name of the attribute.
Returns:
The attribute of the operation.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the operation.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._object._get_all_var_data(),
self._key._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Get the hash of the operation.
Returns:
The hash of the operation.
"""
return hash((self.__class__.__name__, self._object, self._key))
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@classmethod @classmethod
def create( def create(
cls, cls,
@ -813,7 +650,7 @@ class ObjectItemOperation(ImmutableVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ToObjectOperation(ObjectVar): class ToObjectOperation(CachedVarOperation, ObjectVar):
"""Operation to convert a var to an object.""" """Operation to convert a var to an object."""
_original_var: Var = dataclasses.field( _original_var: Var = dataclasses.field(
@ -829,51 +666,6 @@ class ToObjectOperation(ObjectVar):
""" """
return str(self._original_var) return str(self._original_var)
def __getattr__(self, name):
"""Get an attribute of the operation.
Args:
name: The name of the attribute.
Returns:
The attribute of the operation.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the operation.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._original_var._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Get the hash of the operation.
Returns:
The hash of the operation.
"""
return hash((self.__class__.__name__, self._original_var))
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@classmethod @classmethod
def create( def create(
cls, cls,
@ -904,7 +696,7 @@ class ToObjectOperation(ObjectVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ObjectHasOwnProperty(BooleanVar): class ObjectHasOwnProperty(CachedVarOperation, BooleanVar):
"""Operation to check if an object has a property.""" """Operation to check if an object has a property."""
_object: ObjectVar = dataclasses.field( _object: ObjectVar = dataclasses.field(
@ -912,10 +704,6 @@ class ObjectHasOwnProperty(BooleanVar):
) )
_key: Var | Any = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) _key: Var | Any = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_var_name")
@cached_property @cached_property
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
"""The name of the operation. """The name of the operation.
@ -925,48 +713,6 @@ class ObjectHasOwnProperty(BooleanVar):
""" """
return f"{str(self._object)}.hasOwnProperty({str(self._key)})" return f"{str(self._object)}.hasOwnProperty({str(self._key)})"
def __getattr__(self, name):
"""Get an attribute of the operation.
Args:
name: The name of the attribute.
Returns:
The attribute of the operation.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the operation.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._object._get_all_var_data(),
self._key._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Get the hash of the operation.
Returns:
The hash of the operation.
"""
return hash((self.__class__.__name__, self._object, self._key))
@classmethod @classmethod
def create( def create(
cls, cls,

File diff suppressed because it is too large Load Diff

View File

@ -905,7 +905,6 @@ class Var:
Raises: Raises:
VarTypeError: If the var is not indexable. VarTypeError: If the var is not indexable.
""" """
print(repr(self))
from reflex.utils import format from reflex.utils import format
# Indexing is only supported for strings, lists, tuples, dicts, and dataframes. # Indexing is only supported for strings, lists, tuples, dicts, and dataframes.
@ -1058,7 +1057,6 @@ class Var:
return self._replace( return self._replace(
_var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}", _var_name=f"{self._var_name}{'?' if is_optional else ''}.{name}",
_var_type=type_, _var_type=type_,
_var_is_string=False,
) )
if name in REPLACED_NAMES: if name in REPLACED_NAMES:

View File

@ -1073,7 +1073,7 @@ TEST_VAR = LiteralVar.create("test")._replace(
) )
) )
FORMATTED_TEST_VAR = LiteralVar.create(f"foo{TEST_VAR}bar") FORMATTED_TEST_VAR = LiteralVar.create(f"foo{TEST_VAR}bar")
STYLE_VAR = TEST_VAR._replace(_var_name="style", _var_is_local=False) STYLE_VAR = TEST_VAR._replace(_var_name="style")
EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain) EVENT_CHAIN_VAR = TEST_VAR._replace(_var_type=EventChain)
ARG_VAR = Var.create("arg") ARG_VAR = Var.create("arg")
@ -1299,8 +1299,8 @@ def test_get_vars(component, exp_vars):
comp_vars, comp_vars,
sorted(exp_vars, key=lambda v: v._var_name), sorted(exp_vars, key=lambda v: v._var_name),
): ):
print(str(comp_var), str(exp_var)) # print(str(comp_var), str(exp_var))
print(comp_var._get_all_var_data(), exp_var._get_all_var_data()) # print(comp_var._get_all_var_data(), exp_var._get_all_var_data())
assert comp_var.equals(exp_var) assert comp_var.equals(exp_var)

View File

@ -1036,7 +1036,7 @@ def test_object_operations():
assert str(object_var["a"]) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]' assert str(object_var["a"]) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]'
assert ( assert (
str(object_var.merge(LiteralObjectVar.create({"c": 4, "d": 5}))) str(object_var.merge(LiteralObjectVar.create({"c": 4, "d": 5})))
== 'Object.assign(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }), ({ ["c"] : 4, ["d"] : 5 }))' == '({...({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }), ...({ ["c"] : 4, ["d"] : 5 })})'
) )