improve var base typing (#4718)

* improve var base typing

* fix pyi

* dang it darglint

* drain _process in tests

* fixes #4576

* dang it darglint
This commit is contained in:
Khaleel Al-Adhami 2025-01-31 13:12:33 -08:00 committed by GitHub
parent 12a42b6c47
commit 8663dbcb97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 279 additions and 264 deletions

View File

@ -11,10 +11,11 @@ from reflex.event import EventHandler, set_clipboard
from reflex.state import FrontendEventExceptionState
from reflex.vars.base import Var
from reflex.vars.function import ArgsFunctionOperation
from reflex.vars.object import ObjectVar
def on_error_spec(
error: Var[Dict[str, str]], info: Var[Dict[str, str]]
error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]]
) -> Tuple[Var[str], Var[str]]:
"""The spec for the on_error event handler.

View File

@ -9,9 +9,10 @@ from reflex.components.component import Component
from reflex.event import BASE_STATE, EventType
from reflex.style import Style
from reflex.vars.base import Var
from reflex.vars.object import ObjectVar
def on_error_spec(
error: Var[Dict[str, str]], info: Var[Dict[str, str]]
error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]]
) -> Tuple[Var[str], Var[str]]: ...
class ErrorBoundary(Component):

View File

@ -2457,6 +2457,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
@dataclasses.dataclass(
eq=False,
frozen=True,
slots=True,
)
class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
"""A Var that represents a Component."""

View File

@ -11,6 +11,7 @@ from reflex.components.component import Component
from reflex.components.tags import IterTag
from reflex.constants import MemoizationMode
from reflex.state import ComponentState
from reflex.utils.exceptions import UntypedVarError
from reflex.vars.base import LiteralVar, Var
@ -51,6 +52,7 @@ class Foreach(Component):
Raises:
ForeachVarError: If the iterable is of type Any.
TypeError: If the render function is a ComponentState.
UntypedVarError: If the iterable is of type Any without a type annotation.
"""
iterable = LiteralVar.create(iterable)
if iterable._var_type == Any:
@ -72,8 +74,14 @@ class Foreach(Component):
iterable=iterable,
render_fn=render_fn,
)
# Keep a ref to a rendered component to determine correct imports/hooks/styles.
component.children = [component._render().render_component()]
try:
# Keep a ref to a rendered component to determine correct imports/hooks/styles.
component.children = [component._render().render_component()]
except UntypedVarError as e:
raise UntypedVarError(
f"Could not foreach over var `{iterable!s}` without a type annotation. "
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
) from e
return component
def _render(self) -> IterTag:

View File

