increase nested type checking for component var types (#4756)

* increase nested type checking for component var types

* handle mapping as dict in type hint

* fix weird cases of using _isinstance instead of isinstance

* test out nested=0

* move union below

* don't use _instance for simple unions
This commit is contained in:
Khaleel Al-Adhami 2025-02-06 10:09:40 -08:00 committed by GitHub
parent 9d23271c14
commit ab558ce172
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 34 deletions

View File

@ -179,6 +179,7 @@ ComponentStyle = Dict[
Union[str, Type[BaseComponent], Callable, ComponentNamespace], Any
]
ComponentChild = Union[types.PrimitiveType, Var, BaseComponent]
ComponentChildTypes = (*types.PrimitiveTypes, Var, BaseComponent)
def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
@ -191,11 +192,7 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
Returns:
Whether the object satisfies the type hint.
"""
if isinstance(obj, LiteralVar):
return types._isinstance(obj._var_value, type_hint)
if isinstance(obj, Var):
return types._issubclass(obj._var_type, type_hint)
return types._isinstance(obj, type_hint)
return types._isinstance(obj, type_hint, nested=1)
class Component(BaseComponent, ABC):
@ -712,8 +709,8 @@ class Component(BaseComponent, ABC):
validate_children(child)
# Make sure the child is a valid type.
if isinstance(child, dict) or not types._isinstance(
child, ComponentChild
if isinstance(child, dict) or not isinstance(
child, ComponentChildTypes
):
raise ChildrenTypeError(component=cls.__name__, child=child)
@ -1771,9 +1768,7 @@ class CustomComponent(Component):
return [
Var(
_js_expr=name,
_var_type=(
prop._var_type if types._isinstance(prop, Var) else type(prop)
),
_var_type=(prop._var_type if isinstance(prop, Var) else type(prop)),
).guess_type()
for name, prop in self.props.items()
]

View File

@ -178,9 +178,9 @@ class Match(MemoizationLeaf):
first_case_return = match_cases[0][-1]
return_type = type(first_case_return)
if types._isinstance(first_case_return, BaseComponent):
if isinstance(first_case_return, BaseComponent):
return_type = BaseComponent
elif types._isinstance(first_case_return, Var):
elif isinstance(first_case_return, Var):
return_type = Var
for index, case in enumerate(match_cases):
@ -228,8 +228,8 @@ class Match(MemoizationLeaf):
# Validate the match cases (as well as the default case) to have Var return types.
if any(
case for case in match_cases if not types._isinstance(case[-1], Var)
) or not types._isinstance(default, Var):
case for case in match_cases if not isinstance(case[-1], Var)
) or not isinstance(default, Var):
raise ValueError("Return types of match cases should be Vars.")
return Var(

View File

@ -6,11 +6,10 @@ import dataclasses
import textwrap
from functools import lru_cache
from hashlib import md5
from typing import Any, Callable, Dict, Sequence, Union
from typing import Any, Callable, Dict, Sequence
from reflex.components.component import BaseComponent, Component, CustomComponent
from reflex.components.tags.tag import Tag
from reflex.utils import types
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars.base import LiteralVar, Var, VarData
from reflex.vars.function import ARRAY_ISARRAY, ArgsFunctionOperation, DestructuredArg
@ -169,7 +168,7 @@ class Markdown(Component):
Returns:
The markdown component.
"""
if len(children) != 1 or not types._isinstance(children[0], Union[str, Var]):
if len(children) != 1 or not isinstance(children[0], (str, Var)):
raise ValueError(
"Markdown component must have exactly one child containing the markdown source."
)

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import dataclasses
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence
from reflex.event import EventChain
from reflex.utils import format, types
@ -103,7 +103,7 @@ class Tag:
{
format.to_camel_case(name, allow_hyphens=True): (
prop
if types._isinstance(prop, Union[EventChain, dict])
if types._isinstance(prop, (EventChain, Mapping))
else LiteralVar.create(prop)
) # rx.color is always a string
for name, prop in kwargs.items()

View File

@ -95,6 +95,7 @@ GenericType = Union[Type, _GenericAlias]
# Valid state var types.
JSONType = {str, int, float, bool}
PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
PrimitiveTypes = (int, float, bool, str, list, dict, set, tuple)
StateVar = Union[PrimitiveType, Base, None]
StateIterVar = Union[list, set, tuple]
@ -551,13 +552,13 @@ def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
return required_keys.issubset(required_keys)
def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
"""Check if an object is an instance of a class.
Args:
obj: The object to check.
cls: The class to check against.
nested: Whether the check is nested.
nested: How many levels deep to check.
Returns:
Whether the object is an instance of the class.
@ -565,15 +566,24 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
if cls is Any:
return True
from reflex.vars import LiteralVar, Var
if cls is Var:
return isinstance(obj, Var)
if isinstance(obj, LiteralVar):
return _isinstance(obj._var_value, cls, nested=nested)
if isinstance(obj, Var):
return _issubclass(obj._var_type, cls)
if cls is None or cls is type(None):
return obj is None
if cls and is_union(cls):
return any(_isinstance(obj, arg, nested=nested) for arg in get_args(cls))
if is_literal(cls):
return obj in get_args(cls)
if is_union(cls):
return any(_isinstance(obj, arg) for arg in get_args(cls))
origin = get_origin(cls)
if origin is None:
@ -596,38 +606,40 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
# cls is a simple generic class
return isinstance(obj, origin)
if nested and args:
if nested > 0 and args:
if origin is list:
return isinstance(obj, list) and all(
_isinstance(item, args[0]) for item in obj
_isinstance(item, args[0], nested=nested - 1) for item in obj
)
if origin is tuple:
if args[-1] is Ellipsis:
return isinstance(obj, tuple) and all(
_isinstance(item, args[0]) for item in obj
_isinstance(item, args[0], nested=nested - 1) for item in obj
)
return (
isinstance(obj, tuple)
and len(obj) == len(args)
and all(
_isinstance(item, arg) for item, arg in zip(obj, args, strict=True)
_isinstance(item, arg, nested=nested - 1)
for item, arg in zip(obj, args, strict=True)
)
)
if origin in (dict, Breakpoints):
return isinstance(obj, dict) and all(
_isinstance(key, args[0]) and _isinstance(value, args[1])
if origin in (dict, Mapping, Breakpoints):
return isinstance(obj, Mapping) and all(
_isinstance(key, args[0], nested=nested - 1)
and _isinstance(value, args[1], nested=nested - 1)
for key, value in obj.items()
)
if origin is set:
return isinstance(obj, set) and all(
_isinstance(item, args[0]) for item in obj
_isinstance(item, args[0], nested=nested - 1) for item in obj
)
if args:
from reflex.vars import Field
if origin is Field:
return _isinstance(obj, args[0])
return _isinstance(obj, args[0], nested=nested)
return isinstance(obj, get_base_class(cls))
@ -749,7 +761,7 @@ def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool:
"""
from reflex.vars import Var
type_ = prop._var_type if _isinstance(prop, Var) else type(prop)
type_ = prop._var_type if isinstance(prop, Var) else type(prop)
return type_ in allowed_types