diff --git a/reflex/experimental/vars/base.py b/reflex/experimental/vars/base.py index 52bce9161..3dc1f0d7c 100644 --- a/reflex/experimental/vars/base.py +++ b/reflex/experimental/vars/base.py @@ -6,9 +6,10 @@ import dataclasses import sys from typing import Any, Optional, Type +from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.utils import serializers, types from reflex.utils.exceptions import VarTypeError -from reflex.vars import Var, VarData, _extract_var_data +from reflex.vars import Var, VarData, _decode_var, _extract_var_data, _global_vars @dataclasses.dataclass( @@ -55,6 +56,15 @@ class ImmutableVar(Var): """ return False + def __post_init__(self): + """Post-initialize the var.""" + # Decode any inline Var markup and apply it to the instance + _var_data, _var_name = _decode_var(self._var_name) + if _var_data: + self.__init__( + _var_name, self._var_type, VarData.merge(self._var_data, _var_data) + ) + def _replace(self, merge_var_data=None, **kwargs: Any): """Make a copy of this Var with updated fields. @@ -156,3 +166,45 @@ class ImmutableVar(Var): _var_type=type_, _var_data=_var_data, ) + + @classmethod + def create_safe( + cls, + value: Any, + _var_is_local: bool | None = None, + _var_is_string: bool | None = None, + _var_data: VarData | None = None, + ) -> Var: + """Create a var from a value, asserting that it is not None. + + Args: + value: The value to create the var from. + _var_is_local: Whether the var is local. Deprecated. + _var_is_string: Whether the var is a string literal. Deprecated. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The var. + """ + var = cls.create( + value, + _var_is_local=_var_is_local, + _var_is_string=_var_is_string, + _var_data=_var_data, + ) + assert var is not None + return var + + def __format__(self, format_spec: str) -> str: + """Format the var into a Javascript equivalent to an f-string. + + Args: + format_spec: The format specifier (Ignored for now). + + Returns: + The formatted var. + """ + _global_vars[hash(self)] = self + + # Encode the _var_data into the formatted output for tracking purposes. + return f"{REFLEX_VAR_OPENING_TAG}{hash(self)}{REFLEX_VAR_CLOSING_TAG}{self._var_name}" diff --git a/reflex/vars.py b/reflex/vars.py index f226d88bd..8d93f99c0 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -262,6 +262,9 @@ _decode_var_pattern_re = ( ) _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL) +# Defined global immutable vars. +_global_vars: Dict[int, Var] = {} + def _decode_var(value: str) -> tuple[VarData | None, str]: """Decode the state name from a formatted var. @@ -294,17 +297,32 @@ def _decode_var(value: str) -> tuple[VarData | None, str]: start, end = m.span() value = value[:start] + value[end:] - # Read the JSON, pull out the string length, parse the rest as VarData. - data = json_loads(m.group(1)) - string_length = data.pop("string_length", None) - var_data = VarData.parse_obj(data) + serialized_data = m.group(1) - # Use string length to compute positions of interpolations. - if string_length is not None: - realstart = start + offset - var_data.interpolations = [(realstart, realstart + string_length)] + if serialized_data[1:].isnumeric(): + # This is a global immutable var. + var = _global_vars[int(serialized_data)] + var_data = var._var_data - var_datas.append(var_data) + if var_data is not None: + realstart = start + offset + var_data.interpolations = [ + (realstart, realstart + len(var._var_name)) + ] + + var_datas.append(var_data) + else: + # Read the JSON, pull out the string length, parse the rest as VarData. + data = json_loads(serialized_data) + string_length = data.pop("string_length", None) + var_data = VarData.parse_obj(data) + + # Use string length to compute positions of interpolations. + if string_length is not None: + realstart = start + offset + var_data.interpolations = [(realstart, realstart + string_length)] + + var_datas.append(var_data) offset += end - start return VarData.merge(*var_datas) if var_datas else None, value diff --git a/reflex/vars.pyi b/reflex/vars.pyi index 9c9a7315a..129d7888d 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -35,6 +35,9 @@ USED_VARIABLES: Incomplete def get_unique_variable_name() -> str: ... def _encode_var(value: Var) -> str: ... + +_global_vars: Dict[int, Var] + def _decode_var(value: str) -> tuple[VarData, str]: ... def _extract_var_data(value: Iterable) -> list[VarData | None]: ... diff --git a/tests/test_var.py b/tests/test_var.py index a96b331bd..5284bf98d 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -6,11 +6,15 @@ import pytest from pandas import DataFrame from reflex.base import Base +from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG +from reflex.experimental.vars.base import ImmutableVar from reflex.state import BaseState +from reflex.utils.imports import ImportVar from reflex.vars import ( BaseVar, ComputedVar, Var, + VarData, computed_var, ) @@ -849,6 +853,43 @@ def test_state_with_initial_computed_var( assert runtime_dict[var_name] == expected_runtime +def test_retrival(): + var_without_data = ImmutableVar.create("test") + assert var_without_data is not None + + original_var_data = VarData( + state="Test", + imports={"react": [ImportVar(tag="useRef")]}, + hooks={"const state = useContext(StateContexts.state)": None}, + ) + + var_with_data = var_without_data._replace(merge_var_data=original_var_data) + + f_string = f"foo{var_with_data}bar" + + assert REFLEX_VAR_OPENING_TAG in f_string + assert REFLEX_VAR_CLOSING_TAG in f_string + + result_var_data = Var.create_safe(f_string)._var_data + result_immutable_var_data = ImmutableVar.create_safe(f_string)._var_data + assert result_var_data is not None and result_immutable_var_data is not None + assert ( + result_var_data.state + == result_immutable_var_data.state + == original_var_data.state + ) + assert ( + result_var_data.imports + == result_immutable_var_data.imports + == original_var_data.imports + ) + assert ( + result_var_data.hooks + == result_immutable_var_data.hooks + == original_var_data.hooks + ) + + @pytest.mark.parametrize( "out, expected", [