From c07a983f0520e953879f3397dae22a7bbe80c5dc Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Tue, 3 Sep 2024 11:39:05 -0700 Subject: [PATCH] add var_operation and move some operations to the new style (#3841) * add var_operations and move some operations to the new style * change bound style * can't assume int anymore * slice is not hashable (how did this work bef) * convert to int explicitly * move the rest of the operations to new style * fix bool guess type * forgot to precommit dangit * add type ignore to bool for now --- reflex/components/core/cond.py | 13 +- reflex/ivars/__init__.py | 1 - reflex/ivars/base.py | 415 +++++++----- reflex/ivars/number.py | 932 +++++++++++--------------- reflex/ivars/object.py | 318 +++------ reflex/ivars/sequence.py | 957 +++++++-------------------- tests/components/core/test_colors.py | 4 +- tests/components/core/test_cond.py | 7 +- tests/test_var.py | 13 +- 9 files changed, 960 insertions(+), 1700 deletions(-) diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index 80dd35f0f..6e6272665 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -9,7 +9,7 @@ from reflex.components.component import BaseComponent, Component, MemoizationLea from reflex.components.tags import CondTag, Tag from reflex.constants import Dirs from reflex.ivars.base import ImmutableVar, LiteralVar -from reflex.ivars.number import TernaryOperator +from reflex.ivars.number import ternary_operation from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode from reflex.utils.imports import ImportDict, ImportVar from reflex.vars import Var, VarData @@ -163,11 +163,12 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | ImmutableVar: c2 = create_var(c2) # Create the conditional var. - return TernaryOperator.create( - condition=cond_var.to(bool), # type: ignore - if_true=c1, - if_false=c2, - _var_data=VarData(imports=_IS_TRUE_IMPORT), + return ternary_operation( + cond_var.bool()._replace( # type: ignore + merge_var_data=VarData(imports=_IS_TRUE_IMPORT), + ), # type: ignore + c1, + c2, ) diff --git a/reflex/ivars/__init__.py b/reflex/ivars/__init__.py index 8fa5196ff..2c1837510 100644 --- a/reflex/ivars/__init__.py +++ b/reflex/ivars/__init__.py @@ -12,7 +12,6 @@ from .number import LiteralNumberVar as LiteralNumberVar from .number import NumberVar as NumberVar from .object import LiteralObjectVar as LiteralObjectVar from .object import ObjectVar as ObjectVar -from .sequence import ArrayJoinOperation as ArrayJoinOperation from .sequence import ArrayVar as ArrayVar from .sequence import ConcatVarOperation as ConcatVarOperation from .sequence import LiteralArrayVar as LiteralArrayVar diff --git a/reflex/ivars/base.py b/reflex/ivars/base.py index f98d499a3..4cd3550dd 100644 --- a/reflex/ivars/base.py +++ b/reflex/ivars/base.py @@ -20,6 +20,7 @@ from typing import ( Generic, List, Literal, + NoReturn, Optional, Sequence, Set, @@ -384,10 +385,18 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): return self.to(BooleanVar, output) if issubclass(output, NumberVar): - if fixed_type is not None and not issubclass(fixed_type, (int, float)): - raise TypeError( - f"Unsupported type {var_type} for NumberVar. Must be int or float." - ) + if fixed_type is not None: + if fixed_type is Union: + inner_types = get_args(base_type) + if not all(issubclass(t, (int, float)) for t in inner_types): + raise TypeError( + f"Unsupported type {var_type} for NumberVar. Must be int or float." + ) + + elif not issubclass(fixed_type, (int, float)): + raise TypeError( + f"Unsupported type {var_type} for NumberVar. Must be int or float." + ) return ToNumberVarOperation.create(self, var_type or float) if issubclass(output, BooleanVar): @@ -440,7 +449,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Raises: TypeError: If the type is not supported for guessing. """ - from .number import NumberVar + from .number import BooleanVar, NumberVar from .object import ObjectVar from .sequence import ArrayVar, StringVar @@ -454,11 +463,16 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): fixed_type = get_origin(var_type) or var_type if fixed_type is Union: + inner_types = get_args(var_type) + if int in inner_types and float in inner_types: + return self.to(NumberVar, self._var_type) return self if not inspect.isclass(fixed_type): raise TypeError(f"Unsupported type {var_type} for guess_type.") + if issubclass(fixed_type, bool): + return self.to(BooleanVar, self._var_type) if issubclass(fixed_type, (int, float)): return self.to(NumberVar, self._var_type) if issubclass(fixed_type, dict): @@ -570,9 +584,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: BooleanVar: A BooleanVar object representing the result of the equality check. """ - from .number import EqualOperation + from .number import equal_operation - return EqualOperation.create(self, other) + return equal_operation(self, other) def __ne__(self, other: Var | Any) -> BooleanVar: """Check if the current object is not equal to the given object. @@ -583,9 +597,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: BooleanVar: A BooleanVar object representing the result of the comparison. """ - from .number import EqualOperation + from .number import equal_operation - return ~EqualOperation.create(self, other) + return ~equal_operation(self, other) def __gt__(self, other: Var | Any) -> BooleanVar: """Compare the current instance with another variable and return a BooleanVar representing the result of the greater than operation. @@ -596,9 +610,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: BooleanVar: A BooleanVar representing the result of the greater than operation. """ - from .number import GreaterThanOperation + from .number import greater_than_operation - return GreaterThanOperation.create(self, other) + return greater_than_operation(self, other) def __ge__(self, other: Var | Any) -> BooleanVar: """Check if the value of this variable is greater than or equal to the value of another variable or object. @@ -609,9 +623,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: BooleanVar: A BooleanVar object representing the result of the comparison. """ - from .number import GreaterThanOrEqualOperation + from .number import greater_than_or_equal_operation - return GreaterThanOrEqualOperation.create(self, other) + return greater_than_or_equal_operation(self, other) def __lt__(self, other: Var | Any) -> BooleanVar: """Compare the current instance with another variable using the less than (<) operator. @@ -622,9 +636,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the comparison. """ - from .number import LessThanOperation + from .number import less_than_operation - return LessThanOperation.create(self, other) + return less_than_operation(self, other) def __le__(self, other: Var | Any) -> BooleanVar: """Compare if the current instance is less than or equal to the given value. @@ -635,9 +649,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A BooleanVar object representing the result of the comparison. """ - from .number import LessThanOrEqualOperation + from .number import less_than_or_equal_operation - return LessThanOrEqualOperation.create(self, other) + return less_than_or_equal_operation(self, other) def bool(self) -> BooleanVar: """Convert the var to a boolean. @@ -645,9 +659,9 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: The boolean var. """ - from .number import ToBooleanVarOperation + from .number import boolify - return ToBooleanVarOperation.create(self) + return boolify(self) def __and__(self, other: Var | Any) -> ImmutableVar: """Perform a logical AND operation on the current instance and another variable. @@ -658,7 +672,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the logical AND operation. """ - return AndOperation.create(self, other) + return and_operation(self, other) def __rand__(self, other: Var | Any) -> ImmutableVar: """Perform a logical AND operation on the current instance and another variable. @@ -669,7 +683,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the logical AND operation. """ - return AndOperation.create(other, self) + return and_operation(other, self) def __or__(self, other: Var | Any) -> ImmutableVar: """Perform a logical OR operation on the current instance and another variable. @@ -680,7 +694,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the logical OR operation. """ - return OrOperation.create(self, other) + return or_operation(self, other) def __ror__(self, other: Var | Any) -> ImmutableVar: """Perform a logical OR operation on the current instance and another variable. @@ -691,7 +705,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the logical OR operation. """ - return OrOperation.create(other, self) + return or_operation(other, self) def __invert__(self) -> BooleanVar: """Perform a logical NOT operation on the current instance. @@ -699,9 +713,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): Returns: A `BooleanVar` object representing the result of the logical NOT operation. """ - from .number import BooleanNotOperation - - return BooleanNotOperation.create(self.bool()) + return ~self.bool() def to_string(self) -> ImmutableVar: """Convert the var to a string. @@ -926,52 +938,92 @@ class LiteralVar(ImmutableVar): P = ParamSpec("P") -T = TypeVar("T", bound=ImmutableVar) +T = TypeVar("T") -def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P, T]]: +# NoReturn is used to match CustomVarOperationReturn with no type hint. +@overload +def var_operation( + func: Callable[P, CustomVarOperationReturn[NoReturn]], +) -> Callable[P, ImmutableVar]: ... + + +@overload +def var_operation( + func: Callable[P, CustomVarOperationReturn[bool]], +) -> Callable[P, BooleanVar]: ... + + +NUMBER_T = TypeVar("NUMBER_T", int, float, Union[int, float]) + + +@overload +def var_operation( + func: Callable[P, CustomVarOperationReturn[NUMBER_T]], +) -> Callable[P, NumberVar[NUMBER_T]]: ... + + +@overload +def var_operation( + func: Callable[P, CustomVarOperationReturn[str]], +) -> Callable[P, StringVar]: ... + + +LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set]) + + +@overload +def var_operation( + func: Callable[P, CustomVarOperationReturn[LIST_T]], +) -> Callable[P, ArrayVar[LIST_T]]: ... + + +OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict) + + +@overload +def var_operation( + func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]], +) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ... + + +def var_operation( + func: Callable[P, CustomVarOperationReturn[T]], +) -> Callable[P, ImmutableVar[T]]: """Decorator for creating a var operation. Example: ```python - @var_operation(output=NumberVar) + @var_operation def add(a: NumberVar, b: NumberVar): - return f"({a} + {b})" + return custom_var_operation(f"{a} + {b}") ``` Args: - output: The output type of the operation. + func: The function to decorate. Returns: - The decorator. + The decorated function. """ - def decorator(func: Callable[P, str], output=output): - @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - args_vars = [ - LiteralVar.create(arg) if not isinstance(arg, Var) else arg - for arg in args - ] - kwargs_vars = { - key: LiteralVar.create(value) if not isinstance(value, Var) else value - for key, value in kwargs.items() - } - return output( - _var_name=func(*args_vars, **kwargs_vars), # type: ignore - _var_data=VarData.merge( - *[arg._get_all_var_data() for arg in args if isinstance(arg, Var)], - *[ - arg._get_all_var_data() - for arg in kwargs.values() - if isinstance(arg, Var) - ], - ), - ) + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> ImmutableVar[T]: + func_args = list(inspect.signature(func).parameters) + args_vars = { + func_args[i]: (LiteralVar.create(arg) if not isinstance(arg, Var) else arg) + for i, arg in enumerate(args) + } + kwargs_vars = { + key: LiteralVar.create(value) if not isinstance(value, Var) else value + for key, value in kwargs.items() + } - return wrapper + return CustomVarOperation.create( + args=tuple(list(args_vars.items()) + list(kwargs_vars.items())), + return_var=func(*args_vars.values(), **kwargs_vars), # type: ignore + ).guess_type() - return decorator + return wrapper def unionize(*args: Type) -> Type: @@ -1100,114 +1152,64 @@ class CachedVarOperation: ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class AndOperation(CachedVarOperation, ImmutableVar): - """Class for the logical AND operation.""" +def and_operation(a: Var | Any, b: Var | Any) -> ImmutableVar: + """Perform a logical AND operation on two variables. - # The first var. - _var1: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) + Args: + a: The first variable. + b: The second variable. - # The second var. - _var2: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """Get the cached var name. - - Returns: - The cached var name. - """ - return f"({str(self._var1)} && {str(self._var2)})" - - def __hash__(self) -> int: - """Calculates the hash value of the object. - - Returns: - int: The hash value of the object. - """ - return hash((self.__class__.__name__, self._var1, self._var2)) - - @classmethod - def create( - cls, var1: Var | Any, var2: Var | Any, _var_data: VarData | None = None - ) -> AndOperation: - """Create an AndOperation. - - Args: - var1: The first var. - var2: The second var. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The AndOperation. - """ - var1, var2 = map(LiteralVar.create, (var1, var2)) - return AndOperation( - _var_name="", - _var_type=unionize(var1._var_type, var2._var_type), - _var_data=ImmutableVarData.merge(_var_data), - _var1=var1, - _var2=var2, - ) + Returns: + The result of the logical AND operation. + """ + return _and_operation(a, b) # type: ignore -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class OrOperation(CachedVarOperation, ImmutableVar): - """Class for the logical OR operation.""" +@var_operation +def _and_operation(a: ImmutableVar, b: ImmutableVar): + """Perform a logical AND operation on two variables. - # The first var. - _var1: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) + Args: + a: The first variable. + b: The second variable. - # The second var. - _var2: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) + Returns: + The result of the logical AND operation. + """ + return var_operation_return( + js_expression=f"({a} && {b})", + var_type=unionize(a._var_type, b._var_type), + ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """Get the cached var name. - Returns: - The cached var name. - """ - return f"({str(self._var1)} || {str(self._var2)})" +def or_operation(a: Var | Any, b: Var | Any) -> ImmutableVar: + """Perform a logical OR operation on two variables. - def __hash__(self) -> int: - """Calculates the hash value for the object. + Args: + a: The first variable. + b: The second variable. - Returns: - int: The hash value of the object. - """ - return hash((self.__class__.__name__, self._var1, self._var2)) + Returns: + The result of the logical OR operation. + """ + return _or_operation(a, b) # type: ignore - @classmethod - def create( - cls, var1: Var | Any, var2: Var | Any, _var_data: VarData | None = None - ) -> OrOperation: - """Create an OrOperation. - Args: - var1: The first var. - var2: The second var. - _var_data: Additional hooks and imports associated with the Var. +@var_operation +def _or_operation(a: ImmutableVar, b: ImmutableVar): + """Perform a logical OR operation on two variables. - Returns: - The OrOperation. - """ - var1, var2 = map(LiteralVar.create, (var1, var2)) - return OrOperation( - _var_name="", - _var_type=unionize(var1._var_type, var2._var_type), - _var_data=ImmutableVarData.merge(_var_data), - _var1=var1, - _var2=var2, - ) + Args: + a: The first variable. + b: The second variable. + + Returns: + The result of the logical OR operation. + """ + return var_operation_return( + js_expression=f"({a} || {b})", + var_type=unionize(a._var_type, b._var_type), + ) @dataclasses.dataclass( @@ -1797,3 +1799,114 @@ def immutable_computed_var( ) return wrapper + + +RETURN = TypeVar("RETURN") + + +class CustomVarOperationReturn(ImmutableVar[RETURN]): + """Base class for custom var operations.""" + + @classmethod + def create( + cls, + js_expression: str, + _var_type: Type[RETURN] | None = None, + _var_data: VarData | None = None, + ) -> CustomVarOperationReturn[RETURN]: + """Create a CustomVarOperation. + + Args: + js_expression: The JavaScript expression to evaluate. + _var_type: The type of the var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The CustomVarOperation. + """ + return CustomVarOperationReturn( + _var_name=js_expression, + _var_type=_var_type or Any, + _var_data=ImmutableVarData.merge(_var_data), + ) + + +def var_operation_return( + js_expression: str, + var_type: Type[RETURN] | None = None, +) -> CustomVarOperationReturn[RETURN]: + """Shortcut for creating a CustomVarOperationReturn. + + Args: + js_expression: The JavaScript expression to evaluate. + var_type: The type of the var. + + Returns: + The CustomVarOperationReturn. + """ + return CustomVarOperationReturn.create(js_expression, var_type) + + +@dataclasses.dataclass( + eq=False, + frozen=True, + **{"slots": True} if sys.version_info >= (3, 10) else {}, +) +class CustomVarOperation(CachedVarOperation, ImmutableVar[T]): + """Base class for custom var operations.""" + + _args: Tuple[Tuple[str, Var], ...] = dataclasses.field(default_factory=tuple) + + _return: CustomVarOperationReturn[T] = dataclasses.field( + default_factory=lambda: CustomVarOperationReturn.create("") + ) + + @cached_property_no_lock + def _cached_var_name(self) -> str: + """Get the cached var name. + + Returns: + The cached var name. + """ + return str(self._return) + + @cached_property_no_lock + def _cached_get_all_var_data(self) -> ImmutableVarData | None: + """Get the cached VarData. + + Returns: + The cached VarData. + """ + return ImmutableVarData.merge( + *map( + lambda arg: arg[1]._get_all_var_data(), + self._args, + ), + self._return._get_all_var_data(), + self._var_data, + ) + + @classmethod + def create( + cls, + args: Tuple[Tuple[str, Var], ...], + return_var: CustomVarOperationReturn[T], + _var_data: VarData | None = None, + ) -> CustomVarOperation[T]: + """Create a CustomVarOperation. + + Args: + args: The arguments to the operation. + return_var: The return var. + _var_data: Additional hooks and imports associated with the Var. + + Returns: + The CustomVarOperation. + """ + return CustomVarOperation( + _var_name="", + _var_type=return_var._var_type, + _var_data=ImmutableVarData.merge(_var_data), + _args=args, + _return=return_var, + ) diff --git a/reflex/ivars/number.py b/reflex/ivars/number.py index 241a1b595..68c856d18 100644 --- a/reflex/ivars/number.py +++ b/reflex/ivars/number.py @@ -5,23 +5,28 @@ from __future__ import annotations import dataclasses import json import sys -from typing import Any, Union +from typing import Any, Callable, TypeVar, Union from reflex.vars import ImmutableVarData, Var, VarData from .base import ( CachedVarOperation, + CustomVarOperationReturn, ImmutableVar, LiteralVar, cached_property_no_lock, unionize, + var_operation, + var_operation_return, ) +NUMBER_T = TypeVar("NUMBER_T", int, float, Union[int, float]) -class NumberVar(ImmutableVar[Union[int, float]]): + +class NumberVar(ImmutableVar[NUMBER_T]): """Base class for immutable number vars.""" - def __add__(self, other: number_types | boolean_types) -> NumberAddOperation: + def __add__(self, other: number_types | boolean_types): """Add two numbers. Args: @@ -30,9 +35,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number addition operation. """ - return NumberAddOperation.create(self, +other) + return number_add_operation(self, +other) - def __radd__(self, other: number_types | boolean_types) -> NumberAddOperation: + def __radd__(self, other: number_types | boolean_types): """Add two numbers. Args: @@ -41,9 +46,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number addition operation. """ - return NumberAddOperation.create(+other, self) + return number_add_operation(+other, self) - def __sub__(self, other: number_types | boolean_types) -> NumberSubtractOperation: + def __sub__(self, other: number_types | boolean_types): """Subtract two numbers. Args: @@ -52,9 +57,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number subtraction operation. """ - return NumberSubtractOperation.create(self, +other) + return number_subtract_operation(self, +other) - def __rsub__(self, other: number_types | boolean_types) -> NumberSubtractOperation: + def __rsub__(self, other: number_types | boolean_types): """Subtract two numbers. Args: @@ -63,17 +68,17 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number subtraction operation. """ - return NumberSubtractOperation.create(+other, self) + return number_subtract_operation(+other, self) - def __abs__(self) -> NumberAbsoluteOperation: + def __abs__(self): """Get the absolute value of the number. Returns: The number absolute operation. """ - return NumberAbsoluteOperation.create(self) + return number_abs_operation(self) - def __mul__(self, other: number_types | boolean_types) -> NumberMultiplyOperation: + def __mul__(self, other: number_types | boolean_types): """Multiply two numbers. Args: @@ -82,9 +87,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number multiplication operation. """ - return NumberMultiplyOperation.create(self, +other) + return number_multiply_operation(self, +other) - def __rmul__(self, other: number_types | boolean_types) -> NumberMultiplyOperation: + def __rmul__(self, other: number_types | boolean_types): """Multiply two numbers. Args: @@ -93,9 +98,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number multiplication operation. """ - return NumberMultiplyOperation.create(+other, self) + return number_multiply_operation(+other, self) - def __truediv__(self, other: number_types | boolean_types) -> NumberTrueDivision: + def __truediv__(self, other: number_types | boolean_types): """Divide two numbers. Args: @@ -104,9 +109,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number true division operation. """ - return NumberTrueDivision.create(self, +other) + return number_true_division_operation(self, +other) - def __rtruediv__(self, other: number_types | boolean_types) -> NumberTrueDivision: + def __rtruediv__(self, other: number_types | boolean_types): """Divide two numbers. Args: @@ -115,9 +120,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number true division operation. """ - return NumberTrueDivision.create(+other, self) + return number_true_division_operation(+other, self) - def __floordiv__(self, other: number_types | boolean_types) -> NumberFloorDivision: + def __floordiv__(self, other: number_types | boolean_types): """Floor divide two numbers. Args: @@ -126,9 +131,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number floor division operation. """ - return NumberFloorDivision.create(self, +other) + return number_floor_division_operation(self, +other) - def __rfloordiv__(self, other: number_types | boolean_types) -> NumberFloorDivision: + def __rfloordiv__(self, other: number_types | boolean_types): """Floor divide two numbers. Args: @@ -137,9 +142,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number floor division operation. """ - return NumberFloorDivision.create(+other, self) + return number_floor_division_operation(+other, self) - def __mod__(self, other: number_types | boolean_types) -> NumberModuloOperation: + def __mod__(self, other: number_types | boolean_types): """Modulo two numbers. Args: @@ -148,9 +153,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number modulo operation. """ - return NumberModuloOperation.create(self, +other) + return number_modulo_operation(self, +other) - def __rmod__(self, other: number_types | boolean_types) -> NumberModuloOperation: + def __rmod__(self, other: number_types | boolean_types): """Modulo two numbers. Args: @@ -159,9 +164,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number modulo operation. """ - return NumberModuloOperation.create(+other, self) + return number_modulo_operation(+other, self) - def __pow__(self, other: number_types | boolean_types) -> NumberExponentOperation: + def __pow__(self, other: number_types | boolean_types): """Exponentiate two numbers. Args: @@ -170,9 +175,9 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number exponent operation. """ - return NumberExponentOperation.create(self, +other) + return number_exponent_operation(self, +other) - def __rpow__(self, other: number_types | boolean_types) -> NumberExponentOperation: + def __rpow__(self, other: number_types | boolean_types): """Exponentiate two numbers. Args: @@ -181,23 +186,23 @@ class NumberVar(ImmutableVar[Union[int, float]]): Returns: The number exponent operation. """ - return NumberExponentOperation.create(+other, self) + return number_exponent_operation(+other, self) - def __neg__(self) -> NumberNegateOperation: + def __neg__(self): """Negate the number. Returns: The number negation operation. """ - return NumberNegateOperation.create(self) + return number_negate_operation(self) - def __invert__(self) -> BooleanNotOperation: + def __invert__(self): """Boolean NOT the number. Returns: The boolean NOT operation. """ - return BooleanNotOperation.create(self.bool()) + return boolean_not_operation(self.bool()) def __pos__(self) -> NumberVar: """Positive the number. @@ -207,39 +212,39 @@ class NumberVar(ImmutableVar[Union[int, float]]): """ return self - def __round__(self) -> NumberRoundOperation: + def __round__(self): """Round the number. Returns: The number round operation. """ - return NumberRoundOperation.create(self) + return number_round_operation(self) - def __ceil__(self) -> NumberCeilOperation: + def __ceil__(self): """Ceil the number. Returns: The number ceil operation. """ - return NumberCeilOperation.create(self) + return number_ceil_operation(self) - def __floor__(self) -> NumberFloorOperation: + def __floor__(self): """Floor the number. Returns: The number floor operation. """ - return NumberFloorOperation.create(self) + return number_floor_operation(self) - def __trunc__(self) -> NumberTruncOperation: + def __trunc__(self): """Trunc the number. Returns: The number trunc operation. """ - return NumberTruncOperation.create(self) + return number_trunc_operation(self) - def __lt__(self, other: Any) -> LessThanOperation: + def __lt__(self, other: Any): """Less than comparison. Args: @@ -249,10 +254,10 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return LessThanOperation.create(self, +other) - return LessThanOperation.create(self, other) + return less_than_operation(self, +other) + return less_than_operation(self, other) - def __le__(self, other: Any) -> LessThanOrEqualOperation: + def __le__(self, other: Any): """Less than or equal comparison. Args: @@ -262,10 +267,10 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return LessThanOrEqualOperation.create(self, +other) - return LessThanOrEqualOperation.create(self, other) + return less_than_or_equal_operation(self, +other) + return less_than_or_equal_operation(self, other) - def __eq__(self, other: Any) -> EqualOperation: + def __eq__(self, other: Any): """Equal comparison. Args: @@ -275,10 +280,10 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return EqualOperation.create(self, +other) - return EqualOperation.create(self, other) + return equal_operation(self, +other) + return equal_operation(self, other) - def __ne__(self, other: Any) -> NotEqualOperation: + def __ne__(self, other: Any): """Not equal comparison. Args: @@ -288,10 +293,10 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return NotEqualOperation.create(self, +other) - return NotEqualOperation.create(self, other) + return not_equal_operation(self, +other) + return not_equal_operation(self, other) - def __gt__(self, other: Any) -> GreaterThanOperation: + def __gt__(self, other: Any): """Greater than comparison. Args: @@ -301,10 +306,10 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return GreaterThanOperation.create(self, +other) - return GreaterThanOperation.create(self, other) + return greater_than_operation(self, +other) + return greater_than_operation(self, other) - def __ge__(self, other: Any) -> GreaterThanOrEqualOperation: + def __ge__(self, other: Any): """Greater than or equal comparison. Args: @@ -314,308 +319,258 @@ class NumberVar(ImmutableVar[Union[int, float]]): The result of the comparison. """ if isinstance(other, (NumberVar, BooleanVar, int, float, bool)): - return GreaterThanOrEqualOperation.create(self, +other) - return GreaterThanOrEqualOperation.create(self, other) + return greater_than_or_equal_operation(self, +other) + return greater_than_or_equal_operation(self, other) - def bool(self) -> NotEqualOperation: + def bool(self): """Boolean conversion. Returns: The boolean value of the number. """ - return NotEqualOperation.create(self, 0) + return self != 0 -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class BinaryNumberOperation(CachedVarOperation, NumberVar): - """Base class for immutable number vars that are the result of a binary operation.""" +def binary_number_operation( + func: Callable[[NumberVar, NumberVar], str], +) -> Callable[[number_types, number_types], NumberVar]: + """Decorator to create a binary number operation. - _lhs: NumberVar = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(0) - ) - _rhs: NumberVar = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(0) - ) + Args: + func: The binary number operation function. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Returns: + The binary number operation. + """ - Raises: - NotImplementedError: Must be implemented by subclasses - """ - raise NotImplementedError( - "BinaryNumberOperation must implement _cached_var_name" + @var_operation + def operation(lhs: NumberVar, rhs: NumberVar): + return var_operation_return( + js_expression=func(lhs, rhs), + var_type=unionize(lhs._var_type, rhs._var_type), ) - @classmethod - def create( - cls, lhs: number_types, rhs: number_types, _var_data: VarData | None = None - ): - """Create the binary number operation var. + def wrapper(lhs: number_types, rhs: number_types) -> NumberVar: + """Create the binary number operation. Args: lhs: The first number. rhs: The second number. - _var_data: Additional hooks and imports associated with the Var. Returns: - The binary number operation var. + The binary number operation. """ - _lhs, _rhs = map( - lambda v: LiteralNumberVar.create(v) if not isinstance(v, NumberVar) else v, - (lhs, rhs), - ) - return cls( - _var_name="", - _var_type=unionize(_lhs._var_type, _rhs._var_type), - _var_data=ImmutableVarData.merge(_var_data), - _lhs=_lhs, - _rhs=_rhs, - ) + return operation(lhs, rhs) # type: ignore + + return wrapper -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class UnaryNumberOperation(CachedVarOperation, NumberVar): - """Base class for immutable number vars that are the result of a unary operation.""" +@binary_number_operation +def number_add_operation(lhs: NumberVar, rhs: NumberVar): + """Add two numbers. - _value: NumberVar = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(0) + Args: + lhs: The first number. + rhs: The second number. + + Returns: + The number addition operation. + """ + return f"({lhs} + {rhs})" + + +@binary_number_operation +def number_subtract_operation(lhs: NumberVar, rhs: NumberVar): + """Subtract two numbers. + + Args: + lhs: The first number. + rhs: The second number. + + Returns: + The number subtraction operation. + """ + return f"({lhs} - {rhs})" + + +@var_operation +def number_abs_operation(value: NumberVar): + """Get the absolute value of the number. + + Args: + value: The number. + + Returns: + The number absolute operation. + """ + return var_operation_return( + js_expression=f"Math.abs({value})", var_type=value._var_type ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - Raises: - NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError( - "UnaryNumberOperation must implement _cached_var_name" - ) +@binary_number_operation +def number_multiply_operation(lhs: NumberVar, rhs: NumberVar): + """Multiply two numbers. - @classmethod - def create(cls, value: NumberVar, _var_data: VarData | None = None): - """Create the unary number operation var. + Args: + lhs: The first number. + rhs: The second number. - Args: - value: The number. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The unary number operation var. - """ - return cls( - _var_name="", - _var_type=value._var_type, - _var_data=ImmutableVarData.merge(_var_data), - _value=value, - ) + Returns: + The number multiplication operation. + """ + return f"({lhs} * {rhs})" -class NumberAddOperation(BinaryNumberOperation): - """Base class for immutable number vars that are the result of an addition operation.""" +@var_operation +def number_negate_operation( + value: NumberVar[NUMBER_T], +) -> CustomVarOperationReturn[NUMBER_T]: + """Negate the number. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + value: The number. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} + {str(self._rhs)})" + Returns: + The number negation operation. + """ + return var_operation_return(js_expression=f"-({value})", var_type=value._var_type) -class NumberSubtractOperation(BinaryNumberOperation): - """Base class for immutable number vars that are the result of a subtraction operation.""" +@binary_number_operation +def number_true_division_operation(lhs: NumberVar, rhs: NumberVar): + """Divide two numbers. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + lhs: The first number. + rhs: The second number. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} - {str(self._rhs)})" + Returns: + The number true division operation. + """ + return f"({lhs} / {rhs})" -class NumberAbsoluteOperation(UnaryNumberOperation): - """Base class for immutable number vars that are the result of an absolute operation.""" +@binary_number_operation +def number_floor_division_operation(lhs: NumberVar, rhs: NumberVar): + """Floor divide two numbers. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + lhs: The first number. + rhs: The second number. - Returns: - The name of the var. - """ - return f"Math.abs({str(self._value)})" + Returns: + The number floor division operation. + """ + return f"Math.floor({lhs} / {rhs})" -class NumberMultiplyOperation(BinaryNumberOperation): - """Base class for immutable number vars that are the result of a multiplication operation.""" +@binary_number_operation +def number_modulo_operation(lhs: NumberVar, rhs: NumberVar): + """Modulo two numbers. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + lhs: The first number. + rhs: The second number. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} * {str(self._rhs)})" + Returns: + The number modulo operation. + """ + return f"({lhs} % {rhs})" -class NumberNegateOperation(UnaryNumberOperation): - """Base class for immutable number vars that are the result of a negation operation.""" +@binary_number_operation +def number_exponent_operation(lhs: NumberVar, rhs: NumberVar): + """Exponentiate two numbers. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + lhs: The first number. + rhs: The second number. - Returns: - The name of the var. - """ - return f"-({str(self._value)})" + Returns: + The number exponent operation. + """ + return f"({lhs} ** {rhs})" -class NumberTrueDivision(BinaryNumberOperation): - """Base class for immutable number vars that are the result of a true division operation.""" +@var_operation +def number_round_operation(value: NumberVar): + """Round the number. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + value: The number. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} / {str(self._rhs)})" + Returns: + The number round operation. + """ + return var_operation_return(js_expression=f"Math.round({value})", var_type=int) -class NumberFloorDivision(BinaryNumberOperation): - """Base class for immutable number vars that are the result of a floor division operation.""" +@var_operation +def number_ceil_operation(value: NumberVar): + """Ceil the number. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + value: The number. - Returns: - The name of the var. - """ - return f"Math.floor({str(self._lhs)} / {str(self._rhs)})" + Returns: + The number ceil operation. + """ + return var_operation_return(js_expression=f"Math.ceil({value})", var_type=int) -class NumberModuloOperation(BinaryNumberOperation): - """Base class for immutable number vars that are the result of a modulo operation.""" +@var_operation +def number_floor_operation(value: NumberVar): + """Floor the number. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + value: The number. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} % {str(self._rhs)})" + Returns: + The number floor operation. + """ + return var_operation_return(js_expression=f"Math.floor({value})", var_type=int) -class NumberExponentOperation(BinaryNumberOperation): - """Base class for immutable number vars that are the result of an exponent operation.""" +@var_operation +def number_trunc_operation(value: NumberVar): + """Trunc the number. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + value: The number. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} ** {str(self._rhs)})" - - -class NumberRoundOperation(UnaryNumberOperation): - """Base class for immutable number vars that are the result of a round operation.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"Math.round({str(self._value)})" - - -class NumberCeilOperation(UnaryNumberOperation): - """Base class for immutable number vars that are the result of a ceil operation.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"Math.ceil({str(self._value)})" - - -class NumberFloorOperation(UnaryNumberOperation): - """Base class for immutable number vars that are the result of a floor operation.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"Math.floor({str(self._value)})" - - -class NumberTruncOperation(UnaryNumberOperation): - """Base class for immutable number vars that are the result of a trunc operation.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"Math.trunc({str(self._value)})" + Returns: + The number trunc operation. + """ + return var_operation_return(js_expression=f"Math.trunc({value})", var_type=int) class BooleanVar(ImmutableVar[bool]): """Base class for immutable boolean vars.""" - def __invert__(self) -> BooleanNotOperation: + def __invert__(self): """NOT the boolean. Returns: The boolean NOT operation. """ - return BooleanNotOperation.create(self) + return boolean_not_operation(self) - def __int__(self) -> BooleanToIntOperation: + def __int__(self): """Convert the boolean to an int. Returns: The boolean to int operation. """ - return BooleanToIntOperation.create(self) + return boolean_to_number_operation(self) - def __pos__(self) -> BooleanToIntOperation: + def __pos__(self): """Convert the boolean to an int. Returns: The boolean to int operation. """ - return BooleanToIntOperation.create(self) + return boolean_to_number_operation(self) def bool(self) -> BooleanVar: """Boolean conversion. @@ -625,7 +580,7 @@ class BooleanVar(ImmutableVar[bool]): """ return self - def __lt__(self, other: boolean_types | number_types) -> LessThanOperation: + def __lt__(self, other: boolean_types | number_types): """Less than comparison. Args: @@ -634,9 +589,9 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return LessThanOperation.create(+self, +other) + return less_than_operation(+self, +other) - def __le__(self, other: boolean_types | number_types) -> LessThanOrEqualOperation: + def __le__(self, other: boolean_types | number_types): """Less than or equal comparison. Args: @@ -645,9 +600,9 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return LessThanOrEqualOperation.create(+self, +other) + return less_than_or_equal_operation(+self, +other) - def __eq__(self, other: boolean_types | number_types) -> EqualOperation: + def __eq__(self, other: boolean_types | number_types): """Equal comparison. Args: @@ -656,9 +611,9 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return EqualOperation.create(+self, +other) + return equal_operation(+self, +other) - def __ne__(self, other: boolean_types | number_types) -> NotEqualOperation: + def __ne__(self, other: boolean_types | number_types): """Not equal comparison. Args: @@ -667,9 +622,9 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return NotEqualOperation.create(+self, +other) + return not_equal_operation(+self, +other) - def __gt__(self, other: boolean_types | number_types) -> GreaterThanOperation: + def __gt__(self, other: boolean_types | number_types): """Greater than comparison. Args: @@ -678,11 +633,9 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return GreaterThanOperation.create(+self, +other) + return greater_than_operation(+self, +other) - def __ge__( - self, other: boolean_types | number_types - ) -> GreaterThanOrEqualOperation: + def __ge__(self, other: boolean_types | number_types): """Greater than or equal comparison. Args: @@ -691,265 +644,151 @@ class BooleanVar(ImmutableVar[bool]): Returns: The result of the comparison. """ - return GreaterThanOrEqualOperation.create(+self, +other) + return greater_than_or_equal_operation(+self, +other) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class BooleanToIntOperation(CachedVarOperation, NumberVar): - """Base class for immutable number vars that are the result of a boolean to int operation.""" +@var_operation +def boolean_to_number_operation(value: BooleanVar): + """Convert the boolean to a number. - _value: BooleanVar = dataclasses.field( - default_factory=lambda: LiteralBooleanVar.create(False) - ) + Args: + value: The boolean. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Returns: + The boolean to number operation. + """ + return var_operation_return(js_expression=f"Number({value})", var_type=int) - Returns: - The name of the var. - """ - return f"({str(self._value)} ? 1 : 0)" - @classmethod - def create(cls, value: BooleanVar, _var_data: VarData | None = None): - """Create the boolean to int operation var. +def comparison_operator( + func: Callable[[Var, Var], str], +) -> Callable[[Var | Any, Var | Any], BooleanVar]: + """Decorator to create a comparison operation. - Args: - value: The boolean. - _var_data: Additional hooks and imports associated with the Var. + Args: + func: The comparison operation function. - Returns: - The boolean to int operation var. - """ - return cls( - _var_name="", - _var_type=int, - _var_data=ImmutableVarData.merge(_var_data), - _value=value, + Returns: + The comparison operation. + """ + + @var_operation + def operation(lhs: Var, rhs: Var): + return var_operation_return( + js_expression=func(lhs, rhs), + var_type=bool, ) - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ComparisonOperation(CachedVarOperation, BooleanVar): - """Base class for immutable boolean vars that are the result of a comparison operation.""" - - _lhs: Var = dataclasses.field( - default_factory=lambda: LiteralBooleanVar.create(False) - ) - _rhs: Var = dataclasses.field( - default_factory=lambda: LiteralBooleanVar.create(False) - ) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Raises: - NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError("ComparisonOperation must implement _cached_var_name") - - @classmethod - def create(cls, lhs: Var | Any, rhs: Var | Any, _var_data: VarData | None = None): - """Create the comparison operation var. + def wrapper(lhs: Var | Any, rhs: Var | Any) -> BooleanVar: + """Create the comparison operation. Args: lhs: The first value. rhs: The second value. - _var_data: Additional hooks and imports associated with the Var. Returns: - The comparison operation var. + The comparison operation. """ - lhs, rhs = map(LiteralVar.create, (lhs, rhs)) - return cls( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - _lhs=lhs, - _rhs=rhs, - ) + return operation(lhs, rhs) + + return wrapper -class GreaterThanOperation(ComparisonOperation): - """Base class for immutable boolean vars that are the result of a greater than operation.""" +@comparison_operator +def greater_than_operation(lhs: Var, rhs: Var): + """Greater than comparison. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + lhs: The first value. + rhs: The second value. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} > {str(self._rhs)})" + Returns: + The result of the comparison. + """ + return f"({lhs} > {rhs})" -class GreaterThanOrEqualOperation(ComparisonOperation): - """Base class for immutable boolean vars that are the result of a greater than or equal operation.""" +@comparison_operator +def greater_than_or_equal_operation(lhs: Var, rhs: Var): + """Greater than or equal comparison. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + lhs: The first value. + rhs: The second value. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} >= {str(self._rhs)})" + Returns: + The result of the comparison. + """ + return f"({lhs} >= {rhs})" -class LessThanOperation(ComparisonOperation): - """Base class for immutable boolean vars that are the result of a less than operation.""" +@comparison_operator +def less_than_operation(lhs: Var, rhs: Var): + """Less than comparison. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + lhs: The first value. + rhs: The second value. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} < {str(self._rhs)})" + Returns: + The result of the comparison. + """ + return f"({lhs} < {rhs})" -class LessThanOrEqualOperation(ComparisonOperation): - """Base class for immutable boolean vars that are the result of a less than or equal operation.""" +@comparison_operator +def less_than_or_equal_operation(lhs: Var, rhs: Var): + """Less than or equal comparison. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + lhs: The first value. + rhs: The second value. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} <= {str(self._rhs)})" + Returns: + The result of the comparison. + """ + return f"({lhs} <= {rhs})" -class EqualOperation(ComparisonOperation): - """Base class for immutable boolean vars that are the result of an equal operation.""" +@comparison_operator +def equal_operation(lhs: Var, rhs: Var): + """Equal comparison. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + lhs: The first value. + rhs: The second value. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} === {str(self._rhs)})" + Returns: + The result of the comparison. + """ + return f"({lhs} === {rhs})" -class NotEqualOperation(ComparisonOperation): - """Base class for immutable boolean vars that are the result of a not equal operation.""" +@comparison_operator +def not_equal_operation(lhs: Var, rhs: Var): + """Not equal comparison. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + lhs: The first value. + rhs: The second value. - Returns: - The name of the var. - """ - return f"({str(self._lhs)} !== {str(self._rhs)})" + Returns: + The result of the comparison. + """ + return f"({lhs} !== {rhs})" -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class LogicalOperation(CachedVarOperation, BooleanVar): - """Base class for immutable boolean vars that are the result of a logical operation.""" +@var_operation +def boolean_not_operation(value: BooleanVar): + """Boolean NOT the boolean. - _lhs: BooleanVar = dataclasses.field( - default_factory=lambda: LiteralBooleanVar.create(False) - ) - _rhs: BooleanVar = dataclasses.field( - default_factory=lambda: LiteralBooleanVar.create(False) - ) + Args: + value: The boolean. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Raises: - NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError("LogicalOperation must implement _cached_var_name") - - @classmethod - def create( - cls, lhs: boolean_types, rhs: boolean_types, _var_data: VarData | None = None - ): - """Create the logical operation var. - - Args: - lhs: The first boolean. - rhs: The second boolean. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The logical operation var. - """ - lhs, rhs = map( - lambda v: ( - LiteralBooleanVar.create(v) if not isinstance(v, BooleanVar) else v - ), - (lhs, rhs), - ) - return cls( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - _lhs=lhs, - _rhs=rhs, - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class BooleanNotOperation(CachedVarOperation, BooleanVar): - """Base class for immutable boolean vars that are the result of a logical NOT operation.""" - - _value: BooleanVar = dataclasses.field( - default_factory=lambda: LiteralBooleanVar.create(False) - ) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"!({str(self._value)})" - - @classmethod - def create(cls, value: boolean_types, _var_data: VarData | None = None): - """Create the logical NOT operation var. - - Args: - value: The value. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The logical NOT operation var. - """ - value = value if isinstance(value, Var) else LiteralBooleanVar.create(value) - return cls( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - _value=value, - ) + Returns: + The boolean NOT operation. + """ + return var_operation_return(js_expression=f"!({value})", var_type=bool) @dataclasses.dataclass( @@ -1111,7 +950,7 @@ class ToBooleanVarOperation(CachedVarOperation, BooleanVar): Returns: The name of the var. """ - return f"Boolean({str(self._original_value)})" + return str(self._original_value) @classmethod def create( @@ -1136,68 +975,35 @@ class ToBooleanVarOperation(CachedVarOperation, BooleanVar): ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class TernaryOperator(CachedVarOperation, ImmutableVar): - """Base class for immutable vars that are the result of a ternary operation.""" +@var_operation +def boolify(value: Var): + """Convert the value to a boolean. - _condition: BooleanVar = dataclasses.field( - default_factory=lambda: LiteralBooleanVar.create(False) - ) - _if_true: Var = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(0) - ) - _if_false: Var = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(0) + Args: + value: The value. + + Returns: + The boolean value. + """ + return var_operation_return( + js_expression=f"Boolean({value})", + var_type=bool, ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - Returns: - The name of the var. - """ - return ( - f"({str(self._condition)} ? {str(self._if_true)} : {str(self._if_false)})" - ) +@var_operation +def ternary_operation(condition: BooleanVar, if_true: Var, if_false: Var): + """Create a ternary operation. - @classmethod - def create( - cls, - condition: boolean_types, - if_true: Var | Any, - if_false: Var | Any, - _var_data: VarData | None = None, - ): - """Create the ternary operation var. + Args: + condition: The condition. + if_true: The value if the condition is true. + if_false: The value if the condition is false. - Args: - condition: The condition. - if_true: The value if the condition is true. - if_false: The value if the condition is false. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The ternary operation var. - """ - condition = ( - condition - if isinstance(condition, Var) - else LiteralBooleanVar.create(condition) - ) - _if_true, _if_false = map( - lambda v: (LiteralVar.create(v) if not isinstance(v, Var) else v), - (if_true, if_false), - ) - return TernaryOperator( - _var_name="", - _var_type=unionize(_if_true._var_type, _if_false._var_type), - _var_data=ImmutableVarData.merge(_var_data), - _condition=condition, - _if_true=_if_true, - _if_false=_if_false, - ) + Returns: + The ternary operation. + """ + return var_operation_return( + js_expression=f"({condition} ? {if_true} : {if_false})", + var_type=unionize(if_true._var_type, if_false._var_type), + ) diff --git a/reflex/ivars/object.py b/reflex/ivars/object.py index 5401b678a..49a21226f 100644 --- a/reflex/ivars/object.py +++ b/reflex/ivars/object.py @@ -30,6 +30,8 @@ from .base import ( LiteralVar, cached_property_no_lock, figure_out_type, + var_operation, + var_operation_return, ) from .number import BooleanVar, NumberVar from .sequence import ArrayVar, StringVar @@ -56,7 +58,9 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): return str @overload - def _value_type(self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]]) -> VALUE_TYPE: ... + def _value_type( + self: ObjectVar[Dict[KEY_TYPE, VALUE_TYPE]], + ) -> Type[VALUE_TYPE]: ... @overload def _value_type(self) -> Type: ... @@ -79,7 +83,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The keys of the object. """ - return ObjectKeysOperation.create(self) + return object_keys_operation(self) @overload def values( @@ -95,7 +99,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The values of the object. """ - return ObjectValuesOperation.create(self) + return object_values_operation(self) @overload def entries( @@ -111,9 +115,9 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The entries of the object. """ - return ObjectEntriesOperation.create(self) + return object_entries_operation(self) - def merge(self, other: ObjectVar) -> ObjectMergeOperation: + def merge(self, other: ObjectVar): """Merge two objects. Args: @@ -122,7 +126,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The merged object. """ - return ObjectMergeOperation.create(self, other) + return object_merge_operation(self, other) # NoReturn is used here to catch when key value is Any @overload @@ -270,7 +274,7 @@ class ObjectVar(ImmutableVar[OBJECT_TYPE]): Returns: The result of the check. """ - return ObjectHasOwnProperty.create(self, key) + return object_has_own_property_operation(self, key) @dataclasses.dataclass( @@ -387,207 +391,72 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ObjectToArrayOperation(CachedVarOperation, ArrayVar): - """Base class for object to array operations.""" +@var_operation +def object_keys_operation(value: ObjectVar): + """Get the keys of an object. - _value: ObjectVar = dataclasses.field( - default_factory=lambda: LiteralObjectVar.create({}) + Args: + value: The object to get the keys from. + + Returns: + The keys of the object. + """ + return var_operation_return( + js_expression=f"Object.keys({value})", + var_type=List[str], ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the operation. - Raises: - NotImplementedError: Must implement _cached_var_name. - """ - raise NotImplementedError( - "ObjectToArrayOperation must implement _cached_var_name" - ) +@var_operation +def object_values_operation(value: ObjectVar): + """Get the values of an object. - @classmethod - def create( - cls, - value: ObjectVar, - _var_type: GenericType | None = None, - _var_data: VarData | None = None, - ) -> ObjectToArrayOperation: - """Create the object to array operation. + Args: + value: The object to get the values from. - Args: - value: The value of the operation. - _var_data: Additional hooks and imports associated with the operation. - - Returns: - The object to array operation. - """ - return cls( - _var_name="", - _var_type=list if _var_type is None else _var_type, - _var_data=ImmutableVarData.merge(_var_data), - _value=value, - ) - - -class ObjectKeysOperation(ObjectToArrayOperation): - """Operation to get the keys of an object.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the operation. - - Returns: - The name of the operation. - """ - return f"Object.keys({str(self._value)})" - - @classmethod - def create( - cls, - value: ObjectVar, - _var_data: VarData | None = None, - ) -> ObjectKeysOperation: - """Create the object keys operation. - - Args: - value: The value of the operation. - _var_data: Additional hooks and imports associated with the operation. - - Returns: - The object keys operation. - """ - return cls( - _var_name="", - _var_type=List[str], - _var_data=ImmutableVarData.merge(_var_data), - _value=value, - ) - - -class ObjectValuesOperation(ObjectToArrayOperation): - """Operation to get the values of an object.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the operation. - - Returns: - The name of the operation. - """ - return f"Object.values({str(self._value)})" - - @classmethod - def create( - cls, - value: ObjectVar, - _var_data: VarData | None = None, - ) -> ObjectValuesOperation: - """Create the object values operation. - - Args: - value: The value of the operation. - _var_data: Additional hooks and imports associated with the operation. - - Returns: - The object values operation. - """ - return cls( - _var_name="", - _var_type=List[value._value_type()], - _var_data=ImmutableVarData.merge(_var_data), - _value=value, - ) - - -class ObjectEntriesOperation(ObjectToArrayOperation): - """Operation to get the entries of an object.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the operation. - - Returns: - The name of the operation. - """ - return f"Object.entries({str(self._value)})" - - @classmethod - def create( - cls, - value: ObjectVar, - _var_data: VarData | None = None, - ) -> ObjectEntriesOperation: - """Create the object entries operation. - - Args: - value: The value of the operation. - _var_data: Additional hooks and imports associated with the operation. - - Returns: - The object entries operation. - """ - return cls( - _var_name="", - _var_type=List[Tuple[str, value._value_type()]], - _var_data=ImmutableVarData.merge(_var_data), - _value=value, - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ObjectMergeOperation(CachedVarOperation, ObjectVar): - """Operation to merge two objects.""" - - _lhs: ObjectVar = dataclasses.field( - default_factory=lambda: LiteralObjectVar.create({}) - ) - _rhs: ObjectVar = dataclasses.field( - default_factory=lambda: LiteralObjectVar.create({}) + Returns: + The values of the object. + """ + return var_operation_return( + js_expression=f"Object.values({value})", + var_type=List[value._value_type()], ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the operation. - Returns: - The name of the operation. - """ - return f"({{...{str(self._lhs)}, ...{str(self._rhs)}}})" +@var_operation +def object_entries_operation(value: ObjectVar): + """Get the entries of an object. - @classmethod - def create( - cls, - lhs: ObjectVar, - rhs: ObjectVar, - _var_data: VarData | None = None, - ) -> ObjectMergeOperation: - """Create the object merge operation. + Args: + value: The object to get the entries from. - Args: - lhs: The left object to merge. - rhs: The right object to merge. - _var_data: Additional hooks and imports associated with the operation. + Returns: + The entries of the object. + """ + return var_operation_return( + js_expression=f"Object.entries({value})", + var_type=List[Tuple[str, value._value_type()]], + ) - Returns: - The object merge operation. - """ - # TODO: Figure out how to merge the types - return cls( - _var_name="", - _var_type=lhs._var_type, - _var_data=ImmutableVarData.merge(_var_data), - _lhs=lhs, - _rhs=rhs, - ) + +@var_operation +def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): + """Merge two objects. + + Args: + lhs: The first object to merge. + rhs: The second object to merge. + + Returns: + The merged object. + """ + return var_operation_return( + js_expression=f"({{...{lhs}, ...{rhs}}})", + var_type=Dict[ + Union[lhs._key_type(), rhs._key_type()], + Union[lhs._value_type(), rhs._value_type()], + ], + ) @dataclasses.dataclass( @@ -688,49 +557,18 @@ class ToObjectOperation(CachedVarOperation, ObjectVar): ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ObjectHasOwnProperty(CachedVarOperation, BooleanVar): - """Operation to check if an object has a property.""" +@var_operation +def object_has_own_property_operation(object: ObjectVar, key: Var): + """Check if an object has a key. - _object: ObjectVar = dataclasses.field( - default_factory=lambda: LiteralObjectVar.create({}) + Args: + object: The object to check. + key: The key to check. + + Returns: + The result of the check. + """ + return var_operation_return( + js_expression=f"{object}.hasOwnProperty({key})", + var_type=bool, ) - _key: Var | Any = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the operation. - - Returns: - The name of the operation. - """ - return f"{str(self._object)}.hasOwnProperty({str(self._key)})" - - @classmethod - def create( - cls, - object: ObjectVar, - key: Var | Any, - _var_data: VarData | None = None, - ) -> ObjectHasOwnProperty: - """Create the object has own property operation. - - Args: - object: The object to check. - key: The key to check. - _var_data: Additional hooks and imports associated with the operation. - - Returns: - The object has own property operation. - """ - return cls( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - _object=object, - _key=key if isinstance(key, Var) else LiteralVar.create(key), - ) diff --git a/reflex/ivars/sequence.py b/reflex/ivars/sequence.py index 47f5c6f01..25586e4df 100644 --- a/reflex/ivars/sequence.py +++ b/reflex/ivars/sequence.py @@ -34,16 +34,18 @@ from reflex.vars import ( from .base import ( CachedVarOperation, + CustomVarOperationReturn, ImmutableVar, LiteralVar, cached_property_no_lock, figure_out_type, unionize, + var_operation, + var_operation_return, ) from .number import ( BooleanVar, LiteralNumberVar, - NotEqualOperation, NumberVar, ) @@ -98,15 +100,7 @@ class StringVar(ImmutableVar[str]): """ return (self.split() * other).join() - @overload - def __getitem__(self, i: slice) -> ArrayJoinOperation: ... - - @overload - def __getitem__(self, i: int | NumberVar) -> StringItemOperation: ... - - def __getitem__( - self, i: slice | int | NumberVar - ) -> ArrayJoinOperation | StringItemOperation: + def __getitem__(self, i: slice | int | NumberVar) -> StringVar: """Get a slice of the string. Args: @@ -117,7 +111,7 @@ class StringVar(ImmutableVar[str]): """ if isinstance(i, slice): return self.split()[i].join() - return StringItemOperation.create(self, i) + return string_item_operation(self, i) def length(self) -> NumberVar: """Get the length of the string. @@ -133,7 +127,7 @@ class StringVar(ImmutableVar[str]): Returns: The string lower operation. """ - return StringLowerOperation.create(self) + return string_lower_operation(self) def upper(self) -> StringVar: """Convert the string to uppercase. @@ -141,7 +135,7 @@ class StringVar(ImmutableVar[str]): Returns: The string upper operation. """ - return StringUpperOperation.create(self) + return string_upper_operation(self) def strip(self) -> StringVar: """Strip the string. @@ -149,17 +143,17 @@ class StringVar(ImmutableVar[str]): Returns: The string strip operation. """ - return StringStripOperation.create(self) + return string_strip_operation(self) - def bool(self) -> NotEqualOperation: + def bool(self): """Boolean conversion. Returns: The boolean value of the string. """ - return NotEqualOperation.create(self.length(), 0) + return self.length() != 0 - def reversed(self) -> ArrayJoinOperation: + def reversed(self) -> StringVar: """Reverse the string. Returns: @@ -167,7 +161,7 @@ class StringVar(ImmutableVar[str]): """ return self.split().reverse().join() - def contains(self, other: StringVar | str) -> StringContainsOperation: + def contains(self, other: StringVar | str) -> BooleanVar: """Check if the string contains another string. Args: @@ -176,9 +170,9 @@ class StringVar(ImmutableVar[str]): Returns: The string contains operation. """ - return StringContainsOperation.create(self, other) + return string_contains_operation(self, other) - def split(self, separator: StringVar | str = "") -> StringSplitOperation: + def split(self, separator: StringVar | str = "") -> ArrayVar[List[str]]: """Split the string. Args: @@ -187,9 +181,9 @@ class StringVar(ImmutableVar[str]): Returns: The string split operation. """ - return StringSplitOperation.create(self, separator) + return string_split_operation(self, separator) - def startswith(self, prefix: StringVar | str) -> StringStartsWithOperation: + def startswith(self, prefix: StringVar | str) -> BooleanVar: """Check if the string starts with a prefix. Args: @@ -198,308 +192,106 @@ class StringVar(ImmutableVar[str]): Returns: The string starts with operation. """ - return StringStartsWithOperation.create(self, prefix) + return string_starts_with_operation(self, prefix) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class StringToStringOperation(CachedVarOperation, StringVar): - """Base class for immutable string vars that are the result of a string to string operation.""" +@var_operation +def string_lower_operation(string: StringVar): + """Convert a string to lowercase. - _value: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") + Args: + string: The string to convert. + + Returns: + The lowercase string. + """ + return var_operation_return(js_expression=f"{string}.toLowerCase()", var_type=str) + + +@var_operation +def string_upper_operation(string: StringVar): + """Convert a string to uppercase. + + Args: + string: The string to convert. + + Returns: + The uppercase string. + """ + return var_operation_return(js_expression=f"{string}.toUpperCase()", var_type=str) + + +@var_operation +def string_strip_operation(string: StringVar): + """Strip a string. + + Args: + string: The string to strip. + + Returns: + The stripped string. + """ + return var_operation_return(js_expression=f"{string}.trim()", var_type=str) + + +@var_operation +def string_contains_operation(haystack: StringVar, needle: StringVar | str): + """Check if a string contains another string. + + Args: + haystack: The haystack. + needle: The needle. + + Returns: + The string contains operation. + """ + return var_operation_return( + js_expression=f"{haystack}.includes({needle})", var_type=bool ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - Raises: - NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError( - "StringToStringOperation must implement _cached_var_name" - ) +@var_operation +def string_starts_with_operation(full_string: StringVar, prefix: StringVar | str): + """Check if a string starts with a prefix. - @classmethod - def create( - cls, - value: StringVar, - _var_data: VarData | None = None, - ) -> StringVar: - """Create a var from a string value. + Args: + full_string: The full string. + prefix: The prefix. - Args: - value: The value to create the var from. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - _value=value, - ) - - -class StringLowerOperation(StringToStringOperation): - """Base class for immutable string vars that are the result of a string lower operation.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self._value)}.toLowerCase()" - - -class StringUpperOperation(StringToStringOperation): - """Base class for immutable string vars that are the result of a string upper operation.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self._value)}.toUpperCase()" - - -class StringStripOperation(StringToStringOperation): - """Base class for immutable string vars that are the result of a string strip operation.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self._value)}.trim()" - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class StringContainsOperation(CachedVarOperation, BooleanVar): - """Base class for immutable boolean vars that are the result of a string contains operation.""" - - _haystack: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) - _needle: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") + Returns: + Whether the string starts with the prefix. + """ + return var_operation_return( + js_expression=f"{full_string}.startsWith({prefix})", var_type=bool ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - Returns: - The name of the var. - """ - return f"{str(self._haystack)}.includes({str(self._needle)})" +@var_operation +def string_item_operation(string: StringVar, index: NumberVar | int): + """Get an item from a string. - @classmethod - def create( - cls, - haystack: StringVar | str, - needle: StringVar | str, - _var_data: VarData | None = None, - ) -> StringContainsOperation: - """Create a var from a string value. + Args: + string: The string. + index: The index of the item. - Args: - haystack: The haystack. - needle: The needle. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - _haystack=( - haystack - if isinstance(haystack, Var) - else LiteralStringVar.create(haystack) - ), - _needle=( - needle if isinstance(needle, Var) else LiteralStringVar.create(needle) - ), - ) + Returns: + The item from the string. + """ + return var_operation_return(js_expression=f"{string}.at({index})", var_type=str) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class StringStartsWithOperation(CachedVarOperation, BooleanVar): - """Base class for immutable boolean vars that are the result of a string starts with operation.""" +@var_operation +def array_join_operation(array: ArrayVar, sep: StringVar | str = ""): + """Join the elements of an array. - _full_string: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) - _prefix: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) + Args: + array: The array. + sep: The separator. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self._full_string)}.startsWith({str(self._prefix)})" - - @classmethod - def create( - cls, - full_string: StringVar | str, - prefix: StringVar | str, - _var_data: VarData | None = None, - ) -> StringStartsWithOperation: - """Create a var from a string value. - - Args: - full_string: The full string. - prefix: The prefix. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - _full_string=( - full_string - if isinstance(full_string, Var) - else LiteralStringVar.create(full_string) - ), - _prefix=( - prefix if isinstance(prefix, Var) else LiteralStringVar.create(prefix) - ), - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class StringItemOperation(CachedVarOperation, StringVar): - """Base class for immutable string vars that are the result of a string item operation.""" - - _string: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) - _index: NumberVar = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(0) - ) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self._string)}.at({str(self._index)})" - - @classmethod - def create( - cls, - string: StringVar | str, - index: NumberVar | int, - _var_data: VarData | None = None, - ) -> StringItemOperation: - """Create a var from a string value. - - Args: - string: The string. - index: The index. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - _string=( - string if isinstance(string, Var) else LiteralStringVar.create(string) - ), - _index=( - index if isinstance(index, Var) else LiteralNumberVar.create(index) - ), - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ArrayJoinOperation(CachedVarOperation, StringVar): - """Base class for immutable string vars that are the result of an array join operation.""" - - _array: ArrayVar = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]) - ) - _sep: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self._array)}.join({str(self._sep)})" - - @classmethod - def create( - cls, - array: ArrayVar, - sep: StringVar | str = "", - _var_data: VarData | None = None, - ) -> ArrayJoinOperation: - """Create a var from a string value. - - Args: - array: The array. - sep: The separator. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=str, - _var_data=ImmutableVarData.merge(_var_data), - _array=array, - _sep=sep if isinstance(sep, Var) else LiteralStringVar.create(sep), - ) + Returns: + The joined elements. + """ + return var_operation_return(js_expression=f"{array}.join({sep})", var_type=str) # Compile regex for finding reflex var tags. @@ -721,7 +513,7 @@ VALUE_TYPE = TypeVar("VALUE_TYPE") class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): """Base class for immutable array vars.""" - def join(self, sep: StringVar | str = "") -> ArrayJoinOperation: + def join(self, sep: StringVar | str = "") -> StringVar: """Join the elements of the array. Args: @@ -730,7 +522,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: The joined elements. """ - return ArrayJoinOperation.create(self, sep) + return array_join_operation(self, sep) def reverse(self) -> ArrayVar[ARRAY_VAR_TYPE]: """Reverse the array. @@ -738,9 +530,9 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: The reversed array. """ - return ArrayReverseOperation.create(self) + return array_reverse_operation(self) - def __add__(self, other: ArrayVar[ARRAY_VAR_TYPE]) -> ArrayConcatOperation: + def __add__(self, other: ArrayVar[ARRAY_VAR_TYPE]) -> ArrayVar[ARRAY_VAR_TYPE]: """Concatenate two arrays. Parameters: @@ -749,7 +541,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: ArrayConcatOperation: The concatenation of the two arrays. """ - return ArrayConcatOperation.create(self, other) + return array_concat_operation(self, other) @overload def __getitem__(self, i: slice) -> ArrayVar[ARRAY_VAR_TYPE]: ... @@ -854,7 +646,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): """ if isinstance(i, slice): return ArraySliceOperation.create(self, i) - return ArrayItemOperation.create(self, i).guess_type() + return array_item_operation(self, i) def length(self) -> NumberVar: """Get the length of the array. @@ -862,7 +654,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: The length of the array. """ - return ArrayLengthOperation.create(self) + return array_length_operation(self) @overload @classmethod @@ -902,7 +694,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): start = first_endpoint end = second_endpoint - return RangeOperation.create(start, end, step or 1) + return array_range_operation(start, end, step or 1) def contains(self, other: Any) -> BooleanVar: """Check if the array contains an element. @@ -913,7 +705,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: The array contains operation. """ - return ArrayContainsOperation.create(self, other) + return array_contains_operation(self, other) def __mul__(self, other: NumberVar | int) -> ArrayVar[ARRAY_VAR_TYPE]: """Multiply the sequence by a number or integer. @@ -924,7 +716,7 @@ class ArrayVar(ImmutableVar[ARRAY_VAR_TYPE]): Returns: ArrayVar[ARRAY_VAR_TYPE]: The result of multiplying the sequence by the given number or integer. """ - return ArrayRepeatOperation.create(self, other) + return repeat_array_operation(self, other) __rmul__ = __mul__ # type: ignore @@ -1026,102 +818,20 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class StringSplitOperation(CachedVarOperation, ArrayVar): - """Base class for immutable array vars that are the result of a string split operation.""" +@var_operation +def string_split_operation(string: StringVar, sep: StringVar | str = ""): + """Split a string. - _string: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") + Args: + string: The string to split. + sep: The separator. + + Returns: + The split string. + """ + return var_operation_return( + js_expression=f"{string}.split({sep})", var_type=List[str] ) - _sep: StringVar = dataclasses.field( - default_factory=lambda: LiteralStringVar.create("") - ) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self._string)}.split({str(self._sep)})" - - @classmethod - def create( - cls, - string: StringVar | str, - sep: StringVar | str, - _var_data: VarData | None = None, - ) -> StringSplitOperation: - """Create a var from a string value. - - Args: - string: The string. - sep: The separator. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=List[str], - _var_data=ImmutableVarData.merge(_var_data), - _string=( - string if isinstance(string, Var) else LiteralStringVar.create(string) - ), - _sep=(sep if isinstance(sep, Var) else LiteralStringVar.create(sep)), - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ArrayToArrayOperation(CachedVarOperation, ArrayVar): - """Base class for immutable array vars that are the result of an array to array operation.""" - - _value: ArrayVar = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]) - ) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Raises: - NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError( - "ArrayToArrayOperation must implement _cached_var_name" - ) - - @classmethod - def create( - cls, - value: ArrayVar, - _var_data: VarData | None = None, - ) -> ArrayToArrayOperation: - """Create a var from a string value. - - Args: - value: The value to create the var from. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=value._var_type, - _var_data=ImmutableVarData.merge(_var_data), - _value=value, - ) @dataclasses.dataclass( @@ -1135,7 +845,9 @@ class ArraySliceOperation(CachedVarOperation, ArrayVar): _array: ArrayVar = dataclasses.field( default_factory=lambda: LiteralArrayVar.create([]) ) - _slice: slice = dataclasses.field(default_factory=lambda: slice(None, None, None)) + _start: NumberVar | int = dataclasses.field(default_factory=lambda: 0) + _stop: NumberVar | int = dataclasses.field(default_factory=lambda: 0) + _step: NumberVar | int = dataclasses.field(default_factory=lambda: 1) @cached_property_no_lock def _cached_var_name(self) -> str: @@ -1147,7 +859,7 @@ class ArraySliceOperation(CachedVarOperation, ArrayVar): Raises: ValueError: If the slice step is zero. """ - start, end, step = self._slice.start, self._slice.stop, self._slice.step + start, end, step = self._start, self._stop, self._step normalized_start = ( LiteralVar.create(start) @@ -1165,16 +877,7 @@ class ArraySliceOperation(CachedVarOperation, ArrayVar): if step < 0: actual_start = end + 1 if end is not None else 0 actual_end = start + 1 if start is not None else self._array.length() - return str( - ArraySliceOperation.create( - ArrayReverseOperation.create( - ArraySliceOperation.create( - self._array, slice(actual_start, actual_end) - ) - ), - slice(None, None, -step), - ) - ) + return str(self._array[actual_start:actual_end].reverse()[::-step]) if step == 0: raise ValueError("slice step cannot be zero") return f"{str(self._array)}.slice({str(normalized_start)}, {str(normalized_end)}).filter((_, i) => i % {str(step)} === 0)" @@ -1206,80 +909,44 @@ class ArraySliceOperation(CachedVarOperation, ArrayVar): _var_type=array._var_type, _var_data=ImmutableVarData.merge(_var_data), _array=array, - _slice=slice, + _start=slice.start, + _stop=slice.stop, + _step=slice.step, ) -class ArrayReverseOperation(ArrayToArrayOperation): - """Base class for immutable string vars that are the result of a string reverse operation.""" +@var_operation +def array_reverse_operation( + array: ArrayVar[ARRAY_VAR_TYPE], +) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: + """Reverse an array. - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. + Args: + array: The array to reverse. - Returns: - The name of the var. - """ - return f"{str(self._value)}.slice().reverse()" - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ArrayToNumberOperation(CachedVarOperation, NumberVar): - """Base class for immutable number vars that are the result of an array to number operation.""" - - _array: ArrayVar = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]), + Returns: + The reversed array. + """ + return var_operation_return( + js_expression=f"{array}.slice().reverse()", + var_type=array._var_type, ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - Raises: - NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError( - "StringToNumberOperation must implement _cached_var_name" - ) +@var_operation +def array_length_operation(array: ArrayVar): + """Get the length of an array. - @classmethod - def create( - cls, - array: ArrayVar, - _var_data: VarData | None = None, - ) -> ArrayToNumberOperation: - """Create a var from a string value. + Args: + array: The array. - Args: - array: The array. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=int, - _var_data=ImmutableVarData.merge(_var_data), - _array=array, - ) - - -class ArrayLengthOperation(ArrayToNumberOperation): - """Base class for immutable number vars that are the result of an array length operation.""" - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self._array)}.length" + Returns: + The length of the array. + """ + return var_operation_return( + js_expression=f"{array}.length", + var_type=int, + ) def is_tuple_type(t: GenericType) -> bool: @@ -1296,166 +963,65 @@ def is_tuple_type(t: GenericType) -> bool: return get_origin(t) is tuple -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ArrayItemOperation(CachedVarOperation, ImmutableVar): - """Base class for immutable array vars that are the result of an array item operation.""" +@var_operation +def array_item_operation(array: ArrayVar, index: NumberVar | int): + """Get an item from an array. - _array: ArrayVar = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]) - ) - _index: NumberVar = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(0) + Args: + array: The array. + index: The index of the item. + + Returns: + The item from the array. + """ + args = typing.get_args(array._var_type) + if args and isinstance(index, LiteralNumberVar) and is_tuple_type(array._var_type): + index_value = int(index._var_value) + element_type = args[index_value % len(args)] + else: + element_type = unionize(*args) + + return var_operation_return( + js_expression=f"{str(array)}.at({str(index)})", + var_type=element_type, ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - Returns: - The name of the var. - """ - return f"{str(self._array)}.at({str(self._index)})" +@var_operation +def array_range_operation( + start: NumberVar | int, stop: NumberVar | int, step: NumberVar | int +): + """Create a range of numbers. - @classmethod - def create( - cls, - array: ArrayVar, - index: NumberVar | int, - _var_type: GenericType | None = None, - _var_data: VarData | None = None, - ) -> ArrayItemOperation: - """Create a var from a string value. + Args: + start: The start of the range. + stop: The end of the range. + step: The step of the range. - Args: - array: The array. - index: The index. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - args = typing.get_args(array._var_type) - if args and isinstance(index, int) and is_tuple_type(array._var_type): - element_type = args[index % len(args)] - else: - element_type = unionize(*args) - - return cls( - _var_name="", - _var_type=element_type if _var_type is None else _var_type, - _var_data=ImmutableVarData.merge(_var_data), - _array=array, - _index=index if isinstance(index, Var) else LiteralNumberVar.create(index), - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class RangeOperation(CachedVarOperation, ArrayVar): - """Base class for immutable array vars that are the result of a range operation.""" - - _start: NumberVar = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(0) - ) - _stop: NumberVar = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(0) - ) - _step: NumberVar = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(1) + Returns: + The range of numbers. + """ + return var_operation_return( + js_expression=f"Array.from({{ length: ({str(stop)} - {str(start)}) / {str(step)} }}, (_, i) => {str(start)} + i * {str(step)})", + var_type=List[int], ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - Returns: - The name of the var. - """ - start, end, step = self._start, self._stop, self._step - return f"Array.from({{ length: ({str(end)} - {str(start)}) / {str(step)} }}, (_, i) => {str(start)} + i * {str(step)})" +@var_operation +def array_contains_operation(haystack: ArrayVar, needle: Any | Var): + """Check if an array contains an element. - @classmethod - def create( - cls, - start: NumberVar | int, - stop: NumberVar | int, - step: NumberVar | int, - _var_data: VarData | None = None, - ) -> RangeOperation: - """Create a var from a string value. + Args: + haystack: The array to check. + needle: The element to check for. - Args: - start: The start of the range. - stop: The end of the range. - step: The step of the range. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=List[int], - _var_data=ImmutableVarData.merge(_var_data), - _start=start if isinstance(start, Var) else LiteralNumberVar.create(start), - _stop=stop if isinstance(stop, Var) else LiteralNumberVar.create(stop), - _step=step if isinstance(step, Var) else LiteralNumberVar.create(step), - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ArrayContainsOperation(CachedVarOperation, BooleanVar): - """Base class for immutable boolean vars that are the result of an array contains operation.""" - - _haystack: ArrayVar = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]) + Returns: + The array contains operation. + """ + return var_operation_return( + js_expression=f"{haystack}.includes({needle})", + var_type=bool, ) - _needle: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None)) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"{str(self._haystack)}.includes({str(self._needle)})" - - @classmethod - def create( - cls, - haystack: ArrayVar, - needle: Any | Var, - _var_data: VarData | None = None, - ) -> ArrayContainsOperation: - """Create a var from a string value. - - Args: - haystack: The array. - needle: The element to check for. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=bool, - _var_data=ImmutableVarData.merge(_var_data), - _haystack=haystack, - _needle=needle if isinstance(needle, Var) else LiteralVar.create(needle), - ) @dataclasses.dataclass( @@ -1547,102 +1113,39 @@ class ToArrayOperation(CachedVarOperation, ArrayVar): ) -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ArrayRepeatOperation(CachedVarOperation, ArrayVar): - """Base class for immutable array vars that are the result of an array repeat operation.""" +@var_operation +def repeat_array_operation( + array: ArrayVar[ARRAY_VAR_TYPE], count: NumberVar | int +) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: + """Repeat an array a number of times. - _array: ArrayVar = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]) - ) - _count: NumberVar = dataclasses.field( - default_factory=lambda: LiteralNumberVar.create(0) + Args: + array: The array to repeat. + count: The number of times to repeat the array. + + Returns: + The repeated array. + """ + return var_operation_return( + js_expression=f"Array.from({{ length: {count} }}).flatMap(() => {array})", + var_type=array._var_type, ) - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - Returns: - The name of the var. - """ - return f"Array.from({{ length: {str(self._count)} }}).flatMap(() => {str(self._array)})" +@var_operation +def array_concat_operation( + lhs: ArrayVar[ARRAY_VAR_TYPE], rhs: ArrayVar[ARRAY_VAR_TYPE] +) -> CustomVarOperationReturn[ARRAY_VAR_TYPE]: + """Concatenate two arrays. - @classmethod - def create( - cls, - array: ArrayVar, - count: NumberVar | int, - _var_data: VarData | None = None, - ) -> ArrayRepeatOperation: - """Create a var from a string value. + Args: + lhs: The left-hand side array. + rhs: The right-hand side array. - Args: - array: The array. - count: The number of times to repeat the array. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - return cls( - _var_name="", - _var_type=array._var_type, - _var_data=ImmutableVarData.merge(_var_data), - _array=array, - _count=count if isinstance(count, Var) else LiteralNumberVar.create(count), - ) - - -@dataclasses.dataclass( - eq=False, - frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, -) -class ArrayConcatOperation(CachedVarOperation, ArrayVar): - """Base class for immutable array vars that are the result of an array concat operation.""" - - _lhs: ArrayVar = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]) + Returns: + The concatenated array. + """ + return var_operation_return( + js_expression=f"[...{lhs}, ...{rhs}]", + var_type=Union[lhs._var_type, rhs._var_type], ) - _rhs: ArrayVar = dataclasses.field( - default_factory=lambda: LiteralArrayVar.create([]) - ) - - @cached_property_no_lock - def _cached_var_name(self) -> str: - """The name of the var. - - Returns: - The name of the var. - """ - return f"[...{str(self._lhs)}, ...{str(self._rhs)}]" - - @classmethod - def create( - cls, - lhs: ArrayVar, - rhs: ArrayVar, - _var_data: VarData | None = None, - ) -> ArrayConcatOperation: - """Create a var from a string value. - - Args: - lhs: The left-hand side array. - rhs: The right-hand side array. - _var_data: Additional hooks and imports associated with the Var. - - Returns: - The var. - """ - # TODO: Figure out how to merge the types of a and b - return cls( - _var_name="", - _var_type=Union[lhs._var_type, rhs._var_type], - _var_data=ImmutableVarData.merge(_var_data), - _lhs=lhs, - _rhs=rhs, - ) diff --git a/tests/components/core/test_colors.py b/tests/components/core/test_colors.py index d5d5dc995..ace7adad6 100644 --- a/tests/components/core/test_colors.py +++ b/tests/components/core/test_colors.py @@ -67,11 +67,11 @@ def test_color(color, expected): [ ( rx.cond(True, rx.color("mint"), rx.color("tomato", 5)), - '(Boolean(true) ? "var(--mint-7)" : "var(--tomato-5)")', + '(true ? "var(--mint-7)" : "var(--tomato-5)")', ), ( rx.cond(True, rx.color(ColorState.color), rx.color(ColorState.color, 5)), # type: ignore - f'(Boolean(true) ? ("var(--"+{str(color_state_name)}.color+"-7)") : ("var(--"+{str(color_state_name)}.color+"-5)"))', + f'(true ? ("var(--"+{str(color_state_name)}.color+"-7)") : ("var(--"+{str(color_state_name)}.color+"-5)"))', ), ( rx.match( diff --git a/tests/components/core/test_cond.py b/tests/components/core/test_cond.py index e3fc40ae3..53f8e0640 100644 --- a/tests/components/core/test_cond.py +++ b/tests/components/core/test_cond.py @@ -23,7 +23,7 @@ def cond_state(request): def test_f_string_cond_interpolation(): # make sure backticks inside interpolation don't get escaped var = LiteralVar.create(f"x {cond(True, 'a', 'b')}") - assert str(var) == '("x "+(Boolean(true) ? "a" : "b"))' + assert str(var) == '("x "+(true ? "a" : "b"))' @pytest.mark.parametrize( @@ -97,7 +97,7 @@ def test_prop_cond(c1: Any, c2: Any): c1 = json.dumps(c1) if not isinstance(c2, Var): c2 = json.dumps(c2) - assert str(prop_cond) == f"(Boolean(true) ? {c1} : {c2})" + assert str(prop_cond) == f"(true ? {c1} : {c2})" def test_cond_no_mix(): @@ -141,8 +141,7 @@ def test_cond_computed_var(): state_name = format_state_name(CondStateComputed.get_full_name()) assert ( - str(comp) - == f"(Boolean(true) ? {state_name}.computed_int : {state_name}.computed_str)" + str(comp) == f"(true ? {state_name}.computed_int : {state_name}.computed_str)" ) assert comp._var_type == Union[int, str] diff --git a/tests/test_var.py b/tests/test_var.py index cca6cbec4..ba9cf24c8 100644 --- a/tests/test_var.py +++ b/tests/test_var.py @@ -12,6 +12,7 @@ from reflex.ivars.base import ( ImmutableVar, LiteralVar, var_operation, + var_operation_return, ) from reflex.ivars.function import ArgsFunctionOperation, FunctionStringVar from reflex.ivars.number import ( @@ -925,9 +926,9 @@ def test_function_var(): def test_var_operation(): - @var_operation(output=NumberVar) - def add(a: Union[NumberVar, int], b: Union[NumberVar, int]) -> str: - return f"({a} + {b})" + @var_operation + def add(a: Union[NumberVar, int], b: Union[NumberVar, int]): + return var_operation_return(js_expression=f"({a} + {b})", var_type=int) assert str(add(1, 2)) == "(1 + 2)" assert str(add(a=4, b=-9)) == "(4 + -9)" @@ -967,14 +968,14 @@ def test_all_number_operations(): assert ( str(even_more_complicated_number) - == "!(Boolean((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))))))" + == "!(((Math.abs(Math.floor(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2))) || (2 && Math.round(((Math.floor(((-((-5.4 + 1)) * 2) / 3) / 2) % 3) ** 2)))) !== 0))" ) assert str(LiteralNumberVar.create(5) > False) == "(5 > 0)" - assert str(LiteralBooleanVar.create(False) < 5) == "((false ? 1 : 0) < 5)" + assert str(LiteralBooleanVar.create(False) < 5) == "(Number(false) < 5)" assert ( str(LiteralBooleanVar.create(False) < LiteralBooleanVar.create(True)) - == "((false ? 1 : 0) < (true ? 1 : 0))" + == "(Number(false) < Number(true))" )