From ab558ce17285a30ccf88d479277f2b2ebea2760c Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 6 Feb 2025 10:09:40 -0800 Subject: [PATCH] increase nested type checking for component var types (#4756) * increase nested type checking for component var types * handle mapping as dict in type hint * fix weird cases of using _isinstance instead of isinstance * test out nested=0 * move union below * don't use _instance for simple unions --- reflex/components/component.py | 15 +++------ reflex/components/core/match.py | 8 ++--- reflex/components/markdown/markdown.py | 5 ++- reflex/components/tags/tag.py | 4 +-- reflex/utils/types.py | 42 +++++++++++++++++--------- 5 files changed, 40 insertions(+), 34 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 6d1264f4d..6e4c6c37f 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -179,6 +179,7 @@ ComponentStyle = Dict[ Union[str, Type[BaseComponent], Callable, ComponentNamespace], Any ] ComponentChild = Union[types.PrimitiveType, Var, BaseComponent] +ComponentChildTypes = (*types.PrimitiveTypes, Var, BaseComponent) def satisfies_type_hint(obj: Any, type_hint: Any) -> bool: @@ -191,11 +192,7 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool: Returns: Whether the object satisfies the type hint. """ - if isinstance(obj, LiteralVar): - return types._isinstance(obj._var_value, type_hint) - if isinstance(obj, Var): - return types._issubclass(obj._var_type, type_hint) - return types._isinstance(obj, type_hint) + return types._isinstance(obj, type_hint, nested=1) class Component(BaseComponent, ABC): @@ -712,8 +709,8 @@ class Component(BaseComponent, ABC): validate_children(child) # Make sure the child is a valid type. - if isinstance(child, dict) or not types._isinstance( - child, ComponentChild + if isinstance(child, dict) or not isinstance( + child, ComponentChildTypes ): raise ChildrenTypeError(component=cls.__name__, child=child) @@ -1771,9 +1768,7 @@ class CustomComponent(Component): return [ Var( _js_expr=name, - _var_type=( - prop._var_type if types._isinstance(prop, Var) else type(prop) - ), + _var_type=(prop._var_type if isinstance(prop, Var) else type(prop)), ).guess_type() for name, prop in self.props.items() ] diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py index 5c31669a1..2d936544a 100644 --- a/reflex/components/core/match.py +++ b/reflex/components/core/match.py @@ -178,9 +178,9 @@ class Match(MemoizationLeaf): first_case_return = match_cases[0][-1] return_type = type(first_case_return) - if types._isinstance(first_case_return, BaseComponent): + if isinstance(first_case_return, BaseComponent): return_type = BaseComponent - elif types._isinstance(first_case_return, Var): + elif isinstance(first_case_return, Var): return_type = Var for index, case in enumerate(match_cases): @@ -228,8 +228,8 @@ class Match(MemoizationLeaf): # Validate the match cases (as well as the default case) to have Var return types. if any( - case for case in match_cases if not types._isinstance(case[-1], Var) - ) or not types._isinstance(default, Var): + case for case in match_cases if not isinstance(case[-1], Var) + ) or not isinstance(default, Var): raise ValueError("Return types of match cases should be Vars.") return Var( diff --git a/reflex/components/markdown/markdown.py b/reflex/components/markdown/markdown.py index 91d34ea9b..51d3dd3dd 100644 --- a/reflex/components/markdown/markdown.py +++ b/reflex/components/markdown/markdown.py @@ -6,11 +6,10 @@ import dataclasses import textwrap from functools import lru_cache from hashlib import md5 -from typing import Any, Callable, Dict, Sequence, Union +from typing import Any, Callable, Dict, Sequence from reflex.components.component import BaseComponent, Component, CustomComponent from reflex.components.tags.tag import Tag -from reflex.utils import types from reflex.utils.imports import ImportDict, ImportVar from reflex.vars.base import LiteralVar, Var, VarData from reflex.vars.function import ARRAY_ISARRAY, ArgsFunctionOperation, DestructuredArg @@ -169,7 +168,7 @@ class Markdown(Component): Returns: The markdown component. """ - if len(children) != 1 or not types._isinstance(children[0], Union[str, Var]): + if len(children) != 1 or not isinstance(children[0], (str, Var)): raise ValueError( "Markdown component must have exactly one child containing the markdown source." ) diff --git a/reflex/components/tags/tag.py b/reflex/components/tags/tag.py index 983726e56..515d9e05f 100644 --- a/reflex/components/tags/tag.py +++ b/reflex/components/tags/tag.py @@ -3,7 +3,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence from reflex.event import EventChain from reflex.utils import format, types @@ -103,7 +103,7 @@ class Tag: { format.to_camel_case(name, allow_hyphens=True): ( prop - if types._isinstance(prop, Union[EventChain, dict]) + if types._isinstance(prop, (EventChain, Mapping)) else LiteralVar.create(prop) ) # rx.color is always a string for name, prop in kwargs.items() diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 58fec8f3b..b432319e0 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -95,6 +95,7 @@ GenericType = Union[Type, _GenericAlias] # Valid state var types. JSONType = {str, int, float, bool} PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple] +PrimitiveTypes = (int, float, bool, str, list, dict, set, tuple) StateVar = Union[PrimitiveType, Base, None] StateIterVar = Union[list, set, tuple] @@ -551,13 +552,13 @@ def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool: return required_keys.issubset(required_keys) -def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool: +def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool: """Check if an object is an instance of a class. Args: obj: The object to check. cls: The class to check against. - nested: Whether the check is nested. + nested: How many levels deep to check. Returns: Whether the object is an instance of the class. @@ -565,15 +566,24 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool: if cls is Any: return True + from reflex.vars import LiteralVar, Var + + if cls is Var: + return isinstance(obj, Var) + if isinstance(obj, LiteralVar): + return _isinstance(obj._var_value, cls, nested=nested) + if isinstance(obj, Var): + return _issubclass(obj._var_type, cls) + if cls is None or cls is type(None): return obj is None + if cls and is_union(cls): + return any(_isinstance(obj, arg, nested=nested) for arg in get_args(cls)) + if is_literal(cls): return obj in get_args(cls) - if is_union(cls): - return any(_isinstance(obj, arg) for arg in get_args(cls)) - origin = get_origin(cls) if origin is None: @@ -596,38 +606,40 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool: # cls is a simple generic class return isinstance(obj, origin) - if nested and args: + if nested > 0 and args: if origin is list: return isinstance(obj, list) and all( - _isinstance(item, args[0]) for item in obj + _isinstance(item, args[0], nested=nested - 1) for item in obj ) if origin is tuple: if args[-1] is Ellipsis: return isinstance(obj, tuple) and all( - _isinstance(item, args[0]) for item in obj + _isinstance(item, args[0], nested=nested - 1) for item in obj ) return ( isinstance(obj, tuple) and len(obj) == len(args) and all( - _isinstance(item, arg) for item, arg in zip(obj, args, strict=True) + _isinstance(item, arg, nested=nested - 1) + for item, arg in zip(obj, args, strict=True) ) ) - if origin in (dict, Breakpoints): - return isinstance(obj, dict) and all( - _isinstance(key, args[0]) and _isinstance(value, args[1]) + if origin in (dict, Mapping, Breakpoints): + return isinstance(obj, Mapping) and all( + _isinstance(key, args[0], nested=nested - 1) + and _isinstance(value, args[1], nested=nested - 1) for key, value in obj.items() ) if origin is set: return isinstance(obj, set) and all( - _isinstance(item, args[0]) for item in obj + _isinstance(item, args[0], nested=nested - 1) for item in obj ) if args: from reflex.vars import Field if origin is Field: - return _isinstance(obj, args[0]) + return _isinstance(obj, args[0], nested=nested) return isinstance(obj, get_base_class(cls)) @@ -749,7 +761,7 @@ def check_prop_in_allowed_types(prop: Any, allowed_types: Iterable) -> bool: """ from reflex.vars import Var - type_ = prop._var_type if _isinstance(prop, Var) else type(prop) + type_ = prop._var_type if isinstance(prop, Var) else type(prop) return type_ in allowed_types