reflex/reflex/ivars/base.py
2024-08-05 15:30:19 -07:00

1015 lines
30 KiB
Python

"""Collection of base classes."""
from __future__ import annotations
import dataclasses
import functools
import inspect
import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
get_args,
overload,
)
from typing_extensions import ParamSpec, get_origin
from reflex import constants
from reflex.base import Base
from reflex.utils import console, imports, serializers, types
from reflex.utils.exceptions import VarTypeError
from reflex.vars import (
ImmutableVarData,
Var,
VarData,
_decode_var_immutable,
_extract_var_data,
_global_vars,
)
if TYPE_CHECKING:
from reflex.state import BaseState
from .function import FunctionVar, ToFunctionOperation
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)
from .object import ObjectVar, ToObjectOperation
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
VAR_TYPE = TypeVar("VAR_TYPE")
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ImmutableVar(Var, Generic[VAR_TYPE]):
"""Base class for immutable vars."""
# The name of the var.
_var_name: str = dataclasses.field()
# The type of the var.
_var_type: types.GenericType = dataclasses.field(default=Any)
# Extra metadata associated with the Var
_var_data: Optional[ImmutableVarData] = dataclasses.field(default=None)
def __str__(self) -> str:
"""String representation of the var. Guaranteed to be a valid Javascript expression.
Returns:
The name of the var.
"""
return self._var_name
@property
def _var_is_local(self) -> bool:
"""Whether this is a local javascript variable.
Returns:
False
"""
return False
@property
def _var_is_string(self) -> bool:
"""Whether the var is a string literal.
Returns:
False
"""
return False
@property
def _var_full_name_needs_state_prefix(self) -> bool:
"""Whether the full name of the var needs a _var_state prefix.
Returns:
False
"""
return False
def __post_init__(self):
"""Post-initialize the var."""
# Decode any inline Var markup and apply it to the instance
_var_data, _var_name = _decode_var_immutable(self._var_name)
if _var_data or _var_name != self._var_name:
self.__init__(
_var_name=_var_name,
_var_type=self._var_type,
_var_data=ImmutableVarData.merge(self._var_data, _var_data),
)
def __hash__(self) -> int:
"""Define a hash function for the var.
Returns:
The hash of the var.
"""
return hash((self._var_name, self._var_type, self._var_data))
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return self._var_data
def _replace(self, merge_var_data=None, **kwargs: Any):
"""Make a copy of this Var with updated fields.
Args:
merge_var_data: VarData to merge into the existing VarData.
**kwargs: Var fields to update.
Returns:
A new ImmutableVar with the updated fields overwriting the corresponding fields in this Var.
Raises:
TypeError: If _var_is_local, _var_is_string, or _var_full_name_needs_state_prefix is not None.
"""
if kwargs.get("_var_is_local", False) is not False:
raise TypeError(
"The _var_is_local argument is not supported for ImmutableVar."
)
if kwargs.get("_var_is_string", False) is not False:
raise TypeError(
"The _var_is_string argument is not supported for ImmutableVar."
)
if kwargs.get("_var_full_name_needs_state_prefix", False) is not False:
raise TypeError(
"The _var_full_name_needs_state_prefix argument is not supported for ImmutableVar."
)
field_values = dict(
_var_name=kwargs.pop("_var_name", self._var_name),
_var_type=kwargs.pop("_var_type", self._var_type),
_var_data=ImmutableVarData.merge(
kwargs.get("_var_data", self._var_data), merge_var_data
),
)
return type(self)(**field_values)
@classmethod
def create(
cls,
value: Any,
_var_is_local: bool | None = None,
_var_is_string: bool | None = None,
_var_data: VarData | None = None,
) -> ImmutableVar | Var | None:
"""Create a var from a value.
Args:
value: The value to create the var from.
_var_is_local: Whether the var is local. Deprecated.
_var_is_string: Whether the var is a string literal. Deprecated.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
Raises:
VarTypeError: If the value is JSON-unserializable.
TypeError: If _var_is_local or _var_is_string is not None.
"""
if _var_is_local is not None:
raise TypeError(
"The _var_is_local argument is not supported for ImmutableVar."
)
if _var_is_string is not None:
raise TypeError(
"The _var_is_string argument is not supported for ImmutableVar."
)
from reflex.utils import format
# Check for none values.
if value is None:
return None
# If the value is already a var, do nothing.
if isinstance(value, Var):
return value
# Try to pull the imports and hooks from contained values.
if not isinstance(value, str):
_var_data = VarData.merge(*_extract_var_data(value), _var_data)
# Try to serialize the value.
type_ = type(value)
if type_ in types.JSONType:
name = value
else:
name, _serialized_type = serializers.serialize(value, get_type=True)
if name is None:
raise VarTypeError(
f"No JSON serializer found for var {value} of type {type_}."
)
name = name if isinstance(name, str) else format.json_dumps(name)
return cls(
_var_name=name,
_var_type=type_,
_var_data=(
ImmutableVarData(
state=_var_data.state,
imports=_var_data.imports,
hooks=_var_data.hooks,
)
if _var_data
else None
),
)
@classmethod
def create_safe(
cls,
value: Any,
_var_is_local: bool | None = None,
_var_is_string: bool | None = None,
_var_data: VarData | None = None,
) -> Var | ImmutableVar:
"""Create a var from a value, asserting that it is not None.
Args:
value: The value to create the var from.
_var_is_local: Whether the var is local. Deprecated.
_var_is_string: Whether the var is a string literal. Deprecated.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
"""
var = cls.create(
value,
_var_is_local=_var_is_local,
_var_is_string=_var_is_string,
_var_data=_var_data,
)
assert var is not None
return var
def __format__(self, format_spec: str) -> str:
"""Format the var into a Javascript equivalent to an f-string.
Args:
format_spec: The format specifier (Ignored for now).
Returns:
The formatted var.
"""
hashed_var = hash(self)
_global_vars[hashed_var] = self
# Encode the _var_data into the formatted output for tracking purposes.
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}"
@overload
def to(
self, output: Type[NumberVar], var_type: type[int] | type[float] = float
) -> ToNumberVarOperation: ...
@overload
def to(self, output: Type[BooleanVar]) -> ToBooleanVarOperation: ...
@overload
def to(
self,
output: Type[ArrayVar],
var_type: type[list] | type[tuple] | type[set] = list,
) -> ToArrayOperation: ...
@overload
def to(self, output: Type[StringVar]) -> ToStringOperation: ...
@overload
def to(
self, output: Type[ObjectVar], var_type: types.GenericType = dict
) -> ToObjectOperation: ...
@overload
def to(
self, output: Type[FunctionVar], var_type: Type[Callable] = Callable
) -> ToFunctionOperation: ...
@overload
def to(
self, output: Type[OUTPUT], var_type: types.GenericType | None = None
) -> OUTPUT: ...
def to(
self, output: Type[OUTPUT], var_type: types.GenericType | None = None
) -> Var:
"""Convert the var to a different type.
Args:
output: The output type.
var_type: The type of the var.
Raises:
TypeError: If the var_type is not a supported type for the output.
Returns:
The converted var.
"""
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)
fixed_type = (
var_type
if var_type is None or inspect.isclass(var_type)
else get_origin(var_type)
)
if issubclass(output, NumberVar):
if fixed_type is not None and not issubclass(fixed_type, (int, float)):
raise TypeError(
f"Unsupported type {var_type} for NumberVar. Must be int or float."
)
return ToNumberVarOperation(self, var_type or float)
if issubclass(output, BooleanVar):
return ToBooleanVarOperation(self)
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
if issubclass(output, ArrayVar):
if fixed_type is not None and not issubclass(
fixed_type, (list, tuple, set)
):
raise TypeError(
f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set."
)
return ToArrayOperation(self, var_type or list)
if issubclass(output, StringVar):
return ToStringOperation(self)
from .object import ObjectVar, ToObjectOperation
if issubclass(output, ObjectVar):
return ToObjectOperation(self, var_type or dict)
from .function import FunctionVar, ToFunctionOperation
if issubclass(output, FunctionVar):
# if fixed_type is not None and not issubclass(fixed_type, Callable):
# raise TypeError(
# f"Unsupported type {var_type} for FunctionVar. Must be Callable."
# )
return ToFunctionOperation(self, var_type or Callable)
return output(
_var_name=self._var_name,
_var_type=self._var_type if var_type is None else var_type,
_var_data=self._var_data,
)
def guess_type(self) -> ImmutableVar:
"""Guess the type of the var.
Returns:
The guessed type.
"""
from .number import NumberVar
from .object import ObjectVar
from .sequence import ArrayVar, StringVar
if self._var_type is Any:
return self
var_type = self._var_type
fixed_type = var_type if inspect.isclass(var_type) else get_origin(var_type)
if fixed_type is Union:
return self
if issubclass(fixed_type, (int, float)):
return self.to(NumberVar, var_type)
if issubclass(fixed_type, dict):
return self.to(ObjectVar, var_type)
if issubclass(fixed_type, (list, tuple, set)):
return self.to(ArrayVar, var_type)
if issubclass(fixed_type, str):
return self.to(StringVar)
if issubclass(fixed_type, Base):
return self.to(ObjectVar, var_type)
return self
def get_default_value(self) -> Any:
"""Get the default value of the var.
Returns:
The default value of the var.
Raises:
ImportError: If the var is a dataframe and pandas is not installed.
"""
if types.is_optional(self._var_type):
return None
type_ = (
get_origin(self._var_type)
if types.is_generic_alias(self._var_type)
else self._var_type
)
if type_ is Literal:
args = get_args(self._var_type)
return args[0] if args else None
if issubclass(type_, str):
return ""
if issubclass(type_, types.get_args(Union[int, float])):
return 0
if issubclass(type_, bool):
return False
if issubclass(type_, list):
return []
if issubclass(type_, dict):
return {}
if issubclass(type_, tuple):
return ()
if types.is_dataframe(type_):
try:
import pandas as pd
return pd.DataFrame()
except ImportError as e:
raise ImportError(
"Please install pandas to use dataframes in your app."
) from e
return set() if issubclass(type_, set) else None
def get_setter_name(self, include_state: bool = True) -> str:
"""Get the name of the var's generated setter function.
Args:
include_state: Whether to include the state name in the setter name.
Returns:
The name of the setter function.
"""
setter = constants.SETTER_PREFIX + self._var_name
if self._var_data is None:
return setter
if not include_state or self._var_data.state == "":
return setter
print("get_setter_name", self._var_data.state, setter)
return ".".join((self._var_data.state, setter))
def get_setter(self) -> Callable[[BaseState, Any], None]:
"""Get the var's setter function.
Returns:
A function that that creates a setter for the var.
"""
def setter(state: BaseState, value: Any):
"""Get the setter for the var.
Args:
state: The state within which we add the setter function.
value: The value to set.
"""
if self._var_type in [int, float]:
try:
value = self._var_type(value)
setattr(state, self._var_name, value)
except ValueError:
console.debug(
f"{type(state).__name__}.{self._var_name}: Failed conversion of {value} to '{self._var_type.__name__}'. Value not set.",
)
else:
setattr(state, self._var_name, value)
setter.__qualname__ = self.get_setter_name()
return setter
def __eq__(self, other: Var | Any) -> BooleanVar:
"""Check if the current variable is equal to the given variable.
Args:
other (Var | Any): The variable to compare with.
Returns:
BooleanVar: A BooleanVar object representing the result of the equality check.
"""
from .number import EqualOperation
return EqualOperation(self, other)
def __ne__(self, other: Var | Any) -> BooleanVar:
"""Check if the current object is not equal to the given object.
Parameters:
other (Var | Any): The object to compare with.
Returns:
BooleanVar: A BooleanVar object representing the result of the comparison.
"""
from .number import EqualOperation
return ~EqualOperation(self, other)
def __gt__(self, other: Var | Any) -> BooleanVar:
"""Compare the current instance with another variable and return a BooleanVar representing the result of the greater than operation.
Args:
other (Var | Any): The variable to compare with.
Returns:
BooleanVar: A BooleanVar representing the result of the greater than operation.
"""
from .number import GreaterThanOperation
return GreaterThanOperation(self, other)
def __ge__(self, other: Var | Any) -> BooleanVar:
"""Check if the value of this variable is greater than or equal to the value of another variable or object.
Args:
other (Var | Any): The variable or object to compare with.
Returns:
BooleanVar: A BooleanVar object representing the result of the comparison.
"""
from .number import GreaterThanOrEqualOperation
return GreaterThanOrEqualOperation(self, other)
def __lt__(self, other: Var | Any) -> BooleanVar:
"""Compare the current instance with another variable using the less than (<) operator.
Args:
other: The variable to compare with.
Returns:
A `BooleanVar` object representing the result of the comparison.
"""
from .number import LessThanOperation
return LessThanOperation(self, other)
def __le__(self, other: Var | Any) -> BooleanVar:
"""Compare if the current instance is less than or equal to the given value.
Args:
other: The value to compare with.
Returns:
A BooleanVar object representing the result of the comparison.
"""
from .number import LessThanOrEqualOperation
return LessThanOrEqualOperation(self, other)
def bool(self) -> BooleanVar:
"""Convert the var to a boolean.
Returns:
The boolean var.
"""
from .number import ToBooleanVarOperation
return ToBooleanVarOperation(self)
def __and__(self, other: Var | Any) -> ImmutableVar:
"""Perform a logical AND operation on the current instance and another variable.
Args:
other: The variable to perform the logical AND operation with.
Returns:
A `BooleanVar` object representing the result of the logical AND operation.
"""
return AndOperation(self, other)
def __rand__(self, other: Var | Any) -> ImmutableVar:
"""Perform a logical AND operation on the current instance and another variable.
Args:
other: The variable to perform the logical AND operation with.
Returns:
A `BooleanVar` object representing the result of the logical AND operation.
"""
return AndOperation(other, self)
def __or__(self, other: Var | Any) -> ImmutableVar:
"""Perform a logical OR operation on the current instance and another variable.
Args:
other: The variable to perform the logical OR operation with.
Returns:
A `BooleanVar` object representing the result of the logical OR operation.
"""
return OrOperation(self, other)
def __ror__(self, other: Var | Any) -> ImmutableVar:
"""Perform a logical OR operation on the current instance and another variable.
Args:
other: The variable to perform the logical OR operation with.
Returns:
A `BooleanVar` object representing the result of the logical OR operation.
"""
return OrOperation(other, self)
def __invert__(self) -> BooleanVar:
"""Perform a logical NOT operation on the current instance.
Returns:
A `BooleanVar` object representing the result of the logical NOT operation.
"""
from .number import BooleanNotOperation
return BooleanNotOperation(self.bool())
def to_string(self) -> ImmutableVar:
"""Convert the var to a string.
Returns:
The string var.
"""
from .function import JSON_STRINGIFY
return JSON_STRINGIFY.call(self)
def as_ref(self) -> ImmutableVar:
"""Get a reference to the var.
Returns:
The reference to the var.
"""
from .object import ObjectVar
refs = ImmutableVar(
_var_name="refs",
_var_data=ImmutableVarData(
imports={
f"/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
}
),
).to(ObjectVar)
return refs[self]
OUTPUT = TypeVar("OUTPUT", bound=ImmutableVar)
class LiteralVar(ImmutableVar):
"""Base class for immutable literal vars."""
@classmethod
def create(
cls,
value: Any,
_var_data: VarData | None = None,
) -> Var:
"""Create a var from a value.
Args:
value: The value to create the var from.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
Raises:
TypeError: If the value is not a supported type for LiteralVar.
"""
if isinstance(value, Var):
if _var_data is None:
return value
return value._replace(merge_var_data=_var_data)
if value is None:
return ImmutableVar.create_safe("null", _var_data=_var_data)
from .object import LiteralObjectVar
if isinstance(value, Base):
return LiteralObjectVar(
value.dict(), _var_type=type(value), _var_data=_var_data
)
if isinstance(value, dict):
return LiteralObjectVar(value, _var_data=_var_data)
from .number import LiteralBooleanVar, LiteralNumberVar
from .sequence import LiteralArrayVar, LiteralStringVar
if isinstance(value, str):
return LiteralStringVar.create(value, _var_data=_var_data)
type_mapping = {
int: LiteralNumberVar,
float: LiteralNumberVar,
bool: LiteralBooleanVar,
list: LiteralArrayVar,
tuple: LiteralArrayVar,
set: LiteralArrayVar,
}
constructor = type_mapping.get(type(value))
if constructor is None:
raise TypeError(f"Unsupported type {type(value)} for LiteralVar.")
return constructor(value, _var_data=_var_data)
def __post_init__(self):
"""Post-initialize the var."""
def json(self) -> str:
"""Serialize the var to a JSON string.
Raises:
NotImplementedError: If the method is not implemented.
"""
raise NotImplementedError(
"LiteralVar subclasses must implement the json method."
)
P = ParamSpec("P")
T = TypeVar("T", bound=ImmutableVar)
def var_operation(*, output: Type[T]) -> Callable[[Callable[P, str]], Callable[P, T]]:
"""Decorator for creating a var operation.
Example:
```python
@var_operation(output=NumberVar)
def add(a: NumberVar, b: NumberVar):
return f"({a} + {b})"
```
Args:
output: The output type of the operation.
Returns:
The decorator.
"""
def decorator(func: Callable[P, str], output=output):
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
args_vars = [
LiteralVar.create(arg) if not isinstance(arg, Var) else arg
for arg in args
]
kwargs_vars = {
key: LiteralVar.create(value) if not isinstance(value, Var) else value
for key, value in kwargs.items()
}
return output(
_var_name=func(*args_vars, **kwargs_vars), # type: ignore
_var_data=VarData.merge(
*[arg._get_all_var_data() for arg in args if isinstance(arg, Var)],
*[
arg._get_all_var_data()
for arg in kwargs.values()
if isinstance(arg, Var)
],
),
)
return wrapper
return decorator
def unionize(*args: Type) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
first, *rest = args
if not rest:
return first
return Union[first, unionize(*rest)]
def figure_out_type(value: Any) -> Type:
"""Figure out the type of the value.
Args:
value: The value to figure out the type of.
Returns:
The type of the value.
"""
if isinstance(value, list):
return List[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, set):
return Set[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, tuple):
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
if isinstance(value, dict):
return Dict[
unionize(*(figure_out_type(k) for k in value)),
unionize(*(figure_out_type(v) for v in value.values())),
]
return type(value)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class AndOperation(ImmutableVar):
"""Class for the logical AND operation."""
# The first var.
_var1: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
# The second var.
_var2: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
def __init__(
self, var1: Var | Any, var2: Var | Any, _var_data: VarData | None = None
):
"""Initialize the AndOperation.
Args:
var1: The first var.
var2: The second var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(AndOperation, self).__init__(
_var_name="",
_var_type=Union[var1._var_type, var2._var_type],
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self, "_var1", var1 if isinstance(var1, Var) else LiteralVar.create(var1)
)
object.__setattr__(
self, "_var2", var2 if isinstance(var2, Var) else LiteralVar.create(var2)
)
object.__delattr__(self, "_var_name")
@functools.cached_property
def _cached_var_name(self) -> str:
"""Get the cached var name.
Returns:
The cached var name.
"""
return f"({str(self._var1)} && {str(self._var2)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute.
"""
if name == "_var_name":
return self._cached_var_name
return getattr(super(type(self), self), name)
@functools.cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get the cached VarData.
Returns:
The cached VarData.
"""
return ImmutableVarData.merge(
self._var1._get_all_var_data(),
self._var2._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class OrOperation(ImmutableVar):
"""Class for the logical OR operation."""
# The first var.
_var1: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
# The second var.
_var2: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
def __init__(
self, var1: Var | Any, var2: Var | Any, _var_data: VarData | None = None
):
"""Initialize the OrOperation.
Args:
var1: The first var.
var2: The second var.
_var_data: Additional hooks and imports associated with the Var.
"""
super(OrOperation, self).__init__(
_var_name="",
_var_type=Union[var1._var_type, var2._var_type],
_var_data=ImmutableVarData.merge(_var_data),
)
object.__setattr__(
self, "_var1", var1 if isinstance(var1, Var) else LiteralVar.create(var1)
)
object.__setattr__(
self, "_var2", var2 if isinstance(var2, Var) else LiteralVar.create(var2)
)
object.__delattr__(self, "_var_name")
@functools.cached_property
def _cached_var_name(self) -> str:
"""Get the cached var name.
Returns:
The cached var name.
"""
return f"({str(self._var1)} || {str(self._var2)})"
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute.
"""
if name == "_var_name":
return self._cached_var_name
return getattr(super(type(self), self), name)
@functools.cached_property
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get the cached VarData.
Returns:
The cached VarData.
"""
return ImmutableVarData.merge(
self._var1._get_all_var_data(),
self._var2._get_all_var_data(),
self._var_data,
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data