diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index 9ace92b98..fcc12bc51 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -1,7 +1,8 @@ """Create a list of components from an iterable.""" + from __future__ import annotations -from typing import Any, Dict, Optional, overload +from typing import Any, Dict, Optional, Union, overload from reflex.components.base.fragment import Fragment from reflex.components.component import BaseComponent, Component, MemoizationLeaf @@ -10,7 +11,7 @@ from reflex.constants import Dirs from reflex.constants.colors import Color from reflex.style import LIGHT_COLOR_MODE, color_mode from reflex.utils import format, imports -from reflex.vars import BaseVar, Var, VarData +from reflex.vars import Var, VarData _IS_TRUE_IMPORT = { f"/{Dirs.STATE_PATH}": [imports.ImportVar(tag="isTrue")], @@ -171,6 +172,11 @@ def cond(condition: Any, c1: Any, c2: Any = None): c2 = create_var(c2) var_datas.extend([c1._var_data, c2._var_data]) + c1_type = c1._var_type if isinstance(c1, Var) else type(c1) + c2_type = c2._var_type if isinstance(c2, Var) else type(c2) + + var_type = c1_type if c1_type == c2_type else Union[c1_type, c2_type] + # Create the conditional var. return cond_var._replace( _var_name=format.format_cond( @@ -179,7 +185,7 @@ def cond(condition: Any, c1: Any, c2: Any = None): false_value=c2, is_prop=True, ), - _var_type=c1._var_type if isinstance(c1, BaseVar) else type(c1), + _var_type=var_type, _var_is_local=False, _var_full_name_needs_state_prefix=False, merge_var_data=VarData.merge(*var_datas), diff --git a/tests/components/core/test_cond.py b/tests/components/core/test_cond.py index a7604fb9a..4bfa902af 100644 --- a/tests/components/core/test_cond.py +++ b/tests/components/core/test_cond.py @@ -1,13 +1,14 @@ import json -from typing import Any +from typing import Any, Union import pytest from reflex.components.base.fragment import Fragment from reflex.components.core.cond import Cond, cond from reflex.components.radix.themes.typography.text import Text -from reflex.state import BaseState -from reflex.vars import Var +from reflex.state import BaseState, State +from reflex.utils.format import format_state_name +from reflex.vars import BaseVar, Var, computed_var @pytest.fixture @@ -118,3 +119,29 @@ def test_cond_no_else(): # Props do not support the use of cond without else with pytest.raises(ValueError): cond(True, "hello") # type: ignore + + +def test_cond_computed_var(): + """Test if cond works with computed vars.""" + + class CondStateComputed(State): + @computed_var + def computed_int(self) -> int: + return 0 + + @computed_var + def computed_str(self) -> str: + return "a string" + + comp = cond(True, CondStateComputed.computed_int, CondStateComputed.computed_str) + + # TODO: shouln't this be a ComputedVar? + assert isinstance(comp, BaseVar) + + state_name = format_state_name(CondStateComputed.get_full_name()) + assert ( + str(comp) + == f"{{isTrue(true) ? {state_name}.computed_int : {state_name}.computed_str}}" + ) + + assert comp._var_type == Union[int, str]