diff --git a/reflex/components/component.py b/reflex/components/component.py index e89da6900..7ee9b0d3a 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -55,6 +55,7 @@ from reflex.utils.imports import ( ) from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var +from reflex.vars.sequence import LiteralArrayVar class BaseComponent(Base, ABC): @@ -496,7 +497,12 @@ class Component(BaseComponent, ABC): # Convert class_name to str if it's list class_name = kwargs.get("class_name", "") if isinstance(class_name, (List, tuple)): - kwargs["class_name"] = " ".join(class_name) + if any(isinstance(c, Var) for c in class_name): + kwargs["class_name"] = LiteralArrayVar.create( + class_name, _var_type=List[str] + ).join(" ") + else: + kwargs["class_name"] = " ".join(class_name) # Construct the component. super().__init__(*args, **kwargs) diff --git a/reflex/components/radix/themes/layout/stack.py b/reflex/components/radix/themes/layout/stack.py index cb513cbfb..94bba4fb6 100644 --- a/reflex/components/radix/themes/layout/stack.py +++ b/reflex/components/radix/themes/layout/stack.py @@ -33,7 +33,7 @@ class Stack(Flex): """ # Apply the default classname given_class_name = props.pop("class_name", []) - if isinstance(given_class_name, str): + if not isinstance(given_class_name, list): given_class_name = [given_class_name] props["class_name"] = ["rx-Stack", *given_class_name] diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 6145c980c..15c7411a6 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -592,6 +592,29 @@ class LiteralStringVar(LiteralVar, StringVar): else: return only_string.to(StringVar, only_string._var_type) + if len( + literal_strings := [ + s + for s in filtered_strings_and_vals + if isinstance(s, (str, LiteralStringVar)) + ] + ) == len(filtered_strings_and_vals): + return LiteralStringVar.create( + "".join( + s._var_value if isinstance(s, LiteralStringVar) else s + for s in literal_strings + ), + _var_type=_var_type, + _var_data=VarData.merge( + _var_data, + *( + s._get_all_var_data() + for s in filtered_strings_and_vals + if isinstance(s, Var) + ), + ), + ) + concat_result = ConcatVarOperation.create( *filtered_strings_and_vals, _var_data=_var_data, @@ -736,6 +759,26 @@ class ArrayVar(Var[ARRAY_VAR_TYPE]): """ if not isinstance(sep, (StringVar, str)): raise_unsupported_operand_types("join", (type(self), type(sep))) + if ( + isinstance(self, LiteralArrayVar) + and ( + len( + args := [ + x + for x in self._var_value + if isinstance(x, (LiteralStringVar, str)) + ] + ) + == len(self._var_value) + ) + and isinstance(sep, (LiteralStringVar, str)) + ): + sep_str = sep._var_value if isinstance(sep, LiteralStringVar) else sep + return LiteralStringVar.create( + sep_str.join( + i._var_value if isinstance(i, LiteralStringVar) else i for i in args + ) + ) return array_join_operation(self, sep) def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]: diff --git a/tests/components/test_component.py b/tests/components/test_component.py index 4dda81896..73d3f611b 100644 --- a/tests/components/test_component.py +++ b/tests/components/test_component.py @@ -1288,6 +1288,16 @@ class EventState(rx.State): [FORMATTED_TEST_VAR], id="fstring-class_name", ), + pytest.param( + rx.fragment(class_name=f"foo{TEST_VAR}bar other-class"), + [LiteralVar.create(f"{FORMATTED_TEST_VAR} other-class")], + id="fstring-dual-class_name", + ), + pytest.param( + rx.fragment(class_name=[TEST_VAR, "other-class"]), + [LiteralVar.create([TEST_VAR, "other-class"]).join(" ")], + id="fstring-dual-class_name", + ), pytest.param( rx.fragment(special_props=[TEST_VAR]), [TEST_VAR],