[REF-3321] implement var operation decorator (#3698)
* implement var operation decorator
* use older syntax
* use cast and older syntax
* use something even simpler
* add some tests
* use old union tactics
* that's not how you do things
* implement arithmetic operations while we're at it
* add test
* even more operations
* can't use __bool__
* thanos snap
* forgot ruff
* use default factory
* dang it darglint
* i know i should have done that but
* convert values into literalvars
* make test pass
* use older union tactics
* add test to string var
* pright why do you hate me 🥺
This commit is contained in:
parent
0845d2ee76
commit
ede5cd1f2c
@ -1,18 +1,20 @@
|
||||
"""Experimental Immutable-Based Var System."""
|
||||
|
||||
from .base import ArrayVar as ArrayVar
|
||||
from .base import BooleanVar as BooleanVar
|
||||
from .base import ConcatVarOperation as ConcatVarOperation
|
||||
from .base import FunctionStringVar as FunctionStringVar
|
||||
from .base import FunctionVar as FunctionVar
|
||||
from .base import ImmutableVar as ImmutableVar
|
||||
from .base import LiteralArrayVar as LiteralArrayVar
|
||||
from .base import LiteralBooleanVar as LiteralBooleanVar
|
||||
from .base import LiteralNumberVar as LiteralNumberVar
|
||||
from .base import LiteralObjectVar as LiteralObjectVar
|
||||
from .base import LiteralStringVar as LiteralStringVar
|
||||
from .base import LiteralVar as LiteralVar
|
||||
from .base import NumberVar as NumberVar
|
||||
from .base import ObjectVar as ObjectVar
|
||||
from .base import StringVar as StringVar
|
||||
from .base import VarOperationCall as VarOperationCall
|
||||
from .base import var_operation as var_operation
|
||||
from .function import FunctionStringVar as FunctionStringVar
|
||||
from .function import FunctionVar as FunctionVar
|
||||
from .function import VarOperationCall as VarOperationCall
|
||||
from .number import BooleanVar as BooleanVar
|
||||
from .number import LiteralBooleanVar as LiteralBooleanVar
|
||||
from .number import LiteralNumberVar as LiteralNumberVar
|
||||
from .number import NumberVar as NumberVar
|
||||
from .sequence import ArrayJoinOperation as ArrayJoinOperation
|
||||
from .sequence import ArrayVar as ArrayVar
|
||||
from .sequence import ConcatVarOperation as ConcatVarOperation
|
||||
from .sequence import LiteralArrayVar as LiteralArrayVar
|
||||
from .sequence import LiteralStringVar as LiteralStringVar
|
||||
from .sequence import StringVar as StringVar
|
||||
|
@ -3,15 +3,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import re
|
||||
import functools
|
||||
import sys
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from reflex import constants
|
||||
from reflex.base import Base
|
||||
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
|
||||
from reflex.utils import serializers, types
|
||||
from reflex.utils.exceptions import VarTypeError
|
||||
from reflex.vars import (
|
||||
@ -80,11 +87,12 @@ class ImmutableVar(Var):
|
||||
"""Post-initialize the var."""
|
||||
# Decode any inline Var markup and apply it to the instance
|
||||
_var_data, _var_name = _decode_var_immutable(self._var_name)
|
||||
if _var_data:
|
||||
|
||||
if _var_data or _var_name != self._var_name:
|
||||
self.__init__(
|
||||
_var_name,
|
||||
self._var_type,
|
||||
ImmutableVarData.merge(self._var_data, _var_data),
|
||||
_var_name=_var_name,
|
||||
_var_type=self._var_type,
|
||||
_var_data=ImmutableVarData.merge(self._var_data, _var_data),
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@ -255,232 +263,13 @@ class ImmutableVar(Var):
|
||||
_global_vars[hashed_var] = self
|
||||
|
||||
# Encode the _var_data into the formatted output for tracking purposes.
|
||||
return f"{REFLEX_VAR_OPENING_TAG}{hashed_var}{REFLEX_VAR_CLOSING_TAG}{self._var_name}"
|
||||
|
||||
|
||||
class StringVar(ImmutableVar):
|
||||
"""Base class for immutable string vars."""
|
||||
|
||||
|
||||
class NumberVar(ImmutableVar):
|
||||
"""Base class for immutable number vars."""
|
||||
|
||||
|
||||
class BooleanVar(ImmutableVar):
|
||||
"""Base class for immutable boolean vars."""
|
||||
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}"
|
||||
|
||||
|
||||
class ObjectVar(ImmutableVar):
|
||||
"""Base class for immutable object vars."""
|
||||
|
||||
|
||||
class ArrayVar(ImmutableVar):
|
||||
"""Base class for immutable array vars."""
|
||||
|
||||
|
||||
class FunctionVar(ImmutableVar):
|
||||
"""Base class for immutable function vars."""
|
||||
|
||||
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
|
||||
"""Call the function with the given arguments.
|
||||
|
||||
Args:
|
||||
*args: The arguments to call the function with.
|
||||
|
||||
Returns:
|
||||
The function call operation.
|
||||
"""
|
||||
return ArgsFunctionOperation(
|
||||
("...args",),
|
||||
VarOperationCall(self, *args, ImmutableVar.create_safe("...args")),
|
||||
)
|
||||
|
||||
def call(self, *args: Var | Any) -> VarOperationCall:
|
||||
"""Call the function with the given arguments.
|
||||
|
||||
Args:
|
||||
*args: The arguments to call the function with.
|
||||
|
||||
Returns:
|
||||
The function call operation.
|
||||
"""
|
||||
return VarOperationCall(self, *args)
|
||||
|
||||
|
||||
class FunctionStringVar(FunctionVar):
|
||||
"""Base class for immutable function vars from a string."""
|
||||
|
||||
def __init__(self, func: str, _var_data: VarData | None = None) -> None:
|
||||
"""Initialize the function var.
|
||||
|
||||
Args:
|
||||
func: The function to call.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(FunctionVar, self).__init__(
|
||||
_var_name=func,
|
||||
_var_type=Callable,
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class VarOperationCall(ImmutableVar):
|
||||
"""Base class for immutable vars that are the result of a function call."""
|
||||
|
||||
_func: Optional[FunctionVar] = dataclasses.field(default=None)
|
||||
_args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
|
||||
|
||||
def __init__(
|
||||
self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None
|
||||
):
|
||||
"""Initialize the function call var.
|
||||
|
||||
Args:
|
||||
func: The function to call.
|
||||
*args: The arguments to call the function with.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(VarOperationCall, self).__init__(
|
||||
_var_name="",
|
||||
_var_type=Callable,
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
object.__setattr__(self, "_func", func)
|
||||
object.__setattr__(self, "_args", args)
|
||||
object.__delattr__(self, "_var_name")
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Get an attribute of the var.
|
||||
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
|
||||
Returns:
|
||||
The attribute of the var.
|
||||
"""
|
||||
if name == "_var_name":
|
||||
return self._cached_var_name
|
||||
return super(type(self), self).__getattr__(name)
|
||||
|
||||
@cached_property
|
||||
def _cached_var_name(self) -> str:
|
||||
"""The name of the var.
|
||||
|
||||
Returns:
|
||||
The name of the var.
|
||||
"""
|
||||
return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
|
||||
|
||||
@cached_property
|
||||
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Get all VarData associated with the Var.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return ImmutableVarData.merge(
|
||||
self._func._get_all_var_data() if self._func is not None else None,
|
||||
*[var._get_all_var_data() for var in self._args],
|
||||
self._var_data,
|
||||
)
|
||||
|
||||
def _get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Wrapper method for cached property.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return self._cached_get_all_var_data
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialize the var."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class ArgsFunctionOperation(FunctionVar):
|
||||
"""Base class for immutable function defined via arguments and return expression."""
|
||||
|
||||
_args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
|
||||
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args_names: Tuple[str, ...],
|
||||
return_expr: Var | Any,
|
||||
_var_data: VarData | None = None,
|
||||
) -> None:
|
||||
"""Initialize the function with arguments var.
|
||||
|
||||
Args:
|
||||
args_names: The names of the arguments.
|
||||
return_expr: The return expression of the function.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(ArgsFunctionOperation, self).__init__(
|
||||
_var_name=f"",
|
||||
_var_type=Callable,
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
object.__setattr__(self, "_args_names", args_names)
|
||||
object.__setattr__(self, "_return_expr", return_expr)
|
||||
object.__delattr__(self, "_var_name")
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Get an attribute of the var.
|
||||
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
|
||||
Returns:
|
||||
The attribute of the var.
|
||||
"""
|
||||
if name == "_var_name":
|
||||
return self._cached_var_name
|
||||
return super(type(self), self).__getattr__(name)
|
||||
|
||||
@cached_property
|
||||
def _cached_var_name(self) -> str:
|
||||
"""The name of the var.
|
||||
|
||||
Returns:
|
||||
The name of the var.
|
||||
"""
|
||||
return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
|
||||
|
||||
@cached_property
|
||||
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Get all VarData associated with the Var.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return ImmutableVarData.merge(
|
||||
self._return_expr._get_all_var_data(),
|
||||
self._var_data,
|
||||
)
|
||||
|
||||
def _get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Wrapper method for cached property.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return self._cached_get_all_var_data
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialize the var."""
|
||||
|
||||
|
||||
class LiteralVar(ImmutableVar):
|
||||
"""Base class for immutable literal vars."""
|
||||
|
||||
@ -515,9 +304,22 @@ class LiteralVar(ImmutableVar):
|
||||
value.dict(), _var_type=type(value), _var_data=_var_data
|
||||
)
|
||||
|
||||
from .number import LiteralBooleanVar, LiteralNumberVar
|
||||
from .sequence import LiteralArrayVar, LiteralStringVar
|
||||
|
||||
if isinstance(value, str):
|
||||
return LiteralStringVar.create(value, _var_data=_var_data)
|
||||
|
||||
type_mapping = {
|
||||
int: LiteralNumberVar,
|
||||
float: LiteralNumberVar,
|
||||
bool: LiteralBooleanVar,
|
||||
dict: LiteralObjectVar,
|
||||
list: LiteralArrayVar,
|
||||
tuple: LiteralArrayVar,
|
||||
set: LiteralArrayVar,
|
||||
}
|
||||
|
||||
constructor = type_mapping.get(type(value))
|
||||
|
||||
if constructor is None:
|
||||
@ -529,256 +331,6 @@ class LiteralVar(ImmutableVar):
|
||||
"""Post-initialize the var."""
|
||||
|
||||
|
||||
# Compile regex for finding reflex var tags.
|
||||
_decode_var_pattern_re = (
|
||||
rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
|
||||
)
|
||||
_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class LiteralStringVar(LiteralVar):
|
||||
"""Base class for immutable literal string vars."""
|
||||
|
||||
_var_value: str = dataclasses.field(default="")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_var_value: str,
|
||||
_var_data: VarData | None = None,
|
||||
):
|
||||
"""Initialize the string var.
|
||||
|
||||
Args:
|
||||
_var_value: The value of the var.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(LiteralStringVar, self).__init__(
|
||||
_var_name=f'"{_var_value}"',
|
||||
_var_type=str,
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
object.__setattr__(self, "_var_value", _var_value)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
value: str,
|
||||
_var_data: VarData | None = None,
|
||||
) -> LiteralStringVar | ConcatVarOperation:
|
||||
"""Create a var from a string value.
|
||||
|
||||
Args:
|
||||
value: The value to create the var from.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
|
||||
Returns:
|
||||
The var.
|
||||
"""
|
||||
if REFLEX_VAR_OPENING_TAG in value:
|
||||
strings_and_vals: list[Var | str] = []
|
||||
offset = 0
|
||||
|
||||
# Initialize some methods for reading json.
|
||||
var_data_config = VarData().__config__
|
||||
|
||||
def json_loads(s):
|
||||
try:
|
||||
return var_data_config.json_loads(s)
|
||||
except json.decoder.JSONDecodeError:
|
||||
return var_data_config.json_loads(
|
||||
var_data_config.json_loads(f'"{s}"')
|
||||
)
|
||||
|
||||
# Find all tags.
|
||||
while m := _decode_var_pattern.search(value):
|
||||
start, end = m.span()
|
||||
if start > 0:
|
||||
strings_and_vals.append(value[:start])
|
||||
|
||||
serialized_data = m.group(1)
|
||||
|
||||
if serialized_data[1:].isnumeric():
|
||||
# This is a global immutable var.
|
||||
var = _global_vars[int(serialized_data)]
|
||||
strings_and_vals.append(var)
|
||||
value = value[(end + len(var._var_name)) :]
|
||||
else:
|
||||
data = json_loads(serialized_data)
|
||||
string_length = data.pop("string_length", None)
|
||||
var_data = VarData.parse_obj(data)
|
||||
|
||||
# Use string length to compute positions of interpolations.
|
||||
if string_length is not None:
|
||||
realstart = start + offset
|
||||
var_data.interpolations = [
|
||||
(realstart, realstart + string_length)
|
||||
]
|
||||
strings_and_vals.append(
|
||||
ImmutableVar.create_safe(
|
||||
value[end : (end + string_length)], _var_data=var_data
|
||||
)
|
||||
)
|
||||
value = value[(end + string_length) :]
|
||||
|
||||
offset += end - start
|
||||
|
||||
if value:
|
||||
strings_and_vals.append(value)
|
||||
|
||||
return ConcatVarOperation(*strings_and_vals, _var_data=_var_data)
|
||||
|
||||
return LiteralStringVar(
|
||||
value,
|
||||
_var_data=_var_data,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class ConcatVarOperation(StringVar):
|
||||
"""Representing a concatenation of literal string vars."""
|
||||
|
||||
_var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple)
|
||||
|
||||
def __init__(self, *value: Var | str, _var_data: VarData | None = None):
|
||||
"""Initialize the operation of concatenating literal string vars.
|
||||
|
||||
Args:
|
||||
value: The values to concatenate.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(ConcatVarOperation, self).__init__(
|
||||
_var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str
|
||||
)
|
||||
object.__setattr__(self, "_var_value", value)
|
||||
object.__delattr__(self, "_var_name")
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Get an attribute of the var.
|
||||
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
|
||||
Returns:
|
||||
The attribute of the var.
|
||||
"""
|
||||
if name == "_var_name":
|
||||
return self._cached_var_name
|
||||
return super(type(self), self).__getattr__(name)
|
||||
|
||||
@cached_property
|
||||
def _cached_var_name(self) -> str:
|
||||
"""The name of the var.
|
||||
|
||||
Returns:
|
||||
The name of the var.
|
||||
"""
|
||||
return (
|
||||
"("
|
||||
+ "+".join(
|
||||
[
|
||||
str(element) if isinstance(element, Var) else f'"{element}"'
|
||||
for element in self._var_value
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Get all VarData associated with the Var.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return ImmutableVarData.merge(
|
||||
*[
|
||||
var._get_all_var_data()
|
||||
for var in self._var_value
|
||||
if isinstance(var, Var)
|
||||
],
|
||||
self._var_data,
|
||||
)
|
||||
|
||||
def _get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Wrapper method for cached property.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return self._cached_get_all_var_data
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialize the var."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class LiteralBooleanVar(LiteralVar):
|
||||
"""Base class for immutable literal boolean vars."""
|
||||
|
||||
_var_value: bool = dataclasses.field(default=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_var_value: bool,
|
||||
_var_data: VarData | None = None,
|
||||
):
|
||||
"""Initialize the boolean var.
|
||||
|
||||
Args:
|
||||
_var_value: The value of the var.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(LiteralBooleanVar, self).__init__(
|
||||
_var_name="true" if _var_value else "false",
|
||||
_var_type=bool,
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
object.__setattr__(self, "_var_value", _var_value)
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class LiteralNumberVar(LiteralVar):
|
||||
"""Base class for immutable literal number vars."""
|
||||
|
||||
_var_value: float | int = dataclasses.field(default=0)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_var_value: float | int,
|
||||
_var_data: VarData | None = None,
|
||||
):
|
||||
"""Initialize the number var.
|
||||
|
||||
Args:
|
||||
_var_value: The value of the var.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(LiteralNumberVar, self).__init__(
|
||||
_var_name=str(_var_value),
|
||||
_var_type=type(_var_value),
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
object.__setattr__(self, "_var_value", _var_value)
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
@ -828,7 +380,7 @@ class LiteralObjectVar(LiteralVar):
|
||||
return self._cached_var_name
|
||||
return super(type(self), self).__getattr__(name)
|
||||
|
||||
@cached_property
|
||||
@functools.cached_property
|
||||
def _cached_var_name(self) -> str:
|
||||
"""The name of the var.
|
||||
|
||||
@ -846,8 +398,8 @@ class LiteralObjectVar(LiteralVar):
|
||||
+ " }"
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _get_all_var_data(self) -> ImmutableVarData | None:
|
||||
@functools.cached_property
|
||||
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Get all VarData associated with the Var.
|
||||
|
||||
Returns:
|
||||
@ -867,89 +419,59 @@ class LiteralObjectVar(LiteralVar):
|
||||
self._var_data,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class LiteralArrayVar(LiteralVar):
|
||||
"""Base class for immutable literal array vars."""
|
||||
|
||||
_var_value: Union[
|
||||
List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...]
|
||||
] = dataclasses.field(default_factory=list)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any],
|
||||
_var_data: VarData | None = None,
|
||||
):
|
||||
"""Initialize the array var.
|
||||
|
||||
Args:
|
||||
_var_value: The value of the var.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(LiteralArrayVar, self).__init__(
|
||||
_var_name="",
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
_var_type=list,
|
||||
)
|
||||
object.__setattr__(self, "_var_value", _var_value)
|
||||
object.__delattr__(self, "_var_name")
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Get an attribute of the var.
|
||||
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
|
||||
Returns:
|
||||
The attribute of the var.
|
||||
"""
|
||||
if name == "_var_name":
|
||||
return self._cached_var_name
|
||||
return super(type(self), self).__getattr__(name)
|
||||
|
||||
@cached_property
|
||||
def _cached_var_name(self) -> str:
|
||||
"""The name of the var.
|
||||
|
||||
Returns:
|
||||
The name of the var.
|
||||
"""
|
||||
return (
|
||||
"["
|
||||
+ ", ".join(
|
||||
[str(LiteralVar.create(element)) for element in self._var_value]
|
||||
)
|
||||
+ "]"
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Get all VarData associated with the Var.
|
||||
"""Wrapper method for cached property.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return ImmutableVarData.merge(
|
||||
*[
|
||||
var._get_all_var_data()
|
||||
for var in self._var_value
|
||||
if isinstance(var, Var)
|
||||
],
|
||||
self._var_data,
|
||||
)
|
||||
return self._cached_get_all_var_data
|
||||
|
||||
|
||||
type_mapping = {
|
||||
int: LiteralNumberVar,
|
||||
float: LiteralNumberVar,
|
||||
bool: LiteralBooleanVar,
|
||||
dict: LiteralObjectVar,
|
||||
list: LiteralArrayVar,
|
||||
tuple: LiteralArrayVar,
|
||||
set: LiteralArrayVar,
|
||||
}
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T", bound=ImmutableVar)
|
||||
|
||||
|
||||
def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P, T]]:
|
||||
"""Decorator for creating a var operation.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@var_operation(output=NumberVar)
|
||||
def add(a: NumberVar, b: NumberVar):
|
||||
return f"({a} + {b})"
|
||||
```
|
||||
|
||||
Args:
|
||||
output: The output type of the operation.
|
||||
|
||||
Returns:
|
||||
The decorator.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, str], output=output):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
args_vars = [
|
||||
LiteralVar.create(arg) if not isinstance(arg, Var) else arg
|
||||
for arg in args
|
||||
]
|
||||
kwargs_vars = {
|
||||
key: LiteralVar.create(value) if not isinstance(value, Var) else value
|
||||
for key, value in kwargs.items()
|
||||
}
|
||||
return output(
|
||||
_var_name=func(*args_vars, **kwargs_vars), # type: ignore
|
||||
_var_data=VarData.merge(
|
||||
*[arg._get_all_var_data() for arg in args if isinstance(arg, Var)],
|
||||
*[
|
||||
arg._get_all_var_data()
|
||||
for arg in kwargs.values()
|
||||
if isinstance(arg, Var)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
214
reflex/experimental/vars/function.py
Normal file
214
reflex/experimental/vars/function.py
Normal file
@ -0,0 +1,214 @@
|
||||
"""Immutable function vars."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
from reflex.experimental.vars.base import ImmutableVar, LiteralVar
|
||||
from reflex.vars import ImmutableVarData, Var, VarData
|
||||
|
||||
|
||||
class FunctionVar(ImmutableVar):
|
||||
"""Base class for immutable function vars."""
|
||||
|
||||
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
|
||||
"""Call the function with the given arguments.
|
||||
|
||||
Args:
|
||||
*args: The arguments to call the function with.
|
||||
|
||||
Returns:
|
||||
The function call operation.
|
||||
"""
|
||||
return ArgsFunctionOperation(
|
||||
("...args",),
|
||||
VarOperationCall(self, *args, ImmutableVar.create_safe("...args")),
|
||||
)
|
||||
|
||||
def call(self, *args: Var | Any) -> VarOperationCall:
|
||||
"""Call the function with the given arguments.
|
||||
|
||||
Args:
|
||||
*args: The arguments to call the function with.
|
||||
|
||||
Returns:
|
||||
The function call operation.
|
||||
"""
|
||||
return VarOperationCall(self, *args)
|
||||
|
||||
|
||||
class FunctionStringVar(FunctionVar):
|
||||
"""Base class for immutable function vars from a string."""
|
||||
|
||||
def __init__(self, func: str, _var_data: VarData | None = None) -> None:
|
||||
"""Initialize the function var.
|
||||
|
||||
Args:
|
||||
func: The function to call.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(FunctionVar, self).__init__(
|
||||
_var_name=func,
|
||||
_var_type=Callable,
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class VarOperationCall(ImmutableVar):
|
||||
"""Base class for immutable vars that are the result of a function call."""
|
||||
|
||||
_func: Optional[FunctionVar] = dataclasses.field(default=None)
|
||||
_args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
|
||||
|
||||
def __init__(
|
||||
self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None
|
||||
):
|
||||
"""Initialize the function call var.
|
||||
|
||||
Args:
|
||||
func: The function to call.
|
||||
*args: The arguments to call the function with.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(VarOperationCall, self).__init__(
|
||||
_var_name="",
|
||||
_var_type=Any,
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
object.__setattr__(self, "_func", func)
|
||||
object.__setattr__(self, "_args", args)
|
||||
object.__delattr__(self, "_var_name")
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Get an attribute of the var.
|
||||
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
|
||||
Returns:
|
||||
The attribute of the var.
|
||||
"""
|
||||
if name == "_var_name":
|
||||
return self._cached_var_name
|
||||
return super(type(self), self).__getattr__(name)
|
||||
|
||||
@cached_property
|
||||
def _cached_var_name(self) -> str:
|
||||
"""The name of the var.
|
||||
|
||||
Returns:
|
||||
The name of the var.
|
||||
"""
|
||||
return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
|
||||
|
||||
@cached_property
|
||||
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Get all VarData associated with the Var.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return ImmutableVarData.merge(
|
||||
self._func._get_all_var_data() if self._func is not None else None,
|
||||
*[var._get_all_var_data() for var in self._args],
|
||||
self._var_data,
|
||||
)
|
||||
|
||||
def _get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Wrapper method for cached property.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return self._cached_get_all_var_data
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialize the var."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
)
|
||||
class ArgsFunctionOperation(FunctionVar):
|
||||
"""Base class for immutable function defined via arguments and return expression."""
|
||||
|
||||
_args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
|
||||
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args_names: Tuple[str, ...],
|
||||
return_expr: Var | Any,
|
||||
_var_data: VarData | None = None,
|
||||
) -> None:
|
||||
"""Initialize the function with arguments var.
|
||||
|
||||
Args:
|
||||
args_names: The names of the arguments.
|
||||
return_expr: The return expression of the function.
|
||||
_var_data: Additional hooks and imports associated with the Var.
|
||||
"""
|
||||
super(ArgsFunctionOperation, self).__init__(
|
||||
_var_name=f"",
|
||||
_var_type=Callable,
|
||||
_var_data=ImmutableVarData.merge(_var_data),
|
||||
)
|
||||
object.__setattr__(self, "_args_names", args_names)
|
||||
object.__setattr__(self, "_return_expr", return_expr)
|
||||
object.__delattr__(self, "_var_name")
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Get an attribute of the var.
|
||||
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
|
||||
Returns:
|
||||
The attribute of the var.
|
||||
"""
|
||||
if name == "_var_name":
|
||||
return self._cached_var_name
|
||||
return super(type(self), self).__getattr__(name)
|
||||
|
||||
@cached_property
|
||||
def _cached_var_name(self) -> str:
|
||||
"""The name of the var.
|
||||
|
||||
Returns:
|
||||
The name of the var.
|
||||
"""
|
||||
return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
|
||||
|
||||
@cached_property
|
||||
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Get all VarData associated with the Var.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return ImmutableVarData.merge(
|
||||
self._return_expr._get_all_var_data(),
|
||||
self._var_data,
|
||||
)
|
||||
|
||||
def _get_all_var_data(self) -> ImmutableVarData | None:
|
||||
"""Wrapper method for cached property.
|
||||
|
||||
Returns:
|
||||
The VarData of the components and all of its children.
|
||||
"""
|
||||
return self._cached_get_all_var_data
|
||||
|
||||
def __post_init__(self):
|
||||
"""Post-initialize the var."""
|
1295
reflex/experimental/vars/number.py
Normal file
1295
reflex/experimental/vars/number.py
Normal file
File diff suppressed because it is too large
Load Diff
1039
reflex/experimental/vars/sequence.py
Normal file
1039
reflex/experimental/vars/sequence.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -379,7 +379,9 @@ def _decode_var_immutable(value: str) -> tuple[ImmutableVarData | None, str]:
|
||||
|
||||
serialized_data = m.group(1)
|
||||
|
||||
if serialized_data[1:].isnumeric():
|
||||
if serialized_data.isnumeric() or (
|
||||
serialized_data[0] == "-" and serialized_data[1:].isnumeric()
|
||||
):
|
||||
# This is a global immutable var.
|
||||
var = _global_vars[int(serialized_data)]
|
||||
var_data = var._var_data
|
||||
@ -473,7 +475,9 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
|
||||
|
||||
serialized_data = m.group(1)
|
||||
|
||||
if serialized_data[1:].isnumeric():
|
||||
if serialized_data.isnumeric() or (
|
||||
serialized_data[0] == "-" and serialized_data[1:].isnumeric()
|
||||
):
|
||||
# This is a global immutable var.
|
||||
var = _global_vars[int(serialized_data)]
|
||||
var_data = var._var_data
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import math
|
||||
import typing
|
||||
from typing import Dict, List, Set, Tuple, Union
|
||||
|
||||
@ -8,13 +9,17 @@ from pandas import DataFrame
|
||||
from reflex.base import Base
|
||||
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
|
||||
from reflex.experimental.vars.base import (
|
||||
ArgsFunctionOperation,
|
||||
ConcatVarOperation,
|
||||
FunctionStringVar,
|
||||
ImmutableVar,
|
||||
LiteralStringVar,
|
||||
LiteralVar,
|
||||
var_operation,
|
||||
)
|
||||
from reflex.experimental.vars.function import ArgsFunctionOperation, FunctionStringVar
|
||||
from reflex.experimental.vars.number import (
|
||||
LiteralBooleanVar,
|
||||
LiteralNumberVar,
|
||||
NumberVar,
|
||||
)
|
||||
from reflex.experimental.vars.sequence import ConcatVarOperation, LiteralStringVar
|
||||
from reflex.state import BaseState
|
||||
from reflex.utils.imports import ImportVar
|
||||
from reflex.vars import (
|
||||
@ -913,6 +918,60 @@ def test_function_var():
|
||||
)
|
||||
|
||||
|
||||
def test_var_operation():
|
||||
@var_operation(output=NumberVar)
|
||||
def add(a: Union[NumberVar, int], b: Union[NumberVar, int]) -> str:
|
||||
return f"({a} + {b})"
|
||||
|
||||
assert str(add(1, 2)) == "(1 + 2)"
|
||||
assert str(add(a=4, b=-9)) == "(4 + -9)"
|
||||
|
||||
five = LiteralNumberVar(5)
|
||||
seven = add(2, five)
|
||||
|
||||
assert isinstance(seven, NumberVar)
|
||||
|
||||
|
||||
def test_string_operations():
|
||||
basic_string = LiteralStringVar.create("Hello, World!")
|
||||
|
||||
assert str(basic_string.length()) == '"Hello, World!".length'
|
||||
assert str(basic_string.lower()) == '"Hello, World!".toLowerCase()'
|
||||
assert str(basic_string.upper()) == '"Hello, World!".toUpperCase()'
|
||||
assert str(basic_string.strip()) == '"Hello, World!".trim()'
|
||||
assert str(basic_string.contains("World")) == '"Hello, World!".includes("World")'
|
||||
assert (
|
||||
str(basic_string.split(" ").join(",")) == '"Hello, World!".split(" ").join(",")'
|
||||
)
|
||||
|
||||
|
||||
def test_all_number_operations():
|
||||
starting_number = LiteralNumberVar(-5.4)
|
||||
|
||||
complicated_number = (((-(starting_number + 1)) * 2 / 3) // 2 % 3) ** 2
|
||||
|
||||
assert (
|
||||
str(complicated_number)
|
||||
== "((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)"
|
||||
)
|
||||
|
||||
even_more_complicated_number = ~(
|
||||
abs(math.floor(complicated_number)) | 2 & 3 & round(complicated_number)
|
||||
)
|
||||
|
||||
assert (
|
||||
str(even_more_complicated_number)
|
||||
== "!(((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) != 0) || (true && (Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)) != 0))))"
|
||||
)
|
||||
|
||||
assert str(LiteralNumberVar(5) > False) == "(5 > 0)"
|
||||
assert str(LiteralBooleanVar(False) < 5) == "((false ? 1 : 0) < 5)"
|
||||
assert (
|
||||
str(LiteralBooleanVar(False) < LiteralBooleanVar(True))
|
||||
== "((false ? 1 : 0) < (true ? 1 : 0))"
|
||||
)
|
||||
|
||||
|
||||
def test_retrival():
|
||||
var_without_data = ImmutableVar.create("test")
|
||||
assert var_without_data is not None
|
||||
|
Loading…
Reference in New Issue
Block a user