do not auto-determine generic args if already supplied

This commit is contained in:
Benedikt Bartscher 2024-10-10 19:35:37 +02:00
parent e56c7baa39
commit bbd34ab592
No known key found for this signature in database
2 changed files with 48 additions and 3 deletions

View File

@ -1257,6 +1257,27 @@ def unionize(*args: Type) -> Type:
return Union[unionize(*first_half), unionize(*second_half)] return Union[unionize(*first_half), unionize(*second_half)]
def has_args(cls) -> bool:
"""Check if the class has generic parameters.
Args:
cls: The class to check.
Returns:
Whether the class has generic
"""
if get_args(cls):
return True
# Check if the class inherits from a generic class (using __orig_bases__)
if hasattr(cls, "__orig_bases__"):
for base in cls.__orig_bases__:
if get_args(base):
return True
return False
def figure_out_type(value: Any) -> types.GenericType: def figure_out_type(value: Any) -> types.GenericType:
"""Figure out the type of the value. """Figure out the type of the value.
@ -1266,6 +1287,11 @@ def figure_out_type(value: Any) -> types.GenericType:
Returns: Returns:
The type of the value. The type of the value.
""" """
if isinstance(value, Var):
return value._var_type
type_ = type(value)
if has_args(type_):
return type_
if isinstance(value, list): if isinstance(value, list):
return List[unionize(*(figure_out_type(v) for v in value))] return List[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, set): if isinstance(value, set):
@ -1277,8 +1303,6 @@ def figure_out_type(value: Any) -> types.GenericType:
unionize(*(figure_out_type(k) for k in value)), unionize(*(figure_out_type(k) for k in value)),
unionize(*(figure_out_type(v) for v in value.values())), unionize(*(figure_out_type(v) for v in value.values())),
] ]
if isinstance(value, Var):
return value._var_type
return type(value) return type(value)

View File

@ -6,7 +6,25 @@ from reflex.vars.base import figure_out_type
class CustomDict(dict[str, str]): class CustomDict(dict[str, str]):
"""A custom dict.""" """A custom dict with generic arguments."""
pass
class ChildCustomDict(CustomDict):
"""A child of CustomDict."""
pass
class GenericDict(dict):
"""A generic dict with no generic arguments."""
pass
class ChildGenericDict(GenericDict):
"""A child of GenericDict."""
pass pass
@ -22,6 +40,9 @@ class CustomDict(dict[str, str]):
({"a": 1, "b": 2}, Dict[str, int]), ({"a": 1, "b": 2}, Dict[str, int]),
({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]), ({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
(CustomDict(), CustomDict), (CustomDict(), CustomDict),
(ChildCustomDict(), ChildCustomDict),
(GenericDict({1: 1}), Dict[int, int]),
(ChildGenericDict({1: 1}), Dict[int, int]),
], ],
) )
def test_figure_out_type(value, expected): def test_figure_out_type(value, expected):