increase nested type checking for component var types

This commit is contained in:
Khaleel Al-Adhami 2025-02-04 18:56:45 -08:00
parent af9a914ecc
commit a1b31b4baf
2 changed files with 14 additions and 12 deletions

View File

@ -192,10 +192,10 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
Whether the object satisfies the type hint. Whether the object satisfies the type hint.
""" """
if isinstance(obj, LiteralVar): if isinstance(obj, LiteralVar):
return types._isinstance(obj._var_value, type_hint) return types._isinstance(obj._var_value, type_hint, nested=1)
if isinstance(obj, Var): if isinstance(obj, Var):
return types._issubclass(obj._var_type, type_hint) return types._issubclass(obj._var_type, type_hint)
return types._isinstance(obj, type_hint) return types._isinstance(obj, type_hint, nested=1)
class Component(BaseComponent, ABC): class Component(BaseComponent, ABC):

View File

@ -551,13 +551,13 @@ def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
return required_keys.issubset(required_keys) return required_keys.issubset(required_keys)
def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool: def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
"""Check if an object is an instance of a class. """Check if an object is an instance of a class.
Args: Args:
obj: The object to check. obj: The object to check.
cls: The class to check against. cls: The class to check against.
nested: Whether the check is nested. nested: How many levels deep to check.
Returns: Returns:
Whether the object is an instance of the class. Whether the object is an instance of the class.
@ -572,7 +572,7 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
return obj in get_args(cls) return obj in get_args(cls)
if is_union(cls): if is_union(cls):
return any(_isinstance(obj, arg) for arg in get_args(cls)) return any(_isinstance(obj, arg, nested=nested) for arg in get_args(cls))
origin = get_origin(cls) origin = get_origin(cls)
@ -596,38 +596,40 @@ def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
# cls is a simple generic class # cls is a simple generic class
return isinstance(obj, origin) return isinstance(obj, origin)
if nested and args: if nested > 0 and args:
if origin is list: if origin is list:
return isinstance(obj, list) and all( return isinstance(obj, list) and all(
_isinstance(item, args[0]) for item in obj _isinstance(item, args[0], nested=nested - 1) for item in obj
) )
if origin is tuple: if origin is tuple:
if args[-1] is Ellipsis: if args[-1] is Ellipsis:
return isinstance(obj, tuple) and all( return isinstance(obj, tuple) and all(
_isinstance(item, args[0]) for item in obj _isinstance(item, args[0], nested=nested - 1) for item in obj
) )
return ( return (
isinstance(obj, tuple) isinstance(obj, tuple)
and len(obj) == len(args) and len(obj) == len(args)
and all( and all(
_isinstance(item, arg) for item, arg in zip(obj, args, strict=True) _isinstance(item, arg, nested=nested - 1)
for item, arg in zip(obj, args, strict=True)
) )
) )
if origin in (dict, Breakpoints): if origin in (dict, Breakpoints):
return isinstance(obj, dict) and all( return isinstance(obj, dict) and all(
_isinstance(key, args[0]) and _isinstance(value, args[1]) _isinstance(key, args[0], nested=nested - 1)
and _isinstance(value, args[1], nested=nested - 1)
for key, value in obj.items() for key, value in obj.items()
) )
if origin is set: if origin is set:
return isinstance(obj, set) and all( return isinstance(obj, set) and all(
_isinstance(item, args[0]) for item in obj _isinstance(item, args[0], nested=nested - 1) for item in obj
) )
if args: if args:
from reflex.vars import Field from reflex.vars import Field
if origin is Field: if origin is Field:
return _isinstance(obj, args[0]) return _isinstance(obj, args[0], nested=nested)
return isinstance(obj, get_base_class(cls)) return isinstance(obj, get_base_class(cls))