diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 0cd939548..887928e3c 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1199,10 +1199,13 @@ def unionize(*args: Type) -> Type: """ if not args: return Any - first, *rest = args - if not rest: - return first - return Union[first, unionize(*rest)] + if len(args) == 1: + return args[0] + # We are bisecting the args list here to avoid hitting the recursion limit + # In Python versions >= 3.11, we can simply do `return Union[*args]` + midpoint = len(args) // 2 + first_half, second_half = args[:midpoint], args[midpoint:] + return Union[unionize(*first_half), unionize(*second_half)] def figure_out_type(value: Any) -> types.GenericType: diff --git a/tests/vars/test_base.py b/tests/vars/test_base.py new file mode 100644 index 000000000..f83d79373 --- /dev/null +++ b/tests/vars/test_base.py @@ -0,0 +1,21 @@ +from typing import Dict, List, Union + +import pytest + +from reflex.vars.base import figure_out_type + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (1, int), + (1.0, float), + ("a", str), + ([1, 2, 3], List[int]), + ([1, 2.0, "a"], List[Union[int, float, str]]), + ({"a": 1, "b": 2}, Dict[str, int]), + ({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]), + ], +) +def test_figure_out_type(value, expected): + assert figure_out_type(value) == expected