diff --git a/reflex/utils/format.py b/reflex/utils/format.py index aaa371c2a..53b55d0eb 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -188,6 +188,35 @@ def to_kebab_case(text: str) -> str: return to_snake_case(text).replace("_", "-") +def _escape_js_string(string: str) -> str: + """Escape the string for use as a JS string literal. + + Args: + string: The string to escape. + + Returns: + The escaped string. + """ + # Escape backticks. + string = string.replace(r"\`", "`") + string = string.replace("`", r"\`") + return string + + +def _wrap_js_string(string: str) -> str: + """Wrap string so it looks like {`string`}. + + Args: + string: The string to wrap. + + Returns: + The wrapped string. + """ + string = wrap(string, "`") + string = wrap(string, "{") + return string + + def format_string(string: str) -> str: """Format the given string as a JS string literal.. @@ -197,15 +226,33 @@ def format_string(string: str) -> str: Returns: The formatted string. """ - # Escape backticks. - string = string.replace(r"\`", "`") - string = string.replace("`", r"\`") + return _wrap_js_string(_escape_js_string(string)) - # Wrap the string so it looks like {`string`}. - string = wrap(string, "`") - string = wrap(string, "{") - return string +def format_f_string_prop(prop: BaseVar) -> str: + """Format the string in a given prop as an f-string. + + Args: + prop: The prop to format. + + Returns: + The formatted string. + """ + s = prop._var_full_name + var_data = prop._var_data + interps = var_data.interpolations if var_data else [] + parts: List[str] = [] + + if interps: + for i, (start, end) in enumerate(interps): + prev_end = interps[i - 1][1] if i > 0 else 0 + parts.append(_escape_js_string(s[prev_end:start])) + parts.append(s[start:end]) + parts.append(_escape_js_string(s[interps[-1][1] :])) + else: + parts.append(_escape_js_string(s)) + + return _wrap_js_string("".join(parts)) def format_var(var: Var) -> str: @@ -345,7 +392,9 @@ def format_prop( if isinstance(prop, Var): if not prop._var_is_local or prop._var_is_string: return str(prop) - if types._issubclass(prop._var_type, str): + if isinstance(prop, BaseVar) and types._issubclass(prop._var_type, str): + if prop._var_data and prop._var_data.interpolations: + return format_f_string_prop(prop) return format_string(prop._var_full_name) prop = prop._var_full_name diff --git a/reflex/vars.py b/reflex/vars.py index 4e605f9d6..83cdbcc7e 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -31,8 +31,6 @@ from typing import ( get_type_hints, ) -import pydantic - from reflex import constants from reflex.base import Base from reflex.utils import console, format, imports, serializers, types @@ -122,6 +120,11 @@ class VarData(Base): # Hooks that need to be present in the component to render this var hooks: Set[str] = set() + # Positions of interpolated strings. This is used by the decoder to figure + # out where the interpolations are and only escape the non-interpolated + # segments. + interpolations: List[Tuple[int, int]] = [] + @classmethod def merge(cls, *others: VarData | None) -> VarData | None: """Merge multiple var data objects. @@ -135,17 +138,21 @@ class VarData(Base): state = "" _imports = {} hooks = set() + interpolations = [] for var_data in others: if var_data is None: continue state = state or var_data.state _imports = imports.merge_imports(_imports, var_data.imports) hooks.update(var_data.hooks) + interpolations += var_data.interpolations + return ( cls( state=state, imports=_imports, hooks=hooks, + interpolations=interpolations, ) or None ) @@ -156,7 +163,7 @@ class VarData(Base): Returns: True if any field is set to a non-default value. """ - return bool(self.state or self.imports or self.hooks) + return bool(self.state or self.imports or self.hooks or self.interpolations) def __eq__(self, other: Any) -> bool: """Check if two var data objects are equal. @@ -169,6 +176,9 @@ class VarData(Base): """ if not isinstance(other, VarData): return False + + # Don't compare interpolations - that's added in by the decoder, and + # not part of the vardata itself. return ( self.state == other.state and self.hooks == other.hooks @@ -184,6 +194,7 @@ class VarData(Base): """ return { "state": self.state, + "interpolations": list(self.interpolations), "imports": { lib: [import_var.dict() for import_var in import_vars] for lib, import_vars in self.imports.items() @@ -202,10 +213,18 @@ def _encode_var(value: Var) -> str: The encoded var. """ if value._var_data: + from reflex.utils.serializers import serialize + + final_value = str(value) + data = value._var_data.dict() + data["string_length"] = len(final_value) + data_json = value._var_data.__config__.json_dumps(data, default=serialize) + return ( - f"{constants.REFLEX_VAR_OPENING_TAG}{value._var_data.json()}{constants.REFLEX_VAR_CLOSING_TAG}" - + str(value) + f"{constants.REFLEX_VAR_OPENING_TAG}{data_json}{constants.REFLEX_VAR_CLOSING_TAG}" + + final_value ) + return str(value) @@ -220,21 +239,40 @@ def _decode_var(value: str) -> tuple[VarData | None, str]: """ var_datas = [] if isinstance(value, str): - # Extract the state name from a formatted var - while m := re.match( - pattern=rf"(.*){constants.REFLEX_VAR_OPENING_TAG}(.*){constants.REFLEX_VAR_CLOSING_TAG}(.*)", - string=value, - flags=re.DOTALL, # Ensure . matches newline characters. - ): - value = m.group(1) + m.group(3) + offset = 0 + + # Initialize some methods for reading json. + var_data_config = VarData().__config__ + + def json_loads(s): try: - var_datas.append(VarData.parse_raw(m.group(2))) - except pydantic.ValidationError: - # If the VarData is invalid, it was probably json-encoded twice... - var_datas.append(VarData.parse_raw(json.loads(f'"{m.group(2)}"'))) - if var_datas: - return VarData.merge(*var_datas), value - return None, value + return var_data_config.json_loads(s) + except json.decoder.JSONDecodeError: + return var_data_config.json_loads(var_data_config.json_loads(f'"{s}"')) + + # Compile regex for finding reflex var tags. + pattern_re = rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}" + pattern = re.compile(pattern_re, flags=re.DOTALL) + + # Find all tags. + while m := pattern.search(value): + start, end = m.span() + value = value[:start] + value[end:] + + # Read the JSON, pull out the string length, parse the rest as VarData. + data = json_loads(m.group(1)) + string_length = data.pop("string_length", None) + var_data = VarData.parse_obj(data) + + # Use string length to compute positions of interpolations. + if string_length is not None: + realstart = start + offset + var_data.interpolations = [(realstart, realstart + string_length)] + + var_datas.append(var_data) + offset += end - start + + return VarData.merge(*var_datas) if var_datas else None, value def _extract_var_data(value: Iterable) -> list[VarData | None]: diff --git a/reflex/vars.pyi b/reflex/vars.pyi index fc5d7e100..80f4ad25f 100644 --- a/reflex/vars.pyi +++ b/reflex/vars.pyi @@ -1,4 +1,5 @@ """ Generated with stubgen from mypy, then manually edited, do not regen.""" +from __future__ import annotations from dataclasses import dataclass from _typeshed import Incomplete @@ -17,6 +18,7 @@ from typing import ( List, Optional, Set, + Tuple, Type, Union, overload, @@ -34,6 +36,7 @@ class VarData(Base): state: str imports: dict[str, set[ImportVar]] hooks: set[str] + interpolations: List[Tuple[int, int]] @classmethod def merge(cls, *others: VarData | None) -> VarData | None: ... diff --git a/tests/components/layout/test_cond.py b/tests/components/layout/test_cond.py index d60d65085..56a3e156d 100644 --- a/tests/components/layout/test_cond.py +++ b/tests/components/layout/test_cond.py @@ -27,6 +27,12 @@ def cond_state(request): return CondState +def test_f_string_cond_interpolation(): + # make sure backticks inside interpolation don't get escaped + var = Var.create(f"x {cond(True, 'a', 'b')}") + assert str(var) == "x ${isTrue(true) ? `a` : `b`}" + + @pytest.mark.parametrize( "cond_state", [ diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 63ec6b31e..04acfca4e 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -687,7 +687,10 @@ def test_stateful_banner(): TEST_VAR = Var.create_safe("test")._replace( merge_var_data=VarData( - hooks={"useTest"}, imports={"test": {ImportVar(tag="test")}}, state="Test" + hooks={"useTest"}, + imports={"test": {ImportVar(tag="test")}}, + state="Test", + interpolations=[], ) ) FORMATTED_TEST_VAR = Var.create(f"foo{TEST_VAR}bar") diff --git a/tests/test_var.py b/tests/test_var.py index 3f2f9c4fa..3d7d23d89 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -718,7 +718,7 @@ def test_computed_var_with_annotation_error(request, fixture, full_name): (f"{BaseVar(_var_name='var', _var_type=str)}", "${var}"), ( f"testing f-string with {BaseVar(_var_name='myvar', _var_type=int)._var_set_state('state')}", - 'testing f-string with ${"state": "state", "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"]}{state.myvar}', + 'testing f-string with ${"state": "state", "interpolations": [], "imports": {"/utils/context": [{"tag": "StateContexts", "is_default": false, "alias": null, "install": true, "render": true}], "react": [{"tag": "useContext", "is_default": false, "alias": null, "install": true, "render": true}]}, "hooks": ["const state = useContext(StateContexts.state)"], "string_length": 13}{state.myvar}', ), ( f"testing local f-string {BaseVar(_var_name='x', _var_is_local=True, _var_type=str)}",