fix silly mistakes

This commit is contained in:
Khaleel Al-Adhami 2024-11-13 13:51:47 -08:00
parent 1e9743dcd6
commit ebc81811c0

View File

@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overlo
from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
from reflex.utils import format
from reflex.utils.exceptions import VarTypeError
from reflex.utils.types import GenericType
from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
@ -109,12 +110,22 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
Returns:
The partially applied function.
"""
self._pre_check(*args)
if not args:
return ArgsFunctionOperation.create((), self)
return self
remaining_validators = self._pre_check(*args)
if self.__call__ is self.partial:
# if the default behavior is partial, we should return a new partial function
return ArgsFunctionOperationBuilder.create(
(),
VarOperationCall.create(self, *args, Var(_js_expr="...args")),
rest="args",
validators=remaining_validators,
)
return ArgsFunctionOperation.create(
("...args",),
(),
VarOperationCall.create(self, *args, Var(_js_expr="...args")),
rest="args",
validators=remaining_validators,
)
@overload
@ -187,7 +198,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
self._pre_check(*args)
return VarOperationCall.create(self, *args).guess_type()
def _pre_check(self, *args: Var | Any) -> bool:
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
Args:
@ -196,7 +207,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
Returns:
True if the function can be called with the given arguments.
"""
return True
return tuple()
__call__ = call
@ -346,7 +357,8 @@ def format_args_function_operation(
"""
arg_names_str = ", ".join(
[arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
) + (f", ...{args.rest}" if args.rest else "")
+ ([f"...{args.rest}"] if args.rest else [])
)
return_expr_str = str(LiteralVar.create(return_expr))
@ -371,6 +383,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
default_factory=tuple
)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
_function_name: str = dataclasses.field(default="")
_explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock
@ -384,7 +397,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
self._args, self._return_expr, self._explicit_return
)
def _pre_check(self, *args: Var | Any) -> bool:
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
Args:
@ -393,10 +406,17 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
Returns:
True if the function can be called with the given arguments.
"""
return all(
validator(arg)
for validator, arg in zip(self._validators, args, strict=False)
)
for i, (validator, arg) in enumerate(zip(self._validators, args)):
if not validator(arg):
arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
@classmethod
def create(
@ -405,6 +425,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
return_expr: Var | Any,
rest: str | None = None,
validators: Sequence[Callable[[Any], bool]] = (),
function_name: str = "",
explicit_return: bool = False,
_var_type: GenericType = Callable,
_var_data: VarData | None = None,
@ -415,6 +436,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
args_names: The names of the arguments.
return_expr: The return expression of the function.
rest: The name of the rest argument.
validators: The validators for the arguments.
function_name: The name of the function.
explicit_return: Whether to use explicit return syntax.
_var_data: Additional hooks and imports associated with the Var.
@ -426,6 +449,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
_var_type=_var_type,
_var_data=_var_data,
_args=FunctionArgs(args=tuple(args_names), rest=rest),
_function_name=function_name,
_validators=tuple(validators),
_return_expr=return_expr,
_explicit_return=explicit_return,
@ -441,7 +465,11 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
"""Base class for immutable function defined via arguments and return expression with the builder pattern."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
_validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
default_factory=tuple
)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
_function_name: str = dataclasses.field(default="")
_explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock
@ -455,12 +483,35 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
self._args, self._return_expr, self._explicit_return
)
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
for i, (validator, arg) in enumerate(zip(self._validators, args)):
if not validator(arg):
arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
@classmethod
def create(
cls,
args_names: Sequence[Union[str, DestructuredArg]],
return_expr: Var | Any,
rest: str | None = None,
validators: Sequence[Callable[[Any], bool]] = (),
function_name: str = "",
explicit_return: bool = False,
_var_type: GenericType = Callable,
_var_data: VarData | None = None,
@ -471,6 +522,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
args_names: The names of the arguments.
return_expr: The return expression of the function.
rest: The name of the rest argument.
validators: The validators for the arguments.
function_name: The name of the function.
explicit_return: Whether to use explicit return syntax.
_var_data: Additional hooks and imports associated with the Var.
@ -482,6 +535,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
_var_type=_var_type,
_var_data=_var_data,
_args=FunctionArgs(args=tuple(args_names), rest=rest),
_function_name=function_name,
_validators=tuple(validators),
_return_expr=return_expr,
_explicit_return=explicit_return,
)