increase nested type checking for component var types (#4756)

* increase nested type checking for component var types

* handle mapping as dict in type hint

* fix weird cases of using _isinstance instead of isinstance

* test out nested=0

* move union below

* don't use _instance for simple unions
This commit is contained in:
Khaleel Al-Adhami 2025-02-06 10:09:40 -08:00 committed by GitHub
parent 9d23271c14
commit ab558ce172
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 34 deletions

View File

@ -179,6 +179,7 @@ ComponentStyle = Dict[
Union[str, Type[BaseComponent], Callable, ComponentNamespace], Any Union[str, Type[BaseComponent], Callable, ComponentNamespace], Any
] ]
ComponentChild = Union[types.PrimitiveType, Var, BaseComponent] ComponentChild = Union[types.PrimitiveType, Var, BaseComponent]
ComponentChildTypes = (*types.PrimitiveTypes, Var, BaseComponent)
def satisfies_type_hint(obj: Any, type_hint: Any) -> bool: def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
@ -191,11 +192,7 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
Returns: Returns:
Whether the object satisfies the type hint. Whether the object satisfies the type hint.
""" """
if isinstance(obj, LiteralVar): return types._isinstance(obj, type_hint, nested=1)
return types._isinstance(obj._var_value, type_hint)
if isinstance(obj, Var):
return types._issubclass(obj._var_type, type_hint)
return types._isinstance(obj, type_hint)
class Component(BaseComponent, ABC): class Component(BaseComponent, ABC):
@ -712,8 +709,8 @@ class Component(BaseComponent, ABC):
validate_children(child) validate_children(child)
# Make sure the child is a valid type. # Make sure the child is a valid type.
if isinstance(child, dict) or not types._isinstance( if isinstance(child, dict) or not isinstance(
child, ComponentChild child, ComponentChildTypes
): ):
raise ChildrenTypeError(component=cls.__name__, child=child) raise ChildrenTypeError(component=cls.__name__, child=child)
@ -1771,9 +1768,7 @@ class CustomComponent(Component):
return [ return [
Var( Var(
_js_expr=name, _js_expr=name,
_var_type=( _var_type=(prop._var_type if isinstance(prop, Var) else type(prop)),
prop._var_type if types._isinstance(prop, Var) else type(prop)
),
).guess_type() ).guess_type()
for name, prop in self.props.items() for name, prop in self.props.items()
] ]

View File

