reflex/reflex/ivars/base.py
benedikt-bartscher 84ec08616b
simplify ImmutableComputedVar.__get__ (#3902)
* simplify ImmutableComputedVar.__get__

* ruff it
2024-09-09 15:50:40 -07:00

1880 lines
58 KiB
Python

"""Collection of base classes."""
from __future__ import annotations
import contextlib
import dataclasses
import datetime
import dis
import functools
import inspect
import json
import sys
import warnings
from types import CodeType, FunctionType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
List,
Literal,
NoReturn,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
get_args,
overload,
)
from typing_extensions import ParamSpec, get_type_hints, override
from reflex import constants
from reflex.base import Base
from reflex.constants.colors import Color
from reflex.utils import console, imports, serializers, types
from reflex.utils.exceptions import VarDependencyError, VarTypeError, VarValueError
from reflex.utils.format import format_state_name
from reflex.utils.types import get_origin
from reflex.vars import (
ComputedVar,
ImmutableVarData,
Var,
VarData,
_decode_var_immutable,
_extract_var_data,
_global_vars,
)
if TYPE_CHECKING:
from reflex.state import BaseState
from .function import FunctionVar, ToFunctionOperation
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)
from .object import ObjectVar, ToObjectOperation
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
VAR_TYPE = TypeVar("VAR_TYPE")
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ImmutableVar(Var, Generic[VAR_TYPE]):
"""Base class for immutable vars."""
# The name of the var.
_var_name: str = dataclasses.field()
# The type of the var.
_var_type: types.GenericType = dataclasses.field(default=Any)
# Extra metadata associated with the Var
_var_data: Optional[ImmutableVarData] = dataclasses.field(default=None)
def __str__(self) -> str:
"""String representation of the var. Guaranteed to be a valid Javascript expression.
Returns:
The name of the var.
"""
return self._var_name
@property
def _var_is_local(self) -> bool:
"""Whether this is a local javascript variable.
Returns:
False
"""
return False
@property
def _var_is_string(self) -> bool:
"""Whether the var is a string literal.
Returns:
False
"""
return False
@property
def _var_full_name_needs_state_prefix(self) -> bool:
"""Whether the full name of the var needs a _var_state prefix.
Returns:
False
"""
return False
def __post_init__(self):
"""Post-initialize the var."""
# Decode any inline Var markup and apply it to the instance
_var_data, _var_name = _decode_var_immutable(self._var_name)
if _var_data or _var_name != self._var_name:
self.__init__(
_var_name=_var_name,
_var_type=self._var_type,
_var_data=ImmutableVarData.merge(self._var_data, _var_data),
)
def __hash__(self) -> int:
"""Define a hash function for the var.
Returns:
The hash of the var.
"""
return hash((self._var_name, self._var_type, self._var_data))
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return self._var_data
def _replace(self, merge_var_data=None, **kwargs: Any):
"""Make a copy of this Var with updated fields.
Args:
merge_var_data: VarData to merge into the existing VarData.
**kwargs: Var fields to update.
Returns:
A new ImmutableVar with the updated fields overwriting the corresponding fields in this Var.
Raises:
TypeError: If _var_is_local, _var_is_string, or _var_full_name_needs_state_prefix is not None.
"""
if kwargs.get("_var_is_local", False) is not False:
raise TypeError(
"The _var_is_local argument is not supported for ImmutableVar."
)
if kwargs.get("_var_is_string", False) is not False:
raise TypeError(
"The _var_is_string argument is not supported for ImmutableVar."
)
if kwargs.get("_var_full_name_needs_state_prefix", False) is not False:
raise TypeError(
"The _var_full_name_needs_state_prefix argument is not supported for ImmutableVar."
)
return dataclasses.replace(
self,
_var_data=ImmutableVarData.merge(
kwargs.get("_var_data", self._var_data), merge_var_data
),
**kwargs,
)
@classmethod
def create(
cls,
value: Any,
_var_is_local: bool | None = None,
_var_is_string: bool | None = None,
_var_data: VarData | None = None,
) -> ImmutableVar | Var | None:
"""Create a var from a value.
Args:
value: The value to create the var from.
_var_is_local: Whether the var is local. Deprecated.
_var_is_string: Whether the var is a string literal. Deprecated.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
Raises:
VarTypeError: If the value is JSON-unserializable.
TypeError: If _var_is_local or _var_is_string is not None.
"""
if _var_is_local is not None:
raise TypeError(
"The _var_is_local argument is not supported for ImmutableVar."
)
if _var_is_string is not None:
raise TypeError(
"The _var_is_string argument is not supported for ImmutableVar."
)
from reflex.utils import format
# Check for none values.
if value is None:
return None
# If the value is already a var, do nothing.
if isinstance(value, Var):
return value
# Try to pull the imports and hooks from contained values.
if not isinstance(value, str):
_var_data = VarData.merge(*_extract_var_data(value), _var_data)
# Try to serialize the value.
type_ = type(value)
if type_ in types.JSONType:
name = value
else:
name, _serialized_type = serializers.serialize(value, get_type=True)
if name is None:
raise VarTypeError(
f"No JSON serializer found for var {value} of type {type_}."
)
name = name if isinstance(name, str) else format.json_dumps(name)
return cls(
_var_name=name,
_var_type=type_,
_var_data=(
ImmutableVarData(
state=_var_data.state,
imports=_var_data.imports,
hooks=_var_data.hooks,
)
if _var_data
else None
),
)
@classmethod
def create_safe(
cls,
value: Any,
_var_is_local: bool | None = None,
_var_is_string: bool | None = None,
_var_data: VarData | None = None,
) -> Var | ImmutableVar:
"""Create a var from a value, asserting that it is not None.
Args:
value: The value to create the var from.
_var_is_local: Whether the var is local. Deprecated.
_var_is_string: Whether the var is a string literal. Deprecated.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
"""
var = cls.create(
value,
_var_is_local=_var_is_local,
_var_is_string=_var_is_string,
_var_data=_var_data,
)
assert var is not None
return var
def __format__(self, format_spec: str) -> str:
"""Format the var into a Javascript equivalent to an f-string.
Args:
format_spec: The format specifier (Ignored for now).
Returns:
The formatted var.
"""
hashed_var = hash(self)
_global_vars[hashed_var] = self
# Encode the _var_data into the formatted output for tracking purposes.
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._var_name}"
@overload
def to(
self, output: Type[NumberVar], var_type: type[int] | type[float] = float
) -> ToNumberVarOperation: ...
@overload
def to(self, output: Type[BooleanVar]) -> ToBooleanVarOperation: ...
@overload
def to(
self,
output: Type[ArrayVar],
var_type: type[list] | type[tuple] | type[set] = list,
) -> ToArrayOperation: ...
@overload
def to(self, output: Type[StringVar]) -> ToStringOperation: ...
@overload
def to(
self, output: Type[ObjectVar], var_type: types.GenericType = dict
) -> ToObjectOperation: ...
@overload
def to(
self, output: Type[FunctionVar], var_type: Type[Callable] = Callable
) -> ToFunctionOperation: ...
@overload
def to(
self,
output: Type[OUTPUT] | types.GenericType,
var_type: types.GenericType | None = None,
) -> OUTPUT: ...
def to(
self,
output: Type[OUTPUT] | types.GenericType,
var_type: types.GenericType | None = None,
) -> Var:
"""Convert the var to a different type.
Args:
output: The output type.
var_type: The type of the var.
Raises:
TypeError: If the var_type is not a supported type for the output.
Returns:
The converted var.
"""
from .function import FunctionVar, ToFunctionOperation
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)
from .object import ObjectVar, ToObjectOperation
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
base_type = var_type
if types.is_optional(base_type):
base_type = types.get_args(base_type)[0]
fixed_type = get_origin(base_type) or base_type
fixed_output_type = get_origin(output) or output
# If the first argument is a python type, we map it to the corresponding Var type.
if fixed_output_type is dict:
return self.to(ObjectVar, output)
if fixed_output_type in (list, tuple, set):
return self.to(ArrayVar, output)
if fixed_output_type in (int, float):
return self.to(NumberVar, output)
if fixed_output_type is str:
return self.to(StringVar, output)
if fixed_output_type is bool:
return self.to(BooleanVar, output)
if issubclass(output, NumberVar):
if fixed_type is not None:
if fixed_type is Union:
inner_types = get_args(base_type)
if not all(issubclass(t, (int, float)) for t in inner_types):
raise TypeError(
f"Unsupported type {var_type} for NumberVar. Must be int or float."
)
elif not issubclass(fixed_type, (int, float)):
raise TypeError(
f"Unsupported type {var_type} for NumberVar. Must be int or float."
)
return ToNumberVarOperation.create(self, var_type or float)
if issubclass(output, BooleanVar):
return ToBooleanVarOperation.create(self)
if issubclass(output, ArrayVar):
if fixed_type is not None and not issubclass(
fixed_type, (list, tuple, set)
):
raise TypeError(
f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set."
)
return ToArrayOperation.create(self, var_type or list)
if issubclass(output, StringVar):
return ToStringOperation.create(self)
if issubclass(output, (ObjectVar, Base)):
return ToObjectOperation.create(self, var_type or dict)
if issubclass(output, FunctionVar):
# if fixed_type is not None and not issubclass(fixed_type, Callable):
# raise TypeError(
# f"Unsupported type {var_type} for FunctionVar. Must be Callable."
# )
return ToFunctionOperation.create(self, var_type or Callable)
# If we can't determine the first argument, we just replace the _var_type.
if not issubclass(output, Var) or var_type is None:
return dataclasses.replace(
self,
_var_type=output,
)
# We couldn't determine the output type to be any other Var type, so we replace the _var_type.
if var_type is not None:
return dataclasses.replace(
self,
_var_type=var_type,
)
return self
def guess_type(self) -> ImmutableVar:
"""Guesses the type of the variable based on its `_var_type` attribute.
Returns:
ImmutableVar: The guessed type of the variable.
Raises:
TypeError: If the type is not supported for guessing.
"""
from .number import BooleanVar, NumberVar
from .object import ObjectVar
from .sequence import ArrayVar, StringVar
var_type = self._var_type
if types.is_optional(var_type):
var_type = types.get_args(var_type)[0]
if var_type is Any:
return self
fixed_type = get_origin(var_type) or var_type
if fixed_type is Union:
inner_types = get_args(var_type)
if int in inner_types and float in inner_types:
return self.to(NumberVar, self._var_type)
return self
if not inspect.isclass(fixed_type):
raise TypeError(f"Unsupported type {var_type} for guess_type.")
if issubclass(fixed_type, bool):
return self.to(BooleanVar, self._var_type)
if issubclass(fixed_type, (int, float)):
return self.to(NumberVar, self._var_type)
if issubclass(fixed_type, dict):
return self.to(ObjectVar, self._var_type)
if issubclass(fixed_type, (list, tuple, set)):
return self.to(ArrayVar, self._var_type)
if issubclass(fixed_type, str):
return self.to(StringVar)
if issubclass(fixed_type, Base):
return self.to(ObjectVar, self._var_type)
return self
def get_default_value(self) -> Any:
"""Get the default value of the var.
Returns:
The default value of the var.
Raises:
ImportError: If the var is a dataframe and pandas is not installed.
"""
if types.is_optional(self._var_type):
return None
type_ = (
get_origin(self._var_type)
if types.is_generic_alias(self._var_type)
else self._var_type
)
if type_ is Literal:
args = get_args(self._var_type)
return args[0] if args else None
if issubclass(type_, str):
return ""
if issubclass(type_, types.get_args(Union[int, float])):
return 0
if issubclass(type_, bool):
return False
if issubclass(type_, list):
return []
if issubclass(type_, dict):
return {}
if issubclass(type_, tuple):
return ()
if types.is_dataframe(type_):
try:
import pandas as pd
return pd.DataFrame()
except ImportError as e:
raise ImportError(
"Please install pandas to use dataframes in your app."
) from e
return set() if issubclass(type_, set) else None
def get_setter_name(self, include_state: bool = True) -> str:
"""Get the name of the var's generated setter function.
Args:
include_state: Whether to include the state name in the setter name.
Returns:
The name of the setter function.
"""
var_name_parts = self._var_name.split(".")
setter = constants.SETTER_PREFIX + var_name_parts[-1]
if self._var_data is None:
return setter
if not include_state or self._var_data.state == "":
return setter
return ".".join((self._var_data.state, setter))
def get_setter(self) -> Callable[[BaseState, Any], None]:
"""Get the var's setter function.
Returns:
A function that that creates a setter for the var.
"""
actual_name = self._var_name.split(".")[-1]
def setter(state: BaseState, value: Any):
"""Get the setter for the var.
Args:
state: The state within which we add the setter function.
value: The value to set.
"""
if self._var_type in [int, float]:
try:
value = self._var_type(value)
setattr(state, actual_name, value)
except ValueError:
console.debug(
f"{type(state).__name__}.{self._var_name}: Failed conversion of {value} to '{self._var_type.__name__}'. Value not set.",
)
else:
setattr(state, actual_name, value)
setter.__qualname__ = self.get_setter_name()
return setter
def __eq__(self, other: Var | Any) -> BooleanVar:
"""Check if the current variable is equal to the given variable.
Args:
other (Var | Any): The variable to compare with.
Returns:
BooleanVar: A BooleanVar object representing the result of the equality check.
"""
from .number import equal_operation
return equal_operation(self, other)
def __ne__(self, other: Var | Any) -> BooleanVar:
"""Check if the current object is not equal to the given object.
Parameters:
other (Var | Any): The object to compare with.
Returns:
BooleanVar: A BooleanVar object representing the result of the comparison.
"""
from .number import equal_operation
return ~equal_operation(self, other)
def __gt__(self, other: Var | Any) -> BooleanVar:
"""Compare the current instance with another variable and return a BooleanVar representing the result of the greater than operation.
Args:
other (Var | Any): The variable to compare with.
Returns:
BooleanVar: A BooleanVar representing the result of the greater than operation.
"""
from .number import greater_than_operation
return greater_than_operation(self, other)
def __ge__(self, other: Var | Any) -> BooleanVar:
"""Check if the value of this variable is greater than or equal to the value of another variable or object.
Args:
other (Var | Any): The variable or object to compare with.
Returns:
BooleanVar: A BooleanVar object representing the result of the comparison.
"""
from .number import greater_than_or_equal_operation
return greater_than_or_equal_operation(self, other)
def __lt__(self, other: Var | Any) -> BooleanVar:
"""Compare the current instance with another variable using the less than (<) operator.
Args:
other: The variable to compare with.
Returns:
A `BooleanVar` object representing the result of the comparison.
"""
from .number import less_than_operation
return less_than_operation(self, other)
def __le__(self, other: Var | Any) -> BooleanVar:
"""Compare if the current instance is less than or equal to the given value.
Args:
other: The value to compare with.
Returns:
A BooleanVar object representing the result of the comparison.
"""
from .number import less_than_or_equal_operation
return less_than_or_equal_operation(self, other)
def bool(self) -> BooleanVar:
"""Convert the var to a boolean.
Returns:
The boolean var.
"""
from .number import boolify
return boolify(self)
def __and__(self, other: Var | Any) -> ImmutableVar:
"""Perform a logical AND operation on the current instance and another variable.
Args:
other: The variable to perform the logical AND operation with.
Returns:
A `BooleanVar` object representing the result of the logical AND operation.
"""
return and_operation(self, other)
def __rand__(self, other: Var | Any) -> ImmutableVar:
"""Perform a logical AND operation on the current instance and another variable.
Args:
other: The variable to perform the logical AND operation with.
Returns:
A `BooleanVar` object representing the result of the logical AND operation.
"""
return and_operation(other, self)
def __or__(self, other: Var | Any) -> ImmutableVar:
"""Perform a logical OR operation on the current instance and another variable.
Args:
other: The variable to perform the logical OR operation with.
Returns:
A `BooleanVar` object representing the result of the logical OR operation.
"""
return or_operation(self, other)
def __ror__(self, other: Var | Any) -> ImmutableVar:
"""Perform a logical OR operation on the current instance and another variable.
Args:
other: The variable to perform the logical OR operation with.
Returns:
A `BooleanVar` object representing the result of the logical OR operation.
"""
return or_operation(other, self)
def __invert__(self) -> BooleanVar:
"""Perform a logical NOT operation on the current instance.
Returns:
A `BooleanVar` object representing the result of the logical NOT operation.
"""
return ~self.bool()
def to_string(self) -> ImmutableVar:
"""Convert the var to a string.
Returns:
The string var.
"""
from .function import JSON_STRINGIFY
from .sequence import StringVar
return JSON_STRINGIFY.call(self).to(StringVar)
def as_ref(self) -> ImmutableVar:
"""Get a reference to the var.
Returns:
The reference to the var.
"""
from .object import ObjectVar
refs = ImmutableVar(
_var_name="refs",
_var_data=ImmutableVarData(
imports={
f"/{constants.Dirs.STATE_PATH}": [imports.ImportVar(tag="refs")]
}
),
).to(ObjectVar)
return refs[LiteralVar.create(str(self))]
def _type(self) -> StringVar:
"""Returns the type of the object.
This method uses the `typeof` function from the `FunctionStringVar` class
to determine the type of the object.
Returns:
StringVar: A string variable representing the type of the object.
"""
from .function import FunctionStringVar
from .sequence import StringVar
type_of = FunctionStringVar("typeof")
return type_of.call(self).to(StringVar)
OUTPUT = TypeVar("OUTPUT", bound=ImmutableVar)
class LiteralVar(ImmutableVar):
"""Base class for immutable literal vars."""
@classmethod
def create(
cls,
value: Any,
_var_data: VarData | None = None,
) -> Var:
"""Create a var from a value.
Args:
value: The value to create the var from.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
Raises:
TypeError: If the value is not a supported type for LiteralVar.
"""
from .number import LiteralBooleanVar, LiteralNumberVar
from .object import LiteralObjectVar
from .sequence import LiteralArrayVar, LiteralStringVar
if isinstance(value, Var):
if _var_data is None:
return value
return value._replace(merge_var_data=_var_data)
if isinstance(value, str):
return LiteralStringVar.create(value, _var_data=_var_data)
if isinstance(value, bool):
return LiteralBooleanVar.create(value, _var_data=_var_data)
if isinstance(value, (int, float)):
return LiteralNumberVar.create(value, _var_data=_var_data)
if isinstance(value, dict):
return LiteralObjectVar.create(value, _var_data=_var_data)
if isinstance(value, Color):
return LiteralStringVar.create(f"{value}", _var_data=_var_data)
if isinstance(value, (list, tuple, set)):
return LiteralArrayVar.create(value, _var_data=_var_data)
if value is None:
return ImmutableVar.create_safe("null", _var_data=_var_data)
from reflex.event import EventChain, EventSpec
from reflex.utils.format import get_event_handler_parts
from .function import ArgsFunctionOperation, FunctionStringVar
from .object import LiteralObjectVar
if isinstance(value, EventSpec):
event_name = LiteralVar.create(
".".join(filter(None, get_event_handler_parts(value.handler)))
)
event_args = LiteralVar.create(
{str(name): value for name, value in value.args}
)
event_client_name = LiteralVar.create(value.client_handler_name)
return FunctionStringVar("Event").call(
event_name,
event_args,
*([event_client_name] if value.client_handler_name else []),
)
if isinstance(value, EventChain):
sig = inspect.signature(value.args_spec) # type: ignore
if sig.parameters:
arg_def = tuple((f"_{p}" for p in sig.parameters))
arg_def_expr = LiteralVar.create(
[ImmutableVar.create_safe(arg) for arg in arg_def]
)
else:
# add a default argument for addEvents if none were specified in value.args_spec
# used to trigger the preventDefault() on the event.
arg_def = ("...args",)
arg_def_expr = ImmutableVar.create_safe("args")
return ArgsFunctionOperation.create(
arg_def,
FunctionStringVar.create("addEvents").call(
LiteralVar.create(
[LiteralVar.create(event) for event in value.events]
),
arg_def_expr,
LiteralVar.create(value.event_actions),
),
)
try:
from plotly.graph_objects import Figure, layout
from plotly.io import to_json
if isinstance(value, Figure):
return LiteralObjectVar.create(
json.loads(str(to_json(value))),
_var_type=Figure,
_var_data=_var_data,
)
if isinstance(value, layout.Template):
return LiteralObjectVar.create(
{
"data": json.loads(str(to_json(value.data))),
"layout": json.loads(str(to_json(value.layout))),
},
_var_type=layout.Template,
_var_data=_var_data,
)
except ImportError:
pass
try:
import base64
import io
from PIL.Image import MIME
from PIL.Image import Image as Img
if isinstance(value, Img):
with io.BytesIO() as buffer:
image_format = getattr(value, "format", None) or "PNG"
value.save(buffer, format=image_format)
try:
# Newer method to get the mime type, but does not always work.
mimetype = value.get_format_mimetype() # type: ignore
except AttributeError:
try:
# Fallback method
mimetype = MIME[image_format]
except KeyError:
# Unknown mime_type: warn and return image/png and hope the browser can sort it out.
warnings.warn( # noqa: B028
f"Unknown mime type for {value} {image_format}. Defaulting to image/png"
)
mimetype = "image/png"
return LiteralStringVar.create(
f"data:{mimetype};base64,{base64.b64encode(buffer.getvalue()).decode()}",
_var_data=_var_data,
)
except ImportError:
pass
if isinstance(value, Base):
return LiteralObjectVar.create(
{k: (None if callable(v) else v) for k, v in value.dict().items()},
_var_type=type(value),
_var_data=_var_data,
)
raise TypeError(
f"Unsupported type {type(value)} for LiteralVar. Tried to create a LiteralVar from {value}."
)
def __post_init__(self):
"""Post-initialize the var."""
def json(self) -> str:
"""Serialize the var to a JSON string.
Raises:
NotImplementedError: If the method is not implemented.
"""
raise NotImplementedError(
"LiteralVar subclasses must implement the json method."
)
P = ParamSpec("P")
T = TypeVar("T")
# NoReturn is used to match CustomVarOperationReturn with no type hint.
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[NoReturn]],
) -> Callable[P, ImmutableVar]: ...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[bool]],
) -> Callable[P, BooleanVar]: ...
NUMBER_T = TypeVar("NUMBER_T", int, float, Union[int, float])
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[NUMBER_T]],
) -> Callable[P, NumberVar[NUMBER_T]]: ...
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[str]],
) -> Callable[P, StringVar]: ...
LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set])
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[LIST_T]],
) -> Callable[P, ArrayVar[LIST_T]]: ...
OBJECT_TYPE = TypeVar("OBJECT_TYPE", bound=Dict)
@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]],
) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ...
def var_operation(
func: Callable[P, CustomVarOperationReturn[T]],
) -> Callable[P, ImmutableVar[T]]:
"""Decorator for creating a var operation.
Example:
```python
@var_operation
def add(a: NumberVar, b: NumberVar):
return custom_var_operation(f"{a} + {b}")
```
Args:
func: The function to decorate.
Returns:
The decorated function.
"""
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ImmutableVar[T]:
func_args = list(inspect.signature(func).parameters)
args_vars = {
func_args[i]: (LiteralVar.create(arg) if not isinstance(arg, Var) else arg)
for i, arg in enumerate(args)
}
kwargs_vars = {
key: LiteralVar.create(value) if not isinstance(value, Var) else value
for key, value in kwargs.items()
}
return CustomVarOperation.create(
args=tuple(list(args_vars.items()) + list(kwargs_vars.items())),
return_var=func(*args_vars.values(), **kwargs_vars), # type: ignore
).guess_type()
return wrapper
def unionize(*args: Type) -> Type:
"""Unionize the types.
Args:
args: The types to unionize.
Returns:
The unionized types.
"""
if not args:
return Any
first, *rest = args
if not rest:
return first
return Union[first, unionize(*rest)]
def figure_out_type(value: Any) -> types.GenericType:
"""Figure out the type of the value.
Args:
value: The value to figure out the type of.
Returns:
The type of the value.
"""
if isinstance(value, list):
return List[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, set):
return Set[unionize(*(figure_out_type(v) for v in value))]
if isinstance(value, tuple):
return Tuple[unionize(*(figure_out_type(v) for v in value)), ...]
if isinstance(value, dict):
return Dict[
unionize(*(figure_out_type(k) for k in value)),
unionize(*(figure_out_type(v) for v in value.values())),
]
if isinstance(value, Var):
return value._var_type
return type(value)
class cached_property_no_lock(functools.cached_property):
"""A special version of functools.cached_property that does not use a lock."""
def __init__(self, func):
"""Initialize the cached_property_no_lock.
Args:
func: The function to cache.
"""
super().__init__(func)
self.lock = contextlib.nullcontext()
class CachedVarOperation:
"""Base class for cached var operations to lower boilerplate code."""
def __post_init__(self):
"""Post-initialize the CachedVarOperation."""
object.__delattr__(self, "_var_name")
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute.
"""
if name == "_var_name":
return self._cached_var_name
parent_classes = inspect.getmro(self.__class__)
return parent_classes[parent_classes.index(CachedVarOperation) + 1].__getattr__( # type: ignore
self, name
)
def _get_all_var_data(self) -> ImmutableVarData | None:
"""Get all VarData associated with the Var.
Returns:
The VarData of the components and all of its children.
"""
return self._cached_get_all_var_data
@cached_property_no_lock
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get the cached VarData.
Returns:
The cached VarData.
"""
return ImmutableVarData.merge(
*map(
lambda value: (
value._get_all_var_data() if isinstance(value, Var) else None
),
map(
lambda field: getattr(self, field.name),
dataclasses.fields(self), # type: ignore
),
),
self._var_data,
)
def __hash__(self) -> int:
"""Calculate the hash of the object.
Returns:
The hash of the object.
"""
return hash(
(
self.__class__.__name__,
*[
getattr(self, field.name)
for field in dataclasses.fields(self) # type: ignore
if field.name not in ["_var_name", "_var_data", "_var_type"]
],
)
)
def and_operation(a: Var | Any, b: Var | Any) -> ImmutableVar:
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
return _and_operation(a, b) # type: ignore
@var_operation
def _and_operation(a: ImmutableVar, b: ImmutableVar):
"""Perform a logical AND operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical AND operation.
"""
return var_operation_return(
js_expression=f"({a} && {b})",
var_type=unionize(a._var_type, b._var_type),
)
def or_operation(a: Var | Any, b: Var | Any) -> ImmutableVar:
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
return _or_operation(a, b) # type: ignore
@var_operation
def _or_operation(a: ImmutableVar, b: ImmutableVar):
"""Perform a logical OR operation on two variables.
Args:
a: The first variable.
b: The second variable.
Returns:
The result of the logical OR operation.
"""
return var_operation_return(
js_expression=f"({a} || {b})",
var_type=unionize(a._var_type, b._var_type),
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ImmutableCallableVar(ImmutableVar):
"""Decorate a Var-returning function to act as both a Var and a function.
This is used as a compatibility shim for replacing Var objects in the
API with functions that return a family of Var.
"""
fn: Callable[..., Var] = dataclasses.field(
default_factory=lambda: lambda: ImmutableVar(_var_name="undefined")
)
original_var: Var = dataclasses.field(
default_factory=lambda: ImmutableVar(_var_name="undefined")
)
def __init__(self, fn: Callable[..., Var]):
"""Initialize a CallableVar.
Args:
fn: The function to decorate (must return Var)
"""
original_var = fn()
super(ImmutableCallableVar, self).__init__(
_var_name=original_var._var_name,
_var_type=original_var._var_type,
_var_data=ImmutableVarData.merge(original_var._get_all_var_data()),
)
object.__setattr__(self, "fn", fn)
object.__setattr__(self, "original_var", original_var)
def __call__(self, *args, **kwargs) -> Var:
"""Call the decorated function.
Args:
*args: The args to pass to the function.
**kwargs: The kwargs to pass to the function.
Returns:
The Var returned from calling the function.
"""
return self.fn(*args, **kwargs)
def __hash__(self) -> int:
"""Calculate the hash of the object.
Returns:
The hash of the object.
"""
return hash((self.__class__.__name__, self.original_var))
RETURN_TYPE = TypeVar("RETURN_TYPE")
DICT_KEY = TypeVar("DICT_KEY")
DICT_VAL = TypeVar("DICT_VAL")
LIST_INSIDE = TypeVar("LIST_INSIDE")
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ImmutableComputedVar(ImmutableVar[RETURN_TYPE]):
"""A field with computed getters."""
# Whether to track dependencies and cache computed values
_cache: bool = dataclasses.field(default=False)
# Whether the computed var is a backend var
_backend: bool = dataclasses.field(default=False)
# The initial value of the computed var
_initial_value: RETURN_TYPE | types.Unset = dataclasses.field(default=types.Unset())
# Explicit var dependencies to track
_static_deps: set[str] = dataclasses.field(default_factory=set)
# Whether var dependencies should be auto-determined
_auto_deps: bool = dataclasses.field(default=True)
# Interval at which the computed var should be updated
_update_interval: Optional[datetime.timedelta] = dataclasses.field(default=None)
_fget: Callable[[BaseState], RETURN_TYPE] = dataclasses.field(
default_factory=lambda: lambda _: None
) # type: ignore
def __init__(
self,
fget: Callable[[BASE_STATE], RETURN_TYPE],
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[int, datetime.timedelta]] = None,
backend: bool | None = None,
**kwargs,
):
"""Initialize a ComputedVar.
Args:
fget: The getter function.
initial_value: The initial value of the computed var.
cache: Whether to cache the computed value.
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var.
**kwargs: additional attributes to set on the instance
Raises:
TypeError: If the computed var dependencies are not Var instances or var names.
"""
hints = get_type_hints(fget)
hint = hints.get("return", Any)
kwargs["_var_name"] = kwargs.pop("_var_name", fget.__name__)
kwargs["_var_type"] = kwargs.pop("_var_type", hint)
super(ImmutableComputedVar, self).__init__(
_var_name=kwargs.pop("_var_name"),
_var_type=kwargs.pop("_var_type"),
_var_data=ImmutableVarData.merge(kwargs.pop("_var_data", None)),
)
if backend is None:
backend = fget.__name__.startswith("_")
object.__setattr__(self, "_backend", backend)
object.__setattr__(self, "_initial_value", initial_value)
object.__setattr__(self, "_cache", cache)
if isinstance(interval, int):
interval = datetime.timedelta(seconds=interval)
object.__setattr__(self, "_update_interval", interval)
if deps is None:
deps = []
else:
for dep in deps:
if isinstance(dep, Var):
continue
if isinstance(dep, str) and dep != "":
continue
raise TypeError(
"ComputedVar dependencies must be Var instances or var names (non-empty strings)."
)
object.__setattr__(
self,
"_static_deps",
{dep._var_name if isinstance(dep, Var) else dep for dep in deps},
)
object.__setattr__(self, "_auto_deps", auto_deps)
object.__setattr__(self, "_fget", fget)
@override
def _replace(self, merge_var_data=None, **kwargs: Any) -> ImmutableComputedVar:
"""Replace the attributes of the ComputedVar.
Args:
merge_var_data: VarData to merge into the existing VarData.
**kwargs: Var fields to update.
Returns:
The new ComputedVar instance.
Raises:
TypeError: If kwargs contains keys that are not allowed.
"""
field_values = dict(
fget=kwargs.pop("fget", self._fget),
initial_value=kwargs.pop("initial_value", self._initial_value),
cache=kwargs.pop("cache", self._cache),
deps=kwargs.pop("deps", self._static_deps),
auto_deps=kwargs.pop("auto_deps", self._auto_deps),
interval=kwargs.pop("interval", self._update_interval),
backend=kwargs.pop("backend", self._backend),
_var_name=kwargs.pop("_var_name", self._var_name),
_var_type=kwargs.pop("_var_type", self._var_type),
_var_data=kwargs.pop(
"_var_data", VarData.merge(self._var_data, merge_var_data)
),
)
if kwargs:
unexpected_kwargs = ", ".join(kwargs.keys())
raise TypeError(f"Unexpected keyword arguments: {unexpected_kwargs}")
return ImmutableComputedVar(**field_values)
@property
def _cache_attr(self) -> str:
"""Get the attribute used to cache the value on the instance.
Returns:
An attribute name.
"""
return f"__cached_{self._var_name}"
@property
def _last_updated_attr(self) -> str:
"""Get the attribute used to store the last updated timestamp.
Returns:
An attribute name.
"""
return f"__last_updated_{self._var_name}"
def needs_update(self, instance: BaseState) -> bool:
"""Check if the computed var needs to be updated.
Args:
instance: The state instance that the computed var is attached to.
Returns:
True if the computed var needs to be updated, False otherwise.
"""
if self._update_interval is None:
return False
last_updated = getattr(instance, self._last_updated_attr, None)
if last_updated is None:
return True
return datetime.datetime.now() - last_updated > self._update_interval
@overload
def __get__(
self: ImmutableComputedVar[int] | ImmutableComputedVar[float],
instance: None,
owner: Type,
) -> NumberVar: ...
@overload
def __get__(
self: ImmutableComputedVar[str],
instance: None,
owner: Type,
) -> StringVar: ...
@overload
def __get__(
self: ImmutableComputedVar[dict[DICT_KEY, DICT_VAL]],
instance: None,
owner: Type,
) -> ObjectVar[dict[DICT_KEY, DICT_VAL]]: ...
@overload
def __get__(
self: ImmutableComputedVar[list[LIST_INSIDE]],
instance: None,
owner: Type,
) -> ArrayVar[list[LIST_INSIDE]]: ...
@overload
def __get__(
self: ImmutableComputedVar[set[LIST_INSIDE]],
instance: None,
owner: Type,
) -> ArrayVar[set[LIST_INSIDE]]: ...
@overload
def __get__(
self: ImmutableComputedVar[tuple[LIST_INSIDE, ...]],
instance: None,
owner: Type,
) -> ArrayVar[tuple[LIST_INSIDE, ...]]: ...
@overload
def __get__(
self, instance: None, owner: Type
) -> ImmutableComputedVar[RETURN_TYPE]: ...
@overload
def __get__(self, instance: BaseState, owner: Type) -> RETURN_TYPE: ...
def __get__(self, instance: BaseState | None, owner):
"""Get the ComputedVar value.
If the value is already cached on the instance, return the cached value.
Args:
instance: the instance of the class accessing this computed var.
owner: the class that this descriptor is attached to.
Returns:
The value of the var for the given instance.
"""
if instance is None:
state_where_defined = owner
while self.fget.__name__ in state_where_defined.inherited_vars:
state_where_defined = state_where_defined.get_parent_state()
return self._replace(
_var_name=format_state_name(state_where_defined.get_full_name())
+ "."
+ self._var_name,
merge_var_data=ImmutableVarData.from_state(state_where_defined),
).guess_type()
if not self._cache:
return self.fget(instance)
# handle caching
if not hasattr(instance, self._cache_attr) or self.needs_update(instance):
# Set cache attr on state instance.
setattr(instance, self._cache_attr, self.fget(instance))
# Ensure the computed var gets serialized to redis.
instance._was_touched = True
# Set the last updated timestamp on the state instance.
setattr(instance, self._last_updated_attr, datetime.datetime.now())
return getattr(instance, self._cache_attr)
def _deps(
self,
objclass: Type,
obj: FunctionType | CodeType | None = None,
self_name: Optional[str] = None,
) -> set[str]:
"""Determine var dependencies of this ComputedVar.
Save references to attributes accessed on "self". Recursively called
when the function makes a method call on "self" or define comprehensions
or nested functions that may reference "self".
Args:
objclass: the class obj this ComputedVar is attached to.
obj: the object to disassemble (defaults to the fget function).
self_name: if specified, look for this name in LOAD_FAST and LOAD_DEREF instructions.
Returns:
A set of variable names accessed by the given obj.
Raises:
VarValueError: if the function references the get_state, parent_state, or substates attributes
(cannot track deps in a related state, only implicitly via parent state).
"""
if not self._auto_deps:
return self._static_deps
d = self._static_deps.copy()
if obj is None:
fget = self._fget
if fget is not None:
obj = cast(FunctionType, fget)
else:
return set()
with contextlib.suppress(AttributeError):
# unbox functools.partial
obj = cast(FunctionType, obj.func) # type: ignore
with contextlib.suppress(AttributeError):
# unbox EventHandler
obj = cast(FunctionType, obj.fn) # type: ignore
if self_name is None and isinstance(obj, FunctionType):
try:
# the first argument to the function is the name of "self" arg
self_name = obj.__code__.co_varnames[0]
except (AttributeError, IndexError):
self_name = None
if self_name is None:
# cannot reference attributes on self if method takes no args
return set()
invalid_names = ["get_state", "parent_state", "substates", "get_substate"]
self_is_top_of_stack = False
for instruction in dis.get_instructions(obj):
if (
instruction.opname in ("LOAD_FAST", "LOAD_DEREF")
and instruction.argval == self_name
):
# bytecode loaded the class instance to the top of stack, next load instruction
# is referencing an attribute on self
self_is_top_of_stack = True
continue
if self_is_top_of_stack and instruction.opname in (
"LOAD_ATTR",
"LOAD_METHOD",
):
try:
ref_obj = getattr(objclass, instruction.argval)
except Exception:
ref_obj = None
if instruction.argval in invalid_names:
raise VarValueError(
f"Cached var {self._var_full_name} cannot access arbitrary state via `{instruction.argval}`."
)
if callable(ref_obj):
# recurse into callable attributes
d.update(
self._deps(
objclass=objclass,
obj=ref_obj,
)
)
# recurse into property fget functions
elif isinstance(ref_obj, property) and not isinstance(
ref_obj, ImmutableComputedVar
):
d.update(
self._deps(
objclass=objclass,
obj=ref_obj.fget, # type: ignore
)
)
elif (
instruction.argval in objclass.backend_vars
or instruction.argval in objclass.vars
):
# var access
d.add(instruction.argval)
elif instruction.opname == "LOAD_CONST" and isinstance(
instruction.argval, CodeType
):
# recurse into nested functions / comprehensions, which can reference
# instance attributes from the outer scope
d.update(
self._deps(
objclass=objclass,
obj=instruction.argval,
self_name=self_name,
)
)
self_is_top_of_stack = False
return d
def mark_dirty(self, instance) -> None:
"""Mark this ComputedVar as dirty.
Args:
instance: the state instance that needs to recompute the value.
"""
with contextlib.suppress(AttributeError):
delattr(instance, self._cache_attr)
def _determine_var_type(self) -> Type:
"""Get the type of the var.
Returns:
The type of the var.
"""
hints = get_type_hints(self._fget)
if "return" in hints:
return hints["return"]
return Any
@property
def __class__(self) -> Type:
"""Get the class of the var.
Returns:
The class of the var.
"""
return ComputedVar
@property
def fget(self) -> Callable[[BaseState], RETURN_TYPE]:
"""Get the getter function.
Returns:
The getter function.
"""
return self._fget
if TYPE_CHECKING:
BASE_STATE = TypeVar("BASE_STATE", bound=BaseState)
@overload
def immutable_computed_var(
fget: None = None,
initial_value: Any | types.Unset = types.Unset(),
cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
backend: bool | None = None,
**kwargs,
) -> Callable[
[Callable[[BASE_STATE], RETURN_TYPE]], ImmutableComputedVar[RETURN_TYPE]
]: ...
@overload
def immutable_computed_var(
fget: Callable[[BASE_STATE], RETURN_TYPE],
initial_value: RETURN_TYPE | types.Unset = types.Unset(),
cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
backend: bool | None = None,
**kwargs,
) -> ImmutableComputedVar[RETURN_TYPE]: ...
def immutable_computed_var(
fget: Callable[[BASE_STATE], Any] | None = None,
initial_value: Any | types.Unset = types.Unset(),
cache: bool = False,
deps: Optional[List[Union[str, Var]]] = None,
auto_deps: bool = True,
interval: Optional[Union[datetime.timedelta, int]] = None,
backend: bool | None = None,
**kwargs,
) -> (
ImmutableComputedVar | Callable[[Callable[[BASE_STATE], Any]], ImmutableComputedVar]
):
"""A ComputedVar decorator with or without kwargs.
Args:
fget: The getter function.
initial_value: The initial value of the computed var.
cache: Whether to cache the computed value.
deps: Explicit var dependencies to track.
auto_deps: Whether var dependencies should be auto-determined.
interval: Interval at which the computed var should be updated.
backend: Whether the computed var is a backend var.
**kwargs: additional attributes to set on the instance
Returns:
A ComputedVar instance.
Raises:
ValueError: If caching is disabled and an update interval is set.
VarDependencyError: If user supplies dependencies without caching.
"""
if cache is False and interval is not None:
raise ValueError("Cannot set update interval without caching.")
if cache is False and (deps is not None or auto_deps is False):
raise VarDependencyError("Cannot track dependencies without caching.")
if fget is not None:
return ImmutableComputedVar(fget, cache=cache)
def wrapper(fget: Callable[[BASE_STATE], Any]) -> ImmutableComputedVar:
return ImmutableComputedVar(
fget,
initial_value=initial_value,
cache=cache,
deps=deps,
auto_deps=auto_deps,
interval=interval,
backend=backend,
**kwargs,
)
return wrapper
RETURN = TypeVar("RETURN")
class CustomVarOperationReturn(ImmutableVar[RETURN]):
"""Base class for custom var operations."""
@classmethod
def create(
cls,
js_expression: str,
_var_type: Type[RETURN] | None = None,
_var_data: VarData | None = None,
) -> CustomVarOperationReturn[RETURN]:
"""Create a CustomVarOperation.
Args:
js_expression: The JavaScript expression to evaluate.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The CustomVarOperation.
"""
return CustomVarOperationReturn(
_var_name=js_expression,
_var_type=_var_type or Any,
_var_data=ImmutableVarData.merge(_var_data),
)
def var_operation_return(
js_expression: str,
var_type: Type[RETURN] | None = None,
) -> CustomVarOperationReturn[RETURN]:
"""Shortcut for creating a CustomVarOperationReturn.
Args:
js_expression: The JavaScript expression to evaluate.
var_type: The type of the var.
Returns:
The CustomVarOperationReturn.
"""
return CustomVarOperationReturn.create(js_expression, var_type)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class CustomVarOperation(CachedVarOperation, ImmutableVar[T]):
"""Base class for custom var operations."""
_args: Tuple[Tuple[str, Var], ...] = dataclasses.field(default_factory=tuple)
_return: CustomVarOperationReturn[T] = dataclasses.field(
default_factory=lambda: CustomVarOperationReturn.create("")
)
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""Get the cached var name.
Returns:
The cached var name.
"""
return str(self._return)
@cached_property_no_lock
def _cached_get_all_var_data(self) -> ImmutableVarData | None:
"""Get the cached VarData.
Returns:
The cached VarData.
"""
return ImmutableVarData.merge(
*map(
lambda arg: arg[1]._get_all_var_data(),
self._args,
),
self._return._get_all_var_data(),
self._var_data,
)
@classmethod
def create(
cls,
args: Tuple[Tuple[str, Var], ...],
return_var: CustomVarOperationReturn[T],
_var_data: VarData | None = None,
) -> CustomVarOperation[T]:
"""Create a CustomVarOperation.
Args:
args: The arguments to the operation.
return_var: The return var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The CustomVarOperation.
"""
return CustomVarOperation(
_var_name="",
_var_type=return_var._var_type,
_var_data=ImmutableVarData.merge(_var_data),
_args=args,
_return=return_var,
)