Compare commits

...

1 Commits

Author SHA1 Message Date
Khaleel Al-Adhami
24a5309156 don't treat vars as their types for setting state fields 2025-02-22 09:23:53 -08:00
5 changed files with 58 additions and 18 deletions

View File

@ -192,7 +192,7 @@ def satisfies_type_hint(obj: Any, type_hint: Any) -> bool:
Returns:
Whether the object satisfies the type hint.
"""
return types._isinstance(obj, type_hint, nested=1)
return types._isinstance(obj, type_hint, nested=1, treat_var_as_type=True)
def _components_from(

View File

@ -6,7 +6,7 @@ import dataclasses
from typing import Any, Dict, List, Mapping, Optional, Sequence
from reflex.event import EventChain
from reflex.utils import format, types
from reflex.utils import format
from reflex.vars.base import LiteralVar, Var
@ -103,9 +103,9 @@ class Tag:
{
format.to_camel_case(name, treat_hyphens_as_underscores=False): (
prop
if types._isinstance(prop, (EventChain, Mapping))
if isinstance(prop, (EventChain, Mapping))
else LiteralVar.create(prop)
) # rx.color is always a string
)
for name, prop in kwargs.items()
if self.is_valid_prop(prop)
}

View File

@ -695,7 +695,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
def computed_var_func(state: Self):
result = f(state)
if not _isinstance(result, of_type):
if not _isinstance(result, of_type, nested=1, treat_var_as_type=False):
console.warn(
f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. "
"You can specify expected type with `of_type` argument."
@ -1356,7 +1356,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
field_type = _unwrap_field_type(field.outer_type_)
if field.allow_none and not is_optional(field_type):
field_type = Union[field_type, None]
if not _isinstance(value, field_type):
if not _isinstance(value, field_type, nested=1, treat_var_as_type=False):
console.error(
f"Expected field '{type(self).__name__}.{name}' to receive type '{field_type}',"
f" but got '{value}' of type '{type(value)}'."

View File

@ -510,13 +510,16 @@ def does_obj_satisfy_typed_dict(obj: Any, cls: GenericType) -> bool:
return required_keys.issubset(required_keys)
def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
def _isinstance(
obj: Any, cls: GenericType, *, nested: int = 0, treat_var_as_type: bool = True
) -> bool:
"""Check if an object is an instance of a class.
Args:
obj: The object to check.
cls: The class to check against.
nested: How many levels deep to check.
treat_var_as_type: Whether to treat Var as the type it represents, i.e. _var_type.
Returns:
Whether the object is an instance of the class.
@ -529,15 +532,20 @@ def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
if cls is Var:
return isinstance(obj, Var)
if isinstance(obj, LiteralVar):
return _isinstance(obj._var_value, cls, nested=nested)
return treat_var_as_type and _isinstance(
obj._var_value, cls, nested=nested, treat_var_as_type=True
)
if isinstance(obj, Var):
return _issubclass(obj._var_type, cls)
return treat_var_as_type and _issubclass(obj._var_type, cls)
if cls is None or cls is type(None):
return obj is None
if cls and is_union(cls):
return any(_isinstance(obj, arg, nested=nested) for arg in get_args(cls))
return any(
_isinstance(obj, arg, nested=nested, treat_var_as_type=treat_var_as_type)
for arg in get_args(cls)
)
if is_literal(cls):
return obj in get_args(cls)
@ -567,37 +575,69 @@ def _isinstance(obj: Any, cls: GenericType, nested: int = 0) -> bool:
if nested > 0 and args:
if origin is list:
return isinstance(obj, list) and all(
_isinstance(item, args[0], nested=nested - 1) for item in obj
_isinstance(
item,
args[0],
nested=nested - 1,
treat_var_as_type=treat_var_as_type,
)
for item in obj
)
if origin is tuple:
if args[-1] is Ellipsis:
return isinstance(obj, tuple) and all(
_isinstance(item, args[0], nested=nested - 1) for item in obj
_isinstance(
item,
args[0],
nested=nested - 1,
treat_var_as_type=treat_var_as_type,
)
for item in obj
)
return (
isinstance(obj, tuple)
and len(obj) == len(args)
and all(
_isinstance(item, arg, nested=nested - 1)
_isinstance(
item,
arg,
nested=nested - 1,
treat_var_as_type=treat_var_as_type,
)
for item, arg in zip(obj, args, strict=True)
)
)
if origin in (dict, Mapping, Breakpoints):
return isinstance(obj, Mapping) and all(
_isinstance(key, args[0], nested=nested - 1)
and _isinstance(value, args[1], nested=nested - 1)
_isinstance(
key, args[0], nested=nested - 1, treat_var_as_type=treat_var_as_type
)
and _isinstance(
value,
args[1],
nested=nested - 1,
treat_var_as_type=treat_var_as_type,
)
for key, value in obj.items()
)
if origin is set:
return isinstance(obj, set) and all(
_isinstance(item, args[0], nested=nested - 1) for item in obj
_isinstance(
item,
args[0],
nested=nested - 1,
treat_var_as_type=treat_var_as_type,
)
for item in obj
)
if args:
from reflex.vars import Field
if origin is Field:
return _isinstance(obj, args[0], nested=nested)
return _isinstance(
obj, args[0], nested=nested, treat_var_as_type=treat_var_as_type
)
return isinstance(obj, get_base_class(cls))

View File

@ -2290,7 +2290,7 @@ class ComputedVar(Var[RETURN_TYPE]):
return value
def _check_deprecated_return_type(self, instance: BaseState, value: Any) -> None:
if not _isinstance(value, self._var_type):
if not _isinstance(value, self._var_type, nested=1, treat_var_as_type=False):
console.error(
f"Computed var '{type(instance).__name__}.{self._js_expr}' must return"
f" type '{self._var_type}', got '{type(value)}'."