From 458cbfac598b1aba5dc1a0bd75b1f924b66b5524 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Wed, 17 Jul 2024 17:01:27 -0700 Subject: [PATCH] [REF-3228] implement LiteralStringVar and format/retrieval mechanism (#3669) * implement LiteralStringVar and format/retrieval mechanism * use create safe * add cached properties to ConcatVarOperation * fix caches * also include self * fix inconsistencies in typings * use default factory not default * add missing docstring * experiment with immutable var data * solve pydantic issues * add sorted function * missing docs * forgot ellipses * give up on frozen * dang it darglint * fix string serialization bugs and remove unused code * add returns statement * whitespace moment * add simple test for string concat * export ConcatVarOperation --- reflex/experimental/vars/__init__.py | 3 + reflex/experimental/vars/base.py | 234 ++++++++++++++++++++++++++- reflex/utils/imports.py | 81 ++++++++-- reflex/vars.py | 191 +++++++++++++++++++++- reflex/vars.pyi | 22 ++- tests/test_var.py | 63 +++++++- 6 files changed, 564 insertions(+), 30 deletions(-) diff --git a/reflex/experimental/vars/__init__.py b/reflex/experimental/vars/__init__.py index 3327e9119..98fa802d3 100644 --- a/reflex/experimental/vars/__init__.py +++ b/reflex/experimental/vars/__init__.py @@ -2,8 +2,11 @@ from .base import ArrayVar as ArrayVar from .base import BooleanVar as BooleanVar +from .base import ConcatVarOperation as ConcatVarOperation from .base import FunctionVar as FunctionVar from .base import ImmutableVar as ImmutableVar +from .base import LiteralStringVar as LiteralStringVar +from .base import LiteralVar as LiteralVar from .base import NumberVar as NumberVar from .base import ObjectVar as ObjectVar from .base import StringVar as StringVar diff --git a/reflex/experimental/vars/base.py b/reflex/experimental/vars/base.py index f1b2288b0..258f8d6c3 100644 --- a/reflex/experimental/vars/base.py +++ b/reflex/experimental/vars/base.py @@ -3,13 +3,24 @@ from __future__ import annotations import dataclasses +import json +import re import sys +from functools import cached_property from typing import Any, Optional, Type +from reflex import constants from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.utils import serializers, types from reflex.utils.exceptions import VarTypeError -from reflex.vars import Var, VarData, _decode_var, _extract_var_data, _global_vars +from reflex.vars import ( + ImmutableVarData, + Var, + VarData, + _decode_var_immutable, + _extract_var_data, + _global_vars, +) @dataclasses.dataclass( @@ -27,7 +38,15 @@ class ImmutableVar(Var): _var_type: Type = dataclasses.field(default=Any) # Extra metadata associated with the Var - _var_data: Optional[VarData] = dataclasses.field(default=None) + _var_data: Optional[ImmutableVarData] = dataclasses.field(default=None) + + def __str__(self) -> str: + """String representation of the var. Guaranteed to be a valid Javascript expression. + + Returns: + The name of the var. + """ + return self._var_name @property def _var_is_local(self) -> bool: @@ -59,12 +78,25 @@ class ImmutableVar(Var): def __post_init__(self): """Post-initialize the var.""" # Decode any inline Var markup and apply it to the instance - _var_data, _var_name = _decode_var(self._var_name) + _var_data, _var_name = _decode_var_immutable(self._var_name) if _var_data: self.__init__( - _var_name, self._var_type, VarData.merge(self._var_data, _var_data) + _var_name, + self._var_type, + ImmutableVarData.merge(self._var_data, _var_data), ) + def __hash__(self) -> int: + """Define a hash function for the var. + + Returns: + The hash of the var. + """ + return hash((self._var_name, self._var_type, self._var_data)) + + def _get_all_var_data(self) -> ImmutableVarData | None: + return self._var_data + def _replace(self, merge_var_data=None, **kwargs: Any): """Make a copy of this Var with updated fields. @@ -96,7 +128,7 @@ class ImmutableVar(Var): field_values = dict( _var_name=kwargs.pop("_var_name", self._var_name), _var_type=kwargs.pop("_var_type", self._var_type), - _var_data=VarData.merge( + _var_data=ImmutableVarData.merge( kwargs.get("_var_data", self._var_data), merge_var_data ), ) @@ -109,7 +141,7 @@ class ImmutableVar(Var): _var_is_local: bool | None = None, _var_is_string: bool | None = None, _var_data: VarData | None = None, - ) -> Var | None: + ) -> ImmutableVar | Var | None: """Create a var from a value. Args: @@ -164,7 +196,15 @@ class ImmutableVar(Var): return cls( _var_name=name, _var_type=type_, - _var_data=_var_data, + _var_data=( + ImmutableVarData( + state=_var_data.state, + imports=_var_data.imports, + hooks=_var_data.hooks, + ) + if _var_data + else None + ), ) @classmethod @@ -174,7 +214,7 @@ class ImmutableVar(Var): _var_is_local: bool | None = None, _var_is_string: bool | None = None, _var_data: VarData | None = None, - ) -> Var: + ) -> Var | ImmutableVar: """Create a var from a value, asserting that it is not None. Args: @@ -234,3 +274,181 @@ class ArrayVar(ImmutableVar): class FunctionVar(ImmutableVar): """Base class for immutable function vars.""" + + +class LiteralVar(ImmutableVar): + """Base class for immutable literal vars.""" + + def __post_init__(self): + """Post-initialize the var.""" + + +# Compile regex for finding reflex var tags. +_decode_var_pattern_re = ( + rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}" +) +_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class LiteralStringVar(LiteralVar): + """Base class for immutable literal string vars.""" + + _var_value: Optional[str] = dataclasses.field(default=None) + + @classmethod + def create( + cls, + value: str, + _var_data: VarData | None = None, + ) -> LiteralStringVar | ConcatVarOperation: + """Create a var from a string value. + + Args: + value: The value to create the var from. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + if REFLEX_VAR_OPENING_TAG in value: + strings_and_vals: list[Var] = [] + offset = 0 + + # Initialize some methods for reading json. + var_data_config = VarData().__config__ + + def json_loads(s): + try: + return var_data_config.json_loads(s) + except json.decoder.JSONDecodeError: + return var_data_config.json_loads( + var_data_config.json_loads(f'"{s}"') + ) + + # Find all tags. + while m := _decode_var_pattern.search(value): + start, end = m.span() + if start > 0: + strings_and_vals.append(LiteralStringVar.create(value[:start])) + + serialized_data = m.group(1) + + if serialized_data[1:].isnumeric(): + # This is a global immutable var. + var = _global_vars[int(serialized_data)] + strings_and_vals.append(var) + value = value[(end + len(var._var_name)) :] + else: + data = json_loads(serialized_data) + string_length = data.pop("string_length", None) + var_data = VarData.parse_obj(data) + + # Use string length to compute positions of interpolations. + if string_length is not None: + realstart = start + offset + var_data.interpolations = [ + (realstart, realstart + string_length) + ] + strings_and_vals.append( + ImmutableVar.create_safe( + value[end : (end + string_length)], _var_data=var_data + ) + ) + value = value[(end + string_length) :] + + offset += end - start + + if value: + strings_and_vals.append(LiteralStringVar.create(value)) + + return ConcatVarOperation.create( + tuple(strings_and_vals), _var_data=_var_data + ) + + return cls( + _var_value=value, + _var_name=f'"{value}"', + _var_type=str, + _var_data=ImmutableVarData.merge(_var_data), + ) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class ConcatVarOperation(StringVar): + """Representing a concatenation of literal string vars.""" + + _var_value: tuple[Var, ...] = dataclasses.field(default_factory=tuple) + + def __init__(self, _var_value: tuple[Var, ...], _var_data: VarData | None = None): + """Initialize the operation of concatenating literal string vars. + + Args: + _var_value: The list of vars to concatenate. + _var_data: Additional hooks and imports associated with the Var. + """ + super(ConcatVarOperation, self).__init__( + _var_name="", _var_data=ImmutableVarData.merge(_var_data), _var_type=str + ) + object.__setattr__(self, "_var_value", _var_value) + object.__setattr__(self, "_var_name", self._cached_var_name) + + @cached_property + def _cached_var_name(self) -> str: + """The name of the var. + + Returns: + The name of the var. + """ + return "+".join([str(element) for element in self._var_value]) + + @cached_property + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get all VarData associated with the Var. + + Returns: + The VarData of the components and all of its children. + """ + return ImmutableVarData.merge( + *[var._get_all_var_data() for var in self._var_value], self._var_data + ) + + def _get_all_var_data(self) -> ImmutableVarData | None: + """Wrapper method for cached property. + + Returns: + The VarData of the components and all of its children. + """ + return self._cached_get_all_var_data + + def __post_init__(self): + """Post-initialize the var.""" + pass + + @classmethod + def create( + cls, + value: tuple[Var, ...], + _var_data: VarData | None = None, + ) -> ConcatVarOperation: + """Create a var from a tuple of values. + + Args: + value: The value to create the var from. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + return ConcatVarOperation( + _var_value=value, + _var_data=_var_data, + ) diff --git a/reflex/utils/imports.py b/reflex/utils/imports.py index 397c305ff..d58c2bf3f 100644 --- a/reflex/utils/imports.py +++ b/reflex/utils/imports.py @@ -3,12 +3,14 @@ from __future__ import annotations from collections import defaultdict -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from reflex.base import Base -def merge_imports(*imports: ImportDict | ParsedImportDict) -> ParsedImportDict: +def merge_imports( + *imports: ImportDict | ParsedImportDict | ImmutableParsedImportDict, +) -> ParsedImportDict: """Merge multiple import dicts together. Args: @@ -19,7 +21,9 @@ def merge_imports(*imports: ImportDict | ParsedImportDict) -> ParsedImportDict: """ all_imports = defaultdict(list) for import_dict in imports: - for lib, fields in import_dict.items(): + for lib, fields in ( + import_dict if isinstance(import_dict, tuple) else import_dict.items() + ): all_imports[lib].extend(fields) return all_imports @@ -48,7 +52,9 @@ def parse_imports(imports: ImportDict | ParsedImportDict) -> ParsedImportDict: } -def collapse_imports(imports: ParsedImportDict) -> ParsedImportDict: +def collapse_imports( + imports: ParsedImportDict | ImmutableParsedImportDict, +) -> ParsedImportDict: """Remove all duplicate ImportVar within an ImportDict. Args: @@ -58,8 +64,14 @@ def collapse_imports(imports: ParsedImportDict) -> ParsedImportDict: The collapsed import dict. """ return { - lib: list(set(import_vars)) if isinstance(import_vars, list) else import_vars - for lib, import_vars in imports.items() + lib: ( + list(set(import_vars)) + if isinstance(import_vars, list) + else list(import_vars) + ) + for lib, import_vars in ( + imports if isinstance(imports, tuple) else imports.items() + ) } @@ -99,11 +111,61 @@ class ImportVar(Base): else: return self.tag or "" - def __hash__(self) -> int: - """Define a hash function for the import var. + def __lt__(self, other: ImportVar) -> bool: + """Compare two ImportVar objects. + + Args: + other: The other ImportVar object to compare. Returns: - The hash of the var. + Whether this ImportVar object is less than the other. + """ + return ( + self.tag, + self.is_default, + self.alias, + self.install, + self.render, + self.transpile, + ) < ( + other.tag, + other.is_default, + other.alias, + other.install, + other.render, + other.transpile, + ) + + def __eq__(self, other: ImportVar) -> bool: + """Check if two ImportVar objects are equal. + + Args: + other: The other ImportVar object to compare. + + Returns: + Whether the two ImportVar objects are equal. + """ + return ( + self.tag, + self.is_default, + self.alias, + self.install, + self.render, + self.transpile, + ) == ( + other.tag, + other.is_default, + other.alias, + other.install, + other.render, + other.transpile, + ) + + def __hash__(self) -> int: + """Hash the ImportVar object. + + Returns: + The hash of the ImportVar object. """ return hash( ( @@ -120,3 +182,4 @@ class ImportVar(Base): ImportTypes = Union[str, ImportVar, List[Union[str, ImportVar]], List[ImportVar]] ImportDict = Dict[str, ImportTypes] ParsedImportDict = Dict[str, List[ImportVar]] +ImmutableParsedImportDict = Tuple[Tuple[str, Tuple[ImportVar, ...]], ...] diff --git a/reflex/vars.py b/reflex/vars.py index 8d93f99c0..c6ad4eed5 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -45,6 +45,7 @@ from reflex.utils.exceptions import ( # This module used to export ImportVar itself, so we still import it for export here from reflex.utils.imports import ( + ImmutableParsedImportDict, ImportDict, ImportVar, ParsedImportDict, @@ -154,7 +155,7 @@ class VarData(Base): super().__init__(**kwargs) @classmethod - def merge(cls, *others: VarData | None) -> VarData | None: + def merge(cls, *others: ImmutableVarData | VarData | None) -> VarData | None: """Merge multiple var data objects. Args: @@ -172,8 +173,14 @@ class VarData(Base): continue state = state or var_data.state _imports = imports.merge_imports(_imports, var_data.imports) - hooks.update(var_data.hooks) - interpolations += var_data.interpolations + hooks.update( + var_data.hooks + if isinstance(var_data.hooks, dict) + else {k: None for k in var_data.hooks} + ) + interpolations += ( + var_data.interpolations if isinstance(var_data, VarData) else [] + ) return ( cls( @@ -231,6 +238,173 @@ class VarData(Base): } +@dataclasses.dataclass( + eq=True, + frozen=True, +) +class ImmutableVarData: + """Metadata associated with a Var.""" + + # The name of the enclosing state. + state: str = dataclasses.field(default="") + + # Imports needed to render this var + imports: ImmutableParsedImportDict = dataclasses.field(default_factory=tuple) + + # Hooks that need to be present in the component to render this var + hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple) + + def __init__( + self, + state: str = "", + imports: ImportDict | ParsedImportDict | None = None, + hooks: dict[str, None] | None = None, + ): + """Initialize the var data. + + Args: + state: The name of the enclosing state. + imports: Imports needed to render this var. + hooks: Hooks that need to be present in the component to render this var. + """ + immutable_imports: ImmutableParsedImportDict = tuple( + sorted( + ((k, tuple(sorted(v))) for k, v in parse_imports(imports or {}).items()) + ) + ) + object.__setattr__(self, "state", state) + object.__setattr__(self, "imports", immutable_imports) + object.__setattr__(self, "hooks", tuple(hooks or {})) + + @classmethod + def merge( + cls, *others: ImmutableVarData | VarData | None + ) -> ImmutableVarData | None: + """Merge multiple var data objects. + + Args: + *others: The var data objects to merge. + + Returns: + The merged var data object. + """ + state = "" + _imports = {} + hooks = {} + for var_data in others: + if var_data is None: + continue + state = state or var_data.state + _imports = imports.merge_imports(_imports, var_data.imports) + hooks.update( + var_data.hooks + if isinstance(var_data.hooks, dict) + else {k: None for k in var_data.hooks} + ) + + return ( + ImmutableVarData( + state=state, + imports=_imports, + hooks=hooks, + ) + or None + ) + + def __bool__(self) -> bool: + """Check if the var data is non-empty. + + Returns: + True if any field is set to a non-default value. + """ + return bool(self.state or self.imports or self.hooks) + + def __eq__(self, other: Any) -> bool: + """Check if two var data objects are equal. + + Args: + other: The other var data object to compare. + + Returns: + True if all fields are equal and collapsed imports are equal. + """ + if not isinstance(other, (ImmutableVarData, VarData)): + return False + + # Don't compare interpolations - that's added in by the decoder, and + # not part of the vardata itself. + return ( + self.state == other.state + and self.hooks + == ( + other.hooks + if isinstance(other, ImmutableVarData) + else tuple(other.hooks.keys()) + ) + and imports.collapse_imports(self.imports) + == imports.collapse_imports(other.imports) + ) + + +def _decode_var_immutable(value: str) -> tuple[ImmutableVarData | None, str]: + """Decode the state name from a formatted var. + + Args: + value: The value to extract the state name from. + + Returns: + The extracted state name and the value without the state name. + """ + var_datas = [] + if isinstance(value, str): + # fast path if there is no encoded VarData + if constants.REFLEX_VAR_OPENING_TAG not in value: + return None, value + + offset = 0 + + # Initialize some methods for reading json. + var_data_config = VarData().__config__ + + def json_loads(s): + try: + return var_data_config.json_loads(s) + except json.decoder.JSONDecodeError: + return var_data_config.json_loads(var_data_config.json_loads(f'"{s}"')) + + # Find all tags. + while m := _decode_var_pattern.search(value): + start, end = m.span() + value = value[:start] + value[end:] + + serialized_data = m.group(1) + + if serialized_data[1:].isnumeric(): + # This is a global immutable var. + var = _global_vars[int(serialized_data)] + var_data = var._var_data + + if var_data is not None: + realstart = start + offset + + var_datas.append(var_data) + else: + # Read the JSON, pull out the string length, parse the rest as VarData. + data = json_loads(serialized_data) + string_length = data.pop("string_length", None) + var_data = VarData.parse_obj(data) + + # Use string length to compute positions of interpolations. + if string_length is not None: + realstart = start + offset + var_data.interpolations = [(realstart, realstart + string_length)] + + var_datas.append(var_data) + offset += end - start + + return ImmutableVarData.merge(*var_datas) if var_datas else None, value + + def _encode_var(value: Var) -> str: """Encode the state name into a formatted var. @@ -306,9 +480,6 @@ def _decode_var(value: str) -> tuple[VarData | None, str]: if var_data is not None: realstart = start + offset - var_data.interpolations = [ - (realstart, realstart + len(var._var_name)) - ] var_datas.append(var_data) else: @@ -1814,6 +1985,14 @@ class Var: """ return self._var_data.state if self._var_data else "" + def _get_all_var_data(self) -> VarData | None: + """Get all the var data. + + Returns: + The var data. + """ + return self._var_data + @property def _var_name_unwrapped(self) -> str: """Get the var str without wrapping in curly braces. diff --git a/reflex/vars.pyi b/reflex/vars.pyi index e9b5d041d..4aa6afc33 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -29,7 +29,7 @@ from reflex.state import State as State from reflex.utils import console as console from reflex.utils import format as format from reflex.utils import types as types -from reflex.utils.imports import ImportDict, ParsedImportDict +from reflex.utils.imports import ImmutableParsedImportDict, ImportDict, ParsedImportDict USED_VARIABLES: Incomplete @@ -47,7 +47,24 @@ class VarData(Base): hooks: Dict[str, None] = {} interpolations: List[Tuple[int, int]] = [] @classmethod - def merge(cls, *others: VarData | None) -> VarData | None: ... + def merge(cls, *others: ImmutableVarData | VarData | None) -> VarData | None: ... + +class ImmutableVarData: + state: str = "" + imports: ImmutableParsedImportDict = tuple() + hooks: Tuple[str, ...] = tuple() + def __init__( + self, + state: str = "", + imports: ImportDict | ParsedImportDict | None = None, + hooks: dict[str, None] | None = None, + ) -> None: ... + @classmethod + def merge( + cls, *others: ImmutableVarData | VarData | None + ) -> ImmutableVarData | None: ... + +def _decode_var_immutable(value: str) -> tuple[ImmutableVarData, str]: ... class Var: _var_name: str @@ -133,6 +150,7 @@ class Var: @property def _var_full_name(self) -> str: ... def _var_set_state(self, state: Type[BaseState] | str) -> Any: ... + def _get_all_var_data(self) -> VarData: ... @dataclass(eq=False) class BaseVar(Var): diff --git a/tests/test_var.py b/tests/test_var.py index 5284bf98d..78b3a2160 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -7,12 +7,17 @@ from pandas import DataFrame from reflex.base import Base from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG -from reflex.experimental.vars.base import ImmutableVar +from reflex.experimental.vars.base import ( + ConcatVarOperation, + ImmutableVar, + LiteralStringVar, +) from reflex.state import BaseState from reflex.utils.imports import ImportVar from reflex.vars import ( BaseVar, ComputedVar, + ImmutableVarData, Var, VarData, computed_var, @@ -880,13 +885,61 @@ def test_retrival(): ) assert ( result_var_data.imports - == result_immutable_var_data.imports + == ( + result_immutable_var_data.imports + if isinstance(result_immutable_var_data.imports, dict) + else { + k: list(v) + for k, v in result_immutable_var_data.imports + if k in original_var_data.imports + } + ) == original_var_data.imports ) assert ( - result_var_data.hooks - == result_immutable_var_data.hooks - == original_var_data.hooks + list(result_var_data.hooks.keys()) + == ( + list(result_immutable_var_data.hooks.keys()) + if isinstance(result_immutable_var_data.hooks, dict) + else list(result_immutable_var_data.hooks) + ) + == list(original_var_data.hooks.keys()) + ) + + +def test_fstring_concat(): + original_var_with_data = Var.create_safe( + "imagination", _var_data=VarData(state="fear") + ) + + immutable_var_with_data = ImmutableVar.create_safe( + "consequences", + _var_data=VarData( + imports={ + "react": [ImportVar(tag="useRef")], + "utils": [ImportVar(tag="useEffect")], + } + ), + ) + + f_string = f"foo{original_var_with_data}bar{immutable_var_with_data}baz" + + string_concat = LiteralStringVar.create( + f_string, + _var_data=VarData( + hooks={"const state = useContext(StateContexts.state)": None} + ), + ) + + assert str(string_concat) == '"foo"+imagination+"bar"+consequences+"baz"' + assert isinstance(string_concat, ConcatVarOperation) + assert string_concat._get_all_var_data() == ImmutableVarData( + state="fear", + imports={ + "react": [ImportVar(tag="useRef")], + "utils": [ImportVar(tag="useEffect")], + }, + hooks={"const state = useContext(StateContexts.state)": None}, )