diff --git a/reflex/vars/base.py b/reflex/vars/base.py index fe7be726c..d0d14a825 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1067,6 +1067,10 @@ class LiteralVar(Var): _var_type=type(value), _var_data=_var_data, ) + if isinstance(serialized_value, str): + return LiteralStringVar.create( + serialized_value, _var_type=type(value), _var_data=_var_data + ) return LiteralVar.create(serialized_value, _var_data=_var_data) if dataclasses.is_dataclass(value) and not isinstance(value, type): diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 192a969e5..ca8967d33 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -553,12 +553,14 @@ class LiteralStringVar(LiteralVar, StringVar): def create( cls, value: str, + _var_type: GenericType | None = str, _var_data: VarData | None = None, ) -> StringVar: """Create a var from a string value. Args: value: The value to create the var from. + _var_type: The type of the var. _var_data: Additional hooks and imports associated with the Var. Returns: @@ -591,18 +593,27 @@ class LiteralStringVar(LiteralVar, StringVar): filtered_strings_and_vals = [ s for s in strings_and_vals if isinstance(s, Var) or s ] - if len(filtered_strings_and_vals) == 1: - return LiteralVar.create(filtered_strings_and_vals[0]).to(StringVar) + only_string = filtered_strings_and_vals[0] + if isinstance(only_string, str): + return LiteralVar.create(only_string).to(StringVar, _var_type) + else: + return only_string.to(StringVar, only_string._var_type) - return ConcatVarOperation.create( + concat_result = ConcatVarOperation.create( *filtered_strings_and_vals, _var_data=_var_data, ) + return ( + concat_result + if _var_type is str + else concat_result.to(StringVar, _var_type) + ) + return LiteralStringVar( _js_expr=json.dumps(value), - _var_type=str, + _var_type=_var_type, _var_data=_var_data, _var_value=value, ) diff --git a/tests/components/core/test_colors.py b/tests/components/core/test_colors.py index f12c01372..a6175d56a 100644 --- a/tests/components/core/test_colors.py +++ b/tests/components/core/test_colors.py @@ -1,3 +1,5 @@ +from typing import Type, Union + import pytest import reflex as rx @@ -22,44 +24,45 @@ def create_color_var(color): @pytest.mark.parametrize( - "color, expected", + "color, expected, expected_type", [ - (create_color_var(rx.color("mint")), '"var(--mint-7)"'), - (create_color_var(rx.color("mint", 3)), '"var(--mint-3)"'), - (create_color_var(rx.color("mint", 3, True)), '"var(--mint-a3)"'), + (create_color_var(rx.color("mint")), '"var(--mint-7)"', Color), + (create_color_var(rx.color("mint", 3)), '"var(--mint-3)"', Color), + (create_color_var(rx.color("mint", 3, True)), '"var(--mint-a3)"', Color), ( create_color_var(rx.color(ColorState.color, ColorState.shade)), # type: ignore f'("var(--"+{str(color_state_name)}.color+"-"+{str(color_state_name)}.shade+")")', + Color, ), ( create_color_var(rx.color(f"{ColorState.color}", f"{ColorState.shade}")), # type: ignore f'("var(--"+{str(color_state_name)}.color+"-"+{str(color_state_name)}.shade+")")', + Color, ), ( create_color_var( rx.color(f"{ColorState.color_part}ato", f"{ColorState.shade}") # type: ignore ), f'("var(--"+{str(color_state_name)}.color_part+"ato-"+{str(color_state_name)}.shade+")")', + Color, ), ( create_color_var(f'{rx.color(ColorState.color, f"{ColorState.shade}")}'), # type: ignore f'("var(--"+{str(color_state_name)}.color+"-"+{str(color_state_name)}.shade+")")', + str, ), ( create_color_var( f'{rx.color(f"{ColorState.color}", f"{ColorState.shade}")}' # type: ignore ), f'("var(--"+{str(color_state_name)}.color+"-"+{str(color_state_name)}.shade+")")', + str, ), ], ) -def test_color(color, expected): - assert color._var_type is str +def test_color(color, expected, expected_type: Union[Type[str], Type[Color]]): + assert color._var_type is expected_type assert str(color) == expected - if color._var_type == Color: - assert str(color) == f"{{`{expected}`}}" - else: - assert str(color) == expected @pytest.mark.parametrize(