reflex/pynecone/vars.py
2023-06-23 12:56:05 -07:00

1159 lines
31 KiB
Python

"""Define a state var."""
from __future__ import annotations
import contextlib
import dis
import json
import random
import string
from abc import ABC
from types import FunctionType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Set,
Type,
Union,
_GenericAlias, # type: ignore
cast,
get_type_hints,
)
from plotly.graph_objects import Figure
from plotly.io import to_json
from pydantic.fields import ModelField
from pynecone import constants
from pynecone.base import Base
from pynecone.utils import format, types
if TYPE_CHECKING:
from pynecone.state import State
# Set of unique variable names.
USED_VARIABLES = set()
def get_unique_variable_name() -> str:
"""Get a unique variable name.
Returns:
The unique variable name.
"""
name = "".join([random.choice(string.ascii_lowercase) for _ in range(8)])
if name not in USED_VARIABLES:
USED_VARIABLES.add(name)
return name
return get_unique_variable_name()
class Var(ABC):
"""An abstract var."""
# The name of the var.
name: str
# The type of the var.
type_: Type
# The name of the enclosing state.
state: str = ""
# Whether this is a local javascript variable.
is_local: bool = False
# Whether the var is a string literal.
is_string: bool = False
@classmethod
def create(
cls, value: Any, is_local: bool = True, is_string: bool = False
) -> Optional[Var]:
"""Create a var from a value.
Args:
value: The value to create the var from.
is_local: Whether the var is local.
is_string: Whether the var is a string literal.
Returns:
The var.
Raises:
TypeError: If the value is JSON-unserializable.
"""
# Check for none values.
if value is None:
return None
# If the value is already a var, do nothing.
if isinstance(value, Var):
return value
type_ = type(value)
# Special case for plotly figures.
if isinstance(value, Figure):
value = json.loads(to_json(value))["data"] # type: ignore
type_ = Figure
if isinstance(value, dict):
value = format.format_dict(value)
try:
name = value if isinstance(value, str) else json.dumps(value)
except TypeError as e:
raise TypeError(
f"To create a Var must be Var or JSON-serializable. Got {value} of type {type(value)}."
) from e
return BaseVar(name=name, type_=type_, is_local=is_local, is_string=is_string)
@classmethod
def create_safe(
cls, value: Any, is_local: bool = True, is_string: bool = False
) -> Var:
"""Create a var from a value, guaranteeing that it is not None.
Args:
value: The value to create the var from.
is_local: Whether the var is local.
is_string: Whether the var is a string literal.
Returns:
The var.
"""
var = cls.create(value, is_local=is_local, is_string=is_string)
assert var is not None
return var
@classmethod
def __class_getitem__(cls, type_: str) -> _GenericAlias:
"""Get a typed var.
Args:
type_: The type of the var.
Returns:
The var class item.
"""
return _GenericAlias(cls, type_)
def equals(self, other: Var) -> bool:
"""Check if two vars are equal.
Args:
other: The other var to compare.
Returns:
Whether the vars are equal.
"""
return (
self.name == other.name
and self.type_ == other.type_
and self.state == other.state
and self.is_local == other.is_local
)
def to_string(self) -> Var:
"""Convert a var to a string.
Returns:
The stringified var.
"""
return self.operation(fn="JSON.stringify")
def __hash__(self) -> int:
"""Define a hash function for a var.
Returns:
The hash of the var.
"""
return hash((self.name, str(self.type_)))
def __str__(self) -> str:
"""Wrap the var so it can be used in templates.
Returns:
The wrapped var, i.e. {state.var}.
"""
out = self.full_name if self.is_local else format.wrap(self.full_name, "{")
if self.is_string:
out = format.format_string(out)
return out
def __getitem__(self, i: Any) -> Var:
"""Index into a var.
Args:
i: The index to index into.
Returns:
The indexed var.
Raises:
TypeError: If the var is not indexable.
"""
# Indexing is only supported for lists, dicts, and dataframes.
if not (
types._issubclass(self.type_, Union[List, Dict])
or types.is_dataframe(self.type_)
):
if self.type_ == Any:
raise TypeError(
f"Could not index into var of type Any. (If you are trying to index into a state var, add the correct type annotation to the var.)"
)
raise TypeError(
f"Var {self.name} of type {self.type_} does not support indexing."
)
# The type of the indexed var.
type_ = Any
# Convert any vars to local vars.
if isinstance(i, Var):
i = BaseVar(name=i.name, type_=i.type_, state=i.state, is_local=True)
# Handle list indexing.
if types._issubclass(self.type_, List):
# List indices must be ints, slices, or vars.
if not isinstance(i, types.get_args(Union[int, slice, Var])):
raise TypeError("Index must be an integer.")
# Handle slices first.
if isinstance(i, slice):
# Get the start and stop indices.
start = i.start or 0
stop = i.stop or "undefined"
# Use the slice function.
return BaseVar(
name=f"{self.name}.slice({start}, {stop})",
type_=self.type_,
state=self.state,
)
# Get the type of the indexed var.
type_ = (
types.get_args(self.type_)[0]
if types.is_generic_alias(self.type_)
else Any
)
# Use `at` to support negative indices.
return BaseVar(
name=f"{self.name}.at({i})",
type_=type_,
state=self.state,
)
# Dictionary / dataframe indexing.
# Get the type of the indexed var.
if isinstance(i, str):
i = format.wrap(i, '"')
type_ = (
types.get_args(self.type_)[1] if types.is_generic_alias(self.type_) else Any
)
# Use normal indexing here.
return BaseVar(
name=f"{self.name}[{i}]",
type_=type_,
state=self.state,
)
def __getattribute__(self, name: str) -> Var:
"""Get a var attribute.
Args:
name: The name of the attribute.
Returns:
The var attribute.
Raises:
AttributeError: If the var is wrongly annotated or can't find attribute.
TypeError: If an annotation to the var isn't provided.
"""
try:
return super().__getattribute__(name)
except Exception as e:
# Check if the attribute is one of the class fields.
if not name.startswith("_"):
if self.type_ == Any:
raise TypeError(
f"You must provide an annotation for the state var `{self.full_name}`. Annotation cannot be `{self.type_}`"
) from None
if hasattr(self.type_, "__fields__") and name in self.type_.__fields__:
type_ = self.type_.__fields__[name].outer_type_
if isinstance(type_, ModelField):
type_ = type_.type_
return BaseVar(
name=f"{self.name}.{name}",
type_=type_,
state=self.state,
)
raise AttributeError(
f"The State var `{self.full_name}` has no attribute '{name}' or may have been annotated "
f"wrongly.\n"
f"original message: {e.args[0]}"
) from e
def operation(
self,
op: str = "",
other: Optional[Var] = None,
type_: Optional[Type] = None,
flip: bool = False,
fn: Optional[str] = None,
) -> Var:
"""Perform an operation on a var.
Args:
op: The operation to perform.
other: The other var to perform the operation on.
type_: The type of the operation result.
flip: Whether to flip the order of the operation.
fn: A function to apply to the operation.
Returns:
The operation result.
"""
# Wrap strings in quotes.
if isinstance(other, str):
other = Var.create(json.dumps(other))
else:
other = Var.create(other)
if type_ is None:
type_ = self.type_
if other is None:
name = f"{op}{self.full_name}"
else:
props = (other, self) if flip else (self, other)
name = f"{props[0].full_name} {op} {props[1].full_name}"
if fn is None:
name = format.wrap(name, "(")
if fn is not None:
name = f"{fn}({name})"
return BaseVar(
name=name,
type_=type_,
)
def compare(self, op: str, other: Var) -> Var:
"""Compare two vars with inequalities.
Args:
op: The comparison operator.
other: The other var to compare with.
Returns:
The comparison result.
"""
return self.operation(op, other, bool)
def __invert__(self) -> Var:
"""Invert a var.
Returns:
The inverted var.
"""
return self.operation("!", type_=bool)
def __neg__(self) -> Var:
"""Negate a var.
Returns:
The negated var.
"""
return self.operation(fn="-")
def __abs__(self) -> Var:
"""Get the absolute value of a var.
Returns:
A var with the absolute value.
"""
return self.operation(fn="Math.abs")
def length(self) -> Var:
"""Get the length of a list var.
Returns:
A var with the absolute value.
Raises:
TypeError: If the var is not a list.
"""
if not types._issubclass(self.type_, List):
raise TypeError(f"Cannot get length of non-list var {self}.")
return BaseVar(
name=f"{self.full_name}.length",
type_=int,
)
def __eq__(self, other: Var) -> Var:
"""Perform an equality comparison.
Args:
other: The other var to compare with.
Returns:
A var representing the equality comparison.
"""
return self.compare("===", other)
def __ne__(self, other: Var) -> Var:
"""Perform an inequality comparison.
Args:
other: The other var to compare with.
Returns:
A var representing the inequality comparison.
"""
return self.compare("!==", other)
def __gt__(self, other: Var) -> Var:
"""Perform a greater than comparison.
Args:
other: The other var to compare with.
Returns:
A var representing the greater than comparison.
"""
return self.compare(">", other)
def __ge__(self, other: Var) -> Var:
"""Perform a greater than or equal to comparison.
Args:
other: The other var to compare with.
Returns:
A var representing the greater than or equal to comparison.
"""
return self.compare(">=", other)
def __lt__(self, other: Var) -> Var:
"""Perform a less than comparison.
Args:
other: The other var to compare with.
Returns:
A var representing the less than comparison.
"""
return self.compare("<", other)
def __le__(self, other: Var) -> Var:
"""Perform a less than or equal to comparison.
Args:
other: The other var to compare with.
Returns:
A var representing the less than or equal to comparison.
"""
return self.compare("<=", other)
def __add__(self, other: Var) -> Var:
"""Add two vars.
Args:
other: The other var to add.
Returns:
A var representing the sum.
"""
return self.operation("+", other)
def __radd__(self, other: Var) -> Var:
"""Add two vars.
Args:
other: The other var to add.
Returns:
A var representing the sum.
"""
return self.operation("+", other, flip=True)
def __sub__(self, other: Var) -> Var:
"""Subtract two vars.
Args:
other: The other var to subtract.
Returns:
A var representing the difference.
"""
return self.operation("-", other)
def __rsub__(self, other: Var) -> Var:
"""Subtract two vars.
Args:
other: The other var to subtract.
Returns:
A var representing the difference.
"""
return self.operation("-", other, flip=True)
def __mul__(self, other: Var) -> Var:
"""Multiply two vars.
Args:
other: The other var to multiply.
Returns:
A var representing the product.
"""
return self.operation("*", other)
def __rmul__(self, other: Var) -> Var:
"""Multiply two vars.
Args:
other: The other var to multiply.
Returns:
A var representing the product.
"""
return self.operation("*", other, flip=True)
def __pow__(self, other: Var) -> Var:
"""Raise a var to a power.
Args:
other: The power to raise to.
Returns:
A var representing the power.
"""
return self.operation(",", other, fn="Math.pow")
def __rpow__(self, other: Var) -> Var:
"""Raise a var to a power.
Args:
other: The power to raise to.
Returns:
A var representing the power.
"""
return self.operation(",", other, flip=True, fn="Math.pow")
def __truediv__(self, other: Var) -> Var:
"""Divide two vars.
Args:
other: The other var to divide.
Returns:
A var representing the quotient.
"""
return self.operation("/", other)
def __rtruediv__(self, other: Var) -> Var:
"""Divide two vars.
Args:
other: The other var to divide.
Returns:
A var representing the quotient.
"""
return self.operation("/", other, flip=True)
def __floordiv__(self, other: Var) -> Var:
"""Divide two vars.
Args:
other: The other var to divide.
Returns:
A var representing the quotient.
"""
return self.operation("/", other, fn="Math.floor")
def __mod__(self, other: Var) -> Var:
"""Get the remainder of two vars.
Args:
other: The other var to divide.
Returns:
A var representing the remainder.
"""
return self.operation("%", other)
def __rmod__(self, other: Var) -> Var:
"""Get the remainder of two vars.
Args:
other: The other var to divide.
Returns:
A var representing the remainder.
"""
return self.operation("%", other, flip=True)
def __and__(self, other: Var) -> Var:
"""Perform a logical and.
Args:
other: The other var to perform the logical and with.
Returns:
A var representing the logical and.
"""
return self.operation("&&", other)
def __rand__(self, other: Var) -> Var:
"""Perform a logical and.
Args:
other: The other var to perform the logical and with.
Returns:
A var representing the logical and.
"""
return self.operation("&&", other, flip=True)
def __or__(self, other: Var) -> Var:
"""Perform a logical or.
Args:
other: The other var to perform the logical or with.
Returns:
A var representing the logical or.
"""
return self.operation("||", other)
def __ror__(self, other: Var) -> Var:
"""Perform a logical or.
Args:
other: The other var to perform the logical or with.
Returns:
A var representing the logical or.
"""
return self.operation("||", other, flip=True)
def foreach(self, fn: Callable) -> Var:
"""Return a list of components. after doing a foreach on this var.
Args:
fn: The function to call on each component.
Returns:
A var representing foreach operation.
"""
arg = BaseVar(
name=get_unique_variable_name(),
type_=self.type_,
)
return BaseVar(
name=f"{self.full_name}.map(({arg.name}, i) => {fn(arg, key='i')})",
type_=self.type_,
)
def to(self, type_: Type) -> Var:
"""Convert the type of the var.
Args:
type_: The type to convert to.
Returns:
The converted var.
"""
return BaseVar(
name=self.name,
type_=type_,
state=self.state,
is_local=self.is_local,
)
@property
def full_name(self) -> str:
"""Get the full name of the var.
Returns:
The full name of the var.
"""
return self.name if self.state == "" else ".".join([self.state, self.name])
def set_state(self, state: Type[State]) -> Any:
"""Set the state of the var.
Args:
state: The state to set.
Returns:
The var with the set state.
"""
self.state = state.get_full_name()
return self
class BaseVar(Var, Base):
"""A base (non-computed) var of the app state."""
# The name of the var.
name: str
# The type of the var.
type_: Any
# The name of the enclosing state.
state: str = ""
# Whether this is a local javascript variable.
is_local: bool = False
# Whether this var is a raw string.
is_string: bool = False
def __hash__(self) -> int:
"""Define a hash function for a var.
Returns:
The hash of the var.
"""
return hash((self.name, str(self.type_)))
def get_default_value(self) -> Any:
"""Get the default value of the var.
Returns:
The default value of the var.
Raises:
ImportError: If the var is a dataframe and pandas is not installed.
"""
type_ = (
self.type_.__origin__ if types.is_generic_alias(self.type_) else self.type_
)
if issubclass(type_, str):
return ""
if issubclass(type_, types.get_args(Union[int, float])):
return 0
if issubclass(type_, bool):
return False
if issubclass(type_, list):
return []
if issubclass(type_, dict):
return {}
if issubclass(type_, tuple):
return ()
if types.is_dataframe(type_):
try:
import pandas as pd
return pd.DataFrame()
except ImportError as e:
raise ImportError(
"Please install pandas to use dataframes in your app."
) from e
return set() if issubclass(type_, set) else None
def get_setter_name(self, include_state: bool = True) -> str:
"""Get the name of the var's generated setter function.
Args:
include_state: Whether to include the state name in the setter name.
Returns:
The name of the setter function.
"""
setter = constants.SETTER_PREFIX + self.name
if not include_state or self.state == "":
return setter
return ".".join((self.state, setter))
def get_setter(self) -> Callable[[State, Any], None]:
"""Get the var's setter function.
Returns:
A function that that creates a setter for the var.
"""
def setter(state: State, value: Any):
"""Get the setter for the var.
Args:
state: The state within which we add the setter function.
value: The value to set.
"""
setattr(state, self.name, value)
setter.__qualname__ = self.get_setter_name()
return setter
class ComputedVar(Var, property):
"""A field with computed getters."""
# Whether to track dependencies and cache computed values
cache: bool = False
@property
def name(self) -> str:
"""Get the name of the var.
Returns:
The name of the var.
"""
assert self.fget is not None, "Var must have a getter."
return self.fget.__name__
@property
def cache_attr(self) -> str:
"""Get the attribute used to cache the value on the instance.
Returns:
An attribute name.
"""
return f"__cached_{self.name}"
def __get__(self, instance, owner):
"""Get the ComputedVar value.
If the value is already cached on the instance, return the cached value.
Args:
instance: the instance of the class accessing this computed var.
owner: the class that this descriptor is attached to.
Returns:
The value of the var for the given instance.
"""
if instance is None or not self.cache:
return super().__get__(instance, owner)
# handle caching
if not hasattr(instance, self.cache_attr):
setattr(instance, self.cache_attr, super().__get__(instance, owner))
return getattr(instance, self.cache_attr)
def deps(
self,
objclass: Type,
obj: Optional[FunctionType] = None,
) -> Set[str]:
"""Determine var dependencies of this ComputedVar.
Save references to attributes accessed on "self". Recursively called
when the function makes a method call on "self".
Args:
objclass: the class obj this ComputedVar is attached to.
obj: the object to disassemble (defaults to the fget function).
Returns:
A set of variable names accessed by the given obj.
"""
d = set()
if obj is None:
if self.fget is not None:
obj = cast(FunctionType, self.fget)
else:
return set()
if not obj.__code__.co_varnames:
# cannot reference self if method takes no args
return set()
self_name = obj.__code__.co_varnames[0]
self_is_top_of_stack = False
for instruction in dis.get_instructions(obj):
if instruction.opname == "LOAD_FAST" and instruction.argval == self_name:
self_is_top_of_stack = True
continue
if self_is_top_of_stack and instruction.opname == "LOAD_ATTR":
d.add(instruction.argval)
elif self_is_top_of_stack and instruction.opname == "LOAD_METHOD":
d.update(
self.deps(
objclass=objclass,
obj=getattr(objclass, instruction.argval),
)
)
self_is_top_of_stack = False
return d
def mark_dirty(self, instance) -> None:
"""Mark this ComputedVar as dirty.
Args:
instance: the state instance that needs to recompute the value.
"""
with contextlib.suppress(AttributeError):
delattr(instance, self.cache_attr)
@property
def type_(self):
"""Get the type of the var.
Returns:
The type of the var.
"""
hints = get_type_hints(self.fget)
if "return" in hints:
return hints["return"]
return Any
def cached_var(fget: Callable[[Any], Any]) -> ComputedVar:
"""A field with computed getter that tracks other state dependencies.
The cached_var will only be recalculated when other state vars that it
depends on are modified.
Args:
fget: the function that calculates the variable value.
Returns:
ComputedVar that is recomputed when dependencies change.
"""
cvar = ComputedVar(fget=fget)
cvar.cache = True
return cvar
class PCList(list):
"""A custom list that pynecone can detect its mutation."""
def __init__(
self,
original_list: List,
reassign_field: Callable = lambda _field_name: None,
field_name: str = "",
):
"""Initialize PCList.
Args:
original_list (List): The original list
reassign_field (Callable):
The method in the parent state to reassign the field.
Default to be a no-op function
field_name (str): the name of field in the parent state
"""
self._reassign_field = lambda: reassign_field(field_name)
super().__init__(original_list)
def append(self, *args, **kwargs):
"""Append.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().append(*args, **kwargs)
self._reassign_field()
def __setitem__(self, *args, **kwargs):
"""Set item.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().__setitem__(*args, **kwargs)
self._reassign_field()
def __delitem__(self, *args, **kwargs):
"""Delete item.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().__delitem__(*args, **kwargs)
self._reassign_field()
def clear(self, *args, **kwargs):
"""Remove all item from the list.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().clear(*args, **kwargs)
self._reassign_field()
def extend(self, *args, **kwargs):
"""Add all item of a list to the end of the list.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().extend(*args, **kwargs)
self._reassign_field() if hasattr(self, "_reassign_field") else None
def pop(self, *args, **kwargs):
"""Remove an element.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().pop(*args, **kwargs)
self._reassign_field()
def remove(self, *args, **kwargs):
"""Remove an element.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().remove(*args, **kwargs)
self._reassign_field()
class PCDict(dict):
"""A custom dict that pynecone can detect its mutation."""
def __init__(
self,
original_dict: Dict,
reassign_field: Callable = lambda _field_name: None,
field_name: str = "",
):
"""Initialize PCDict.
Args:
original_dict: The original dict
reassign_field:
The method in the parent state to reassign the field.
Default to be a no-op function
field_name: the name of field in the parent state
"""
super().__init__(original_dict)
self._reassign_field = lambda: reassign_field(field_name)
def clear(self):
"""Remove all item from the list."""
super().clear()
self._reassign_field()
def setdefault(self, *args, **kwargs):
"""Return value of key if or set default.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().setdefault(*args, **kwargs)
self._reassign_field()
def popitem(self):
"""Pop last item."""
super().popitem()
self._reassign_field()
def pop(self, k, d=None):
"""Remove an element.
Args:
k: The args passed.
d: The kwargs passed.
"""
super().pop(k, d)
self._reassign_field()
def update(self, *args, **kwargs):
"""Update the dict with another dict.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().update(*args, **kwargs)
self._reassign_field()
def __setitem__(self, *args, **kwargs):
"""Set an item in the dict.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().__setitem__(*args, **kwargs)
self._reassign_field() if hasattr(self, "_reassign_field") else None
def __delitem__(self, *args, **kwargs):
"""Delete an item in the dict.
Args:
args: The args passed.
kwargs: The kwargs passed.
"""
super().__delitem__(*args, **kwargs)
self._reassign_field()
class ImportVar(Base):
"""An import var."""
# The name of the import tag.
tag: Optional[str]
# whether the import is default or named.
is_default: Optional[bool] = False
# The tag alias.
alias: Optional[str] = None
@property
def name(self) -> str:
"""The name of the import.
Returns:
The name(tag name with alias) of tag.
"""
return self.tag if not self.alias else " as ".join([self.tag, self.alias]) # type: ignore
def __hash__(self) -> int:
"""Define a hash function for the import var.
Returns:
The hash of the var.
"""
return hash((self.tag, self.is_default, self.alias))
def get_local_storage(key: Optional[Union[Var, str]] = None) -> BaseVar:
"""Provide a base var as payload to get local storage item(s).
Args:
key: Key to obtain value in the local storage.
Returns:
A BaseVar of the local storage method/function to call.
Raises:
TypeError: if the wrong key type is provided.
"""
if key:
if not (isinstance(key, Var) and key.type_ == str) and not isinstance(key, str):
type_ = type(key) if not isinstance(key, Var) else key.type_
raise TypeError(
f"Local storage keys can only be of type `str` or `var` of type `str`. Got `{type_}` instead."
)
key = key.full_name if isinstance(key, Var) else format.wrap(key, "'")
return BaseVar(name=f"localStorage.getItem({key})", type_=str)
return BaseVar(name="getAllLocalStorageItems()", type_=Dict)