@ -387,7 +387,8 @@ class DataEditor(NoSSRComponent):
raise ValueError(
"DataEditor data must be an ArrayVar if rows is not provided."
)
props["rows"] = data.length() if isinstance(data, Var) else len(data)
props["rows"] = data.length() if isinstance(data, ArrayVar) else len(data)
if not isinstance(columns, Var) and len(columns):
if types.is_dataframe(type(data)) or (

View File

@ -621,18 +621,22 @@ class ShikiCodeBlock(Component, MarkdownComponentMap):
Returns:
Imports for the component.
Raises:
ValueError: If the transformers are not of type LiteralVar.
"""
imports = defaultdict(list)
if not isinstance(self.transformers, LiteralVar):
raise ValueError(
f"transformers should be a LiteralVar type. Got {type(self.transformers)} instead."
)
for transformer in self.transformers._var_value:
if isinstance(transformer, ShikiBaseTransformers):
imports[transformer.library].extend(
[ImportVar(tag=str(fn)) for fn in transformer.fns]
)
(
if transformer.library not in self.lib_dependencies:
self.lib_dependencies.append(transformer.library)
if transformer.library not in self.lib_dependencies
else None
)
return imports
@classmethod

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import dataclasses
import inspect
import sys
import types
import urllib.parse
from base64 import b64encode
@ -541,7 +540,7 @@ class JavasciptKeyboardEvent:
shiftKey: bool = False # noqa: N815
def input_event(e: Var[JavascriptInputEvent]) -> Tuple[Var[str]]:
def input_event(e: ObjectVar[JavascriptInputEvent]) -> Tuple[Var[str]]:
"""Get the value from an input event.
Args:
@ -562,7 +561,9 @@ class KeyInputInfo(TypedDict):
shift_key: bool
def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInfo]]:
def key_event(
e: ObjectVar[JavasciptKeyboardEvent],
) -> Tuple[Var[str], Var[KeyInputInfo]]:
"""Get the key from a keyboard event.
Args:
@ -572,7 +573,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf
The key from the keyboard event.
"""
return (
e.key,
e.key.to(str),
Var.create(
{
"alt_key": e.altKey,
@ -580,7 +581,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf
"meta_key": e.metaKey,
"shift_key": e.shiftKey,
},
),
).to(KeyInputInfo),
)
@ -1354,7 +1355,7 @@ def unwrap_var_annotation(annotation: GenericType):
Returns:
The unwrapped annotation.
"""
if get_origin(annotation) is Var and (args := get_args(annotation)):
if get_origin(annotation) in (Var, ObjectVar) and (args := get_args(annotation)):
return args[0]
return annotation
@ -1620,7 +1621,7 @@ class EventVar(ObjectVar, python_types=EventSpec):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
"""A literal event var."""
@ -1681,7 +1682,7 @@ class EventChainVar(BuilderFunctionVar, python_types=EventChain):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
# Note: LiteralVar is second in the inheritance list allowing it act like a
# CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
@ -1713,6 +1714,9 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
Returns:
The created LiteralEventChainVar instance.
Raises:
ValueError: If the invocation is not a FunctionVar.
"""
arg_spec = (
value.args_spec[0]
@ -1740,6 +1744,11 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
else:
invocation = value.invocation
if invocation is not None and not isinstance(invocation, FunctionVar):
raise ValueError(
f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}."
)
return cls(
_js_expr="",
_var_type=EventChain,

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import dataclasses
import re
import sys
from typing import Any, Callable, Union
from reflex import constants
@ -49,7 +48,7 @@ def _client_state_ref_dict(var_name: str) -> str:
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class ClientStateVar(Var):
"""A Var that exists on the client via useState."""

View File

@ -1637,9 +1637,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if not isinstance(var, Var):
return var
unset = object()
# Fast case: this is a literal var and the value is known.
if hasattr(var, "_var_value"):
return var._var_value
if (var_value := getattr(var, "_var_value", unset)) is not unset:
return var_value # pyright: ignore [reportReturnType]
var_data = var._get_all_var_data()
if var_data is None or not var_data.state:

View File

@ -75,6 +75,10 @@ class VarAttributeError(ReflexError, AttributeError):
"""Custom AttributeError for var related errors."""
class UntypedVarError(ReflexError, TypeError):
"""Custom TypeError for untyped var errors."""
class UntypedComputedVarError(ReflexError, TypeError):
"""Custom TypeError for untyped computed var errors."""

View File

@ -12,7 +12,6 @@ import json
import random
import re
import string
import sys
import warnings
from types import CodeType, FunctionType
from typing import (
@ -82,6 +81,7 @@ if TYPE_CHECKING:
VAR_TYPE = TypeVar("VAR_TYPE", covariant=True)
OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
STRING_T = TypeVar("STRING_T", bound=str)
SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence)
warnings.filterwarnings("ignore", message="fields may not start with an underscore")
@ -449,7 +449,7 @@ class Var(Generic[VAR_TYPE]):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class ToVarOperation(ToOperation, cls):
"""Base class of converting a var to another var type."""
@ -597,7 +597,7 @@ class Var(Generic[VAR_TYPE]):
@overload
@classmethod
def create(
def create( # pyright: ignore [reportOverlappingOverload]
cls,
value: STRING_T,
_var_data: VarData | None = None,
@ -611,6 +611,22 @@ class Var(Generic[VAR_TYPE]):
_var_data: VarData | None = None,
) -> NoneVar: ...
@overload
@classmethod
def create(
cls,
value: MAPPING_TYPE,
_var_data: VarData | None = None,
) -> ObjectVar[MAPPING_TYPE]: ...
@overload
@classmethod
def create(
cls,
value: SEQUENCE_TYPE,
_var_data: VarData | None = None,
) -> ArrayVar[SEQUENCE_TYPE]: ...
@overload
@classmethod
def create(
@ -692,8 +708,8 @@ class Var(Generic[VAR_TYPE]):
@overload
def to(
self,
output: type[Mapping],
) -> ObjectVar[Mapping]: ...
output: type[MAPPING_TYPE],
) -> ObjectVar[MAPPING_TYPE]: ...
@overload
def to(
@ -744,7 +760,7 @@ class Var(Generic[VAR_TYPE]):
return get_to_operation(NoneVar).create(self) # pyright: ignore [reportReturnType]
# Handle fixed_output_type being Base or a dataclass.
if can_use_in_object_var(fixed_output_type):
if can_use_in_object_var(output):
return self.to(ObjectVar, output)
if inspect.isclass(output):
@ -776,6 +792,9 @@ class Var(Generic[VAR_TYPE]):
return self
@overload
def guess_type(self: Var[NoReturn]) -> Var[Any]: ... # pyright: ignore [reportOverlappingOverload]
@overload
def guess_type(self: Var[str]) -> StringVar: ...
@ -785,6 +804,9 @@ class Var(Generic[VAR_TYPE]):
@overload
def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ...
@overload
def guess_type(self: Var[BASE_TYPE]) -> ObjectVar[BASE_TYPE]: ...
@overload
def guess_type(self) -> Self: ...
@ -933,7 +955,7 @@ class Var(Generic[VAR_TYPE]):
return setter
def _var_set_state(self, state: type[BaseState] | str):
def _var_set_state(self, state: type[BaseState] | str) -> Self:
"""Set the state of the var.
Args:
@ -948,7 +970,7 @@ class Var(Generic[VAR_TYPE]):
else format_state_name(state.get_full_name())
)
return StateOperation.create(
return StateOperation.create( # pyright: ignore [reportReturnType]
formatted_state_name,
self,
_var_data=VarData.merge(
@ -1127,43 +1149,6 @@ class Var(Generic[VAR_TYPE]):
"""
return self
def __getattr__(self, name: str):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute.
Raises:
VarAttributeError: If the attribute does not exist.
TypeError: If the var type is Any.
"""
if name.startswith("_"):
return self.__getattribute__(name)
if name == "contains":
raise TypeError(
f"Var of type {self._var_type} does not support contains check."
)
if name == "reverse":
raise TypeError("Cannot reverse non-list var.")
if self._var_type is Any:
raise TypeError(
f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`."
)
if name in REPLACED_NAMES:
raise VarAttributeError(
f"Field {name!r} was renamed to {REPLACED_NAMES[name]!r}"
)
raise VarAttributeError(
f"The State var has no attribute '{name}' or may have been annotated wrongly.",
)
def _decode(self) -> Any:
"""Decode Var as a python value.
@ -1225,36 +1210,76 @@ class Var(Generic[VAR_TYPE]):
return ArrayVar.range(first_endpoint, second_endpoint, step)
def __bool__(self) -> bool:
"""Raise exception if using Var in a boolean context.
if not TYPE_CHECKING:
Raises:
VarTypeError: when attempting to bool-ify the Var.
"""
raise VarTypeError(
f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. "
"Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)."
)
def __getattr__(self, name: str):
"""Get an attribute of the var.
def __iter__(self) -> Any:
"""Raise exception if using Var in an iterable context.
Args:
name: The name of the attribute.
Raises:
VarTypeError: when attempting to iterate over the Var.
"""
raise VarTypeError(
f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`."
)
Raises:
VarAttributeError: If the attribute does not exist.
UntypedVarError: If the var type is Any.
TypeError: If the var type is Any.
def __contains__(self, _: Any) -> Var:
"""Override the 'in' operator to alert the user that it is not supported.
# noqa: DAR101 self
"""
if name.startswith("_"):
raise VarAttributeError(f"Attribute {name} not found.")
Raises:
VarTypeError: the operation is not supported
"""
raise VarTypeError(
"'in' operator not supported for Var types, use Var.contains() instead."
)
if name == "contains":
raise TypeError(
f"Var of type {self._var_type} does not support contains check."
)
if name == "reverse":
raise TypeError("Cannot reverse non-list var.")
if self._var_type is Any:
raise exceptions.UntypedVarError(
f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`."
)
raise VarAttributeError(
f"The State var has no attribute '{name}' or may have been annotated wrongly.",
)
def __bool__(self) -> bool:
"""Raise exception if using Var in a boolean context.
Raises:
VarTypeError: when attempting to bool-ify the Var.
# noqa: DAR101 self
"""
raise VarTypeError(
f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. "
"Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)."
)
def __iter__(self) -> Any:
"""Raise exception if using Var in an iterable context.
Raises:
VarTypeError: when attempting to iterate over the Var.
# noqa: DAR101 self
"""
raise VarTypeError(
f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`."
)
def __contains__(self, _: Any) -> Var:
"""Override the 'in' operator to alert the user that it is not supported.
Raises:
VarTypeError: the operation is not supported
# noqa: DAR101 self
"""
raise VarTypeError(
"'in' operator not supported for Var types, use Var.contains() instead."
)
OUTPUT = TypeVar("OUTPUT", bound=Var)
@ -1471,6 +1496,12 @@ class LiteralVar(Var):
def __post_init__(self):
"""Post-initialize the var."""
@property
def _var_value(self) -> Any:
raise NotImplementedError(
"LiteralVar subclasses must implement the _var_value property."
)
def json(self) -> str:
"""Serialize the var to a JSON string.
@ -1543,7 +1574,7 @@ def var_operation(
) -> Callable[P, StringVar]: ...
LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set])
LIST_T = TypeVar("LIST_T", bound=Sequence)
@overload
@ -1780,7 +1811,7 @@ def _or_operation(a: Var, b: Var):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class CallableVar(Var):
"""Decorate a Var-returning function to act as both a Var and a function.
@ -1861,7 +1892,7 @@ def is_computed_var(obj: Any) -> TypeGuard[ComputedVar]:
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class ComputedVar(Var[RETURN_TYPE]):
"""A field with computed getters."""
@ -2070,13 +2101,6 @@ class ComputedVar(Var[RETURN_TYPE]):
owner: Type,
) -> ArrayVar[list[LIST_INSIDE]]: ...
@overload
def __get__(
self: ComputedVar[set[LIST_INSIDE]],
instance: None,
owner: Type,
) -> ArrayVar[set[LIST_INSIDE]]: ...
@overload
def __get__(
self: ComputedVar[tuple[LIST_INSIDE, ...]],
@ -2436,7 +2460,7 @@ def var_operation_return(
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class CustomVarOperation(CachedVarOperation, Var[T]):
"""Base class for custom var operations."""
@ -2507,7 +2531,7 @@ class NoneVar(Var[None], python_types=type(None)):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class LiteralNoneVar(LiteralVar, NoneVar):
"""A var representing None."""
@ -2569,7 +2593,7 @@ def get_to_operation(var_subclass: Type[Var]) -> Type[ToOperation]:
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class StateOperation(CachedVarOperation, Var):
"""A var operation that accesses a field on an object."""
@ -2716,19 +2740,6 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]:
return var_datas
# These names were changed in reflex 0.3.0
REPLACED_NAMES = {
"full_name": "_var_full_name",
"name": "_js_expr",
"state": "_var_data.state",
"type_": "_var_type",
"is_local": "_var_is_local",
"is_string": "_var_is_string",
"set_state": "_var_set_state",
"deps": "_deps",
}
dispatchers: Dict[GenericType, Callable[[Var], Var]] = {}

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import dataclasses
import sys
from datetime import date, datetime
from typing import Any, NoReturn, TypeVar, Union, overload
@ -193,7 +192,7 @@ def date_compare_operation(
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class LiteralDatetimeVar(LiteralVar, DateTimeVar):
"""Base class for immutable datetime and date vars."""

View File

@ -226,7 +226,7 @@ class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
"""Base class for immutable vars that are the result of a function call."""
@ -350,7 +350,7 @@ def format_args_function_operation(
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
"""Base class for immutable function defined via arguments and return expression."""
@ -407,7 +407,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
"""Base class for immutable function defined via arguments and return expression with the builder pattern."""

View File

@ -5,7 +5,6 @@ from __future__ import annotations
import dataclasses
import json
import math
import sys
from typing import (
TYPE_CHECKING,
Any,
@ -160,7 +159,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
from .sequence import ArrayVar, LiteralArrayVar
if isinstance(other, (list, tuple, set, ArrayVar)):
if isinstance(other, (list, tuple, ArrayVar)):
if isinstance(other, ArrayVar):
return other * self
return LiteralArrayVar.create(other) * self
@ -187,7 +186,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
"""
from .sequence import ArrayVar, LiteralArrayVar
if isinstance(other, (list, tuple, set, ArrayVar)):
if isinstance(other, (list, tuple, ArrayVar)):
if isinstance(other, ArrayVar):
return other * self
return LiteralArrayVar.create(other) * self
@ -973,7 +972,7 @@ def boolean_not_operation(value: BooleanVar):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class LiteralNumberVar(LiteralVar, NumberVar):
"""Base class for immutable literal number vars."""
@ -1032,7 +1031,7 @@ class LiteralNumberVar(LiteralVar, NumberVar):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class LiteralBooleanVar(LiteralVar, BooleanVar):
"""Base class for immutable literal boolean vars."""

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import dataclasses
import sys
import typing
from inspect import isclass
from typing import (
@ -167,12 +166,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
@ -229,12 +222,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
@ -305,7 +292,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
"""Base class for immutable literal object vars."""
@ -355,17 +342,20 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
Returns:
The JSON representation of the object.
Raises:
TypeError: The keys and values of the object must be literal vars to get the JSON representation
"""
return (
"{"
+ ", ".join(
[
f"{LiteralVar.create(key).json()}:{LiteralVar.create(value).json()}"
for key, value in self._var_value.items()
]
)
+ "}"
)
keys_and_values = []
for key, value in self._var_value.items():
key = LiteralVar.create(key)
value = LiteralVar.create(value)
if not isinstance(key, LiteralVar) or not isinstance(value, LiteralVar):
raise TypeError(
"The keys and values of the object must be literal vars to get the JSON representation."
)
keys_and_values.append(f"{key.json()}:{value.json()}")
return "{" + ", ".join(keys_and_values) + "}"
def __hash__(self) -> int:
"""Get the hash of the var.
@ -487,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class ObjectItemOperation(CachedVarOperation, Var):
"""Operation to get an item from an object."""

View File

@ -6,7 +6,6 @@ import dataclasses
import inspect
import json
import re
import sys
import typing
from typing import (
TYPE_CHECKING,
@ -15,7 +14,7 @@ from typing import (
List,
Literal,
NoReturn,
Set,
Sequence,
Tuple,
Type,
Union,
@ -596,7 +595,7 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class LiteralStringVar(LiteralVar, StringVar[str]):
"""Base class for immutable literal string vars."""
@ -718,7 +717,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class ConcatVarOperation(CachedVarOperation, StringVar[str]):
"""Representing a concatenation of literal string vars."""
@ -794,7 +793,8 @@ class ConcatVarOperation(CachedVarOperation, StringVar[str]):
)
ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set])
ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Sequence, covariant=True)
OTHER_ARRAY_VAR_TYPE = TypeVar("OTHER_ARRAY_VAR_TYPE", bound=Sequence)
OTHER_TUPLE = TypeVar("OTHER_TUPLE")
@ -887,6 +887,11 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
i: Literal[0, -2],
) -> NumberVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1]
) -> BooleanVar: ...
@overload
def __getitem__(
self: (
@ -914,7 +919,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
@overload
def __getitem__(
self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1]
self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
) -> BooleanVar: ...
@overload
@ -932,23 +937,12 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar
) -> StringVar: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
) -> BooleanVar: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]],
i: int | NumberVar,
) -> ArrayVar[List[INNER_ARRAY_VAR]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]],
i: int | NumberVar,
) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[KEY_TYPE, VALUE_TYPE]],
@ -1239,26 +1233,18 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
LIST_ELEMENT = TypeVar("LIST_ELEMENT")
ARRAY_VAR_OF_LIST_ELEMENT = Union[
ArrayVar[List[LIST_ELEMENT]],
ArrayVar[Set[LIST_ELEMENT]],
ArrayVar[Tuple[LIST_ELEMENT, ...]],
]
ARRAY_VAR_OF_LIST_ELEMENT = ArrayVar[Sequence[LIST_ELEMENT]]
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
"""Base class for immutable literal array vars."""
_var_value: Union[
List[Union[Var, Any]],
Set[Union[Var, Any]],
Tuple[Union[Var, Any], ...],
] = dataclasses.field(default_factory=list)
_var_value: Sequence[Union[Var, Any]] = dataclasses.field(default=())
@cached_property_no_lock
def _cached_var_name(self) -> str:
@ -1303,22 +1289,28 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
Returns:
The JSON representation of the var.
Raises:
TypeError: If the array elements are not of type LiteralVar.
"""
return (
"["
+ ", ".join(
[LiteralVar.create(element).json() for element in self._var_value]
)
+ "]"
)
elements = []
for element in self._var_value:
element_var = LiteralVar.create(element)
if not isinstance(element_var, LiteralVar):
raise TypeError(
f"Array elements must be of type LiteralVar, not {type(element_var)}"
)
elements.append(element_var.json())
return "[" + ", ".join(elements) + "]"
@classmethod
def create(
cls,
value: ARRAY_VAR_TYPE,
_var_type: Type[ARRAY_VAR_TYPE] | None = None,
value: OTHER_ARRAY_VAR_TYPE,
_var_type: Type[OTHER_ARRAY_VAR_TYPE] | None = None,
_var_data: VarData | None = None,
) -> LiteralArrayVar[ARRAY_VAR_TYPE]:
) -> LiteralArrayVar[OTHER_ARRAY_VAR_TYPE]:
"""Create a var from a string value.
Args:
@ -1329,7 +1321,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
Returns:
The var.
"""
return cls(
return LiteralArrayVar(
_js_expr="",
_var_type=figure_out_type(value) if _var_type is None else _var_type,
_var_data=_var_data,
@ -1356,7 +1348,7 @@ def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class ArraySliceOperation(CachedVarOperation, ArrayVar):
"""Base class for immutable string vars that are the result of a string slice operation."""
@ -1705,7 +1697,7 @@ class ColorVar(StringVar[Color], python_types=Color):
@dataclasses.dataclass(
eq=False,
frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {},
slots=True,
)
class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar):
"""Base class for immutable literal color vars."""

View File

@ -3,6 +3,7 @@ from typing import List, Mapping, Tuple
import pytest
import reflex as rx
from reflex.components.component import Component
from reflex.components.core.match import Match
from reflex.state import BaseState
from reflex.utils.exceptions import MatchTypeError
@ -29,6 +30,8 @@ def test_match_components():
rx.text("default value"),
)
match_comp = Match.create(MatchState.value, *match_case_tuples)
assert isinstance(match_comp, Component)
match_dict = match_comp.render()
assert match_dict["name"] == "Fragment"
@ -151,6 +154,7 @@ def test_match_on_component_without_default():
)
match_comp = Match.create(MatchState.value, *match_case_tuples)
assert isinstance(match_comp, Component)
default = match_comp.render()["children"][0]["default"]
assert isinstance(default, dict) and default["name"] == Fragment.__name__

View File

@ -36,6 +36,7 @@ from reflex.utils.exceptions import (
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
from reflex.vars.object import ObjectVar
@pytest.fixture
@ -842,12 +843,12 @@ def test_component_event_trigger_arbitrary_args():
"""Test that we can define arbitrary types for the args of an event trigger."""
def on_foo_spec(
_e: Var[JavascriptInputEvent],
_e: ObjectVar[JavascriptInputEvent],
alpha: Var[str],
bravo: dict[str, Any],
charlie: Var[_Obj],
charlie: ObjectVar[_Obj],
):
return [_e.target.value, bravo["nested"], charlie.custom + 42]
return [_e.target.value, bravo["nested"], charlie.custom.to(int) + 42]
class C1(Component):
library = "/local"
@ -1328,7 +1329,7 @@ class EventState(rx.State):
),
pytest.param(
rx.fragment(class_name=[TEST_VAR, "other-class"]),
[LiteralVar.create([TEST_VAR, "other-class"]).join(" ")],
[Var.create([TEST_VAR, "other-class"]).join(" ")],
id="fstring-dual-class_name",
),
pytest.param(

View File

@ -471,15 +471,15 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
"""
state = test_state() # pyright: ignore [reportCallIssue]
state.add_var("int_val", int, 0)
result = await state._process(
async for result in state._process(
Event(
token=token,
name=f"{test_state.get_name()}.set_int_val",
router_data={"pathname": "/", "query": {}},
payload={"value": 50},
)
).__anext__()
assert result.delta == {test_state.get_name(): {"int_val": 50}}
):
assert result.delta == {test_state.get_name(): {"int_val": 50}}
@pytest.mark.asyncio
@ -583,18 +583,17 @@ async def test_list_mutation_detection__plain_list(
token: a Token.
"""
for event_name, expected_delta in event_tuples:
result = await list_mutation_state._process(
async for result in list_mutation_state._process(
Event(
token=token,
name=f"{list_mutation_state.get_name()}.{event_name}",
router_data={"pathname": "/", "query": {}},
payload={},
)
).__anext__()
# prefix keys in expected_delta with the state name
expected_delta = {list_mutation_state.get_name(): expected_delta}
assert result.delta == expected_delta
):
# prefix keys in expected_delta with the state name
expected_delta = {list_mutation_state.get_name(): expected_delta}
assert result.delta == expected_delta
@pytest.mark.asyncio
@ -709,19 +708,18 @@ async def test_dict_mutation_detection__plain_list(
token: a Token.
"""
for event_name, expected_delta in event_tuples:
result = await dict_mutation_state._process(
async for result in dict_mutation_state._process(
Event(
token=token,
name=f"{dict_mutation_state.get_name()}.{event_name}",
router_data={"pathname": "/", "query": {}},
payload={},
)
).__anext__()
):
# prefix keys in expected_delta with the state name
expected_delta = {dict_mutation_state.get_name(): expected_delta}
# prefix keys in expected_delta with the state name
expected_delta = {dict_mutation_state.get_name(): expected_delta}
assert result.delta == expected_delta
assert result.delta == expected_delta
@pytest.mark.asyncio

View File

@ -789,17 +789,16 @@ async def test_process_event_simple(test_state):
assert test_state.num1 == 0
event = Event(token="t", name="set_num1", payload={"value": 69})
update = await test_state._process(event).__anext__()
async for update in test_state._process(event):
# The event should update the value.
assert test_state.num1 == 69
# The event should update the value.
assert test_state.num1 == 69
# The delta should contain the changes, including computed vars.
assert update.delta == {
TestState.get_full_name(): {"num1": 69, "sum": 72.14},
GrandchildState3.get_full_name(): {"computed": ""},
}
assert update.events == []
# The delta should contain the changes, including computed vars.
assert update.delta == {
TestState.get_full_name(): {"num1": 69, "sum": 72.14},
GrandchildState3.get_full_name(): {"computed": ""},
}
assert update.events == []
@pytest.mark.asyncio
@ -819,15 +818,15 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
name=f"{ChildState.get_name()}.change_both",
payload={"value": "hi", "count": 12},
)
update = await test_state._process(event).__anext__()
assert child_state.value == "HI"
assert child_state.count == 24
assert update.delta == {
# TestState.get_full_name(): {"sum": 3.14, "upper": ""},
ChildState.get_full_name(): {"value": "HI", "count": 24},
GrandchildState3.get_full_name(): {"computed": ""},
}
test_state._clean()
async for update in test_state._process(event):
assert child_state.value == "HI"
assert child_state.count == 24
assert update.delta == {
# TestState.get_full_name(): {"sum": 3.14, "upper": ""},
ChildState.get_full_name(): {"value": "HI", "count": 24},
GrandchildState3.get_full_name(): {"computed": ""},
}
test_state._clean()
# Test with the granchild state.
assert grandchild_state.value2 == ""
@ -836,13 +835,13 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
name=f"{GrandchildState.get_full_name()}.set_value2",
payload={"value": "new"},
)
update = await test_state._process(event).__anext__()
assert grandchild_state.value2 == "new"
assert update.delta == {
# TestState.get_full_name(): {"sum": 3.14, "upper": ""},
GrandchildState.get_full_name(): {"value2": "new"},
GrandchildState3.get_full_name(): {"computed": ""},
}
async for update in test_state._process(event):
assert grandchild_state.value2 == "new"
assert update.delta == {
# TestState.get_full_name(): {"sum": 3.14, "upper": ""},
GrandchildState.get_full_name(): {"value2": "new"},
GrandchildState3.get_full_name(): {"computed": ""},
}
@pytest.mark.asyncio
@ -2909,10 +2908,10 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
events = updates[0].events
assert len(events) == 2
assert (await state._process(events[0]).__anext__()).delta == {
test_state.get_full_name(): {"num": 1}
}
assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state)
async for update in state._process(events[0]):
assert update.delta == {test_state.get_full_name(): {"num": 1}}
async for update in state._process(events[1]):
assert update.delta == exp_is_hydrated(state)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@ -2957,13 +2956,12 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
events = updates[0].events
assert len(events) == 3
assert (await state._process(events[0]).__anext__()).delta == {
OnLoadState.get_full_name(): {"num": 1}
}
assert (await state._process(events[1]).__anext__()).delta == {
OnLoadState.get_full_name(): {"num": 2}
}
assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
async for update in state._process(events[0]):
assert update.delta == {OnLoadState.get_full_name(): {"num": 1}}
async for update in state._process(events[1]):
assert update.delta == {OnLoadState.get_full_name(): {"num": 2}}
async for update in state._process(events[2]):
assert update.delta == exp_is_hydrated(state)
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()

View File

@ -1,6 +1,5 @@
import json
import math
import sys
import typing
from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
@ -422,19 +421,13 @@ class Bar(rx.Base):
@pytest.mark.parametrize(
("var", "var_type"),
(
[
(Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar),
(Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]),
]
if sys.version_info >= (3, 10)
else []
)
+ [
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
[
(Var(_js_expr="").to(Foo | Bar), Foo | Bar),
(Var(_js_expr="").to(Foo | Bar).bar, Union[int, str]),
(Var(_js_expr="").to(Union[Foo, Bar]), Union[Foo, Bar]),
(Var(_js_expr="").to(Union[Foo, Bar]).baz, str),
(
Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo,
Var(_js_expr="").to(Union[Foo, Bar]).foo,
Union[int, None],
),
],
@ -1358,7 +1351,7 @@ def test_unsupported_types_for_contains(var: Var):
var: The base var.
"""
with pytest.raises(TypeError) as err:
assert var.contains(1)
assert var.contains(1) # pyright: ignore [reportAttributeAccessIssue]
assert (
err.value.args[0]
== f"Var of type {var._var_type} does not support contains check."
@ -1388,7 +1381,7 @@ def test_unsupported_types_for_string_contains(other):
def test_unsupported_default_contains():
with pytest.raises(TypeError) as err:
assert 1 in Var(_js_expr="var", _var_type=str).guess_type()
assert 1 in Var(_js_expr="var", _var_type=str).guess_type() # pyright: ignore [reportOperatorIssue]
assert (
err.value.args[0]
== "'in' operator not supported for Var types, use Var.contains() instead."