[REF-3321] implement var operation decorator (#3698)

* implement var operation decorator

* use older syntax

* use cast and older syntax

* use something even simpler

* add some tests

* use old union tactics

* that's not how you do things

* implement arithmetic operations while we're at it

* add test

* even more operations

* can't use __bool__

* thanos snap

* forgot ruff

* use default factory

* dang it darglint

* i know i should have done that but

* convert values into literalvars

* make test pass

* use older union tactics

* add test to string var

* pright why do you hate me 🥺
This commit is contained in:
Khaleel Al-Adhami 2024-07-25 09:34:14 -07:00 committed by GitHub
parent 0845d2ee76
commit ede5cd1f2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 2714 additions and 579 deletions

View File

@ -1,18 +1,20 @@
"""Experimental Immutable-Based Var System."""
from .base import ArrayVar as ArrayVar
from .base import BooleanVar as BooleanVar
from .base import ConcatVarOperation as ConcatVarOperation
from .base import FunctionStringVar as FunctionStringVar
from .base import FunctionVar as FunctionVar
from .base import ImmutableVar as ImmutableVar
from .base import LiteralArrayVar as LiteralArrayVar
from .base import LiteralBooleanVar as LiteralBooleanVar
from .base import LiteralNumberVar as LiteralNumberVar
from .base import LiteralObjectVar as LiteralObjectVar
from .base import LiteralStringVar as LiteralStringVar
from .base import LiteralVar as LiteralVar
from .base import NumberVar as NumberVar
from .base import ObjectVar as ObjectVar
from .base import StringVar as StringVar
from .base import VarOperationCall as VarOperationCall
from .base import var_operation as var_operation
from .function import FunctionStringVar as FunctionStringVar
from .function import FunctionVar as FunctionVar
from .function import VarOperationCall as VarOperationCall
from .number import BooleanVar as BooleanVar
from .number import LiteralBooleanVar as LiteralBooleanVar
from .number import LiteralNumberVar as LiteralNumberVar
from .number import NumberVar as NumberVar
from .sequence import ArrayJoinOperation as ArrayJoinOperation
from .sequence import ArrayVar as ArrayVar
from .sequence import ConcatVarOperation as ConcatVarOperation
from .sequence import LiteralArrayVar as LiteralArrayVar
from .sequence import LiteralStringVar as LiteralStringVar
from .sequence import StringVar as StringVar

View File

@ -3,15 +3,22 @@
from __future__ import annotations
import dataclasses
import json
import re
import functools
import sys
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
Optional,
Type,
TypeVar,
Union,
)
from typing_extensions import ParamSpec
from reflex import constants
from reflex.base import Base
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.utils import serializers, types
from reflex.utils.exceptions import VarTypeError
from reflex.vars import (
@ -80,11 +87,12 @@ class ImmutableVar(Var):
"""Post-initialize the var."""
# Decode any inline Var markup and apply it to the instance
_var_data, _var_name = _decode_var_immutable(self._var_name)
if _var_data:
if _var_data or _var_name != self._var_name:
self.__init__(
_var_name,
self._var_type,
ImmutableVarData.merge(self._var_data, _var_data),
_var_name=_var_name,
_var_type=self._var_type,
_var_data=ImmutableVarData.merge(self._var_data, _var_data),
)
def __hash__(self) -> int:
@ -255,232 +263,13 @@ class ImmutableVar(Var):
_global_vars[hashed_var] = self
# Encode the _var_data into the formatted output for tracking purposes.
return f"{REFLEX_VAR_OPENING_TAG}{hashed_var}{REFLEX_VAR_CLOSING_TAG}{self._var_name}"
class StringVar(ImmutableVar):
"""Base class for immutable string vars."""
class NumberVar(ImmutableVar):
"""Base class for immutable number vars."""
class BooleanVar(ImmutableVar):
"""Base class for immutable boolean vars."""
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}"
class ObjectVar(ImmutableVar):
"""Base class for immutable object vars."""
class ArrayVar(ImmutableVar):
"""Base class for immutable array vars."""
class FunctionVar(ImmutableVar):
"""Base class for immutable function vars."""
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
"""Call the function with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The function call operation.
"""
return ArgsFunctionOperation(
("...args",),
VarOperationCall(self, *args, ImmutableVar.create_safe("...args")),
)
def call(self, *args: Var | Any) -> VarOperationCall:
"""Call the function with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The function call operation.
"""
return VarOperationCall(self, *args)
class FunctionStringVar(FunctionVar):
"""Base class for immutable function vars from a string."""
def __init__(self, func: str, _var_data: VarData | None = None) -> None:
"""Initialize the function var.
Args:
func: The function to call.
_var_data: Additional hooks and imports associated with the Var.
"""
super(FunctionVar, self).__init__(
_var_name=func,
_var_type=Callable,
_var_data=ImmutableVarData.merge(_var_data),
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class VarOperationCall(ImmutableVar):
"""Base class for immutable vars that are the result of a function call."""
_func: Optional[FunctionVar] = dataclasses.field(default=None)
_args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
def __init__(
self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None
):
"""Initialize the function call var.
Args:
func: The function to call.
*args: The arguments to call the function with.
_var_data: Additional hooks and imports associated with the Var.
"""
super(VarOperationCall, self).__init__(
_var_name="",
_var_type=Callable,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_func", func)
object.__setattr__(self, "_args", args)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._func._get_all_var_data() if self._func is not None else None,
*[var._get_all_var_data() for var in self._args],
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""
pass
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArgsFunctionOperation(FunctionVar):
"""Base class for immutable function defined via arguments and return expression."""
_args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
def __init__(
self,
args_names: Tuple[str, ...],
return_expr: Var | Any,
_var_data: VarData | None = None,
) -> None:
"""Initialize the function with arguments var.
Args:
args_names: The names of the arguments.
return_expr: The return expression of the function.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ArgsFunctionOperation, self).__init__(
_var_name=f"",
_var_type=Callable,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_args_names", args_names)
object.__setattr__(self, "_return_expr", return_expr)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._return_expr._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""
class LiteralVar(ImmutableVar):
"""Base class for immutable literal vars."""
@ -515,9 +304,22 @@ class LiteralVar(ImmutableVar):
value.dict(), _var_type=type(value), _var_data=_var_data
)
from .number import LiteralBooleanVar, LiteralNumberVar
from .sequence import LiteralArrayVar, LiteralStringVar
if isinstance(value, str):
return LiteralStringVar.create(value, _var_data=_var_data)
type_mapping = {
int: LiteralNumberVar,
float: LiteralNumberVar,
bool: LiteralBooleanVar,
dict: LiteralObjectVar,
list: LiteralArrayVar,
tuple: LiteralArrayVar,
set: LiteralArrayVar,
}
constructor = type_mapping.get(type(value))
if constructor is None:
@ -529,256 +331,6 @@ class LiteralVar(ImmutableVar):
"""Post-initialize the var."""
# Compile regex for finding reflex var tags.
_decode_var_pattern_re = (
rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
)
_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralStringVar(LiteralVar):
"""Base class for immutable literal string vars."""
_var_value: str = dataclasses.field(default="")
def __init__(
self,
_var_value: str,
_var_data: VarData | None = None,
):
"""Initialize the string var.
Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralStringVar, self).__init__(
_var_name=f'"{_var_value}"',
_var_type=str,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_var_value", _var_value)
@classmethod
def create(
cls,
value: str,
_var_data: VarData | None = None,
) -> LiteralStringVar | ConcatVarOperation:
"""Create a var from a string value.
Args:
value: The value to create the var from.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
"""
if REFLEX_VAR_OPENING_TAG in value:
strings_and_vals: list[Var | str] = []
offset = 0
# Initialize some methods for reading json.
var_data_config = VarData().__config__
def json_loads(s):
try:
return var_data_config.json_loads(s)
except json.decoder.JSONDecodeError:
return var_data_config.json_loads(
var_data_config.json_loads(f'"{s}"')
)
# Find all tags.
while m := _decode_var_pattern.search(value):
start, end = m.span()
if start > 0:
strings_and_vals.append(value[:start])
serialized_data = m.group(1)
if serialized_data[1:].isnumeric():
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
strings_and_vals.append(var)
value = value[(end + len(var._var_name)) :]
else:
data = json_loads(serialized_data)
string_length = data.pop("string_length", None)
var_data = VarData.parse_obj(data)
# Use string length to compute positions of interpolations.
if string_length is not None:
realstart = start + offset
var_data.interpolations = [
(realstart, realstart + string_length)
]
strings_and_vals.append(
ImmutableVar.create_safe(
value[end : (end + string_length)], _var_data=var_data
)
)
value = value[(end + string_length) :]
offset += end - start
if value:
strings_and_vals.append(value)
return ConcatVarOperation(*strings_and_vals, _var_data=_var_data)
return LiteralStringVar(
value,
_var_data=_var_data,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ConcatVarOperation(StringVar):
"""Representing a concatenation of literal string vars."""
_var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple)
def __init__(self, *value: Var | str, _var_data: VarData | None = None):
"""Initialize the operation of concatenating literal string vars.
Args:
value: The values to concatenate.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ConcatVarOperation, self).__init__(
_var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str
)
object.__setattr__(self, "_var_value", value)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return (
"("
+ "+".join(
[
str(element) if isinstance(element, Var) else f'"{element}"'
for element in self._var_value
]
)
+ ")"
)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
*[
var._get_all_var_data()
for var in self._var_value
if isinstance(var, Var)
],
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""
pass
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralBooleanVar(LiteralVar):
"""Base class for immutable literal boolean vars."""
_var_value: bool = dataclasses.field(default=False)
def __init__(
self,
_var_value: bool,
_var_data: VarData | None = None,
):
"""Initialize the boolean var.
Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralBooleanVar, self).__init__(
_var_name="true" if _var_value else "false",
_var_type=bool,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_var_value", _var_value)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralNumberVar(LiteralVar):
"""Base class for immutable literal number vars."""
_var_value: float | int = dataclasses.field(default=0)
def __init__(
self,
_var_value: float | int,
_var_data: VarData | None = None,
):
"""Initialize the number var.
Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralNumberVar, self).__init__(
_var_name=str(_var_value),
_var_type=type(_var_value),
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_var_value", _var_value)
@dataclasses.dataclass(
eq=False,
frozen=True,
@ -828,7 +380,7 @@ class LiteralObjectVar(LiteralVar):
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
@functools.cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
@ -846,8 +398,8 @@ class LiteralObjectVar(LiteralVar):
+ " }"
)
@cached_property
def _get_all_var_data(self) -> ImmutableVarData | None:
@functools.cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
@ -867,89 +419,59 @@ class LiteralObjectVar(LiteralVar):
self._var_data,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralArrayVar(LiteralVar):
"""Base class for immutable literal array vars."""
_var_value: Union[
List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...]
] = dataclasses.field(default_factory=list)
def __init__(
self,
_var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any],
_var_data: VarData | None = None,
):
"""Initialize the array var.
Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralArrayVar, self).__init__(
_var_name="",
_var_data=ImmutableVarData.merge(_var_data),
_var_type=list,
)
object.__setattr__(self, "_var_value", _var_value)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return (
"["
+ ", ".join(
[str(LiteralVar.create(element)) for element in self._var_value]
)
+ "]"
)
@cached_property
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
*[
var._get_all_var_data()
for var in self._var_value
if isinstance(var, Var)
],
self._var_data,
)
return self._cached_get_all_var_data
type_mapping = {
int: LiteralNumberVar,
float: LiteralNumberVar,
bool: LiteralBooleanVar,
dict: LiteralObjectVar,
list: LiteralArrayVar,
tuple: LiteralArrayVar,
set: LiteralArrayVar,
}
P = ParamSpec("P")
T = TypeVar("T", bound=ImmutableVar)
def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P, T]]:
"""Decorator for creating a var operation.
Example:
```python
@var_operation(output=NumberVar)
def add(a: NumberVar, b: NumberVar):
return f"({a} + {b})"
```
Args:
output: The output type of the operation.
Returns:
The decorator.
"""
def decorator(func: Callable[P, str], output=output):
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
args_vars = [
LiteralVar.create(arg) if not isinstance(arg, Var) else arg
for arg in args
]
kwargs_vars = {
key: LiteralVar.create(value) if not isinstance(value, Var) else value
for key, value in kwargs.items()
}
return output(
_var_name=func(*args_vars, **kwargs_vars), # type: ignore
_var_data=VarData.merge(
*[arg._get_all_var_data() for arg in args if isinstance(arg, Var)],
*[
arg._get_all_var_data()
for arg in kwargs.values()
if isinstance(arg, Var)
],
),
)
return wrapper
return decorator

