From e729a315f8d5d4b970bdf89180fc56578fa694ce Mon Sep 17 00:00:00 2001
From: invrainbow <77120437+invrainbow@users.noreply.github.com>
Date: Tue, 13 Feb 2024 14:32:44 -0800
Subject: [PATCH] Fix fstrings being escaped improperly (#2571)
---
reflex/utils/format.py | 65 +++++++++++++++++++++---
reflex/vars.py | 76 +++++++++++++++++++++-------
reflex/vars.pyi | 3 ++
tests/components/layout/test_cond.py | 6 +++
tests/components/test_component.py | 5 +-
tests/test_var.py | 2 +-
6 files changed, 128 insertions(+), 29 deletions(-)
diff --git a/reflex/utils/format.py b/reflex/utils/format.py
index aaa371c2a..53b55d0eb 100644
--- a/reflex/utils/format.py
+++ b/reflex/utils/format.py
@@ -188,6 +188,35 @@ def to_kebab_case(text: str) -> str:
return to_snake_case(text).replace("_", "-")
+def _escape_js_string(string: str) -> str:
+ """Escape the string for use as a JS string literal.
+
+ Args:
+ string: The string to escape.
+
+ Returns:
+ The escaped string.
+ """
+ # Escape backticks.
+ string = string.replace(r"\`", "`")
+ string = string.replace("`", r"\`")
+ return string
+
+
+def _wrap_js_string(string: str) -> str:
+ """Wrap string so it looks like {`string`}.
+
+ Args:
+ string: The string to wrap.
+
+ Returns:
+ The wrapped string.
+ """
+ string = wrap(string, "`")
+ string = wrap(string, "{")
+ return string
+
+
def format_string(string: str) -> str:
"""Format the given string as a JS string literal..
@@ -197,15 +226,33 @@ def format_string(string: str) -> str:
Returns:
The formatted string.
"""
- # Escape backticks.
- string = string.replace(r"\`", "`")
- string = string.replace("`", r"\`")
+ return _wrap_js_string(_escape_js_string(string))
- # Wrap the string so it looks like {`string`}.
- string = wrap(string, "`")
- string = wrap(string, "{")
- return string
+def format_f_string_prop(prop: BaseVar) -> str:
+ """Format the string in a given prop as an f-string.
+
+ Args:
+ prop: The prop to format.
+
+ Returns:
+ The formatted string.
+ """
+ s = prop._var_full_name
+ var_data = prop._var_data
+ interps = var_data.interpolations if var_data else []
+ parts: List[str] = []
+
+ if interps:
+ for i, (start, end) in enumerate(interps):
+ prev_end = interps[i - 1][1] if i > 0 else 0
+ parts.append(_escape_js_string(s[prev_end:start]))
+ parts.append(s[start:end])
+ parts.append(_escape_js_string(s[interps[-1][1] :]))
+ else:
+ parts.append(_escape_js_string(s))
+
+ return _wrap_js_string("".join(parts))
def format_var(var: Var) -> str:
@@ -345,7 +392,9 @@ def format_prop(
if isinstance(prop, Var):
if not prop._var_is_local or prop._var_is_string:
return str(prop)
- if types._issubclass(prop._var_type, str):
+ if isinstance(prop, BaseVar) and types._issubclass(prop._var_type, str):
+ if prop._var_data and prop._var_data.interpolations:
+ return format_f_string_prop(prop)
return format_string(prop._var_full_name)
prop = prop._var_full_name
diff --git a/reflex/vars.py b/reflex/vars.py
index 4e605f9d6..83cdbcc7e 100644
--- a/reflex/vars.py
+++ b/reflex/vars.py
@@ -31,8 +31,6 @@ from typing import (
get_type_hints,
)
-import pydantic
-
from reflex import constants
from reflex.base import Base
from reflex.utils import console, format, imports, serializers, types
@@ -122,6 +120,11 @@ class VarData(Base):
# Hooks that need to be present in the component to render this var
hooks: Set[str] = set()
+ # Positions of interpolated strings. This is used by the decoder to figure
+ # out where the interpolations are and only escape the non-interpolated
+ # segments.
+ interpolations: List[Tuple[int, int]] = []
+
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None:
"""Merge multiple var data objects.
@@ -135,17 +138,21 @@ class VarData(Base):
state = ""
_imports = {}
hooks = set()
+ interpolations = []
for var_data in others:
if var_data is None:
continue
state = state or var_data.state
_imports = imports.merge_imports(_imports, var_data.imports)
hooks.update(var_data.hooks)
+ interpolations += var_data.interpolations
+
return (
cls(
state=state,
imports=_imports,
hooks=hooks,
+ interpolations=interpolations,
)
or None
)
@@ -156,7 +163,7 @@ class VarData(Base):
Returns:
True if any field is set to a non-default value.
"""
- return bool(self.state or self.imports or self.hooks)
+ return bool(self.state or self.imports or self.hooks or self.interpolations)
def __eq__(self, other: Any) -> bool:
"""Check if two var data objects are equal.
@@ -169,6 +176,9 @@ class VarData(Base):
"""
if not isinstance(other, VarData):
return False
+
+ # Don't compare interpolations - that's added in by the decoder, and
+ # not part of the vardata itself.
return (
self.state == other.state
and self.hooks == other.hooks
@@ -184,6 +194,7 @@ class VarData(Base):
"""
return {
"state": self.state,
+ "interpolations": list(self.interpolations),
"imports": {
lib: [import_var.dict() for import_var in import_vars]
for lib, import_vars in self.imports.items()
@@ -202,10 +213,18 @@ def _encode_var(value: Var) -> str:
The encoded var.
"""
if value._var_data:
+ from reflex.utils.serializers import serialize
+
+ final_value = str(value)
+ data = value._var_data.dict()
+ data["string_length"] = len(final_value)
+ data_json = value._var_data.__config__.json_dumps(data, default=serialize)
+
return (
- f"{constants.REFLEX_VAR_OPENING_TAG}{value._var_data.json()}{constants.REFLEX_VAR_CLOSING_TAG}"
- + str(value)
+ f"{constants.REFLEX_VAR_OPENING_TAG}{data_json}{constants.REFLEX_VAR_CLOSING_TAG}"
+ + final_value
)
+
return str(value)
@@ -220,21 +239,40 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
"""
var_datas = []
if isinstance(value, str):
- # Extract the state name from a formatted var
- while m := re.match(
- pattern=rf"(.*){constants.REFLEX_VAR_OPENING_TAG}(.*){constants.REFLEX_VAR_CLOSING_TAG}(.*)",
- string=value,
- flags=re.DOTALL, # Ensure . matches newline characters.
- ):
- value = m.group(1) + m.group(3)
+ offset = 0
+
+ # Initialize some methods for reading json.
+ var_data_config = VarData().__config__
+
+ def json_loads(s):
try:
- var_datas.append(VarData.parse_raw(m.group(2)))
- except pydantic.ValidationError:
- # If the VarData is invalid, it was probably json-encoded twice...
- var_datas.append(VarData.parse_raw(json.loads(f'"{m.group(2)}"')))
- if var_datas:
- return VarData.merge(*var_datas), value
- return None, value
+ return var_data_config.json_loads(s)
+ except json.decoder.JSONDecodeError:
+ return var_data_config.json_loads(var_data_config.json_loads(f'"{s}"'))
+
+ # Compile regex for finding reflex var tags.
+ pattern_re = rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
+ pattern = re.compile(pattern_re, flags=re.DOTALL)
+
+ # Find all tags.
+ while m := pattern.search(value):
+ start, end = m.span()
+ value = value[:start] + value[end:]
+
+ # Read the JSON, pull out the string length, parse the rest as VarData.
+ data = json_loads(m.group(1))
+ string_length = data.pop("string_length", None)
+ var_data = VarData.parse_obj(data)
+
+ # Use string length to compute positions of interpolations.
+ if string_length is not None:
+ realstart = start + offset
+ var_data.interpolations = [(realstart, realstart + string_length)]
+
+ var_datas.append(var_data)
+ offset += end - start
+
+ return VarData.merge(*var_datas) if var_datas else None, value
def _extract_var_data(value: Iterable) -> list[VarData | None]:
diff --git a/reflex/vars.pyi b/reflex/vars.pyi
index fc5d7e100..80f4ad25f 100644
--- a/reflex/vars.pyi
+++ b/reflex/vars.pyi
@@ -1,4 +1,5 @@
""" Generated with stubgen from mypy, then manually edited, do not regen."""
+from __future__ import annotations
from dataclasses import dataclass
from _typeshed import Incomplete
@@ -17,6 +18,7 @@ from typing import (
List,
Optional,
Set,
+ Tuple,
Type,
Union,
overload,
@@ -34,6 +36,7 @@ class VarData(Base):
state: str
imports: dict[str, set[ImportVar]]
hooks: set[str]
+ interpolations: List[Tuple[int, int]]
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None: ...
diff --git a/tests/components/layout/test_cond.py b/tests/components/layout/test_cond.py
index d60d65085..56a3e156d 100644
--- a/tests/components/layout/test_cond.py
+++ b/tests/components/layout/test_cond.py
@@ -27,6 +27,12 @@ def cond_state(request):
return CondState
+def test_f_string_cond_interpolation():
+ # make sure backticks inside interpolation don't get escaped
+ var = Var.create(f"x {cond(True, 'a', 'b')}")
+ assert str(var) == "x ${isTrue(true) ? `a` : `b`}"
+
+
@pytest.mark.parametrize(
"cond_state",
[
diff --git a/tests/components/test_component.py b/tests/components/test_component.py
index 63ec6b31e..04acfca4e 100644
--- a/tests/components/test_component.py
+++ b/tests/components/test_component.py
@@ -687,7 +687,10 @@ def test_stateful_banner():
TEST_VAR = Var.create_safe("test")._replace(
merge_var_data=VarData(
- hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test"
+ hooks={"useTest"},
+ imports={"test": {ImportVar(tag="test")}},
+ state="Test",
+ interpolations=[],
)
)
FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")
diff --git a/tests/test_var.py b/tests/test_var.py
index 3f2f9c4fa..3d7d23d89 100644
--- a/tests/test_var.py
+++ b/tests/test_var.py
@@ -718,7 +718,7 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
(
f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
- 'testing f-string with ${"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}{state.myvar}',
+ 'testing f-string with ${"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}{state.myvar}',
),
(
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",