reflex/reflex/vars/function.py
2024-11-13 19:03:42 -08:00

645 lines
20 KiB
Python

"""Immutable function vars."""
from __future__ import annotations
import dataclasses
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
from typing_extensions import Concatenate, Generic, ParamSpec, TypeVar
from reflex.utils import format
from reflex.utils.exceptions import VarTypeError
from reflex.utils.types import GenericType
from .base import (
CachedVarOperation,
LiteralVar,
ReflexCallable,
TypeComputer,
Var,
VarData,
cached_property_no_lock,
unwrap_reflex_callalbe,
)
P = ParamSpec("P")
R = TypeVar("R")
V1 = TypeVar("V1")
V2 = TypeVar("V2")
V3 = TypeVar("V3")
V4 = TypeVar("V4")
V5 = TypeVar("V5")
V6 = TypeVar("V6")
CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
OTHER_CALLABLE_TYPE = TypeVar(
"OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True
)
class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
"""Base class for immutable function vars."""
@overload
def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
arg1: Union[V1, Var[V1]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
arg6: Union[V6, Var[V6]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(self, *args: Var | Any) -> FunctionVar: ...
def partial(self, *args: Var | Any) -> FunctionVar: # type: ignore
"""Partially apply the function with the given arguments.
Args:
*args: The arguments to partially apply the function with.
Returns:
The partially applied function.
"""
if not args:
return self
args = tuple(map(LiteralVar.create, args))
remaining_validators = self._pre_check(*args)
partial_types, type_computer = self._partial_type(*args)
if self.__call__ is self.partial:
# if the default behavior is partial, we should return a new partial function
return ArgsFunctionOperationBuilder.create(
(),
VarOperationCall.create(
self,
*args,
Var(_js_expr="...args"),
_var_type=self._return_type(*args),
),
rest="args",
validators=remaining_validators,
type_computer=type_computer,
_var_type=partial_types,
)
return ArgsFunctionOperation.create(
(),
VarOperationCall.create(
self, *args, Var(_js_expr="...args"), _var_type=self._return_type(*args)
),
rest="args",
validators=remaining_validators,
type_computer=type_computer,
_var_type=partial_types,
)
@overload
def call(
self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
) -> VarOperationCall[[V1], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
) -> VarOperationCall[[V1, V2], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
) -> VarOperationCall[[V1, V2, V3], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
arg6: Union[V6, Var[V6]],
) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
) -> VarOperationCall[P, R]: ...
@overload
def call(self, *args: Var | Any) -> Var: ...
def call(self, *args: Var | Any) -> Var: # type: ignore
"""Call the function with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The function call operation.
Raises:
VarTypeError: If the number of arguments is invalid
"""
arg_len = self._arg_len()
if arg_len is not None and len(args) != arg_len:
raise VarTypeError(f"Invalid number of arguments provided to {str(self)}")
args = tuple(map(LiteralVar.create, args))
self._pre_check(*args)
return_type = self._return_type(*args)
return VarOperationCall.create(self, *args, _var_type=return_type).guess_type()
def _partial_type(
self, *args: Var | Any
) -> Tuple[GenericType, Optional[TypeComputer]]:
"""Override the type of the function call with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The overridden type of the function call.
"""
args_types, return_type = unwrap_reflex_callalbe(self._var_type)
if isinstance(args_types, tuple):
return ReflexCallable[[*args_types[len(args) :]], return_type], None # type: ignore
return ReflexCallable[..., return_type], None
def _arg_len(self) -> int | None:
"""Get the number of arguments the function takes.
Returns:
The number of arguments the function takes.
"""
args_types, _ = unwrap_reflex_callalbe(self._var_type)
if isinstance(args_types, tuple):
return len(args_types)
return None
def _return_type(self, *args: Var | Any) -> GenericType:
"""Override the type of the function call with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The overridden type of the function call.
"""
partial_types, _ = self._partial_type(*args)
return unwrap_reflex_callalbe(partial_types)[1]
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
return tuple()
__call__ = call
class BuilderFunctionVar(
FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
):
"""Base class for immutable function vars with the builder pattern."""
__call__ = FunctionVar.partial
class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
"""Base class for immutable function vars from a string."""
@classmethod
def create(
cls,
func: str,
_var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
_var_data: VarData | None = None,
) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
"""Create a new function var from a string.
Args:
func: The function to call.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The function var.
"""
return FunctionStringVar(
_js_expr=func,
_var_type=_var_type,
_var_data=_var_data,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
"""Base class for immutable vars that are the result of a function call."""
_func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None)
_args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return f"({str(self._func)}({', '.join([str(LiteralVar.create(arg)) for arg in self._args])}))"
@cached_property_no_lock
def _cached_get_all_var_data(self) -> VarData | None:
"""Get all the var data associated with the var.
Returns:
All the var data associated with the var.
"""
return VarData.merge(
self._func._get_all_var_data() if self._func is not None else None,
*[LiteralVar.create(arg)._get_all_var_data() for arg in self._args],
self._var_data,
)
@classmethod
def create(
cls,
func: FunctionVar[ReflexCallable[P, R]],
*args: Var | Any,
_var_type: GenericType = Any,
_var_data: VarData | None = None,
) -> VarOperationCall:
"""Create a new function call var.
Args:
func: The function to call.
*args: The arguments to call the function with.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The function call var.
"""
function_return_type = (
func._var_type.__args__[1]
if getattr(func._var_type, "__args__", None)
else Any
)
var_type = _var_type if _var_type is not Any else function_return_type
return cls(
_js_expr="",
_var_type=var_type,
_var_data=_var_data,
_func=func,
_args=args,
)
@dataclasses.dataclass(frozen=True)
class DestructuredArg:
"""Class for destructured arguments."""
fields: Tuple[str, ...] = tuple()
rest: Optional[str] = None
def to_javascript(self) -> str:
"""Convert the destructured argument to JavaScript.
Returns:
The destructured argument in JavaScript.
"""
return format.wrap(
", ".join(self.fields) + (f", ...{self.rest}" if self.rest else ""),
"{",
"}",
)
@dataclasses.dataclass(
frozen=True,
)
class FunctionArgs:
"""Class for function arguments."""
args: Tuple[Union[str, DestructuredArg], ...] = tuple()
rest: Optional[str] = None
def format_args_function_operation(
self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
) -> str:
"""Format an args function operation.
Args:
self: The function operation.
Returns:
The formatted args function operation.
"""
arg_names_str = ", ".join(
[
arg if isinstance(arg, str) else arg.to_javascript()
for arg in self._args.args
]
+ ([f"...{self._args.rest}"] if self._args.rest else [])
)
return_expr_str = str(LiteralVar.create(self._return_expr))
# Wrap return expression in curly braces if explicit return syntax is used.
return_expr_str_wrapped = (
format.wrap(return_expr_str, "{", "}")
if self._explicit_return
else return_expr_str
)
return f"(({arg_names_str}) => {return_expr_str_wrapped})"
def pre_check_args(
self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, *args: Var | Any
) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
Args:
self: The function operation.
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
Raises:
VarTypeError: If the arguments are invalid.
"""
for i, (validator, arg) in enumerate(zip(self._validators, args)):
if not validator(arg):
arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
def figure_partial_type(
self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
*args: Var | Any,
) -> Tuple[GenericType, Optional[TypeComputer]]:
"""Figure out the return type of the function.
Args:
self: The function operation.
*args: The arguments to call the function with.
Returns:
The return type of the function.
"""
return (
self._type_computer(*args)
if self._type_computer is not None
else FunctionVar._partial_type(self, *args)
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArgsFunctionOperation(CachedVarOperation, FunctionVar[CALLABLE_TYPE]):
"""Base class for immutable function defined via arguments and return expression."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
_validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
default_factory=tuple
)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
_function_name: str = dataclasses.field(default="")
_type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
_explicit_return: bool = dataclasses.field(default=False)
_cached_var_name = cached_property_no_lock(format_args_function_operation)
_pre_check = pre_check_args # type: ignore
_partial_type = figure_partial_type # type: ignore
@classmethod
def create(
cls,
args_names: Sequence[Union[str, DestructuredArg]],
return_expr: Var | Any,
rest: str | None = None,
validators: Sequence[Callable[[Any], bool]] = (),
function_name: str = "",
explicit_return: bool = False,
type_computer: Optional[TypeComputer] = None,
_var_type: GenericType = Callable,
_var_data: VarData | None = None,
):
"""Create a new function var.
Args:
args_names: The names of the arguments.
return_expr: The return expression of the function.
rest: The name of the rest argument.
validators: The validators for the arguments.
function_name: The name of the function.
explicit_return: Whether to use explicit return syntax.
type_computer: A function to compute the return type.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The function var.
"""
return cls(
_js_expr="",
_var_type=_var_type,
_var_data=_var_data,
_args=FunctionArgs(args=tuple(args_names), rest=rest),
_function_name=function_name,
_validators=tuple(validators),
_return_expr=return_expr,
_explicit_return=explicit_return,
_type_computer=type_computer,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArgsFunctionOperationBuilder(
CachedVarOperation, BuilderFunctionVar[CALLABLE_TYPE]
):
"""Base class for immutable function defined via arguments and return expression with the builder pattern."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
_validators: Tuple[Callable[[Any], bool], ...] = dataclasses.field(
default_factory=tuple
)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
_function_name: str = dataclasses.field(default="")
_type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
_explicit_return: bool = dataclasses.field(default=False)
_cached_var_name = cached_property_no_lock(format_args_function_operation)
_pre_check = pre_check_args # type: ignore
_partial_type = figure_partial_type # type: ignore
@classmethod
def create(
cls,
args_names: Sequence[Union[str, DestructuredArg]],
return_expr: Var | Any,
rest: str | None = None,
validators: Sequence[Callable[[Any], bool]] = (),
function_name: str = "",
explicit_return: bool = False,
type_computer: Optional[TypeComputer] = None,
_var_type: GenericType = Callable,
_var_data: VarData | None = None,
):
"""Create a new function var.
Args:
args_names: The names of the arguments.
return_expr: The return expression of the function.
rest: The name of the rest argument.
validators: The validators for the arguments.
function_name: The name of the function.
explicit_return: Whether to use explicit return syntax.
type_computer: A function to compute the return type.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The function var.
"""
return cls(
_js_expr="",
_var_type=_var_type,
_var_data=_var_data,
_args=FunctionArgs(args=tuple(args_names), rest=rest),
_function_name=function_name,
_validators=tuple(validators),
_return_expr=return_expr,
_explicit_return=explicit_return,
_type_computer=type_computer,
)
if python_version := sys.version_info[:2] >= (3, 10):
JSON_STRINGIFY = FunctionStringVar.create(
"JSON.stringify", _var_type=ReflexCallable[[Any], str]
)
ARRAY_ISARRAY = FunctionStringVar.create(
"Array.isArray", _var_type=ReflexCallable[[Any], bool]
)
PROTOTYPE_TO_STRING = FunctionStringVar.create(
"((__to_string) => __to_string.toString())",
_var_type=ReflexCallable[[Any], str],
)
else:
JSON_STRINGIFY = FunctionStringVar.create(
"JSON.stringify", _var_type=ReflexCallable[Any, str]
)
ARRAY_ISARRAY = FunctionStringVar.create(
"Array.isArray", _var_type=ReflexCallable[Any, bool]
)
PROTOTYPE_TO_STRING = FunctionStringVar.create(
"((__to_string) => __to_string.toString())",
_var_type=ReflexCallable[Any, str],
)