make var system expandable (#4175)

* make var system expandable

* use old syntax

* remove newer features

* that's a weird error

* remove unnecessary error message

* remove hacky getattr as it's no longer necessary

* improve color handling

* get it right pyright

* dang it darglint

* fix prototype to string

* don't try twice

* adjust test case

* add test for var alpha

* change place of type ignore

* fix json

* add name to custom var operation

* don't delete that you silly

* change logic

* remove extra word
This commit is contained in:
Khaleel Al-Adhami 2024-10-21 17:05:13 -07:00 committed by GitHub
parent f39e8c9667
commit 54ad9f0f4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 682 additions and 631 deletions

View File

@ -12,7 +12,6 @@ from functools import partial
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generic,
List,
@ -33,9 +32,7 @@ from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
from reflex.utils.types import ArgsSpec, GenericType
from reflex.vars import VarData
from reflex.vars.base import (
LiteralNoneVar,
LiteralVar,
ToOperation,
Var,
)
from reflex.vars.function import (
@ -1254,7 +1251,7 @@ def get_fn_signature(fn: Callable) -> inspect.Signature:
return signature.replace(parameters=(new_param, *signature.parameters.values()))
class EventVar(ObjectVar):
class EventVar(ObjectVar, python_types=EventSpec):
"""Base class for event vars."""
@ -1315,7 +1312,7 @@ class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
)
class EventChainVar(FunctionVar):
class EventChainVar(FunctionVar, python_types=EventChain):
"""Base class for event chain vars."""
@ -1384,32 +1381,6 @@ class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToEventVarOperation(ToOperation, EventVar):
"""Result of a cast to an event var."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = EventSpec
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToEventChainVarOperation(ToOperation, EventChainVar):
"""Result of a cast to an event chain var."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = EventChain
G = ParamSpec("G")
IndividualEventType = Union[EventSpec, EventHandler, Callable[G, Any], Var[Any]]
@ -1537,8 +1508,6 @@ class EventNamespace(types.SimpleNamespace):
LiteralEventVar = LiteralEventVar
EventChainVar = EventChainVar
LiteralEventChainVar = LiteralEventChainVar
ToEventVarOperation = ToEventVarOperation
ToEventChainVarOperation = ToEventChainVarOperation
EventType = EventType
__call__ = staticmethod(event_handler)

File diff suppressed because it is too large Load Diff

View File

@ -4,21 +4,20 @@ from __future__ import annotations
import dataclasses
import sys
from typing import Any, Callable, ClassVar, Optional, Tuple, Type, Union
from typing import Any, Callable, Optional, Tuple, Type, Union
from reflex.utils.types import GenericType
from .base import (
CachedVarOperation,
LiteralVar,
ToOperation,
Var,
VarData,
cached_property_no_lock,
)
class FunctionVar(Var[Callable]):
class FunctionVar(Var[Callable], python_types=Callable):
"""Base class for immutable function vars."""
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation:
@ -180,17 +179,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToFunctionOperation(ToOperation, FunctionVar):
"""Base class of converting a var to a function."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralVar.create(None))
_default_var_type: ClassVar[GenericType] = Callable
JSON_STRINGIFY = FunctionStringVar.create("JSON.stringify")
PROTOTYPE_TO_STRING = FunctionStringVar.create(
"((__to_string) => __to_string.toString())"
)

View File

@ -10,7 +10,6 @@ from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
NoReturn,
Type,
TypeVar,
@ -25,9 +24,7 @@ from reflex.utils.types import is_optional
from .base import (
CustomVarOperationReturn,
LiteralNoneVar,
LiteralVar,
ToOperation,
Var,
VarData,
unionize,
@ -58,7 +55,7 @@ def raise_unsupported_operand_types(
)
class NumberVar(Var[NUMBER_T]):
class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""Base class for immutable number vars."""
@overload
@ -760,7 +757,7 @@ def number_trunc_operation(value: NumberVar):
return var_operation_return(js_expression=f"Math.trunc({value})", var_type=int)
class BooleanVar(NumberVar[bool]):
class BooleanVar(NumberVar[bool], python_types=bool):
"""Base class for immutable boolean vars."""
def __invert__(self):
@ -984,51 +981,6 @@ def boolean_not_operation(value: BooleanVar):
return var_operation_return(js_expression=f"!({value})", var_type=bool)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralBooleanVar(LiteralVar, BooleanVar):
"""Base class for immutable literal boolean vars."""
_var_value: bool = dataclasses.field(default=False)
def json(self) -> str:
"""Get the JSON representation of the var.
Returns:
The JSON representation of the var.
"""
return "true" if self._var_value else "false"
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._var_value))
@classmethod
def create(cls, value: bool, _var_data: VarData | None = None):
"""Create the boolean var.
Args:
value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The boolean var.
"""
return cls(
_js_expr="true" if value else "false",
_var_type=bool,
_var_data=_var_data,
_var_value=value,
)
@dataclasses.dataclass(
eq=False,
frozen=True,
@ -1088,36 +1040,55 @@ class LiteralNumberVar(LiteralVar, NumberVar):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralBooleanVar(LiteralVar, BooleanVar):
"""Base class for immutable literal boolean vars."""
_var_value: bool = dataclasses.field(default=False)
def json(self) -> str:
"""Get the JSON representation of the var.
Returns:
The JSON representation of the var.
"""
return "true" if self._var_value else "false"
def __hash__(self) -> int:
"""Calculate the hash value of the object.
Returns:
int: The hash value of the object.
"""
return hash((self.__class__.__name__, self._var_value))
@classmethod
def create(cls, value: bool, _var_data: VarData | None = None):
"""Create the boolean var.
Args:
value: The value of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The boolean var.
"""
return cls(
_js_expr="true" if value else "false",
_var_type=bool,
_var_data=_var_data,
_var_value=value,
)
number_types = Union[NumberVar, int, float]
boolean_types = Union[BooleanVar, bool]
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToNumberVarOperation(ToOperation, NumberVar):
"""Base class for immutable number vars that are the result of a number operation."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = float
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToBooleanVarOperation(ToOperation, BooleanVar):
"""Base class for immutable boolean vars that are the result of a boolean operation."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = bool
_IS_TRUE_IMPORT: ImportDict = {
f"/{Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
}
@ -1140,8 +1111,12 @@ def boolify(value: Var):
)
T = TypeVar("T")
U = TypeVar("U")
@var_operation
def ternary_operation(condition: BooleanVar, if_true: Var, if_false: Var):
def ternary_operation(condition: BooleanVar, if_true: Var[T], if_false: Var[U]):
"""Create a ternary operation.
Args:
@ -1152,10 +1127,14 @@ def ternary_operation(condition: BooleanVar, if_true: Var, if_false: Var):
Returns:
The ternary operation.
"""
return var_operation_return(
js_expression=f"({condition} ? {if_true} : {if_false})",
var_type=unionize(if_true._var_type, if_false._var_type),
type_value: Union[Type[T], Type[U]] = unionize(
if_true._var_type, if_false._var_type
)
value: CustomVarOperationReturn[Union[T, U]] = var_operation_return(
js_expression=f"({condition} ? {if_true} : {if_false})",
var_type=type_value,
)
return value
NUMBER_TYPES = (int, float, NumberVar)

View File

@ -8,7 +8,6 @@ import typing
from inspect import isclass
from typing import (
Any,
ClassVar,
Dict,
List,
NoReturn,
@ -27,7 +26,6 @@ from reflex.utils.types import GenericType, get_attribute_access_type, get_origi
from .base import (
CachedVarOperation,
LiteralVar,
ToOperation,
Var,
VarData,
cached_property_no_lock,
@ -48,7 +46,7 @@ ARRAY_INNER_TYPE = TypeVar("ARRAY_INNER_TYPE")
OTHER_KEY_TYPE = TypeVar("OTHER_KEY_TYPE")
class ObjectVar(Var[OBJECT_TYPE]):
class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
"""Base class for immutable object vars."""
def _key_type(self) -> Type:
@ -521,34 +519,6 @@ class ObjectItemOperation(CachedVarOperation, Var):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToObjectOperation(ToOperation, ObjectVar):
"""Operation to convert a var to an object."""
_original: Var = dataclasses.field(
default_factory=lambda: LiteralObjectVar.create({})
)
_default_var_type: ClassVar[GenericType] = dict
def __getattr__(self, name: str) -> Any:
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute of the var.
"""
if name == "_js_expr":
return self._original._js_expr
return ObjectVar.__getattr__(self, name)
@var_operation
def object_has_own_property_operation(object: ObjectVar, key: Var):
"""Check if an object has a key.

View File

@ -11,7 +11,6 @@ import typing
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Literal,
@ -19,27 +18,28 @@ from typing import (
Set,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import TypeVar
from reflex import constants
from reflex.constants.base import REFLEX_VAR_OPENING_TAG
from reflex.constants.colors import Color
from reflex.utils.exceptions import VarTypeError
from reflex.utils.types import GenericType, get_origin
from .base import (
CachedVarOperation,
CustomVarOperationReturn,
LiteralNoneVar,
LiteralVar,
ToOperation,
Var,
VarData,
_global_vars,
cached_property_no_lock,
figure_out_type,
get_python_literal,
get_unique_variable_name,
unionize,
var_operation,
@ -50,13 +50,16 @@ from .number import (
LiteralNumberVar,
NumberVar,
raise_unsupported_operand_types,
ternary_operation,
)
if TYPE_CHECKING:
from .object import ObjectVar
STRING_TYPE = TypeVar("STRING_TYPE", default=str)
class StringVar(Var[str]):
class StringVar(Var[STRING_TYPE], python_types=str):
"""Base class for immutable string vars."""
@overload
@ -350,7 +353,7 @@ class StringVar(Var[str]):
@var_operation
def string_lt_operation(lhs: StringVar | str, rhs: StringVar | str):
def string_lt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
"""Check if a string is less than another string.
Args:
@ -364,7 +367,7 @@ def string_lt_operation(lhs: StringVar | str, rhs: StringVar | str):
@var_operation
def string_gt_operation(lhs: StringVar | str, rhs: StringVar | str):
def string_gt_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
"""Check if a string is greater than another string.
Args:
@ -378,7 +381,7 @@ def string_gt_operation(lhs: StringVar | str, rhs: StringVar | str):
@var_operation
def string_le_operation(lhs: StringVar | str, rhs: StringVar | str):
def string_le_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
"""Check if a string is less than or equal to another string.
Args:
@ -392,7 +395,7 @@ def string_le_operation(lhs: StringVar | str, rhs: StringVar | str):
@var_operation
def string_ge_operation(lhs: StringVar | str, rhs: StringVar | str):
def string_ge_operation(lhs: StringVar[Any] | str, rhs: StringVar[Any] | str):
"""Check if a string is greater than or equal to another string.
Args:
@ -406,7 +409,7 @@ def string_ge_operation(lhs: StringVar | str, rhs: StringVar | str):
@var_operation
def string_lower_operation(string: StringVar):
def string_lower_operation(string: StringVar[Any]):
"""Convert a string to lowercase.
Args:
@ -419,7 +422,7 @@ def string_lower_operation(string: StringVar):
@var_operation
def string_upper_operation(string: StringVar):
def string_upper_operation(string: StringVar[Any]):
"""Convert a string to uppercase.
Args:
@ -432,7 +435,7 @@ def string_upper_operation(string: StringVar):
@var_operation
def string_strip_operation(string: StringVar):
def string_strip_operation(string: StringVar[Any]):
"""Strip a string.
Args:
@ -446,7 +449,7 @@ def string_strip_operation(string: StringVar):
@var_operation
def string_contains_field_operation(
haystack: StringVar, needle: StringVar | str, field: StringVar | str
haystack: StringVar[Any], needle: StringVar[Any] | str, field: StringVar[Any] | str
):
"""Check if a string contains another string.
@ -465,7 +468,7 @@ def string_contains_field_operation(
@var_operation
def string_contains_operation(haystack: StringVar, needle: StringVar | str):
def string_contains_operation(haystack: StringVar[Any], needle: StringVar[Any] | str):
"""Check if a string contains another string.
Args:
@ -481,7 +484,9 @@ def string_contains_operation(haystack: StringVar, needle: StringVar | str):
@var_operation
def string_starts_with_operation(full_string: StringVar, prefix: StringVar | str):
def string_starts_with_operation(
full_string: StringVar[Any], prefix: StringVar[Any] | str
):
"""Check if a string starts with a prefix.
Args:
@ -497,7 +502,7 @@ def string_starts_with_operation(full_string: StringVar, prefix: StringVar | str
@var_operation
def string_item_operation(string: StringVar, index: NumberVar | int):
def string_item_operation(string: StringVar[Any], index: NumberVar | int):
"""Get an item from a string.
Args:
@ -511,7 +516,7 @@ def string_item_operation(string: StringVar, index: NumberVar | int):
@var_operation
def array_join_operation(array: ArrayVar, sep: StringVar | str = ""):
def array_join_operation(array: ArrayVar, sep: StringVar[Any] | str = ""):
"""Join the elements of an array.
Args:
@ -536,7 +541,7 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralStringVar(LiteralVar, StringVar):
class LiteralStringVar(LiteralVar, StringVar[str]):
"""Base class for immutable literal string vars."""
_var_value: str = dataclasses.field(default="")
@ -658,7 +663,7 @@ class LiteralStringVar(LiteralVar, StringVar):
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ConcatVarOperation(CachedVarOperation, StringVar):
class ConcatVarOperation(CachedVarOperation, StringVar[str]):
"""Representing a concatenation of literal string vars."""
_var_value: Tuple[Var, ...] = dataclasses.field(default_factory=tuple)
@ -742,7 +747,7 @@ KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
class ArrayVar(Var[ARRAY_VAR_TYPE]):
class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
"""Base class for immutable array vars."""
@overload
@ -1275,7 +1280,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
@var_operation
def string_split_operation(string: StringVar, sep: StringVar | str = ""):
def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""):
"""Split a string.
Args:
@ -1572,32 +1577,6 @@ def array_contains_operation(haystack: ArrayVar, needle: Any | Var):
)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToStringOperation(ToOperation, StringVar):
"""Base class for immutable string vars that are the result of a to string operation."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = str
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class ToArrayOperation(ToOperation, ArrayVar):
"""Base class for immutable array vars that are the result of a to array operation."""
_original: Var = dataclasses.field(default_factory=lambda: LiteralNoneVar.create())
_default_var_type: ClassVar[Type] = List[Any]
@var_operation
def repeat_array_operation(
array: ArrayVar[ARRAY_VAR_TYPE], count: NumberVar | int
@ -1657,3 +1636,134 @@ def array_concat_operation(
js_expression=f"[...{lhs}, ...{rhs}]",
var_type=Union[lhs._var_type, rhs._var_type],
)
class ColorVar(StringVar[Color], python_types=Color):
"""Base class for immutable color vars."""
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
)
class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar):
"""Base class for immutable literal color vars."""
_var_value: Color = dataclasses.field(default_factory=lambda: Color(color="black"))
@classmethod
def create(
cls,
value: Color,
_var_type: Type[Color] | None = None,
_var_data: VarData | None = None,
) -> ColorVar:
"""Create a var from a string value.
Args:
value: The value to create the var from.
_var_type: The type of the var.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The var.
"""
return cls(
_js_expr="",
_var_type=_var_type or Color,
_var_data=_var_data,
_var_value=value,
)
def __hash__(self) -> int:
"""Get the hash of the var.
Returns:
The hash of the var.
"""
return hash(
(
self.__class__.__name__,
self._var_value.color,
self._var_value.alpha,
self._var_value.shade,
)
)
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
alpha = self._var_value.alpha
alpha = (
ternary_operation(
alpha,
LiteralStringVar.create("a"),
LiteralStringVar.create(""),
)
if isinstance(alpha, Var)
else LiteralStringVar.create("a" if alpha else "")
)
shade = self._var_value.shade
shade = (
shade.to_string(use_json=False)
if isinstance(shade, Var)
else LiteralStringVar.create(str(shade))
)
return str(
ConcatVarOperation.create(
LiteralStringVar.create("var(--"),
self._var_value.color,
LiteralStringVar.create("-"),
alpha,
shade,
LiteralStringVar.create(")"),
)
)
@cached_property_no_lock
def _cached_get_all_var_data(self) -> VarData | None:
"""Get all the var data.
Returns:
The var data.
"""
return VarData.merge(
*[
LiteralVar.create(var)._get_all_var_data()
for var in (
self._var_value.color,
self._var_value.alpha,
self._var_value.shade,
)
],
self._var_data,
)
def json(self) -> str:
"""Get the JSON representation of the var.
Returns:
The JSON representation of the var.
Raises:
TypeError: If the color is not a valid color.
"""
color, alpha, shade = map(
get_python_literal,
(self._var_value.color, self._var_value.alpha, self._var_value.shade),
)
if color is None or alpha is None or shade is None:
raise TypeError("Cannot serialize color that contains non-literal vars.")
if (
not isinstance(color, str)
or not isinstance(alpha, bool)
or not isinstance(shade, int)
):
raise TypeError("Color is not a valid color.")
return f"var(--{color}-{'a' if alpha else ''}{shade})"

View File

@ -14,6 +14,7 @@ class ColorState(rx.State):
color: str = "mint"
color_part: str = "tom"
shade: int = 4
alpha: bool = False
color_state_name = ColorState.get_full_name().replace(".", "__")
@ -31,7 +32,14 @@ def create_color_var(color):
(create_color_var(rx.color("mint", 3, True)), '"var(--mint-a3)"', Color),
(
create_color_var(rx.color(ColorState.color, ColorState.shade)), # type: ignore
f'("var(--"+{str(color_state_name)}.color+"-"+{str(color_state_name)}.shade+")")',
f'("var(--"+{str(color_state_name)}.color+"-"+(((__to_string) => __to_string.toString())({str(color_state_name)}.shade))+")")',
Color,
),
(
create_color_var(
rx.color(ColorState.color, ColorState.shade, ColorState.alpha) # type: ignore
),
f'("var(--"+{str(color_state_name)}.color+"-"+({str(color_state_name)}.alpha ? "a" : "")+(((__to_string) => __to_string.toString())({str(color_state_name)}.shade))+")")',
Color,
),
(
@ -43,7 +51,7 @@ def create_color_var(color):
create_color_var(
rx.color(f"{ColorState.color_part}ato", f"{ColorState.shade}") # type: ignore
),
f'("var(--"+{str(color_state_name)}.color_part+"ato-"+{str(color_state_name)}.shade+")")',
f'("var(--"+({str(color_state_name)}.color_part+"ato")+"-"+{str(color_state_name)}.shade+")")',
Color,
),
(

View File

@ -519,8 +519,8 @@ def test_var_indexing_types(var, type_):
type_ : The type on indexed object.
"""
assert var[2]._var_type == type_[0]
assert var[3]._var_type == type_[1]
assert var[0]._var_type == type_[0]
assert var[1]._var_type == type_[1]
def test_var_indexing_str():