Fix fstrings being escaped improperly (#2571)

This commit is contained in:
invrainbow 2024-02-13 14:32:44 -08:00 committed by GitHub
parent fccb73ee70
commit e729a315f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 128 additions and 29 deletions

View File

@ -188,6 +188,35 @@ def to_kebab_case(text: str) -> str:
return to_snake_case(text).replace("_", "-") return to_snake_case(text).replace("_", "-")
def _escape_js_string(string: str) -> str:
"""Escape the string for use as a JS string literal.
Args:
string: The string to escape.
Returns:
The escaped string.
"""
# Escape backticks.
string = string.replace(r"\`", "`")
string = string.replace("`", r"\`")
return string
def _wrap_js_string(string: str) -> str:
"""Wrap string so it looks like {`string`}.
Args:
string: The string to wrap.
Returns:
The wrapped string.
"""
string = wrap(string, "`")
string = wrap(string, "{")
return string
def format_string(string: str) -> str: def format_string(string: str) -> str:
"""Format the given string as a JS string literal.. """Format the given string as a JS string literal..
@ -197,15 +226,33 @@ def format_string(string: str) -> str:
Returns: Returns:
The formatted string. The formatted string.
""" """
# Escape backticks. return _wrap_js_string(_escape_js_string(string))
string = string.replace(r"\`", "`")
string = string.replace("`", r"\`")
# Wrap the string so it looks like {`string`}.
string = wrap(string, "`")
string = wrap(string, "{")
return string def format_f_string_prop(prop: BaseVar) -> str:
"""Format the string in a given prop as an f-string.
Args:
prop: The prop to format.
Returns:
The formatted string.
"""
s = prop._var_full_name
var_data = prop._var_data
interps = var_data.interpolations if var_data else []
parts: List[str] = []
if interps:
for i, (start, end) in enumerate(interps):
prev_end = interps[i - 1][1] if i > 0 else 0
parts.append(_escape_js_string(s[prev_end:start]))
parts.append(s[start:end])
parts.append(_escape_js_string(s[interps[-1][1] :]))
else:
parts.append(_escape_js_string(s))
return _wrap_js_string("".join(parts))
def format_var(var: Var) -> str: def format_var(var: Var) -> str:
@ -345,7 +392,9 @@ def format_prop(
if isinstance(prop, Var): if isinstance(prop, Var):
if not prop._var_is_local or prop._var_is_string: if not prop._var_is_local or prop._var_is_string:
return str(prop) return str(prop)
if types._issubclass(prop._var_type, str): if isinstance(prop, BaseVar) and types._issubclass(prop._var_type, str):
if prop._var_data and prop._var_data.interpolations:
return format_f_string_prop(prop)
return format_string(prop._var_full_name) return format_string(prop._var_full_name)
prop = prop._var_full_name prop = prop._var_full_name

View File

@ -31,8 +31,6 @@ from typing import (
get_type_hints, get_type_hints,
) )
import pydantic
from reflex import constants from reflex import constants
from reflex.base import Base from reflex.base import Base
from reflex.utils import console, format, imports, serializers, types from reflex.utils import console, format, imports, serializers, types
@ -122,6 +120,11 @@ class VarData(Base):
# Hooks that need to be present in the component to render this var # Hooks that need to be present in the component to render this var
hooks: Set[str] = set() hooks: Set[str] = set()
# Positions of interpolated strings. This is used by the decoder to figure
# out where the interpolations are and only escape the non-interpolated
# segments.
interpolations: List[Tuple[int, int]] = []
@classmethod @classmethod
def merge(cls, *others: VarData | None) -> VarData | None: def merge(cls, *others: VarData | None) -> VarData | None:
"""Merge multiple var data objects. """Merge multiple var data objects.
@ -135,17 +138,21 @@ class VarData(Base):
state = "" state = ""
_imports = {} _imports = {}
hooks = set() hooks = set()
interpolations = []
for var_data in others: for var_data in others:
if var_data is None: if var_data is None:
continue continue
state = state or var_data.state state = state or var_data.state
_imports = imports.merge_imports(_imports, var_data.imports) _imports = imports.merge_imports(_imports, var_data.imports)
hooks.update(var_data.hooks) hooks.update(var_data.hooks)
interpolations += var_data.interpolations
return ( return (
cls( cls(
state=state, state=state,
imports=_imports, imports=_imports,
hooks=hooks, hooks=hooks,
interpolations=interpolations,
) )
or None or None
) )
@ -156,7 +163,7 @@ class VarData(Base):
Returns: Returns:
True if any field is set to a non-default value. True if any field is set to a non-default value.
""" """
return bool(self.state or self.imports or self.hooks) return bool(self.state or self.imports or self.hooks or self.interpolations)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Check if two var data objects are equal. """Check if two var data objects are equal.
@ -169,6 +176,9 @@ class VarData(Base):
""" """
if not isinstance(other, VarData): if not isinstance(other, VarData):
return False return False
# Don't compare interpolations - that's added in by the decoder, and
# not part of the vardata itself.
return ( return (
self.state == other.state self.state == other.state
and self.hooks == other.hooks and self.hooks == other.hooks
@ -184,6 +194,7 @@ class VarData(Base):
""" """
return { return {
"state": self.state, "state": self.state,
"interpolations": list(self.interpolations),
"imports": { "imports": {
lib: [import_var.dict() for import_var in import_vars] lib: [import_var.dict() for import_var in import_vars]
for lib, import_vars in self.imports.items() for lib, import_vars in self.imports.items()
@ -202,10 +213,18 @@ def _encode_var(value: Var) -> str:
The encoded var. The encoded var.
""" """
if value._var_data: if value._var_data:
from reflex.utils.serializers import serialize
final_value = str(value)
data = value._var_data.dict()
data["string_length"] = len(final_value)
data_json = value._var_data.__config__.json_dumps(data, default=serialize)
return ( return (
f"{constants.REFLEX_VAR_OPENING_TAG}{value._var_data.json()}{constants.REFLEX_VAR_CLOSING_TAG}" f"{constants.REFLEX_VAR_OPENING_TAG}{data_json}{constants.REFLEX_VAR_CLOSING_TAG}"
+ str(value) + final_value
) )
return str(value) return str(value)
@ -220,21 +239,40 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
""" """
var_datas = [] var_datas = []
if isinstance(value, str): if isinstance(value, str):
# Extract the state name from a formatted var offset = 0
while m := re.match(
pattern=rf"(.*){constants.REFLEX_VAR_OPENING_TAG}(.*){constants.REFLEX_VAR_CLOSING_TAG}(.*)", # Initialize some methods for reading json.
string=value, var_data_config = VarData().__config__
flags=re.DOTALL, # Ensure . matches newline characters.
): def json_loads(s):
value = m.group(1) + m.group(3)
try: try:
var_datas.append(VarData.parse_raw(m.group(2))) return var_data_config.json_loads(s)
except pydantic.ValidationError: except json.decoder.JSONDecodeError:
# If the VarData is invalid, it was probably json-encoded twice... return var_data_config.json_loads(var_data_config.json_loads(f'"{s}"'))
var_datas.append(VarData.parse_raw(json.loads(f'"{m.group(2)}"')))
if var_datas: # Compile regex for finding reflex var tags.
return VarData.merge(*var_datas), value pattern_re = rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
return None, value pattern = re.compile(pattern_re, flags=re.DOTALL)
# Find all tags.
while m := pattern.search(value):
start, end = m.span()
value = value[:start] + value[end:]
# Read the JSON, pull out the string length, parse the rest as VarData.
data = json_loads(m.group(1))
string_length = data.pop("string_length", None)
var_data = VarData.parse_obj(data)
# Use string length to compute positions of interpolations.
if string_length is not None:
realstart = start + offset
var_data.interpolations = [(realstart, realstart + string_length)]
var_datas.append(var_data)
offset += end - start
return VarData.merge(*var_datas) if var_datas else None, value
def _extract_var_data(value: Iterable) -> list[VarData | None]: def _extract_var_data(value: Iterable) -> list[VarData | None]:

View File

@ -1,4 +1,5 @@
""" Generated with stubgen from mypy, then manually edited, do not regen.""" """ Generated with stubgen from mypy, then manually edited, do not regen."""
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from _typeshed import Incomplete from _typeshed import Incomplete
@ -17,6 +18,7 @@ from typing import (
List, List,
Optional, Optional,
Set, Set,
Tuple,
Type, Type,
Union, Union,
overload, overload,
@ -34,6 +36,7 @@ class VarData(Base):
state: str state: str
imports: dict[str, set[ImportVar]] imports: dict[str, set[ImportVar]]
hooks: set[str] hooks: set[str]
interpolations: List[Tuple[int, int]]
@classmethod @classmethod
def merge(cls, *others: VarData | None) -> VarData | None: ... def merge(cls, *others: VarData | None) -> VarData | None: ...

View File

@ -27,6 +27,12 @@ def cond_state(request):
return CondState return CondState
def test_f_string_cond_interpolation():
# make sure backticks inside interpolation don't get escaped
var = Var.create(f"x {cond(True, 'a', 'b')}")
assert str(var) == "x ${isTrue(true) ? `a` : `b`}"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cond_state", "cond_state",
[ [

View File

@ -687,7 +687,10 @@ def test_stateful_banner():
TEST_VAR = Var.create_safe("test")._replace( TEST_VAR = Var.create_safe("test")._replace(
merge_var_data=VarData( merge_var_data=VarData(
hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test" hooks={"useTest"},
imports={"test": {ImportVar(tag="test")}},
state="Test",
interpolations=[],
) )
) )
FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar") FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")

View File

@ -718,7 +718,7 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"), (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
( (
f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}", f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
'testing f-string with $<reflex.Var>{"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}</reflex.Var>{state.myvar}', 'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}</reflex.Var>{state.myvar}',
), ),
( (
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}", f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",