make var system expandable

This commit is contained in:
Khaleel Al-Adhami 2024-10-14 15:18:11 -07:00
parent d6797a1f1d
commit 617e53d5ff
7 changed files with 448 additions and 607 deletions

View File

@ -12,7 +12,6 @@ from functools import partial
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generic,
List,
@ -33,9 +32,7 @@ from reflex.utils.types import ArgsSpec, GenericType
from reflex.vars import VarData
from reflex.vars.base import (
CachedVarOperation,
LiteralNoneVar,
LiteralVar,
ToOperation,
Var,
cached_property_no_lock,
)
@ -1249,7 +1246,7 @@ def get_fn_signature(fn: Callable) -> inspect.Signature:
return signature.replace(parameters=(new_param, *signature.parameters.values()))
class EventVar(ObjectVar):
class EventVar(ObjectVar, python_types=EventSpec):
"""Base class for event vars."""
@ -1323,7 +1320,7 @@ class LiteralEventVar(CachedVarOperation, LiteralVar, EventVar):
)
class EventChainVar(FunctionVar):
class EventChainVar(FunctionVar, python_types=EventChain):
"""Base class for event chain vars."""
@ -1403,32 +1400,6 @@ class LiteralEventChainVar(CachedVarOperation, LiteralVar, EventChainVar):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToEventVarOperation(ToOperation, EventVar):
"""Result of a cast to an event var."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = EventSpec
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToEventChainVarOperation(ToOperation, EventChainVar):
"""Result of a cast to an event chain var."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = EventChain
G = ParamSpec("G")
IndividualEventType = Union[EventSpec, EventHandler, Callable[G, Any], Var]
@ -1503,8 +1474,6 @@ class EventNamespace(types.SimpleNamespace):
LiteralEventVar = LiteralEventVar
EventChainVar = EventChainVar
LiteralEventChainVar = LiteralEventChainVar
ToEventVarOperation = ToEventVarOperation
ToEventChainVarOperation = ToEventChainVarOperation
EventType = EventType
__call__ = staticmethod(event_handler)

View File

