add validation

This commit is contained in:
Khaleel Al-Adhami 2024-11-13 13:22:01 -08:00
parent f9b24fe5bd
commit 05bd41c040

View File

@ -109,6 +109,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
Returns: Returns:
The partially applied function. The partially applied function.
""" """
self._pre_check(*args)
if not args: if not args:
return ArgsFunctionOperation.create((), self) return ArgsFunctionOperation.create((), self)
return ArgsFunctionOperation.create( return ArgsFunctionOperation.create(
@ -183,8 +184,20 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
Returns: Returns:
The function call operation. The function call operation.
""" """
self._pre_check(*args)
return VarOperationCall.create(self, *args).guess_type() return VarOperationCall.create(self, *args).guess_type()
def _pre_check(self, *args: Var | Any) -> bool:
"""Check if the function can be called with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
return True
__call__ = call __call__ = call
@ -354,6 +367,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
"""Base class for immutable function defined via arguments and return expression.""" """Base class for immutable function defined via arguments and return expression."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
_validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
default_factory=tuple
)
_return_expr: Union[Var, Any] = dataclasses.field(default=None) _return_expr: Union[Var, Any] = dataclasses.field(default=None)
_explicit_return: bool = dataclasses.field(default=False) _explicit_return: bool = dataclasses.field(default=False)
@ -368,12 +384,27 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
self._args, self._return_expr, self._explicit_return self._args, self._return_expr, self._explicit_return
) )
def _pre_check(self, *args: Var | Any) -> bool:
"""Check if the function can be called with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
return all(
validator(arg)
for validator, arg in zip(self._validators, args, strict=False)
)
@classmethod @classmethod
def create( def create(
cls, cls,
args_names: Sequence[Union[str, DestructuredArg]], args_names: Sequence[Union[str, DestructuredArg]],
return_expr: Var | Any, return_expr: Var | Any,
rest: str | None = None, rest: str | None = None,
validators: Sequence[Callable[[Any], bool]] = (),
explicit_return: bool = False, explicit_return: bool = False,
_var_type: GenericType = Callable, _var_type: GenericType = Callable,
_var_data: VarData | None = None, _var_data: VarData | None = None,
@ -395,6 +426,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
_var_type=_var_type, _var_type=_var_type,
_var_data=_var_data, _var_data=_var_data,
_args=FunctionArgs(args=tuple(args_names), rest=rest), _args=FunctionArgs(args=tuple(args_names), rest=rest),
_validators=tuple(validators),
_return_expr=return_expr, _return_expr=return_expr,
_explicit_return=explicit_return, _explicit_return=explicit_return,
) )