[REF-3225] implement __format__ for immutable vars (#3617)

* implement format for immutable vars

* add some basic test

* make reference only after formatting

* win over pyright

* hopefully now pyright doesn't hate me

* forgot some _var_data

* i don't know how imports work

* use f_string var and remove assignments from pyi file

* override post_init to not break immutability

* add create_safe and test for it
This commit is contained in:
Khaleel Al-Adhami 2024-07-11 14:19:38 -04:00 committed by GitHub
parent 3039b54a75
commit d4d077818c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 124 additions and 10 deletions

View File

@ -6,9 +6,10 @@ import dataclasses
import sys
from typing import Any, Optional, Type
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.utils import serializers, types
from reflex.utils.exceptions import VarTypeError
from reflex.vars import Var, VarData, _extract_var_data
from reflex.vars import Var, VarData, _decode_var, _extract_var_data, _global_vars
@dataclasses.dataclass(
@ -55,6 +56,15 @@ class ImmutableVar(Var):
"""
return False
def __post_init__(self):
"""Post-initialize the var."""
# Decode any inline Var markup and apply it to the instance
_var_data, _var_name = _decode_var(self._var_name)
if _var_data:
self.__init__(
_var_name, self._var_type, VarData.merge(self._var_data, _var_data)
)
def _replace(self, merge_var_data=None, **kwargs: Any):
"""Make a copy of this Var with updated fields.
@ -156,3 +166,45 @@ class ImmutableVar(Var):
_var_type=type_,
_var_data=_var_data,
)
@classmethod
def create_safe(
cls,
value: Any,
_var_is_local: bool | None = None,
_var_is_string: bool | None = None,
_var_data: VarData | None = None,
) -> Var:
"""Create a var from a value, asserting that it is not None.
Args:
value: The value to create the var from.
_var_is_local: Whether the var is local. Deprecated.
_var_is_string: Whether the var is a string literal. Deprecated.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
"""
var = cls.create(
value,
_var_is_local=_var_is_local,
_var_is_string=_var_is_string,
_var_data=_var_data,
)
assert var is not None
return var
def __format__(self, format_spec: str) -> str:
"""Format the var into a Javascript equivalent to an f-string.
Args:
format_spec: The format specifier (Ignored for now).
Returns:
The formatted var.
"""
_global_vars[hash(self)] = self
# Encode the _var_data into the formatted output for tracking purposes.
return f"{REFLEX_VAR_OPENING_TAG}{hash(self)}{REFLEX_VAR_CLOSING_TAG}{self._var_name}"

View File

@ -262,6 +262,9 @@ _decode_var_pattern_re = (
)
_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
# Defined global immutable vars.
_global_vars: Dict[int, Var] = {}
def _decode_var(value: str) -> tuple[VarData | None, str]:
"""Decode the state name from a formatted var.
@ -294,17 +297,32 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
start, end = m.span()
value = value[:start] + value[end:]
# Read the JSON, pull out the string length, parse the rest as VarData.
data = json_loads(m.group(1))
string_length = data.pop("string_length", None)
var_data = VarData.parse_obj(data)
serialized_data = m.group(1)
# Use string length to compute positions of interpolations.
if string_length is not None:
realstart = start + offset
var_data.interpolations = [(realstart, realstart + string_length)]
if serialized_data[1:].isnumeric():
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
var_data = var._var_data
var_datas.append(var_data)
if var_data is not None:
realstart = start + offset
var_data.interpolations = [
(realstart, realstart + len(var._var_name))
]
var_datas.append(var_data)
else:
# Read the JSON, pull out the string length, parse the rest as VarData.
data = json_loads(serialized_data)
string_length = data.pop("string_length", None)
var_data = VarData.parse_obj(data)
# Use string length to compute positions of interpolations.
if string_length is not None:
realstart = start + offset
var_data.interpolations = [(realstart, realstart + string_length)]
var_datas.append(var_data)
offset += end - start
return VarData.merge(*var_datas) if var_datas else None, value

View File

@ -35,6 +35,9 @@ USED_VARIABLES: Incomplete
def get_unique_variable_name() -> str: ...
def _encode_var(value: Var) -> str: ...
_global_vars: Dict[int, Var]
def _decode_var(value: str) -> tuple[VarData, str]: ...
def _extract_var_data(value: Iterable) -> list[VarData | None]: ...

View File

@ -6,11 +6,15 @@ import pytest
from pandas import DataFrame
from reflex.base import Base
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.experimental.vars.base import ImmutableVar
from reflex.state import BaseState
from reflex.utils.imports import ImportVar
from reflex.vars import (
BaseVar,
ComputedVar,
Var,
VarData,
computed_var,
)
@ -849,6 +853,43 @@ def test_state_with_initial_computed_var(
assert runtime_dict[var_name] == expected_runtime
def test_retrival():
var_without_data = ImmutableVar.create("test")
assert var_without_data is not None
original_var_data = VarData(
state="Test",
imports={"react": [ImportVar(tag="useRef")]},
hooks={"const state = useContext(StateContexts.state)": None},
)
var_with_data = var_without_data._replace(merge_var_data=original_var_data)
f_string = f"foo{var_with_data}bar"
assert REFLEX_VAR_OPENING_TAG in f_string
assert REFLEX_VAR_CLOSING_TAG in f_string
result_var_data = Var.create_safe(f_string)._var_data
result_immutable_var_data = ImmutableVar.create_safe(f_string)._var_data
assert result_var_data is not None and result_immutable_var_data is not None
assert (
result_var_data.state
== result_immutable_var_data.state
== original_var_data.state
)
assert (
result_var_data.imports
== result_immutable_var_data.imports
== original_var_data.imports
)
assert (
result_var_data.hooks
== result_immutable_var_data.hooks
== original_var_data.hooks
)
@pytest.mark.parametrize(
"out, expected",
[