@ -14,11 +14,12 @@ import re
import string
import sys
import warnings
from types import CodeType, FunctionType
from types import CodeType, FunctionType, get_original_bases
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
FrozenSet,
Generic,
@ -61,15 +62,13 @@ from reflex.utils.types import GenericType, Self, get_origin, has_args, unionize
if TYPE_CHECKING:
from reflex.state import BaseState
from .function import FunctionVar, ToFunctionOperation
from .function import FunctionVar
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)
from .object import ObjectVar, ToObjectOperation
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
from .object import ObjectVar
from .sequence import ArrayVar, StringVar
VAR_TYPE = TypeVar("VAR_TYPE", covariant=True)
@ -78,6 +77,184 @@ OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
warnings.filterwarnings("ignore", message="fields may not start with an underscore")
@dataclasses.dataclass(
eq=False,
frozen=True,
)
class VarSubclassEntry:
"""Entry for a Var subclass."""
var_subclass: Type[Var]
to_var_subclass: Type[ToOperation]
python_types: Tuple[GenericType, ...]
_var_subclasses: List[VarSubclassEntry] = []
_var_literal_subclasses: List[Tuple[Type[LiteralVar], VarSubclassEntry]] = []
@dataclasses.dataclass(
eq=True,
frozen=True,
)
class VarData:
"""Metadata associated with a x."""
# The name of the enclosing state.
state: str = dataclasses.field(default="")
# The name of the field in the state.
field_name: str = dataclasses.field(default="")
# Imports needed to render this var
imports: ImmutableParsedImportDict = dataclasses.field(default_factory=tuple)
# Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
def __init__(
self,
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
):
"""Initialize the var data.
Args:
state: The name of the enclosing state.
field_name: The name of the field in the state.
imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var.
"""
immutable_imports: ImmutableParsedImportDict = tuple(
sorted(
((k, tuple(sorted(v))) for k, v in parse_imports(imports or {}).items())
)
)
object.__setattr__(self, "state", state)
object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {}))
def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict.
Returns:
The imports as a mutable dict.
"""
return dict((k, list(v)) for k, v in self.imports)
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None:
"""Merge multiple var data objects.
Args:
*others: The var data objects to merge.
Returns:
The merged var data object.
"""
state = ""
field_name = ""
_imports = {}
hooks = {}
for var_data in others:
if var_data is None:
continue
state = state or var_data.state
field_name = field_name or var_data.field_name
_imports = imports.merge_imports(_imports, var_data.imports)
hooks.update(
var_data.hooks
if isinstance(var_data.hooks, dict)
else {k: None for k in var_data.hooks}
)
if state or _imports or hooks or field_name:
return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
)
return None
def __bool__(self) -> bool:
"""Check if the var data is non-empty.
Returns:
True if any field is set to a non-default value.
"""
return bool(self.state or self.imports or self.hooks or self.field_name)
@classmethod
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
"""Set the state of the var.
Args:
state: The state to set or the full name of the state.
field_name: The name of the field in the state. Optional.
Returns:
The var with the set state.
"""
from reflex.utils import format
state_name = state if isinstance(state, str) else state.get_full_name()
return VarData(
state=state_name,
field_name=field_name,
hooks={
"const {0} = useContext(StateContexts.{0})".format(
format.format_state_name(state_name)
): None
},
imports={
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
"react": [ImportVar(tag="useContext")],
},
)
def _decode_var_immutable(value: str) -> tuple[VarData | None, str]:
"""Decode the state name from a formatted var.
Args:
value: The value to extract the state name from.
Returns:
The extracted state name and the value without the state name.
"""
var_datas = []
if isinstance(value, str):
# fast path if there is no encoded VarData
if constants.REFLEX_VAR_OPENING_TAG not in value:
return None, value
offset = 0
# Find all tags.
while m := _decode_var_pattern.search(value):
start, end = m.span()
value = value[:start] + value[end:]
serialized_data = m.group(1)
if serialized_data.isnumeric() or (
serialized_data[0] == "-" and serialized_data[1:].isnumeric()
):
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
var_data = var._get_all_var_data()
if var_data is not None:
var_datas.append(var_data)
offset += end - start
return VarData.merge(*var_datas) if var_datas else None, value
@dataclasses.dataclass(
eq=False,
frozen=True,
@ -151,6 +328,40 @@ class Var(Generic[VAR_TYPE]):
"""
return False
def __init_subclass__(
cls, python_types: Tuple[GenericType, ...] | GenericType = types.Unset, **kwargs
):
"""Initialize the subclass.
Args:
python_types: The python types that the var represents.
**kwargs: Additional keyword arguments.
"""
super().__init_subclass__(**kwargs)
if python_types is not types.Unset:
python_types = (
python_types if isinstance(python_types, tuple) else (python_types,)
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToVarOperation(ToOperation, cls):
"""Base class of converting a var to another var type."""
_original: Var = dataclasses.field(
default=Var(_js_expr="null", _var_type=None),
)
_default_var_type: ClassVar[GenericType] = python_types[0]
ToVarOperation.__name__ = f"To{cls.__name__.removesuffix("Var")}Operation"
_var_subclasses.append(VarSubclassEntry(cls, ToVarOperation, python_types))
def __post_init__(self):
"""Post-initialize the var."""
# Decode any inline Var markup and apply it to the instance
@ -331,35 +542,35 @@ class Var(Generic[VAR_TYPE]):
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._js_expr}"
@overload
def to(self, output: Type[StringVar]) -> ToStringOperation: ...
def to(self, output: Type[StringVar]) -> StringVar: ...
@overload
def to(self, output: Type[str]) -> ToStringOperation: ...
def to(self, output: Type[str]) -> StringVar: ...
@overload
def to(self, output: Type[BooleanVar]) -> ToBooleanVarOperation: ...
def to(self, output: Type[BooleanVar]) -> BooleanVar: ...
@overload
def to(
self, output: Type[NumberVar], var_type: type[int] | type[float] = float
) -> ToNumberVarOperation: ...
) -> NumberVar: ...
@overload
def to(
self,
output: Type[ArrayVar],
var_type: type[list] | type[tuple] | type[set] = list,
) -> ToArrayOperation: ...
) -> ArrayVar: ...
@overload
def to(
self, output: Type[ObjectVar], var_type: types.GenericType = dict
) -> ToObjectOperation: ...
) -> ObjectVar: ...
@overload
def to(
self, output: Type[FunctionVar], var_type: Type[Callable] = Callable
) -> ToFunctionOperation: ...
) -> FunctionVar: ...
@overload
def to(
@ -379,56 +590,26 @@ class Var(Generic[VAR_TYPE]):
output: The output type.
var_type: The type of the var.
Raises:
TypeError: If the var_type is not a supported type for the output.
Returns:
The converted var.
"""
from reflex.event import (
EventChain,
EventChainVar,
EventSpec,
EventVar,
ToEventChainVarOperation,
ToEventVarOperation,
)
from .function import FunctionVar, ToFunctionOperation
from .number import (
BooleanVar,
NumberVar,
ToBooleanVarOperation,
ToNumberVarOperation,
)
from .object import ObjectVar, ToObjectOperation
from .sequence import ArrayVar, StringVar, ToArrayOperation, ToStringOperation
from .object import ObjectVar
base_type = var_type
if types.is_optional(base_type):
base_type = types.get_args(base_type)[0]
fixed_type = get_origin(base_type) or base_type
fixed_output_type = get_origin(output) or output
# If the first argument is a python type, we map it to the corresponding Var type.
if fixed_output_type is dict:
return self.to(ObjectVar, output)
if fixed_output_type in (list, tuple, set):
return self.to(ArrayVar, output)
if fixed_output_type in (int, float):
return self.to(NumberVar, output)
if fixed_output_type is str:
return self.to(StringVar, output)
if fixed_output_type is bool:
return self.to(BooleanVar, output)
for var_subclass in _var_subclasses[::-1]:
if fixed_output_type in var_subclass.python_types:
return self.to(var_subclass.var_subclass, output)
if fixed_output_type is None:
return ToNoneOperation.create(self)
if fixed_output_type is EventSpec:
return self.to(EventVar, output)
if fixed_output_type is EventChain:
return self.to(EventChainVar, output)
return get_to_operation(NoneVar).create(self) # type: ignore
# Handle fixed_output_type being Base or a dataclass.
try:
if issubclass(fixed_output_type, Base):
return self.to(ObjectVar, output)
@ -440,57 +621,12 @@ class Var(Generic[VAR_TYPE]):
return self.to(ObjectVar, output)
if inspect.isclass(output):
if issubclass(output, BooleanVar):
return ToBooleanVarOperation.create(self)
if issubclass(output, NumberVar):
if fixed_type is not None:
if fixed_type in types.UnionTypes:
inner_types = get_args(base_type)
if not all(issubclass(t, (int, float)) for t in inner_types):
raise TypeError(
f"Unsupported type {var_type} for NumberVar. Must be int or float."
)
elif not issubclass(fixed_type, (int, float)):
raise TypeError(
f"Unsupported type {var_type} for NumberVar. Must be int or float."
)
return ToNumberVarOperation.create(self, var_type or float)
if issubclass(output, ArrayVar):
if fixed_type is not None and not issubclass(
fixed_type, (list, tuple, set)
):
raise TypeError(
f"Unsupported type {var_type} for ArrayVar. Must be list, tuple, or set."
for var_subclass in _var_subclasses[::-1]:
if issubclass(output, var_subclass.var_subclass):
to_operation_return = var_subclass.to_var_subclass.create(
value=self, _var_type=var_type
)
return ToArrayOperation.create(self, var_type or list)
if issubclass(output, StringVar):
return ToStringOperation.create(self, var_type or str)
if issubclass(output, EventVar):
return ToEventVarOperation.create(self, var_type or EventSpec)
if issubclass(output, EventChainVar):
return ToEventChainVarOperation.create(self, var_type or EventChain)
if issubclass(output, (ObjectVar, Base)):
return ToObjectOperation.create(self, var_type or dict)
if issubclass(output, FunctionVar):
# if fixed_type is not None and not issubclass(fixed_type, Callable):
# raise TypeError(
# f"Unsupported type {var_type} for FunctionVar. Must be Callable."
# )
return ToFunctionOperation.create(self, var_type or Callable)
if issubclass(output, NoneVar):
return ToNoneOperation.create(self)
if dataclasses.is_dataclass(output):
return ToObjectOperation.create(self, var_type or dict)
return to_operation_return # type: ignore
# If we can't determine the first argument, we just replace the _var_type.
if not issubclass(output, Var) or var_type is None:
@ -517,11 +653,8 @@ class Var(Generic[VAR_TYPE]):
Raises:
TypeError: If the type is not supported for guessing.
"""
from reflex.event import EventChain, EventChainVar, EventSpec, EventVar
from .number import BooleanVar, NumberVar
from .number import NumberVar
from .object import ObjectVar
from .sequence import ArrayVar, StringVar
var_type = self._var_type
if var_type is None:
@ -558,20 +691,13 @@ class Var(Generic[VAR_TYPE]):
if not inspect.isclass(fixed_type):
raise TypeError(f"Unsupported type {var_type} for guess_type.")
if issubclass(fixed_type, bool):
return self.to(BooleanVar, self._var_type)
if issubclass(fixed_type, (int, float)):
return self.to(NumberVar, self._var_type)
if issubclass(fixed_type, dict):
return self.to(ObjectVar, self._var_type)
if issubclass(fixed_type, (list, tuple, set)):
return self.to(ArrayVar, self._var_type)
if issubclass(fixed_type, str):
return self.to(StringVar, self._var_type)
if issubclass(fixed_type, EventSpec):
return self.to(EventVar, self._var_type)
if issubclass(fixed_type, EventChain):
return self.to(EventChainVar, self._var_type)
if fixed_type is None:
return self.to(None)
for var_subclass in _var_subclasses[::-1]:
if issubclass(fixed_type, var_subclass.python_types):
return self.to(var_subclass.var_subclass, self._var_type)
try:
if issubclass(fixed_type, Base):
return self.to(ObjectVar, self._var_type)
@ -1017,9 +1143,131 @@ class Var(Generic[VAR_TYPE]):
OUTPUT = TypeVar("OUTPUT", bound=Var)
class ToOperation:
"""A var operation that converts a var to another type."""
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
from .object import ObjectVar
if isinstance(self, ObjectVar) and name != "_js_expr":
return ObjectVar.__getattr__(self, name)
return getattr(object.__getattribute__(self, "_original"), name)
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_js_expr")
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash(object.__getattribute__(self, "_original"))
def _get_all_var_data(self) -> VarData | None:
"""Get all the var data.
Returns:
The var data.
"""
return VarData.merge(
object.__getattribute__(self, "_original")._get_all_var_data(),
self._var_data, # type: ignore
)
@classmethod
def create(
cls,
value: Var,
_var_type: GenericType | None = None,
_var_data: VarData | None = None,
):
"""Create a ToOperation.
Args:
value: The value of the var.
_var_type: The type of the Var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The ToOperation.
"""
return cls(
_js_expr="", # type: ignore
_var_data=_var_data, # type: ignore
_var_type=_var_type or cls._default_var_type, # type: ignore
_original=value, # type: ignore
)
class LiteralVar(Var):
"""Base class for immutable literal vars."""
def __init_subclass__(cls, **kwargs):
"""Initialize the subclass.
Args:
**kwargs: Additional keyword arguments.
Raises:
TypeError: If the subclass is not a subclass of LiteralVar.
"""
super().__init_subclass__(**kwargs)
bases = get_original_bases(cls)
bases_normalized = [
base if inspect.isclass(base) else get_origin(base) for base in bases
]
possible_bases = [
base
for base in bases_normalized
if issubclass(base, Var) and base != LiteralVar
]
if not issubclass(cls, LiteralVar):
raise TypeError(
f"LiteralVar subclass {cls} must be a subclass of LiteralVar."
)
if len(possible_bases) != 1:
raise TypeError(
f"LiteralVar subclass {cls} must have exactly one base class that is a subclass of Var and not LiteralVar."
)
base_class = possible_bases[0]
var_subclasses = [
var_subclass
for var_subclass in _var_subclasses
if var_subclass.var_subclass is base_class
]
if not var_subclasses:
raise TypeError(
f"Var subclass {base_class} must have a corresponding `python_types` because a literal subclass {cls} is derived from it."
)
var_subclass = var_subclasses[0]
# Remove the old subclass, happens because __init_subclass__ is called twice
# for each subclass. This is because of __slots__ in dataclasses.
for var_literal_subclass in list(_var_literal_subclasses):
if var_literal_subclass[1] is var_subclass:
_var_literal_subclasses.remove(var_literal_subclass)
_var_literal_subclasses.append((cls, var_subclass))
@classmethod
def create(
cls,
@ -1038,50 +1286,21 @@ class LiteralVar(Var):
Raises:
TypeError: If the value is not a supported type for LiteralVar.
"""
from .number import LiteralBooleanVar, LiteralNumberVar
from .object import LiteralObjectVar
from .sequence import LiteralArrayVar, LiteralStringVar
from .sequence import LiteralStringVar
if isinstance(value, Var):
if _var_data is None:
return value
return value._replace(merge_var_data=_var_data)
if isinstance(value, str):
return LiteralStringVar.create(value, _var_data=_var_data)
for literal_subclass, var_subclass in _var_literal_subclasses[::-1]:
if isinstance(value, var_subclass.python_types):
return literal_subclass.create(value, _var_data=_var_data)
if isinstance(value, bool):
return LiteralBooleanVar.create(value, _var_data=_var_data)
if isinstance(value, (int, float)):
return LiteralNumberVar.create(value, _var_data=_var_data)
if isinstance(value, dict):
return LiteralObjectVar.create(value, _var_data=_var_data)
if isinstance(value, (list, tuple, set)):
return LiteralArrayVar.create(value, _var_data=_var_data)
if value is None:
return LiteralNoneVar.create(_var_data=_var_data)
from reflex.event import (
EventChain,
EventHandler,
EventSpec,
LiteralEventChainVar,
LiteralEventVar,
)
from reflex.event import EventHandler
from reflex.utils.format import get_event_handler_parts
from .object import LiteralObjectVar
if isinstance(value, EventSpec):
return LiteralEventVar.create(value, _var_data=_var_data)
if isinstance(value, EventChain):
return LiteralEventChainVar.create(value, _var_data=_var_data)
if isinstance(value, EventHandler):
return Var(_js_expr=".".join(filter(None, get_event_handler_parts(value))))
@ -2116,7 +2335,7 @@ class CustomVarOperation(CachedVarOperation, Var[T]):
)
class NoneVar(Var[None]):
class NoneVar(Var[None], python_types=type(None)):
"""A var representing None."""
@ -2141,11 +2360,13 @@ class LiteralNoneVar(LiteralVar, NoneVar):
@classmethod
def create(
cls,
value: None = None,
_var_data: VarData | None = None,
) -> LiteralNoneVar:
"""Create a var from a value.
Args:
value: The value of the var. Must be None. Existed for compatibility with LiteralVar.
_var_data: Additional hooks and imports associated with the Var.
Returns:
@ -2158,48 +2379,26 @@ class LiteralNoneVar(LiteralVar, NoneVar):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToNoneOperation(CachedVarOperation, NoneVar):
"""A var operation that converts a var to None."""
def get_to_operation(var_subclass: Type[Var]) -> Type[ToOperation]:
"""Get the ToOperation class for a given Var subclass.
_original_var: Var = dataclasses.field(
default_factory=lambda: LiteralNoneVar.create()
)
Args:
var_subclass: The Var subclass.
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""Get the cached var name.
Returns:
The ToOperation class.
Returns:
The cached var name.
"""
return str(self._original_var)
@classmethod
def create(
cls,
var: Var,
_var_data: VarData | None = None,
) -> ToNoneOperation:
"""Create a ToNoneOperation.
Args:
var: The var to convert to None.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The ToNoneOperation.
"""
return ToNoneOperation(
_js_expr="",
_var_type=None,
_var_data=_var_data,
_original_var=var,
)
Raises:
ValueError: If the ToOperation class cannot be found.
"""
possible_classes = [
var_subclass.to_var_subclass
for var_subclass in _var_subclasses
if var_subclass.var_subclass is var_subclass
]
if not possible_classes:
raise ValueError(f"Could not find ToOperation for {var_subclass}.")
return possible_classes[0]
@dataclasses.dataclass(
@ -2262,68 +2461,6 @@ class StateOperation(CachedVarOperation, Var):
)
class ToOperation:
"""A var operation that converts a var to another type."""
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
return getattr(object.__getattribute__(self, "_original"), name)
def __post_init__(self):
"""Post initialization."""
object.__delattr__(self, "_js_expr")
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash(object.__getattribute__(self, "_original"))
def _get_all_var_data(self) -> VarData | None:
"""Get all the var data.
Returns:
The var data.
"""
return VarData.merge(
object.__getattribute__(self, "_original")._get_all_var_data(),
self._var_data, # type: ignore
)
@classmethod
def create(
cls,
value: Var,
_var_type: GenericType | None = None,
_var_data: VarData | None = None,
):
"""Create a ToOperation.
Args:
value: The value of the var.
_var_type: The type of the Var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The ToOperation.
"""
return cls(
_js_expr="", # type: ignore
_var_data=_var_data, # type: ignore
_var_type=_var_type or cls._default_var_type, # type: ignore
_original=value, # type: ignore
)
def get_uuid_string_var() -> Var:
"""Return a Var that generates a single memoized UUID via .web/utils/state.js.
@ -2369,168 +2506,6 @@ def get_unique_variable_name() -> str:
return get_unique_variable_name()
@dataclasses.dataclass(
eq=True,
frozen=True,
)
class VarData:
"""Metadata associated with a x."""
# The name of the enclosing state.
state: str = dataclasses.field(default="")
# The name of the field in the state.
field_name: str = dataclasses.field(default="")
# Imports needed to render this var
imports: ImmutableParsedImportDict = dataclasses.field(default_factory=tuple)
# Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
def __init__(
self,
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
):
"""Initialize the var data.
Args:
state: The name of the enclosing state.
field_name: The name of the field in the state.
imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var.
"""
immutable_imports: ImmutableParsedImportDict = tuple(
sorted(
((k, tuple(sorted(v))) for k, v in parse_imports(imports or {}).items())
)
)
object.__setattr__(self, "state", state)
object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {}))
def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict.
Returns:
The imports as a mutable dict.
"""
return dict((k, list(v)) for k, v in self.imports)
@classmethod
def merge(cls, *others: VarData | None) -> VarData | None:
"""Merge multiple var data objects.
Args:
*others: The var data objects to merge.
Returns:
The merged var data object.
"""
state = ""
field_name = ""
_imports = {}
hooks = {}
for var_data in others:
if var_data is None:
continue
state = state or var_data.state
field_name = field_name or var_data.field_name
_imports = imports.merge_imports(_imports, var_data.imports)
hooks.update(
var_data.hooks
if isinstance(var_data.hooks, dict)
else {k: None for k in var_data.hooks}
)
if state or _imports or hooks or field_name:
return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
)
return None
def __bool__(self) -> bool:
"""Check if the var data is non-empty.
Returns:
True if any field is set to a non-default value.
"""
return bool(self.state or self.imports or self.hooks or self.field_name)
@classmethod
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
"""Set the state of the var.
Args:
state: The state to set or the full name of the state.
field_name: The name of the field in the state. Optional.
Returns:
The var with the set state.
"""
from reflex.utils import format
state_name = state if isinstance(state, str) else state.get_full_name()
return VarData(
state=state_name,
field_name=field_name,
hooks={
"const {0} = useContext(StateContexts.{0})".format(
format.format_state_name(state_name)
): None
},
imports={
f"/{constants.Dirs.CONTEXTS_PATH}": [ImportVar(tag="StateContexts")],
"react": [ImportVar(tag="useContext")],
},
)
def _decode_var_immutable(value: str) -> tuple[VarData | None, str]:
"""Decode the state name from a formatted var.
Args:
value: The value to extract the state name from.
Returns:
The extracted state name and the value without the state name.
"""
var_datas = []
if isinstance(value, str):
# fast path if there is no encoded VarData
if constants.REFLEX_VAR_OPENING_TAG not in value:
return None, value
offset = 0
# Find all tags.
while m := _decode_var_pattern.search(value):
start, end = m.span()
value = value[:start] + value[end:]
serialized_data = m.group(1)
if serialized_data.isnumeric() or (
serialized_data[0] == "-" and serialized_data[1:].isnumeric()
):
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
var_data = var._get_all_var_data()
if var_data is not None:
var_datas.append(var_data)
offset += end - start
return VarData.merge(*var_datas) if var_datas else None, value
# Compile regex for finding reflex var tags.
_decode_var_pattern_re = (
rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}"

View File

@ -4,21 +4,20 @@ from __future__ import annotations
import dataclasses
import sys
from typing import Any, Callable, ClassVar, Optional, Tuple, Type, Union
from typing import Any, Callable, Optional, Tuple, Type, Union
from reflex.utils.types import GenericType
from .base import (
CachedVarOperation,
LiteralVar,
ToOperation,
Var,
VarData,
cached_property_no_lock,
)
class FunctionVar(Var[Callable]):
class FunctionVar(Var[Callable], python_types=Callable):
"""Base class for immutable function vars."""
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
@ -180,17 +179,4 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToFunctionOperation(ToOperation, FunctionVar):
"""Base class of converting a var to a function."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
_default_var_type: ClassVar[GenericType] = Callable
JSON_STRINGIFY = FunctionStringVar.create("JSON.stringify")

