fix silly mistakes

This commit is contained in:
Khaleel Al-Adhami 2024-11-13 13:51:47 -08:00
parent 1e9743dcd6
commit ebc81811c0

View File

@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overlo
from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
from reflex.utils import format from reflex.utils import format
from reflex.utils.exceptions import VarTypeError
from reflex.utils.types import GenericType from reflex.utils.types import GenericType
from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
@ -109,12 +110,22 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
Returns: Returns:
The partially applied function. The partially applied function.
""" """
self._pre_check(*args)
if not args: if not args:
return ArgsFunctionOperation.create((), self) return self
remaining_validators = self._pre_check(*args)
if self.__call__ is self.partial:
# if the default behavior is partial, we should return a new partial function
return ArgsFunctionOperationBuilder.create(
(),
VarOperationCall.create(self, *args, Var(_js_expr="...args")),
rest="args",
validators=remaining_validators,
)
return ArgsFunctionOperation.create( return ArgsFunctionOperation.create(
("...args",), (),
VarOperationCall.create(self, *args, Var(_js_expr="...args")), VarOperationCall.create(self, *args, Var(_js_expr="...args")),
rest="args",
validators=remaining_validators,
) )
@overload @overload
@ -187,7 +198,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
self._pre_check(*args) self._pre_check(*args)
return VarOperationCall.create(self, *args).guess_type() return VarOperationCall.create(self, *args).guess_type()
def _pre_check(self, *args: Var | Any) -> bool: def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments. """Check if the function can be called with the given arguments.
Args: Args:
@ -196,7 +207,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
Returns: Returns:
True if the function can be called with the given arguments. True if the function can be called with the given arguments.
""" """
return True return tuple()
__call__ = call __call__ = call
@ -346,7 +357,8 @@ def format_args_function_operation(
""" """
arg_names_str = ", ".join( arg_names_str = ", ".join(
[arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args] [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
) + (f", ...{args.rest}" if args.rest else "") + ([f"...{args.rest}"] if args.rest else [])
)
return_expr_str = str(LiteralVar.create(return_expr)) return_expr_str = str(LiteralVar.create(return_expr))
@ -371,6 +383,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
default_factory=tuple default_factory=tuple
) )
_return_expr: Union[Var, Any] = dataclasses.field(default=None) _return_expr: Union[Var, Any] = dataclasses.field(default=None)
_function_name: str = dataclasses.field(default="")
_explicit_return: bool = dataclasses.field(default=False) _explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock @cached_property_no_lock
@ -384,7 +397,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
self._args, self._return_expr, self._explicit_return self._args, self._return_expr, self._explicit_return
) )
def _pre_check(self, *args: Var | Any) -> bool: def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments. """Check if the function can be called with the given arguments.
Args: Args:
@ -393,10 +406,17 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
Returns: Returns:
True if the function can be called with the given arguments. True if the function can be called with the given arguments.
""" """
return all( for i, (validator, arg) in enumerate(zip(self._validators, args)):
validator(arg) if not validator(arg):
for validator, arg in zip(self._validators, args, strict=False) arg_name = self._args.args[i] if i < len(self._args.args) else None
) if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
@classmethod @classmethod
def create( def create(
@ -405,6 +425,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
return_expr: Var | Any, return_expr: Var | Any,
rest: str | None = None, rest: str | None = None,
validators: Sequence[Callable[[Any], bool]] = (), validators: Sequence[Callable[[Any], bool]] = (),
function_name: str = "",
explicit_return: bool = False, explicit_return: bool = False,
_var_type: GenericType = Callable, _var_type: GenericType = Callable,
_var_data: VarData | None = None, _var_data: VarData | None = None,
@ -415,6 +436,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
args_names: The names of the arguments. args_names: The names of the arguments.
return_expr: The return expression of the function. return_expr: The return expression of the function.
rest: The name of the rest argument. rest: The name of the rest argument.
validators: The validators for the arguments.
function_name: The name of the function.
explicit_return: Whether to use explicit return syntax. explicit_return: Whether to use explicit return syntax.
_var_data: Additional hooks and imports associated with the Var. _var_data: Additional hooks and imports associated with the Var.
@ -426,6 +449,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
_var_type=_var_type, _var_type=_var_type,
_var_data=_var_data, _var_data=_var_data,
_args=FunctionArgs(args=tuple(args_names), rest=rest), _args=FunctionArgs(args=tuple(args_names), rest=rest),
_function_name=function_name,
_validators=tuple(validators), _validators=tuple(validators),
_return_expr=return_expr, _return_expr=return_expr,
_explicit_return=explicit_return, _explicit_return=explicit_return,
@ -441,7 +465,11 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
"""Base class for immutable function defined via arguments and return expression with the builder pattern.""" """Base class for immutable function defined via arguments and return expression with the builder pattern."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
_validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
default_factory=tuple
)
_return_expr: Union[Var, Any] = dataclasses.field(default=None) _return_expr: Union[Var, Any] = dataclasses.field(default=None)
_function_name: str = dataclasses.field(default="")
_explicit_return: bool = dataclasses.field(default=False) _explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock @cached_property_no_lock
@ -455,12 +483,35 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
self._args, self._return_expr, self._explicit_return self._args, self._return_expr, self._explicit_return
) )
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
for i, (validator, arg) in enumerate(zip(self._validators, args)):
if not validator(arg):
arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
@classmethod @classmethod
def create( def create(
cls, cls,
args_names: Sequence[Union[str, DestructuredArg]], args_names: Sequence[Union[str, DestructuredArg]],
return_expr: Var | Any, return_expr: Var | Any,
rest: str | None = None, rest: str | None = None,
validators: Sequence[Callable[[Any], bool]] = (),
function_name: str = "",
explicit_return: bool = False, explicit_return: bool = False,
_var_type: GenericType = Callable, _var_type: GenericType = Callable,
_var_data: VarData | None = None, _var_data: VarData | None = None,
@ -471,6 +522,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
args_names: The names of the arguments. args_names: The names of the arguments.
return_expr: The return expression of the function. return_expr: The return expression of the function.
rest: The name of the rest argument. rest: The name of the rest argument.
validators: The validators for the arguments.
function_name: The name of the function.
explicit_return: Whether to use explicit return syntax. explicit_return: Whether to use explicit return syntax.
_var_data: Additional hooks and imports associated with the Var. _var_data: Additional hooks and imports associated with the Var.
@ -482,6 +535,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
_var_type=_var_type, _var_type=_var_type,
_var_data=_var_data, _var_data=_var_data,
_args=FunctionArgs(args=tuple(args_names), rest=rest), _args=FunctionArgs(args=tuple(args_names), rest=rest),
_function_name=function_name,
_validators=tuple(validators),
_return_expr=return_expr, _return_expr=return_expr,
_explicit_return=explicit_return, _explicit_return=explicit_return,
) )