diff --git a/reflex/vars/function.py b/reflex/vars/function.py index 2139071b9..d719e5ced 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overlo from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar from reflex.utils import format +from reflex.utils.exceptions import VarTypeError from reflex.utils.types import GenericType from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock @@ -109,12 +110,22 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): Returns: The partially applied function. """ - self._pre_check(*args) if not args: - return ArgsFunctionOperation.create((), self) + return self + remaining_validators = self._pre_check(*args) + if self.__call__ is self.partial: + # if the default behavior is partial, we should return a new partial function + return ArgsFunctionOperationBuilder.create( + (), + VarOperationCall.create(self, *args, Var(_js_expr="...args")), + rest="args", + validators=remaining_validators, + ) return ArgsFunctionOperation.create( - ("...args",), + (), VarOperationCall.create(self, *args, Var(_js_expr="...args")), + rest="args", + validators=remaining_validators, ) @overload @@ -187,7 +198,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): self._pre_check(*args) return VarOperationCall.create(self, *args).guess_type() - def _pre_check(self, *args: Var | Any) -> bool: + def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]: """Check if the function can be called with the given arguments. Args: @@ -196,7 +207,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): Returns: True if the function can be called with the given arguments. """ - return True + return tuple() __call__ = call @@ -346,7 +357,8 @@ def format_args_function_operation( """ arg_names_str = ", ".join( [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args] - ) + (f", ...{args.rest}" if args.rest else "") + + ([f"...{args.rest}"] if args.rest else []) + ) return_expr_str = str(LiteralVar.create(return_expr)) @@ -371,6 +383,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): default_factory=tuple ) _return_expr: Union[Var, Any] = dataclasses.field(default=None) + _function_name: str = dataclasses.field(default="") _explicit_return: bool = dataclasses.field(default=False) @cached_property_no_lock @@ -384,7 +397,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): self._args, self._return_expr, self._explicit_return ) - def _pre_check(self, *args: Var | Any) -> bool: + def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]: """Check if the function can be called with the given arguments. Args: @@ -393,10 +406,17 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): Returns: True if the function can be called with the given arguments. """ - return all( - validator(arg) - for validator, arg in zip(self._validators, args, strict=False) - ) + for i, (validator, arg) in enumerate(zip(self._validators, args)): + if not validator(arg): + arg_name = self._args.args[i] if i < len(self._args.args) else None + if arg_name is not None: + raise VarTypeError( + f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}" + ) + raise VarTypeError( + f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}" + ) + return self._validators[len(args) :] @classmethod def create( @@ -405,6 +425,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): return_expr: Var | Any, rest: str | None = None, validators: Sequence[Callable[[Any], bool]] = (), + function_name: str = "", explicit_return: bool = False, _var_type: GenericType = Callable, _var_data: VarData | None = None, @@ -415,6 +436,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): args_names: The names of the arguments. return_expr: The return expression of the function. rest: The name of the rest argument. + validators: The validators for the arguments. + function_name: The name of the function. explicit_return: Whether to use explicit return syntax. _var_data: Additional hooks and imports associated with the Var. @@ -426,6 +449,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): _var_type=_var_type, _var_data=_var_data, _args=FunctionArgs(args=tuple(args_names), rest=rest), + _function_name=function_name, _validators=tuple(validators), _return_expr=return_expr, _explicit_return=explicit_return, @@ -441,7 +465,11 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): """Base class for immutable function defined via arguments and return expression with the builder pattern.""" _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) + _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field( + default_factory=tuple + ) _return_expr: Union[Var, Any] = dataclasses.field(default=None) + _function_name: str = dataclasses.field(default="") _explicit_return: bool = dataclasses.field(default=False) @cached_property_no_lock @@ -455,12 +483,35 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): self._args, self._return_expr, self._explicit_return ) + def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]: + """Check if the function can be called with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + True if the function can be called with the given arguments. + """ + for i, (validator, arg) in enumerate(zip(self._validators, args)): + if not validator(arg): + arg_name = self._args.args[i] if i < len(self._args.args) else None + if arg_name is not None: + raise VarTypeError( + f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}" + ) + raise VarTypeError( + f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}" + ) + return self._validators[len(args) :] + @classmethod def create( cls, args_names: Sequence[Union[str, DestructuredArg]], return_expr: Var | Any, rest: str | None = None, + validators: Sequence[Callable[[Any], bool]] = (), + function_name: str = "", explicit_return: bool = False, _var_type: GenericType = Callable, _var_data: VarData | None = None, @@ -471,6 +522,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): args_names: The names of the arguments. return_expr: The return expression of the function. rest: The name of the rest argument. + validators: The validators for the arguments. + function_name: The name of the function. explicit_return: Whether to use explicit return syntax. _var_data: Additional hooks and imports associated with the Var. @@ -482,6 +535,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): _var_type=_var_type, _var_data=_var_data, _args=FunctionArgs(args=tuple(args_names), rest=rest), + _function_name=function_name, + _validators=tuple(validators), _return_expr=return_expr, _explicit_return=explicit_return, )