Fix fstrings being escaped improperly (#2571)
This commit is contained in:
parent
fccb73ee70
commit
e729a315f8
@ -188,6 +188,35 @@ def to_kebab_case(text: str) -> str:
|
|||||||
return to_snake_case(text).replace("_", "-")
|
return to_snake_case(text).replace("_", "-")
|
||||||
|
|
||||||
|
|
||||||
|
def _escape_js_string(string: str) -> str:
|
||||||
|
"""Escape the string for use as a JS string literal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
string: The string to escape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The escaped string.
|
||||||
|
"""
|
||||||
|
# Escape backticks.
|
||||||
|
string = string.replace(r"\`", "`")
|
||||||
|
string = string.replace("`", r"\`")
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_js_string(string: str) -> str:
|
||||||
|
"""Wrap string so it looks like {`string`}.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
string: The string to wrap.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The wrapped string.
|
||||||
|
"""
|
||||||
|
string = wrap(string, "`")
|
||||||
|
string = wrap(string, "{")
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
def format_string(string: str) -> str:
|
def format_string(string: str) -> str:
|
||||||
"""Format the given string as a JS string literal..
|
"""Format the given string as a JS string literal..
|
||||||
|
|
||||||
@ -197,15 +226,33 @@ def format_string(string: str) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
The formatted string.
|
The formatted string.
|
||||||
"""
|
"""
|
||||||
# Escape backticks.
|
return _wrap_js_string(_escape_js_string(string))
|
||||||
string = string.replace(r"\`", "`")
|
|
||||||
string = string.replace("`", r"\`")
|
|
||||||
|
|
||||||
# Wrap the string so it looks like {`string`}.
|
|
||||||
string = wrap(string, "`")
|
|
||||||
string = wrap(string, "{")
|
|
||||||
|
|
||||||
return string
|
def format_f_string_prop(prop: BaseVar) -> str:
|
||||||
|
"""Format the string in a given prop as an f-string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prop: The prop to format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The formatted string.
|
||||||
|
"""
|
||||||
|
s = prop._var_full_name
|
||||||
|
var_data = prop._var_data
|
||||||
|
interps = var_data.interpolations if var_data else []
|
||||||
|
parts: List[str] = []
|
||||||
|
|
||||||
|
if interps:
|
||||||
|
for i, (start, end) in enumerate(interps):
|
||||||
|
prev_end = interps[i - 1][1] if i > 0 else 0
|
||||||
|
parts.append(_escape_js_string(s[prev_end:start]))
|
||||||
|
parts.append(s[start:end])
|
||||||
|
parts.append(_escape_js_string(s[interps[-1][1] :]))
|
||||||
|
else:
|
||||||
|
parts.append(_escape_js_string(s))
|
||||||
|
|
||||||
|
return _wrap_js_string("".join(parts))
|
||||||
|
|
||||||
|
|
||||||
def format_var(var: Var) -> str:
|
def format_var(var: Var) -> str:
|
||||||
@ -345,7 +392,9 @@ def format_prop(
|
|||||||
if isinstance(prop, Var):
|
if isinstance(prop, Var):
|
||||||
if not prop._var_is_local or prop._var_is_string:
|
if not prop._var_is_local or prop._var_is_string:
|
||||||
return str(prop)
|
return str(prop)
|
||||||
if types._issubclass(prop._var_type, str):
|
if isinstance(prop, BaseVar) and types._issubclass(prop._var_type, str):
|
||||||
|
if prop._var_data and prop._var_data.interpolations:
|
||||||
|
return format_f_string_prop(prop)
|
||||||
return format_string(prop._var_full_name)
|
return format_string(prop._var_full_name)
|
||||||
prop = prop._var_full_name
|
prop = prop._var_full_name
|
||||||
|
|
||||||
|
@ -31,8 +31,6 @@ from typing import (
|
|||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
import pydantic
|
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.base import Base
|
from reflex.base import Base
|
||||||
from reflex.utils import console, format, imports, serializers, types
|
from reflex.utils import console, format, imports, serializers, types
|
||||||
@ -122,6 +120,11 @@ class VarData(Base):
|
|||||||
# Hooks that need to be present in the component to render this var
|
# Hooks that need to be present in the component to render this var
|
||||||
hooks: Set[str] = set()
|
hooks: Set[str] = set()
|
||||||
|
|
||||||
|
# Positions of interpolated strings. This is used by the decoder to figure
|
||||||
|
# out where the interpolations are and only escape the non-interpolated
|
||||||
|
# segments.
|
||||||
|
interpolations: List[Tuple[int, int]] = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def merge(cls, *others: VarData | None) -> VarData | None:
|
def merge(cls, *others: VarData | None) -> VarData | None:
|
||||||
"""Merge multiple var data objects.
|
"""Merge multiple var data objects.
|
||||||
@ -135,17 +138,21 @@ class VarData(Base):
|
|||||||
state = ""
|
state = ""
|
||||||
_imports = {}
|
_imports = {}
|
||||||
hooks = set()
|
hooks = set()
|
||||||
|
interpolations = []
|
||||||
for var_data in others:
|
for var_data in others:
|
||||||
if var_data is None:
|
if var_data is None:
|
||||||
continue
|
continue
|
||||||
state = state or var_data.state
|
state = state or var_data.state
|
||||||
_imports = imports.merge_imports(_imports, var_data.imports)
|
_imports = imports.merge_imports(_imports, var_data.imports)
|
||||||
hooks.update(var_data.hooks)
|
hooks.update(var_data.hooks)
|
||||||
|
interpolations += var_data.interpolations
|
||||||
|
|
||||||
return (
|
return (
|
||||||
cls(
|
cls(
|
||||||
state=state,
|
state=state,
|
||||||
imports=_imports,
|
imports=_imports,
|
||||||
hooks=hooks,
|
hooks=hooks,
|
||||||
|
interpolations=interpolations,
|
||||||
)
|
)
|
||||||
or None
|
or None
|
||||||
)
|
)
|
||||||
@ -156,7 +163,7 @@ class VarData(Base):
|
|||||||
Returns:
|
Returns:
|
||||||
True if any field is set to a non-default value.
|
True if any field is set to a non-default value.
|
||||||
"""
|
"""
|
||||||
return bool(self.state or self.imports or self.hooks)
|
return bool(self.state or self.imports or self.hooks or self.interpolations)
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Check if two var data objects are equal.
|
"""Check if two var data objects are equal.
|
||||||
@ -169,6 +176,9 @@ class VarData(Base):
|
|||||||
"""
|
"""
|
||||||
if not isinstance(other, VarData):
|
if not isinstance(other, VarData):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Don't compare interpolations - that's added in by the decoder, and
|
||||||
|
# not part of the vardata itself.
|
||||||
return (
|
return (
|
||||||
self.state == other.state
|
self.state == other.state
|
||||||
and self.hooks == other.hooks
|
and self.hooks == other.hooks
|
||||||
@ -184,6 +194,7 @@ class VarData(Base):
|
|||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"state": self.state,
|
"state": self.state,
|
||||||
|
"interpolations": list(self.interpolations),
|
||||||
"imports": {
|
"imports": {
|
||||||
lib: [import_var.dict() for import_var in import_vars]
|
lib: [import_var.dict() for import_var in import_vars]
|
||||||
for lib, import_vars in self.imports.items()
|
for lib, import_vars in self.imports.items()
|
||||||
@ -202,10 +213,18 @@ def _encode_var(value: Var) -> str:
|
|||||||
The encoded var.
|
The encoded var.
|
||||||
"""
|
"""
|
||||||
if value._var_data:
|
if value._var_data:
|
||||||
|
from reflex.utils.serializers import serialize
|
||||||
|
|
||||||
|
final_value = str(value)
|
||||||
|
data = value._var_data.dict()
|
||||||
|
data["string_length"] = len(final_value)
|
||||||
|
data_json = value._var_data.__config__.json_dumps(data, default=serialize)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
f"{constants.REFLEX_VAR_OPENING_TAG}{value._var_data.json()}{constants.REFLEX_VAR_CLOSING_TAG}"
|
f"{constants.REFLEX_VAR_OPENING_TAG}{data_json}{constants.REFLEX_VAR_CLOSING_TAG}"
|
||||||
+ str(value)
|
+ final_value
|
||||||
)
|
)
|
||||||
|
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
@ -220,21 +239,40 @@ def _decode_var(value: str) -> tuple[VarData | None, str]:
|
|||||||
"""
|
"""
|
||||||
var_datas = []
|
var_datas = []
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
# Extract the state name from a formatted var
|
offset = 0
|
||||||
while m := re.match(
|
|
||||||
pattern=rf"(.*){constants.REFLEX_VAR_OPENING_TAG}(.*){constants.REFLEX_VAR_CLOSING_TAG}(.*)",
|
# Initialize some methods for reading json.
|
||||||
string=value,
|
var_data_config = VarData().__config__
|
||||||
flags=re.DOTALL, # Ensure . matches newline characters.
|
|
||||||
):
|
def json_loads(s):
|
||||||
value = m.group(1) + m.group(3)
|
|
||||||
try:
|
try:
|
||||||
var_datas.append(VarData.parse_raw(m.group(2)))
|
return var_data_config.json_loads(s)
|
||||||
except pydantic.ValidationError:
|
except json.decoder.JSONDecodeError:
|
||||||
# If the VarData is invalid, it was probably json-encoded twice...
|
return var_data_config.json_loads(var_data_config.json_loads(f'"{s}"'))
|
||||||
var_datas.append(VarData.parse_raw(json.loads(f'"{m.group(2)}"')))
|
|
||||||
if var_datas:
|
# Compile regex for finding reflex var tags.
|
||||||
return VarData.merge(*var_datas), value
|
pattern_re = rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"
|
||||||
return None, value
|
pattern = re.compile(pattern_re, flags=re.DOTALL)
|
||||||
|
|
||||||
|
# Find all tags.
|
||||||
|
while m := pattern.search(value):
|
||||||
|
start, end = m.span()
|
||||||
|
value = value[:start] + value[end:]
|
||||||
|
|
||||||
|
# Read the JSON, pull out the string length, parse the rest as VarData.
|
||||||
|
data = json_loads(m.group(1))
|
||||||
|
string_length = data.pop("string_length", None)
|
||||||
|
var_data = VarData.parse_obj(data)
|
||||||
|
|
||||||
|
# Use string length to compute positions of interpolations.
|
||||||
|
if string_length is not None:
|
||||||
|
realstart = start + offset
|
||||||
|
var_data.interpolations = [(realstart, realstart + string_length)]
|
||||||
|
|
||||||
|
var_datas.append(var_data)
|
||||||
|
offset += end - start
|
||||||
|
|
||||||
|
return VarData.merge(*var_datas) if var_datas else None, value
|
||||||
|
|
||||||
|
|
||||||
def _extract_var_data(value: Iterable) -> list[VarData | None]:
|
def _extract_var_data(value: Iterable) -> list[VarData | None]:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
""" Generated with stubgen from mypy, then manually edited, do not regen."""
|
""" Generated with stubgen from mypy, then manually edited, do not regen."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from _typeshed import Incomplete
|
from _typeshed import Incomplete
|
||||||
@ -17,6 +18,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
overload,
|
overload,
|
||||||
@ -34,6 +36,7 @@ class VarData(Base):
|
|||||||
state: str
|
state: str
|
||||||
imports: dict[str, set[ImportVar]]
|
imports: dict[str, set[ImportVar]]
|
||||||
hooks: set[str]
|
hooks: set[str]
|
||||||
|
interpolations: List[Tuple[int, int]]
|
||||||
@classmethod
|
@classmethod
|
||||||
def merge(cls, *others: VarData | None) -> VarData | None: ...
|
def merge(cls, *others: VarData | None) -> VarData | None: ...
|
||||||
|
|
||||||
|
@ -27,6 +27,12 @@ def cond_state(request):
|
|||||||
return CondState
|
return CondState
|
||||||
|
|
||||||
|
|
||||||
|
def test_f_string_cond_interpolation():
|
||||||
|
# make sure backticks inside interpolation don't get escaped
|
||||||
|
var = Var.create(f"x {cond(True, 'a', 'b')}")
|
||||||
|
assert str(var) == "x ${isTrue(true) ? `a` : `b`}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"cond_state",
|
"cond_state",
|
||||||
[
|
[
|
||||||
|
@ -687,7 +687,10 @@ def test_stateful_banner():
|
|||||||
|
|
||||||
TEST_VAR = Var.create_safe("test")._replace(
|
TEST_VAR = Var.create_safe("test")._replace(
|
||||||
merge_var_data=VarData(
|
merge_var_data=VarData(
|
||||||
hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test"
|
hooks={"useTest"},
|
||||||
|
imports={"test": {ImportVar(tag="test")}},
|
||||||
|
state="Test",
|
||||||
|
interpolations=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")
|
FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar")
|
||||||
|
@ -718,7 +718,7 @@ def test_computed_var_with_annotation_error(request, fixture, full_name):
|
|||||||
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
|
(f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"),
|
||||||
(
|
(
|
||||||
f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
|
f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}",
|
||||||
'testing f-string with $<reflex.Var>{"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}</reflex.Var>{state.myvar}',
|
'testing f-string with $<reflex.Var>{"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}</reflex.Var>{state.myvar}',
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
|
f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",
|
||||||
|
Loading…
Reference in New Issue
Block a user