diff --git a/reflex/utils/types.py b/reflex/utils/types.py index dec0f4eaf..390b5426e 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -195,6 +195,8 @@ def unionize(*args: GenericType) -> Type: return Any if len(args) == 1: return args[0] + if sys.version_info >= (3, 11): + return Union[*args] # type: ignore # We are bisecting the args list here to avoid hitting the recursion limit # In Python versions >= 3.11, we can simply do `return Union[*args]` midpoint = len(args) // 2 diff --git a/tests/units/test_var.py b/tests/units/test_var.py index c02acefe6..86170e5d5 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -398,6 +398,31 @@ def test_list_tuple_contains(var, expected): assert str(var.contains(other_var)) == f"{expected}.includes(other)" +class Foo(rx.Base): + bar: int + baz: str + + +class Bar(rx.Base): + bar: str + baz: str + foo: int + + +@pytest.mark.parametrize( + ("var", "var_type"), + [ + (Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar), + (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]), + (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().bar, Union[int, str]), + (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str), + (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo, Union[int, None]), + ], +) +def test_var_types(var, var_type): + assert var._var_type == var_type + + @pytest.mark.parametrize( "var, expected", [