make object var handle all mapping instead of just dict (#4602)

* make object var handle all mapping instead of just dict

* unbreak ci

* get it right pyright

* create generic variable for field

* add support for typeddict (to some degree)

* import from extensions
This commit is contained in:
Khaleel Al-Adhami 2025-01-21 13:14:02 -08:00 committed by GitHub
parent abaaa22adb
commit bea266b8ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 100 additions and 64 deletions

View File

@ -829,6 +829,22 @@ StateBases = get_base_class(StateVar)
StateIterBases = get_base_class(StateIterVar) StateIterBases = get_base_class(StateIterVar)
def safe_issubclass(cls: Type, cls_check: Type | Tuple[Type, ...]):
"""Check if a class is a subclass of another class. Returns False if internal error occurs.
Args:
cls: The class to check.
cls_check: The class to check against.
Returns:
Whether the class is a subclass of the other class.
"""
try:
return issubclass(cls, cls_check)
except TypeError:
return False
def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool: def typehint_issubclass(possible_subclass: Any, possible_superclass: Any) -> bool:
"""Check if a type hint is a subclass of another type hint. """Check if a type hint is a subclass of another type hint.

View File

@ -26,6 +26,7 @@ from typing import (
Iterable, Iterable,
List, List,
Literal, Literal,
Mapping,
NoReturn, NoReturn,
Optional, Optional,
Set, Set,
@ -64,6 +65,7 @@ from reflex.utils.types import (
_isinstance, _isinstance,
get_origin, get_origin,
has_args, has_args,
safe_issubclass,
unionize, unionize,
) )
@ -127,7 +129,7 @@ class VarData:
state: str = "", state: str = "",
field_name: str = "", field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None, imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, VarData | None] | None = None, hooks: Mapping[str, VarData | None] | None = None,
deps: list[Var] | None = None, deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None, position: Hooks.HookPosition | None = None,
): ):
@ -643,8 +645,8 @@ class Var(Generic[VAR_TYPE]):
@overload @overload
def to( def to(
self, self,
output: type[dict], output: type[Mapping],
) -> ObjectVar[dict]: ... ) -> ObjectVar[Mapping]: ...
@overload @overload
def to( def to(
@ -686,7 +688,9 @@ class Var(Generic[VAR_TYPE]):
# If the first argument is a python type, we map it to the corresponding Var type. # If the first argument is a python type, we map it to the corresponding Var type.
for var_subclass in _var_subclasses[::-1]: for var_subclass in _var_subclasses[::-1]:
if fixed_output_type in var_subclass.python_types: if fixed_output_type in var_subclass.python_types or safe_issubclass(
fixed_output_type, var_subclass.python_types
):
return self.to(var_subclass.var_subclass, output) return self.to(var_subclass.var_subclass, output)
if fixed_output_type is None: if fixed_output_type is None:
@ -820,7 +824,7 @@ class Var(Generic[VAR_TYPE]):
return False return False
if issubclass(type_, list): if issubclass(type_, list):
return [] return []
if issubclass(type_, dict): if issubclass(type_, Mapping):
return {} return {}
if issubclass(type_, tuple): if issubclass(type_, tuple):
return () return ()
@ -1026,7 +1030,7 @@ class Var(Generic[VAR_TYPE]):
f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")] f"$/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
} }
), ),
).to(ObjectVar, Dict[str, str]) ).to(ObjectVar, Mapping[str, str])
return refs[LiteralVar.create(str(self))] return refs[LiteralVar.create(str(self))]
@deprecated("Use `.js_type()` instead.") @deprecated("Use `.js_type()` instead.")
@ -1373,7 +1377,7 @@ class LiteralVar(Var):
serialized_value = serializers.serialize(value) serialized_value = serializers.serialize(value)
if serialized_value is not None: if serialized_value is not None:
if isinstance(serialized_value, dict): if isinstance(serialized_value, Mapping):
return LiteralObjectVar.create( return LiteralObjectVar.create(
serialized_value, serialized_value,
_var_type=type(value), _var_type=type(value),
@ -1498,7 +1502,7 @@ def var_operation(
) -> Callable[P, ArrayVar[LIST_T]]: ... ) -> Callable[P, ArrayVar[LIST_T]]: ...
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Mapping)
@overload @overload
@ -1573,8 +1577,8 @@ def figure_out_type(value: Any) -> types.GenericType:
return Set[unionize(*(figure_out_type(v) for v in value))] return Set[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, tuple): if isinstance(value, tuple):
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...] return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
if isinstance(value, dict): if isinstance(value, Mapping):
return Dict[ return Mapping[
unionize(*(figure_out_type(k) for k in value)), unionize(*(figure_out_type(k) for k in value)),
unionize(*(figure_out_type(v) for v in value.values())), unionize(*(figure_out_type(v) for v in value.values())),
] ]
@ -2002,10 +2006,10 @@ class ComputedVar(Var[RETURN_TYPE]):
@overload @overload
def __get__( def __get__(
self: ComputedVar[dict[DICT_KEY, DICT_VAL]], self: ComputedVar[Mapping[DICT_KEY, DICT_VAL]],
instance: None, instance: None,
owner: Type, owner: Type,
) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ... ) -> ObjectVar[Mapping[DICT_KEY, DICT_VAL]]: ...
@overload @overload
def __get__( def __get__(
@ -2915,11 +2919,14 @@ V = TypeVar("V")
BASE_TYPE = TypeVar("BASE_TYPE", bound=Base) BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)
FIELD_TYPE = TypeVar("FIELD_TYPE")
MAPPING_TYPE = TypeVar("MAPPING_TYPE", bound=Mapping)
class Field(Generic[T]):
class Field(Generic[FIELD_TYPE]):
"""Shadow class for Var to allow for type hinting in the IDE.""" """Shadow class for Var to allow for type hinting in the IDE."""
def __set__(self, instance, value: T): def __set__(self, instance, value: FIELD_TYPE):
"""Set the Var. """Set the Var.
Args: Args:
@ -2931,7 +2938,9 @@ class Field(Generic[T]):
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ... def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
@overload @overload
def __get__(self: Field[int], instance: None, owner) -> NumberVar: ... def __get__(
self: Field[int] | Field[float] | Field[int | float], instance: None, owner
) -> NumberVar: ...
@overload @overload
def __get__(self: Field[str], instance: None, owner) -> StringVar: ... def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
@ -2948,8 +2957,8 @@ class Field(Generic[T]):
@overload @overload
def __get__( def __get__(
self: Field[Dict[str, V]], instance: None, owner self: Field[MAPPING_TYPE], instance: None, owner
) -> ObjectVar[Dict[str, V]]: ... ) -> ObjectVar[MAPPING_TYPE]: ...
@overload @overload
def __get__( def __get__(
@ -2957,10 +2966,10 @@ class Field(Generic[T]):
) -> ObjectVar[BASE_TYPE]: ... ) -> ObjectVar[BASE_TYPE]: ...
@overload @overload
def __get__(self, instance: None, owner) -> Var[T]: ... def __get__(self, instance: None, owner) -> Var[FIELD_TYPE]: ...
@overload @overload
def __get__(self, instance, owner) -> T: ... def __get__(self, instance, owner) -> FIELD_TYPE: ...
def __get__(self, instance, owner): # type: ignore def __get__(self, instance, owner): # type: ignore
"""Get the Var. """Get the Var.
@ -2971,7 +2980,7 @@ class Field(Generic[T]):
""" """
def field(value: T) -> Field[T]: def field(value: FIELD_TYPE) -> Field[FIELD_TYPE]:
"""Create a Field with a value. """Create a Field with a value.
Args: Args:

View File

@ -8,8 +8,8 @@ import typing
from inspect import isclass from inspect import isclass
from typing import ( from typing import (
Any, Any,
Dict,
List, List,
Mapping,
NoReturn, NoReturn,
Tuple, Tuple,
Type, Type,
@ -19,6 +19,8 @@ from typing import (
overload, overload,
) )
from typing_extensions import is_typeddict
from reflex.utils import types from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type, get_origin from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
@ -36,7 +38,7 @@ from .base import (
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
from .sequence import ArrayVar, StringVar from .sequence import ArrayVar, StringVar
OBJECT_TYPE = TypeVar("OBJECT_TYPE") OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
KEY_TYPE = TypeVar("KEY_TYPE") KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE")
@ -46,7 +48,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE") OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
class ObjectVar(Var[OBJECT_TYPE], python_types=dict): class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
"""Base class for immutable object vars.""" """Base class for immutable object vars."""
def _key_type(self) -> Type: def _key_type(self) -> Type:
@ -59,7 +61,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload @overload
def _value_type( def _value_type(
self: ObjectVar[Dict[Any, VALUE_TYPE]], self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> Type[VALUE_TYPE]: ... ) -> Type[VALUE_TYPE]: ...
@overload @overload
@ -74,7 +76,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
fixed_type = get_origin(self._var_type) or self._var_type fixed_type = get_origin(self._var_type) or self._var_type
if not isclass(fixed_type): if not isclass(fixed_type):
return Any return Any
args = get_args(self._var_type) if issubclass(fixed_type, dict) else () args = get_args(self._var_type) if issubclass(fixed_type, Mapping) else ()
return args[1] if args else Any return args[1] if args else Any
def keys(self) -> ArrayVar[List[str]]: def keys(self) -> ArrayVar[List[str]]:
@ -87,7 +89,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload @overload
def values( def values(
self: ObjectVar[Dict[Any, VALUE_TYPE]], self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[VALUE_TYPE]]: ... ) -> ArrayVar[List[VALUE_TYPE]]: ...
@overload @overload
@ -103,7 +105,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload @overload
def entries( def entries(
self: ObjectVar[Dict[Any, VALUE_TYPE]], self: ObjectVar[Mapping[Any, VALUE_TYPE]],
) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ... ) -> ArrayVar[List[Tuple[str, VALUE_TYPE]]]: ...
@overload @overload
@ -133,49 +135,55 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any # NoReturn is used here to catch when key value is Any
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, NoReturn]], self: ObjectVar[Mapping[Any, NoReturn]],
key: Var | Any, key: Var | Any,
) -> Var: ... ) -> Var: ...
@overload
def __getitem__(
self: (ObjectVar[Mapping[Any, bool]]),
key: Var | Any,
) -> BooleanVar: ...
@overload @overload
def __getitem__( def __getitem__(
self: ( self: (
ObjectVar[Dict[Any, int]] ObjectVar[Mapping[Any, int]]
| ObjectVar[Dict[Any, float]] | ObjectVar[Mapping[Any, float]]
| ObjectVar[Dict[Any, int | float]] | ObjectVar[Mapping[Any, int | float]]
), ),
key: Var | Any, key: Var | Any,
) -> NumberVar: ... ) -> NumberVar: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, str]], self: ObjectVar[Mapping[Any, str]],
key: Var | Any, key: Var | Any,
) -> StringVar: ... ) -> StringVar: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
key: Var | Any, key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
key: Var | Any, key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
key: Var | Any, key: Var | Any,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
key: Var | Any, key: Var | Any,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
def __getitem__(self, key: Var | Any) -> Var: def __getitem__(self, key: Var | Any) -> Var:
"""Get an item from the object. """Get an item from the object.
@ -195,49 +203,49 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
# NoReturn is used here to catch when key value is Any # NoReturn is used here to catch when key value is Any
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, NoReturn]], self: ObjectVar[Mapping[Any, NoReturn]],
name: str, name: str,
) -> Var: ... ) -> Var: ...
@overload @overload
def __getattr__( def __getattr__(
self: ( self: (
ObjectVar[Dict[Any, int]] ObjectVar[Mapping[Any, int]]
| ObjectVar[Dict[Any, float]] | ObjectVar[Mapping[Any, float]]
| ObjectVar[Dict[Any, int | float]] | ObjectVar[Mapping[Any, int | float]]
), ),
name: str, name: str,
) -> NumberVar: ... ) -> NumberVar: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, str]], self: ObjectVar[Mapping[Any, str]],
name: str, name: str,
) -> StringVar: ... ) -> StringVar: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, list[ARRAY_INNER_TYPE]]], self: ObjectVar[Mapping[Any, list[ARRAY_INNER_TYPE]]],
name: str, name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, set[ARRAY_INNER_TYPE]]], self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
name: str, name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, tuple[ARRAY_INNER_TYPE, ...]]], self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
name: str, name: str,
) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ... ) -> ArrayVar[tuple[ARRAY_INNER_TYPE, ...]]: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Dict[Any, dict[OTHER_KEY_TYPE, VALUE_TYPE]]], self: ObjectVar[Mapping[Any, Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]],
name: str, name: str,
) -> ObjectVar[dict[OTHER_KEY_TYPE, VALUE_TYPE]]: ... ) -> ObjectVar[Mapping[OTHER_KEY_TYPE, VALUE_TYPE]]: ...
@overload @overload
def __getattr__( def __getattr__(
@ -266,8 +274,11 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
var_type = get_args(var_type)[0] var_type = get_args(var_type)[0]
fixed_type = var_type if isclass(var_type) else get_origin(var_type) fixed_type = var_type if isclass(var_type) else get_origin(var_type)
if (isclass(fixed_type) and not issubclass(fixed_type, dict)) or (
fixed_type in types.UnionTypes if (
(isclass(fixed_type) and not issubclass(fixed_type, Mapping))
or (fixed_type in types.UnionTypes)
or is_typeddict(fixed_type)
): ):
attribute_type = get_attribute_access_type(var_type, name) attribute_type = get_attribute_access_type(var_type, name)
if attribute_type is None: if attribute_type is None:
@ -299,7 +310,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
"""Base class for immutable literal object vars.""" """Base class for immutable literal object vars."""
_var_value: Dict[Union[Var, Any], Union[Var, Any]] = dataclasses.field( _var_value: Mapping[Union[Var, Any], Union[Var, Any]] = dataclasses.field(
default_factory=dict default_factory=dict
) )
@ -383,7 +394,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
@classmethod @classmethod
def create( def create(
cls, cls,
_var_value: dict, _var_value: Mapping,
_var_type: Type[OBJECT_TYPE] | None = None, _var_type: Type[OBJECT_TYPE] | None = None,
_var_data: VarData | None = None, _var_data: VarData | None = None,
) -> LiteralObjectVar[OBJECT_TYPE]: ) -> LiteralObjectVar[OBJECT_TYPE]:
@ -466,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
""" """
return var_operation_return( return var_operation_return(
js_expression=f"({{...{lhs}, ...{rhs}}})", js_expression=f"({{...{lhs}, ...{rhs}}})",
var_type=Dict[ var_type=Mapping[
Union[lhs._key_type(), rhs._key_type()], Union[lhs._key_type(), rhs._key_type()],
Union[lhs._value_type(), rhs._value_type()], Union[lhs._value_type(), rhs._value_type()],
], ],

View File

@ -987,7 +987,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
raise_unsupported_operand_types("[]", (type(self), type(i))) raise_unsupported_operand_types("[]", (type(self), type(i)))
return array_item_operation(self, i) return array_item_operation(self, i)
def length(self) -> NumberVar: def length(self) -> NumberVar[int]:
"""Get the length of the array. """Get the length of the array.
Returns: Returns:

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple from typing import List, Mapping, Tuple
import pytest import pytest
@ -67,7 +67,7 @@ def test_match_components():
assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}' assert fourth_return_value_render["children"][0]["contents"] == '{"fourth value"}'
assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })' assert match_cases[4][0]._js_expr == '({ ["foo"] : "bar" })'
assert match_cases[4][0]._var_type == Dict[str, str] assert match_cases[4][0]._var_type == Mapping[str, str]
fifth_return_value_render = match_cases[4][1].render() fifth_return_value_render = match_cases[4][1].render()
assert fifth_return_value_render["name"] == "RadixThemesText" assert fifth_return_value_render["name"] == "RadixThemesText"
assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}' assert fifth_return_value_render["children"][0]["contents"] == '{"fifth value"}'

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict from typing import Any, Mapping
import pytest import pytest
@ -379,7 +379,7 @@ class StyleState(rx.State):
{ {
"css": Var( "css": Var(
_js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})' _js_expr=f'({{ ["color"] : ("dark"+{StyleState.color}) }})'
).to(Dict[str, str]) ).to(Mapping[str, str])
}, },
), ),
( (

View File

@ -2,7 +2,7 @@ import json
import math import math
import sys import sys
import typing import typing
from typing import Dict, List, Optional, Set, Tuple, Union, cast from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
import pytest import pytest
from pandas import DataFrame from pandas import DataFrame
@ -270,7 +270,7 @@ def test_get_setter(prop: Var, expected):
([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])), ([1, 2, 3], Var(_js_expr="[1, 2, 3]", _var_type=List[int])),
( (
{"a": 1, "b": 2}, {"a": 1, "b": 2},
Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Dict[str, int]), Var(_js_expr='({ ["a"] : 1, ["b"] : 2 })', _var_type=Mapping[str, int]),
), ),
], ],
) )

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union from typing import List, Mapping, Union
import pytest import pytest
@ -37,12 +37,12 @@ class ChildGenericDict(GenericDict):
("a", str), ("a", str),
([1, 2, 3], List[int]), ([1, 2, 3], List[int]),
([1, 2.0, "a"], List[Union[int, float, str]]), ([1, 2.0, "a"], List[Union[int, float, str]]),
({"a": 1, "b": 2}, Dict[str, int]), ({"a": 1, "b": 2}, Mapping[str, int]),
({"a": 1, 2: "b"}, Dict[Union[int, str], Union[str, int]]), ({"a": 1, 2: "b"}, Mapping[Union[int, str], Union[str, int]]),
(CustomDict(), CustomDict), (CustomDict(), CustomDict),
(ChildCustomDict(), ChildCustomDict), (ChildCustomDict(), ChildCustomDict),
(GenericDict({1: 1}), Dict[int, int]), (GenericDict({1: 1}), Mapping[int, int]),
(ChildGenericDict({1: 1}), Dict[int, int]), (ChildGenericDict({1: 1}), Mapping[int, int]),
], ],
) )
def test_figure_out_type(value, expected): def test_figure_out_type(value, expected):