From f4aa1f58c386a5e59bcc4bb9837093e7d3977438 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Wed, 13 Nov 2024 18:17:53 -0800 Subject: [PATCH] implement type computers --- reflex/vars/base.py | 440 +++++++++++++++++++++++++++++----------- reflex/vars/function.py | 232 ++++++++++++++------- reflex/vars/number.py | 111 +++++----- reflex/vars/object.py | 43 ++-- reflex/vars/sequence.py | 186 +++++++++++------ tests/units/test_var.py | 4 +- 6 files changed, 699 insertions(+), 317 deletions(-) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 200f693de..1e6d4163e 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -14,7 +14,7 @@ import re import string import sys import warnings -from types import CodeType, FunctionType +from types import CodeType, EllipsisType, FunctionType from typing import ( TYPE_CHECKING, Any, @@ -26,7 +26,6 @@ from typing import ( Iterable, List, Literal, - NoReturn, Optional, Set, Tuple, @@ -38,7 +37,14 @@ from typing import ( overload, ) -from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override +from typing_extensions import ( + ParamSpec, + Protocol, + TypeGuard, + deprecated, + get_type_hints, + override, +) from reflex import constants from reflex.base import Base @@ -69,6 +75,7 @@ from reflex.utils.types import ( if TYPE_CHECKING: from reflex.state import BaseState + from .function import ArgsFunctionOperation, ReflexCallable from .number import BooleanVar, NumberVar from .object import ObjectVar from .sequence import ArrayVar, StringVar @@ -79,6 +86,36 @@ OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE") warnings.filterwarnings("ignore", message="fields may not start with an underscore") +P = ParamSpec("P") +R = TypeVar("R") + + +class ReflexCallable(Protocol[P, R]): + """Protocol for a callable.""" + + __call__: Callable[P, R] + + +def unwrap_reflex_callalbe( + callable_type: GenericType, +) -> Tuple[Union[EllipsisType, Tuple[GenericType, ...]], GenericType]: + """Unwrap the ReflexCallable type. + + Args: + callable_type: The ReflexCallable type to unwrap. + + Returns: + The unwrapped ReflexCallable type. + """ + if callable_type is ReflexCallable: + return Ellipsis, Any + if get_origin(callable_type) is not ReflexCallable: + return Ellipsis, Any + args = get_args(callable_type) + if not args or len(args) != 2: + return Ellipsis, Any + return args + @dataclasses.dataclass( eq=False, @@ -409,9 +446,11 @@ class Var(Generic[VAR_TYPE]): if _var_data or _js_expr != self._js_expr: self.__init__( - _js_expr=_js_expr, - _var_type=self._var_type, - _var_data=VarData.merge(self._var_data, _var_data), + **{ + **dataclasses.asdict(self), + "_js_expr": _js_expr, + "_var_data": VarData.merge(self._var_data, _var_data), + } ) def __hash__(self) -> int: @@ -690,6 +729,12 @@ class Var(Generic[VAR_TYPE]): @overload def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ... + @overload + def guess_type(self: Var[list] | Var[tuple] | Var[set]) -> ArrayVar: ... + + @overload + def guess_type(self: Var[dict]) -> ObjectVar[dict]: ... + @overload def guess_type(self) -> Self: ... @@ -1413,71 +1458,94 @@ def get_python_literal(value: Union[LiteralVar, Any]) -> Any | None: return value +def validate_arg(type_hint: GenericType) -> Callable[[Any], bool]: + """Create a validator for an argument. + + Args: + type_hint: The type hint of the argument. + + Returns: + The validator. + """ + + def validate(value: Any): + return True + + return validate + + P = ParamSpec("P") T = TypeVar("T") +V1 = TypeVar("V1") +V2 = TypeVar("V2") +V3 = TypeVar("V3") +V4 = TypeVar("V4") +V5 = TypeVar("V5") -# NoReturn is used to match CustomVarOperationReturn with no type hint. -@overload -def var_operation( - func: Callable[P, CustomVarOperationReturn[NoReturn]], -) -> Callable[P, Var]: ... +class TypeComputer(Protocol): + """A protocol for type computers.""" + + def __call__(self, *args: Var) -> Tuple[GenericType, Union[TypeComputer, None]]: + """Compute the type of the operation. + + Args: + *args: The arguments to compute the type of. + + Returns: + The type of the operation. + """ + ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[bool]], -) -> Callable[P, BooleanVar]: ... - - -NUMBER_T = TypeVar("NUMBER_T", int, float, Union[int, float]) + func: Callable[[], CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[[], T]]: ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[NUMBER_T]], -) -> Callable[P, NumberVar[NUMBER_T]]: ... + func: Callable[[Var[V1]], CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[[V1], T]]: ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[str]], -) -> Callable[P, StringVar]: ... - - -LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set]) + func: Callable[[Var[V1], Var[V2]], CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[[V1, V2], T]]: ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[LIST_T]], -) -> Callable[P, ArrayVar[LIST_T]]: ... - - -OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) + func: Callable[[Var[V1], Var[V2], Var[V3]], CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3], T]]: ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]], -) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ... + func: Callable[[Var[V1], Var[V2], Var[V3], Var[V4]], CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3, V4], T]]: ... @overload def var_operation( - func: Callable[P, CustomVarOperationReturn[T]], -) -> Callable[P, Var[T]]: ... + func: Callable[ + [Var[V1], Var[V2], Var[V3], Var[V4], Var[V5]], + CustomVarOperationReturn[T], + ], +) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3, V4, V5], T]]: ... def var_operation( - func: Callable[P, CustomVarOperationReturn[T]], -) -> Callable[P, Var[T]]: + func: Callable[..., CustomVarOperationReturn[T]], +) -> ArgsFunctionOperation: """Decorator for creating a var operation. Example: ```python @var_operation - def add(a: NumberVar, b: NumberVar): + def add(a: Var[int], b: Var[int]): return custom_var_operation(f"{a} + {b}") ``` @@ -1487,26 +1555,61 @@ def var_operation( Returns: The decorated function. """ + from .function import ArgsFunctionOperation, ReflexCallable - @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> Var[T]: - func_args = list(inspect.signature(func).parameters) - args_vars = { - func_args[i]: (LiteralVar.create(arg) if not isinstance(arg, Var) else arg) - for i, arg in enumerate(args) - } - kwargs_vars = { - key: LiteralVar.create(value) if not isinstance(value, Var) else value - for key, value in kwargs.items() - } + func_name = func.__name__ - return CustomVarOperation.create( - name=func.__name__, - args=tuple(list(args_vars.items()) + list(kwargs_vars.items())), - return_var=func(*args_vars.values(), **kwargs_vars), # type: ignore - ).guess_type() + func_arg_spec = inspect.getfullargspec(func) - return wrapper + if func_arg_spec.kwonlyargs: + raise TypeError(f"Function {func_name} cannot have keyword-only arguments.") + if func_arg_spec.varargs: + raise TypeError(f"Function {func_name} cannot have variable arguments.") + + arg_names = func_arg_spec.args + + type_hints = get_type_hints(func) + + if not all( + (get_origin((type_hint := type_hints.get(arg_name, Any))) or type_hint) is Var + and len(get_args(type_hint)) <= 1 + for arg_name in arg_names + ): + raise TypeError( + f"Function {func_name} must have type hints of the form `Var[Type]`." + ) + + args_with_type_hints = tuple( + (arg_name, (args[0] if (args := get_args(type_hints[arg_name])) else Any)) + for arg_name in arg_names + ) + + arg_vars = tuple( + ( + Var("_" + arg_name, _var_type=arg_python_type) + if not isinstance(arg_python_type, TypeVar) + else Var("_" + arg_name) + ) + for arg_name, arg_python_type in args_with_type_hints + ) + + custom_operation_return = func(*arg_vars) + + args_operation = ArgsFunctionOperation.create( + tuple(map(str, arg_vars)), + custom_operation_return, + validators=tuple( + validate_arg(type_hints.get(arg_name, Any)) for arg_name in arg_names + ), + function_name=func_name, + type_computer=custom_operation_return._type_computer, + _var_type=ReflexCallable[ + tuple(arg_python_type for _, arg_python_type in args_with_type_hints), + custom_operation_return._var_type, + ], + ) + + return args_operation def figure_out_type(value: Any) -> types.GenericType: @@ -1621,66 +1724,6 @@ class CachedVarOperation: ) -def and_operation(a: Var | Any, b: Var | Any) -> Var: - """Perform a logical AND operation on two variables. - - Args: - a: The first variable. - b: The second variable. - - Returns: - The result of the logical AND operation. - """ - return _and_operation(a, b) # type: ignore - - -@var_operation -def _and_operation(a: Var, b: Var): - """Perform a logical AND operation on two variables. - - Args: - a: The first variable. - b: The second variable. - - Returns: - The result of the logical AND operation. - """ - return var_operation_return( - js_expression=f"({a} && {b})", - var_type=unionize(a._var_type, b._var_type), - ) - - -def or_operation(a: Var | Any, b: Var | Any) -> Var: - """Perform a logical OR operation on two variables. - - Args: - a: The first variable. - b: The second variable. - - Returns: - The result of the logical OR operation. - """ - return _or_operation(a, b) # type: ignore - - -@var_operation -def _or_operation(a: Var, b: Var): - """Perform a logical OR operation on two variables. - - Args: - a: The first variable. - b: The second variable. - - Returns: - The result of the logical OR operation. - """ - return var_operation_return( - js_expression=f"({a} || {b})", - var_type=unionize(a._var_type, b._var_type), - ) - - @dataclasses.dataclass( eq=False, frozen=True, @@ -2289,14 +2332,22 @@ def computed_var( RETURN = TypeVar("RETURN") +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) class CustomVarOperationReturn(Var[RETURN]): """Base class for custom var operations.""" + _type_computer: Optional[TypeComputer] = dataclasses.field(default=None) + @classmethod def create( cls, js_expression: str, _var_type: Type[RETURN] | None = None, + _type_computer: Optional[TypeComputer] = None, _var_data: VarData | None = None, ) -> CustomVarOperationReturn[RETURN]: """Create a CustomVarOperation. @@ -2304,6 +2355,7 @@ class CustomVarOperationReturn(Var[RETURN]): Args: js_expression: The JavaScript expression to evaluate. _var_type: The type of the var. + _type_computer: A function to compute the type of the var given the arguments. _var_data: Additional hooks and imports associated with the Var. Returns: @@ -2312,6 +2364,7 @@ class CustomVarOperationReturn(Var[RETURN]): return CustomVarOperationReturn( _js_expr=js_expression, _var_type=_var_type or Any, + _type_computer=_type_computer, _var_data=_var_data, ) @@ -2319,6 +2372,7 @@ class CustomVarOperationReturn(Var[RETURN]): def var_operation_return( js_expression: str, var_type: Type[RETURN] | None = None, + type_computer: Optional[TypeComputer] = None, var_data: VarData | None = None, ) -> CustomVarOperationReturn[RETURN]: """Shortcut for creating a CustomVarOperationReturn. @@ -2326,15 +2380,17 @@ def var_operation_return( Args: js_expression: The JavaScript expression to evaluate. var_type: The type of the var. + type_computer: A function to compute the type of the var given the arguments. var_data: Additional hooks and imports associated with the Var. Returns: The CustomVarOperationReturn. """ return CustomVarOperationReturn.create( - js_expression, - var_type, - var_data, + js_expression=js_expression, + _var_type=var_type, + _type_computer=type_computer, + _var_data=var_data, ) @@ -2942,3 +2998,157 @@ def field(value: T) -> Field[T]: The Field. """ return value # type: ignore + + +def and_operation(a: Var | Any, b: Var | Any) -> Var: + """Perform a logical AND operation on two variables. + + Args: + a: The first variable. + b: The second variable. + + Returns: + The result of the logical AND operation. + """ + return _and_operation(a, b) # type: ignore + + +@var_operation +def _and_operation(a: Var, b: Var): + """Perform a logical AND operation on two variables. + + Args: + a: The first variable. + b: The second variable. + + Returns: + The result of the logical AND operation. + """ + + def type_computer(*args: Var): + if not args: + return (ReflexCallable[[Any, Any], Any], type_computer) + if len(args) == 1: + return ( + ReflexCallable[[Any], Any], + functools.partial(type_computer, args[0]), + ) + return ( + ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)], + None, + ) + + return var_operation_return( + js_expression=f"({a} && {b})", + type_computer=type_computer, + ) + + +def or_operation(a: Var | Any, b: Var | Any) -> Var: + """Perform a logical OR operation on two variables. + + Args: + a: The first variable. + b: The second variable. + + Returns: + The result of the logical OR operation. + """ + return _or_operation(a, b) # type: ignore + + +@var_operation +def _or_operation(a: Var, b: Var): + """Perform a logical OR operation on two variables. + + Args: + a: The first variable. + b: The second variable. + + Returns: + The result of the logical OR operation. + """ + + def type_computer(*args: Var): + if not args: + return (ReflexCallable[[Any, Any], Any], type_computer) + if len(args) == 1: + return ( + ReflexCallable[[Any], Any], + functools.partial(type_computer, args[0]), + ) + return ( + ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)], + None, + ) + + return var_operation_return( + js_expression=f"({a} || {b})", + type_computer=type_computer, + ) + + +def passthrough_unary_type_computer(no_args: GenericType) -> TypeComputer: + """Create a type computer for unary operations. + + Args: + no_args: The type to return when no arguments are provided. + + Returns: + The type computer. + """ + + def type_computer(*args: Var): + if not args: + return (no_args, type_computer) + return (ReflexCallable[[], args[0]._var_type], None) + + return type_computer + + +def unary_type_computer( + no_args: GenericType, computer: Callable[[Var], GenericType] +) -> TypeComputer: + """Create a type computer for unary operations. + + Args: + no_args: The type to return when no arguments are provided. + computer: The function to compute the type. + + Returns: + The type computer. + """ + + def type_computer(*args: Var): + if not args: + return (no_args, type_computer) + return (ReflexCallable[[], computer(args[0])], None) + + return type_computer + + +def nary_type_computer( + *types: GenericType, computer: Callable[..., GenericType] +) -> TypeComputer: + """Create a type computer for n-ary operations. + + Args: + types: The types to return when no arguments are provided. + computer: The function to compute the type. + + Returns: + The type computer. + """ + + def type_computer(*args: Var): + if len(args) != len(types): + return ( + ReflexCallable[[], types[len(args)]], + functools.partial(type_computer, *args), + ) + return ( + ReflexCallable[[], computer(args)], + None, + ) + + return type_computer diff --git a/reflex/vars/function.py b/reflex/vars/function.py index d719e5ced..ebe3eba5c 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -6,28 +6,31 @@ import dataclasses import sys from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload -from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar +from typing_extensions import Concatenate, Generic, ParamSpec, TypeVar from reflex.utils import format from reflex.utils.exceptions import VarTypeError from reflex.utils.types import GenericType -from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock +from .base import ( + CachedVarOperation, + LiteralVar, + ReflexCallable, + TypeComputer, + Var, + VarData, + cached_property_no_lock, + unwrap_reflex_callalbe, +) P = ParamSpec("P") +R = TypeVar("R") V1 = TypeVar("V1") V2 = TypeVar("V2") V3 = TypeVar("V3") V4 = TypeVar("V4") V5 = TypeVar("V5") V6 = TypeVar("V6") -R = TypeVar("R") - - -class ReflexCallable(Protocol[P, R]): - """Protocol for a callable.""" - - __call__: Callable[P, R] CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True) @@ -112,20 +115,37 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): """ if not args: return self + + args = tuple(map(LiteralVar.create, args)) + remaining_validators = self._pre_check(*args) + + partial_types, type_computer = self._partial_type(*args) + if self.__call__ is self.partial: # if the default behavior is partial, we should return a new partial function return ArgsFunctionOperationBuilder.create( (), - VarOperationCall.create(self, *args, Var(_js_expr="...args")), + VarOperationCall.create( + self, + *args, + Var(_js_expr="...args"), + _var_type=self._return_type(*args), + ), rest="args", validators=remaining_validators, + type_computer=type_computer, + _var_type=partial_types, ) return ArgsFunctionOperation.create( (), - VarOperationCall.create(self, *args, Var(_js_expr="...args")), + VarOperationCall.create( + self, *args, Var(_js_expr="...args"), _var_type=self._return_type(*args) + ), rest="args", validators=remaining_validators, + type_computer=type_computer, + _var_type=partial_types, ) @overload @@ -194,9 +214,56 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]): Returns: The function call operation. + + Raises: + VarTypeError: If the number of arguments is invalid """ + arg_len = self._arg_len() + if arg_len is not None and len(args) != arg_len: + raise VarTypeError(f"Invalid number of arguments provided to {str(self)}") + args = tuple(map(LiteralVar.create, args)) self._pre_check(*args) - return VarOperationCall.create(self, *args).guess_type() + return_type = self._return_type(*args) + return VarOperationCall.create(self, *args, _var_type=return_type).guess_type() + + def _partial_type( + self, *args: Var | Any + ) -> Tuple[GenericType, Optional[TypeComputer]]: + """Override the type of the function call with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The overridden type of the function call. + """ + args_types, return_type = unwrap_reflex_callalbe(self._var_type) + if isinstance(args_types, tuple): + return ReflexCallable[[*args_types[len(args) :]], return_type], None + return ReflexCallable[..., return_type], None + + def _arg_len(self) -> int | None: + """Get the number of arguments the function takes. + + Returns: + The number of arguments the function takes. + """ + args_types, _ = unwrap_reflex_callalbe(self._var_type) + if isinstance(args_types, tuple): + return len(args_types) + return None + + def _return_type(self, *args: Var | Any) -> GenericType: + """Override the type of the function call with the given arguments. + + Args: + *args: The arguments to call the function with. + + Returns: + The overridden type of the function call. + """ + partial_types, _ = self._partial_type(*args) + return unwrap_reflex_callalbe(partial_types)[1] def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]: """Check if the function can be called with the given arguments. @@ -343,11 +410,12 @@ class FunctionArgs: def format_args_function_operation( - args: FunctionArgs, return_expr: Var | Any, explicit_return: bool + self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, ) -> str: """Format an args function operation. Args: + self: The function operation. args: The function arguments. return_expr: The return expression. explicit_return: Whether to use explicit return syntax. @@ -356,26 +424,76 @@ def format_args_function_operation( The formatted args function operation. """ arg_names_str = ", ".join( - [arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args] - + ([f"...{args.rest}"] if args.rest else []) + [ + arg if isinstance(arg, str) else arg.to_javascript() + for arg in self._args.args + ] + + ([f"...{self._args.rest}"] if self._args.rest else []) ) - return_expr_str = str(LiteralVar.create(return_expr)) + return_expr_str = str(LiteralVar.create(self._return_expr)) # Wrap return expression in curly braces if explicit return syntax is used. return_expr_str_wrapped = ( - format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str + format.wrap(return_expr_str, "{", "}") + if self._explicit_return + else return_expr_str ) return f"(({arg_names_str}) => {return_expr_str_wrapped})" +def pre_check_args( + self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, *args: Var | Any +) -> Tuple[Callable[[Any], bool], ...]: + """Check if the function can be called with the given arguments. + + Args: + self: The function operation. + *args: The arguments to call the function with. + + Returns: + True if the function can be called with the given arguments. + """ + for i, (validator, arg) in enumerate(zip(self._validators, args)): + if not validator(arg): + arg_name = self._args.args[i] if i < len(self._args.args) else None + if arg_name is not None: + raise VarTypeError( + f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}" + ) + raise VarTypeError( + f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}" + ) + return self._validators[len(args) :] + + +def figure_partial_type( + self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, + *args: Var | Any, +) -> Tuple[GenericType, Optional[TypeComputer]]: + """Figure out the return type of the function. + + Args: + self: The function operation. + *args: The arguments to call the function with. + + Returns: + The return type of the function. + """ + return ( + self._type_computer(*args) + if self._type_computer is not None + else FunctionVar._partial_type(self, *args) + ) + + @dataclasses.dataclass( eq=False, frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class ArgsFunctionOperation(CachedVarOperation, FunctionVar): +class ArgsFunctionOperation(CachedVarOperation, FunctionVar[CALLABLE_TYPE]): """Base class for immutable function defined via arguments and return expression.""" _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) @@ -384,39 +502,14 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): ) _return_expr: Union[Var, Any] = dataclasses.field(default=None) _function_name: str = dataclasses.field(default="") + _type_computer: Optional[TypeComputer] = dataclasses.field(default=None) _explicit_return: bool = dataclasses.field(default=False) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + _cached_var_name = cached_property_no_lock(format_args_function_operation) - Returns: - The name of the var. - """ - return format_args_function_operation( - self._args, self._return_expr, self._explicit_return - ) + _pre_check = pre_check_args - def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]: - """Check if the function can be called with the given arguments. - - Args: - *args: The arguments to call the function with. - - Returns: - True if the function can be called with the given arguments. - """ - for i, (validator, arg) in enumerate(zip(self._validators, args)): - if not validator(arg): - arg_name = self._args.args[i] if i < len(self._args.args) else None - if arg_name is not None: - raise VarTypeError( - f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}" - ) - raise VarTypeError( - f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}" - ) - return self._validators[len(args) :] + _partial_type = figure_partial_type @classmethod def create( @@ -427,6 +520,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): validators: Sequence[Callable[[Any], bool]] = (), function_name: str = "", explicit_return: bool = False, + type_computer: Optional[TypeComputer] = None, _var_type: GenericType = Callable, _var_data: VarData | None = None, ): @@ -439,6 +533,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): validators: The validators for the arguments. function_name: The name of the function. explicit_return: Whether to use explicit return syntax. + type_computer: A function to compute the return type. + _var_type: The type of the var. _var_data: Additional hooks and imports associated with the Var. Returns: @@ -453,6 +549,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): _validators=tuple(validators), _return_expr=return_expr, _explicit_return=explicit_return, + _type_computer=type_computer, ) @@ -461,7 +558,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): frozen=True, **{"slots": True} if sys.version_info >= (3, 10) else {}, ) -class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): +class ArgsFunctionOperationBuilder( + CachedVarOperation, BuilderFunctionVar[CALLABLE_TYPE] +): """Base class for immutable function defined via arguments and return expression with the builder pattern.""" _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) @@ -470,39 +569,14 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): ) _return_expr: Union[Var, Any] = dataclasses.field(default=None) _function_name: str = dataclasses.field(default="") + _type_computer: Optional[TypeComputer] = dataclasses.field(default=None) _explicit_return: bool = dataclasses.field(default=False) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + _cached_var_name = cached_property_no_lock(format_args_function_operation) - Returns: - The name of the var. - """ - return format_args_function_operation( - self._args, self._return_expr, self._explicit_return - ) + _pre_check = pre_check_args - def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]: - """Check if the function can be called with the given arguments. - - Args: - *args: The arguments to call the function with. - - Returns: - True if the function can be called with the given arguments. - """ - for i, (validator, arg) in enumerate(zip(self._validators, args)): - if not validator(arg): - arg_name = self._args.args[i] if i < len(self._args.args) else None - if arg_name is not None: - raise VarTypeError( - f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}" - ) - raise VarTypeError( - f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}" - ) - return self._validators[len(args) :] + _partial_type = figure_partial_type @classmethod def create( @@ -513,6 +587,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): validators: Sequence[Callable[[Any], bool]] = (), function_name: str = "", explicit_return: bool = False, + type_computer: Optional[TypeComputer] = None, _var_type: GenericType = Callable, _var_data: VarData | None = None, ): @@ -525,6 +600,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): validators: The validators for the arguments. function_name: The name of the function. explicit_return: Whether to use explicit return syntax. + type_computer: A function to compute the return type. + _var_type: The type of the var. _var_data: Additional hooks and imports associated with the Var. Returns: @@ -539,6 +616,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): _validators=tuple(validators), _return_expr=return_expr, _explicit_return=explicit_return, + _type_computer=type_computer, ) diff --git a/reflex/vars/number.py b/reflex/vars/number.py index a762796e2..0f2cb4e46 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -3,19 +3,11 @@ from __future__ import annotations import dataclasses +import functools import json import math import sys -from typing import ( - TYPE_CHECKING, - Any, - Callable, - NoReturn, - Type, - TypeVar, - Union, - overload, -) +from typing import TYPE_CHECKING, Any, Callable, NoReturn, TypeVar, Union, overload from reflex.constants.base import Dirs from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError @@ -25,8 +17,11 @@ from reflex.utils.types import is_optional from .base import ( CustomVarOperationReturn, LiteralVar, + ReflexCallable, Var, VarData, + nary_type_computer, + passthrough_unary_type_computer, unionize, var_operation, var_operation_return, @@ -544,8 +539,8 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): def binary_number_operation( - func: Callable[[NumberVar, NumberVar], str], -) -> Callable[[number_types, number_types], NumberVar]: + func: Callable[[Var[int | float], Var[int | float]], str], +): """Decorator to create a binary number operation. Args: @@ -555,30 +550,37 @@ def binary_number_operation( The binary number operation. """ - @var_operation - def operation(lhs: NumberVar, rhs: NumberVar): + def operation( + lhs: Var[int | float], rhs: Var[int | float] + ) -> CustomVarOperationReturn[int | float]: + def type_computer(*args: Var): + if not args: + return ( + ReflexCallable[[int | float, int | float], int | float], + type_computer, + ) + if len(args) == 1: + return ( + ReflexCallable[[int | float], int | float], + functools.partial(type_computer, args[0]), + ) + return ( + ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)], + None, + ) + return var_operation_return( js_expression=func(lhs, rhs), - var_type=unionize(lhs._var_type, rhs._var_type), + type_computer=type_computer, ) - def wrapper(lhs: number_types, rhs: number_types) -> NumberVar: - """Create the binary number operation. + operation.__name__ = func.__name__ - Args: - lhs: The first number. - rhs: The second number. - - Returns: - The binary number operation. - """ - return operation(lhs, rhs) # type: ignore - - return wrapper + return var_operation(operation) @binary_number_operation -def number_add_operation(lhs: NumberVar, rhs: NumberVar): +def number_add_operation(lhs: Var[int | float], rhs: Var[int | float]): """Add two numbers. Args: @@ -592,7 +594,7 @@ def number_add_operation(lhs: NumberVar, rhs: NumberVar): @binary_number_operation -def number_subtract_operation(lhs: NumberVar, rhs: NumberVar): +def number_subtract_operation(lhs: Var[int | float], rhs: Var[int | float]): """Subtract two numbers. Args: @@ -605,8 +607,15 @@ def number_subtract_operation(lhs: NumberVar, rhs: NumberVar): return f"({lhs} - {rhs})" +unary_operation_type_computer = passthrough_unary_type_computer( + ReflexCallable[[int | float], int | float] +) + + @var_operation -def number_abs_operation(value: NumberVar): +def number_abs_operation( + value: Var[int | float], +) -> CustomVarOperationReturn[int | float]: """Get the absolute value of the number. Args: @@ -616,12 +625,12 @@ def number_abs_operation(value: NumberVar): The number absolute operation. """ return var_operation_return( - js_expression=f"Math.abs({value})", var_type=value._var_type + js_expression=f"Math.abs({value})", type_computer=unary_operation_type_computer ) @binary_number_operation -def number_multiply_operation(lhs: NumberVar, rhs: NumberVar): +def number_multiply_operation(lhs: Var[int | float], rhs: Var[int | float]): """Multiply two numbers. Args: @@ -636,7 +645,7 @@ def number_multiply_operation(lhs: NumberVar, rhs: NumberVar): @var_operation def number_negate_operation( - value: NumberVar[NUMBER_T], + value: Var[NUMBER_T], ) -> CustomVarOperationReturn[NUMBER_T]: """Negate the number. @@ -646,11 +655,13 @@ def number_negate_operation( Returns: The number negation operation. """ - return var_operation_return(js_expression=f"-({value})", var_type=value._var_type) + return var_operation_return( + js_expression=f"-({value})", type_computer=unary_operation_type_computer + ) @binary_number_operation -def number_true_division_operation(lhs: NumberVar, rhs: NumberVar): +def number_true_division_operation(lhs: Var[int | float], rhs: Var[int | float]): """Divide two numbers. Args: @@ -664,7 +675,7 @@ def number_true_division_operation(lhs: NumberVar, rhs: NumberVar): @binary_number_operation -def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar): +def number_floor_division_operation(lhs: Var[int | float], rhs: Var[int | float]): """Floor divide two numbers. Args: @@ -678,7 +689,7 @@ def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar): @binary_number_operation -def number_modulo_operation(lhs: NumberVar, rhs: NumberVar): +def number_modulo_operation(lhs: Var[int | float], rhs: Var[int | float]): """Modulo two numbers. Args: @@ -692,7 +703,7 @@ def number_modulo_operation(lhs: NumberVar, rhs: NumberVar): @binary_number_operation -def number_exponent_operation(lhs: NumberVar, rhs: NumberVar): +def number_exponent_operation(lhs: Var[int | float], rhs: Var[int | float]): """Exponentiate two numbers. Args: @@ -706,7 +717,7 @@ def number_exponent_operation(lhs: NumberVar, rhs: NumberVar): @var_operation -def number_round_operation(value: NumberVar): +def number_round_operation(value: Var[int | float]): """Round the number. Args: @@ -719,7 +730,7 @@ def number_round_operation(value: NumberVar): @var_operation -def number_ceil_operation(value: NumberVar): +def number_ceil_operation(value: Var[int | float]): """Ceil the number. Args: @@ -732,7 +743,7 @@ def number_ceil_operation(value: NumberVar): @var_operation -def number_floor_operation(value: NumberVar): +def number_floor_operation(value: Var[int | float]): """Floor the number. Args: @@ -745,7 +756,7 @@ def number_floor_operation(value: NumberVar): @var_operation -def number_trunc_operation(value: NumberVar): +def number_trunc_operation(value: Var[int | float]): """Trunc the number. Args: @@ -838,7 +849,7 @@ class BooleanVar(NumberVar[bool], python_types=bool): @var_operation -def boolean_to_number_operation(value: BooleanVar): +def boolean_to_number_operation(value: Var[bool]): """Convert the boolean to a number. Args: @@ -969,7 +980,7 @@ def not_equal_operation(lhs: Var, rhs: Var): @var_operation -def boolean_not_operation(value: BooleanVar): +def boolean_not_operation(value: Var[bool]): """Boolean NOT the boolean. Args: @@ -1117,7 +1128,7 @@ U = TypeVar("U") @var_operation def ternary_operation( - condition: BooleanVar, if_true: Var[T], if_false: Var[U] + condition: Var[bool], if_true: Var[T], if_false: Var[U] ) -> CustomVarOperationReturn[Union[T, U]]: """Create a ternary operation. @@ -1129,12 +1140,14 @@ def ternary_operation( Returns: The ternary operation. """ - type_value: Union[Type[T], Type[U]] = unionize( - if_true._var_type, if_false._var_type - ) value: CustomVarOperationReturn[Union[T, U]] = var_operation_return( js_expression=f"({condition} ? {if_true} : {if_false})", - var_type=type_value, + type_computer=nary_type_computer( + ReflexCallable[[bool, Any, Any], Any], + ReflexCallable[[Any, Any], Any], + ReflexCallable[[Any], Any], + computer=lambda args: unionize(args[1]._var_type, args[2]._var_type), + ), ) return value diff --git a/reflex/vars/object.py b/reflex/vars/object.py index e60ea09e3..6aabd8c80 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -21,15 +21,23 @@ from typing import ( from reflex.utils import types from reflex.utils.exceptions import VarAttributeError -from reflex.utils.types import GenericType, get_attribute_access_type, get_origin +from reflex.utils.types import ( + GenericType, + get_attribute_access_type, + get_origin, + unionize, +) from .base import ( CachedVarOperation, LiteralVar, + ReflexCallable, Var, VarData, cached_property_no_lock, figure_out_type, + nary_type_computer, + unary_type_computer, var_operation, var_operation_return, ) @@ -406,7 +414,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): @var_operation -def object_keys_operation(value: ObjectVar): +def object_keys_operation(value: Var): """Get the keys of an object. Args: @@ -422,7 +430,7 @@ def object_keys_operation(value: ObjectVar): @var_operation -def object_values_operation(value: ObjectVar): +def object_values_operation(value: Var): """Get the values of an object. Args: @@ -433,12 +441,15 @@ def object_values_operation(value: ObjectVar): """ return var_operation_return( js_expression=f"Object.values({value})", - var_type=List[value._value_type()], + type_computer=unary_type_computer( + ReflexCallable[[Any], List[Any]], + lambda x: List[x.to(ObjectVar)._value_type()], + ), ) @var_operation -def object_entries_operation(value: ObjectVar): +def object_entries_operation(value: Var): """Get the entries of an object. Args: @@ -447,14 +458,18 @@ def object_entries_operation(value: ObjectVar): Returns: The entries of the object. """ + value = value.to(ObjectVar) return var_operation_return( js_expression=f"Object.entries({value})", - var_type=List[Tuple[str, value._value_type()]], + type_computer=unary_type_computer( + ReflexCallable[[Any], List[Tuple[str, Any]]], + lambda x: List[Tuple[str, x.to(ObjectVar)._value_type()]], + ), ) @var_operation -def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): +def object_merge_operation(lhs: Var, rhs: Var): """Merge two objects. Args: @@ -466,10 +481,14 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): """ return var_operation_return( js_expression=f"({{...{lhs}, ...{rhs}}})", - var_type=Dict[ - Union[lhs._key_type(), rhs._key_type()], - Union[lhs._value_type(), rhs._value_type()], - ], + type_computer=nary_type_computer( + ReflexCallable[[Any, Any], Dict[Any, Any]], + ReflexCallable[[Any], Dict[Any, Any]], + computer=lambda args: Dict[ + unionize(*[arg.to(ObjectVar)._key_type() for arg in args]), + unionize(*[arg.to(ObjectVar)._value_type() for arg in args]), + ], + ), ) @@ -526,7 +545,7 @@ class ObjectItemOperation(CachedVarOperation, Var): @var_operation -def object_has_own_property_operation(object: ObjectVar, key: Var): +def object_has_own_property_operation(object: Var, key: Var): """Check if an object has a key. Args: diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index 08429883f..84f4d8146 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -3,6 +3,7 @@ from __future__ import annotations import dataclasses +import functools import inspect import json import re @@ -34,6 +35,7 @@ from .base import ( CachedVarOperation, CustomVarOperationReturn, LiteralVar, + ReflexCallable, Var, VarData, _global_vars, @@ -41,7 +43,10 @@ from .base import ( figure_out_type, get_python_literal, get_unique_variable_name, + nary_type_computer, + passthrough_unary_type_computer, unionize, + unwrap_reflex_callalbe, var_operation, var_operation_return, ) @@ -353,7 +358,7 @@ class StringVar(Var[STRING_TYPE], python_types=str): @var_operation -def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): +def string_lt_operation(lhs: Var[str], rhs: Var[str]): """Check if a string is less than another string. Args: @@ -367,7 +372,7 @@ def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): @var_operation -def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): +def string_gt_operation(lhs: Var[str], rhs: Var[str]): """Check if a string is greater than another string. Args: @@ -381,7 +386,7 @@ def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): @var_operation -def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): +def string_le_operation(lhs: Var[str], rhs: Var[str]): """Check if a string is less than or equal to another string. Args: @@ -395,7 +400,7 @@ def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): @var_operation -def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): +def string_ge_operation(lhs: Var[str], rhs: Var[str]): """Check if a string is greater than or equal to another string. Args: @@ -409,7 +414,7 @@ def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): @var_operation -def string_lower_operation(string: StringVar[Any]): +def string_lower_operation(string: Var[str]): """Convert a string to lowercase. Args: @@ -422,7 +427,7 @@ def string_lower_operation(string: StringVar[Any]): @var_operation -def string_upper_operation(string: StringVar[Any]): +def string_upper_operation(string: Var[str]): """Convert a string to uppercase. Args: @@ -435,7 +440,7 @@ def string_upper_operation(string: StringVar[Any]): @var_operation -def string_strip_operation(string: StringVar[Any]): +def string_strip_operation(string: Var[str]): """Strip a string. Args: @@ -449,7 +454,7 @@ def string_strip_operation(string: StringVar[Any]): @var_operation def string_contains_field_operation( - haystack: StringVar[Any], needle: StringVar[Any] | str, field: StringVar[Any] | str + haystack: Var[str], needle: Var[str], field: Var[str] ): """Check if a string contains another string. @@ -468,7 +473,7 @@ def string_contains_field_operation( @var_operation -def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] | str): +def string_contains_operation(haystack: Var[str], needle: Var[str]): """Check if a string contains another string. Args: @@ -484,9 +489,7 @@ def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] | @var_operation -def string_starts_with_operation( - full_string: StringVar[Any], prefix: StringVar[Any] | str -): +def string_starts_with_operation(full_string: Var[str], prefix: Var[str]): """Check if a string starts with a prefix. Args: @@ -502,7 +505,7 @@ def string_starts_with_operation( @var_operation -def string_item_operation(string: StringVar[Any], index: NumberVar | int): +def string_item_operation(string: Var[str], index: Var[int]): """Get an item from a string. Args: @@ -515,23 +518,9 @@ def string_item_operation(string: StringVar[Any], index: NumberVar | int): return var_operation_return(js_expression=f"{string}.at({index})", var_type=str) -@var_operation -def array_join_operation(array: ArrayVar, sep: StringVar[Any] | str = ""): - """Join the elements of an array. - - Args: - array: The array. - sep: The separator. - - Returns: - The joined elements. - """ - return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str) - - @var_operation def string_replace_operation( - string: StringVar, search_value: StringVar | str, new_value: StringVar | str + string: Var[str], search_value: Var[str], new_value: Var[str] ): """Replace a string with a value. @@ -1046,7 +1035,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): Returns: The array pluck operation. """ - return array_pluck_operation(self, field) + return array_pluck_operation(self, field).guess_type() @overload def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: ... @@ -1300,7 +1289,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): @var_operation -def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""): +def string_split_operation(string: Var[str], sep: Var[str]): """Split a string. Args: @@ -1394,9 +1383,9 @@ class ArraySliceOperation(CachedVarOperation, ArrayVar): @var_operation def array_pluck_operation( - array: ArrayVar[ARRAY_VAR_TYPE], - field: StringVar | str, -) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: + array: Var[ARRAY_VAR_TYPE], + field: Var[str], +) -> CustomVarOperationReturn[List]: """Pluck a field from an array of objects. Args: @@ -1408,13 +1397,27 @@ def array_pluck_operation( """ return var_operation_return( js_expression=f"{array}.map(e=>e?.[{field}])", - var_type=array._var_type, + var_type=List[Any], ) +@var_operation +def array_join_operation(array: Var[ARRAY_VAR_TYPE], sep: Var[str]): + """Join the elements of an array. + + Args: + array: The array. + sep: The separator. + + Returns: + The joined elements. + """ + return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str) + + @var_operation def array_reverse_operation( - array: ArrayVar[ARRAY_VAR_TYPE], + array: Var[ARRAY_VAR_TYPE], ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: """Reverse an array. @@ -1426,12 +1429,12 @@ def array_reverse_operation( """ return var_operation_return( js_expression=f"{array}.slice().reverse()", - var_type=array._var_type, + type_computer=passthrough_unary_type_computer(ReflexCallable[[List], List]), ) @var_operation -def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): +def array_lt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]): """Check if an array is less than another array. Args: @@ -1445,7 +1448,7 @@ def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl @var_operation -def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): +def array_gt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]): """Check if an array is greater than another array. Args: @@ -1459,7 +1462,7 @@ def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl @var_operation -def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): +def array_le_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]): """Check if an array is less than or equal to another array. Args: @@ -1473,7 +1476,7 @@ def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl @var_operation -def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): +def array_ge_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]): """Check if an array is greater than or equal to another array. Args: @@ -1487,7 +1490,7 @@ def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl @var_operation -def array_length_operation(array: ArrayVar): +def array_length_operation(array: Var[ARRAY_VAR_TYPE]): """Get the length of an array. Args: @@ -1517,7 +1520,7 @@ def is_tuple_type(t: GenericType) -> bool: @var_operation -def array_item_operation(array: ArrayVar, index: NumberVar | int): +def array_item_operation(array: Var[ARRAY_VAR_TYPE], index: Var[int]): """Get an item from an array. Args: @@ -1527,23 +1530,45 @@ def array_item_operation(array: ArrayVar, index: NumberVar | int): Returns: The item from the array. """ - args = typing.get_args(array._var_type) - if args and isinstance(index, LiteralNumberVar) and is_tuple_type(array._var_type): - index_value = int(index._var_value) - element_type = args[index_value % len(args)] - else: - element_type = unionize(*args) + + def type_computer(*args): + if len(args) == 0: + return ( + ReflexCallable[[List[Any], int], Any], + functools.partial(type_computer, *args), + ) + + array = args[0] + array_args = typing.get_args(array._var_type) + + if len(args) == 1: + return ( + ReflexCallable[[int], unionize(*array_args)], + functools.partial(type_computer, *args), + ) + + index = args[1] + + if ( + array_args + and isinstance(index, LiteralNumberVar) + and is_tuple_type(array._var_type) + ): + index_value = int(index._var_value) + element_type = array_args[index_value % len(array_args)] + else: + element_type = unionize(*array_args) + + return (ReflexCallable[[], element_type], None) return var_operation_return( js_expression=f"{str(array)}.at({str(index)})", - var_type=element_type, + type_computer=type_computer, ) @var_operation -def array_range_operation( - start: NumberVar | int, stop: NumberVar | int, step: NumberVar | int -): +def array_range_operation(start: Var[int], stop: Var[int], step: Var[int]): """Create a range of numbers. Args: @@ -1562,7 +1587,7 @@ def array_range_operation( @var_operation def array_contains_field_operation( - haystack: ArrayVar, needle: Any | Var, field: StringVar | str + haystack: Var[ARRAY_VAR_TYPE], needle: Var, field: Var[str] ): """Check if an array contains an element. @@ -1581,7 +1606,7 @@ def array_contains_field_operation( @var_operation -def array_contains_operation(haystack: ArrayVar, needle: Any | Var): +def array_contains_operation(haystack: Var[ARRAY_VAR_TYPE], needle: Var): """Check if an array contains an element. Args: @@ -1599,7 +1624,7 @@ def array_contains_operation(haystack: ArrayVar, needle: Any | Var): @var_operation def repeat_array_operation( - array: ArrayVar[ARRAY_VAR_TYPE], count: NumberVar | int + array: Var[ARRAY_VAR_TYPE], count: Var[int] ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: """Repeat an array a number of times. @@ -1610,20 +1635,34 @@ def repeat_array_operation( Returns: The repeated array. """ + + def type_computer(*args: Var): + if not args: + return ( + ReflexCallable[[List[Any], int], List[Any]], + type_computer, + ) + if len(args) == 1: + return ( + ReflexCallable[[int], args[0]._var_type], + functools.partial(type_computer, *args), + ) + return (ReflexCallable[[], args[0]._var_type], None) + return var_operation_return( js_expression=f"Array.from({{ length: {count} }}).flatMap(() => {array})", - var_type=array._var_type, + type_computer=type_computer, ) if TYPE_CHECKING: - from .function import FunctionVar + pass @var_operation def map_array_operation( - array: ArrayVar[ARRAY_VAR_TYPE], - function: FunctionVar, + array: Var[ARRAY_VAR_TYPE], + function: Var[ReflexCallable], ): """Map a function over an array. @@ -1634,14 +1673,33 @@ def map_array_operation( Returns: The mapped array. """ + + def type_computer(*args: Var): + if not args: + return ( + ReflexCallable[[List[Any], ReflexCallable], List[Any]], + type_computer, + ) + if len(args) == 1: + return ( + ReflexCallable[[ReflexCallable], List[Any]], + functools.partial(type_computer, *args), + ) + return (ReflexCallable[[], List[args[0]._var_type]], None) + return var_operation_return( - js_expression=f"{array}.map({function})", var_type=List[Any] + js_expression=f"{array}.map({function})", + type_computer=nary_type_computer( + ReflexCallable[[List[Any], ReflexCallable], List[Any]], + ReflexCallable[[ReflexCallable], List[Any]], + computer=lambda args: List[unwrap_reflex_callalbe(args[1]._var_type)[1]], + ), ) @var_operation def array_concat_operation( - lhs: ArrayVar[ARRAY_VAR_TYPE], rhs: ArrayVar[ARRAY_VAR_TYPE] + lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE] ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: """Concatenate two arrays. @@ -1654,7 +1712,11 @@ def array_concat_operation( """ return var_operation_return( js_expression=f"[...{lhs}, ...{rhs}]", - var_type=Union[lhs._var_type, rhs._var_type], + type_computer=nary_type_computer( + ReflexCallable[[List[Any], List[Any]], List[Any]], + ReflexCallable[[List[Any]], List[Any]], + computer=lambda args: unionize(args[0]._var_type, args[1]._var_type), + ), ) diff --git a/tests/units/test_var.py b/tests/units/test_var.py index 4940246e7..69732740f 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -963,11 +963,11 @@ def test_function_var(): def test_var_operation(): @var_operation - def add(a: Union[NumberVar, int], b: Union[NumberVar, int]): + def add(a: Var[int], b: Var[int]): return var_operation_return(js_expression=f"({a} + {b})", var_type=int) assert str(add(1, 2)) == "(1 + 2)" - assert str(add(a=4, b=-9)) == "(4 + -9)" + assert str(add(4, -9)) == "(4 + -9)" five = LiteralNumberVar.create(5) seven = add(2, five)