This commit is contained in:
Khaleel Al-Adhami 2025-02-22 17:24:27 +00:00 committed by GitHub
commit 7051f3e557
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 58 additions and 18 deletions

View File

@ -192,7 +192,7 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
Returns: Returns:
Whether the object satisfies the type hint. Whether the object satisfies the type hint.
""" """
return types._isinstance(obj, type_hint, nested=1) return types._isinstance(obj, type_hint, nested=1, treat_var_as_type=True)
def _components_from( def _components_from(

View File

@ -6,7 +6,7 @@ import dataclasses
from typing import Any, Dict, List, Mapping, Optional, Sequence from typing import Any, Dict, List, Mapping, Optional, Sequence
from reflex.event import EventChain from reflex.event import EventChain
from reflex.utils import format, types from reflex.utils import format
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
@ -103,9 +103,9 @@ class Tag:
{ {
format.to_camel_case(name, treat_hyphens_as_underscores=False): ( format.to_camel_case(name, treat_hyphens_as_underscores=False): (
prop prop
if types._isinstance(prop, (EventChain, Mapping)) if isinstance(prop, (EventChain, Mapping))
else LiteralVar.create(prop) else LiteralVar.create(prop)
) # rx.color is always a string )
for name, prop in kwargs.items() for name, prop in kwargs.items()
if self.is_valid_prop(prop) if self.is_valid_prop(prop)
} }

View File

@ -695,7 +695,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
def computed_var_func(state: Self): def computed_var_func(state: Self):
result = f(state) result = f(state)
if not _isinstance(result, of_type): if not _isinstance(result, of_type, nested=1, treat_var_as_type=False):
console.warn( console.warn(
f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. " f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. "
"You can specify expected type with `of_type` argument." "You can specify expected type with `of_type` argument."
@ -1356,7 +1356,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
field_type = _unwrap_field_type(field.outer_type_) field_type = _unwrap_field_type(field.outer_type_)
if field.allow_none and not is_optional(field_type): if field.allow_none and not is_optional(field_type):
field_type = Union[field_type, None] field_type = Union[field_type, None]
if not _isinstance(value, field_type): if not _isinstance(value, field_type, nested=1, treat_var_as_type=False):
console.error( console.error(
f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}'," f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"
f" but got '{value}' of type '{type(value)}'." f" but got '{value}' of type '{type(value)}'."

View File

@ -510,13 +510,16 @@ def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
return required_keys.issubset(required_keys) return required_keys.issubset(required_keys)
def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool: def _isinstance(
obj: Any, cls: GenericType, *, nested: int = 0, treat_var_as_type: bool = True
) -> bool:
"""Check if an object is an instance of a class. """Check if an object is an instance of a class.
Args: Args:
obj: The object to check. obj: The object to check.
cls: The class to check against. cls: The class to check against.
nested: How many levels deep to check. nested: How many levels deep to check.
treat_var_as_type: Whether to treat Var as the type it represents, i.e. _var_type.
Returns: Returns:
Whether the object is an instance of the class. Whether the object is an instance of the class.
@ -529,15 +532,20 @@ def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
if cls is Var: if cls is Var:
return isinstance(obj, Var) return isinstance(obj, Var)
if isinstance(obj, LiteralVar): if isinstance(obj, LiteralVar):
return _isinstance(obj._var_value, cls, nested=nested) return treat_var_as_type and _isinstance(
obj._var_value, cls, nested=nested, treat_var_as_type=True
)
if isinstance(obj, Var): if isinstance(obj, Var):
return _issubclass(obj._var_type, cls) return treat_var_as_type and _issubclass(obj._var_type, cls)
if cls is None or cls is type(None): if cls is None or cls is type(None):
return obj is None return obj is None
if cls and is_union(cls): if cls and is_union(cls):
return any(_isinstance(obj, arg, nested=nested) for arg in get_args(cls)) return any(
_isinstance(obj, arg, nested=nested, treat_var_as_type=treat_var_as_type)
for arg in get_args(cls)
)
if is_literal(cls): if is_literal(cls):
return obj in get_args(cls) return obj in get_args(cls)
@ -567,37 +575,69 @@ def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
if nested > 0 and args: if nested > 0 and args:
if origin is list: if origin is list:
return isinstance(obj, list) and all( return isinstance(obj, list) and all(
_isinstance(item, args[0], nested=nested - 1) for item in obj _isinstance(
item,
args[0],
nested=nested - 1,
treat_var_as_type=treat_var_as_type,
)
for item in obj
) )
if origin is tuple: if origin is tuple:
if args[-1] is Ellipsis: if args[-1] is Ellipsis:
return isinstance(obj, tuple) and all( return isinstance(obj, tuple) and all(
_isinstance(item, args[0], nested=nested - 1) for item in obj _isinstance(
item,
args[0],
nested=nested - 1,
treat_var_as_type=treat_var_as_type,
)
for item in obj
) )
return ( return (
isinstance(obj, tuple) isinstance(obj, tuple)
and len(obj) == len(args) and len(obj) == len(args)
and all( and all(
_isinstance(item, arg, nested=nested - 1) _isinstance(
item,
arg,
nested=nested - 1,
treat_var_as_type=treat_var_as_type,
)
for item, arg in zip(obj, args, strict=True) for item, arg in zip(obj, args, strict=True)
) )
) )
if origin in (dict, Mapping, Breakpoints): if origin in (dict, Mapping, Breakpoints):
return isinstance(obj, Mapping) and all( return isinstance(obj, Mapping) and all(
_isinstance(key, args[0], nested=nested - 1) _isinstance(
and _isinstance(value, args[1], nested=nested - 1) key, args[0], nested=nested - 1, treat_var_as_type=treat_var_as_type
)
and _isinstance(
value,
args[1],
nested=nested - 1,
treat_var_as_type=treat_var_as_type,
)
for key, value in obj.items() for key, value in obj.items()
) )
if origin is set: if origin is set:
return isinstance(obj, set) and all( return isinstance(obj, set) and all(
_isinstance(item, args[0], nested=nested - 1) for item in obj _isinstance(
item,
args[0],
nested=nested - 1,
treat_var_as_type=treat_var_as_type,
)
for item in obj
) )
if args: if args:
from reflex.vars import Field from reflex.vars import Field
if origin is Field: if origin is Field:
return _isinstance(obj, args[0], nested=nested) return _isinstance(
obj, args[0], nested=nested, treat_var_as_type=treat_var_as_type
)
return isinstance(obj, get_base_class(cls)) return isinstance(obj, get_base_class(cls))

View File

@ -2290,7 +2290,7 @@ class ComputedVar(Var[RETURN_TYPE]):
return value return value
def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None: def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None:
if not _isinstance(value, self._var_type): if not _isinstance(value, self._var_type, nested=1, treat_var_as_type=False):
console.error( console.error(
f"Computed var '{type(instance).__name__}.{self._js_expr}' must return" f"Computed var '{type(instance).__name__}.{self._js_expr}' must return"
f" type '{self._var_type}', got '{type(value)}'." f" type '{self._var_type}', got '{type(value)}'."