diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 58a73e025..f84dfbdd5 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -1257,6 +1257,27 @@ def unionize(*args: Type) -> Type: return Union[unionize(*first_half), unionize(*second_half)] +def has_args(cls) -> bool: + """Check if the class has generic parameters. + + Args: + cls: The class to check. + + Returns: + Whether the class has generic + """ + if get_args(cls): + return True + + # Check if the class inherits from a generic class (using __orig_bases__) + if hasattr(cls, "__orig_bases__"): + for base in cls.__orig_bases__: + if get_args(base): + return True + + return False + + def figure_out_type(value: Any) -> types.GenericType: """Figure out the type of the value. @@ -1266,6 +1287,11 @@ def figure_out_type(value: Any) -> types.GenericType: Returns: The type of the value. """ + if isinstance(value, Var): + return value._var_type + type_ = type(value) + if has_args(type_): + return type_ if isinstance(value, list): return List[unionize(*(figure_out_type(v) for v in value))] if isinstance(value, set): @@ -1277,8 +1303,6 @@ def figure_out_type(value: Any) -> types.GenericType: unionize(*(figure_out_type(k) for k in value)), unionize(*(figure_out_type(v) for v in value.values())), ] - if isinstance(value, Var): - return value._var_type return type(value) diff --git a/tests/units/vars/test_base.py b/tests/units/vars/test_base.py index 5f01dab43..68bc0c38e 100644 --- a/tests/units/vars/test_base.py +++ b/tests/units/vars/test_base.py @@ -6,7 +6,25 @@ from reflex.vars.base import figure_out_type class CustomDict(dict[str, str]): - """A custom dict.""" + """A custom dict with generic arguments.""" + + pass + + +class ChildCustomDict(CustomDict): + """A child of CustomDict.""" + + pass + + +class GenericDict(dict): + """A generic dict with no generic arguments.""" + + pass + + +class ChildGenericDict(GenericDict): + """A child of GenericDict.""" pass @@ -22,6 +40,9 @@ class CustomDict(dict[str, str]): ({"a": 1, "b": 2}, Dict[str, int]), ({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]), (CustomDict(), CustomDict), + (ChildCustomDict(), ChildCustomDict), + (GenericDict({1: 1}), Dict[int, int]), + (ChildGenericDict({1: 1}), Dict[int, int]), ], ) def test_figure_out_type(value, expected):