improve var base typing (#4718)
* improve var base typing * fix pyi * dang it darglint * drain _process in tests * fixes #4576 * dang it darglint
This commit is contained in:
parent
12a42b6c47
commit
8663dbcb97
@ -11,10 +11,11 @@ from reflex.event import EventHandler, set_clipboard
|
||||
from reflex.state import FrontendEventExceptionState
|
||||
from reflex.vars.base import Var
|
||||
from reflex.vars.function import ArgsFunctionOperation
|
||||
from reflex.vars.object import ObjectVar
|
||||
|
||||
|
||||
def on_error_spec(
|
||||
error: Var[Dict[str, str]], info: Var[Dict[str, str]]
|
||||
error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]]
|
||||
) -> Tuple[Var[str], Var[str]]:
|
||||
"""The spec for the on_error event handler.
|
||||
|
||||
|
@ -9,9 +9,10 @@ from reflex.components.component import Component
|
||||
from reflex.event import BASE_STATE, EventType
|
||||
from reflex.style import Style
|
||||
from reflex.vars.base import Var
|
||||
from reflex.vars.object import ObjectVar
|
||||
|
||||
def on_error_spec(
|
||||
error: Var[Dict[str, str]], info: Var[Dict[str, str]]
|
||||
error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]]
|
||||
) -> Tuple[Var[str], Var[str]]: ...
|
||||
|
||||
class ErrorBoundary(Component):
|
||||
|
@ -2457,6 +2457,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
slots=True,
|
||||
)
|
||||
class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
|
||||
"""A Var that represents a Component."""
|
||||
|
@ -11,6 +11,7 @@ from reflex.components.component import Component
|
||||
from reflex.components.tags import IterTag
|
||||
from reflex.constants import MemoizationMode
|
||||
from reflex.state import ComponentState
|
||||
from reflex.utils.exceptions import UntypedVarError
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
|
||||
|
||||
@ -51,6 +52,7 @@ class Foreach(Component):
|
||||
Raises:
|
||||
ForeachVarError: If the iterable is of type Any.
|
||||
TypeError: If the render function is a ComponentState.
|
||||
UntypedVarError: If the iterable is of type Any without a type annotation.
|
||||
"""
|
||||
iterable = LiteralVar.create(iterable)
|
||||
if iterable._var_type == Any:
|
||||
@ -72,8 +74,14 @@ class Foreach(Component):
|
||||
iterable=iterable,
|
||||
render_fn=render_fn,
|
||||
)
|
||||
# Keep a ref to a rendered component to determine correct imports/hooks/styles.
|
||||
component.children = [component._render().render_component()]
|
||||
try:
|
||||
# Keep a ref to a rendered component to determine correct imports/hooks/styles.
|
||||
component.children = [component._render().render_component()]
|
||||
except UntypedVarError as e:
|
||||
raise UntypedVarError(
|
||||
f"Could not foreach over var `{iterable!s}` without a type annotation. "
|
||||
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
|
||||
) from e
|
||||
return component
|
||||
|
||||
def _render(self) -> IterTag:
|
||||
|
@ -387,7 +387,8 @@ class DataEditor(NoSSRComponent):
|
||||
raise ValueError(
|
||||
"DataEditor data must be an ArrayVar if rows is not provided."
|
||||
)
|
||||
props["rows"] = data.length() if isinstance(data, Var) else len(data)
|
||||
|
||||
props["rows"] = data.length() if isinstance(data, ArrayVar) else len(data)
|
||||
|
||||
if not isinstance(columns, Var) and len(columns):
|
||||
if types.is_dataframe(type(data)) or (
|
||||
|
@ -621,18 +621,22 @@ class ShikiCodeBlock(Component, MarkdownComponentMap):
|
||||
|
||||
Returns:
|
||||
Imports for the component.
|
||||
|
||||
Raises:
|
||||
ValueError: If the transformers are not of type LiteralVar.
|
||||
"""
|
||||
imports = defaultdict(list)
|
||||
if not isinstance(self.transformers, LiteralVar):
|
||||
raise ValueError(
|
||||
f"transformers should be a LiteralVar type. Got {type(self.transformers)} instead."
|
||||
)
|
||||
for transformer in self.transformers._var_value:
|
||||
if isinstance(transformer, ShikiBaseTransformers):
|
||||
imports[transformer.library].extend(
|
||||
[ImportVar(tag=str(fn)) for fn in transformer.fns]
|
||||
)
|
||||
(
|
||||
if transformer.library not in self.lib_dependencies:
|
||||
self.lib_dependencies.append(transformer.library)
|
||||
if transformer.library not in self.lib_dependencies
|
||||
else None
|
||||
)
|
||||
return imports
|
||||
|
||||
@classmethod
|
||||
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
import urllib.parse
|
||||
from base64 import b64encode
|
||||
@ -541,7 +540,7 @@ class JavasciptKeyboardEvent:
|
||||
shiftKey: bool = False # noqa: N815
|
||||
|
||||
|
||||
def input_event(e: Var[JavascriptInputEvent]) -> Tuple[Var[str]]:
|
||||
def input_event(e: ObjectVar[JavascriptInputEvent]) -> Tuple[Var[str]]:
|
||||
"""Get the value from an input event.
|
||||
|
||||
Args:
|
||||
@ -562,7 +561,9 @@ class KeyInputInfo(TypedDict):
|
||||
shift_key: bool
|
||||
|
||||
|
||||
def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInfo]]:
|
||||
def key_event(
|
||||
e: ObjectVar[JavasciptKeyboardEvent],
|
||||
) -> Tuple[Var[str], Var[KeyInputInfo]]:
|
||||
"""Get the key from a keyboard event.
|
||||
|
||||
Args:
|
||||
@ -572,7 +573,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf
|
||||
The key from the keyboard event.
|
||||
"""
|
||||
return (
|
||||
e.key,
|
||||
e.key.to(str),
|
||||
Var.create(
|
||||
{
|
||||
"alt_key": e.altKey,
|
||||
@ -580,7 +581,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf
|
||||
"meta_key": e.metaKey,
|
||||
"shift_key": e.shiftKey,
|
||||
},
|
||||
),
|
||||
).to(KeyInputInfo),
|
||||
)
|
||||
|
||||
|
||||
@ -1354,7 +1355,7 @@ def unwrap_var_annotation(annotation: GenericType):
|
||||
Returns:
|
||||
The unwrapped annotation.
|
||||
"""
|
||||
if get_origin(annotation) is Var and (args := get_args(annotation)):
|
||||
if get_origin(annotation) in (Var, ObjectVar) and (args := get_args(annotation)):
|
||||
return args[0]
|
||||
return annotation
|
||||
|
||||
@ -1620,7 +1621,7 @@ class EventVar(ObjectVar, python_types=EventSpec):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
|
||||
"""A literal event var."""
|
||||
@ -1681,7 +1682,7 @@ class EventChainVar(BuilderFunctionVar, python_types=EventChain):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
# Note: LiteralVar is second in the inheritance list allowing it act like a
|
||||
# CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
|
||||
@ -1713,6 +1714,9 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
|
||||
|
||||
Returns:
|
||||
The created LiteralEventChainVar instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the invocation is not a FunctionVar.
|
||||
"""
|
||||
arg_spec = (
|
||||
value.args_spec[0]
|
||||
@ -1740,6 +1744,11 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
|
||||
else:
|
||||
invocation = value.invocation
|
||||
|
||||
if invocation is not None and not isinstance(invocation, FunctionVar):
|
||||
raise ValueError(
|
||||
f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}."
|
||||
)
|
||||
|
||||
return cls(
|
||||
_js_expr="",
|
||||
_var_type=EventChain,
|
||||
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from reflex import constants
|
||||
@ -49,7 +48,7 @@ def _client_state_ref_dict(var_name: str) -> str:
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class ClientStateVar(Var):
|
||||
"""A Var that exists on the client via useState."""
|
||||
|
@ -1637,9 +1637,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if not isinstance(var, Var):
|
||||
return var
|
||||
|
||||
unset = object()
|
||||
|
||||
# Fast case: this is a literal var and the value is known.
|
||||
if hasattr(var, "_var_value"):
|
||||
return var._var_value
|
||||
if (var_value := getattr(var, "_var_value", unset)) is not unset:
|
||||
return var_value # pyright: ignore [reportReturnType]
|
||||
|
||||
var_data = var._get_all_var_data()
|
||||
if var_data is None or not var_data.state:
|
||||
|
@ -75,6 +75,10 @@ class VarAttributeError(ReflexError, AttributeError):
|
||||
"""Custom AttributeError for var related errors."""
|
||||
|
||||
|
||||
class UntypedVarError(ReflexError, TypeError):
|
||||
"""Custom TypeError for untyped var errors."""
|
||||
|
||||
|
||||
class UntypedComputedVarError(ReflexError, TypeError):
|
||||
"""Custom TypeError for untyped computed var errors."""
|
||||
|
||||
|
@ -12,7 +12,6 @@ import json
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
import warnings
|
||||
from types import CodeType, FunctionType
|
||||
from typing import (
|
||||
@ -82,6 +81,7 @@ if TYPE_CHECKING:
|
||||
VAR_TYPE = TypeVar("VAR_TYPE", covariant=True)
|
||||
OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
|
||||
STRING_T = TypeVar("STRING_T", bound=str)
|
||||
SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence)
|
||||
|
||||
warnings.filterwarnings("ignore", message="fields may not start with an underscore")
|
||||
|
||||
@ -449,7 +449,7 @@ class Var(Generic[VAR_TYPE]):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class ToVarOperation(ToOperation, cls):
|
||||
"""Base class of converting a var to another var type."""
|
||||
@ -597,7 +597,7 @@ class Var(Generic[VAR_TYPE]):
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def create(
|
||||
def create( # pyright: ignore [reportOverlappingOverload]
|
||||
cls,
|
||||
value: STRING_T,
|
||||
_var_data: VarData | None = None,
|
||||
@ -611,6 +611,22 @@ class Var(Generic[VAR_TYPE]):
|
||||
_var_data: VarData | None = None,
|
||||
) -> NoneVar: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
value: MAPPING_TYPE,
|
||||
_var_data: VarData | None = None,
|
||||
) -> ObjectVar[MAPPING_TYPE]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
value: SEQUENCE_TYPE,
|
||||
_var_data: VarData | None = None,
|
||||
) -> ArrayVar[SEQUENCE_TYPE]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def create(
|
||||
@ -692,8 +708,8 @@ class Var(Generic[VAR_TYPE]):
|
||||
@overload
|
||||
def to(
|
||||
self,
|
||||
output: type[Mapping],
|
||||
) -> ObjectVar[Mapping]: ...
|
||||
output: type[MAPPING_TYPE],
|
||||
) -> ObjectVar[MAPPING_TYPE]: ...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
@ -744,7 +760,7 @@ class Var(Generic[VAR_TYPE]):
|
||||
return get_to_operation(NoneVar).create(self) # pyright: ignore [reportReturnType]
|
||||
|
||||
# Handle fixed_output_type being Base or a dataclass.
|
||||
if can_use_in_object_var(fixed_output_type):
|
||||
if can_use_in_object_var(output):
|
||||
return self.to(ObjectVar, output)
|
||||
|
||||
if inspect.isclass(output):
|
||||
@ -776,6 +792,9 @@ class Var(Generic[VAR_TYPE]):
|
||||
|
||||
return self
|
||||
|
||||
@overload
|
||||
def guess_type(self: Var[NoReturn]) -> Var[Any]: ... # pyright: ignore [reportOverlappingOverload]
|
||||
|
||||
@overload
|
||||
def guess_type(self: Var[str]) -> StringVar: ...
|
||||
|
||||
@ -785,6 +804,9 @@ class Var(Generic[VAR_TYPE]):
|
||||
@overload
|
||||
def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def guess_type(self: Var[BASE_TYPE]) -> ObjectVar[BASE_TYPE]: ...
|
||||
|
||||
@overload
|
||||
def guess_type(self) -> Self: ...
|
||||
|
||||
@ -933,7 +955,7 @@ class Var(Generic[VAR_TYPE]):
|
||||
|
||||
return setter
|
||||
|
||||
def _var_set_state(self, state: type[BaseState] | str):
|
||||
def _var_set_state(self, state: type[BaseState] | str) -> Self:
|
||||
"""Set the state of the var.
|
||||
|
||||
Args:
|
||||
@ -948,7 +970,7 @@ class Var(Generic[VAR_TYPE]):
|
||||
else format_state_name(state.get_full_name())
|
||||
)
|
||||
|
||||
return StateOperation.create(
|
||||
return StateOperation.create( # pyright: ignore [reportReturnType]
|
||||
formatted_state_name,
|
||||
self,
|
||||
_var_data=VarData.merge(
|
||||
@ -1127,43 +1149,6 @@ class Var(Generic[VAR_TYPE]):
|
||||
"""
|
||||
return self
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Get an attribute of the var.
|
||||
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
|
||||
Returns:
|
||||
The attribute.
|
||||
|
||||
Raises:
|
||||
VarAttributeError: If the attribute does not exist.
|
||||
TypeError: If the var type is Any.
|
||||
"""
|
||||
if name.startswith("_"):
|
||||
return self.__getattribute__(name)
|
||||
|
||||
if name == "contains":
|
||||
raise TypeError(
|
||||
f"Var of type {self._var_type} does not support contains check."
|
||||
)
|
||||
if name == "reverse":
|
||||
raise TypeError("Cannot reverse non-list var.")
|
||||
|
||||
if self._var_type is Any:
|
||||
raise TypeError(
|
||||
f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`."
|
||||
)
|
||||
|
||||
if name in REPLACED_NAMES:
|
||||
raise VarAttributeError(
|
||||
f"Field {name!r} was renamed to {REPLACED_NAMES[name]!r}"
|
||||
)
|
||||
|
||||
raise VarAttributeError(
|
||||
f"The State var has no attribute '{name}' or may have been annotated wrongly.",
|
||||
)
|
||||
|
||||
def _decode(self) -> Any:
|
||||
"""Decode Var as a python value.
|
||||
|
||||
@ -1225,36 +1210,76 @@ class Var(Generic[VAR_TYPE]):
|
||||
|
||||
return ArrayVar.range(first_endpoint, second_endpoint, step)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Raise exception if using Var in a boolean context.
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
Raises:
|
||||
VarTypeError: when attempting to bool-ify the Var.
|
||||
"""
|
||||
raise VarTypeError(
|
||||
f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. "
|
||||
"Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)."
|
||||
)
|
||||
def __getattr__(self, name: str):
|
||||
"""Get an attribute of the var.
|
||||
|
||||
def __iter__(self) -> Any:
|
||||
"""Raise exception if using Var in an iterable context.
|
||||
Args:
|
||||
name: The name of the attribute.
|
||||
|
||||
Raises:
|
||||
VarTypeError: when attempting to iterate over the Var.
|
||||
"""
|
||||
raise VarTypeError(
|
||||
f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`."
|
||||
)
|
||||
Raises:
|
||||
VarAttributeError: If the attribute does not exist.
|
||||
UntypedVarError: If the var type is Any.
|
||||
TypeError: If the var type is Any.
|
||||
|
||||
def __contains__(self, _: Any) -> Var:
|
||||
"""Override the 'in' operator to alert the user that it is not supported.
|
||||
# noqa: DAR101 self
|
||||
"""
|
||||
if name.startswith("_"):
|
||||
raise VarAttributeError(f"Attribute {name} not found.")
|
||||
|
||||
Raises:
|
||||
VarTypeError: the operation is not supported
|
||||
"""
|
||||
raise VarTypeError(
|
||||
"'in' operator not supported for Var types, use Var.contains() instead."
|
||||
)
|
||||
if name == "contains":
|
||||
raise TypeError(
|
||||
f"Var of type {self._var_type} does not support contains check."
|
||||
)
|
||||
if name == "reverse":
|
||||
raise TypeError("Cannot reverse non-list var.")
|
||||
|
||||
if self._var_type is Any:
|
||||
raise exceptions.UntypedVarError(
|
||||
f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`."
|
||||
)
|
||||
|
||||
raise VarAttributeError(
|
||||
f"The State var has no attribute '{name}' or may have been annotated wrongly.",
|
||||
)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Raise exception if using Var in a boolean context.
|
||||
|
||||
Raises:
|
||||
VarTypeError: when attempting to bool-ify the Var.
|
||||
|
||||
# noqa: DAR101 self
|
||||
"""
|
||||
raise VarTypeError(
|
||||
f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. "
|
||||
"Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)."
|
||||
)
|
||||
|
||||
def __iter__(self) -> Any:
|
||||
"""Raise exception if using Var in an iterable context.
|
||||
|
||||
Raises:
|
||||
VarTypeError: when attempting to iterate over the Var.
|
||||
|
||||
# noqa: DAR101 self
|
||||
"""
|
||||
raise VarTypeError(
|
||||
f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`."
|
||||
)
|
||||
|
||||
def __contains__(self, _: Any) -> Var:
|
||||
"""Override the 'in' operator to alert the user that it is not supported.
|
||||
|
||||
Raises:
|
||||
VarTypeError: the operation is not supported
|
||||
|
||||
# noqa: DAR101 self
|
||||
"""
|
||||
raise VarTypeError(
|
||||
"'in' operator not supported for Var types, use Var.contains() instead."
|
||||
)
|
||||
|
||||
|
||||
OUTPUT = TypeVar("OUTPUT", bound=Var)
|
||||
@ -1471,6 +1496,12 @@ class LiteralVar(Var):
|
||||
def __post_init__(self):
|
||||
"""Post-initialize the var."""
|
||||
|
||||
@property
|
||||
def _var_value(self) -> Any:
|
||||
raise NotImplementedError(
|
||||
"LiteralVar subclasses must implement the _var_value property."
|
||||
)
|
||||
|
||||
def json(self) -> str:
|
||||
"""Serialize the var to a JSON string.
|
||||
|
||||
@ -1543,7 +1574,7 @@ def var_operation(
|
||||
) -> Callable[P, StringVar]: ...
|
||||
|
||||
|
||||
LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set])
|
||||
LIST_T = TypeVar("LIST_T", bound=Sequence)
|
||||
|
||||
|
||||
@overload
|
||||
@ -1780,7 +1811,7 @@ def _or_operation(a: Var, b: Var):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class CallableVar(Var):
|
||||
"""Decorate a Var-returning function to act as both a Var and a function.
|
||||
@ -1861,7 +1892,7 @@ def is_computed_var(obj: Any) -> TypeGuard[ComputedVar]:
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class ComputedVar(Var[RETURN_TYPE]):
|
||||
"""A field with computed getters."""
|
||||
@ -2070,13 +2101,6 @@ class ComputedVar(Var[RETURN_TYPE]):
|
||||
owner: Type,
|
||||
) -> ArrayVar[list[LIST_INSIDE]]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: ComputedVar[set[LIST_INSIDE]],
|
||||
instance: None,
|
||||
owner: Type,
|
||||
) -> ArrayVar[set[LIST_INSIDE]]: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: ComputedVar[tuple[LIST_INSIDE, ...]],
|
||||
@ -2436,7 +2460,7 @@ def var_operation_return(
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class CustomVarOperation(CachedVarOperation, Var[T]):
|
||||
"""Base class for custom var operations."""
|
||||
@ -2507,7 +2531,7 @@ class NoneVar(Var[None], python_types=type(None)):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class LiteralNoneVar(LiteralVar, NoneVar):
|
||||
"""A var representing None."""
|
||||
@ -2569,7 +2593,7 @@ def get_to_operation(var_subclass: Type[Var]) -> Type[ToOperation]:
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class StateOperation(CachedVarOperation, Var):
|
||||
"""A var operation that accesses a field on an object."""
|
||||
@ -2716,19 +2740,6 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]:
|
||||
return var_datas
|
||||
|
||||
|
||||
# These names were changed in reflex 0.3.0
|
||||
REPLACED_NAMES = {
|
||||
"full_name": "_var_full_name",
|
||||
"name": "_js_expr",
|
||||
"state": "_var_data.state",
|
||||
"type_": "_var_type",
|
||||
"is_local": "_var_is_local",
|
||||
"is_string": "_var_is_string",
|
||||
"set_state": "_var_set_state",
|
||||
"deps": "_deps",
|
||||
}
|
||||
|
||||
|
||||
dispatchers: Dict[GenericType, Callable[[Var], Var]] = {}
|
||||
|
||||
|
||||
|
@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
from datetime import date, datetime
|
||||
from typing import Any, NoReturn, TypeVar, Union, overload
|
||||
|
||||
@ -193,7 +192,7 @@ def date_compare_operation(
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class LiteralDatetimeVar(LiteralVar, DateTimeVar):
|
||||
"""Base class for immutable datetime and date vars."""
|
||||
|
@ -226,7 +226,7 @@ class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
|
||||
"""Base class for immutable vars that are the result of a function call."""
|
||||
@ -350,7 +350,7 @@ def format_args_function_operation(
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
||||
"""Base class for immutable function defined via arguments and return expression."""
|
||||
@ -407,7 +407,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
|
||||
"""Base class for immutable function defined via arguments and return expression with the builder pattern."""
|
||||
|
@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -160,7 +159,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
"""
|
||||
from .sequence import ArrayVar, LiteralArrayVar
|
||||
|
||||
if isinstance(other, (list, tuple, set, ArrayVar)):
|
||||
if isinstance(other, (list, tuple, ArrayVar)):
|
||||
if isinstance(other, ArrayVar):
|
||||
return other * self
|
||||
return LiteralArrayVar.create(other) * self
|
||||
@ -187,7 +186,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
|
||||
"""
|
||||
from .sequence import ArrayVar, LiteralArrayVar
|
||||
|
||||
if isinstance(other, (list, tuple, set, ArrayVar)):
|
||||
if isinstance(other, (list, tuple, ArrayVar)):
|
||||
if isinstance(other, ArrayVar):
|
||||
return other * self
|
||||
return LiteralArrayVar.create(other) * self
|
||||
@ -973,7 +972,7 @@ def boolean_not_operation(value: BooleanVar):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class LiteralNumberVar(LiteralVar, NumberVar):
|
||||
"""Base class for immutable literal number vars."""
|
||||
@ -1032,7 +1031,7 @@ class LiteralNumberVar(LiteralVar, NumberVar):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class LiteralBooleanVar(LiteralVar, BooleanVar):
|
||||
"""Base class for immutable literal boolean vars."""
|
||||
|
@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
import typing
|
||||
from inspect import isclass
|
||||
from typing import (
|
||||
@ -167,12 +166,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
|
||||
key: Var | Any,
|
||||
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
|
||||
key: Var | Any,
|
||||
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
|
||||
@ -229,12 +222,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
|
||||
name: str,
|
||||
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
|
||||
name: str,
|
||||
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
@overload
|
||||
def __getattr__(
|
||||
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
|
||||
@ -305,7 +292,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
|
||||
"""Base class for immutable literal object vars."""
|
||||
@ -355,17 +342,20 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
|
||||
|
||||
Returns:
|
||||
The JSON representation of the object.
|
||||
|
||||
Raises:
|
||||
TypeError: The keys and values of the object must be literal vars to get the JSON representation
|
||||
"""
|
||||
return (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
[
|
||||
f"{LiteralVar.create(key).json()}:{LiteralVar.create(value).json()}"
|
||||
for key, value in self._var_value.items()
|
||||
]
|
||||
)
|
||||
+ "}"
|
||||
)
|
||||
keys_and_values = []
|
||||
for key, value in self._var_value.items():
|
||||
key = LiteralVar.create(key)
|
||||
value = LiteralVar.create(value)
|
||||
if not isinstance(key, LiteralVar) or not isinstance(value, LiteralVar):
|
||||
raise TypeError(
|
||||
"The keys and values of the object must be literal vars to get the JSON representation."
|
||||
)
|
||||
keys_and_values.append(f"{key.json()}:{value.json()}")
|
||||
return "{" + ", ".join(keys_and_values) + "}"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Get the hash of the var.
|
||||
@ -487,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class ObjectItemOperation(CachedVarOperation, Var):
|
||||
"""Operation to get an item from an object."""
|
||||
|
@ -6,7 +6,6 @@ import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import typing
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -15,7 +14,7 @@ from typing import (
|
||||
List,
|
||||
Literal,
|
||||
NoReturn,
|
||||
Set,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
@ -596,7 +595,7 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class LiteralStringVar(LiteralVar, StringVar[str]):
|
||||
"""Base class for immutable literal string vars."""
|
||||
@ -718,7 +717,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class ConcatVarOperation(CachedVarOperation, StringVar[str]):
|
||||
"""Representing a concatenation of literal string vars."""
|
||||
@ -794,7 +793,8 @@ class ConcatVarOperation(CachedVarOperation, StringVar[str]):
|
||||
)
|
||||
|
||||
|
||||
ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set])
|
||||
ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Sequence, covariant=True)
|
||||
OTHER_ARRAY_VAR_TYPE = TypeVar("OTHER_ARRAY_VAR_TYPE", bound=Sequence)
|
||||
|
||||
OTHER_TUPLE = TypeVar("OTHER_TUPLE")
|
||||
|
||||
@ -887,6 +887,11 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
|
||||
i: Literal[0, -2],
|
||||
) -> NumberVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1]
|
||||
) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: (
|
||||
@ -914,7 +919,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1]
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
|
||||
) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
@ -932,23 +937,12 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar
|
||||
) -> StringVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
|
||||
) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]],
|
||||
i: int | NumberVar,
|
||||
) -> ArrayVar[List[INNER_ARRAY_VAR]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]],
|
||||
i: int | NumberVar,
|
||||
) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[KEY_TYPE, VALUE_TYPE]],
|
||||
@ -1239,26 +1233,18 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
|
||||
|
||||
LIST_ELEMENT = TypeVar("LIST_ELEMENT")
|
||||
|
||||
ARRAY_VAR_OF_LIST_ELEMENT = Union[
|
||||
ArrayVar[List[LIST_ELEMENT]],
|
||||
ArrayVar[Set[LIST_ELEMENT]],
|
||||
ArrayVar[Tuple[LIST_ELEMENT, ...]],
|
||||
]
|
||||
ARRAY_VAR_OF_LIST_ELEMENT = ArrayVar[Sequence[LIST_ELEMENT]]
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
|
||||
"""Base class for immutable literal array vars."""
|
||||
|
||||
_var_value: Union[
|
||||
List[Union[Var, Any]],
|
||||
Set[Union[Var, Any]],
|
||||
Tuple[Union[Var, Any], ...],
|
||||
] = dataclasses.field(default_factory=list)
|
||||
_var_value: Sequence[Union[Var, Any]] = dataclasses.field(default=())
|
||||
|
||||
@cached_property_no_lock
|
||||
def _cached_var_name(self) -> str:
|
||||
@ -1303,22 +1289,28 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
|
||||
|
||||
Returns:
|
||||
The JSON representation of the var.
|
||||
|
||||
Raises:
|
||||
TypeError: If the array elements are not of type LiteralVar.
|
||||
"""
|
||||
return (
|
||||
"["
|
||||
+ ", ".join(
|
||||
[LiteralVar.create(element).json() for element in self._var_value]
|
||||
)
|
||||
+ "]"
|
||||
)
|
||||
elements = []
|
||||
for element in self._var_value:
|
||||
element_var = LiteralVar.create(element)
|
||||
if not isinstance(element_var, LiteralVar):
|
||||
raise TypeError(
|
||||
f"Array elements must be of type LiteralVar, not {type(element_var)}"
|
||||
)
|
||||
elements.append(element_var.json())
|
||||
|
||||
return "[" + ", ".join(elements) + "]"
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
value: ARRAY_VAR_TYPE,
|
||||
_var_type: Type[ARRAY_VAR_TYPE] | None = None,
|
||||
value: OTHER_ARRAY_VAR_TYPE,
|
||||
_var_type: Type[OTHER_ARRAY_VAR_TYPE] | None = None,
|
||||
_var_data: VarData | None = None,
|
||||
) -> LiteralArrayVar[ARRAY_VAR_TYPE]:
|
||||
) -> LiteralArrayVar[OTHER_ARRAY_VAR_TYPE]:
|
||||
"""Create a var from a string value.
|
||||
|
||||
Args:
|
||||
@ -1329,7 +1321,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
|
||||
Returns:
|
||||
The var.
|
||||
"""
|
||||
return cls(
|
||||
return LiteralArrayVar(
|
||||
_js_expr="",
|
||||
_var_type=figure_out_type(value) if _var_type is None else _var_type,
|
||||
_var_data=_var_data,
|
||||
@ -1356,7 +1348,7 @@ def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class ArraySliceOperation(CachedVarOperation, ArrayVar):
|
||||
"""Base class for immutable string vars that are the result of a string slice operation."""
|
||||
@ -1705,7 +1697,7 @@ class ColorVar(StringVar[Color], python_types=Color):
|
||||
@dataclasses.dataclass(
|
||||
eq=False,
|
||||
frozen=True,
|
||||
**{"slots": True} if sys.version_info >= (3, 10) else {},
|
||||
slots=True,
|
||||
)
|
||||
class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar):
|
||||
"""Base class for immutable literal color vars."""
|
||||
|
@ -3,6 +3,7 @@ from typing import List, Mapping, Tuple
|
||||
import pytest
|
||||
|
||||
import reflex as rx
|
||||
from reflex.components.component import Component
|
||||
from reflex.components.core.match import Match
|
||||
from reflex.state import BaseState
|
||||
from reflex.utils.exceptions import MatchTypeError
|
||||
@ -29,6 +30,8 @@ def test_match_components():
|
||||
rx.text("default value"),
|
||||
)
|
||||
match_comp = Match.create(MatchState.value, *match_case_tuples)
|
||||
|
||||
assert isinstance(match_comp, Component)
|
||||
match_dict = match_comp.render()
|
||||
assert match_dict["name"] == "Fragment"
|
||||
|
||||
@ -151,6 +154,7 @@ def test_match_on_component_without_default():
|
||||
)
|
||||
|
||||
match_comp = Match.create(MatchState.value, *match_case_tuples)
|
||||
assert isinstance(match_comp, Component)
|
||||
default = match_comp.render()["children"][0]["default"]
|
||||
|
||||
assert isinstance(default, dict) and default["name"] == Fragment.__name__
|
||||
|
@ -36,6 +36,7 @@ from reflex.utils.exceptions import (
|
||||
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
from reflex.vars.object import ObjectVar
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -842,12 +843,12 @@ def test_component_event_trigger_arbitrary_args():
|
||||
"""Test that we can define arbitrary types for the args of an event trigger."""
|
||||
|
||||
def on_foo_spec(
|
||||
_e: Var[JavascriptInputEvent],
|
||||
_e: ObjectVar[JavascriptInputEvent],
|
||||
alpha: Var[str],
|
||||
bravo: dict[str, Any],
|
||||
charlie: Var[_Obj],
|
||||
charlie: ObjectVar[_Obj],
|
||||
):
|
||||
return [_e.target.value, bravo["nested"], charlie.custom + 42]
|
||||
return [_e.target.value, bravo["nested"], charlie.custom.to(int) + 42]
|
||||
|
||||
class C1(Component):
|
||||
library = "/local"
|
||||
@ -1328,7 +1329,7 @@ class EventState(rx.State):
|
||||
),
|
||||
pytest.param(
|
||||
rx.fragment(class_name=[TEST_VAR, "other-class"]),
|
||||
[LiteralVar.create([TEST_VAR, "other-class"]).join(" ")],
|
||||
[Var.create([TEST_VAR, "other-class"]).join(" ")],
|
||||
id="fstring-dual-class_name",
|
||||
),
|
||||
pytest.param(
|
||||
|
@ -471,15 +471,15 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
|
||||
"""
|
||||
state = test_state() # pyright: ignore [reportCallIssue]
|
||||
state.add_var("int_val", int, 0)
|
||||
result = await state._process(
|
||||
async for result in state._process(
|
||||
Event(
|
||||
token=token,
|
||||
name=f"{test_state.get_name()}.set_int_val",
|
||||
router_data={"pathname": "/", "query": {}},
|
||||
payload={"value": 50},
|
||||
)
|
||||
).__anext__()
|
||||
assert result.delta == {test_state.get_name(): {"int_val": 50}}
|
||||
):
|
||||
assert result.delta == {test_state.get_name(): {"int_val": 50}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -583,18 +583,17 @@ async def test_list_mutation_detection__plain_list(
|
||||
token: a Token.
|
||||
"""
|
||||
for event_name, expected_delta in event_tuples:
|
||||
result = await list_mutation_state._process(
|
||||
async for result in list_mutation_state._process(
|
||||
Event(
|
||||
token=token,
|
||||
name=f"{list_mutation_state.get_name()}.{event_name}",
|
||||
router_data={"pathname": "/", "query": {}},
|
||||
payload={},
|
||||
)
|
||||
).__anext__()
|
||||
|
||||
# prefix keys in expected_delta with the state name
|
||||
expected_delta = {list_mutation_state.get_name(): expected_delta}
|
||||
assert result.delta == expected_delta
|
||||
):
|
||||
# prefix keys in expected_delta with the state name
|
||||
expected_delta = {list_mutation_state.get_name(): expected_delta}
|
||||
assert result.delta == expected_delta
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -709,19 +708,18 @@ async def test_dict_mutation_detection__plain_list(
|
||||
token: a Token.
|
||||
"""
|
||||
for event_name, expected_delta in event_tuples:
|
||||
result = await dict_mutation_state._process(
|
||||
async for result in dict_mutation_state._process(
|
||||
Event(
|
||||
token=token,
|
||||
name=f"{dict_mutation_state.get_name()}.{event_name}",
|
||||
router_data={"pathname": "/", "query": {}},
|
||||
payload={},
|
||||
)
|
||||
).__anext__()
|
||||
):
|
||||
# prefix keys in expected_delta with the state name
|
||||
expected_delta = {dict_mutation_state.get_name(): expected_delta}
|
||||
|
||||
# prefix keys in expected_delta with the state name
|
||||
expected_delta = {dict_mutation_state.get_name(): expected_delta}
|
||||
|
||||
assert result.delta == expected_delta
|
||||
assert result.delta == expected_delta
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -789,17 +789,16 @@ async def test_process_event_simple(test_state):
|
||||
assert test_state.num1 == 0
|
||||
|
||||
event = Event(token="t", name="set_num1", payload={"value": 69})
|
||||
update = await test_state._process(event).__anext__()
|
||||
async for update in test_state._process(event):
|
||||
# The event should update the value.
|
||||
assert test_state.num1 == 69
|
||||
|
||||
# The event should update the value.
|
||||
assert test_state.num1 == 69
|
||||
|
||||
# The delta should contain the changes, including computed vars.
|
||||
assert update.delta == {
|
||||
TestState.get_full_name(): {"num1": 69, "sum": 72.14},
|
||||
GrandchildState3.get_full_name(): {"computed": ""},
|
||||
}
|
||||
assert update.events == []
|
||||
# The delta should contain the changes, including computed vars.
|
||||
assert update.delta == {
|
||||
TestState.get_full_name(): {"num1": 69, "sum": 72.14},
|
||||
GrandchildState3.get_full_name(): {"computed": ""},
|
||||
}
|
||||
assert update.events == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -819,15 +818,15 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
name=f"{ChildState.get_name()}.change_both",
|
||||
payload={"value": "hi", "count": 12},
|
||||
)
|
||||
update = await test_state._process(event).__anext__()
|
||||
assert child_state.value == "HI"
|
||||
assert child_state.count == 24
|
||||
assert update.delta == {
|
||||
# TestState.get_full_name(): {"sum": 3.14, "upper": ""},
|
||||
ChildState.get_full_name(): {"value": "HI", "count": 24},
|
||||
GrandchildState3.get_full_name(): {"computed": ""},
|
||||
}
|
||||
test_state._clean()
|
||||
async for update in test_state._process(event):
|
||||
assert child_state.value == "HI"
|
||||
assert child_state.count == 24
|
||||
assert update.delta == {
|
||||
# TestState.get_full_name(): {"sum": 3.14, "upper": ""},
|
||||
ChildState.get_full_name(): {"value": "HI", "count": 24},
|
||||
GrandchildState3.get_full_name(): {"computed": ""},
|
||||
}
|
||||
test_state._clean()
|
||||
|
||||
# Test with the granchild state.
|
||||
assert grandchild_state.value2 == ""
|
||||
@ -836,13 +835,13 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
|
||||
name=f"{GrandchildState.get_full_name()}.set_value2",
|
||||
payload={"value": "new"},
|
||||
)
|
||||
update = await test_state._process(event).__anext__()
|
||||
assert grandchild_state.value2 == "new"
|
||||
assert update.delta == {
|
||||
# TestState.get_full_name(): {"sum": 3.14, "upper": ""},
|
||||
GrandchildState.get_full_name(): {"value2": "new"},
|
||||
GrandchildState3.get_full_name(): {"computed": ""},
|
||||
}
|
||||
async for update in test_state._process(event):
|
||||
assert grandchild_state.value2 == "new"
|
||||
assert update.delta == {
|
||||
# TestState.get_full_name(): {"sum": 3.14, "upper": ""},
|
||||
GrandchildState.get_full_name(): {"value2": "new"},
|
||||
GrandchildState3.get_full_name(): {"computed": ""},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -2909,10 +2908,10 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
|
||||
|
||||
events = updates[0].events
|
||||
assert len(events) == 2
|
||||
assert (await state._process(events[0]).__anext__()).delta == {
|
||||
test_state.get_full_name(): {"num": 1}
|
||||
}
|
||||
assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state)
|
||||
async for update in state._process(events[0]):
|
||||
assert update.delta == {test_state.get_full_name(): {"num": 1}}
|
||||
async for update in state._process(events[1]):
|
||||
assert update.delta == exp_is_hydrated(state)
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
await app.state_manager.close()
|
||||
@ -2957,13 +2956,12 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
|
||||
|
||||
events = updates[0].events
|
||||
assert len(events) == 3
|
||||
assert (await state._process(events[0]).__anext__()).delta == {
|
||||
OnLoadState.get_full_name(): {"num": 1}
|
||||
}
|
||||
assert (await state._process(events[1]).__anext__()).delta == {
|
||||
OnLoadState.get_full_name(): {"num": 2}
|
||||
}
|
||||
assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
|
||||
async for update in state._process(events[0]):
|
||||
assert update.delta == {OnLoadState.get_full_name(): {"num": 1}}
|
||||
async for update in state._process(events[1]):
|
||||
assert update.delta == {OnLoadState.get_full_name(): {"num": 2}}
|
||||
async for update in state._process(events[2]):
|
||||
assert update.delta == exp_is_hydrated(state)
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
await app.state_manager.close()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import json
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
|
||||
|
||||
@ -422,19 +421,13 @@ class Bar(rx.Base):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("var", "var_type"),
|
||||
(
|
||||
[
|
||||
(Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar),
|
||||
(Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]),
|
||||
]
|
||||
if sys.version_info >= (3, 10)
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
|
||||
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
|
||||
[
|
||||
(Var(_js_expr="").to(Foo | Bar), Foo | Bar),
|
||||
(Var(_js_expr="").to(Foo | Bar).bar, Union[int, str]),
|
||||
(Var(_js_expr="").to(Union[Foo, Bar]), Union[Foo, Bar]),
|
||||
(Var(_js_expr="").to(Union[Foo, Bar]).baz, str),
|
||||
(
|
||||
Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo,
|
||||
Var(_js_expr="").to(Union[Foo, Bar]).foo,
|
||||
Union[int, None],
|
||||
),
|
||||
],
|
||||
@ -1358,7 +1351,7 @@ def test_unsupported_types_for_contains(var: Var):
|
||||
var: The base var.
|
||||
"""
|
||||
with pytest.raises(TypeError) as err:
|
||||
assert var.contains(1)
|
||||
assert var.contains(1) # pyright: ignore [reportAttributeAccessIssue]
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== f"Var of type {var._var_type} does not support contains check."
|
||||
@ -1388,7 +1381,7 @@ def test_unsupported_types_for_string_contains(other):
|
||||
|
||||
def test_unsupported_default_contains():
|
||||
with pytest.raises(TypeError) as err:
|
||||
assert 1 in Var(_js_expr="var", _var_type=str).guess_type()
|
||||
assert 1 in Var(_js_expr="var", _var_type=str).guess_type() # pyright: ignore [reportOperatorIssue]
|
||||
assert (
|
||||
err.value.args[0]
|
||||
== "'in' operator not supported for Var types, use Var.contains() instead."
|
||||
|
Loading…
Reference in New Issue
Block a user