implement type computers

This commit is contained in:
Khaleel Al-Adhami 2024-11-13 18:17:53 -08:00
parent ebc81811c0
commit f4aa1f58c3
6 changed files with 699 additions and 317 deletions

View File

@ -14,7 +14,7 @@ import re
import string
import sys
import warnings
from types import CodeType, FunctionType
from types import CodeType, EllipsisType, FunctionType
from typing import (
TYPE_CHECKING,
Any,
@ -26,7 +26,6 @@ from typing import (
Iterable,
List,
Literal,
NoReturn,
Optional,
Set,
Tuple,
@ -38,7 +37,14 @@ from typing import (
overload,
)
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override
from typing_extensions import (
ParamSpec,
Protocol,
TypeGuard,
deprecated,
get_type_hints,
override,
)
from reflex import constants
from reflex.base import Base
@ -69,6 +75,7 @@ from reflex.utils.types import (
if TYPE_CHECKING:
from reflex.state import BaseState
from .function import ArgsFunctionOperation, ReflexCallable
from .number import BooleanVar, NumberVar
from .object import ObjectVar
from .sequence import ArrayVar, StringVar
@ -79,6 +86,36 @@ OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
warnings.filterwarnings("ignore", message="fields may not start with an underscore")
P = ParamSpec("P")
R = TypeVar("R")
class ReflexCallable(Protocol[P, R]):
"""Protocol for a callable."""
__call__: Callable[P, R]
def unwrap_reflex_callalbe(
callable_type: GenericType,
) -> Tuple[Union[EllipsisType, Tuple[GenericType, ...]], GenericType]:
"""Unwrap the ReflexCallable type.
Args:
callable_type: The ReflexCallable type to unwrap.
Returns:
The unwrapped ReflexCallable type.
"""
if callable_type is ReflexCallable:
return Ellipsis, Any
if get_origin(callable_type) is not ReflexCallable:
return Ellipsis, Any
args = get_args(callable_type)
if not args or len(args) != 2:
return Ellipsis, Any
return args
@dataclasses.dataclass(
eq=False,
@ -409,9 +446,11 @@ class Var(Generic[VAR_TYPE]):
if _var_data or _js_expr != self._js_expr:
self.__init__(
_js_expr=_js_expr,
_var_type=self._var_type,
_var_data=VarData.merge(self._var_data, _var_data),
**{
**dataclasses.asdict(self),
"_js_expr": _js_expr,
"_var_data": VarData.merge(self._var_data, _var_data),
}
)
def __hash__(self) -> int:
@ -690,6 +729,12 @@ class Var(Generic[VAR_TYPE]):
@overload
def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ...
@overload
def guess_type(self: Var[list] | Var[tuple] | Var[set]) -> ArrayVar: ...
@overload
def guess_type(self: Var[dict]) -> ObjectVar[dict]: ...
@overload
def guess_type(self) -> Self: ...
@ -1413,71 +1458,94 @@ def get_python_literal(value: Union[LiteralVar, Any]) -> Any | None:
return value
def validate_arg(type_hint: GenericType) -> Callable[[Any], bool]:
"""Create a validator for an argument.
Args:
type_hint: The type hint of the argument.
Returns:
The validator.
"""
def validate(value: Any):
return True
return validate
P = ParamSpec("P")
T = TypeVar("T")
V1 = TypeVar("V1")
V2 = TypeVar("V2")
V3 = TypeVar("V3")
V4 = TypeVar("V4")
V5 = TypeVar("V5")
# NoReturn is used to match CustomVarOperationReturn with no type hint.
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[NoReturn]],
) -> Callable[P, Var]: ...
class TypeComputer(Protocol):
"""A protocol for type computers."""
def __call__(self, *args: Var) -> Tuple[GenericType, Union[TypeComputer, None]]:
"""Compute the type of the operation.
Args:
*args: The arguments to compute the type of.
Returns:
The type of the operation.
"""
...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[bool]],
) -> Callable[P, BooleanVar]: ...
NUMBER_T = TypeVar("NUMBER_T", int, float, Union[int, float])
func: Callable[[], CustomVarOperationReturn[T]],
) -> ArgsFunctionOperation[ReflexCallable[[], T]]: ...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[NUMBER_T]],
) -> Callable[P, NumberVar[NUMBER_T]]: ...
func: Callable[[Var[V1]], CustomVarOperationReturn[T]],
) -> ArgsFunctionOperation[ReflexCallable[[V1], T]]: ...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[str]],
) -> Callable[P, StringVar]: ...
LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set])
func: Callable[[Var[V1], Var[V2]], CustomVarOperationReturn[T]],
) -> ArgsFunctionOperation[ReflexCallable[[V1, V2], T]]: ...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[LIST_T]],
) -> Callable[P, ArrayVar[LIST_T]]: ...
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
func: Callable[[Var[V1], Var[V2], Var[V3]], CustomVarOperationReturn[T]],
) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3], T]]: ...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]],
) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ...
func: Callable[[Var[V1], Var[V2], Var[V3], Var[V4]], CustomVarOperationReturn[T]],
) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3, V4], T]]: ...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[T]],
) -> Callable[P, Var[T]]: ...
func: Callable[
[Var[V1], Var[V2], Var[V3], Var[V4], Var[V5]],
CustomVarOperationReturn[T],
],
) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3, V4, V5], T]]: ...
def var_operation(
func: Callable[P, CustomVarOperationReturn[T]],
) -> Callable[P, Var[T]]:
func: Callable[..., CustomVarOperationReturn[T]],
) -> ArgsFunctionOperation:
"""Decorator for creating a var operation.
Example:
```python
@var_operation
def add(a: NumberVar, b: NumberVar):
def add(a: Var[int], b: Var[int]):
return custom_var_operation(f"{a} + {b}")
```
@ -1487,26 +1555,61 @@ def var_operation(
Returns:
The decorated function.
"""
from .function import ArgsFunctionOperation, ReflexCallable
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Var[T]:
func_args = list(inspect.signature(func).parameters)
args_vars = {
func_args[i]: (LiteralVar.create(arg) if not isinstance(arg, Var) else arg)
for i, arg in enumerate(args)
}
kwargs_vars = {
key: LiteralVar.create(value) if not isinstance(value, Var) else value
for key, value in kwargs.items()
}
func_name = func.__name__
return CustomVarOperation.create(
name=func.__name__,
args=tuple(list(args_vars.items()) + list(kwargs_vars.items())),
return_var=func(*args_vars.values(), **kwargs_vars), # type: ignore
).guess_type()
func_arg_spec = inspect.getfullargspec(func)
return wrapper
if func_arg_spec.kwonlyargs:
raise TypeError(f"Function {func_name} cannot have keyword-only arguments.")
if func_arg_spec.varargs:
raise TypeError(f"Function {func_name} cannot have variable arguments.")
arg_names = func_arg_spec.args
type_hints = get_type_hints(func)
if not all(
(get_origin((type_hint := type_hints.get(arg_name, Any))) or type_hint) is Var
and len(get_args(type_hint)) <= 1
for arg_name in arg_names
):
raise TypeError(
f"Function {func_name} must have type hints of the form `Var[Type]`."
)
args_with_type_hints = tuple(
(arg_name, (args[0] if (args := get_args(type_hints[arg_name])) else Any))
for arg_name in arg_names
)
arg_vars = tuple(
(
Var("_" + arg_name, _var_type=arg_python_type)
if not isinstance(arg_python_type, TypeVar)
else Var("_" + arg_name)
)
for arg_name, arg_python_type in args_with_type_hints
)
custom_operation_return = func(*arg_vars)
args_operation = ArgsFunctionOperation.create(
tuple(map(str, arg_vars)),
custom_operation_return,
validators=tuple(
validate_arg(type_hints.get(arg_name, Any)) for arg_name in arg_names
),
function_name=func_name,
type_computer=custom_operation_return._type_computer,
_var_type=ReflexCallable[
tuple(arg_python_type for _, arg_python_type in args_with_type_hints),
custom_operation_return._var_type,
],
)
return args_operation
def figure_out_type(value: Any) -> types.GenericType:
@ -1621,66 +1724,6 @@ class CachedVarOperation:
)
def and_operation(a: Var | Any, b: Var | Any) -> Var:
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
return _and_operation(a, b) # type: ignore
@var_operation
def _and_operation(a: Var, b: Var):
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
return var_operation_return(
js_expression=f"({a} && {b})",
var_type=unionize(a._var_type, b._var_type),
)
def or_operation(a: Var | Any, b: Var | Any) -> Var:
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
return _or_operation(a, b) # type: ignore
@var_operation
def _or_operation(a: Var, b: Var):
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
return var_operation_return(
js_expression=f"({a} || {b})",
var_type=unionize(a._var_type, b._var_type),
)
@dataclasses.dataclass(
eq=False,
frozen=True,
@ -2289,14 +2332,22 @@ def computed_var(
RETURN = TypeVar("RETURN")
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class CustomVarOperationReturn(Var[RETURN]):
"""Base class for custom var operations."""
_type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
@classmethod
def create(
cls,
js_expression: str,
_var_type: Type[RETURN] | None = None,
_type_computer: Optional[TypeComputer] = None,
_var_data: VarData | None = None,
) -> CustomVarOperationReturn[RETURN]:
"""Create a CustomVarOperation.
@ -2304,6 +2355,7 @@ class CustomVarOperationReturn(Var[RETURN]):
Args:
js_expression: The JavaScript expression to evaluate.
_var_type: The type of the var.
_type_computer: A function to compute the type of the var given the arguments.
_var_data: Additional hooks and imports associated with the Var.
Returns:
@ -2312,6 +2364,7 @@ class CustomVarOperationReturn(Var[RETURN]):
return CustomVarOperationReturn(
_js_expr=js_expression,
_var_type=_var_type or Any,
_type_computer=_type_computer,
_var_data=_var_data,
)
@ -2319,6 +2372,7 @@ class CustomVarOperationReturn(Var[RETURN]):
def var_operation_return(
js_expression: str,
var_type: Type[RETURN] | None = None,
type_computer: Optional[TypeComputer] = None,
var_data: VarData | None = None,
) -> CustomVarOperationReturn[RETURN]:
"""Shortcut for creating a CustomVarOperationReturn.
@ -2326,15 +2380,17 @@ def var_operation_return(
Args:
js_expression: The JavaScript expression to evaluate.
var_type: The type of the var.
type_computer: A function to compute the type of the var given the arguments.
var_data: Additional hooks and imports associated with the Var.
Returns:
The CustomVarOperationReturn.
"""
return CustomVarOperationReturn.create(
js_expression,
var_type,
var_data,
js_expression=js_expression,
_var_type=var_type,
_type_computer=type_computer,
_var_data=var_data,
)
@ -2942,3 +2998,157 @@ def field(value: T) -> Field[T]:
The Field.
"""
return value # type: ignore
def and_operation(a: Var | Any, b: Var | Any) -> Var:
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
return _and_operation(a, b) # type: ignore
@var_operation
def _and_operation(a: Var, b: Var):
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
def type_computer(*args: Var):
if not args:
return (ReflexCallable[[Any, Any], Any], type_computer)
if len(args) == 1:
return (
ReflexCallable[[Any], Any],
functools.partial(type_computer, args[0]),
)
return (
ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
None,
)
return var_operation_return(
js_expression=f"({a} && {b})",
type_computer=type_computer,
)
def or_operation(a: Var | Any, b: Var | Any) -> Var:
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
return _or_operation(a, b) # type: ignore
@var_operation
def _or_operation(a: Var, b: Var):
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
def type_computer(*args: Var):
if not args:
return (ReflexCallable[[Any, Any], Any], type_computer)
if len(args) == 1:
return (
ReflexCallable[[Any], Any],
functools.partial(type_computer, args[0]),
)
return (
ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
None,
)
return var_operation_return(
js_expression=f"({a} || {b})",
type_computer=type_computer,
)
def passthrough_unary_type_computer(no_args: GenericType) -> TypeComputer:
"""Create a type computer for unary operations.
Args:
no_args: The type to return when no arguments are provided.
Returns:
The type computer.
"""
def type_computer(*args: Var):
if not args:
return (no_args, type_computer)
return (ReflexCallable[[], args[0]._var_type], None)
return type_computer
def unary_type_computer(
no_args: GenericType, computer: Callable[[Var], GenericType]
) -> TypeComputer:
"""Create a type computer for unary operations.
Args:
no_args: The type to return when no arguments are provided.
computer: The function to compute the type.
Returns:
The type computer.
"""
def type_computer(*args: Var):
if not args:
return (no_args, type_computer)
return (ReflexCallable[[], computer(args[0])], None)
return type_computer
def nary_type_computer(
*types: GenericType, computer: Callable[..., GenericType]
) -> TypeComputer:
"""Create a type computer for n-ary operations.
Args:
types: The types to return when no arguments are provided.
computer: The function to compute the type.
Returns:
The type computer.
"""
def type_computer(*args: Var):
if len(args) != len(types):
return (
ReflexCallable[[], types[len(args)]],
functools.partial(type_computer, *args),
)
return (
ReflexCallable[[], computer(args)],
None,
)
return type_computer

View File

@ -6,28 +6,31 @@ import dataclasses
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
from typing_extensions import Concatenate, Generic, ParamSpec, TypeVar
from reflex.utils import format
from reflex.utils.exceptions import VarTypeError
from reflex.utils.types import GenericType
from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
from .base import (
CachedVarOperation,
LiteralVar,
ReflexCallable,
TypeComputer,
Var,
VarData,
cached_property_no_lock,
unwrap_reflex_callalbe,
)
P = ParamSpec("P")
R = TypeVar("R")
V1 = TypeVar("V1")
V2 = TypeVar("V2")
V3 = TypeVar("V3")
V4 = TypeVar("V4")
V5 = TypeVar("V5")
V6 = TypeVar("V6")
R = TypeVar("R")
class ReflexCallable(Protocol[P, R]):
"""Protocol for a callable."""
__call__: Callable[P, R]
CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
@ -112,20 +115,37 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
"""
if not args:
return self
args = tuple(map(LiteralVar.create, args))
remaining_validators = self._pre_check(*args)
partial_types, type_computer = self._partial_type(*args)
if self.__call__ is self.partial:
# if the default behavior is partial, we should return a new partial function
return ArgsFunctionOperationBuilder.create(
(),
VarOperationCall.create(self, *args, Var(_js_expr="...args")),
VarOperationCall.create(
self,
*args,
Var(_js_expr="...args"),
_var_type=self._return_type(*args),
),
rest="args",
validators=remaining_validators,
type_computer=type_computer,
_var_type=partial_types,
)
return ArgsFunctionOperation.create(
(),
VarOperationCall.create(self, *args, Var(_js_expr="...args")),
VarOperationCall.create(
self, *args, Var(_js_expr="...args"), _var_type=self._return_type(*args)
),
rest="args",
validators=remaining_validators,
type_computer=type_computer,
_var_type=partial_types,
)
@overload
@ -194,9 +214,56 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
Returns:
The function call operation.
Raises:
VarTypeError: If the number of arguments is invalid
"""
arg_len = self._arg_len()
if arg_len is not None and len(args) != arg_len:
raise VarTypeError(f"Invalid number of arguments provided to {str(self)}")
args = tuple(map(LiteralVar.create, args))
self._pre_check(*args)
return VarOperationCall.create(self, *args).guess_type()
return_type = self._return_type(*args)
return VarOperationCall.create(self, *args, _var_type=return_type).guess_type()
def _partial_type(
self, *args: Var | Any
) -> Tuple[GenericType, Optional[TypeComputer]]:
"""Override the type of the function call with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The overridden type of the function call.
"""
args_types, return_type = unwrap_reflex_callalbe(self._var_type)
if isinstance(args_types, tuple):
return ReflexCallable[[*args_types[len(args) :]], return_type], None
return ReflexCallable[..., return_type], None
def _arg_len(self) -> int | None:
"""Get the number of arguments the function takes.
Returns:
The number of arguments the function takes.
"""
args_types, _ = unwrap_reflex_callalbe(self._var_type)
if isinstance(args_types, tuple):
return len(args_types)
return None
def _return_type(self, *args: Var | Any) -> GenericType:
"""Override the type of the function call with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The overridden type of the function call.
"""
partial_types, _ = self._partial_type(*args)
return unwrap_reflex_callalbe(partial_types)[1]
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
@ -343,11 +410,12 @@ class FunctionArgs:
def format_args_function_operation(
args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
) -> str:
"""Format an args function operation.
Args:
self: The function operation.
args: The function arguments.
return_expr: The return expression.
explicit_return: Whether to use explicit return syntax.
@ -356,26 +424,76 @@ def format_args_function_operation(
The formatted args function operation.
"""
arg_names_str = ", ".join(
[arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
+ ([f"...{args.rest}"] if args.rest else [])
[
arg if isinstance(arg, str) else arg.to_javascript()
for arg in self._args.args
]
+ ([f"...{self._args.rest}"] if self._args.rest else [])
)
return_expr_str = str(LiteralVar.create(return_expr))
return_expr_str = str(LiteralVar.create(self._return_expr))
# Wrap return expression in curly braces if explicit return syntax is used.
return_expr_str_wrapped = (
format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
format.wrap(return_expr_str, "{", "}")
if self._explicit_return
else return_expr_str
)
return f"(({arg_names_str}) => {return_expr_str_wrapped})"
def pre_check_args(
self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, *args: Var | Any
) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
Args:
self: The function operation.
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
for i, (validator, arg) in enumerate(zip(self._validators, args)):
if not validator(arg):
arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
def figure_partial_type(
self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
*args: Var | Any,
) -> Tuple[GenericType, Optional[TypeComputer]]:
"""Figure out the return type of the function.
Args:
self: The function operation.
*args: The arguments to call the function with.
Returns:
The return type of the function.
"""
return (
self._type_computer(*args)
if self._type_computer is not None
else FunctionVar._partial_type(self, *args)
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
class ArgsFunctionOperation(CachedVarOperation, FunctionVar[CALLABLE_TYPE]):
"""Base class for immutable function defined via arguments and return expression."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
@ -384,39 +502,14 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
_function_name: str = dataclasses.field(default="")
_type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
_explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
_cached_var_name = cached_property_no_lock(format_args_function_operation)
Returns:
The name of the var.
"""
return format_args_function_operation(
self._args, self._return_expr, self._explicit_return
)
_pre_check = pre_check_args
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
for i, (validator, arg) in enumerate(zip(self._validators, args)):
if not validator(arg):
arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
_partial_type = figure_partial_type
@classmethod
def create(
@ -427,6 +520,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
validators: Sequence[Callable[[Any], bool]] = (),
function_name: str = "",
explicit_return: bool = False,
type_computer: Optional[TypeComputer] = None,
_var_type: GenericType = Callable,
_var_data: VarData | None = None,
):
@ -439,6 +533,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
validators: The validators for the arguments.
function_name: The name of the function.
explicit_return: Whether to use explicit return syntax.
type_computer: A function to compute the return type.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
@ -453,6 +549,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
_validators=tuple(validators),
_return_expr=return_expr,
_explicit_return=explicit_return,
_type_computer=type_computer,
)
@ -461,7 +558,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
class ArgsFunctionOperationBuilder(
CachedVarOperation, BuilderFunctionVar[CALLABLE_TYPE]
):
"""Base class for immutable function defined via arguments and return expression with the builder pattern."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
@ -470,39 +569,14 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
_function_name: str = dataclasses.field(default="")
_type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
_explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
_cached_var_name = cached_property_no_lock(format_args_function_operation)
Returns:
The name of the var.
"""
return format_args_function_operation(
self._args, self._return_expr, self._explicit_return
)
_pre_check = pre_check_args
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
for i, (validator, arg) in enumerate(zip(self._validators, args)):
if not validator(arg):
arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
_partial_type = figure_partial_type
@classmethod
def create(
@ -513,6 +587,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
validators: Sequence[Callable[[Any], bool]] = (),
function_name: str = "",
explicit_return: bool = False,
type_computer: Optional[TypeComputer] = None,
_var_type: GenericType = Callable,
_var_data: VarData | None = None,
):
@ -525,6 +600,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
validators: The validators for the arguments.
function_name: The name of the function.
explicit_return: Whether to use explicit return syntax.
type_computer: A function to compute the return type.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
@ -539,6 +616,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
_validators=tuple(validators),
_return_expr=return_expr,
_explicit_return=explicit_return,
_type_computer=type_computer,
)

View File

@ -3,19 +3,11 @@
from __future__ import annotations
import dataclasses
import functools
import json
import math
import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
NoReturn,
Type,
TypeVar,
Union,
overload,
)
from typing import TYPE_CHECKING, Any, Callable, NoReturn, TypeVar, Union, overload
from reflex.constants.base import Dirs
from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
@ -25,8 +17,11 @@ from reflex.utils.types import is_optional
from .base import (
CustomVarOperationReturn,
LiteralVar,
ReflexCallable,
Var,
VarData,
nary_type_computer,
passthrough_unary_type_computer,
unionize,
var_operation,
var_operation_return,
@ -544,8 +539,8 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
def binary_number_operation(
func: Callable[[NumberVar, NumberVar], str],
) -> Callable[[number_types, number_types], NumberVar]:
func: Callable[[Var[int | float], Var[int | float]], str],
):
"""Decorator to create a binary number operation.
Args:
@ -555,30 +550,37 @@ def binary_number_operation(
The binary number operation.
"""
@var_operation
def operation(lhs: NumberVar, rhs: NumberVar):
def operation(
lhs: Var[int | float], rhs: Var[int | float]
) -> CustomVarOperationReturn[int | float]:
def type_computer(*args: Var):
if not args:
return (
ReflexCallable[[int | float, int | float], int | float],
type_computer,
)
if len(args) == 1:
return (
ReflexCallable[[int | float], int | float],
functools.partial(type_computer, args[0]),
)
return (
ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
None,
)
return var_operation_return(
js_expression=func(lhs, rhs),
var_type=unionize(lhs._var_type, rhs._var_type),
type_computer=type_computer,
)
def wrapper(lhs: number_types, rhs: number_types) -> NumberVar:
"""Create the binary number operation.
operation.__name__ = func.__name__
Args:
lhs: The first number.
rhs: The second number.
Returns:
The binary number operation.
"""
return operation(lhs, rhs) # type: ignore
return wrapper
return var_operation(operation)
@binary_number_operation
def number_add_operation(lhs: NumberVar, rhs: NumberVar):
def number_add_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Add two numbers.
Args:
@ -592,7 +594,7 @@ def number_add_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation
def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
def number_subtract_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Subtract two numbers.
Args:
@ -605,8 +607,15 @@ def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
return f"({lhs} - {rhs})"
unary_operation_type_computer = passthrough_unary_type_computer(
ReflexCallable[[int | float], int | float]
)
@var_operation
def number_abs_operation(value: NumberVar):
def number_abs_operation(
value: Var[int | float],
) -> CustomVarOperationReturn[int | float]:
"""Get the absolute value of the number.
Args:
@ -616,12 +625,12 @@ def number_abs_operation(value: NumberVar):
The number absolute operation.
"""
return var_operation_return(
js_expression=f"Math.abs({value})", var_type=value._var_type
js_expression=f"Math.abs({value})", type_computer=unary_operation_type_computer
)
@binary_number_operation
def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
def number_multiply_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Multiply two numbers.
Args:
@ -636,7 +645,7 @@ def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
@var_operation
def number_negate_operation(
value: NumberVar[NUMBER_T],
value: Var[NUMBER_T],
) -> CustomVarOperationReturn[NUMBER_T]:
"""Negate the number.
@ -646,11 +655,13 @@ def number_negate_operation(
Returns:
The number negation operation.
"""
return var_operation_return(js_expression=f"-({value})", var_type=value._var_type)
return var_operation_return(
js_expression=f"-({value})", type_computer=unary_operation_type_computer
)
@binary_number_operation
def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
def number_true_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Divide two numbers.
Args:
@ -664,7 +675,7 @@ def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation
def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
def number_floor_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Floor divide two numbers.
Args:
@ -678,7 +689,7 @@ def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation
def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
def number_modulo_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Modulo two numbers.
Args:
@ -692,7 +703,7 @@ def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation
def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
def number_exponent_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Exponentiate two numbers.
Args:
@ -706,7 +717,7 @@ def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
@var_operation
def number_round_operation(value: NumberVar):
def number_round_operation(value: Var[int | float]):
"""Round the number.
Args:
@ -719,7 +730,7 @@ def number_round_operation(value: NumberVar):
@var_operation
def number_ceil_operation(value: NumberVar):
def number_ceil_operation(value: Var[int | float]):
"""Ceil the number.
Args:
@ -732,7 +743,7 @@ def number_ceil_operation(value: NumberVar):
@var_operation
def number_floor_operation(value: NumberVar):
def number_floor_operation(value: Var[int | float]):
"""Floor the number.
Args:
@ -745,7 +756,7 @@ def number_floor_operation(value: NumberVar):
@var_operation
def number_trunc_operation(value: NumberVar):
def number_trunc_operation(value: Var[int | float]):
"""Trunc the number.
Args:
@ -838,7 +849,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
@var_operation
def boolean_to_number_operation(value: BooleanVar):
def boolean_to_number_operation(value: Var[bool]):
"""Convert the boolean to a number.
Args:
@ -969,7 +980,7 @@ def not_equal_operation(lhs: Var, rhs: Var):
@var_operation
def boolean_not_operation(value: BooleanVar):
def boolean_not_operation(value: Var[bool]):
"""Boolean NOT the boolean.
Args:
@ -1117,7 +1128,7 @@ U = TypeVar("U")
@var_operation
def ternary_operation(
condition: BooleanVar, if_true: Var[T], if_false: Var[U]
condition: Var[bool], if_true: Var[T], if_false: Var[U]
) -> CustomVarOperationReturn[Union[T, U]]:
"""Create a ternary operation.
@ -1129,12 +1140,14 @@ def ternary_operation(
Returns:
The ternary operation.
"""
type_value: Union[Type[T], Type[U]] = unionize(
if_true._var_type, if_false._var_type
)
value: CustomVarOperationReturn[Union[T, U]] = var_operation_return(
js_expression=f"({condition} ? {if_true} : {if_false})",
var_type=type_value,
type_computer=nary_type_computer(
ReflexCallable[[bool, Any, Any], Any],
ReflexCallable[[Any, Any], Any],
ReflexCallable[[Any], Any],
computer=lambda args: unionize(args[1]._var_type, args[2]._var_type),
),
)
return value

View File

@ -21,15 +21,23 @@ from typing import (
from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type, get_origin
from reflex.utils.types import (
GenericType,
get_attribute_access_type,
get_origin,
unionize,
)
from .base import (
CachedVarOperation,
LiteralVar,
ReflexCallable,
Var,
VarData,
cached_property_no_lock,
figure_out_type,
nary_type_computer,
unary_type_computer,
var_operation,
var_operation_return,
)
@ -406,7 +414,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
@var_operation
def object_keys_operation(value: ObjectVar):
def object_keys_operation(value: Var):
"""Get the keys of an object.
Args:
@ -422,7 +430,7 @@ def object_keys_operation(value: ObjectVar):
@var_operation
def object_values_operation(value: ObjectVar):
def object_values_operation(value: Var):
"""Get the values of an object.
Args:
@ -433,12 +441,15 @@ def object_values_operation(value: ObjectVar):
"""
return var_operation_return(
js_expression=f"Object.values({value})",
var_type=List[value._value_type()],
type_computer=unary_type_computer(
ReflexCallable[[Any], List[Any]],
lambda x: List[x.to(ObjectVar)._value_type()],
),
)
@var_operation
def object_entries_operation(value: ObjectVar):
def object_entries_operation(value: Var):
"""Get the entries of an object.
Args:
@ -447,14 +458,18 @@ def object_entries_operation(value: ObjectVar):
Returns:
The entries of the object.
"""
value = value.to(ObjectVar)
return var_operation_return(
js_expression=f"Object.entries({value})",
var_type=List[Tuple[str, value._value_type()]],
type_computer=unary_type_computer(
ReflexCallable[[Any], List[Tuple[str, Any]]],
lambda x: List[Tuple[str, x.to(ObjectVar)._value_type()]],
),
)
@var_operation
def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
def object_merge_operation(lhs: Var, rhs: Var):
"""Merge two objects.
Args:
@ -466,10 +481,14 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
"""
return var_operation_return(
js_expression=f"({{...{lhs}, ...{rhs}}})",
var_type=Dict[
Union[lhs._key_type(), rhs._key_type()],
Union[lhs._value_type(), rhs._value_type()],
],
type_computer=nary_type_computer(
ReflexCallable[[Any, Any], Dict[Any, Any]],
ReflexCallable[[Any], Dict[Any, Any]],
computer=lambda args: Dict[
unionize(*[arg.to(ObjectVar)._key_type() for arg in args]),
unionize(*[arg.to(ObjectVar)._value_type() for arg in args]),
],
),
)
@ -526,7 +545,7 @@ class ObjectItemOperation(CachedVarOperation, Var):
@var_operation
def object_has_own_property_operation(object: ObjectVar, key: Var):
def object_has_own_property_operation(object: Var, key: Var):
"""Check if an object has a key.
Args:

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import dataclasses
import functools
import inspect
import json
import re
@ -34,6 +35,7 @@ from .base import (
CachedVarOperation,
CustomVarOperationReturn,
LiteralVar,
ReflexCallable,
Var,
VarData,
_global_vars,
@ -41,7 +43,10 @@ from .base import (
figure_out_type,
get_python_literal,
get_unique_variable_name,
nary_type_computer,
passthrough_unary_type_computer,
unionize,
unwrap_reflex_callalbe,
var_operation,
var_operation_return,
)
@ -353,7 +358,7 @@ class StringVar(Var[STRING_TYPE], python_types=str):
@var_operation
def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
def string_lt_operation(lhs: Var[str], rhs: Var[str]):
"""Check if a string is less than another string.
Args:
@ -367,7 +372,7 @@ def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
@var_operation
def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
def string_gt_operation(lhs: Var[str], rhs: Var[str]):
"""Check if a string is greater than another string.
Args:
@ -381,7 +386,7 @@ def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
@var_operation
def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
def string_le_operation(lhs: Var[str], rhs: Var[str]):
"""Check if a string is less than or equal to another string.
Args:
@ -395,7 +400,7 @@ def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
@var_operation
def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
def string_ge_operation(lhs: Var[str], rhs: Var[str]):
"""Check if a string is greater than or equal to another string.
Args:
@ -409,7 +414,7 @@ def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
@var_operation
def string_lower_operation(string: StringVar[Any]):
def string_lower_operation(string: Var[str]):
"""Convert a string to lowercase.
Args:
@ -422,7 +427,7 @@ def string_lower_operation(string: StringVar[Any]):
@var_operation
def string_upper_operation(string: StringVar[Any]):
def string_upper_operation(string: Var[str]):
"""Convert a string to uppercase.
Args:
@ -435,7 +440,7 @@ def string_upper_operation(string: StringVar[Any]):
@var_operation
def string_strip_operation(string: StringVar[Any]):
def string_strip_operation(string: Var[str]):
"""Strip a string.
Args:
@ -449,7 +454,7 @@ def string_strip_operation(string: StringVar[Any]):
@var_operation
def string_contains_field_operation(
haystack: StringVar[Any], needle: StringVar[Any] | str, field: StringVar[Any] | str
haystack: Var[str], needle: Var[str], field: Var[str]
):
"""Check if a string contains another string.
@ -468,7 +473,7 @@ def string_contains_field_operation(
@var_operation
def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] | str):
def string_contains_operation(haystack: Var[str], needle: Var[str]):
"""Check if a string contains another string.
Args:
@ -484,9 +489,7 @@ def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] |
@var_operation
def string_starts_with_operation(
full_string: StringVar[Any], prefix: StringVar[Any] | str
):
def string_starts_with_operation(full_string: Var[str], prefix: Var[str]):
"""Check if a string starts with a prefix.
Args:
@ -502,7 +505,7 @@ def string_starts_with_operation(
@var_operation
def string_item_operation(string: StringVar[Any], index: NumberVar | int):
def string_item_operation(string: Var[str], index: Var[int]):
"""Get an item from a string.
Args:
@ -515,23 +518,9 @@ def string_item_operation(string: StringVar[Any], index: NumberVar | int):
return var_operation_return(js_expression=f"{string}.at({index})", var_type=str)
@var_operation
def array_join_operation(array: ArrayVar, sep: StringVar[Any] | str = ""):
"""Join the elements of an array.
Args:
array: The array.
sep: The separator.
Returns:
The joined elements.
"""
return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str)
@var_operation
def string_replace_operation(
string: StringVar, search_value: StringVar | str, new_value: StringVar | str
string: Var[str], search_value: Var[str], new_value: Var[str]
):
"""Replace a string with a value.
@ -1046,7 +1035,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
Returns:
The array pluck operation.
"""
return array_pluck_operation(self, field)
return array_pluck_operation(self, field).guess_type()
@overload
def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: ...
@ -1300,7 +1289,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
@var_operation
def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""):
def string_split_operation(string: Var[str], sep: Var[str]):
"""Split a string.
Args:
@ -1394,9 +1383,9 @@ class ArraySliceOperation(CachedVarOperation, ArrayVar):
@var_operation
def array_pluck_operation(
array: ArrayVar[ARRAY_VAR_TYPE],
field: StringVar | str,
) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
array: Var[ARRAY_VAR_TYPE],
field: Var[str],
) -> CustomVarOperationReturn[List]:
"""Pluck a field from an array of objects.
Args:
@ -1408,13 +1397,27 @@ def array_pluck_operation(
"""
return var_operation_return(
js_expression=f"{array}.map(e=>e?.[{field}])",
var_type=array._var_type,
var_type=List[Any],
)
@var_operation
def array_join_operation(array: Var[ARRAY_VAR_TYPE], sep: Var[str]):
"""Join the elements of an array.
Args:
array: The array.
sep: The separator.
Returns:
The joined elements.
"""
return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str)
@var_operation
def array_reverse_operation(
array: ArrayVar[ARRAY_VAR_TYPE],
array: Var[ARRAY_VAR_TYPE],
) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
"""Reverse an array.
@ -1426,12 +1429,12 @@ def array_reverse_operation(
"""
return var_operation_return(
js_expression=f"{array}.slice().reverse()",
var_type=array._var_type,
type_computer=passthrough_unary_type_computer(ReflexCallable[[List], List]),
)
@var_operation
def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
def array_lt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
"""Check if an array is less than another array.
Args:
@ -1445,7 +1448,7 @@ def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
@var_operation
def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
def array_gt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
"""Check if an array is greater than another array.
Args:
@ -1459,7 +1462,7 @@ def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
@var_operation
def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
def array_le_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
"""Check if an array is less than or equal to another array.
Args:
@ -1473,7 +1476,7 @@ def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
@var_operation
def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple):
def array_ge_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
"""Check if an array is greater than or equal to another array.
Args:
@ -1487,7 +1490,7 @@ def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
@var_operation
def array_length_operation(array: ArrayVar):
def array_length_operation(array: Var[ARRAY_VAR_TYPE]):
"""Get the length of an array.
Args:
@ -1517,7 +1520,7 @@ def is_tuple_type(t: GenericType) -> bool:
@var_operation
def array_item_operation(array: ArrayVar, index: NumberVar | int):
def array_item_operation(array: Var[ARRAY_VAR_TYPE], index: Var[int]):
"""Get an item from an array.
Args:
@ -1527,23 +1530,45 @@ def array_item_operation(array: ArrayVar, index: NumberVar | int):
Returns:
The item from the array.
"""
args = typing.get_args(array._var_type)
if args and isinstance(index, LiteralNumberVar) and is_tuple_type(array._var_type):
index_value = int(index._var_value)
element_type = args[index_value % len(args)]
else:
element_type = unionize(*args)
def type_computer(*args):
if len(args) == 0:
return (
ReflexCallable[[List[Any], int], Any],
functools.partial(type_computer, *args),
)
array = args[0]
array_args = typing.get_args(array._var_type)
if len(args) == 1:
return (
ReflexCallable[[int], unionize(*array_args)],
functools.partial(type_computer, *args),
)
index = args[1]
if (
array_args
and isinstance(index, LiteralNumberVar)
and is_tuple_type(array._var_type)
):
index_value = int(index._var_value)
element_type = array_args[index_value % len(array_args)]
else:
element_type = unionize(*array_args)
return (ReflexCallable[[], element_type], None)
return var_operation_return(
js_expression=f"{str(array)}.at({str(index)})",
var_type=element_type,
type_computer=type_computer,
)
@var_operation
def array_range_operation(
start: NumberVar | int, stop: NumberVar | int, step: NumberVar | int
):
def array_range_operation(start: Var[int], stop: Var[int], step: Var[int]):
"""Create a range of numbers.
Args:
@ -1562,7 +1587,7 @@ def array_range_operation(
@var_operation
def array_contains_field_operation(
haystack: ArrayVar, needle: Any | Var, field: StringVar | str
haystack: Var[ARRAY_VAR_TYPE], needle: Var, field: Var[str]
):
"""Check if an array contains an element.
@ -1581,7 +1606,7 @@ def array_contains_field_operation(
@var_operation
def array_contains_operation(haystack: ArrayVar, needle: Any | Var):
def array_contains_operation(haystack: Var[ARRAY_VAR_TYPE], needle: Var):
"""Check if an array contains an element.
Args:
@ -1599,7 +1624,7 @@ def array_contains_operation(haystack: ArrayVar, needle: Any | Var):
@var_operation
def repeat_array_operation(
array: ArrayVar[ARRAY_VAR_TYPE], count: NumberVar | int
array: Var[ARRAY_VAR_TYPE], count: Var[int]
) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
"""Repeat an array a number of times.
@ -1610,20 +1635,34 @@ def repeat_array_operation(
Returns:
The repeated array.
"""
def type_computer(*args: Var):
if not args:
return (
ReflexCallable[[List[Any], int], List[Any]],
type_computer,
)
if len(args) == 1:
return (
ReflexCallable[[int], args[0]._var_type],
functools.partial(type_computer, *args),
)
return (ReflexCallable[[], args[0]._var_type], None)
return var_operation_return(
js_expression=f"Array.from({{ length: {count} }}).flatMap(() => {array})",
var_type=array._var_type,
type_computer=type_computer,
)
if TYPE_CHECKING:
from .function import FunctionVar
pass
@var_operation
def map_array_operation(
array: ArrayVar[ARRAY_VAR_TYPE],
function: FunctionVar,
array: Var[ARRAY_VAR_TYPE],
function: Var[ReflexCallable],
):
"""Map a function over an array.
@ -1634,14 +1673,33 @@ def map_array_operation(
Returns:
The mapped array.
"""
def type_computer(*args: Var):
if not args:
return (
ReflexCallable[[List[Any], ReflexCallable], List[Any]],
type_computer,
)
if len(args) == 1:
return (
ReflexCallable[[ReflexCallable], List[Any]],
functools.partial(type_computer, *args),
)
return (ReflexCallable[[], List[args[0]._var_type]], None)
return var_operation_return(
js_expression=f"{array}.map({function})", var_type=List[Any]
js_expression=f"{array}.map({function})",
type_computer=nary_type_computer(
ReflexCallable[[List[Any], ReflexCallable], List[Any]],
ReflexCallable[[ReflexCallable], List[Any]],
computer=lambda args: List[unwrap_reflex_callalbe(args[1]._var_type)[1]],
),
)
@var_operation
def array_concat_operation(
lhs: ArrayVar[ARRAY_VAR_TYPE], rhs: ArrayVar[ARRAY_VAR_TYPE]
lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]
) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
"""Concatenate two arrays.
@ -1654,7 +1712,11 @@ def array_concat_operation(
"""
return var_operation_return(
js_expression=f"[...{lhs}, ...{rhs}]",
var_type=Union[lhs._var_type, rhs._var_type],
type_computer=nary_type_computer(
ReflexCallable[[List[Any], List[Any]], List[Any]],
ReflexCallable[[List[Any]], List[Any]],
computer=lambda args: unionize(args[0]._var_type, args[1]._var_type),
),
)

View File

@ -963,11 +963,11 @@ def test_function_var():
def test_var_operation():
@var_operation
def add(a: Union[NumberVar, int], b: Union[NumberVar, int]):
def add(a: Var[int], b: Var[int]):
return var_operation_return(js_expression=f"({a} + {b})", var_type=int)
assert str(add(1, 2)) == "(1 + 2)"
assert str(add(a=4, b=-9)) == "(4 + -9)"
assert str(add(4, -9)) == "(4 + -9)"
five = LiteralNumberVar.create(5)
seven = add(2, five)