reflex/reflex/ivars/sequence.py
2024-08-05 15:20:21 -07:00

1939 lines
57 KiB
Python

"""Collection of string classes and utilities."""
from __future__ import annotations
import dataclasses
import functools
import inspect
import json
import re
import sys
import typing
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Set,
Tuple,
TypeVar,
Union,
overload,
)
from typing_extensions import get_origin
from reflex import constants
from reflex.constants.base import REFLEX_VAR_OPENING_TAG
from reflex.utils.types import GenericType
from reflex.vars import ImmutableVarData, Var, VarData, _global_vars
from .base import (
ImmutableVar,
LiteralVar,
figure_out_type,
unionize,
)
from .number import (
BooleanVar,
LiteralNumberVar,
NotEqualOperation,
NumberVar,
)
if TYPE_CHECKING:
from .object import ObjectVar
class StringVar(ImmutableVar[str]):
"""Base class for immutable string vars."""
def __add__(self, other: StringVar | str) -> ConcatVarOperation:
"""Concatenate two strings.
Args:
other: The other string.
Returns:
The string concatenation operation.
"""
return ConcatVarOperation(self, other)
def __radd__(self, other: StringVar | str) -> ConcatVarOperation:
"""Concatenate two strings.
Args:
other: The other string.
Returns:
The string concatenation operation.
"""
return ConcatVarOperation(other, self)
def __mul__(self, other: NumberVar | int) -> StringVar:
"""Multiply the sequence by a number or an integer.
Args:
other (NumberVar | int): The number or integer to multiply the sequence by.
Returns:
StringVar: The resulting sequence after multiplication.
"""
return (self.split() * other).join()
def __rmul__(self, other: NumberVar | int) -> StringVar:
"""Multiply the sequence by a number or an integer.
Args:
other (NumberVar | int): The number or integer to multiply the sequence by.
Returns:
StringVar: The resulting sequence after multiplication.
"""
return (self.split() * other).join()
@overload
def __getitem__(self, i: slice) -> ArrayJoinOperation: ...
@overload
def __getitem__(self, i: int | NumberVar) -> StringItemOperation: ...
def __getitem__(
self, i: slice | int | NumberVar
) -> ArrayJoinOperation | StringItemOperation:
"""Get a slice of the string.
Args:
i: The slice.
Returns:
The string slice operation.
"""
if isinstance(i, slice):
return self.split()[i].join()
return StringItemOperation(self, i)
def length(self) -> NumberVar:
"""Get the length of the string.
Returns:
The string length operation.
"""
return self.split().length()
def lower(self) -> StringLowerOperation:
"""Convert the string to lowercase.
Returns:
The string lower operation.
"""
return StringLowerOperation(self)
def upper(self) -> StringUpperOperation:
"""Convert the string to uppercase.
Returns:
The string upper operation.
"""
return StringUpperOperation(self)
def strip(self) -> StringStripOperation:
"""Strip the string.
Returns:
The string strip operation.
"""
return StringStripOperation(self)
def bool(self) -> NotEqualOperation:
"""Boolean conversion.
Returns:
The boolean value of the string.
"""
return NotEqualOperation(self.length(), 0)
def reversed(self) -> ArrayJoinOperation:
"""Reverse the string.
Returns:
The string reverse operation.
"""
return self.split().reverse().join()
def contains(self, other: StringVar | str) -> StringContainsOperation:
"""Check if the string contains another string.
Args:
other: The other string.
Returns:
The string contains operation.
"""
return StringContainsOperation(self, other)
def split(self, separator: StringVar | str = "") -> StringSplitOperation:
"""Split the string.
Args:
separator: The separator.
Returns:
The string split operation.
"""
return StringSplitOperation(self, separator)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class StringToStringOperation(StringVar):
"""Base class for immutable string vars that are the result of a string to string operation."""
a: StringVar = dataclasses.field(
default_factory=lambda: LiteralStringVar.create("")
)
def __init__(self, a: StringVar | str, _var_data: VarData | None = None):
"""Initialize the string to string operation var.
Args:
a: The string.
_var_data: Additional hooks and imports associated with the Var.
"""
super(StringToStringOperation, self).__init__(
_var_name="",
_var_type=str,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a)
)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Raises:
NotImplementedError: Must be implemented by subclasses.
"""
raise NotImplementedError(
"StringToStringOperation must implement _cached_var_name"
)
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(StringToStringOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data() if isinstance(self.a, Var) else None,
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
class StringLowerOperation(StringToStringOperation):
"""Base class for immutable string vars that are the result of a string lower operation."""
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.toLowerCase()"
class StringUpperOperation(StringToStringOperation):
"""Base class for immutable string vars that are the result of a string upper operation."""
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.toUpperCase()"
class StringStripOperation(StringToStringOperation):
"""Base class for immutable string vars that are the result of a string strip operation."""
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.trim()"
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class StringContainsOperation(BooleanVar):
"""Base class for immutable boolean vars that are the result of a string contains operation."""
a: StringVar = dataclasses.field(
default_factory=lambda: LiteralStringVar.create("")
)
b: StringVar = dataclasses.field(
default_factory=lambda: LiteralStringVar.create("")
)
def __init__(
self, a: StringVar | str, b: StringVar | str, _var_data: VarData | None = None
):
"""Initialize the string contains operation var.
Args:
a: The first string.
b: The second string.
_var_data: Additional hooks and imports associated with the Var.
"""
super(StringContainsOperation, self).__init__(
_var_name="",
_var_type=bool,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a)
)
object.__setattr__(
self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b)
)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.includes({str(self.b)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(StringContainsOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class StringItemOperation(StringVar):
"""Base class for immutable string vars that are the result of a string item operation."""
a: StringVar = dataclasses.field(
default_factory=lambda: LiteralStringVar.create("")
)
i: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0))
def __init__(
self, a: StringVar | str, i: int | NumberVar, _var_data: VarData | None = None
):
"""Initialize the string item operation var.
Args:
a: The string.
i: The index.
_var_data: Additional hooks and imports associated with the Var.
"""
super(StringItemOperation, self).__init__(
_var_name="",
_var_type=str,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a)
)
object.__setattr__(self, "i", i if isinstance(i, Var) else LiteralNumberVar(i))
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.at({str(self.i)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(StringItemOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data(), self.i._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
class ArrayJoinOperation(StringVar):
"""Base class for immutable string vars that are the result of an array join operation."""
a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
b: StringVar = dataclasses.field(
default_factory=lambda: LiteralStringVar.create("")
)
def __init__(
self, a: ArrayVar, b: StringVar | str, _var_data: VarData | None = None
):
"""Initialize the array join operation var.
Args:
a: The array.
b: The separator.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ArrayJoinOperation, self).__init__(
_var_name="",
_var_type=str,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a)
object.__setattr__(
self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b)
)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.join({str(self.b)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ArrayJoinOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
# Compile regex for finding reflex var tags.
_decode_var_pattern_re = (
rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
)
_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralStringVar(LiteralVar, StringVar):
"""Base class for immutable literal string vars."""
_var_value: str = dataclasses.field(default="")
def __init__(
self,
_var_value: str,
_var_data: VarData | None = None,
):
"""Initialize the string var.
Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralStringVar, self).__init__(
_var_name=f'"{_var_value}"',
_var_type=str,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_var_value", _var_value)
@classmethod
def create(
cls,
value: str,
_var_data: VarData | None = None,
) -> LiteralStringVar | ConcatVarOperation:
"""Create a var from a string value.
Args:
value: The value to create the var from.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
"""
if REFLEX_VAR_OPENING_TAG in value:
strings_and_vals: list[Var | str] = []
offset = 0
# Initialize some methods for reading json.
var_data_config = VarData().__config__
def json_loads(s):
try:
return var_data_config.json_loads(s)
except json.decoder.JSONDecodeError:
return var_data_config.json_loads(
var_data_config.json_loads(f'"{s}"')
)
# Find all tags
while m := _decode_var_pattern.search(value):
start, end = m.span()
if start > 0:
strings_and_vals.append(value[:start])
serialized_data = m.group(1)
if serialized_data.isnumeric() or (
serialized_data[0] == "-" and serialized_data[1:].isnumeric()
):
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
strings_and_vals.append(var)
value = value[(end + len(var._var_name)) :]
else:
data = json_loads(serialized_data)
string_length = data.pop("string_length", None)
var_data = VarData.parse_obj(data)
# Use string length to compute positions of interpolations.
if string_length is not None:
realstart = start + offset
var_data.interpolations = [
(realstart, realstart + string_length)
]
var_content = value[end : (end + string_length)]
if (
var_content[0] == "{"
and var_content[-1] == "}"
and strings_and_vals
and strings_and_vals[-1][-1] == "$"
):
strings_and_vals[-1] = strings_and_vals[-1][:-1]
var_content = "(" + var_content[1:-1] + ")"
strings_and_vals.append(
ImmutableVar.create_safe(var_content, _var_data=var_data)
)
value = value[(end + string_length) :]
offset += end - start
if value:
strings_and_vals.append(value)
return ConcatVarOperation(*strings_and_vals, _var_data=_var_data)
return LiteralStringVar(
value,
_var_data=_var_data,
)
def __hash__(self) -> int:
"""Get the hash of the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._var_value))
def json(self) -> str:
"""Get the JSON representation of the var.
Returns:
The JSON representation of the var.
"""
return json.dumps(self._var_value)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ConcatVarOperation(StringVar):
"""Representing a concatenation of literal string vars."""
_var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple)
def __init__(self, *value: Var | str, _var_data: VarData | None = None):
"""Initialize the operation of concatenating literal string vars.
Args:
value: The values to concatenate.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ConcatVarOperation, self).__init__(
_var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str
)
object.__setattr__(self, "_var_value", value)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return (
"("
+ "+".join(
[
str(element) if isinstance(element, Var) else f'"{element}"'
for element in self._var_value
]
)
+ ")"
)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
*[
var._get_all_var_data()
for var in self._var_value
if isinstance(var, Var)
],
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""
pass
ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set])
OTHER_TUPLE = TypeVar("OTHER_TUPLE")
INNER_ARRAY_VAR = TypeVar("INNER_ARRAY_VAR")
KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]):
"""Base class for immutable array vars."""
def join(self, sep: StringVar | str = "") -> ArrayJoinOperation:
"""Join the elements of the array.
Args:
sep: The separator between elements.
Returns:
The joined elements.
"""
return ArrayJoinOperation(self, sep)
def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]:
"""Reverse the array.
Returns:
The reversed array.
"""
return ArrayReverseOperation(self)
def __add__(self, other: ArrayVar[ARRAY_VAR_TYPE]) -> ArrayConcatOperation:
"""Concatenate two arrays.
Parameters:
other (ArrayVar[ARRAY_VAR_TYPE]): The other array to concatenate.
Returns:
ArrayConcatOperation: The concatenation of the two arrays.
"""
return ArrayConcatOperation(self, other)
@overload
def __getitem__(self, i: slice) -> ArrayVar[ARRAY_VAR_TYPE]: ...
@overload
def __getitem__(
self: (
ArrayVar[Tuple[int, OTHER_TUPLE]]
| ArrayVar[Tuple[float, OTHER_TUPLE]]
| ArrayVar[Tuple[int | float, OTHER_TUPLE]]
),
i: Literal[0, -2],
) -> NumberVar: ...
@overload
def __getitem__(
self: (
ArrayVar[Tuple[OTHER_TUPLE, int]]
| ArrayVar[Tuple[OTHER_TUPLE, float]]
| ArrayVar[Tuple[OTHER_TUPLE, int | float]]
),
i: Literal[1, -1],
) -> NumberVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[str, OTHER_TUPLE]], i: Literal[0, -2]
) -> StringVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[OTHER_TUPLE, str]], i: Literal[1, -1]
) -> StringVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[bool, OTHER_TUPLE]], i: Literal[0, -2]
) -> BooleanVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[OTHER_TUPLE, bool]], i: Literal[1, -1]
) -> BooleanVar: ...
@overload
def __getitem__(
self: (
ARRAY_VAR_OF_LIST_ELEMENT[int]
| ARRAY_VAR_OF_LIST_ELEMENT[float]
| ARRAY_VAR_OF_LIST_ELEMENT[int | float]
),
i: int | NumberVar,
) -> NumberVar: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar
) -> StringVar: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
) -> BooleanVar: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]],
i: int | NumberVar,
) -> ArrayVar[List[INNER_ARRAY_VAR]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]],
i: int | NumberVar,
) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[INNER_ARRAY_VAR, ...]],
i: int | NumberVar,
) -> ArrayVar[Tuple[INNER_ARRAY_VAR, ...]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[Dict[KEY_TYPE, VALUE_TYPE]],
i: int | NumberVar,
) -> ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]: ...
@overload
def __getitem__(self, i: int | NumberVar) -> ImmutableVar: ...
def __getitem__(
self, i: slice | int | NumberVar
) -> ArrayVar[ARRAY_VAR_TYPE] | ImmutableVar:
"""Get a slice of the array.
Args:
i: The slice.
Returns:
The array slice operation.
"""
if isinstance(i, slice):
return ArraySliceOperation(self, i)
return ArrayItemOperation(self, i).guess_type()
def length(self) -> NumberVar:
"""Get the length of the array.
Returns:
The length of the array.
"""
return ArrayLengthOperation(self)
@overload
@classmethod
def range(cls, stop: int | NumberVar, /) -> ArrayVar[List[int]]: ...
@overload
@classmethod
def range(
cls,
start: int | NumberVar,
end: int | NumberVar,
step: int | NumberVar = 1,
/,
) -> ArrayVar[List[int]]: ...
@classmethod
def range(
cls,
first_endpoint: int | NumberVar,
second_endpoint: int | NumberVar | None = None,
step: int | NumberVar | None = None,
) -> ArrayVar[List[int]]:
"""Create a range of numbers.
Args:
first_endpoint: The end of the range if second_endpoint is not provided, otherwise the start of the range.
second_endpoint: The end of the range.
step: The step of the range.
Returns:
The range of numbers.
"""
if second_endpoint is None:
start = 0
end = first_endpoint
else:
start = first_endpoint
end = second_endpoint
return RangeOperation(start, end, step or 1)
def contains(self, other: Any) -> BooleanVar:
"""Check if the array contains an element.
Args:
other: The element to check for.
Returns:
The array contains operation.
"""
return ArrayContainsOperation(self, other)
def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]:
"""Multiply the sequence by a number or integer.
Parameters:
other (NumberVar | int): The number or integer to multiply the sequence by.
Returns:
ArrayVar[ARRAY_VAR_TYPE]: The result of multiplying the sequence by the given number or integer.
"""
return ArrayRepeatOperation(self, other)
def __rmul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]:
"""Multiply the sequence by a number or integer.
Parameters:
other (NumberVar | int): The number or integer to multiply the sequence by.
Returns:
ArrayVar[ARRAY_VAR_TYPE]: The result of multiplying the sequence by the given number or integer.
"""
return ArrayRepeatOperation(self, other)
LIST_ELEMENT = TypeVar("LIST_ELEMENT")
ARRAY_VAR_OF_LIST_ELEMENT = Union[
ArrayVar[List[LIST_ELEMENT]],
ArrayVar[Set[LIST_ELEMENT]],
ArrayVar[Tuple[LIST_ELEMENT, ...]],
]
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralArrayVar(LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
"""Base class for immutable literal array vars."""
_var_value: Union[
List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...]
] = dataclasses.field(default_factory=list)
def __init__(
self: LiteralArrayVar[ARRAY_VAR_TYPE],
_var_value: ARRAY_VAR_TYPE,
_var_type: type[ARRAY_VAR_TYPE] | None = None,
_var_data: VarData | None = None,
):
"""Initialize the array var.
Args:
_var_value: The value of the var.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralArrayVar, self).__init__(
_var_name="",
_var_data=ImmutableVarData.merge(_var_data),
_var_type=(figure_out_type(_var_value) if _var_type is None else _var_type),
)
object.__setattr__(self, "_var_value", _var_value)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@functools.cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return (
"["
+ ", ".join(
[str(LiteralVar.create(element)) for element in self._var_value]
)
+ "]"
)
@functools.cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
*[
var._get_all_var_data()
for var in self._var_value
if isinstance(var, Var)
],
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __hash__(self) -> int:
"""Get the hash of the var.
Returns:
The hash of the var.
"""
return hash((self.__class__.__name__, self._var_name))
def json(self) -> str:
"""Get the JSON representation of the var.
Returns:
The JSON representation of the var.
"""
return (
"["
+ ", ".join(
[LiteralVar.create(element).json() for element in self._var_value]
)
+ "]"
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class StringSplitOperation(ArrayVar):
"""Base class for immutable array vars that are the result of a string split operation."""
a: StringVar = dataclasses.field(
default_factory=lambda: LiteralStringVar.create("")
)
b: StringVar = dataclasses.field(
default_factory=lambda: LiteralStringVar.create("")
)
def __init__(
self, a: StringVar | str, b: StringVar | str, _var_data: VarData | None = None
):
"""Initialize the string split operation var.
Args:
a: The string.
b: The separator.
_var_data: Additional hooks and imports associated with the Var.
"""
super(StringSplitOperation, self).__init__(
_var_name="",
_var_type=List[str],
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self, "a", a if isinstance(a, Var) else LiteralStringVar.create(a)
)
object.__setattr__(
self, "b", b if isinstance(b, Var) else LiteralStringVar.create(b)
)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.split({str(self.b)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(StringSplitOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArrayToArrayOperation(ArrayVar):
"""Base class for immutable array vars that are the result of an array to array operation."""
a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
def __init__(self, a: ArrayVar, _var_data: VarData | None = None):
"""Initialize the array to array operation var.
Args:
a: The string.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ArrayToArrayOperation, self).__init__(
_var_name="",
_var_type=a._var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Raises:
NotImplementedError: Must be implemented by subclasses.
"""
raise NotImplementedError(
"ArrayToArrayOperation must implement _cached_var_name"
)
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ArrayToArrayOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data() if isinstance(self.a, Var) else None,
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArraySliceOperation(ArrayVar):
"""Base class for immutable string vars that are the result of a string slice operation."""
a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
_slice: slice = dataclasses.field(default_factory=lambda: slice(None, None, None))
def __init__(self, a: ArrayVar, _slice: slice, _var_data: VarData | None = None):
"""Initialize the string slice operation var.
Args:
a: The string.
_slice: The slice.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ArraySliceOperation, self).__init__(
_var_name="",
_var_type=a._var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a)
object.__setattr__(self, "_slice", _slice)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
Raises:
ValueError: If the slice step is zero.
"""
start, end, step = self._slice.start, self._slice.stop, self._slice.step
normalized_start = (
LiteralVar.create(start)
if start is not None
else ImmutableVar.create_safe("undefined")
)
normalized_end = (
LiteralVar.create(end)
if end is not None
else ImmutableVar.create_safe("undefined")
)
if step is None:
return (
f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)})"
)
if not isinstance(step, Var):
if step < 0:
actual_start = end + 1 if end is not None else 0
actual_end = start + 1 if start is not None else self.a.length()
return str(
ArraySliceOperation(
ArrayReverseOperation(
ArraySliceOperation(self.a, slice(actual_start, actual_end))
),
slice(None, None, -step),
)
)
if step == 0:
raise ValueError("slice step cannot be zero")
return f"{str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0)"
actual_start_reverse = end + 1 if end is not None else 0
actual_end_reverse = start + 1 if start is not None else self.a.length()
return f"{str(self.step)} > 0 ? {str(self.a)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0) : {str(self.a)}.slice({str(actual_start_reverse)}, {str(actual_end_reverse)}).reverse().filter((_, i) => i % {str(-step)} === 0)"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ArraySliceOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data(),
*[
slice_value._get_all_var_data()
for slice_value in (
self._slice.start,
self._slice.stop,
self._slice.step,
)
if slice_value is not None and isinstance(slice_value, Var)
],
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
class ArrayReverseOperation(ArrayToArrayOperation):
"""Base class for immutable string vars that are the result of a string reverse operation."""
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.slice().reverse()"
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArrayToNumberOperation(NumberVar):
"""Base class for immutable number vars that are the result of an array to number operation."""
a: ArrayVar = dataclasses.field(
default_factory=lambda: LiteralArrayVar([]),
)
def __init__(self, a: ArrayVar, _var_data: VarData | None = None):
"""Initialize the string to number operation var.
Args:
a: The array.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ArrayToNumberOperation, self).__init__(
_var_name="",
_var_type=int,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a))
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Raises:
NotImplementedError: Must be implemented by subclasses.
"""
raise NotImplementedError(
"StringToNumberOperation must implement _cached_var_name"
)
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ArrayToNumberOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(self.a._get_all_var_data(), self._var_data)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
class ArrayLengthOperation(ArrayToNumberOperation):
"""Base class for immutable number vars that are the result of an array length operation."""
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.length"
def is_tuple_type(t: GenericType) -> bool:
"""Check if a type is a tuple type.
Args:
t: The type to check.
Returns:
Whether the type is a tuple type.
"""
if inspect.isclass(t):
return issubclass(t, tuple)
return get_origin(t) is tuple
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArrayItemOperation(ImmutableVar):
"""Base class for immutable array vars that are the result of an array item operation."""
a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
i: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0))
def __init__(
self,
a: ArrayVar,
i: NumberVar | int,
_var_data: VarData | None = None,
):
"""Initialize the array item operation var.
Args:
a: The array.
i: The index.
_var_data: Additional hooks and imports associated with the Var.
"""
args = typing.get_args(a._var_type)
if args and isinstance(i, int) and is_tuple_type(a._var_type):
element_type = args[i % len(args)]
else:
element_type = unionize(*args)
super(ArrayItemOperation, self).__init__(
_var_name="",
_var_type=element_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a if isinstance(a, Var) else LiteralArrayVar(a))
object.__setattr__(
self,
"i",
i if isinstance(i, Var) else LiteralNumberVar(i),
)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.at({str(self.i)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ArrayItemOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data(), self.i._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class RangeOperation(ArrayVar):
"""Base class for immutable array vars that are the result of a range operation."""
start: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0))
end: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0))
step: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(1))
def __init__(
self,
start: NumberVar | int,
end: NumberVar | int,
step: NumberVar | int,
_var_data: VarData | None = None,
):
"""Initialize the range operation var.
Args:
start: The start of the range.
end: The end of the range.
step: The step of the range.
_var_data: Additional hooks and imports associated with the Var.
"""
super(RangeOperation, self).__init__(
_var_name="",
_var_type=List[int],
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self,
"start",
start if isinstance(start, Var) else LiteralNumberVar(start),
)
object.__setattr__(
self,
"end",
end if isinstance(end, Var) else LiteralNumberVar(end),
)
object.__setattr__(
self,
"step",
step if isinstance(step, Var) else LiteralNumberVar(step),
)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
start, end, step = self.start, self.end, self.step
return f"Array.from({{ length: ({str(end)} - {str(start)}) / {str(step)} }}, (_, i) => {str(start)} + i * {str(step)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(RangeOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.start._get_all_var_data(),
self.end._get_all_var_data(),
self.step._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArrayContainsOperation(BooleanVar):
"""Base class for immutable boolean vars that are the result of an array contains operation."""
a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
b: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
def __init__(self, a: ArrayVar, b: Any | Var, _var_data: VarData | None = None):
"""Initialize the array contains operation var.
Args:
a: The array.
b: The element to check for.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ArrayContainsOperation, self).__init__(
_var_name="",
_var_type=bool,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a)
object.__setattr__(self, "b", b if isinstance(b, Var) else LiteralVar.create(b))
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"{str(self.a)}.includes({str(self.b)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ArrayContainsOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToStringOperation(StringVar):
"""Base class for immutable string vars that are the result of a to string operation."""
original_var: Var = dataclasses.field(
default_factory=lambda: LiteralStringVar.create("")
)
def __init__(self, original_var: Var, _var_data: VarData | None = None):
"""Initialize the to string operation var.
Args:
original_var: The original var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ToStringOperation, self).__init__(
_var_name="",
_var_type=str,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self,
"original_var",
original_var,
)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return str(self.original_var)
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ToStringOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.original_var._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToArrayOperation(ArrayVar):
"""Base class for immutable array vars that are the result of a to array operation."""
original_var: Var = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
def __init__(
self,
original_var: Var,
_var_type: type[list] | type[set] | type[tuple] = list,
_var_data: VarData | None = None,
):
"""Initialize the to array operation var.
Args:
original_var: The original var.
_var_type: The type of the array.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ToArrayOperation, self).__init__(
_var_name="",
_var_type=_var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self,
"original_var",
original_var,
)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return str(self.original_var)
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ToArrayOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.original_var._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArrayRepeatOperation(ArrayVar):
"""Base class for immutable array vars that are the result of an array repeat operation."""
a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
n: NumberVar = dataclasses.field(default_factory=lambda: LiteralNumberVar(0))
def __init__(
self, a: ArrayVar, n: NumberVar | int, _var_data: VarData | None = None
):
"""Initialize the array repeat operation var.
Args:
a: The array.
n: The number of times to repeat the array.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ArrayRepeatOperation, self).__init__(
_var_name="",
_var_type=a._var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a)
object.__setattr__(
self,
"n",
n if isinstance(n, Var) else LiteralNumberVar(n),
)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"Array.from({{ length: {str(self.n)} }}).flatMap(() => {str(self.a)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ArrayRepeatOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data(), self.n._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArrayConcatOperation(ArrayVar):
"""Base class for immutable array vars that are the result of an array concat operation."""
a: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
b: ArrayVar = dataclasses.field(default_factory=lambda: LiteralArrayVar([]))
def __init__(self, a: ArrayVar, b: ArrayVar, _var_data: VarData | None = None):
"""Initialize the array concat operation var.
Args:
a: The first array.
b: The second array.
_var_data: Additional hooks and imports associated with the Var.
"""
# TODO: Figure out how to merge the types of a and b
super(ArrayConcatOperation, self).__init__(
_var_name="",
_var_type=List[ARRAY_VAR_TYPE],
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "a", a)
object.__setattr__(self, "b", b)
object.__delattr__(self, "_var_name")
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"[...{str(self.a)}, ...{str(self.b)}]"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute value.
"""
if name == "_var_name":
return self._cached_var_name
getattr(super(ArrayConcatOperation, self), name)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self.a._get_all_var_data(), self.b._get_all_var_data(), self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._cached_get_all_var_data