Do not auto-determine generic args if already supplied (#4148)
* add failing test for figure_out_type * do not auto-determine generic args if already supplied * move has_args to utils.types, add tests for it
This commit is contained in:
parent
736b2a6ea9
commit
0889276e24
@ -220,6 +220,27 @@ def is_literal(cls: GenericType) -> bool:
|
|||||||
return get_origin(cls) is Literal
|
return get_origin(cls) is Literal
|
||||||
|
|
||||||
|
|
||||||
|
def has_args(cls) -> bool:
|
||||||
|
"""Check if the class has generic parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: The class to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether the class has generic
|
||||||
|
"""
|
||||||
|
if get_args(cls):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if the class inherits from a generic class (using __orig_bases__)
|
||||||
|
if hasattr(cls, "__orig_bases__"):
|
||||||
|
for base in cls.__orig_bases__:
|
||||||
|
if get_args(base):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_optional(cls: GenericType) -> bool:
|
def is_optional(cls: GenericType) -> bool:
|
||||||
"""Check if a class is an Optional.
|
"""Check if a class is an Optional.
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ from reflex.utils.imports import (
|
|||||||
ParsedImportDict,
|
ParsedImportDict,
|
||||||
parse_imports,
|
parse_imports,
|
||||||
)
|
)
|
||||||
from reflex.utils.types import GenericType, Self, get_origin
|
from reflex.utils.types import GenericType, Self, get_origin, has_args
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from reflex.state import BaseState
|
from reflex.state import BaseState
|
||||||
@ -1266,6 +1266,11 @@ def figure_out_type(value: Any) -> types.GenericType:
|
|||||||
Returns:
|
Returns:
|
||||||
The type of the value.
|
The type of the value.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(value, Var):
|
||||||
|
return value._var_type
|
||||||
|
type_ = type(value)
|
||||||
|
if has_args(type_):
|
||||||
|
return type_
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
return List[unionize(*(figure_out_type(v) for v in value))]
|
return List[unionize(*(figure_out_type(v) for v in value))]
|
||||||
if isinstance(value, set):
|
if isinstance(value, set):
|
||||||
@ -1277,8 +1282,6 @@ def figure_out_type(value: Any) -> types.GenericType:
|
|||||||
unionize(*(figure_out_type(k) for k in value)),
|
unionize(*(figure_out_type(k) for k in value)),
|
||||||
unionize(*(figure_out_type(v) for v in value.values())),
|
unionize(*(figure_out_type(v) for v in value.values())),
|
||||||
]
|
]
|
||||||
if isinstance(value, Var):
|
|
||||||
return value._var_type
|
|
||||||
return type(value)
|
return type(value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Literal, Tuple, Union
|
from typing import Any, Dict, List, Literal, Tuple, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -45,3 +45,48 @@ def test_issubclass(
|
|||||||
cls: types.GenericType, cls_check: types.GenericType, expected: bool
|
cls: types.GenericType, cls_check: types.GenericType, expected: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
assert types._issubclass(cls, cls_check) == expected
|
assert types._issubclass(cls, cls_check) == expected
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDict(dict[str, str]):
|
||||||
|
"""A custom dict with generic arguments."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChildCustomDict(CustomDict):
|
||||||
|
"""A child of CustomDict."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GenericDict(dict):
|
||||||
|
"""A generic dict with no generic arguments."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChildGenericDict(GenericDict):
|
||||||
|
"""A child of GenericDict."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"cls,expected",
|
||||||
|
[
|
||||||
|
(int, False),
|
||||||
|
(str, False),
|
||||||
|
(float, False),
|
||||||
|
(Tuple[int], True),
|
||||||
|
(List[int], True),
|
||||||
|
(Union[int, str], True),
|
||||||
|
(Union[str, int], True),
|
||||||
|
(Dict[str, int], True),
|
||||||
|
(CustomDict, True),
|
||||||
|
(ChildCustomDict, True),
|
||||||
|
(GenericDict, False),
|
||||||
|
(ChildGenericDict, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_has_args(cls, expected: bool) -> None:
|
||||||
|
assert types.has_args(cls) == expected
|
||||||
|
@ -5,6 +5,30 @@ import pytest
|
|||||||
from reflex.vars.base import figure_out_type
|
from reflex.vars.base import figure_out_type
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDict(dict[str, str]):
|
||||||
|
"""A custom dict with generic arguments."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChildCustomDict(CustomDict):
|
||||||
|
"""A child of CustomDict."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GenericDict(dict):
|
||||||
|
"""A generic dict with no generic arguments."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChildGenericDict(GenericDict):
|
||||||
|
"""A child of GenericDict."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("value", "expected"),
|
("value", "expected"),
|
||||||
[
|
[
|
||||||
@ -15,6 +39,10 @@ from reflex.vars.base import figure_out_type
|
|||||||
([1, 2.0, "a"], List[Union[int, float, str]]),
|
([1, 2.0, "a"], List[Union[int, float, str]]),
|
||||||
({"a": 1, "b": 2}, Dict[str, int]),
|
({"a": 1, "b": 2}, Dict[str, int]),
|
||||||
({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
|
({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]),
|
||||||
|
(CustomDict(), CustomDict),
|
||||||
|
(ChildCustomDict(), ChildCustomDict),
|
||||||
|
(GenericDict({1: 1}), Dict[int, int]),
|
||||||
|
(ChildGenericDict({1: 1}), Dict[int, int]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_figure_out_type(value, expected):
|
def test_figure_out_type(value, expected):
|
||||||
|
Loading…
Reference in New Issue
Block a user