do not auto-determine generic args if already supplied
This commit is contained in:
parent
e56c7baa39
commit
bbd34ab592
@ -1257,6 +1257,27 @@ def unionize(*args: Type) -> Type:
|
||||
return Union[unionize(*first_half), unionize(*second_half)]
|
||||
|
||||
|
||||
def has_args(cls) -> bool:
|
||||
"""Check if the class has generic parameters.
|
||||
|
||||
Args:
|
||||
cls: The class to check.
|
||||
|
||||
Returns:
|
||||
Whether the class has generic
|
||||
"""
|
||||
if get_args(cls):
|
||||
return True
|
||||
|
||||
# Check if the class inherits from a generic class (using __orig_bases__)
|
||||
if hasattr(cls, "__orig_bases__"):
|
||||
for base in cls.__orig_bases__:
|
||||
if get_args(base):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def figure_out_type(value: Any) -> types.GenericType:
|
||||
"""Figure out the type of the value.
|
||||
|
||||
@ -1266,6 +1287,11 @@ def figure_out_type(value: Any) -> types.GenericType:
|
||||
Returns:
|
||||
The type of the value.
|
||||
"""
|
||||
if isinstance(value, Var):
|
||||
return value._var_type
|
||||
type_ = type(value)
|
||||
if has_args(type_):
|
||||
return type_
|
||||
if isinstance(value, list):
|
||||
return List[unionize(*(figure_out_type(v) for v in value))]
|
||||
if isinstance(value, set):
|
||||
@ -1277,8 +1303,6 @@ def figure_out_type(value: Any) -> types.GenericType:
|
||||
unionize(*(figure_out_type(k) for k in value)),
|
||||
unionize(*(figure_out_type(v) for v in value.values())),
|
||||
]
|
||||
if isinstance(value, Var):
|
||||
return value._var_type
|
||||
return type(value)
|
||||
|
||||
|
||||
|
@ -6,7 +6,25 @@ from reflex.vars.base import figure_out_type
|
||||
|
||||
|
||||
class CustomDict(dict[str, str]):
|
||||
"""A custom dict."""
|
||||
"""A custom dict with generic arguments."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChildCustomDict(CustomDict):
|
||||
"""A child of CustomDict."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class GenericDict(dict):
|
||||
"""A generic dict with no generic arguments."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChildGenericDict(GenericDict):
|
||||
"""A child of GenericDict."""
|
||||
|
||||
pass
|
||||
|
||||
@ -22,6 +40,9 @@ class CustomDict(dict[str, str]):
|
||||
({"a": 1, "b": 2}, Dict[str, int]),
|
||||
({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
|
||||
(CustomDict(), CustomDict),
|
||||
(ChildCustomDict(), ChildCustomDict),
|
||||
(GenericDict({1: 1}), Dict[int, int]),
|
||||
(ChildGenericDict({1: 1}), Dict[int, int]),
|
||||
],
|
||||
)
|
||||
def test_figure_out_type(value, expected):
|
||||
|
Loading…
Reference in New Issue
Block a user