implement type computers

This commit is contained in:
Khaleel Al-Adhami 2024-11-13 18:17:53 -08:00
parent ebc81811c0
commit f4aa1f58c3
6 changed files with 699 additions and 317 deletions

View File

@ -14,7 +14,7 @@ import re
import string import string
import sys import sys
import warnings import warnings
from types import CodeType, FunctionType from types import CodeType, EllipsisType, FunctionType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -26,7 +26,6 @@ from typing import (
Iterable, Iterable,
List, List,
Literal, Literal,
NoReturn,
Optional, Optional,
Set, Set,
Tuple, Tuple,
@ -38,7 +37,14 @@ from typing import (
overload, overload,
) )
from typing_extensions import ParamSpec, TypeGuard, deprecated, get_type_hints, override from typing_extensions import (
ParamSpec,
Protocol,
TypeGuard,
deprecated,
get_type_hints,
override,
)
from reflex import constants from reflex import constants
from reflex.base import Base from reflex.base import Base
@ -69,6 +75,7 @@ from reflex.utils.types import (
if TYPE_CHECKING: if TYPE_CHECKING:
from reflex.state import BaseState from reflex.state import BaseState
from .function import ArgsFunctionOperation, ReflexCallable
from .number import BooleanVar, NumberVar from .number import BooleanVar, NumberVar
from .object import ObjectVar from .object import ObjectVar
from .sequence import ArrayVar, StringVar from .sequence import ArrayVar, StringVar
@ -79,6 +86,36 @@ OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
warnings.filterwarnings("ignore", message="fields may not start with an underscore") warnings.filterwarnings("ignore", message="fields may not start with an underscore")
P = ParamSpec("P")
R = TypeVar("R")
class ReflexCallable(Protocol[P, R]):
"""Protocol for a callable."""
__call__: Callable[P, R]
def unwrap_reflex_callalbe(
callable_type: GenericType,
) -> Tuple[Union[EllipsisType, Tuple[GenericType, ...]], GenericType]:
"""Unwrap the ReflexCallable type.
Args:
callable_type: The ReflexCallable type to unwrap.
Returns:
The unwrapped ReflexCallable type.
"""
if callable_type is ReflexCallable:
return Ellipsis, Any
if get_origin(callable_type) is not ReflexCallable:
return Ellipsis, Any
args = get_args(callable_type)
if not args or len(args) != 2:
return Ellipsis, Any
return args
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
@ -409,9 +446,11 @@ class Var(Generic[VAR_TYPE]):
if _var_data or _js_expr != self._js_expr: if _var_data or _js_expr != self._js_expr:
self.__init__( self.__init__(
_js_expr=_js_expr, **{
_var_type=self._var_type, **dataclasses.asdict(self),
_var_data=VarData.merge(self._var_data, _var_data), "_js_expr": _js_expr,
"_var_data": VarData.merge(self._var_data, _var_data),
}
) )
def __hash__(self) -> int: def __hash__(self) -> int:
@ -690,6 +729,12 @@ class Var(Generic[VAR_TYPE]):
@overload @overload
def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ... def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ...
@overload
def guess_type(self: Var[list] | Var[tuple] | Var[set]) -> ArrayVar: ...
@overload
def guess_type(self: Var[dict]) -> ObjectVar[dict]: ...
@overload @overload
def guess_type(self) -> Self: ... def guess_type(self) -> Self: ...
@ -1413,71 +1458,94 @@ def get_python_literal(value: Union[LiteralVar, Any]) -> Any | None:
return value return value
def validate_arg(type_hint: GenericType) -> Callable[[Any], bool]:
"""Create a validator for an argument.
Args:
type_hint: The type hint of the argument.
Returns:
The validator.
"""
def validate(value: Any):
return True
return validate
P = ParamSpec("P") P = ParamSpec("P")
T = TypeVar("T") T = TypeVar("T")
V1 = TypeVar("V1")
V2 = TypeVar("V2")
V3 = TypeVar("V3")
V4 = TypeVar("V4")
V5 = TypeVar("V5")
# NoReturn is used to match CustomVarOperationReturn with no type hint. class TypeComputer(Protocol):
@overload """A protocol for type computers."""
def var_operation(
func: Callable[P, CustomVarOperationReturn[NoReturn]], def __call__(self, *args: Var) -> Tuple[GenericType, Union[TypeComputer, None]]:
) -> Callable[P, Var]: ... """Compute the type of the operation.
Args:
*args: The arguments to compute the type of.
Returns:
The type of the operation.
"""
...
@overload @overload
def var_operation( def var_operation(
func: Callable[P, CustomVarOperationReturn[bool]], func: Callable[[], CustomVarOperationReturn[T]],
) -> Callable[P, BooleanVar]: ... ) -> ArgsFunctionOperation[ReflexCallable[[], T]]: ...
NUMBER_T = TypeVar("NUMBER_T", int, float, Union[int, float])
@overload @overload
def var_operation( def var_operation(
func: Callable[P, CustomVarOperationReturn[NUMBER_T]], func: Callable[[Var[V1]], CustomVarOperationReturn[T]],
) -> Callable[P, NumberVar[NUMBER_T]]: ... ) -> ArgsFunctionOperation[ReflexCallable[[V1], T]]: ...
@overload @overload
def var_operation( def var_operation(
func: Callable[P, CustomVarOperationReturn[str]], func: Callable[[Var[V1], Var[V2]], CustomVarOperationReturn[T]],
) -> Callable[P, StringVar]: ... ) -> ArgsFunctionOperation[ReflexCallable[[V1, V2], T]]: ...
LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set])
@overload @overload
def var_operation( def var_operation(
func: Callable[P, CustomVarOperationReturn[LIST_T]], func: Callable[[Var[V1], Var[V2], Var[V3]], CustomVarOperationReturn[T]],
) -> Callable[P, ArrayVar[LIST_T]]: ... ) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3], T]]: ...
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
@overload @overload
def var_operation( def var_operation(
func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]], func: Callable[[Var[V1], Var[V2], Var[V3], Var[V4]], CustomVarOperationReturn[T]],
) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ... ) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3, V4], T]]: ...
@overload @overload
def var_operation( def var_operation(
func: Callable[P, CustomVarOperationReturn[T]], func: Callable[
) -> Callable[P, Var[T]]: ... [Var[V1], Var[V2], Var[V3], Var[V4], Var[V5]],
CustomVarOperationReturn[T],
],
) -> ArgsFunctionOperation[ReflexCallable[[V1, V2, V3, V4, V5], T]]: ...
def var_operation( def var_operation(
func: Callable[P, CustomVarOperationReturn[T]], func: Callable[..., CustomVarOperationReturn[T]],
) -> Callable[P, Var[T]]: ) -> ArgsFunctionOperation:
"""Decorator for creating a var operation. """Decorator for creating a var operation.
Example: Example:
```python ```python
@var_operation @var_operation
def add(a: NumberVar, b: NumberVar): def add(a: Var[int], b: Var[int]):
return custom_var_operation(f"{a} + {b}") return custom_var_operation(f"{a} + {b}")
``` ```
@ -1487,26 +1555,61 @@ def var_operation(
Returns: Returns:
The decorated function. The decorated function.
""" """
from .function import ArgsFunctionOperation, ReflexCallable
@functools.wraps(func) func_name = func.__name__
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Var[T]:
func_args = list(inspect.signature(func).parameters)
args_vars = {
func_args[i]: (LiteralVar.create(arg) if not isinstance(arg, Var) else arg)
for i, arg in enumerate(args)
}
kwargs_vars = {
key: LiteralVar.create(value) if not isinstance(value, Var) else value
for key, value in kwargs.items()
}
return CustomVarOperation.create( func_arg_spec = inspect.getfullargspec(func)
name=func.__name__,
args=tuple(list(args_vars.items()) + list(kwargs_vars.items())),
return_var=func(*args_vars.values(), **kwargs_vars), # type: ignore
).guess_type()
return wrapper if func_arg_spec.kwonlyargs:
raise TypeError(f"Function {func_name} cannot have keyword-only arguments.")
if func_arg_spec.varargs:
raise TypeError(f"Function {func_name} cannot have variable arguments.")
arg_names = func_arg_spec.args
type_hints = get_type_hints(func)
if not all(
(get_origin((type_hint := type_hints.get(arg_name, Any))) or type_hint) is Var
and len(get_args(type_hint)) <= 1
for arg_name in arg_names
):
raise TypeError(
f"Function {func_name} must have type hints of the form `Var[Type]`."
)
args_with_type_hints = tuple(
(arg_name, (args[0] if (args := get_args(type_hints[arg_name])) else Any))
for arg_name in arg_names
)
arg_vars = tuple(
(
Var("_" + arg_name, _var_type=arg_python_type)
if not isinstance(arg_python_type, TypeVar)
else Var("_" + arg_name)
)
for arg_name, arg_python_type in args_with_type_hints
)
custom_operation_return = func(*arg_vars)
args_operation = ArgsFunctionOperation.create(
tuple(map(str, arg_vars)),
custom_operation_return,
validators=tuple(
validate_arg(type_hints.get(arg_name, Any)) for arg_name in arg_names
),
function_name=func_name,
type_computer=custom_operation_return._type_computer,
_var_type=ReflexCallable[
tuple(arg_python_type for _, arg_python_type in args_with_type_hints),
custom_operation_return._var_type,
],
)
return args_operation
def figure_out_type(value: Any) -> types.GenericType: def figure_out_type(value: Any) -> types.GenericType:
@ -1621,66 +1724,6 @@ class CachedVarOperation:
) )
def and_operation(a: Var | Any, b: Var | Any) -> Var:
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
return _and_operation(a, b) # type: ignore
@var_operation
def _and_operation(a: Var, b: Var):
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
return var_operation_return(
js_expression=f"({a} && {b})",
var_type=unionize(a._var_type, b._var_type),
)
def or_operation(a: Var | Any, b: Var | Any) -> Var:
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
return _or_operation(a, b) # type: ignore
@var_operation
def _or_operation(a: Var, b: Var):
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
return var_operation_return(
js_expression=f"({a} || {b})",
var_type=unionize(a._var_type, b._var_type),
)
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
@ -2289,14 +2332,22 @@ def computed_var(
RETURN = TypeVar("RETURN") RETURN = TypeVar("RETURN")
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class CustomVarOperationReturn(Var[RETURN]): class CustomVarOperationReturn(Var[RETURN]):
"""Base class for custom var operations.""" """Base class for custom var operations."""
_type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
@classmethod @classmethod
def create( def create(
cls, cls,
js_expression: str, js_expression: str,
_var_type: Type[RETURN] | None = None, _var_type: Type[RETURN] | None = None,
_type_computer: Optional[TypeComputer] = None,
_var_data: VarData | None = None, _var_data: VarData | None = None,
) -> CustomVarOperationReturn[RETURN]: ) -> CustomVarOperationReturn[RETURN]:
"""Create a CustomVarOperation. """Create a CustomVarOperation.
@ -2304,6 +2355,7 @@ class CustomVarOperationReturn(Var[RETURN]):
Args: Args:
js_expression: The JavaScript expression to evaluate. js_expression: The JavaScript expression to evaluate.
_var_type: The type of the var. _var_type: The type of the var.
_type_computer: A function to compute the type of the var given the arguments.
_var_data: Additional hooks and imports associated with the Var. _var_data: Additional hooks and imports associated with the Var.
Returns: Returns:
@ -2312,6 +2364,7 @@ class CustomVarOperationReturn(Var[RETURN]):
return CustomVarOperationReturn( return CustomVarOperationReturn(
_js_expr=js_expression, _js_expr=js_expression,
_var_type=_var_type or Any, _var_type=_var_type or Any,
_type_computer=_type_computer,
_var_data=_var_data, _var_data=_var_data,
) )
@ -2319,6 +2372,7 @@ class CustomVarOperationReturn(Var[RETURN]):
def var_operation_return( def var_operation_return(
js_expression: str, js_expression: str,
var_type: Type[RETURN] | None = None, var_type: Type[RETURN] | None = None,
type_computer: Optional[TypeComputer] = None,
var_data: VarData | None = None, var_data: VarData | None = None,
) -> CustomVarOperationReturn[RETURN]: ) -> CustomVarOperationReturn[RETURN]:
"""Shortcut for creating a CustomVarOperationReturn. """Shortcut for creating a CustomVarOperationReturn.
@ -2326,15 +2380,17 @@ def var_operation_return(
Args: Args:
js_expression: The JavaScript expression to evaluate. js_expression: The JavaScript expression to evaluate.
var_type: The type of the var. var_type: The type of the var.
type_computer: A function to compute the type of the var given the arguments.
var_data: Additional hooks and imports associated with the Var. var_data: Additional hooks and imports associated with the Var.
Returns: Returns:
The CustomVarOperationReturn. The CustomVarOperationReturn.
""" """
return CustomVarOperationReturn.create( return CustomVarOperationReturn.create(
js_expression, js_expression=js_expression,
var_type, _var_type=var_type,
var_data, _type_computer=type_computer,
_var_data=var_data,
) )
@ -2942,3 +2998,157 @@ def field(value: T) -> Field[T]:
The Field. The Field.
""" """
return value # type: ignore return value # type: ignore
def and_operation(a: Var | Any, b: Var | Any) -> Var:
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
return _and_operation(a, b) # type: ignore
@var_operation
def _and_operation(a: Var, b: Var):
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
def type_computer(*args: Var):
if not args:
return (ReflexCallable[[Any, Any], Any], type_computer)
if len(args) == 1:
return (
ReflexCallable[[Any], Any],
functools.partial(type_computer, args[0]),
)
return (
ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
None,
)
return var_operation_return(
js_expression=f"({a} && {b})",
type_computer=type_computer,
)
def or_operation(a: Var | Any, b: Var | Any) -> Var:
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
return _or_operation(a, b) # type: ignore
@var_operation
def _or_operation(a: Var, b: Var):
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
def type_computer(*args: Var):
if not args:
return (ReflexCallable[[Any, Any], Any], type_computer)
if len(args) == 1:
return (
ReflexCallable[[Any], Any],
functools.partial(type_computer, args[0]),
)
return (
ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
None,
)
return var_operation_return(
js_expression=f"({a} || {b})",
type_computer=type_computer,
)
def passthrough_unary_type_computer(no_args: GenericType) -> TypeComputer:
"""Create a type computer for unary operations.
Args:
no_args: The type to return when no arguments are provided.
Returns:
The type computer.
"""
def type_computer(*args: Var):
if not args:
return (no_args, type_computer)
return (ReflexCallable[[], args[0]._var_type], None)
return type_computer
def unary_type_computer(
no_args: GenericType, computer: Callable[[Var], GenericType]
) -> TypeComputer:
"""Create a type computer for unary operations.
Args:
no_args: The type to return when no arguments are provided.
computer: The function to compute the type.
Returns:
The type computer.
"""
def type_computer(*args: Var):
if not args:
return (no_args, type_computer)
return (ReflexCallable[[], computer(args[0])], None)
return type_computer
def nary_type_computer(
*types: GenericType, computer: Callable[..., GenericType]
) -> TypeComputer:
"""Create a type computer for n-ary operations.
Args:
types: The types to return when no arguments are provided.
computer: The function to compute the type.
Returns:
The type computer.
"""
def type_computer(*args: Var):
if len(args) != len(types):
return (
ReflexCallable[[], types[len(args)]],
functools.partial(type_computer, *args),
)
return (
ReflexCallable[[], computer(args)],
None,
)
return type_computer

View File

@ -6,28 +6,31 @@ import dataclasses
import sys import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar from typing_extensions import Concatenate, Generic, ParamSpec, TypeVar
from reflex.utils import format from reflex.utils import format
from reflex.utils.exceptions import VarTypeError from reflex.utils.exceptions import VarTypeError
from reflex.utils.types import GenericType from reflex.utils.types import GenericType
from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock from .base import (
CachedVarOperation,
LiteralVar,
ReflexCallable,
TypeComputer,
Var,
VarData,
cached_property_no_lock,
unwrap_reflex_callalbe,
)
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R")
V1 = TypeVar("V1") V1 = TypeVar("V1")
V2 = TypeVar("V2") V2 = TypeVar("V2")
V3 = TypeVar("V3") V3 = TypeVar("V3")
V4 = TypeVar("V4") V4 = TypeVar("V4")
V5 = TypeVar("V5") V5 = TypeVar("V5")
V6 = TypeVar("V6") V6 = TypeVar("V6")
R = TypeVar("R")
class ReflexCallable(Protocol[P, R]):
"""Protocol for a callable."""
__call__: Callable[P, R]
CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True) CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
@ -112,20 +115,37 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
""" """
if not args: if not args:
return self return self
args = tuple(map(LiteralVar.create, args))
remaining_validators = self._pre_check(*args) remaining_validators = self._pre_check(*args)
partial_types, type_computer = self._partial_type(*args)
if self.__call__ is self.partial: if self.__call__ is self.partial:
# if the default behavior is partial, we should return a new partial function # if the default behavior is partial, we should return a new partial function
return ArgsFunctionOperationBuilder.create( return ArgsFunctionOperationBuilder.create(
(), (),
VarOperationCall.create(self, *args, Var(_js_expr="...args")), VarOperationCall.create(
self,
*args,
Var(_js_expr="...args"),
_var_type=self._return_type(*args),
),
rest="args", rest="args",
validators=remaining_validators, validators=remaining_validators,
type_computer=type_computer,
_var_type=partial_types,
) )
return ArgsFunctionOperation.create( return ArgsFunctionOperation.create(
(), (),
VarOperationCall.create(self, *args, Var(_js_expr="...args")), VarOperationCall.create(
self, *args, Var(_js_expr="...args"), _var_type=self._return_type(*args)
),
rest="args", rest="args",
validators=remaining_validators, validators=remaining_validators,
type_computer=type_computer,
_var_type=partial_types,
) )
@overload @overload
@ -194,9 +214,56 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
Returns: Returns:
The function call operation. The function call operation.
Raises:
VarTypeError: If the number of arguments is invalid
""" """
arg_len = self._arg_len()
if arg_len is not None and len(args) != arg_len:
raise VarTypeError(f"Invalid number of arguments provided to {str(self)}")
args = tuple(map(LiteralVar.create, args))
self._pre_check(*args) self._pre_check(*args)
return VarOperationCall.create(self, *args).guess_type() return_type = self._return_type(*args)
return VarOperationCall.create(self, *args, _var_type=return_type).guess_type()
def _partial_type(
self, *args: Var | Any
) -> Tuple[GenericType, Optional[TypeComputer]]:
"""Override the type of the function call with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The overridden type of the function call.
"""
args_types, return_type = unwrap_reflex_callalbe(self._var_type)
if isinstance(args_types, tuple):
return ReflexCallable[[*args_types[len(args) :]], return_type], None
return ReflexCallable[..., return_type], None
def _arg_len(self) -> int | None:
"""Get the number of arguments the function takes.
Returns:
The number of arguments the function takes.
"""
args_types, _ = unwrap_reflex_callalbe(self._var_type)
if isinstance(args_types, tuple):
return len(args_types)
return None
def _return_type(self, *args: Var | Any) -> GenericType:
"""Override the type of the function call with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
The overridden type of the function call.
"""
partial_types, _ = self._partial_type(*args)
return unwrap_reflex_callalbe(partial_types)[1]
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]: def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments. """Check if the function can be called with the given arguments.
@ -343,11 +410,12 @@ class FunctionArgs:
def format_args_function_operation( def format_args_function_operation(
args: FunctionArgs, return_expr: Var | Any, explicit_return: bool self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
) -> str: ) -> str:
"""Format an args function operation. """Format an args function operation.
Args: Args:
self: The function operation.
args: The function arguments. args: The function arguments.
return_expr: The return expression. return_expr: The return expression.
explicit_return: Whether to use explicit return syntax. explicit_return: Whether to use explicit return syntax.
@ -356,26 +424,76 @@ def format_args_function_operation(
The formatted args function operation. The formatted args function operation.
""" """
arg_names_str = ", ".join( arg_names_str = ", ".join(
[arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args] [
+ ([f"...{args.rest}"] if args.rest else []) arg if isinstance(arg, str) else arg.to_javascript()
for arg in self._args.args
]
+ ([f"...{self._args.rest}"] if self._args.rest else [])
) )
return_expr_str = str(LiteralVar.create(return_expr)) return_expr_str = str(LiteralVar.create(self._return_expr))
# Wrap return expression in curly braces if explicit return syntax is used. # Wrap return expression in curly braces if explicit return syntax is used.
return_expr_str_wrapped = ( return_expr_str_wrapped = (
format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str format.wrap(return_expr_str, "{", "}")
if self._explicit_return
else return_expr_str
) )
return f"(({arg_names_str}) => {return_expr_str_wrapped})" return f"(({arg_names_str}) => {return_expr_str_wrapped})"
def pre_check_args(
self: ArgsFunctionOperation | ArgsFunctionOperationBuilder, *args: Var | Any
) -> Tuple[Callable[[Any], bool], ...]:
"""Check if the function can be called with the given arguments.
Args:
self: The function operation.
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
for i, (validator, arg) in enumerate(zip(self._validators, args)):
if not validator(arg):
arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
def figure_partial_type(
self: ArgsFunctionOperation | ArgsFunctionOperationBuilder,
*args: Var | Any,
) -> Tuple[GenericType, Optional[TypeComputer]]:
"""Figure out the return type of the function.
Args:
self: The function operation.
*args: The arguments to call the function with.
Returns:
The return type of the function.
"""
return (
self._type_computer(*args)
if self._type_computer is not None
else FunctionVar._partial_type(self, *args)
)
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ArgsFunctionOperation(CachedVarOperation, FunctionVar): class ArgsFunctionOperation(CachedVarOperation, FunctionVar[CALLABLE_TYPE]):
"""Base class for immutable function defined via arguments and return expression.""" """Base class for immutable function defined via arguments and return expression."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
@ -384,39 +502,14 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
) )
_return_expr: Union[Var, Any] = dataclasses.field(default=None) _return_expr: Union[Var, Any] = dataclasses.field(default=None)
_function_name: str = dataclasses.field(default="") _function_name: str = dataclasses.field(default="")
_type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
_explicit_return: bool = dataclasses.field(default=False) _explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock _cached_var_name = cached_property_no_lock(format_args_function_operation)
def _cached_var_name(self) -> str:
"""The name of the var.
Returns: _pre_check = pre_check_args
The name of the var.
"""
return format_args_function_operation(
self._args, self._return_expr, self._explicit_return
)
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]: _partial_type = figure_partial_type
"""Check if the function can be called with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
for i, (validator, arg) in enumerate(zip(self._validators, args)):
if not validator(arg):
arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
@classmethod @classmethod
def create( def create(
@ -427,6 +520,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
validators: Sequence[Callable[[Any], bool]] = (), validators: Sequence[Callable[[Any], bool]] = (),
function_name: str = "", function_name: str = "",
explicit_return: bool = False, explicit_return: bool = False,
type_computer: Optional[TypeComputer] = None,
_var_type: GenericType = Callable, _var_type: GenericType = Callable,
_var_data: VarData | None = None, _var_data: VarData | None = None,
): ):
@ -439,6 +533,8 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
validators: The validators for the arguments. validators: The validators for the arguments.
function_name: The name of the function. function_name: The name of the function.
explicit_return: Whether to use explicit return syntax. explicit_return: Whether to use explicit return syntax.
type_computer: A function to compute the return type.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var. _var_data: Additional hooks and imports associated with the Var.
Returns: Returns:
@ -453,6 +549,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
_validators=tuple(validators), _validators=tuple(validators),
_return_expr=return_expr, _return_expr=return_expr,
_explicit_return=explicit_return, _explicit_return=explicit_return,
_type_computer=type_computer,
) )
@ -461,7 +558,9 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): class ArgsFunctionOperationBuilder(
CachedVarOperation, BuilderFunctionVar[CALLABLE_TYPE]
):
"""Base class for immutable function defined via arguments and return expression with the builder pattern.""" """Base class for immutable function defined via arguments and return expression with the builder pattern."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs) _args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
@ -470,39 +569,14 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
) )
_return_expr: Union[Var, Any] = dataclasses.field(default=None) _return_expr: Union[Var, Any] = dataclasses.field(default=None)
_function_name: str = dataclasses.field(default="") _function_name: str = dataclasses.field(default="")
_type_computer: Optional[TypeComputer] = dataclasses.field(default=None)
_explicit_return: bool = dataclasses.field(default=False) _explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock _cached_var_name = cached_property_no_lock(format_args_function_operation)
def _cached_var_name(self) -> str:
"""The name of the var.
Returns: _pre_check = pre_check_args
The name of the var.
"""
return format_args_function_operation(
self._args, self._return_expr, self._explicit_return
)
def _pre_check(self, *args: Var | Any) -> Tuple[Callable[[Any], bool], ...]: _partial_type = figure_partial_type
"""Check if the function can be called with the given arguments.
Args:
*args: The arguments to call the function with.
Returns:
True if the function can be called with the given arguments.
"""
for i, (validator, arg) in enumerate(zip(self._validators, args)):
if not validator(arg):
arg_name = self._args.args[i] if i < len(self._args.args) else None
if arg_name is not None:
raise VarTypeError(
f"Invalid argument {str(arg)} provided to {arg_name} in {self._function_name or 'var operation'}"
)
raise VarTypeError(
f"Invalid argument {str(arg)} provided to argument {i} in {self._function_name or 'var operation'}"
)
return self._validators[len(args) :]
@classmethod @classmethod
def create( def create(
@ -513,6 +587,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
validators: Sequence[Callable[[Any], bool]] = (), validators: Sequence[Callable[[Any], bool]] = (),
function_name: str = "", function_name: str = "",
explicit_return: bool = False, explicit_return: bool = False,
type_computer: Optional[TypeComputer] = None,
_var_type: GenericType = Callable, _var_type: GenericType = Callable,
_var_data: VarData | None = None, _var_data: VarData | None = None,
): ):
@ -525,6 +600,8 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
validators: The validators for the arguments. validators: The validators for the arguments.
function_name: The name of the function. function_name: The name of the function.
explicit_return: Whether to use explicit return syntax. explicit_return: Whether to use explicit return syntax.
type_computer: A function to compute the return type.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var. _var_data: Additional hooks and imports associated with the Var.
Returns: Returns:
@ -539,6 +616,7 @@ class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
_validators=tuple(validators), _validators=tuple(validators),
_return_expr=return_expr, _return_expr=return_expr,
_explicit_return=explicit_return, _explicit_return=explicit_return,
_type_computer=type_computer,
) )

View File

@ -3,19 +3,11 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import functools
import json import json
import math import math
import sys import sys
from typing import ( from typing import TYPE_CHECKING, Any, Callable, NoReturn, TypeVar, Union, overload
TYPE_CHECKING,
Any,
Callable,
NoReturn,
Type,
TypeVar,
Union,
overload,
)
from reflex.constants.base import Dirs from reflex.constants.base import Dirs
from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
@ -25,8 +17,11 @@ from reflex.utils.types import is_optional
from .base import ( from .base import (
CustomVarOperationReturn, CustomVarOperationReturn,
LiteralVar, LiteralVar,
ReflexCallable,
Var, Var,
VarData, VarData,
nary_type_computer,
passthrough_unary_type_computer,
unionize, unionize,
var_operation, var_operation,
var_operation_return, var_operation_return,
@ -544,8 +539,8 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
def binary_number_operation( def binary_number_operation(
func: Callable[[NumberVar, NumberVar], str], func: Callable[[Var[int | float], Var[int | float]], str],
) -> Callable[[number_types, number_types], NumberVar]: ):
"""Decorator to create a binary number operation. """Decorator to create a binary number operation.
Args: Args:
@ -555,30 +550,37 @@ def binary_number_operation(
The binary number operation. The binary number operation.
""" """
@var_operation def operation(
def operation(lhs: NumberVar, rhs: NumberVar): lhs: Var[int | float], rhs: Var[int | float]
) -> CustomVarOperationReturn[int | float]:
def type_computer(*args: Var):
if not args:
return (
ReflexCallable[[int | float, int | float], int | float],
type_computer,
)
if len(args) == 1:
return (
ReflexCallable[[int | float], int | float],
functools.partial(type_computer, args[0]),
)
return (
ReflexCallable[[], unionize(args[0]._var_type, args[1]._var_type)],
None,
)
return var_operation_return( return var_operation_return(
js_expression=func(lhs, rhs), js_expression=func(lhs, rhs),
var_type=unionize(lhs._var_type, rhs._var_type), type_computer=type_computer,
) )
def wrapper(lhs: number_types, rhs: number_types) -> NumberVar: operation.__name__ = func.__name__
"""Create the binary number operation.
Args: return var_operation(operation)
lhs: The first number.
rhs: The second number.
Returns:
The binary number operation.
"""
return operation(lhs, rhs) # type: ignore
return wrapper
@binary_number_operation @binary_number_operation
def number_add_operation(lhs: NumberVar, rhs: NumberVar): def number_add_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Add two numbers. """Add two numbers.
Args: Args:
@ -592,7 +594,7 @@ def number_add_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation @binary_number_operation
def number_subtract_operation(lhs: NumberVar, rhs: NumberVar): def number_subtract_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Subtract two numbers. """Subtract two numbers.
Args: Args:
@ -605,8 +607,15 @@ def number_subtract_operation(lhs: NumberVar, rhs: NumberVar):
return f"({lhs} - {rhs})" return f"({lhs} - {rhs})"
unary_operation_type_computer = passthrough_unary_type_computer(
ReflexCallable[[int | float], int | float]
)
@var_operation @var_operation
def number_abs_operation(value: NumberVar): def number_abs_operation(
value: Var[int | float],
) -> CustomVarOperationReturn[int | float]:
"""Get the absolute value of the number. """Get the absolute value of the number.
Args: Args:
@ -616,12 +625,12 @@ def number_abs_operation(value: NumberVar):
The number absolute operation. The number absolute operation.
""" """
return var_operation_return( return var_operation_return(
js_expression=f"Math.abs({value})", var_type=value._var_type js_expression=f"Math.abs({value})", type_computer=unary_operation_type_computer
) )
@binary_number_operation @binary_number_operation
def number_multiply_operation(lhs: NumberVar, rhs: NumberVar): def number_multiply_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Multiply two numbers. """Multiply two numbers.
Args: Args:
@ -636,7 +645,7 @@ def number_multiply_operation(lhs: NumberVar, rhs: NumberVar):
@var_operation @var_operation
def number_negate_operation( def number_negate_operation(
value: NumberVar[NUMBER_T], value: Var[NUMBER_T],
) -> CustomVarOperationReturn[NUMBER_T]: ) -> CustomVarOperationReturn[NUMBER_T]:
"""Negate the number. """Negate the number.
@ -646,11 +655,13 @@ def number_negate_operation(
Returns: Returns:
The number negation operation. The number negation operation.
""" """
return var_operation_return(js_expression=f"-({value})", var_type=value._var_type) return var_operation_return(
js_expression=f"-({value})", type_computer=unary_operation_type_computer
)
@binary_number_operation @binary_number_operation
def number_true_division_operation(lhs: NumberVar, rhs: NumberVar): def number_true_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Divide two numbers. """Divide two numbers.
Args: Args:
@ -664,7 +675,7 @@ def number_true_division_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation @binary_number_operation
def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar): def number_floor_division_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Floor divide two numbers. """Floor divide two numbers.
Args: Args:
@ -678,7 +689,7 @@ def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation @binary_number_operation
def number_modulo_operation(lhs: NumberVar, rhs: NumberVar): def number_modulo_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Modulo two numbers. """Modulo two numbers.
Args: Args:
@ -692,7 +703,7 @@ def number_modulo_operation(lhs: NumberVar, rhs: NumberVar):
@binary_number_operation @binary_number_operation
def number_exponent_operation(lhs: NumberVar, rhs: NumberVar): def number_exponent_operation(lhs: Var[int | float], rhs: Var[int | float]):
"""Exponentiate two numbers. """Exponentiate two numbers.
Args: Args:
@ -706,7 +717,7 @@ def number_exponent_operation(lhs: NumberVar, rhs: NumberVar):
@var_operation @var_operation
def number_round_operation(value: NumberVar): def number_round_operation(value: Var[int | float]):
"""Round the number. """Round the number.
Args: Args:
@ -719,7 +730,7 @@ def number_round_operation(value: NumberVar):
@var_operation @var_operation
def number_ceil_operation(value: NumberVar): def number_ceil_operation(value: Var[int | float]):
"""Ceil the number. """Ceil the number.
Args: Args:
@ -732,7 +743,7 @@ def number_ceil_operation(value: NumberVar):
@var_operation @var_operation
def number_floor_operation(value: NumberVar): def number_floor_operation(value: Var[int | float]):
"""Floor the number. """Floor the number.
Args: Args:
@ -745,7 +756,7 @@ def number_floor_operation(value: NumberVar):
@var_operation @var_operation
def number_trunc_operation(value: NumberVar): def number_trunc_operation(value: Var[int | float]):
"""Trunc the number. """Trunc the number.
Args: Args:
@ -838,7 +849,7 @@ class BooleanVar(NumberVar[bool], python_types=bool):
@var_operation @var_operation
def boolean_to_number_operation(value: BooleanVar): def boolean_to_number_operation(value: Var[bool]):
"""Convert the boolean to a number. """Convert the boolean to a number.
Args: Args:
@ -969,7 +980,7 @@ def not_equal_operation(lhs: Var, rhs: Var):
@var_operation @var_operation
def boolean_not_operation(value: BooleanVar): def boolean_not_operation(value: Var[bool]):
"""Boolean NOT the boolean. """Boolean NOT the boolean.
Args: Args:
@ -1117,7 +1128,7 @@ U = TypeVar("U")
@var_operation @var_operation
def ternary_operation( def ternary_operation(
condition: BooleanVar, if_true: Var[T], if_false: Var[U] condition: Var[bool], if_true: Var[T], if_false: Var[U]
) -> CustomVarOperationReturn[Union[T, U]]: ) -> CustomVarOperationReturn[Union[T, U]]:
"""Create a ternary operation. """Create a ternary operation.
@ -1129,12 +1140,14 @@ def ternary_operation(
Returns: Returns:
The ternary operation. The ternary operation.
""" """
type_value: Union[Type[T], Type[U]] = unionize(
if_true._var_type, if_false._var_type
)
value: CustomVarOperationReturn[Union[T, U]] = var_operation_return( value: CustomVarOperationReturn[Union[T, U]] = var_operation_return(
js_expression=f"({condition} ? {if_true} : {if_false})", js_expression=f"({condition} ? {if_true} : {if_false})",
var_type=type_value, type_computer=nary_type_computer(
ReflexCallable[[bool, Any, Any], Any],
ReflexCallable[[Any, Any], Any],
ReflexCallable[[Any], Any],
computer=lambda args: unionize(args[1]._var_type, args[2]._var_type),
),
) )
return value return value

View File

@ -21,15 +21,23 @@ from typing import (
from reflex.utils import types from reflex.utils import types
from reflex.utils.exceptions import VarAttributeError from reflex.utils.exceptions import VarAttributeError
from reflex.utils.types import GenericType, get_attribute_access_type, get_origin from reflex.utils.types import (
GenericType,
get_attribute_access_type,
get_origin,
unionize,
)
from .base import ( from .base import (
CachedVarOperation, CachedVarOperation,
LiteralVar, LiteralVar,
ReflexCallable,
Var, Var,
VarData, VarData,
cached_property_no_lock, cached_property_no_lock,
figure_out_type, figure_out_type,
nary_type_computer,
unary_type_computer,
var_operation, var_operation,
var_operation_return, var_operation_return,
) )
@ -406,7 +414,7 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
@var_operation @var_operation
def object_keys_operation(value: ObjectVar): def object_keys_operation(value: Var):
"""Get the keys of an object. """Get the keys of an object.
Args: Args:
@ -422,7 +430,7 @@ def object_keys_operation(value: ObjectVar):
@var_operation @var_operation
def object_values_operation(value: ObjectVar): def object_values_operation(value: Var):
"""Get the values of an object. """Get the values of an object.
Args: Args:
@ -433,12 +441,15 @@ def object_values_operation(value: ObjectVar):
""" """
return var_operation_return( return var_operation_return(
js_expression=f"Object.values({value})", js_expression=f"Object.values({value})",
var_type=List[value._value_type()], type_computer=unary_type_computer(
ReflexCallable[[Any], List[Any]],
lambda x: List[x.to(ObjectVar)._value_type()],
),
) )
@var_operation @var_operation
def object_entries_operation(value: ObjectVar): def object_entries_operation(value: Var):
"""Get the entries of an object. """Get the entries of an object.
Args: Args:
@ -447,14 +458,18 @@ def object_entries_operation(value: ObjectVar):
Returns: Returns:
The entries of the object. The entries of the object.
""" """
value = value.to(ObjectVar)
return var_operation_return( return var_operation_return(
js_expression=f"Object.entries({value})", js_expression=f"Object.entries({value})",
var_type=List[Tuple[str, value._value_type()]], type_computer=unary_type_computer(
ReflexCallable[[Any], List[Tuple[str, Any]]],
lambda x: List[Tuple[str, x.to(ObjectVar)._value_type()]],
),
) )
@var_operation @var_operation
def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): def object_merge_operation(lhs: Var, rhs: Var):
"""Merge two objects. """Merge two objects.
Args: Args:
@ -466,10 +481,14 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
""" """
return var_operation_return( return var_operation_return(
js_expression=f"({{...{lhs}, ...{rhs}}})", js_expression=f"({{...{lhs}, ...{rhs}}})",
var_type=Dict[ type_computer=nary_type_computer(
Union[lhs._key_type(), rhs._key_type()], ReflexCallable[[Any, Any], Dict[Any, Any]],
Union[lhs._value_type(), rhs._value_type()], ReflexCallable[[Any], Dict[Any, Any]],
], computer=lambda args: Dict[
unionize(*[arg.to(ObjectVar)._key_type() for arg in args]),
unionize(*[arg.to(ObjectVar)._value_type() for arg in args]),
],
),
) )
@ -526,7 +545,7 @@ class ObjectItemOperation(CachedVarOperation, Var):
@var_operation @var_operation
def object_has_own_property_operation(object: ObjectVar, key: Var): def object_has_own_property_operation(object: Var, key: Var):
"""Check if an object has a key. """Check if an object has a key.
Args: Args:

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import functools
import inspect import inspect
import json import json
import re import re
@ -34,6 +35,7 @@ from .base import (
CachedVarOperation, CachedVarOperation,
CustomVarOperationReturn, CustomVarOperationReturn,
LiteralVar, LiteralVar,
ReflexCallable,
Var, Var,
VarData, VarData,
_global_vars, _global_vars,
@ -41,7 +43,10 @@ from .base import (
figure_out_type, figure_out_type,
get_python_literal, get_python_literal,
get_unique_variable_name, get_unique_variable_name,
nary_type_computer,
passthrough_unary_type_computer,
unionize, unionize,
unwrap_reflex_callalbe,
var_operation, var_operation,
var_operation_return, var_operation_return,
) )
@ -353,7 +358,7 @@ class StringVar(Var[STRING_TYPE], python_types=str):
@var_operation @var_operation
def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): def string_lt_operation(lhs: Var[str], rhs: Var[str]):
"""Check if a string is less than another string. """Check if a string is less than another string.
Args: Args:
@ -367,7 +372,7 @@ def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
@var_operation @var_operation
def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): def string_gt_operation(lhs: Var[str], rhs: Var[str]):
"""Check if a string is greater than another string. """Check if a string is greater than another string.
Args: Args:
@ -381,7 +386,7 @@ def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
@var_operation @var_operation
def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): def string_le_operation(lhs: Var[str], rhs: Var[str]):
"""Check if a string is less than or equal to another string. """Check if a string is less than or equal to another string.
Args: Args:
@ -395,7 +400,7 @@ def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
@var_operation @var_operation
def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str): def string_ge_operation(lhs: Var[str], rhs: Var[str]):
"""Check if a string is greater than or equal to another string. """Check if a string is greater than or equal to another string.
Args: Args:
@ -409,7 +414,7 @@ def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
@var_operation @var_operation
def string_lower_operation(string: StringVar[Any]): def string_lower_operation(string: Var[str]):
"""Convert a string to lowercase. """Convert a string to lowercase.
Args: Args:
@ -422,7 +427,7 @@ def string_lower_operation(string: StringVar[Any]):
@var_operation @var_operation
def string_upper_operation(string: StringVar[Any]): def string_upper_operation(string: Var[str]):
"""Convert a string to uppercase. """Convert a string to uppercase.
Args: Args:
@ -435,7 +440,7 @@ def string_upper_operation(string: StringVar[Any]):
@var_operation @var_operation
def string_strip_operation(string: StringVar[Any]): def string_strip_operation(string: Var[str]):
"""Strip a string. """Strip a string.
Args: Args:
@ -449,7 +454,7 @@ def string_strip_operation(string: StringVar[Any]):
@var_operation @var_operation
def string_contains_field_operation( def string_contains_field_operation(
haystack: StringVar[Any], needle: StringVar[Any] | str, field: StringVar[Any] | str haystack: Var[str], needle: Var[str], field: Var[str]
): ):
"""Check if a string contains another string. """Check if a string contains another string.
@ -468,7 +473,7 @@ def string_contains_field_operation(
@var_operation @var_operation
def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] | str): def string_contains_operation(haystack: Var[str], needle: Var[str]):
"""Check if a string contains another string. """Check if a string contains another string.
Args: Args:
@ -484,9 +489,7 @@ def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] |
@var_operation @var_operation
def string_starts_with_operation( def string_starts_with_operation(full_string: Var[str], prefix: Var[str]):
full_string: StringVar[Any], prefix: StringVar[Any] | str
):
"""Check if a string starts with a prefix. """Check if a string starts with a prefix.
Args: Args:
@ -502,7 +505,7 @@ def string_starts_with_operation(
@var_operation @var_operation
def string_item_operation(string: StringVar[Any], index: NumberVar | int): def string_item_operation(string: Var[str], index: Var[int]):
"""Get an item from a string. """Get an item from a string.
Args: Args:
@ -515,23 +518,9 @@ def string_item_operation(string: StringVar[Any], index: NumberVar | int):
return var_operation_return(js_expression=f"{string}.at({index})", var_type=str) return var_operation_return(js_expression=f"{string}.at({index})", var_type=str)
@var_operation
def array_join_operation(array: ArrayVar, sep: StringVar[Any] | str = ""):
"""Join the elements of an array.
Args:
array: The array.
sep: The separator.
Returns:
The joined elements.
"""
return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str)
@var_operation @var_operation
def string_replace_operation( def string_replace_operation(
string: StringVar, search_value: StringVar | str, new_value: StringVar | str string: Var[str], search_value: Var[str], new_value: Var[str]
): ):
"""Replace a string with a value. """Replace a string with a value.
@ -1046,7 +1035,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
Returns: Returns:
The array pluck operation. The array pluck operation.
""" """
return array_pluck_operation(self, field) return array_pluck_operation(self, field).guess_type()
@overload @overload
def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: ... def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: ...
@ -1300,7 +1289,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
@var_operation @var_operation
def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""): def string_split_operation(string: Var[str], sep: Var[str]):
"""Split a string. """Split a string.
Args: Args:
@ -1394,9 +1383,9 @@ class ArraySliceOperation(CachedVarOperation, ArrayVar):
@var_operation @var_operation
def array_pluck_operation( def array_pluck_operation(
array: ArrayVar[ARRAY_VAR_TYPE], array: Var[ARRAY_VAR_TYPE],
field: StringVar | str, field: Var[str],
) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: ) -> CustomVarOperationReturn[List]:
"""Pluck a field from an array of objects. """Pluck a field from an array of objects.
Args: Args:
@ -1408,13 +1397,27 @@ def array_pluck_operation(
""" """
return var_operation_return( return var_operation_return(
js_expression=f"{array}.map(e=>e?.[{field}])", js_expression=f"{array}.map(e=>e?.[{field}])",
var_type=array._var_type, var_type=List[Any],
) )
@var_operation
def array_join_operation(array: Var[ARRAY_VAR_TYPE], sep: Var[str]):
"""Join the elements of an array.
Args:
array: The array.
sep: The separator.
Returns:
The joined elements.
"""
return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str)
@var_operation @var_operation
def array_reverse_operation( def array_reverse_operation(
array: ArrayVar[ARRAY_VAR_TYPE], array: Var[ARRAY_VAR_TYPE],
) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
"""Reverse an array. """Reverse an array.
@ -1426,12 +1429,12 @@ def array_reverse_operation(
""" """
return var_operation_return( return var_operation_return(
js_expression=f"{array}.slice().reverse()", js_expression=f"{array}.slice().reverse()",
var_type=array._var_type, type_computer=passthrough_unary_type_computer(ReflexCallable[[List], List]),
) )
@var_operation @var_operation
def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): def array_lt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
"""Check if an array is less than another array. """Check if an array is less than another array.
Args: Args:
@ -1445,7 +1448,7 @@ def array_lt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
@var_operation @var_operation
def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): def array_gt_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
"""Check if an array is greater than another array. """Check if an array is greater than another array.
Args: Args:
@ -1459,7 +1462,7 @@ def array_gt_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
@var_operation @var_operation
def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): def array_le_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
"""Check if an array is less than or equal to another array. """Check if an array is less than or equal to another array.
Args: Args:
@ -1473,7 +1476,7 @@ def array_le_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
@var_operation @var_operation
def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tuple): def array_ge_operation(lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]):
"""Check if an array is greater than or equal to another array. """Check if an array is greater than or equal to another array.
Args: Args:
@ -1487,7 +1490,7 @@ def array_ge_operation(lhs: ArrayVar | list | tuple, rhs: ArrayVar | list | tupl
@var_operation @var_operation
def array_length_operation(array: ArrayVar): def array_length_operation(array: Var[ARRAY_VAR_TYPE]):
"""Get the length of an array. """Get the length of an array.
Args: Args:
@ -1517,7 +1520,7 @@ def is_tuple_type(t: GenericType) -> bool:
@var_operation @var_operation
def array_item_operation(array: ArrayVar, index: NumberVar | int): def array_item_operation(array: Var[ARRAY_VAR_TYPE], index: Var[int]):
"""Get an item from an array. """Get an item from an array.
Args: Args:
@ -1527,23 +1530,45 @@ def array_item_operation(array: ArrayVar, index: NumberVar | int):
Returns: Returns:
The item from the array. The item from the array.
""" """
args = typing.get_args(array._var_type)
if args and isinstance(index, LiteralNumberVar) and is_tuple_type(array._var_type): def type_computer(*args):
index_value = int(index._var_value) if len(args) == 0:
element_type = args[index_value % len(args)] return (
else: ReflexCallable[[List[Any], int], Any],
element_type = unionize(*args) functools.partial(type_computer, *args),
)
array = args[0]
array_args = typing.get_args(array._var_type)
if len(args) == 1:
return (
ReflexCallable[[int], unionize(*array_args)],
functools.partial(type_computer, *args),
)
index = args[1]
if (
array_args
and isinstance(index, LiteralNumberVar)
and is_tuple_type(array._var_type)
):
index_value = int(index._var_value)
element_type = array_args[index_value % len(array_args)]
else:
element_type = unionize(*array_args)
return (ReflexCallable[[], element_type], None)
return var_operation_return( return var_operation_return(
js_expression=f"{str(array)}.at({str(index)})", js_expression=f"{str(array)}.at({str(index)})",
var_type=element_type, type_computer=type_computer,
) )
@var_operation @var_operation
def array_range_operation( def array_range_operation(start: Var[int], stop: Var[int], step: Var[int]):
start: NumberVar | int, stop: NumberVar | int, step: NumberVar | int
):
"""Create a range of numbers. """Create a range of numbers.
Args: Args:
@ -1562,7 +1587,7 @@ def array_range_operation(
@var_operation @var_operation
def array_contains_field_operation( def array_contains_field_operation(
haystack: ArrayVar, needle: Any | Var, field: StringVar | str haystack: Var[ARRAY_VAR_TYPE], needle: Var, field: Var[str]
): ):
"""Check if an array contains an element. """Check if an array contains an element.
@ -1581,7 +1606,7 @@ def array_contains_field_operation(
@var_operation @var_operation
def array_contains_operation(haystack: ArrayVar, needle: Any | Var): def array_contains_operation(haystack: Var[ARRAY_VAR_TYPE], needle: Var):
"""Check if an array contains an element. """Check if an array contains an element.
Args: Args:
@ -1599,7 +1624,7 @@ def array_contains_operation(haystack: ArrayVar, needle: Any | Var):
@var_operation @var_operation
def repeat_array_operation( def repeat_array_operation(
array: ArrayVar[ARRAY_VAR_TYPE], count: NumberVar | int array: Var[ARRAY_VAR_TYPE], count: Var[int]
) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
"""Repeat an array a number of times. """Repeat an array a number of times.
@ -1610,20 +1635,34 @@ def repeat_array_operation(
Returns: Returns:
The repeated array. The repeated array.
""" """
def type_computer(*args: Var):
if not args:
return (
ReflexCallable[[List[Any], int], List[Any]],
type_computer,
)
if len(args) == 1:
return (
ReflexCallable[[int], args[0]._var_type],
functools.partial(type_computer, *args),
)
return (ReflexCallable[[], args[0]._var_type], None)
return var_operation_return( return var_operation_return(
js_expression=f"Array.from({{ length: {count} }}).flatMap(() => {array})", js_expression=f"Array.from({{ length: {count} }}).flatMap(() => {array})",
var_type=array._var_type, type_computer=type_computer,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .function import FunctionVar pass
@var_operation @var_operation
def map_array_operation( def map_array_operation(
array: ArrayVar[ARRAY_VAR_TYPE], array: Var[ARRAY_VAR_TYPE],
function: FunctionVar, function: Var[ReflexCallable],
): ):
"""Map a function over an array. """Map a function over an array.
@ -1634,14 +1673,33 @@ def map_array_operation(
Returns: Returns:
The mapped array. The mapped array.
""" """
def type_computer(*args: Var):
if not args:
return (
ReflexCallable[[List[Any], ReflexCallable], List[Any]],
type_computer,
)
if len(args) == 1:
return (
ReflexCallable[[ReflexCallable], List[Any]],
functools.partial(type_computer, *args),
)
return (ReflexCallable[[], List[args[0]._var_type]], None)
return var_operation_return( return var_operation_return(
js_expression=f"{array}.map({function})", var_type=List[Any] js_expression=f"{array}.map({function})",
type_computer=nary_type_computer(
ReflexCallable[[List[Any], ReflexCallable], List[Any]],
ReflexCallable[[ReflexCallable], List[Any]],
computer=lambda args: List[unwrap_reflex_callalbe(args[1]._var_type)[1]],
),
) )
@var_operation @var_operation
def array_concat_operation( def array_concat_operation(
lhs: ArrayVar[ARRAY_VAR_TYPE], rhs: ArrayVar[ARRAY_VAR_TYPE] lhs: Var[ARRAY_VAR_TYPE], rhs: Var[ARRAY_VAR_TYPE]
) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: ) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]:
"""Concatenate two arrays. """Concatenate two arrays.
@ -1654,7 +1712,11 @@ def array_concat_operation(
""" """
return var_operation_return( return var_operation_return(
js_expression=f"[...{lhs}, ...{rhs}]", js_expression=f"[...{lhs}, ...{rhs}]",
var_type=Union[lhs._var_type, rhs._var_type], type_computer=nary_type_computer(
ReflexCallable[[List[Any], List[Any]], List[Any]],
ReflexCallable[[List[Any]], List[Any]],
computer=lambda args: unionize(args[0]._var_type, args[1]._var_type),
),
) )

View File

@ -963,11 +963,11 @@ def test_function_var():
def test_var_operation(): def test_var_operation():
@var_operation @var_operation
def add(a: Union[NumberVar, int], b: Union[NumberVar, int]): def add(a: Var[int], b: Var[int]):
return var_operation_return(js_expression=f"({a} + {b})", var_type=int) return var_operation_return(js_expression=f"({a} + {b})", var_type=int)
assert str(add(1, 2)) == "(1 + 2)" assert str(add(1, 2)) == "(1 + 2)"
assert str(add(a=4, b=-9)) == "(4 + -9)" assert str(add(4, -9)) == "(4 + -9)"
five = LiteralNumberVar.create(5) five = LiteralNumberVar.create(5)
seven = add(2, five) seven = add(2, five)