[REF-3228] implement LiteralStringVar and format/retrieval mechanism (#3669)

* implement LiteralStringVar and format/retrieval mechanism

* use create safe

* add cached properties to ConcatVarOperation

* fix caches

* also include self

* fix inconsistencies in typings

* use default factory not default

* add missing docstring

* experiment with immutable var data

* solve pydantic issues

* add sorted function

* missing docs

* forgot ellipses

* give up on frozen

* dang it darglint

* fix string serialization bugs and remove unused code

* add returns statement

* whitespace moment

* add simple test for string concat

* export ConcatVarOperation
This commit is contained in:
Khaleel Al-Adhami 2024-07-17 17:01:27 -07:00 committed by GitHub
parent 94c4c2f29f
commit 458cbfac59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 564 additions and 30 deletions

View File

@ -2,8 +2,11 @@
from .base import ArrayVar as ArrayVar
from .base import BooleanVar as BooleanVar
from .base import ConcatVarOperation as ConcatVarOperation
from .base import FunctionVar as FunctionVar
from .base import ImmutableVar as ImmutableVar
from .base import LiteralStringVar as LiteralStringVar
from .base import LiteralVar as LiteralVar
from .base import NumberVar as NumberVar
from .base import ObjectVar as ObjectVar
from .base import StringVar as StringVar

View File

@ -3,13 +3,24 @@
from __future__ import annotations
import dataclasses
import json
import re
import sys
from functools import cached_property
from typing import Any, Optional, Type
from reflex import constants
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.utils import serializers, types
from reflex.utils.exceptions import VarTypeError
from reflex.vars import Var, VarData, _decode_var, _extract_var_data, _global_vars
from reflex.vars import (
ImmutableVarData,
Var,
VarData,
_decode_var_immutable,
_extract_var_data,
_global_vars,
)
@dataclasses.dataclass(
@ -27,7 +38,15 @@ class ImmutableVar(Var):
_var_type: Type = dataclasses.field(default=Any)
# Extra metadata associated with the Var
_var_data: Optional[VarData] = dataclasses.field(default=None)
_var_data: Optional[ImmutableVarData] = dataclasses.field(default=None)
def __str__(self) -> str:
"""String representation of the var. Guaranteed to be a valid Javascript expression.
Returns:
The name of the var.
"""
return self._var_name
@property
def _var_is_local(self) -> bool:
@ -59,12 +78,25 @@ class ImmutableVar(Var):
def __post_init__(self):
"""Post-initialize the var."""
# Decode any inline Var markup and apply it to the instance
_var_data, _var_name = _decode_var(self._var_name)
_var_data, _var_name = _decode_var_immutable(self._var_name)
if _var_data:
self.__init__(
_var_name, self._var_type, VarData.merge(self._var_data, _var_data)
_var_name,
self._var_type,
ImmutableVarData.merge(self._var_data, _var_data),
)
def __hash__(self) -> int:
"""Define a hash function for the var.
Returns:
The hash of the var.
"""
return hash((self._var_name, self._var_type, self._var_data))
def _get_all_var_data(self) -> ImmutableVarData | None:
return self._var_data
def _replace(self, merge_var_data=None, **kwargs: Any):
"""Make a copy of this Var with updated fields.
@ -96,7 +128,7 @@ class ImmutableVar(Var):
field_values = dict(
_var_name=kwargs.pop("_var_name", self._var_name),
_var_type=kwargs.pop("_var_type", self._var_type),
_var_data=VarData.merge(
_var_data=ImmutableVarData.merge(
kwargs.get("_var_data", self._var_data), merge_var_data
),
)
@ -109,7 +141,7 @@ class ImmutableVar(Var):
_var_is_local: bool | None = None,
_var_is_string: bool | None = None,
_var_data: VarData | None = None,
) -> Var | None:
) -> ImmutableVar | Var | None:
"""Create a var from a value.
Args:
@ -164,7 +196,15 @@ class ImmutableVar(Var):
return cls(
_var_name=name,
_var_type=type_,
_var_data=_var_data,
_var_data=(
ImmutableVarData(
state=_var_data.state,
imports=_var_data.imports,
hooks=_var_data.hooks,
)
if _var_data
else None
),
)
@classmethod
@ -174,7 +214,7 @@ class ImmutableVar(Var):
_var_is_local: bool | None = None,
_var_is_string: bool | None = None,
_var_data: VarData | None = None,
) -> Var:
) -> Var | ImmutableVar:
"""Create a var from a value, asserting that it is not None.
Args:
@ -234,3 +274,181 @@ class ArrayVar(ImmutableVar):
class FunctionVar(ImmutableVar):
"""Base class for immutable function vars."""
class LiteralVar(ImmutableVar):
"""Base class for immutable literal vars."""
def __post_init__(self):
"""Post-initialize the var."""
# Compile regex for finding reflex var tags.
_decode_var_pattern_re = (
rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
)
_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralStringVar(LiteralVar):
"""Base class for immutable literal string vars."""
_var_value: Optional[str] = dataclasses.field(default=None)
@classmethod
def create(
cls,
value: str,
_var_data: VarData | None = None,
) -> LiteralStringVar | ConcatVarOperation:
"""Create a var from a string value.
Args:
value: The value to create the var from.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
"""
if REFLEX_VAR_OPENING_TAG in value:
strings_and_vals: list[Var] = []
offset = 0
# Initialize some methods for reading json.
var_data_config = VarData().__config__
def json_loads(s):
try:
return var_data_config.json_loads(s)
except json.decoder.JSONDecodeError:
return var_data_config.json_loads(
var_data_config.json_loads(f'"{s}"')
)
# Find all tags.
while m := _decode_var_pattern.search(value):
start, end = m.span()
if start > 0:
strings_and_vals.append(LiteralStringVar.create(value[:start]))
serialized_data = m.group(1)
if serialized_data[1:].isnumeric():
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
strings_and_vals.append(var)
value = value[(end + len(var._var_name)) :]
else:
data = json_loads(serialized_data)
string_length = data.pop("string_length", None)
var_data = VarData.parse_obj(data)
# Use string length to compute positions of interpolations.
if string_length is not None:
realstart = start + offset
var_data.interpolations = [
(realstart, realstart + string_length)
]
strings_and_vals.append(
ImmutableVar.create_safe(
value[end : (end + string_length)], _var_data=var_data
)
)
value = value[(end + string_length) :]
offset += end - start
if value:
strings_and_vals.append(LiteralStringVar.create(value))
return ConcatVarOperation.create(
tuple(strings_and_vals), _var_data=_var_data
)
return cls(
_var_value=value,
_var_name=f'"{value}"',
_var_type=str,
_var_data=ImmutableVarData.merge(_var_data),
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ConcatVarOperation(StringVar):
"""Representing a concatenation of literal string vars."""
_var_value: tuple[Var, ...] = dataclasses.field(default_factory=tuple)
def __init__(self, _var_value: tuple[Var, ...], _var_data: VarData | None = None):
"""Initialize the operation of concatenating literal string vars.
Args:
_var_value: The list of vars to concatenate.
_var_data: Additional hooks and imports associated with the Var.
"""
super(ConcatVarOperation, self).__init__(
_var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str
)
object.__setattr__(self, "_var_value", _var_value)
object.__setattr__(self, "_var_name", self._cached_var_name)
@cached_property
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return "+".join([str(element) for element in self._var_value])
@cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return ImmutableVarData.merge(
*[var._get_all_var_data() for var in self._var_value], self._var_data
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Wrapper method for cached property.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
def __post_init__(self):
"""Post-initialize the var."""
pass
@classmethod
def create(
cls,
value: tuple[Var, ...],
_var_data: VarData | None = None,
) -> ConcatVarOperation:
"""Create a var from a tuple of values.
Args:
value: The value to create the var from.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
"""
return ConcatVarOperation(
_var_value=value,
_var_data=_var_data,
)

View File

@ -3,12 +3,14 @@
from __future__ import annotations
from collections import defaultdict
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
from reflex.base import Base
def merge_imports(*imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
def merge_imports(
*imports: ImportDict | ParsedImportDict | ImmutableParsedImportDict,
) -> ParsedImportDict:
"""Merge multiple import dicts together.
Args:
@ -19,7 +21,9 @@ def merge_imports(*imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
"""
all_imports = defaultdict(list)
for import_dict in imports:
for lib, fields in import_dict.items():
for lib, fields in (
import_dict if isinstance(import_dict, tuple) else import_dict.items()
):
all_imports[lib].extend(fields)
return all_imports
@ -48,7 +52,9 @@ def parse_imports(imports: ImportDict | ParsedImportDict) -> ParsedImportDict:
}
def collapse_imports(imports: ParsedImportDict) -> ParsedImportDict:
def collapse_imports(
imports: ParsedImportDict | ImmutableParsedImportDict,
) -> ParsedImportDict:
"""Remove all duplicate ImportVar within an ImportDict.
Args:
@ -58,8 +64,14 @@ def collapse_imports(imports: ParsedImportDict) -> ParsedImportDict:
The collapsed import dict.
"""
return {
lib: list(set(import_vars)) if isinstance(import_vars, list) else import_vars
for lib, import_vars in imports.items()
lib: (
list(set(import_vars))
if isinstance(import_vars, list)
else list(import_vars)
)
for lib, import_vars in (
imports if isinstance(imports, tuple) else imports.items()
)
}
@ -99,11 +111,61 @@ class ImportVar(Base):
else:
return self.tag or ""
def __hash__(self) -> int:
"""Define a hash function for the import var.
def __lt__(self, other: ImportVar) -> bool:
"""Compare two ImportVar objects.
Args:
other: The other ImportVar object to compare.
Returns:
The hash of the var.
Whether this ImportVar object is less than the other.
"""
return (
self.tag,
self.is_default,
self.alias,
self.install,
self.render,
self.transpile,
) < (
other.tag,
other.is_default,
other.alias,
other.install,
other.render,
other.transpile,
)
def __eq__(self, other: ImportVar) -> bool:
"""Check if two ImportVar objects are equal.
Args:
other: The other ImportVar object to compare.
Returns:
Whether the two ImportVar objects are equal.
"""
return (
self.tag,
self.is_default,
self.alias,
self.install,
self.render,
self.transpile,
) == (
other.tag,
other.is_default,
other.alias,
other.install,
other.render,
other.transpile,
)
def __hash__(self) -> int:
"""Hash the ImportVar object.
Returns:
The hash of the ImportVar object.
"""
return hash(
(
@ -120,3 +182,4 @@ class ImportVar(Base):
ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]]
ImportDict = Dict[str, ImportTypes]
ParsedImportDict = Dict[str, List[ImportVar]]
ImmutableParsedImportDict = Tuple[Tuple[str, Tuple[ImportVar, ...]], ...]

View File

@ -45,6 +45,7 @@ from reflex.utils.exceptions import (
# This module used to export ImportVar itself, so we still import it for export here
from reflex.utils.imports import (
ImmutableParsedImportDict,
ImportDict,
ImportVar,
ParsedImportDict,
@ -154,7 +155,7 @@ class VarData(Base):
super().__init__(**kwargs)
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None:
def merge(cls, *others: ImmutableVarData | VarData | None) -> VarData | None:
"""Merge multiple var data objects.
Args:
@ -172,8 +173,14 @@ class VarData(Base):
continue
state = state or var_data.state
_imports = imports.merge_imports(_imports, var_data.imports)
hooks.update(var_data.hooks)
interpolations += var_data.interpolations
hooks.update(
var_data.hooks
if isinstance(var_data.hooks, dict)
else {k: None for k in var_data.hooks}
)
interpolations += (
var_data.interpolations if isinstance(var_data, VarData) else []
)
return (
cls(
@ -231,6 +238,173 @@ class VarData(Base):
}
@dataclasses.dataclass(
eq=True,
frozen=True,
)
class ImmutableVarData:
"""Metadata associated with a Var."""
# The name of the enclosing state.
state: str = dataclasses.field(default="")
# Imports needed to render this var
imports: ImmutableParsedImportDict = dataclasses.field(default_factory=tuple)
# Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
def __init__(
self,
state: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
):
"""Initialize the var data.
Args:
state: The name of the enclosing state.
imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var.
"""
immutable_imports: ImmutableParsedImportDict = tuple(
sorted(
((k, tuple(sorted(v))) for k, v in parse_imports(imports or {}).items())
)
)
object.__setattr__(self, "state", state)
object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {}))
@classmethod
def merge(
cls, *others: ImmutableVarData | VarData | None
) -> ImmutableVarData | None:
"""Merge multiple var data objects.
Args:
*others: The var data objects to merge.
Returns:
The merged var data object.
"""
state = ""
_imports = {}
hooks = {}
for var_data in others:
if var_data is None:
continue
state = state or var_data.state
_imports = imports.merge_imports(_imports, var_data.imports)
hooks.update(
var_data.hooks
if isinstance(var_data.hooks, dict)
else {k: None for k in var_data.hooks}
)
return (
ImmutableVarData(
state=state,
imports=_imports,
hooks=hooks,
)
or None
)
def __bool__(self) -> bool:
"""Check if the var data is non-empty.
Returns:
True if any field is set to a non-default value.
"""
return bool(self.state or self.imports or self.hooks)
def __eq__(self, other: Any) -> bool:
"""Check if two var data objects are equal.
Args:
other: The other var data object to compare.
Returns:
True if all fields are equal and collapsed imports are equal.
"""
if not isinstance(other, (ImmutableVarData, VarData)):
return False
# Don't compare interpolations - that's added in by the decoder, and
# not part of the vardata itself.
return (
self.state == other.state
and self.hooks
== (
other.hooks
if isinstance(other, ImmutableVarData)
else tuple(other.hooks.keys())
)
and imports.collapse_imports(self.imports)
== imports.collapse_imports(other.imports)
)
def _decode_var_immutable(value: str) -> tuple[ImmutableVarData | None, str]:
"""Decode the state name from a formatted var.
Args:
value: The value to extract the state name from.
Returns:
The extracted state name and the value without the state name.
"""
var_datas = []
if isinstance(value, str):
# fast path if there is no encoded VarData
if constants.REFLEX_VAR_OPENING_TAG not in value:
return None, value
offset = 0
# Initialize some methods for reading json.
var_data_config = VarData().__config__
def json_loads(s):
try:
return var_data_config.json_loads(s)
except json.decoder.JSONDecodeError:
return var_data_config.json_loads(var_data_config.json_loads(f'"{s}"'))
# Find all tags.
while m := _decode_var_pattern.search(value):
start, end = m.span()
value = value[:start] + value[end:]
serialized_data = m.group(1)
if serialized_data[1:].isnumeric():
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
var_data = var._var_data
if var_data is not None:
realstart = start + offset
var_datas.append(var_data)
else:
# Read the JSON, pull out the string length, parse the rest as VarData.
data = json_loads(serialized_data)
string_length = data.pop("string_length", None)
var_data = VarData.parse_obj(data)
# Use string length to compute positions of interpolations.
if string_length is not None:
realstart = start + offset
var_data.interpolations = [(realstart, realstart + string_length)]
var_datas.append(var_data)
offset += end - start
return ImmutableVarData.merge(*var_datas) if var_datas else None, value
def _encode_var(value: Var) -> str:
"""Encode the state name into a formatted var.
@ -306,9 +480,6 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
if var_data is not None:
realstart = start + offset
var_data.interpolations = [
(realstart, realstart + len(var._var_name))
]
var_datas.append(var_data)
else:
@ -1814,6 +1985,14 @@ class Var:
"""
return self._var_data.state if self._var_data else ""
def _get_all_var_data(self) -> VarData | None:
"""Get all the var data.
Returns:
The var data.
"""
return self._var_data
@property
def _var_name_unwrapped(self) -> str:
"""Get the var str without wrapping in curly braces.

View File

@ -29,7 +29,7 @@ from reflex.state import State as State
from reflex.utils import console as console
from reflex.utils import format as format
from reflex.utils import types as types
from reflex.utils.imports import ImportDict, ParsedImportDict
from reflex.utils.imports import ImmutableParsedImportDict, ImportDict, ParsedImportDict
USED_VARIABLES: Incomplete
@ -47,7 +47,24 @@ class VarData(Base):
hooks: Dict[str, None] = {}
interpolations: List[Tuple[int, int]] = []
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None: ...
def merge(cls, *others: ImmutableVarData | VarData | None) -> VarData | None: ...
class ImmutableVarData:
state: str = ""
imports: ImmutableParsedImportDict = tuple()
hooks: Tuple[str, ...] = tuple()
def __init__(
self,
state: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
) -> None: ...
@classmethod
def merge(
cls, *others: ImmutableVarData | VarData | None
) -> ImmutableVarData | None: ...
def _decode_var_immutable(value: str) -> tuple[ImmutableVarData, str]: ...
class Var:
_var_name: str
@ -133,6 +150,7 @@ class Var:
@property
def _var_full_name(self) -> str: ...
def _var_set_state(self, state: Type[BaseState] | str) -> Any: ...
def _get_all_var_data(self) -> VarData: ...
@dataclass(eq=False)
class BaseVar(Var):

View File

@ -7,12 +7,17 @@ from pandas import DataFrame
from reflex.base import Base
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.experimental.vars.base import ImmutableVar
from reflex.experimental.vars.base import (
ConcatVarOperation,
ImmutableVar,
LiteralStringVar,
)
from reflex.state import BaseState
from reflex.utils.imports import ImportVar
from reflex.vars import (
BaseVar,
ComputedVar,
ImmutableVarData,
Var,
VarData,
computed_var,
@ -880,13 +885,61 @@ def test_retrival():
)
assert (
result_var_data.imports
== result_immutable_var_data.imports
== (
result_immutable_var_data.imports
if isinstance(result_immutable_var_data.imports, dict)
else {
k: list(v)
for k, v in result_immutable_var_data.imports
if k in original_var_data.imports
}
)
== original_var_data.imports
)
assert (
result_var_data.hooks
== result_immutable_var_data.hooks
== original_var_data.hooks
list(result_var_data.hooks.keys())
== (
list(result_immutable_var_data.hooks.keys())
if isinstance(result_immutable_var_data.hooks, dict)
else list(result_immutable_var_data.hooks)
)
== list(original_var_data.hooks.keys())
)
def test_fstring_concat():
original_var_with_data = Var.create_safe(
"imagination", _var_data=VarData(state="fear")
)
immutable_var_with_data = ImmutableVar.create_safe(
"consequences",
_var_data=VarData(
imports={
"react": [ImportVar(tag="useRef")],
"utils": [ImportVar(tag="useEffect")],
}
),
)
f_string = f"foo{original_var_with_data}bar{immutable_var_with_data}baz"
string_concat = LiteralStringVar.create(
f_string,
_var_data=VarData(
hooks={"const state = useContext(StateContexts.state)": None}
),
)
assert str(string_concat) == '"foo"+imagination+"bar"+consequences+"baz"'
assert isinstance(string_concat, ConcatVarOperation)
assert string_concat._get_all_var_data() == ImmutableVarData(
state="fear",
imports={
"react": [ImportVar(tag="useRef")],
"utils": [ImportVar(tag="useEffect")],
},
hooks={"const state = useContext(StateContexts.state)": None},
)