diff --git a/reflex/vars/function.py b/reflex/vars/function.py index c65b38f70..2139071b9 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -109,6 +109,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): Returns: The partially applied function. """ + self._pre_check(*args) if not args: return ArgsFunctionOperation.create((), self) return ArgsFunctionOperation.create( @@ -183,8 +184,20 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): Returns: The function call operation. """ + self._pre_check(*args) return VarOperationCall.create(self, *args).guess_type() + def _pre_check(self, *args: Var | Any) -> bool: + """Check if the function can be called with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + True if the function can be called with the given arguments. + """ + return True + __call__ = call @@ -354,6 +367,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): """Base class for immutable function defined via arguments and return expression.""" _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) + _validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field( + default_factory=tuple + ) _return_expr: Union[Var, Any] = dataclasses.field(default=None) _explicit_return: bool = dataclasses.field(default=False) @@ -368,12 +384,27 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): self._args, self._return_expr, self._explicit_return ) + def _pre_check(self, *args: Var | Any) -> bool: + """Check if the function can be called with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + True if the function can be called with the given arguments. + """ + return all( + validator(arg) + for validator, arg in zip(self._validators, args, strict=False) + ) + @classmethod def create( cls, args_names: Sequence[Union[str, DestructuredArg]], return_expr: Var | Any, rest: str | None = None, + validators: Sequence[Callable[[Any], bool]] = (), explicit_return: bool = False, _var_type: GenericType = Callable, _var_data: VarData | None = None, @@ -395,6 +426,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): _var_type=_var_type, _var_data=_var_data, _args=FunctionArgs(args=tuple(args_names), rest=rest), + _validators=tuple(validators), _return_expr=return_expr, _explicit_return=explicit_return, )