add type validation for state setattr (#4265)
* add type validation for state setattr * add type to check to state setattr * add type validation to computed vars
This commit is contained in:
parent
4260a0cfc3
commit
c8a7ee52bf
@ -91,7 +91,7 @@ from reflex.utils.exceptions import (
|
|||||||
)
|
)
|
||||||
from reflex.utils.exec import is_testing_env
|
from reflex.utils.exec import is_testing_env
|
||||||
from reflex.utils.serializers import serializer
|
from reflex.utils.serializers import serializer
|
||||||
from reflex.utils.types import get_origin, override
|
from reflex.utils.types import _isinstance, get_origin, override
|
||||||
from reflex.vars import VarData
|
from reflex.vars import VarData
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -636,7 +636,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
def computed_var_func(state: Self):
|
def computed_var_func(state: Self):
|
||||||
result = f(state)
|
result = f(state)
|
||||||
|
|
||||||
if not isinstance(result, of_type):
|
if not _isinstance(result, of_type):
|
||||||
console.warn(
|
console.warn(
|
||||||
f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. "
|
f"Inline ComputedVar {f} expected type {of_type}, got {type(result)}. "
|
||||||
"You can specify expected type with `of_type` argument."
|
"You can specify expected type with `of_type` argument."
|
||||||
@ -1274,6 +1274,19 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
f"All state variables must be declared before they can be set."
|
f"All state variables must be declared before they can be set."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fields = self.get_fields()
|
||||||
|
|
||||||
|
if name in fields and not _isinstance(
|
||||||
|
value, (field_type := fields[name].outer_type_)
|
||||||
|
):
|
||||||
|
console.deprecate(
|
||||||
|
"mismatched-type-assignment",
|
||||||
|
f"Tried to assign value {value} of type {type(value)} to field {type(self).__name__}.{name} of type {field_type}."
|
||||||
|
" This might lead to unexpected behavior.",
|
||||||
|
"0.6.5",
|
||||||
|
"0.7.0",
|
||||||
|
)
|
||||||
|
|
||||||
# Set the attribute.
|
# Set the attribute.
|
||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
|
@ -510,16 +510,66 @@ def _issubclass(cls: GenericType, cls_check: GenericType, instance: Any = None)
|
|||||||
raise TypeError(f"Invalid type for issubclass: {cls_base}") from te
|
raise TypeError(f"Invalid type for issubclass: {cls_base}") from te
|
||||||
|
|
||||||
|
|
||||||
def _isinstance(obj: Any, cls: GenericType) -> bool:
|
def _isinstance(obj: Any, cls: GenericType, nested: bool = False) -> bool:
|
||||||
"""Check if an object is an instance of a class.
|
"""Check if an object is an instance of a class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
obj: The object to check.
|
obj: The object to check.
|
||||||
cls: The class to check against.
|
cls: The class to check against.
|
||||||
|
nested: Whether the check is nested.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Whether the object is an instance of the class.
|
Whether the object is an instance of the class.
|
||||||
"""
|
"""
|
||||||
|
if cls is Any:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if cls is None or cls is type(None):
|
||||||
|
return obj is None
|
||||||
|
|
||||||
|
if is_literal(cls):
|
||||||
|
return obj in get_args(cls)
|
||||||
|
|
||||||
|
if is_union(cls):
|
||||||
|
return any(_isinstance(obj, arg) for arg in get_args(cls))
|
||||||
|
|
||||||
|
origin = get_origin(cls)
|
||||||
|
|
||||||
|
if origin is None:
|
||||||
|
# cls is a simple class
|
||||||
|
return isinstance(obj, cls)
|
||||||
|
|
||||||
|
args = get_args(cls)
|
||||||
|
|
||||||
|
if not args:
|
||||||
|
# cls is a simple generic class
|
||||||
|
return isinstance(obj, origin)
|
||||||
|
|
||||||
|
if nested and args:
|
||||||
|
if origin is list:
|
||||||
|
return isinstance(obj, list) and all(
|
||||||
|
_isinstance(item, args[0]) for item in obj
|
||||||
|
)
|
||||||
|
if origin is tuple:
|
||||||
|
if args[-1] is Ellipsis:
|
||||||
|
return isinstance(obj, tuple) and all(
|
||||||
|
_isinstance(item, args[0]) for item in obj
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
isinstance(obj, tuple)
|
||||||
|
and len(obj) == len(args)
|
||||||
|
and all(_isinstance(item, arg) for item, arg in zip(obj, args))
|
||||||
|
)
|
||||||
|
if origin is dict:
|
||||||
|
return isinstance(obj, dict) and all(
|
||||||
|
_isinstance(key, args[0]) and _isinstance(value, args[1])
|
||||||
|
for key, value in obj.items()
|
||||||
|
)
|
||||||
|
if origin is set:
|
||||||
|
return isinstance(obj, set) and all(
|
||||||
|
_isinstance(item, args[0]) for item in obj
|
||||||
|
)
|
||||||
|
|
||||||
return isinstance(obj, get_base_class(cls))
|
return isinstance(obj, get_base_class(cls))
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,7 +63,14 @@ from reflex.utils.imports import (
|
|||||||
ParsedImportDict,
|
ParsedImportDict,
|
||||||
parse_imports,
|
parse_imports,
|
||||||
)
|
)
|
||||||
from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize
|
from reflex.utils.types import (
|
||||||
|
GenericType,
|
||||||
|
Self,
|
||||||
|
_isinstance,
|
||||||
|
get_origin,
|
||||||
|
has_args,
|
||||||
|
unionize,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from reflex.state import BaseState
|
from reflex.state import BaseState
|
||||||
@ -1833,6 +1840,14 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
"return", Any
|
"return", Any
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hint is Any:
|
||||||
|
console.deprecate(
|
||||||
|
"untyped-computed-var",
|
||||||
|
"ComputedVar should have a return type annotation.",
|
||||||
|
"0.6.5",
|
||||||
|
"0.7.0",
|
||||||
|
)
|
||||||
|
|
||||||
kwargs.setdefault("_js_expr", fget.__name__)
|
kwargs.setdefault("_js_expr", fget.__name__)
|
||||||
kwargs.setdefault("_var_type", hint)
|
kwargs.setdefault("_var_type", hint)
|
||||||
|
|
||||||
@ -2026,17 +2041,28 @@ class ComputedVar(Var[RETURN_TYPE]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not self._cache:
|
if not self._cache:
|
||||||
return self.fget(instance)
|
value = self.fget(instance)
|
||||||
|
else:
|
||||||
|
# handle caching
|
||||||
|
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
|
||||||
|
# Set cache attr on state instance.
|
||||||
|
setattr(instance, self._cache_attr, self.fget(instance))
|
||||||
|
# Ensure the computed var gets serialized to redis.
|
||||||
|
instance._was_touched = True
|
||||||
|
# Set the last updated timestamp on the state instance.
|
||||||
|
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
||||||
|
value = getattr(instance, self._cache_attr)
|
||||||
|
|
||||||
# handle caching
|
if not _isinstance(value, self._var_type):
|
||||||
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
|
console.deprecate(
|
||||||
# Set cache attr on state instance.
|
"mismatched-computed-var-return",
|
||||||
setattr(instance, self._cache_attr, self.fget(instance))
|
f"Computed var {type(instance).__name__}.{self._js_expr} returned value of type {type(value)}, "
|
||||||
# Ensure the computed var gets serialized to redis.
|
f"expected {self._var_type}. This might cause unexpected behavior.",
|
||||||
instance._was_touched = True
|
"0.6.5",
|
||||||
# Set the last updated timestamp on the state instance.
|
"0.7.0",
|
||||||
setattr(instance, self._last_updated_attr, datetime.datetime.now())
|
)
|
||||||
return getattr(instance, self._cache_attr)
|
|
||||||
|
return value
|
||||||
|
|
||||||
def _deps(
|
def _deps(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user