View File

@ -0,0 +1,214 @@
"""Immutable function vars."""
from __future__ import annotations
import dataclasses
import sys
from functools import cached_property
from typing import Any, Callable, Optional, Tuple, Union
from reflex.experimental.vars.base import ImmutableVar, LiteralVar
from reflex.vars import ImmutableVarData, Var, VarData
class FunctionVar(ImmutableVar):
"""Base class for immutable function vars."""
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
"""Call the function with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The function call operation.
"""
return ArgsFunctionOperation(
("...args",),
VarOperationCall(self, *args, ImmutableVar.create_safe("...args")),
)
def call(self, *args: Var | Any) -> VarOperationCall:
"""Call the function with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The function call operation.
"""
return VarOperationCall(self, *args)
class FunctionStringVar(FunctionVar):
"""Base class for immutable function vars from a string."""
def __init__(self, func: str, _var_data: VarData | None = None) -> None:
"""Initialize the function var.
Args:
func: The function to call.
_var_data: Additional hooks and imports associated with the Var.
"""
super(FunctionVar, self).__init__(
_var_name=func,
_var_type=Callable,
_var_data=ImmutableVarData.merge(_var_data),
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class VarOperationCall(ImmutableVar):
"""Base class for immutable vars that are the result of a function call."""
_func: Optional[FunctionVar] = dataclasses.field(default=None)
_args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
def __init__(
self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None
):
"""Initialize the function call var.
Args:
func: The function to call.
*args: The arguments to call the function with.
_var_data: Additional hooks and imports associated with the Var.
"""
super(VarOperationCall, self).__init__(
_var_name="",
_var_type=Any,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_func", func)
object.__setattr__(self, "_args", args)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._func._get_all_var_data() if self._func is not None else None,
*[var._get_all_var_data() for var in self._args],
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""
pass
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArgsFunctionOperation(FunctionVar):
"""Base class for immutable function defined via arguments and return expression."""
_args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
def __init__(
self,
args_names: Tuple[str, ...],
return_expr: Var | Any,
_var_data: VarData | None = None,
) -> None:
"""Initialize the function with arguments var.
Args:
args_names: The names of the arguments.
return_expr: The return expression of the function.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ArgsFunctionOperation, self).__init__(
_var_name=f"",
_var_type=Callable,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_args_names", args_names)
object.__setattr__(self, "_return_expr", return_expr)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._return_expr._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -379,7 +379,9 @@ def _decode_var_immutable(value: str) -> tuple[ImmutableVarData | None, str]:
serialized_data = m.group(1)
if serialized_data[1:].isnumeric():
if serialized_data.isnumeric() or (
serialized_data[0] == "-" and serialized_data[1:].isnumeric()
):
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
var_data = var._var_data
@ -473,7 +475,9 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
serialized_data = m.group(1)
if serialized_data[1:].isnumeric():
if serialized_data.isnumeric() or (
serialized_data[0] == "-" and serialized_data[1:].isnumeric()
):
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
var_data = var._var_data

View File

@ -1,4 +1,5 @@
import json
import math
import typing
from typing import Dict, List, Set, Tuple, Union
@ -8,13 +9,17 @@ from pandas import DataFrame
from reflex.base import Base
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.experimental.vars.base import (
ArgsFunctionOperation,
ConcatVarOperation,
FunctionStringVar,
ImmutableVar,
LiteralStringVar,
LiteralVar,
var_operation,
)
from reflex.experimental.vars.function import ArgsFunctionOperation, FunctionStringVar
from reflex.experimental.vars.number import (
LiteralBooleanVar,
LiteralNumberVar,
NumberVar,
)
from reflex.experimental.vars.sequence import ConcatVarOperation, LiteralStringVar
from reflex.state import BaseState
from reflex.utils.imports import ImportVar
from reflex.vars import (
@ -913,6 +918,60 @@ def test_function_var():
)
def test_var_operation():
@var_operation(output=NumberVar)
def add(a: Union[NumberVar, int], b: Union[NumberVar, int]) -> str:
return f"({a} + {b})"
assert str(add(1, 2)) == "(1 + 2)"
assert str(add(a=4, b=-9)) == "(4 + -9)"
five = LiteralNumberVar(5)
seven = add(2, five)
assert isinstance(seven, NumberVar)
def test_string_operations():
basic_string = LiteralStringVar.create("Hello, World!")
assert str(basic_string.length()) == '"Hello, World!".length'
assert str(basic_string.lower()) == '"Hello, World!".toLowerCase()'
assert str(basic_string.upper()) == '"Hello, World!".toUpperCase()'
assert str(basic_string.strip()) == '"Hello, World!".trim()'
assert str(basic_string.contains("World")) == '"Hello, World!".includes("World")'
assert (
str(basic_string.split(" ").join(",")) == '"Hello, World!".split(" ").join(",")'
)
def test_all_number_operations():
starting_number = LiteralNumberVar(-5.4)
complicated_number = (((-(starting_number + 1)) * 2 / 3) // 2 % 3) ** 2
assert (
str(complicated_number)
== "((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)"
)
even_more_complicated_number = ~(
abs(math.floor(complicated_number)) | 2 & 3 & round(complicated_number)
)
assert (
str(even_more_complicated_number)
== "!(((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) != 0) || (true && (Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)) != 0))))"
)
assert str(LiteralNumberVar(5) > False) == "(5 > 0)"
assert str(LiteralBooleanVar(False) < 5) == "((false ? 1 : 0) < 5)"
assert (
str(LiteralBooleanVar(False) < LiteralBooleanVar(True))
== "((false ? 1 : 0) < (true ? 1 : 0))"
)
def test_retrival():
var_without_data = ImmutableVar.create("test")
assert var_without_data is not None