add type validation for state setattr (#4265)

* add type validation for state setattr

* add type to check to state setattr

* add type validation to computed vars
This commit is contained in:
Khaleel Al-Adhami 2024-10-30 11:11:03 -07:00 committed by GitHub
parent 4260a0cfc3
commit c8a7ee52bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 103 additions and 14 deletions

View File

@ -91,7 +91,7 @@ from reflex.utils.exceptions import (
)
from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import serializer
from reflex.utils.types import get_origin, override
from reflex.utils.types import _isinstance, get_origin, override
from reflex.vars import VarData
if TYPE_CHECKING:
@ -636,7 +636,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
def computed_var_func(state: Self):
result = f(state)
if not isinstance(result, of_type):
if not _isinstance(result, of_type):
console.warn(
f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. "
"You can specify expected type with `of_type` argument."
@ -1274,6 +1274,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
f"All state variables must be declared before they can be set."
)
fields = self.get_fields()
if name in fields and not _isinstance(
value, (field_type := fields[name].outer_type_)
):
console.deprecate(
"mismatched-type-assignment",
f"Tried to assign value {value} of type {type(value)} to field {type(self).__name__}.{name} of type {field_type}."
" This might lead to unexpected behavior.",
"0.6.5",
"0.7.0",
)
# Set the attribute.
super().__setattr__(name, value)

View File

@ -510,16 +510,66 @@ def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None)
raise TypeError(f"Invalid type for issubclass: {cls_base}") from te
def _isinstance(obj: Any, cls: GenericType) -> bool:
def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
"""Check if an object is an instance of a class.
Args:
obj: The object to check.
cls: The class to check against.
nested: Whether the check is nested.
Returns:
Whether the object is an instance of the class.
"""
if cls is Any:
return True
if cls is None or cls is type(None):
return obj is None
if is_literal(cls):
return obj in get_args(cls)
if is_union(cls):
return any(_isinstance(obj, arg) for arg in get_args(cls))
origin = get_origin(cls)
if origin is None:
# cls is a simple class
return isinstance(obj, cls)
args = get_args(cls)
if not args:
# cls is a simple generic class
return isinstance(obj, origin)
if nested and args:
if origin is list:
return isinstance(obj, list) and all(
_isinstance(item, args[0]) for item in obj
)
if origin is tuple:
if args[-1] is Ellipsis:
return isinstance(obj, tuple) and all(
_isinstance(item, args[0]) for item in obj
)
return (
isinstance(obj, tuple)
and len(obj) == len(args)
and all(_isinstance(item, arg) for item, arg in zip(obj, args))
)
if origin is dict:
return isinstance(obj, dict) and all(
_isinstance(key, args[0]) and _isinstance(value, args[1])
for key, value in obj.items()
)
if origin is set:
return isinstance(obj, set) and all(
_isinstance(item, args[0]) for item in obj
)
return isinstance(obj, get_base_class(cls))

View File

@ -63,7 +63,14 @@ from reflex.utils.imports import (
ParsedImportDict,
parse_imports,
)
from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize
from reflex.utils.types import (
GenericType,
Self,
_isinstance,
get_origin,
has_args,
unionize,
)
if TYPE_CHECKING:
from reflex.state import BaseState
@ -1833,6 +1840,14 @@ class ComputedVar(Var[RETURN_TYPE]):
"return", Any
)
if hint is Any:
console.deprecate(
"untyped-computed-var",
"ComputedVar should have a return type annotation.",
"0.6.5",
"0.7.0",
)
kwargs.setdefault("_js_expr", fget.__name__)
kwargs.setdefault("_var_type", hint)
@ -2026,17 +2041,28 @@ class ComputedVar(Var[RETURN_TYPE]):
)
if not self._cache:
return self.fget(instance)
value = self.fget(instance)
else:
# handle caching
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, self.fget(instance))
# Ensure the computed var gets serialized to redis.
instance._was_touched = True
# Set the last updated timestamp on the state instance.
setattr(instance, self._last_updated_attr, datetime.datetime.now())
value = getattr(instance, self._cache_attr)
# handle caching
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, self.fget(instance))
# Ensure the computed var gets serialized to redis.
instance._was_touched = True
# Set the last updated timestamp on the state instance.
setattr(instance, self._last_updated_attr, datetime.datetime.now())
return getattr(instance, self._cache_attr)
if not _isinstance(value, self._var_type):
console.deprecate(
"mismatched-computed-var-return",
f"Computed var {type(instance).__name__}.{self._js_expr} returned value of type {type(value)}, "
f"expected {self._var_type}. This might cause unexpected behavior.",
"0.6.5",
"0.7.0",
)
return value
def _deps(
self,