From c8a7ee52bf8989942a2fb92e71d00e3ac6b8d864 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Wed, 30 Oct 2024 11:11:03 -0700 Subject: [PATCH] add type validation for state setattr (#4265) * add type validation for state setattr * add type to check to state setattr * add type validation to computed vars --- reflex/state.py | 17 ++++++++++++-- reflex/utils/types.py | 52 ++++++++++++++++++++++++++++++++++++++++++- reflex/vars/base.py | 48 ++++++++++++++++++++++++++++++--------- 3 files changed, 103 insertions(+), 14 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index 2704d58f2..7bdbcdc2b 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -91,7 +91,7 @@ from reflex.utils.exceptions import ( ) from reflex.utils.exec import is_testing_env from reflex.utils.serializers import serializer -from reflex.utils.types import get_origin, override +from reflex.utils.types import _isinstance, get_origin, override from reflex.vars import VarData if TYPE_CHECKING: @@ -636,7 +636,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): def computed_var_func(state: Self): result = f(state) - if not isinstance(result, of_type): + if not _isinstance(result, of_type): console.warn( f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. " "You can specify expected type with `of_type` argument." @@ -1274,6 +1274,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): f"All state variables must be declared before they can be set." ) + fields = self.get_fields() + + if name in fields and not _isinstance( + value, (field_type := fields[name].outer_type_) + ): + console.deprecate( + "mismatched-type-assignment", + f"Tried to assign value {value} of type {type(value)} to field {type(self).__name__}.{name} of type {field_type}." + " This might lead to unexpected behavior.", + "0.6.5", + "0.7.0", + ) + # Set the attribute. super().__setattr__(name, value) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 3d7992011..baedcc5a0 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -510,16 +510,66 @@ def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None) raise TypeError(f"Invalid type for issubclass: {cls_base}") from te -def _isinstance(obj: Any, cls: GenericType) -> bool: +def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool: """Check if an object is an instance of a class. Args: obj: The object to check. cls: The class to check against. + nested: Whether the check is nested. Returns: Whether the object is an instance of the class. """ + if cls is Any: + return True + + if cls is None or cls is type(None): + return obj is None + + if is_literal(cls): + return obj in get_args(cls) + + if is_union(cls): + return any(_isinstance(obj, arg) for arg in get_args(cls)) + + origin = get_origin(cls) + + if origin is None: + # cls is a simple class + return isinstance(obj, cls) + + args = get_args(cls) + + if not args: + # cls is a simple generic class + return isinstance(obj, origin) + + if nested and args: + if origin is list: + return isinstance(obj, list) and all( + _isinstance(item, args[0]) for item in obj + ) + if origin is tuple: + if args[-1] is Ellipsis: + return isinstance(obj, tuple) and all( + _isinstance(item, args[0]) for item in obj + ) + return ( + isinstance(obj, tuple) + and len(obj) == len(args) + and all(_isinstance(item, arg) for item, arg in zip(obj, args)) + ) + if origin is dict: + return isinstance(obj, dict) and all( + _isinstance(key, args[0]) and _isinstance(value, args[1]) + for key, value in obj.items() + ) + if origin is set: + return isinstance(obj, set) and all( + _isinstance(item, args[0]) for item in obj + ) + return isinstance(obj, get_base_class(cls)) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 2f26e9170..78862aa17 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -63,7 +63,14 @@ from reflex.utils.imports import ( ParsedImportDict, parse_imports, ) -from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize +from reflex.utils.types import ( + GenericType, + Self, + _isinstance, + get_origin, + has_args, + unionize, +) if TYPE_CHECKING: from reflex.state import BaseState @@ -1833,6 +1840,14 @@ class ComputedVar(Var[RETURN_TYPE]): "return", Any ) + if hint is Any: + console.deprecate( + "untyped-computed-var", + "ComputedVar should have a return type annotation.", + "0.6.5", + "0.7.0", + ) + kwargs.setdefault("_js_expr", fget.__name__) kwargs.setdefault("_var_type", hint) @@ -2026,17 +2041,28 @@ class ComputedVar(Var[RETURN_TYPE]): ) if not self._cache: - return self.fget(instance) + value = self.fget(instance) + else: + # handle caching + if not hasattr(instance, self._cache_attr) or self.needs_update(instance): + # Set cache attr on state instance. + setattr(instance, self._cache_attr, self.fget(instance)) + # Ensure the computed var gets serialized to redis. + instance._was_touched = True + # Set the last updated timestamp on the state instance. + setattr(instance, self._last_updated_attr, datetime.datetime.now()) + value = getattr(instance, self._cache_attr) - # handle caching - if not hasattr(instance, self._cache_attr) or self.needs_update(instance): - # Set cache attr on state instance. - setattr(instance, self._cache_attr, self.fget(instance)) - # Ensure the computed var gets serialized to redis. - instance._was_touched = True - # Set the last updated timestamp on the state instance. - setattr(instance, self._last_updated_attr, datetime.datetime.now()) - return getattr(instance, self._cache_attr) + if not _isinstance(value, self._var_type): + console.deprecate( + "mismatched-computed-var-return", + f"Computed var {type(instance).__name__}.{self._js_expr} returned value of type {type(value)}, " + f"expected {self._var_type}. This might cause unexpected behavior.", + "0.6.5", + "0.7.0", + ) + + return value def _deps( self,