@ -178,9 +178,9 @@ class Match(MemoizationLeaf):
first_case_return = match_cases[0][-1] first_case_return = match_cases[0][-1]
return_type = type(first_case_return) return_type = type(first_case_return)
if types._isinstance(first_case_return, BaseComponent): if isinstance(first_case_return, BaseComponent):
return_type = BaseComponent return_type = BaseComponent
elif types._isinstance(first_case_return, Var): elif isinstance(first_case_return, Var):
return_type = Var return_type = Var
for index, case in enumerate(match_cases): for index, case in enumerate(match_cases):
@ -228,8 +228,8 @@ class Match(MemoizationLeaf):
# Validate the match cases (as well as the default case) to have Var return types. # Validate the match cases (as well as the default case) to have Var return types.
if any( if any(
case for case in match_cases if not types._isinstance(case[-1], Var) case for case in match_cases if not isinstance(case[-1], Var)
) or not types._isinstance(default, Var): ) or not isinstance(default, Var):
raise ValueError("Return types of match cases should be Vars.") raise ValueError("Return types of match cases should be Vars.")
return Var( return Var(

View File

@ -6,11 +6,10 @@ import dataclasses
import textwrap import textwrap
from functools import lru_cache from functools import lru_cache
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Dict, Sequence, Union from typing import Any, Callable, Dict, Sequence
from reflex.components.component import BaseComponent, Component, CustomComponent from reflex.components.component import BaseComponent, Component, CustomComponent
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.utils import types
from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars.base import LiteralVar, Var, VarData from reflex.vars.base import LiteralVar, Var, VarData
from reflex.vars.function import ARRAY_ISARRAY, ArgsFunctionOperation, DestructuredArg from reflex.vars.function import ARRAY_ISARRAY, ArgsFunctionOperation, DestructuredArg
@ -169,7 +168,7 @@ class Markdown(Component):
Returns: Returns:
The markdown component. The markdown component.
""" """
if len(children) != 1 or not types._isinstance(children[0], Union[str, Var]): if len(children) != 1 or not isinstance(children[0], (str, Var)):
raise ValueError( raise ValueError(
"Markdown component must have exactly one child containing the markdown source." "Markdown component must have exactly one child containing the markdown source."
) )

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
from typing import Any, Dict, List, Optional, Sequence, Union from typing import Any, Dict, List, Mapping, Optional, Sequence
from reflex.event import EventChain from reflex.event import EventChain
from reflex.utils import format, types from reflex.utils import format, types
@ -103,7 +103,7 @@ class Tag:
{ {
format.to_camel_case(name, allow_hyphens=True): ( format.to_camel_case(name, allow_hyphens=True): (
prop prop
if types._isinstance(prop, Union[EventChain, dict]) if types._isinstance(prop, (EventChain, Mapping))
else LiteralVar.create(prop) else LiteralVar.create(prop)
) # rx.color is always a string ) # rx.color is always a string
for name, prop in kwargs.items() for name, prop in kwargs.items()

View File

@ -95,6 +95,7 @@ GenericType = Union[Type, _GenericAlias]
# Valid state var types. # Valid state var types.
JSONType = {str, int, float, bool} JSONType = {str, int, float, bool}
PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple] PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
PrimitiveTypes = (int, float, bool, str, list, dict, set, tuple)
StateVar = Union[PrimitiveType, Base, None] StateVar = Union[PrimitiveType, Base, None]
StateIterVar = Union[list, set, tuple] StateIterVar = Union[list, set, tuple]
@ -551,13 +552,13 @@ def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
return required_keys.issubset(required_keys) return required_keys.issubset(required_keys)
def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool: def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
"""Check if an object is an instance of a class. """Check if an object is an instance of a class.
Args: Args:
obj: The object to check. obj: The object to check.
cls: The class to check against. cls: The class to check against.
nested: Whether the check is nested. nested: How many levels deep to check.
Returns: Returns:
Whether the object is an instance of the class. Whether the object is an instance of the class.
@ -565,15 +566,24 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
if cls is Any: if cls is Any:
return True return True
from reflex.vars import LiteralVar, Var
if cls is Var:
return isinstance(obj, Var)
if isinstance(obj, LiteralVar):
return _isinstance(obj._var_value, cls, nested=nested)
if isinstance(obj, Var):
return _issubclass(obj._var_type, cls)
if cls is None or cls is type(None): if cls is None or cls is type(None):
return obj is None return obj is None
if cls and is_union(cls):
return any(_isinstance(obj, arg, nested=nested) for arg in get_args(cls))
if is_literal(cls): if is_literal(cls):
return obj in get_args(cls) return obj in get_args(cls)
if is_union(cls):
return any(_isinstance(obj, arg) for arg in get_args(cls))
origin = get_origin(cls) origin = get_origin(cls)
if origin is None: if origin is None:
@ -596,38 +606,40 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
# cls is a simple generic class # cls is a simple generic class
return isinstance(obj, origin) return isinstance(obj, origin)
if nested and args: if nested > 0 and args:
if origin is list: if origin is list:
return isinstance(obj, list) and all( return isinstance(obj, list) and all(
_isinstance(item, args[0]) for item in obj _isinstance(item, args[0], nested=nested - 1) for item in obj
) )
if origin is tuple: if origin is tuple:
if args[-1] is Ellipsis: if args[-1] is Ellipsis:
return isinstance(obj, tuple) and all( return isinstance(obj, tuple) and all(
_isinstance(item, args[0]) for item in obj _isinstance(item, args[0], nested=nested - 1) for item in obj
) )
return ( return (
isinstance(obj, tuple) isinstance(obj, tuple)
and len(obj) == len(args) and len(obj) == len(args)
and all( and all(
_isinstance(item, arg) for item, arg in zip(obj, args, strict=True) _isinstance(item, arg, nested=nested - 1)
for item, arg in zip(obj, args, strict=True)
) )
) )
if origin in (dict, Breakpoints): if origin in (dict, Mapping, Breakpoints):
return isinstance(obj, dict) and all( return isinstance(obj, Mapping) and all(
_isinstance(key, args[0]) and _isinstance(value, args[1]) _isinstance(key, args[0], nested=nested - 1)
and _isinstance(value, args[1], nested=nested - 1)
for key, value in obj.items() for key, value in obj.items()
) )
if origin is set: if origin is set:
return isinstance(obj, set) and all( return isinstance(obj, set) and all(
_isinstance(item, args[0]) for item in obj _isinstance(item, args[0], nested=nested - 1) for item in obj
) )
if args: if args:
from reflex.vars import Field from reflex.vars import Field
if origin is Field: if origin is Field:
return _isinstance(obj, args[0]) return _isinstance(obj, args[0], nested=nested)
return isinstance(obj, get_base_class(cls)) return isinstance(obj, get_base_class(cls))
@ -749,7 +761,7 @@ def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool:
""" """
from reflex.vars import Var from reflex.vars import Var
type_ = prop._var_type if _isinstance(prop, Var) else type(prop) type_ = prop._var_type if isinstance(prop, Var) else type(prop)
return type_ in allowed_types return type_ in allowed_types