[REF-3227] implement more literal vars (#3687)

* implement more literal vars

* fix super issue

* pyright has a bug i think

* oh we changed that

* fix docs

* literalize vars recursively

* do what masen told me :D

* use dynamic keys

* forgot .create

* adjust _var_value

* dang it darglint

* add test for serializing literal vars into js exprs

* fix silly mistake

* add  handling for var and none

* use create safe

* is none bruh

* implement function vars and do various modification

* fix None issue

* clear a lot of creates that did nothing

* add tests to function vars

* added simple fix smh

* use fconcat to make an even more complicated test
This commit is contained in:
Khaleel Al-Adhami 2024-07-22 12:45:23 -07:00 committed by GitHub
parent 9666244a87
commit ea016314b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 594 additions and 32 deletions

View File

@ -3,10 +3,16 @@
from .base import ArrayVar as ArrayVar
from .base import BooleanVar as BooleanVar
from .base import ConcatVarOperation as ConcatVarOperation
from .base import FunctionStringVar as FunctionStringVar
from .base import FunctionVar as FunctionVar
from .base import ImmutableVar as ImmutableVar
from .base import LiteralArrayVar as LiteralArrayVar
from .base import LiteralBooleanVar as LiteralBooleanVar
from .base import LiteralNumberVar as LiteralNumberVar
from .base import LiteralObjectVar as LiteralObjectVar
from .base import LiteralStringVar as LiteralStringVar
from .base import LiteralVar as LiteralVar
from .base import NumberVar as NumberVar
from .base import ObjectVar as ObjectVar
from .base import StringVar as StringVar
from .base import VarOperationCall as VarOperationCall

View File

@ -7,9 +7,10 @@ import json
import re
import sys
from functools import cached_property
from typing import Any, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from reflex import constants
from reflex.base import Base
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.utils import serializers, types
from reflex.utils.exceptions import VarTypeError
@ -95,6 +96,11 @@ class ImmutableVar(Var):
return hash((self._var_name, self._var_type, self._var_data))
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return self._var_data
def _replace(self, merge_var_data=None, **kwargs: Any):
@ -275,10 +281,250 @@ class ArrayVar(ImmutableVar):
class FunctionVar(ImmutableVar):
"""Base class for immutable function vars."""
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
"""Call the function with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The function call operation.
"""
return ArgsFunctionOperation(
("...args",),
VarOperationCall(self, *args, ImmutableVar.create_safe("...args")),
)
def call(self, *args: Var | Any) -> VarOperationCall:
"""Call the function with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The function call operation.
"""
return VarOperationCall(self, *args)
class FunctionStringVar(FunctionVar):
"""Base class for immutable function vars from a string."""
def __init__(self, func: str, _var_data: VarData | None = None) -> None:
"""Initialize the function var.
Args:
func: The function to call.
_var_data: Additional hooks and imports associated with the Var.
"""
super(FunctionVar, self).__init__(
_var_name=func,
_var_type=Callable,
_var_data=ImmutableVarData.merge(_var_data),
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class VarOperationCall(ImmutableVar):
"""Base class for immutable vars that are the result of a function call."""
_func: Optional[FunctionVar] = dataclasses.field(default=None)
_args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
def __init__(
self, func: FunctionVar, *args: Var | Any, _var_data: VarData | None = None
):
"""Initialize the function call var.
Args:
func: The function to call.
*args: The arguments to call the function with.
_var_data: Additional hooks and imports associated with the Var.
"""
super(VarOperationCall, self).__init__(
_var_name="",
_var_type=Callable,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_func", func)
object.__setattr__(self, "_args", args)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._func._get_all_var_data() if self._func is not None else None,
*[var._get_all_var_data() for var in self._args],
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""
pass
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArgsFunctionOperation(FunctionVar):
"""Base class for immutable function defined via arguments and return expression."""
_args_names: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
def __init__(
self,
args_names: Tuple[str, ...],
return_expr: Var | Any,
_var_data: VarData | None = None,
) -> None:
"""Initialize the function with arguments var.
Args:
args_names: The names of the arguments.
return_expr: The return expression of the function.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ArgsFunctionOperation, self).__init__(
_var_name=f"",
_var_type=Callable,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_args_names", args_names)
object.__setattr__(self, "_return_expr", return_expr)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"(({', '.join(self._args_names)}) => ({str(LiteralVar.create(self._return_expr))}))"
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
self._return_expr._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""
class LiteralVar(ImmutableVar):
"""Base class for immutable literal vars."""
@classmethod
def create(
cls,
value: Any,
_var_data: VarData | None = None,
) -> Var:
"""Create a var from a value.
Args:
value: The value to create the var from.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
Raises:
TypeError: If the value is not a supported type for LiteralVar.
"""
if isinstance(value, Var):
if _var_data is None:
return value
return value._replace(merge_var_data=_var_data)
if value is None:
return ImmutableVar.create_safe("null", _var_data=_var_data)
if isinstance(value, Base):
return LiteralObjectVar(
value.dict(), _var_type=type(value), _var_data=_var_data
)
if isinstance(value, str):
return LiteralStringVar.create(value, _var_data=_var_data)
constructor = type_mapping.get(type(value))
if constructor is None:
raise TypeError(f"Unsupported type {type(value)} for LiteralVar.")
return constructor(value, _var_data=_var_data)
def __post_init__(self):
"""Post-initialize the var."""
@ -298,7 +544,25 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
class LiteralStringVar(LiteralVar):
"""Base class for immutable literal string vars."""
_var_value: Optional[str] = dataclasses.field(default=None)
_var_value: str = dataclasses.field(default="")
def __init__(
self,
_var_value: str,
_var_data: VarData | None = None,
):
"""Initialize the string var.
Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralStringVar, self).__init__(
_var_name=f'"{_var_value}"',
_var_type=str,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_var_value", _var_value)
@classmethod
def create(
@ -316,7 +580,7 @@ class LiteralStringVar(LiteralVar):
The var.
"""
if REFLEX_VAR_OPENING_TAG in value:
strings_and_vals: list[Var] = []
strings_and_vals: list[Var | str] = []
offset = 0
# Initialize some methods for reading json.
@ -334,7 +598,7 @@ class LiteralStringVar(LiteralVar):
while m := _decode_var_pattern.search(value):
start, end = m.span()
if start > 0:
strings_and_vals.append(LiteralStringVar.create(value[:start]))
strings_and_vals.append(value[:start])
serialized_data = m.group(1)
@ -364,17 +628,13 @@ class LiteralStringVar(LiteralVar):
offset += end - start
if value:
strings_and_vals.append(LiteralStringVar.create(value))
strings_and_vals.append(value)
return ConcatVarOperation.create(
tuple(strings_and_vals), _var_data=_var_data
)
return ConcatVarOperation(*strings_and_vals, _var_data=_var_data)
return cls(
_var_value=value,
_var_name=f'"{value}"',
_var_type=str,
_var_data=ImmutableVarData.merge(_var_data),
return LiteralStringVar(
value,
_var_data=_var_data,
)
@ -386,20 +646,33 @@ class LiteralStringVar(LiteralVar):
class ConcatVarOperation(StringVar):
"""Representing a concatenation of literal string vars."""
_var_value: tuple[Var, ...] = dataclasses.field(default_factory=tuple)
_var_value: Tuple[Union[Var, str], ...] = dataclasses.field(default_factory=tuple)
def __init__(self, _var_value: tuple[Var, ...], _var_data: VarData | None = None):
def __init__(self, *value: Var | str, _var_data: VarData | None = None):
"""Initialize the operation of concatenating literal string vars.
Args:
_var_value: The list of vars to concatenate.
value: The values to concatenate.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ConcatVarOperation, self).__init__(
_var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str
)
object.__setattr__(self, "_var_value", _var_value)
object.__setattr__(self, "_var_name", self._cached_var_name)
object.__setattr__(self, "_var_value", value)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
@ -408,7 +681,16 @@ class ConcatVarOperation(StringVar):
Returns:
The name of the var.
"""
return "+".join([str(element) for element in self._var_value])
return (
"("
+ "+".join(
[
str(element) if isinstance(element, Var) else f'"{element}"'
for element in self._var_value
]
)
+ ")"
)
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
@ -418,7 +700,12 @@ class ConcatVarOperation(StringVar):
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
*[var._get_all_var_data() for var in self._var_value], self._var_data
*[
var._get_all_var_data()
for var in self._var_value
if isinstance(var, Var)
],
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
@ -433,22 +720,236 @@ class ConcatVarOperation(StringVar):
"""Post-initialize the var."""
pass
@classmethod
def create(
cls,
value: tuple[Var, ...],
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralBooleanVar(LiteralVar):
"""Base class for immutable literal boolean vars."""
_var_value: bool = dataclasses.field(default=False)
def __init__(
self,
_var_value: bool,
_var_data: VarData | None = None,
) -> ConcatVarOperation:
"""Create a var from a tuple of values.
):
"""Initialize the boolean var.
Args:
value: The value to create the var from.
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralBooleanVar, self).__init__(
_var_name="true" if _var_value else "false",
_var_type=bool,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_var_value", _var_value)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralNumberVar(LiteralVar):
"""Base class for immutable literal number vars."""
_var_value: float | int = dataclasses.field(default=0)
def __init__(
self,
_var_value: float | int,
_var_data: VarData | None = None,
):
"""Initialize the number var.
Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralNumberVar, self).__init__(
_var_name=str(_var_value),
_var_type=type(_var_value),
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(self, "_var_value", _var_value)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralObjectVar(LiteralVar):
"""Base class for immutable literal object vars."""
_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
default_factory=dict
)
def __init__(
self,
_var_value: dict[Var | Any, Var | Any],
_var_type: Type = dict,
_var_data: VarData | None = None,
):
"""Initialize the object var.
Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralObjectVar, self).__init__(
_var_name="",
_var_type=_var_type,
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self,
"_var_value",
_var_value,
)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The var.
The attribute of the var.
"""
return ConcatVarOperation(
_var_value=value,
_var_data=_var_data,
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return (
"{ "
+ ", ".join(
[
f"[{str(LiteralVar.create(key))}] : {str(LiteralVar.create(value))}"
for key, value in self._var_value.items()
]
)
+ " }"
)
@cached_property
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
*[
value._get_all_var_data()
for key, value in self._var_value
if isinstance(value, Var)
],
*[
key._get_all_var_data()
for key, value in self._var_value
if isinstance(key, Var)
],
self._var_data,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralArrayVar(LiteralVar):
"""Base class for immutable literal array vars."""
_var_value: Union[
List[Union[Var, Any]], Set[Union[Var, Any]], Tuple[Union[Var, Any], ...]
] = dataclasses.field(default_factory=list)
def __init__(
self,
_var_value: list[Var | Any] | tuple[Var | Any] | set[Var | Any],
_var_data: VarData | None = None,
):
"""Initialize the array var.
Args:
_var_value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(LiteralArrayVar, self).__init__(
_var_name="",
_var_data=ImmutableVarData.merge(_var_data),
_var_type=list,
)
object.__setattr__(self, "_var_value", _var_value)
object.__delattr__(self, "_var_name")
def __getattr__(self, name):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_var_name":
return self._cached_var_name
return super(type(self), self).__getattr__(name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return (
"["
+ ", ".join(
[str(LiteralVar.create(element)) for element in self._var_value]
)
+ "]"
)
@cached_property
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
*[
var._get_all_var_data()
for var in self._var_value
if isinstance(var, Var)
],
self._var_data,
)
type_mapping = {
int: LiteralNumberVar,
float: LiteralNumberVar,
bool: LiteralBooleanVar,
dict: LiteralObjectVar,
list: LiteralArrayVar,
tuple: LiteralArrayVar,
set: LiteralArrayVar,
}

View File

@ -8,9 +8,12 @@ from pandas import DataFrame
from reflex.base import Base
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.experimental.vars.base import (
ArgsFunctionOperation,
ConcatVarOperation,
FunctionStringVar,
ImmutableVar,
LiteralStringVar,
LiteralVar,
)
from reflex.state import BaseState
from reflex.utils.imports import ImportVar
@ -858,6 +861,58 @@ def test_state_with_initial_computed_var(
assert runtime_dict[var_name] == expected_runtime
def test_literal_var():
complicated_var = LiteralVar.create(
[
{"a": 1, "b": 2, "c": {"d": 3, "e": 4}},
[1, 2, 3, 4],
9,
"string",
True,
False,
None,
set([1, 2, 3]),
]
)
assert (
str(complicated_var)
== '[{ ["a"] : 1, ["b"] : 2, ["c"] : { ["d"] : 3, ["e"] : 4 } }, [1, 2, 3, 4], 9, "string", true, false, null, [1, 2, 3]]'
)
def test_function_var():
addition_func = FunctionStringVar("((a, b) => a + b)")
assert str(addition_func.call(1, 2)) == "(((a, b) => a + b)(1, 2))"
manual_addition_func = ArgsFunctionOperation(
("a", "b"),
{
"args": [ImmutableVar.create_safe("a"), ImmutableVar.create_safe("b")],
"result": ImmutableVar.create_safe("a + b"),
},
)
assert (
str(manual_addition_func.call(1, 2))
== '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))'
)
increment_func = addition_func(1)
assert (
str(increment_func.call(2))
== "(((...args) => ((((a, b) => a + b)(1, ...args))))(2))"
)
create_hello_statement = ArgsFunctionOperation(
("name",), f"Hello, {ImmutableVar.create_safe('name')}!"
)
first_name = LiteralStringVar("Steven")
last_name = LiteralStringVar("Universe")
assert (
str(create_hello_statement.call(f"{first_name} {last_name}"))
== '(((name) => (("Hello, "+name+"!")))(("Steven"+" "+"Universe")))'
)
def test_retrival():
var_without_data = ImmutableVar.create("test")
assert var_without_data is not None
@ -931,7 +986,7 @@ def test_fstring_concat():
),
)
assert str(string_concat) == '"foo"+imagination+"bar"+consequences+"baz"'
assert str(string_concat) == '("foo"+imagination+"bar"+consequences+"baz")'
assert isinstance(string_concat, ConcatVarOperation)
assert string_concat._get_all_var_data() == ImmutableVarData(
state="fear",