diff --git a/reflex/vars/__init__.py b/reflex/vars/__init__.py index 1a4cebe19..cb02319bc 100644 --- a/reflex/vars/__init__.py +++ b/reflex/vars/__init__.py @@ -9,6 +9,7 @@ from .base import get_unique_variable_name as get_unique_variable_name from .base import get_uuid_string_var as get_uuid_string_var from .base import var_operation as var_operation from .base import var_operation_return as var_operation_return +from .datetime import DateTimeVar as DateTimeVar from .function import FunctionStringVar as FunctionStringVar from .function import FunctionVar as FunctionVar from .function import VarOperationCall as VarOperationCall diff --git a/reflex/vars/datetime.py b/reflex/vars/datetime.py new file mode 100644 index 000000000..f7691b0bb --- /dev/null +++ b/reflex/vars/datetime.py @@ -0,0 +1,182 @@ +"""Immutable datetime and date vars.""" + +from datetime import date, datetime +from typing import Any, NoReturn, TypeVar, Union, overload + +from reflex.utils.exceptions import VarTypeError +from reflex.vars.number import BooleanVar + +from .base import CustomVarOperationReturn, Var, var_operation, var_operation_return + +DATETIME_T = TypeVar("DATETIME_T", datetime, date) + +datetime_types = Union[datetime, date] + + +def raise_var_type_error(): + """Raise a VarTypeError. + + Raises: + VarTypeError: Cannot compare a datetime object with a non-datetime object. + """ + raise VarTypeError("Cannot compare a datetime object with a non-datetime object.") + + +class DateTimeVar(Var[DATETIME_T], python_types=(datetime, date)): + """A variable that holds a datetime or date object.""" + + @overload + def __lt__(self, other: datetime_types) -> BooleanVar: ... + + @overload + def __lt__(self, other: NoReturn) -> NoReturn: ... + + def __lt__(self, other: Any): + """Less than comparison. + + Args: + other: The other datetime to compare. + + Returns: + The result of the comparison. + """ + if not isinstance(other, DATETIME_TYPES): + raise_var_type_error() + return date_lt_operation(self, other) + + @overload + def __le__(self, other: datetime_types) -> BooleanVar: ... + + @overload + def __le__(self, other: NoReturn) -> NoReturn: ... + + def __le__(self, other: Any): + """Less than or equal comparison. + + Args: + other: The other datetime to compare. + + Returns: + The result of the comparison. + """ + if not isinstance(other, DATETIME_TYPES): + raise_var_type_error() + return date_le_operation(self, other) + + @overload + def __gt__(self, other: datetime_types) -> BooleanVar: ... + + @overload + def __gt__(self, other: NoReturn) -> NoReturn: ... + + def __gt__(self, other: Any): + """Greater than comparison. + + Args: + other: The other datetime to compare. + + Returns: + The result of the comparison. + """ + if not isinstance(other, DATETIME_TYPES): + raise_var_type_error() + return date_gt_operation(self, other) + + @overload + def __ge__(self, other: datetime_types) -> BooleanVar: ... + + @overload + def __ge__(self, other: NoReturn) -> NoReturn: ... + + def __ge__(self, other: Any): + """Greater than or equal comparison. + + Args: + other: The other datetime to compare. + + Returns: + The result of the comparison. + """ + if not isinstance(other, DATETIME_TYPES): + raise_var_type_error() + return date_ge_operation(self, other) + + +@var_operation +def date_gt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: + """Greater than comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(rhs, lhs, strict=True) + + +@var_operation +def date_lt_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: + """Less than comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(lhs, rhs, strict=True) + + +@var_operation +def date_le_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: + """Less than or equal comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(lhs, rhs) + + +@var_operation +def date_ge_operation(lhs: Var | Any, rhs: Var | Any) -> CustomVarOperationReturn: + """Greater than or equal comparison. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + + Returns: + The result of the operation. + """ + return date_compare_operation(rhs, lhs) + + +def date_compare_operation( + lhs: DateTimeVar[DATETIME_T] | Any, + rhs: DateTimeVar[DATETIME_T] | Any, + strict: bool = False, +) -> CustomVarOperationReturn: + """Check if the value is less than the other value. + + Args: + lhs: The left-hand side of the operation. + rhs: The right-hand side of the operation. + strict: Whether to use strict comparison. + + Returns: + The result of the operation. + """ + return var_operation_return( + f"isTrue({lhs} { '<' if strict else '<='} {rhs})", + bool, + ) + + +DATETIME_TYPES = (datetime, date, DateTimeVar)