View File

@ -10,9 +10,7 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
NoReturn,
Type,
TypeVar,
Union,
overload,
@ -25,9 +23,7 @@ from reflex.utils.types import is_optional
from .base import (
CustomVarOperationReturn,
LiteralNoneVar,
LiteralVar,
ToOperation,
Var,
VarData,
unionize,
@ -58,7 +54,7 @@ def raise_unsupported_operand_types(
)
class NumberVar(Var[NUMBER_T]):
class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""Base class for immutable number vars."""
@overload
@ -760,7 +756,7 @@ def number_trunc_operation(value: NumberVar):
return var_operation_return(js_expression=f"Math.trunc({value})", var_type=int)
class BooleanVar(NumberVar[bool]):
class BooleanVar(NumberVar[bool], python_types=bool):
"""Base class for immutable boolean vars."""
def __invert__(self):
@ -984,51 +980,6 @@ def boolean_not_operation(value: BooleanVar):
return var_operation_return(js_expression=f"!({value})", var_type=bool)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralBooleanVar(LiteralVar, BooleanVar):
"""Base class for immutable literal boolean vars."""
_var_value: bool = dataclasses.field(default=False)
def json(self) -> str:
"""Get the JSON representation of the var.
Returns:
The JSON representation of the var.
"""
return "true" if self._var_value else "false"
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._var_value))
@classmethod
def create(cls, value: bool, _var_data: VarData | None = None):
"""Create the boolean var.
Args:
value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The boolean var.
"""
return cls(
_js_expr="true" if value else "false",
_var_type=bool,
_var_data=_var_data,
_var_value=value,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
@ -1088,36 +1039,55 @@ class LiteralNumberVar(LiteralVar, NumberVar):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralBooleanVar(LiteralVar, BooleanVar):
"""Base class for immutable literal boolean vars."""
_var_value: bool = dataclasses.field(default=False)
def json(self) -> str:
"""Get the JSON representation of the var.
Returns:
The JSON representation of the var.
"""
return "true" if self._var_value else "false"
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._var_value))
@classmethod
def create(cls, value: bool, _var_data: VarData | None = None):
"""Create the boolean var.
Args:
value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The boolean var.
"""
return cls(
_js_expr="true" if value else "false",
_var_type=bool,
_var_data=_var_data,
_var_value=value,
)
number_types = Union[NumberVar, int, float]
boolean_types = Union[BooleanVar, bool]
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToNumberVarOperation(ToOperation, NumberVar):
"""Base class for immutable number vars that are the result of a number operation."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = float
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToBooleanVarOperation(ToOperation, BooleanVar):
"""Base class for immutable boolean vars that are the result of a boolean operation."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = bool
_IS_TRUE_IMPORT: ImportDict = {
f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
}

View File

@ -8,7 +8,6 @@ import typing
from inspect import isclass
from typing import (
Any,
ClassVar,
Dict,
List,
NoReturn,
@ -27,7 +26,6 @@ from reflex.utils.types import GenericType, get_attribute_access_type, get_origi
from .base import (
CachedVarOperation,
LiteralVar,
ToOperation,
Var,
VarData,
cached_property_no_lock,
@ -48,7 +46,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
class ObjectVar(Var[OBJECT_TYPE]):
class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
"""Base class for immutable object vars."""
def _key_type(self) -> Type:
@ -521,34 +519,6 @@ class ObjectItemOperation(CachedVarOperation, Var):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToObjectOperation(ToOperation, ObjectVar):
"""Operation to convert a var to an object."""
_original: Var = dataclasses.field(
default_factory=lambda: LiteralObjectVar.create({})
)
_default_var_type: ClassVar[GenericType] = dict
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_js_expr":
return self._original._js_expr
return ObjectVar.__getattr__(self, name)
@var_operation
def object_has_own_property_operation(object: ObjectVar, key: Var):
"""Check if an object has a key.

View File

@ -11,7 +11,6 @@ import typing
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Literal,
@ -32,9 +31,7 @@ from reflex.utils.types import GenericType, get_origin
from .base import (
CachedVarOperation,
CustomVarOperationReturn,
LiteralNoneVar,
LiteralVar,
ToOperation,
Var,
VarData,
_global_vars,
@ -56,7 +53,7 @@ if TYPE_CHECKING:
from .object import ObjectVar
class StringVar(Var[str]):
class StringVar(Var[str], python_types=str):
"""Base class for immutable string vars."""
@overload
@ -742,7 +739,7 @@ KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
class ArrayVar(Var[ARRAY_VAR_TYPE]):
class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
"""Base class for immutable array vars."""
@overload
@ -1569,32 +1566,6 @@ def array_contains_operation(haystack: ArrayVar, needle: Any | Var):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToStringOperation(ToOperation, StringVar):
"""Base class for immutable string vars that are the result of a to string operation."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = str
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToArrayOperation(ToOperation, ArrayVar):
"""Base class for immutable array vars that are the result of a to array operation."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = List[Any]
@var_operation
def repeat_array_operation(
array: ArrayVar[ARRAY_VAR_TYPE], count: NumberVar | int

View File

@ -519,8 +519,8 @@ def test_var_indexing_types(var, type_):
type_ : The type on indexed object.
"""
assert var[2]._var_type == type_[0]
assert var[3]._var_type == type_[1]
assert var[0]._var_type == type_[0]
assert var[1]._var_type == type_[1]
def test_var_indexing_str():