From 1c400043c646ba723fbfb87564fcb459d36fd75f Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Tue, 30 Jul 2024 15:38:32 -0700 Subject: [PATCH] [REF-3328] Implement __getitem__ for ArrayVar (#3705) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * half of the way there * implement __getitem__ for array * add some tests * add fixes to pyright * fix default factory * implement array operations * format code * fix pyright issue * give up * add object operations * add test for merge * pyright 🥺 * use str isntead of _var_name Co-authored-by: Masen Furer * wrong var_type * make to much nicer * add subclass checking * enhance types * use builtin list type * improve typing even more * i'm awaiting october * use even better typing * add hash, json, and guess type method * fix pyright issues * add a test and fix lots of errors * fix pyright once again * add type inference to list --------- Co-authored-by: Masen Furer --- reflex/experimental/vars/__init__.py | 4 +- reflex/experimental/vars/base.py | 254 +++--- reflex/experimental/vars/function.py | 78 +- reflex/experimental/vars/number.py | 165 +++- reflex/experimental/vars/object.py | 627 +++++++++++++++ reflex/experimental/vars/sequence.py | 1061 ++++++++++++++++++++------ reflex/vars.py | 8 + reflex/vars.pyi | 1 + tests/test_var.py | 97 ++- 9 files changed, 1971 insertions(+), 324 deletions(-) create mode 100644 reflex/experimental/vars/object.py diff --git a/reflex/experimental/vars/__init__.py b/reflex/experimental/vars/__init__.py index 945cf25fc..8fa5196ff 100644 --- a/reflex/experimental/vars/__init__.py +++ b/reflex/experimental/vars/__init__.py @@ -1,9 +1,7 @@ """Experimental Immutable-Based Var System.""" from .base import ImmutableVar as ImmutableVar -from .base import LiteralObjectVar as LiteralObjectVar from .base import LiteralVar as LiteralVar -from .base import ObjectVar as ObjectVar from .base import var_operation as var_operation from .function import FunctionStringVar as FunctionStringVar from .function import FunctionVar as FunctionVar @@ -12,6 +10,8 @@ from .number import BooleanVar as BooleanVar from .number import LiteralBooleanVar as LiteralBooleanVar from .number import LiteralNumberVar as LiteralNumberVar from .number import NumberVar as NumberVar +from .object import LiteralObjectVar as LiteralObjectVar +from .object import ObjectVar as ObjectVar from .sequence import ArrayJoinOperation as ArrayJoinOperation from .sequence import ArrayVar as ArrayVar from .sequence import ConcatVarOperation as ConcatVarOperation diff --git a/reflex/experimental/vars/base.py b/reflex/experimental/vars/base.py index 55b5673bd..dadcc38bd 100644 --- a/reflex/experimental/vars/base.py +++ b/reflex/experimental/vars/base.py @@ -4,18 +4,19 @@ from __future__ import annotations import dataclasses import functools +import inspect import sys from typing import ( + TYPE_CHECKING, Any, Callable, - Dict, Optional, Type, TypeVar, - Union, + overload, ) -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, get_origin from reflex import constants from reflex.base import Base @@ -30,6 +31,17 @@ from reflex.vars import ( _global_vars, ) +if TYPE_CHECKING: + from .function import FunctionVar, ToFunctionOperation + from .number import ( + BooleanVar, + NumberVar, + ToBooleanVarOperation, + ToNumberVarOperation, + ) + from .object import ObjectVar, ToObjectOperation + from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation + @dataclasses.dataclass( eq=False, @@ -43,7 +55,7 @@ class ImmutableVar(Var): _var_name: str = dataclasses.field() # The type of the var. - _var_type: Type = dataclasses.field(default=Any) + _var_type: types.GenericType = dataclasses.field(default=Any) # Extra metadata associated with the Var _var_data: Optional[ImmutableVarData] = dataclasses.field(default=None) @@ -265,9 +277,138 @@ class ImmutableVar(Var): # Encode the _var_data into the formatted output for tracking purposes. return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}" + @overload + def to( + self, output: Type[NumberVar], var_type: type[int] | type[float] = float + ) -> ToNumberVarOperation: ... -class ObjectVar(ImmutableVar): - """Base class for immutable object vars.""" + @overload + def to(self, output: Type[BooleanVar]) -> ToBooleanVarOperation: ... + + @overload + def to( + self, + output: Type[ArrayVar], + var_type: type[list] | type[tuple] | type[set] = list, + ) -> ToArrayOperation: ... + + @overload + def to(self, output: Type[StringVar]) -> ToStringOperation: ... + + @overload + def to( + self, output: Type[ObjectVar], var_type: types.GenericType = dict + ) -> ToObjectOperation: ... + + @overload + def to( + self, output: Type[FunctionVar], var_type: Type[Callable] = Callable + ) -> ToFunctionOperation: ... + + @overload + def to( + self, output: Type[OUTPUT], var_type: types.GenericType | None = None + ) -> OUTPUT: ... + + def to( + self, output: Type[OUTPUT], var_type: types.GenericType | None = None + ) -> Var: + """Convert the var to a different type. + + Args: + output: The output type. + var_type: The type of the var. + + Raises: + TypeError: If the var_type is not a supported type for the output. + + Returns: + The converted var. + """ + from .number import ( + BooleanVar, + NumberVar, + ToBooleanVarOperation, + ToNumberVarOperation, + ) + + fixed_type = ( + var_type + if var_type is None or inspect.isclass(var_type) + else get_origin(var_type) + ) + + if issubclass(output, NumberVar): + if fixed_type is not None and not issubclass(fixed_type, (int, float)): + raise TypeError( + f"Unsupported type {var_type} for NumberVar. Must be int or float." + ) + return ToNumberVarOperation(self, var_type or float) + if issubclass(output, BooleanVar): + return ToBooleanVarOperation(self) + + from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation + + if issubclass(output, ArrayVar): + if fixed_type is not None and not issubclass( + fixed_type, (list, tuple, set) + ): + raise TypeError( + f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set." + ) + return ToArrayOperation(self, var_type or list) + if issubclass(output, StringVar): + return ToStringOperation(self) + + from .object import ObjectVar, ToObjectOperation + + if issubclass(output, ObjectVar): + return ToObjectOperation(self, var_type or dict) + + from .function import FunctionVar, ToFunctionOperation + + if issubclass(output, FunctionVar): + if fixed_type is not None and not issubclass(fixed_type, Callable): + raise TypeError( + f"Unsupported type {var_type} for FunctionVar. Must be Callable." + ) + return ToFunctionOperation(self, var_type or Callable) + + return output( + _var_name=self._var_name, + _var_type=self._var_type if var_type is None else var_type, + _var_data=self._var_data, + ) + + def guess_type(self) -> ImmutableVar: + """Guess the type of the var. + + Returns: + The guessed type. + """ + from .number import NumberVar + from .object import ObjectVar + from .sequence import ArrayVar, StringVar + + if self._var_type is Any: + return self + + var_type = self._var_type + + fixed_type = var_type if inspect.isclass(var_type) else get_origin(var_type) + + if issubclass(fixed_type, (int, float)): + return self.to(NumberVar, var_type) + if issubclass(fixed_type, dict): + return self.to(ObjectVar, var_type) + if issubclass(fixed_type, (list, tuple, set)): + return self.to(ArrayVar, var_type) + if issubclass(fixed_type, str): + return self.to(StringVar) + return self + + +OUTPUT = TypeVar("OUTPUT", bound=ImmutableVar) class LiteralVar(ImmutableVar): @@ -299,6 +440,8 @@ class LiteralVar(ImmutableVar): if value is None: return ImmutableVar.create_safe("null", _var_data=_var_data) + from .object import LiteralObjectVar + if isinstance(value, Base): return LiteralObjectVar( value.dict(), _var_type=type(value), _var_data=_var_data @@ -330,102 +473,15 @@ class LiteralVar(ImmutableVar): def __post_init__(self): """Post-initialize the var.""" + def json(self) -> str: + """Serialize the var to a JSON string. -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LiteralObjectVar(LiteralVar): - """Base class for immutable literal object vars.""" - - _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( - default_factory=dict - ) - - def __init__( - self, - _var_value: dict[Var | Any, Var | Any], - _var_type: Type = dict, - _var_data: VarData | None = None, - ): - """Initialize the object var. - - Args: - _var_value: The value of the var. - _var_data: Additional hooks and imports associated with the Var. + Raises: + NotImplementedError: If the method is not implemented. """ - super(LiteralObjectVar, self).__init__( - _var_name="", - _var_type=_var_type, - _var_data=ImmutableVarData.merge(_var_data), + raise NotImplementedError( + "LiteralVar subclasses must implement the json method." ) - object.__setattr__( - self, - "_var_value", - _var_value, - ) - object.__delattr__(self, "_var_name") - - def __getattr__(self, name): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute of the var. - """ - if name == "_var_name": - return self._cached_var_name - return super(type(self), self).__getattr__(name) - - @functools.cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return ( - "{ " - + ", ".join( - [ - f"[{str(LiteralVar.create(key))}] : {str(LiteralVar.create(value))}" - for key, value in self._var_value.items() - ] - ) - + " }" - ) - - @functools.cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - *[ - value._get_all_var_data() - for key, value in self._var_value - if isinstance(value, Var) - ], - *[ - key._get_all_var_data() - for key, value in self._var_value - if isinstance(key, Var) - ], - self._var_data, - ) - - def _get_all_var_data(self) -> ImmutableVarData | None: - """Wrapper method for cached property. - - Returns: - The VarData of the components and all of its children. - """ - return self._cached_get_all_var_data P = ParamSpec("P") diff --git a/reflex/experimental/vars/function.py b/reflex/experimental/vars/function.py index f1cf83886..adce1329d 100644 --- a/reflex/experimental/vars/function.py +++ b/reflex/experimental/vars/function.py @@ -5,7 +5,7 @@ from __future__ import annotations import dataclasses import sys from functools import cached_property -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Type, Union from reflex.experimental.vars.base import ImmutableVar, LiteralVar from reflex.vars import ImmutableVarData, Var, VarData @@ -212,3 +212,79 @@ class ArgsFunctionOperation(FunctionVar): def __post_init__(self): """Post-initialize the var.""" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToFunctionOperation(FunctionVar): + """Base class of converting a var to a function.""" + + _original_var: Var = dataclasses.field( + default_factory=lambda: LiteralVar.create(None) + ) + + def __init__( + self, + original_var: Var, + _var_type: Type[Callable] = Callable, + _var_data: VarData | None = None, + ) -> None: + """Initialize the function with arguments var. + + Args: + original_var: The original var to convert to a function. + _var_type: The type of the function. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ToFunctionOperation, self).__init__( + _var_name=f"", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_original_var", original_var) + object.__delattr__(self, "_var_name") + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str(self._original_var) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._original_var._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data diff --git a/reflex/experimental/vars/number.py b/reflex/experimental/vars/number.py index 6b74bc336..c83c5c4d2 100644 --- a/reflex/experimental/vars/number.py +++ b/reflex/experimental/vars/number.py @@ -3,6 +3,7 @@ from __future__ import annotations import dataclasses +import json import sys from functools import cached_property from typing import Any, Union @@ -1253,6 +1254,22 @@ class LiteralBooleanVar(LiteralVar, BooleanVar): ) object.__setattr__(self, "_var_value", _var_value) + def __hash__(self) -> int: + """Hash the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._var_value)) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + """ + return "true" if self._var_value else "false" + @dataclasses.dataclass( eq=False, @@ -1288,8 +1305,154 @@ class LiteralNumberVar(LiteralVar, NumberVar): Returns: The hash of the var. """ - return hash(self._var_value) + return hash((self.__class__.__name__, self._var_value)) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + """ + return json.dumps(self._var_value) number_types = Union[NumberVar, LiteralNumberVar, int, float] boolean_types = Union[BooleanVar, LiteralBooleanVar, bool] + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToNumberVarOperation(NumberVar): + """Base class for immutable number vars that are the result of a number operation.""" + + _original_value: Var = dataclasses.field( + default_factory=lambda: LiteralNumberVar(0) + ) + + def __init__( + self, + _original_value: Var, + _var_type: type[int] | type[float] = float, + _var_data: VarData | None = None, + ): + """Initialize the number var. + + Args: + _original_value: The original value. + _var_type: The type of the Var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ToNumberVarOperation, self).__init__( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_original_value", _original_value) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str(self._original_value) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ToNumberVarOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._original_value._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToBooleanVarOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of a boolean operation.""" + + _original_value: Var = dataclasses.field( + default_factory=lambda: LiteralBooleanVar(False) + ) + + def __init__( + self, + _original_value: Var, + _var_data: VarData | None = None, + ): + """Initialize the boolean var. + + Args: + _original_value: The original value. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ToBooleanVarOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_original_value", _original_value) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str(self._original_value) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ToBooleanVarOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._original_value._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data diff --git a/reflex/experimental/vars/object.py b/reflex/experimental/vars/object.py new file mode 100644 index 000000000..4522473c7 --- /dev/null +++ b/reflex/experimental/vars/object.py @@ -0,0 +1,627 @@ +"""Classes for immutable object vars.""" + +from __future__ import annotations + +import dataclasses +import sys +import typing +from functools import cached_property +from typing import Any, Dict, List, Tuple, Type, Union + +from reflex.experimental.vars.base import ImmutableVar, LiteralVar +from reflex.experimental.vars.sequence import ArrayVar, unionize +from reflex.vars import ImmutableVarData, Var, VarData + + +class ObjectVar(ImmutableVar): + """Base class for immutable object vars.""" + + def _key_type(self) -> Type: + """Get the type of the keys of the object. + + Returns: + The type of the keys of the object. + """ + return ImmutableVar + + def _value_type(self) -> Type: + """Get the type of the values of the object. + + Returns: + The type of the values of the object. + """ + return ImmutableVar + + def keys(self) -> ObjectKeysOperation: + """Get the keys of the object. + + Returns: + The keys of the object. + """ + return ObjectKeysOperation(self) + + def values(self) -> ObjectValuesOperation: + """Get the values of the object. + + Returns: + The values of the object. + """ + return ObjectValuesOperation(self) + + def entries(self) -> ObjectEntriesOperation: + """Get the entries of the object. + + Returns: + The entries of the object. + """ + return ObjectEntriesOperation(self) + + def merge(self, other: ObjectVar) -> ObjectMergeOperation: + """Merge two objects. + + Args: + other: The other object to merge. + + Returns: + The merged object. + """ + return ObjectMergeOperation(self, other) + + def __getitem__(self, key: Var | Any) -> ImmutableVar: + """Get an item from the object. + + Args: + key: The key to get from the object. + + Returns: + The item from the object. + """ + return ObjectItemOperation(self, key).guess_type() + + def __getattr__(self, name) -> ObjectItemOperation: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + return ObjectItemOperation(self, name) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralObjectVar(LiteralVar, ObjectVar): + """Base class for immutable literal object vars.""" + + _var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( + default_factory=dict + ) + + def __init__( + self, + _var_value: dict[Var | Any, Var | Any], + _var_type: Type | None = None, + _var_data: VarData | None = None, + ): + """Initialize the object var. + + Args: + _var_value: The value of the var. + _var_type: The type of the var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(LiteralObjectVar, self).__init__( + _var_name="", + _var_type=( + Dict[ + unionize(*map(type, _var_value.keys())), + unionize(*map(type, _var_value.values())), + ] + if _var_type is None + else _var_type + ), + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "_var_value", + _var_value, + ) + object.__delattr__(self, "_var_name") + + def _key_type(self) -> Type: + """Get the type of the keys of the object. + + Returns: + The type of the keys of the object. + """ + args_list = typing.get_args(self._var_type) + return args_list[0] if args_list else Any + + def _value_type(self) -> Type: + """Get the type of the values of the object. + + Returns: + The type of the values of the object. + """ + args_list = typing.get_args(self._var_type) + return args_list[1] if args_list else Any + + def __getattr__(self, name): + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the var. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return ( + "({ " + + ", ".join( + [ + f"[{str(LiteralVar.create(key))}] : {str(LiteralVar.create(value))}" + for key, value in self._var_value.items() + ] + ) + + " })" + ) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[ + value._get_all_var_data() + for key, value in self._var_value + if isinstance(value, Var) + ], + *[ + key._get_all_var_data() + for key, value in self._var_value + if isinstance(key, Var) + ], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def json(self) -> str: + """Get the JSON representation of the object. + + Returns: + The JSON representation of the object. + """ + return ( + "{" + + ", ".join( + [ + f"{LiteralVar.create(key).json()}:{LiteralVar.create(value).json()}" + for key, value in self._var_value.items() + ] + ) + + "}" + ) + + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._var_name)) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ObjectToArrayOperation(ArrayVar): + """Base class for object to array operations.""" + + value: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + + def __init__( + self, + _var_value: ObjectVar, + _var_type: Type = list, + _var_data: VarData | None = None, + ): + """Initialize the object to array operation. + + Args: + _var_value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectToArrayOperation, self).__init__( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "value", _var_value) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Raises: + NotImplementedError: Must implement _cached_var_name. + """ + raise NotImplementedError( + "ObjectToArrayOperation must implement _cached_var_name" + ) + + def __getattr__(self, name): + """Get an attribute of the operation. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the operation. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the operation. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.value._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + +class ObjectKeysOperation(ObjectToArrayOperation): + """Operation to get the keys of an object.""" + + def __init__( + self, + value: ObjectVar, + _var_data: VarData | None = None, + ): + """Initialize the object keys operation. + + Args: + value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectKeysOperation, self).__init__( + value, List[value._key_type()], _var_data + ) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return f"Object.keys({self.value._var_name})" + + +class ObjectValuesOperation(ObjectToArrayOperation): + """Operation to get the values of an object.""" + + def __init__( + self, + value: ObjectVar, + _var_data: VarData | None = None, + ): + """Initialize the object values operation. + + Args: + value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectValuesOperation, self).__init__( + value, List[value._value_type()], _var_data + ) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return f"Object.values({self.value._var_name})" + + +class ObjectEntriesOperation(ObjectToArrayOperation): + """Operation to get the entries of an object.""" + + def __init__( + self, + value: ObjectVar, + _var_data: VarData | None = None, + ): + """Initialize the object entries operation. + + Args: + value: The value of the operation. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectEntriesOperation, self).__init__( + value, List[Tuple[value._key_type(), value._value_type()]], _var_data + ) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return f"Object.entries({self.value._var_name})" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ObjectMergeOperation(ObjectVar): + """Operation to merge two objects.""" + + left: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + right: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + + def __init__( + self, + left: ObjectVar, + right: ObjectVar, + _var_data: VarData | None = None, + ): + """Initialize the object merge operation. + + Args: + left: The left object to merge. + right: The right object to merge. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectMergeOperation, self).__init__( + _var_name="", + _var_type=left._var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "left", left) + object.__setattr__(self, "right", right) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return f"Object.assign({self.left._var_name}, {self.right._var_name})" + + def __getattr__(self, name): + """Get an attribute of the operation. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the operation. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the operation. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.left._get_all_var_data(), + self.right._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ObjectItemOperation(ImmutableVar): + """Operation to get an item from an object.""" + + value: ObjectVar = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + key: Var | Any = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) + + def __init__( + self, + value: ObjectVar, + key: Var | Any, + _var_data: VarData | None = None, + ): + """Initialize the object item operation. + + Args: + value: The value of the operation. + key: The key to get from the object. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ObjectItemOperation, self).__init__( + _var_name="", + _var_type=value._value_type(), + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "value", value) + object.__setattr__( + self, "key", key if isinstance(key, Var) else LiteralVar.create(key) + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return f"{str(self.value)}[{str(self.key)}]" + + def __getattr__(self, name): + """Get an attribute of the operation. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the operation. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the operation. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.value._get_all_var_data(), + self.key._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToObjectOperation(ObjectVar): + """Operation to convert a var to an object.""" + + _original_var: Var = dataclasses.field(default_factory=lambda: LiteralObjectVar({})) + + def __init__( + self, + _original_var: Var, + _var_type: Type = dict, + _var_data: VarData | None = None, + ): + """Initialize the to object operation. + + Args: + _original_var: The original var to convert. + _var_type: The type of the var. + _var_data: Additional hooks and imports associated with the operation. + """ + super(ToObjectOperation, self).__init__( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "_original_var", _original_var) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the operation. + + Returns: + The name of the operation. + """ + return str(self._original_var) + + def __getattr__(self, name): + """Get an attribute of the operation. + + Args: + name: The name of the attribute. + + Returns: + The attribute of the operation. + """ + if name == "_var_name": + return self._cached_var_name + return super(type(self), self).__getattr__(name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the operation. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self._original_var._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data diff --git a/reflex/experimental/vars/sequence.py b/reflex/experimental/vars/sequence.py index c0e8bb9d7..8db1300ec 100644 --- a/reflex/experimental/vars/sequence.py +++ b/reflex/experimental/vars/sequence.py @@ -4,11 +4,15 @@ from __future__ import annotations import dataclasses import functools +import inspect import json import re import sys +import typing from functools import cached_property -from typing import Any, List, Set, Tuple, Union +from typing import Any, List, Set, Tuple, Type, Union, overload + +from typing_extensions import get_origin from reflex import constants from reflex.constants.base import REFLEX_VAR_OPENING_TAG @@ -16,7 +20,13 @@ from reflex.experimental.vars.base import ( ImmutableVar, LiteralVar, ) -from reflex.experimental.vars.number import BooleanVar, NotEqualOperation, NumberVar +from reflex.experimental.vars.number import ( + BooleanVar, + LiteralNumberVar, + NotEqualOperation, + NumberVar, +) +from reflex.utils.types import GenericType from reflex.vars import ImmutableVarData, Var, VarData, _global_vars @@ -67,7 +77,15 @@ class StringVar(ImmutableVar): """ return ConcatVarOperation(*[self for _ in range(other)]) - def __getitem__(self, i: slice | int) -> StringSliceOperation | StringItemOperation: + @overload + def __getitem__(self, i: slice) -> ArrayJoinOperation: ... + + @overload + def __getitem__(self, i: int | NumberVar) -> StringItemOperation: ... + + def __getitem__( + self, i: slice | int | NumberVar + ) -> ArrayJoinOperation | StringItemOperation: """Get a slice of the string. Args: @@ -77,16 +95,16 @@ class StringVar(ImmutableVar): The string slice operation. """ if isinstance(i, slice): - return StringSliceOperation(self, i) + return self.split()[i].join() return StringItemOperation(self, i) - def length(self) -> StringLengthOperation: + def length(self) -> NumberVar: """Get the length of the string. Returns: The string length operation. """ - return StringLengthOperation(self) + return self.split().length() def lower(self) -> StringLowerOperation: """Convert the string to lowercase. @@ -120,13 +138,13 @@ class StringVar(ImmutableVar): """ return NotEqualOperation(self.length(), 0) - def reversed(self) -> StringReverseOperation: + def reversed(self) -> ArrayJoinOperation: """Reverse the string. Returns: The string reverse operation. """ - return StringReverseOperation(self) + return self.split().reverse().join() def contains(self, other: StringVar | str) -> StringContainsOperation: """Check if the string contains another string. @@ -151,85 +169,6 @@ class StringVar(ImmutableVar): return StringSplitOperation(self, separator) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class StringToNumberOperation(NumberVar): - """Base class for immutable number vars that are the result of a string to number operation.""" - - a: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) - - def __init__(self, a: StringVar | str, _var_data: VarData | None = None): - """Initialize the string to number operation var. - - Args: - a: The string. - _var_data: Additional hooks and imports associated with the Var. - """ - super(StringToNumberOperation, self).__init__( - _var_name="", - _var_type=float, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) - ) - object.__delattr__(self, "_var_name") - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Raises: - NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError( - "StringToNumberOperation must implement _cached_var_name" - ) - - def __getattr__(self, name: str) -> Any: - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute value. - """ - if name == "_var_name": - return self._cached_var_name - getattr(super(StringToNumberOperation, self), name) - - @cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) - - def _get_all_var_data(self) -> ImmutableVarData | None: - return self._cached_get_all_var_data - - -class StringLengthOperation(StringToNumberOperation): - """Base class for immutable number vars that are the result of a string length operation.""" - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self.a)}.length" - - @dataclasses.dataclass( eq=False, frozen=True, @@ -338,19 +277,6 @@ class StringStripOperation(StringToStringOperation): return f"{str(self.a)}.trim()" -class StringReverseOperation(StringToStringOperation): - """Base class for immutable string vars that are the result of a string reverse operation.""" - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self.a)}.split('').reverse().join('')" - - @dataclasses.dataclass( eq=False, frozen=True, @@ -426,112 +352,6 @@ class StringContainsOperation(BooleanVar): return self._cached_get_all_var_data -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class StringSliceOperation(StringVar): - """Base class for immutable string vars that are the result of a string slice operation.""" - - a: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) - _slice: slice = dataclasses.field(default_factory=lambda: slice(None, None, None)) - - def __init__( - self, a: StringVar | str, _slice: slice, _var_data: VarData | None = None - ): - """Initialize the string slice operation var. - - Args: - a: The string. - _slice: The slice. - _var_data: Additional hooks and imports associated with the Var. - """ - super(StringSliceOperation, self).__init__( - _var_name="", - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - ) - object.__setattr__( - self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) - ) - object.__setattr__(self, "_slice", _slice) - object.__delattr__(self, "_var_name") - - @cached_property - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - - Raises: - ValueError: If the slice step is zero. - """ - start, end, step = self._slice.start, self._slice.stop, self._slice.step - - if step is not None and step < 0: - actual_start = end + 1 if end is not None else 0 - actual_end = start + 1 if start is not None else self.a.length() - return str( - StringSliceOperation( - StringReverseOperation( - StringSliceOperation(self.a, slice(actual_start, actual_end)) - ), - slice(None, None, -step), - ) - ) - - start = ( - LiteralVar.create(start) - if start is not None - else ImmutableVar.create_safe("undefined") - ) - end = ( - LiteralVar.create(end) - if end is not None - else ImmutableVar.create_safe("undefined") - ) - - if step is None: - return f"{str(self.a)}.slice({str(start)}, {str(end)})" - if step == 0: - raise ValueError("slice step cannot be zero") - return f"{str(self.a)}.slice({str(start)}, {str(end)}).split('').filter((_, i) => i % {str(step)} === 0).join('')" - - def __getattr__(self, name: str) -> Any: - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute value. - """ - if name == "_var_name": - return self._cached_var_name - getattr(super(StringSliceOperation, self), name) - - @cached_property - def _cached_get_all_var_data(self) -> ImmutableVarData | None: - """Get all VarData associated with the Var. - - Returns: - The VarData of the components and all of its children. - """ - return ImmutableVarData.merge( - self.a._get_all_var_data(), - self.start._get_all_var_data(), - self.end._get_all_var_data(), - self._var_data, - ) - - def _get_all_var_data(self) -> ImmutableVarData | None: - return self._cached_get_all_var_data - - @dataclasses.dataclass( eq=False, frozen=True, @@ -543,9 +363,11 @@ class StringItemOperation(StringVar): a: StringVar = dataclasses.field( default_factory=lambda: LiteralStringVar.create("") ) - i: int = dataclasses.field(default=0) + i: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) - def __init__(self, a: StringVar | str, i: int, _var_data: VarData | None = None): + def __init__( + self, a: StringVar | str, i: int | NumberVar, _var_data: VarData | None = None + ): """Initialize the string item operation var. Args: @@ -561,7 +383,7 @@ class StringItemOperation(StringVar): object.__setattr__( self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a) ) - object.__setattr__(self, "i", i) + object.__setattr__(self, "i", i if isinstance(i, Var) else LiteralNumberVar(i)) object.__delattr__(self, "_var_name") @cached_property @@ -593,7 +415,9 @@ class StringItemOperation(StringVar): Returns: The VarData of the components and all of its children. """ - return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.i._get_all_var_data(), self._var_data + ) def _get_all_var_data(self) -> ImmutableVarData | None: return self._cached_get_all_var_data @@ -608,7 +432,7 @@ class ArrayJoinOperation(StringVar): ) def __init__( - self, a: ArrayVar | list, b: StringVar | str, _var_data: VarData | None = None + self, a: ArrayVar, b: StringVar | str, _var_data: VarData | None = None ): """Initialize the array join operation var. @@ -622,9 +446,7 @@ class ArrayJoinOperation(StringVar): _var_type=str, _var_data=ImmutableVarData.merge(_var_data), ) - object.__setattr__( - self, "a", a if isinstance(a, Var) else LiteralArrayVar.create(a) - ) + object.__setattr__(self, "a", a) object.__setattr__( self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b) ) @@ -777,6 +599,22 @@ class LiteralStringVar(LiteralVar, StringVar): _var_data=_var_data, ) + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._var_value)) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + """ + return json.dumps(self._var_value) + @dataclasses.dataclass( eq=False, @@ -879,6 +717,94 @@ class ArrayVar(ImmutableVar): return ArrayJoinOperation(self, sep) + def reverse(self) -> ArrayReverseOperation: + """Reverse the array. + + Returns: + The reversed array. + """ + return ArrayReverseOperation(self) + + @overload + def __getitem__(self, i: slice) -> ArraySliceOperation: ... + + @overload + def __getitem__(self, i: int | NumberVar) -> ImmutableVar: ... + + def __getitem__( + self, i: slice | int | NumberVar + ) -> ArraySliceOperation | ImmutableVar: + """Get a slice of the array. + + Args: + i: The slice. + + Returns: + The array slice operation. + """ + if isinstance(i, slice): + return ArraySliceOperation(self, i) + return ArrayItemOperation(self, i).guess_type() + + def length(self) -> NumberVar: + """Get the length of the array. + + Returns: + The length of the array. + """ + return ArrayLengthOperation(self) + + @overload + @classmethod + def range(cls, stop: int | NumberVar, /) -> RangeOperation: ... + + @overload + @classmethod + def range( + cls, + start: int | NumberVar, + end: int | NumberVar, + step: int | NumberVar = 1, + /, + ) -> RangeOperation: ... + + @classmethod + def range( + cls, + first_endpoint: int | NumberVar, + second_endpoint: int | NumberVar | None = None, + step: int | NumberVar | None = None, + ) -> RangeOperation: + """Create a range of numbers. + + Args: + first_endpoint: The end of the range if second_endpoint is not provided, otherwise the start of the range. + second_endpoint: The end of the range. + step: The step of the range. + + Returns: + The range of numbers. + """ + if second_endpoint is None: + start = 0 + end = first_endpoint + else: + start = first_endpoint + end = second_endpoint + + return RangeOperation(start, end, step or 1) + + def contains(self, other: Any) -> ArrayContainsOperation: + """Check if the array contains an element. + + Args: + other: The element to check for. + + Returns: + The array contains operation. + """ + return ArrayContainsOperation(self, other) + @dataclasses.dataclass( eq=False, @@ -894,19 +820,25 @@ class LiteralArrayVar(LiteralVar, ArrayVar): def __init__( self, - _var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any], + _var_value: list[Var | Any] | tuple[Var | Any, ...] | set[Var | Any], + _var_type: type[list] | type[tuple] | type[set] | None = None, _var_data: VarData | None = None, ): """Initialize the array var. Args: _var_value: The value of the var. + _var_type: The type of the var. _var_data: Additional hooks and imports associated with the Var. """ super(LiteralArrayVar, self).__init__( _var_name="", _var_data=ImmutableVarData.merge(_var_data), - _var_type=list, + _var_type=( + List[unionize(*map(type, _var_value))] + if _var_type is None + else _var_type + ), ) object.__setattr__(self, "_var_value", _var_value) object.__delattr__(self, "_var_name") @@ -963,6 +895,28 @@ class LiteralArrayVar(LiteralVar, ArrayVar): """ return self._cached_get_all_var_data + def __hash__(self) -> int: + """Get the hash of the var. + + Returns: + The hash of the var. + """ + return hash((self.__class__.__name__, self._var_name)) + + def json(self) -> str: + """Get the JSON representation of the var. + + Returns: + The JSON representation of the var. + """ + return ( + "[" + + ", ".join( + [LiteralVar.create(element).json() for element in self._var_value] + ) + + "]" + ) + @dataclasses.dataclass( eq=False, @@ -991,7 +945,7 @@ class StringSplitOperation(ArrayVar): """ super(StringSplitOperation, self).__init__( _var_name="", - _var_type=list, + _var_type=List[str], _var_data=ImmutableVarData.merge(_var_data), ) object.__setattr__( @@ -1037,3 +991,676 @@ class StringSplitOperation(ArrayVar): def _get_all_var_data(self) -> ImmutableVarData | None: return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArrayToArrayOperation(ArrayVar): + """Base class for immutable array vars that are the result of an array to array operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + + def __init__(self, a: ArrayVar, _var_data: VarData | None = None): + """Initialize the array to array operation var. + + Args: + a: The string. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArrayToArrayOperation, self).__init__( + _var_name="", + _var_type=a._var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError( + "ArrayToArrayOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayToArrayOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data() if isinstance(self.a, Var) else None, + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArraySliceOperation(ArrayVar): + """Base class for immutable string vars that are the result of a string slice operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + _slice: slice = dataclasses.field(default_factory=lambda: slice(None, None, None)) + + def __init__(self, a: ArrayVar, _slice: slice, _var_data: VarData | None = None): + """Initialize the string slice operation var. + + Args: + a: The string. + _slice: The slice. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArraySliceOperation, self).__init__( + _var_name="", + _var_type=a._var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__setattr__(self, "_slice", _slice) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + + Raises: + ValueError: If the slice step is zero. + """ + start, end, step = self._slice.start, self._slice.stop, self._slice.step + + normalized_start = ( + LiteralVar.create(start) + if start is not None + else ImmutableVar.create_safe("undefined") + ) + normalized_end = ( + LiteralVar.create(end) + if end is not None + else ImmutableVar.create_safe("undefined") + ) + if step is None: + return ( + f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)})" + ) + if not isinstance(step, Var): + if step < 0: + actual_start = end + 1 if end is not None else 0 + actual_end = start + 1 if start is not None else self.a.length() + return str( + ArraySliceOperation( + ArrayReverseOperation( + ArraySliceOperation(self.a, slice(actual_start, actual_end)) + ), + slice(None, None, -step), + ) + ) + if step == 0: + raise ValueError("slice step cannot be zero") + return f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0)" + + actual_start_reverse = end + 1 if end is not None else 0 + actual_end_reverse = start + 1 if start is not None else self.a.length() + + return f"{str(self.step)} > 0 ? {str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0) : {str(self.a)}.slice({str(actual_start_reverse)}, {str(actual_end_reverse)}).reverse().filter((_, i) => i % {str(-step)} === 0)" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArraySliceOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), + *[ + slice_value._get_all_var_data() + for slice_value in ( + self._slice.start, + self._slice.stop, + self._slice.step, + ) + if slice_value is not None and isinstance(slice_value, Var) + ], + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class ArrayReverseOperation(ArrayToArrayOperation): + """Base class for immutable string vars that are the result of a string reverse operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.reverse()" + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArrayToNumberOperation(NumberVar): + """Base class for immutable number vars that are the result of an array to number operation.""" + + a: ArrayVar = dataclasses.field( + default_factory=lambda: LiteralArrayVar([]), + ) + + def __init__(self, a: ArrayVar, _var_data: VarData | None = None): + """Initialize the string to number operation var. + + Args: + a: The array. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArrayToNumberOperation, self).__init__( + _var_name="", + _var_type=int, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a)) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError( + "StringToNumberOperation must implement _cached_var_name" + ) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayToNumberOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +class ArrayLengthOperation(ArrayToNumberOperation): + """Base class for immutable number vars that are the result of an array length operation.""" + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.length" + + +def unionize(*args: Type) -> Type: + """Unionize the types. + + Args: + args: The types to unionize. + + Returns: + The unionized types. + """ + if not args: + return Any + first, *rest = args + if not rest: + return first + return Union[first, unionize(*rest)] + + +def is_tuple_type(t: GenericType) -> bool: + """Check if a type is a tuple type. + + Args: + t: The type to check. + + Returns: + Whether the type is a tuple type. + """ + if inspect.isclass(t): + return issubclass(t, tuple) + return get_origin(t) is tuple + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArrayItemOperation(ImmutableVar): + """Base class for immutable array vars that are the result of an array item operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + i: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + + def __init__( + self, + a: ArrayVar, + i: NumberVar | int, + _var_data: VarData | None = None, + ): + """Initialize the array item operation var. + + Args: + a: The array. + i: The index. + _var_data: Additional hooks and imports associated with the Var. + """ + args = typing.get_args(a._var_type) + if args and isinstance(i, int) and is_tuple_type(a._var_type): + element_type = args[i % len(args)] + else: + element_type = unionize(*args) + super(ArrayItemOperation, self).__init__( + _var_name="", + _var_type=element_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a)) + object.__setattr__( + self, + "i", + i if isinstance(i, Var) else LiteralNumberVar(i), + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.at({str(self.i)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayItemOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.i._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class RangeOperation(ArrayVar): + """Base class for immutable array vars that are the result of a range operation.""" + + start: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + end: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0)) + step: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(1)) + + def __init__( + self, + start: NumberVar | int, + end: NumberVar | int, + step: NumberVar | int, + _var_data: VarData | None = None, + ): + """Initialize the range operation var. + + Args: + start: The start of the range. + end: The end of the range. + step: The step of the range. + _var_data: Additional hooks and imports associated with the Var. + """ + super(RangeOperation, self).__init__( + _var_name="", + _var_type=List[int], + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "start", + start if isinstance(start, Var) else LiteralNumberVar(start), + ) + object.__setattr__( + self, + "end", + end if isinstance(end, Var) else LiteralNumberVar(end), + ) + object.__setattr__( + self, + "step", + step if isinstance(step, Var) else LiteralNumberVar(step), + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + start, end, step = self.start, self.end, self.step + return f"Array.from({{ length: ({str(end)} - {str(start)}) / {str(step)} }}, (_, i) => {str(start)} + i * {str(step)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(RangeOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.start._get_all_var_data(), + self.end._get_all_var_data(), + self.step._get_all_var_data(), + self._var_data, + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ArrayContainsOperation(BooleanVar): + """Base class for immutable boolean vars that are the result of an array contains operation.""" + + a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + b: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) + + def __init__(self, a: ArrayVar, b: Any | Var, _var_data: VarData | None = None): + """Initialize the array contains operation var. + + Args: + a: The array. + b: The element to check for. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ArrayContainsOperation, self).__init__( + _var_name="", + _var_type=bool, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__(self, "a", a) + object.__setattr__(self, "b", b if isinstance(b, Var) else LiteralVar.create(b)) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return f"{str(self.a)}.includes({str(self.b)})" + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ArrayContainsOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToStringOperation(StringVar): + """Base class for immutable string vars that are the result of a to string operation.""" + + original_var: Var = dataclasses.field( + default_factory=lambda: LiteralStringVar.create("") + ) + + def __init__(self, original_var: Var, _var_data: VarData | None = None): + """Initialize the to string operation var. + + Args: + original_var: The original var. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ToStringOperation, self).__init__( + _var_name="", + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "original_var", + original_var, + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str(self.original_var) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ToStringOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.original_var._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ToArrayOperation(ArrayVar): + """Base class for immutable array vars that are the result of a to array operation.""" + + original_var: Var = dataclasses.field(default_factory=lambda: LiteralArrayVar([])) + + def __init__( + self, + original_var: Var, + _var_type: type[list] | type[set] | type[tuple] = list, + _var_data: VarData | None = None, + ): + """Initialize the to array operation var. + + Args: + original_var: The original var. + _var_type: The type of the array. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ToArrayOperation, self).__init__( + _var_name="", + _var_type=_var_type, + _var_data=ImmutableVarData.merge(_var_data), + ) + object.__setattr__( + self, + "original_var", + original_var, + ) + object.__delattr__(self, "_var_name") + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return str(self.original_var) + + def __getattr__(self, name: str) -> Any: + """Get an attribute of the var. + + Args: + name: The name of the attribute. + + Returns: + The attribute value. + """ + if name == "_var_name": + return self._cached_var_name + getattr(super(ToArrayOperation, self), name) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + self.original_var._get_all_var_data(), self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._cached_get_all_var_data diff --git a/reflex/vars.py b/reflex/vars.py index 00f02804c..c5a66a20d 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -1997,6 +1997,14 @@ class Var: """ return self._var_data + def json(self) -> str: + """Serialize the var to a JSON string. + + Raises: + NotImplementedError: If the method is not implemented. + """ + raise NotImplementedError("Var subclasses must implement the json method.") + @property def _var_name_unwrapped(self) -> str: """Get the var str without wrapping in curly braces. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 4aa6afc33..77a878086 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -151,6 +151,7 @@ class Var: def _var_full_name(self) -> str: ... def _var_set_state(self, state: Type[BaseState] | str) -> Any: ... def _get_all_var_data(self) -> VarData: ... + def json(self) -> str: ... @dataclass(eq=False) class BaseVar(Var): diff --git a/tests/test_var.py b/tests/test_var.py index 761375464..66599db72 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -19,7 +19,13 @@ from reflex.experimental.vars.number import ( LiteralNumberVar, NumberVar, ) -from reflex.experimental.vars.sequence import ConcatVarOperation, LiteralStringVar +from reflex.experimental.vars.object import LiteralObjectVar +from reflex.experimental.vars.sequence import ( + ArrayVar, + ConcatVarOperation, + LiteralArrayVar, + LiteralStringVar, +) from reflex.state import BaseState from reflex.utils.imports import ImportVar from reflex.vars import ( @@ -881,7 +887,7 @@ def test_literal_var(): ) assert ( str(complicated_var) - == '[{ ["a"] : 1, ["b"] : 2, ["c"] : { ["d"] : 3, ["e"] : 4 } }, [1, 2, 3, 4], 9, "string", true, false, null, [1, 2, 3]]' + == '[({ ["a"] : 1, ["b"] : 2, ["c"] : ({ ["d"] : 3, ["e"] : 4 }) }), [1, 2, 3, 4], 9, "string", true, false, null, [1, 2, 3]]' ) @@ -898,7 +904,7 @@ def test_function_var(): ) assert ( str(manual_addition_func.call(1, 2)) - == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))' + == '(((a, b) => (({ ["args"] : [a, b], ["result"] : a + b })))(1, 2))' ) increment_func = addition_func(1) @@ -935,7 +941,7 @@ def test_var_operation(): def test_string_operations(): basic_string = LiteralStringVar.create("Hello, World!") - assert str(basic_string.length()) == '"Hello, World!".length' + assert str(basic_string.length()) == '"Hello, World!".split("").length' assert str(basic_string.lower()) == '"Hello, World!".toLowerCase()' assert str(basic_string.upper()) == '"Hello, World!".toUpperCase()' assert str(basic_string.strip()) == '"Hello, World!".trim()' @@ -972,6 +978,89 @@ def test_all_number_operations(): ) +def test_index_operation(): + array_var = LiteralArrayVar([1, 2, 3, 4, 5]) + assert str(array_var[0]) == "[1, 2, 3, 4, 5].at(0)" + assert str(array_var[1:2]) == "[1, 2, 3, 4, 5].slice(1, 2)" + assert ( + str(array_var[1:4:2]) + == "[1, 2, 3, 4, 5].slice(1, 4).filter((_, i) => i % 2 === 0)" + ) + assert ( + str(array_var[::-1]) + == "[1, 2, 3, 4, 5].slice(0, [1, 2, 3, 4, 5].length).reverse().slice(undefined, undefined).filter((_, i) => i % 1 === 0)" + ) + assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].reverse()" + assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)" + + +def test_array_operations(): + array_var = LiteralArrayVar.create([1, 2, 3, 4, 5]) + + assert str(array_var.length()) == "[1, 2, 3, 4, 5].length" + assert str(array_var.contains(3)) == "[1, 2, 3, 4, 5].includes(3)" + assert str(array_var.reverse()) == "[1, 2, 3, 4, 5].reverse()" + assert ( + str(ArrayVar.range(10)) + == "Array.from({ length: (10 - 0) / 1 }, (_, i) => 0 + i * 1)" + ) + assert ( + str(ArrayVar.range(1, 10)) + == "Array.from({ length: (10 - 1) / 1 }, (_, i) => 1 + i * 1)" + ) + assert ( + str(ArrayVar.range(1, 10, 2)) + == "Array.from({ length: (10 - 1) / 2 }, (_, i) => 1 + i * 2)" + ) + assert ( + str(ArrayVar.range(1, 10, -1)) + == "Array.from({ length: (10 - 1) / -1 }, (_, i) => 1 + i * -1)" + ) + + +def test_object_operations(): + object_var = LiteralObjectVar({"a": 1, "b": 2, "c": 3}) + + assert ( + str(object_var.keys()) == 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))' + ) + assert ( + str(object_var.values()) + == 'Object.values(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))' + ) + assert ( + str(object_var.entries()) + == 'Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }))' + ) + assert str(object_var.a) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]' + assert str(object_var["a"]) == '({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["a"]' + assert ( + str(object_var.merge(LiteralObjectVar({"c": 4, "d": 5}))) + == 'Object.assign(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 }), ({ ["c"] : 4, ["d"] : 5 }))' + ) + + +def test_type_chains(): + object_var = LiteralObjectVar({"a": 1, "b": 2, "c": 3}) + assert object_var._var_type is Dict[str, int] + assert (object_var.keys()._var_type, object_var.values()._var_type) == ( + List[str], + List[int], + ) + assert ( + str(object_var.keys()[0].upper()) # type: ignore + == 'Object.keys(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(0).toUpperCase()' + ) + assert ( + str(object_var.entries()[1][1] - 1) # type: ignore + == '(Object.entries(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })).at(1).at(1) - 1)' + ) + assert ( + str(object_var["c"] + object_var["b"]) # type: ignore + == '(({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["c"] + ({ ["a"] : 1, ["b"] : 2, ["c"] : 3 })["b"])' + ) + + def test_retrival(): var_without_data = ImmutableVar.create("test") assert var_without_data is not None