improve var base typing (#4718)

* improve var base typing

* fix pyi

* dang it darglint

* drain _process in tests

* fixes #4576

* dang it darglint
This commit is contained in:
Khaleel Al-Adhami 2025-01-31 13:12:33 -08:00 committed by GitHub
parent 12a42b6c47
commit 8663dbcb97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 279 additions and 264 deletions

View File

@ -11,10 +11,11 @@ from reflex.event import EventHandler, set_clipboard
from reflex.state import FrontendEventExceptionState from reflex.state import FrontendEventExceptionState
from reflex.vars.base import Var from reflex.vars.base import Var
from reflex.vars.function import ArgsFunctionOperation from reflex.vars.function import ArgsFunctionOperation
from reflex.vars.object import ObjectVar
def on_error_spec( def on_error_spec(
error: Var[Dict[str, str]], info: Var[Dict[str, str]] error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]]
) -> Tuple[Var[str], Var[str]]: ) -> Tuple[Var[str], Var[str]]:
"""The spec for the on_error event handler. """The spec for the on_error event handler.

View File

@ -9,9 +9,10 @@ from reflex.components.component import Component
from reflex.event import BASE_STATE, EventType from reflex.event import BASE_STATE, EventType
from reflex.style import Style from reflex.style import Style
from reflex.vars.base import Var from reflex.vars.base import Var
from reflex.vars.object import ObjectVar
def on_error_spec( def on_error_spec(
error: Var[Dict[str, str]], info: Var[Dict[str, str]] error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]]
) -> Tuple[Var[str], Var[str]]: ... ) -> Tuple[Var[str], Var[str]]: ...
class ErrorBoundary(Component): class ErrorBoundary(Component):

View File

@ -2457,6 +2457,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) ->
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
slots=True,
) )
class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar):
"""A Var that represents a Component.""" """A Var that represents a Component."""

View File

@ -11,6 +11,7 @@ from reflex.components.component import Component
from reflex.components.tags import IterTag from reflex.components.tags import IterTag
from reflex.constants import MemoizationMode from reflex.constants import MemoizationMode
from reflex.state import ComponentState from reflex.state import ComponentState
from reflex.utils.exceptions import UntypedVarError
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
@ -51,6 +52,7 @@ class Foreach(Component):
Raises: Raises:
ForeachVarError: If the iterable is of type Any. ForeachVarError: If the iterable is of type Any.
TypeError: If the render function is a ComponentState. TypeError: If the render function is a ComponentState.
UntypedVarError: If the iterable is of type Any without a type annotation.
""" """
iterable = LiteralVar.create(iterable) iterable = LiteralVar.create(iterable)
if iterable._var_type == Any: if iterable._var_type == Any:
@ -72,8 +74,14 @@ class Foreach(Component):
iterable=iterable, iterable=iterable,
render_fn=render_fn, render_fn=render_fn,
) )
# Keep a ref to a rendered component to determine correct imports/hooks/styles. try:
component.children = [component._render().render_component()] # Keep a ref to a rendered component to determine correct imports/hooks/styles.
component.children = [component._render().render_component()]
except UntypedVarError as e:
raise UntypedVarError(
f"Could not foreach over var `{iterable!s}` without a type annotation. "
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
) from e
return component return component
def _render(self) -> IterTag: def _render(self) -> IterTag:

View File

@ -387,7 +387,8 @@ class DataEditor(NoSSRComponent):
raise ValueError( raise ValueError(
"DataEditor data must be an ArrayVar if rows is not provided." "DataEditor data must be an ArrayVar if rows is not provided."
) )
props["rows"] = data.length() if isinstance(data, Var) else len(data)
props["rows"] = data.length() if isinstance(data, ArrayVar) else len(data)
if not isinstance(columns, Var) and len(columns): if not isinstance(columns, Var) and len(columns):
if types.is_dataframe(type(data)) or ( if types.is_dataframe(type(data)) or (

View File

@ -621,18 +621,22 @@ class ShikiCodeBlock(Component, MarkdownComponentMap):
Returns: Returns:
Imports for the component. Imports for the component.
Raises:
ValueError: If the transformers are not of type LiteralVar.
""" """
imports = defaultdict(list) imports = defaultdict(list)
if not isinstance(self.transformers, LiteralVar):
raise ValueError(
f"transformers should be a LiteralVar type. Got {type(self.transformers)} instead."
)
for transformer in self.transformers._var_value: for transformer in self.transformers._var_value:
if isinstance(transformer, ShikiBaseTransformers): if isinstance(transformer, ShikiBaseTransformers):
imports[transformer.library].extend( imports[transformer.library].extend(
[ImportVar(tag=str(fn)) for fn in transformer.fns] [ImportVar(tag=str(fn)) for fn in transformer.fns]
) )
( if transformer.library not in self.lib_dependencies:
self.lib_dependencies.append(transformer.library) self.lib_dependencies.append(transformer.library)
if transformer.library not in self.lib_dependencies
else None
)
return imports return imports
@classmethod @classmethod

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
import sys
import types import types
import urllib.parse import urllib.parse
from base64 import b64encode from base64 import b64encode
@ -541,7 +540,7 @@ class JavasciptKeyboardEvent:
shiftKey: bool = False # noqa: N815 shiftKey: bool = False # noqa: N815
def input_event(e: Var[JavascriptInputEvent]) -> Tuple[Var[str]]: def input_event(e: ObjectVar[JavascriptInputEvent]) -> Tuple[Var[str]]:
"""Get the value from an input event. """Get the value from an input event.
Args: Args:
@ -562,7 +561,9 @@ class KeyInputInfo(TypedDict):
shift_key: bool shift_key: bool
def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInfo]]: def key_event(
e: ObjectVar[JavasciptKeyboardEvent],
) -> Tuple[Var[str], Var[KeyInputInfo]]:
"""Get the key from a keyboard event. """Get the key from a keyboard event.
Args: Args:
@ -572,7 +573,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf
The key from the keyboard event. The key from the keyboard event.
""" """
return ( return (
e.key, e.key.to(str),
Var.create( Var.create(
{ {
"alt_key": e.altKey, "alt_key": e.altKey,
@ -580,7 +581,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf
"meta_key": e.metaKey, "meta_key": e.metaKey,
"shift_key": e.shiftKey, "shift_key": e.shiftKey,
}, },
), ).to(KeyInputInfo),
) )
@ -1354,7 +1355,7 @@ def unwrap_var_annotation(annotation: GenericType):
Returns: Returns:
The unwrapped annotation. The unwrapped annotation.
""" """
if get_origin(annotation) is Var and (args := get_args(annotation)): if get_origin(annotation) in (Var, ObjectVar) and (args := get_args(annotation)):
return args[0] return args[0]
return annotation return annotation
@ -1620,7 +1621,7 @@ class EventVar(ObjectVar, python_types=EventSpec):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralEventVar(VarOperationCall, LiteralVar, EventVar): class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
"""A literal event var.""" """A literal event var."""
@ -1681,7 +1682,7 @@ class EventChainVar(BuilderFunctionVar, python_types=EventChain):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
# Note: LiteralVar is second in the inheritance list allowing it act like a # Note: LiteralVar is second in the inheritance list allowing it act like a
# CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the # CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
@ -1713,6 +1714,9 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
Returns: Returns:
The created LiteralEventChainVar instance. The created LiteralEventChainVar instance.
Raises:
ValueError: If the invocation is not a FunctionVar.
""" """
arg_spec = ( arg_spec = (
value.args_spec[0] value.args_spec[0]
@ -1740,6 +1744,11 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV
else: else:
invocation = value.invocation invocation = value.invocation
if invocation is not None and not isinstance(invocation, FunctionVar):
raise ValueError(
f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}."
)
return cls( return cls(
_js_expr="", _js_expr="",
_var_type=EventChain, _var_type=EventChain,

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import dataclasses import dataclasses
import re import re
import sys
from typing import Any, Callable, Union from typing import Any, Callable, Union
from reflex import constants from reflex import constants
@ -49,7 +48,7 @@ def _client_state_ref_dict(var_name: str) -> str:
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class ClientStateVar(Var): class ClientStateVar(Var):
"""A Var that exists on the client via useState.""" """A Var that exists on the client via useState."""

View File

@ -1637,9 +1637,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if not isinstance(var, Var): if not isinstance(var, Var):
return var return var
unset = object()
# Fast case: this is a literal var and the value is known. # Fast case: this is a literal var and the value is known.
if hasattr(var, "_var_value"): if (var_value := getattr(var, "_var_value", unset)) is not unset:
return var._var_value return var_value # pyright: ignore [reportReturnType]
var_data = var._get_all_var_data() var_data = var._get_all_var_data()
if var_data is None or not var_data.state: if var_data is None or not var_data.state:

View File

@ -75,6 +75,10 @@ class VarAttributeError(ReflexError, AttributeError):
"""Custom AttributeError for var related errors.""" """Custom AttributeError for var related errors."""
class UntypedVarError(ReflexError, TypeError):
"""Custom TypeError for untyped var errors."""
class UntypedComputedVarError(ReflexError, TypeError): class UntypedComputedVarError(ReflexError, TypeError):
"""Custom TypeError for untyped computed var errors.""" """Custom TypeError for untyped computed var errors."""

View File

@ -12,7 +12,6 @@ import json
import random import random
import re import re
import string import string
import sys
import warnings import warnings
from types import CodeType, FunctionType from types import CodeType, FunctionType
from typing import ( from typing import (
@ -82,6 +81,7 @@ if TYPE_CHECKING:
VAR_TYPE = TypeVar("VAR_TYPE", covariant=True) VAR_TYPE = TypeVar("VAR_TYPE", covariant=True)
OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE") OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE")
STRING_T = TypeVar("STRING_T", bound=str) STRING_T = TypeVar("STRING_T", bound=str)
SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence)
warnings.filterwarnings("ignore", message="fields may not start with an underscore") warnings.filterwarnings("ignore", message="fields may not start with an underscore")
@ -449,7 +449,7 @@ class Var(Generic[VAR_TYPE]):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class ToVarOperation(ToOperation, cls): class ToVarOperation(ToOperation, cls):
"""Base class of converting a var to another var type.""" """Base class of converting a var to another var type."""
@ -597,7 +597,7 @@ class Var(Generic[VAR_TYPE]):
@overload @overload
@classmethod @classmethod
def create( def create( # pyright: ignore [reportOverlappingOverload]
cls, cls,
value: STRING_T, value: STRING_T,
_var_data: VarData | None = None, _var_data: VarData | None = None,
@ -611,6 +611,22 @@ class Var(Generic[VAR_TYPE]):
_var_data: VarData | None = None, _var_data: VarData | None = None,
) -> NoneVar: ... ) -> NoneVar: ...
@overload
@classmethod
def create(
cls,
value: MAPPING_TYPE,
_var_data: VarData | None = None,
) -> ObjectVar[MAPPING_TYPE]: ...
@overload
@classmethod
def create(
cls,
value: SEQUENCE_TYPE,
_var_data: VarData | None = None,
) -> ArrayVar[SEQUENCE_TYPE]: ...
@overload @overload
@classmethod @classmethod
def create( def create(
@ -692,8 +708,8 @@ class Var(Generic[VAR_TYPE]):
@overload @overload
def to( def to(
self, self,
output: type[Mapping], output: type[MAPPING_TYPE],
) -> ObjectVar[Mapping]: ... ) -> ObjectVar[MAPPING_TYPE]: ...
@overload @overload
def to( def to(
@ -744,7 +760,7 @@ class Var(Generic[VAR_TYPE]):
return get_to_operation(NoneVar).create(self) # pyright: ignore [reportReturnType] return get_to_operation(NoneVar).create(self) # pyright: ignore [reportReturnType]
# Handle fixed_output_type being Base or a dataclass. # Handle fixed_output_type being Base or a dataclass.
if can_use_in_object_var(fixed_output_type): if can_use_in_object_var(output):
return self.to(ObjectVar, output) return self.to(ObjectVar, output)
if inspect.isclass(output): if inspect.isclass(output):
@ -776,6 +792,9 @@ class Var(Generic[VAR_TYPE]):
return self return self
@overload
def guess_type(self: Var[NoReturn]) -> Var[Any]: ... # pyright: ignore [reportOverlappingOverload]
@overload @overload
def guess_type(self: Var[str]) -> StringVar: ... def guess_type(self: Var[str]) -> StringVar: ...
@ -785,6 +804,9 @@ class Var(Generic[VAR_TYPE]):
@overload @overload
def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ... def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ...
@overload
def guess_type(self: Var[BASE_TYPE]) -> ObjectVar[BASE_TYPE]: ...
@overload @overload
def guess_type(self) -> Self: ... def guess_type(self) -> Self: ...
@ -933,7 +955,7 @@ class Var(Generic[VAR_TYPE]):
return setter return setter
def _var_set_state(self, state: type[BaseState] | str): def _var_set_state(self, state: type[BaseState] | str) -> Self:
"""Set the state of the var. """Set the state of the var.
Args: Args:
@ -948,7 +970,7 @@ class Var(Generic[VAR_TYPE]):
else format_state_name(state.get_full_name()) else format_state_name(state.get_full_name())
) )
return StateOperation.create( return StateOperation.create( # pyright: ignore [reportReturnType]
formatted_state_name, formatted_state_name,
self, self,
_var_data=VarData.merge( _var_data=VarData.merge(
@ -1127,43 +1149,6 @@ class Var(Generic[VAR_TYPE]):
""" """
return self return self
def __getattr__(self, name: str):
"""Get an attribute of the var.
Args:
name: The name of the attribute.
Returns:
The attribute.
Raises:
VarAttributeError: If the attribute does not exist.
TypeError: If the var type is Any.
"""
if name.startswith("_"):
return self.__getattribute__(name)
if name == "contains":
raise TypeError(
f"Var of type {self._var_type} does not support contains check."
)
if name == "reverse":
raise TypeError("Cannot reverse non-list var.")
if self._var_type is Any:
raise TypeError(
f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`."
)
if name in REPLACED_NAMES:
raise VarAttributeError(
f"Field {name!r} was renamed to {REPLACED_NAMES[name]!r}"
)
raise VarAttributeError(
f"The State var has no attribute '{name}' or may have been annotated wrongly.",
)
def _decode(self) -> Any: def _decode(self) -> Any:
"""Decode Var as a python value. """Decode Var as a python value.
@ -1225,36 +1210,76 @@ class Var(Generic[VAR_TYPE]):
return ArrayVar.range(first_endpoint, second_endpoint, step) return ArrayVar.range(first_endpoint, second_endpoint, step)
def __bool__(self) -> bool: if not TYPE_CHECKING:
"""Raise exception if using Var in a boolean context.
Raises: def __getattr__(self, name: str):
VarTypeError: when attempting to bool-ify the Var. """Get an attribute of the var.
"""
raise VarTypeError(
f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. "
"Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)."
)
def __iter__(self) -> Any: Args:
"""Raise exception if using Var in an iterable context. name: The name of the attribute.
Raises: Raises:
VarTypeError: when attempting to iterate over the Var. VarAttributeError: If the attribute does not exist.
""" UntypedVarError: If the var type is Any.
raise VarTypeError( TypeError: If the var type is Any.
f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`."
)
def __contains__(self, _: Any) -> Var: # noqa: DAR101 self
"""Override the 'in' operator to alert the user that it is not supported. """
if name.startswith("_"):
raise VarAttributeError(f"Attribute {name} not found.")
Raises: if name == "contains":
VarTypeError: the operation is not supported raise TypeError(
""" f"Var of type {self._var_type} does not support contains check."
raise VarTypeError( )
"'in' operator not supported for Var types, use Var.contains() instead." if name == "reverse":
) raise TypeError("Cannot reverse non-list var.")
if self._var_type is Any:
raise exceptions.UntypedVarError(
f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`."
)
raise VarAttributeError(
f"The State var has no attribute '{name}' or may have been annotated wrongly.",
)
def __bool__(self) -> bool:
"""Raise exception if using Var in a boolean context.
Raises:
VarTypeError: when attempting to bool-ify the Var.
# noqa: DAR101 self
"""
raise VarTypeError(
f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. "
"Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)."
)
def __iter__(self) -> Any:
"""Raise exception if using Var in an iterable context.
Raises:
VarTypeError: when attempting to iterate over the Var.
# noqa: DAR101 self
"""
raise VarTypeError(
f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`."
)
def __contains__(self, _: Any) -> Var:
"""Override the 'in' operator to alert the user that it is not supported.
Raises:
VarTypeError: the operation is not supported
# noqa: DAR101 self
"""
raise VarTypeError(
"'in' operator not supported for Var types, use Var.contains() instead."
)
OUTPUT = TypeVar("OUTPUT", bound=Var) OUTPUT = TypeVar("OUTPUT", bound=Var)
@ -1471,6 +1496,12 @@ class LiteralVar(Var):
def __post_init__(self): def __post_init__(self):
"""Post-initialize the var.""" """Post-initialize the var."""
@property
def _var_value(self) -> Any:
raise NotImplementedError(
"LiteralVar subclasses must implement the _var_value property."
)
def json(self) -> str: def json(self) -> str:
"""Serialize the var to a JSON string. """Serialize the var to a JSON string.
@ -1543,7 +1574,7 @@ def var_operation(
) -> Callable[P, StringVar]: ... ) -> Callable[P, StringVar]: ...
LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set]) LIST_T = TypeVar("LIST_T", bound=Sequence)
@overload @overload
@ -1780,7 +1811,7 @@ def _or_operation(a: Var, b: Var):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class CallableVar(Var): class CallableVar(Var):
"""Decorate a Var-returning function to act as both a Var and a function. """Decorate a Var-returning function to act as both a Var and a function.
@ -1861,7 +1892,7 @@ def is_computed_var(obj: Any) -> TypeGuard[ComputedVar]:
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class ComputedVar(Var[RETURN_TYPE]): class ComputedVar(Var[RETURN_TYPE]):
"""A field with computed getters.""" """A field with computed getters."""
@ -2070,13 +2101,6 @@ class ComputedVar(Var[RETURN_TYPE]):
owner: Type, owner: Type,
) -> ArrayVar[list[LIST_INSIDE]]: ... ) -> ArrayVar[list[LIST_INSIDE]]: ...
@overload
def __get__(
self: ComputedVar[set[LIST_INSIDE]],
instance: None,
owner: Type,
) -> ArrayVar[set[LIST_INSIDE]]: ...
@overload @overload
def __get__( def __get__(
self: ComputedVar[tuple[LIST_INSIDE, ...]], self: ComputedVar[tuple[LIST_INSIDE, ...]],
@ -2436,7 +2460,7 @@ def var_operation_return(
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class CustomVarOperation(CachedVarOperation, Var[T]): class CustomVarOperation(CachedVarOperation, Var[T]):
"""Base class for custom var operations.""" """Base class for custom var operations."""
@ -2507,7 +2531,7 @@ class NoneVar(Var[None], python_types=type(None)):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralNoneVar(LiteralVar, NoneVar): class LiteralNoneVar(LiteralVar, NoneVar):
"""A var representing None.""" """A var representing None."""
@ -2569,7 +2593,7 @@ def get_to_operation(var_subclass: Type[Var]) -> Type[ToOperation]:
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class StateOperation(CachedVarOperation, Var): class StateOperation(CachedVarOperation, Var):
"""A var operation that accesses a field on an object.""" """A var operation that accesses a field on an object."""
@ -2716,19 +2740,6 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]:
return var_datas return var_datas
# These names were changed in reflex 0.3.0
REPLACED_NAMES = {
"full_name": "_var_full_name",
"name": "_js_expr",
"state": "_var_data.state",
"type_": "_var_type",
"is_local": "_var_is_local",
"is_string": "_var_is_string",
"set_state": "_var_set_state",
"deps": "_deps",
}
dispatchers: Dict[GenericType, Callable[[Var], Var]] = {} dispatchers: Dict[GenericType, Callable[[Var], Var]] = {}

View File

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import sys
from datetime import date, datetime from datetime import date, datetime
from typing import Any, NoReturn, TypeVar, Union, overload from typing import Any, NoReturn, TypeVar, Union, overload
@ -193,7 +192,7 @@ def date_compare_operation(
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralDatetimeVar(LiteralVar, DateTimeVar): class LiteralDatetimeVar(LiteralVar, DateTimeVar):
"""Base class for immutable datetime and date vars.""" """Base class for immutable datetime and date vars."""

View File

@ -226,7 +226,7 @@ class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]): class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
"""Base class for immutable vars that are the result of a function call.""" """Base class for immutable vars that are the result of a function call."""
@ -350,7 +350,7 @@ def format_args_function_operation(
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class ArgsFunctionOperation(CachedVarOperation, FunctionVar): class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
"""Base class for immutable function defined via arguments and return expression.""" """Base class for immutable function defined via arguments and return expression."""
@ -407,7 +407,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
"""Base class for immutable function defined via arguments and return expression with the builder pattern.""" """Base class for immutable function defined via arguments and return expression with the builder pattern."""

View File

@ -5,7 +5,6 @@ from __future__ import annotations
import dataclasses import dataclasses
import json import json
import math import math
import sys
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -160,7 +159,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
""" """
from .sequence import ArrayVar, LiteralArrayVar from .sequence import ArrayVar, LiteralArrayVar
if isinstance(other, (list, tuple, set, ArrayVar)): if isinstance(other, (list, tuple, ArrayVar)):
if isinstance(other, ArrayVar): if isinstance(other, ArrayVar):
return other * self return other * self
return LiteralArrayVar.create(other) * self return LiteralArrayVar.create(other) * self
@ -187,7 +186,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)):
""" """
from .sequence import ArrayVar, LiteralArrayVar from .sequence import ArrayVar, LiteralArrayVar
if isinstance(other, (list, tuple, set, ArrayVar)): if isinstance(other, (list, tuple, ArrayVar)):
if isinstance(other, ArrayVar): if isinstance(other, ArrayVar):
return other * self return other * self
return LiteralArrayVar.create(other) * self return LiteralArrayVar.create(other) * self
@ -973,7 +972,7 @@ def boolean_not_operation(value: BooleanVar):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralNumberVar(LiteralVar, NumberVar): class LiteralNumberVar(LiteralVar, NumberVar):
"""Base class for immutable literal number vars.""" """Base class for immutable literal number vars."""
@ -1032,7 +1031,7 @@ class LiteralNumberVar(LiteralVar, NumberVar):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralBooleanVar(LiteralVar, BooleanVar): class LiteralBooleanVar(LiteralVar, BooleanVar):
"""Base class for immutable literal boolean vars.""" """Base class for immutable literal boolean vars."""

View File

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import sys
import typing import typing
from inspect import isclass from inspect import isclass
from typing import ( from typing import (
@ -167,12 +166,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
key: Var | Any, key: Var | Any,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getitem__(
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
@ -229,12 +222,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
name: str, name: str,
) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ...
@overload
def __getattr__(
self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]],
name: str,
) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ...
@overload @overload
def __getattr__( def __getattr__(
self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]],
@ -305,7 +292,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
"""Base class for immutable literal object vars.""" """Base class for immutable literal object vars."""
@ -355,17 +342,20 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar):
Returns: Returns:
The JSON representation of the object. The JSON representation of the object.
Raises:
TypeError: The keys and values of the object must be literal vars to get the JSON representation
""" """
return ( keys_and_values = []
"{" for key, value in self._var_value.items():
+ ", ".join( key = LiteralVar.create(key)
[ value = LiteralVar.create(value)
f"{LiteralVar.create(key).json()}:{LiteralVar.create(value).json()}" if not isinstance(key, LiteralVar) or not isinstance(value, LiteralVar):
for key, value in self._var_value.items() raise TypeError(
] "The keys and values of the object must be literal vars to get the JSON representation."
) )
+ "}" keys_and_values.append(f"{key.json()}:{value.json()}")
) return "{" + ", ".join(keys_and_values) + "}"
def __hash__(self) -> int: def __hash__(self) -> int:
"""Get the hash of the var. """Get the hash of the var.
@ -487,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class ObjectItemOperation(CachedVarOperation, Var): class ObjectItemOperation(CachedVarOperation, Var):
"""Operation to get an item from an object.""" """Operation to get an item from an object."""

View File

@ -6,7 +6,6 @@ import dataclasses
import inspect import inspect
import json import json
import re import re
import sys
import typing import typing
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -15,7 +14,7 @@ from typing import (
List, List,
Literal, Literal,
NoReturn, NoReturn,
Set, Sequence,
Tuple, Tuple,
Type, Type,
Union, Union,
@ -596,7 +595,7 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralStringVar(LiteralVar, StringVar[str]): class LiteralStringVar(LiteralVar, StringVar[str]):
"""Base class for immutable literal string vars.""" """Base class for immutable literal string vars."""
@ -718,7 +717,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class ConcatVarOperation(CachedVarOperation, StringVar[str]): class ConcatVarOperation(CachedVarOperation, StringVar[str]):
"""Representing a concatenation of literal string vars.""" """Representing a concatenation of literal string vars."""
@ -794,7 +793,8 @@ class ConcatVarOperation(CachedVarOperation, StringVar[str]):
) )
ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set]) ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Sequence, covariant=True)
OTHER_ARRAY_VAR_TYPE = TypeVar("OTHER_ARRAY_VAR_TYPE", bound=Sequence)
OTHER_TUPLE = TypeVar("OTHER_TUPLE") OTHER_TUPLE = TypeVar("OTHER_TUPLE")
@ -887,6 +887,11 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
i: Literal[0, -2], i: Literal[0, -2],
) -> NumberVar: ... ) -> NumberVar: ...
@overload
def __getitem__(
self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1]
) -> BooleanVar: ...
@overload @overload
def __getitem__( def __getitem__(
self: ( self: (
@ -914,7 +919,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
@overload @overload
def __getitem__( def __getitem__(
self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1] self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
) -> BooleanVar: ... ) -> BooleanVar: ...
@overload @overload
@ -932,23 +937,12 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar
) -> StringVar: ... ) -> StringVar: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar
) -> BooleanVar: ...
@overload @overload
def __getitem__( def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]], self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]],
i: int | NumberVar, i: int | NumberVar,
) -> ArrayVar[List[INNER_ARRAY_VAR]]: ... ) -> ArrayVar[List[INNER_ARRAY_VAR]]: ...
@overload
def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]],
i: int | NumberVar,
) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ...
@overload @overload
def __getitem__( def __getitem__(
self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[KEY_TYPE, VALUE_TYPE]], self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[KEY_TYPE, VALUE_TYPE]],
@ -1239,26 +1233,18 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)):
LIST_ELEMENT = TypeVar("LIST_ELEMENT") LIST_ELEMENT = TypeVar("LIST_ELEMENT")
ARRAY_VAR_OF_LIST_ELEMENT = Union[ ARRAY_VAR_OF_LIST_ELEMENT = ArrayVar[Sequence[LIST_ELEMENT]]
ArrayVar[List[LIST_ELEMENT]],
ArrayVar[Set[LIST_ELEMENT]],
ArrayVar[Tuple[LIST_ELEMENT, ...]],
]
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
"""Base class for immutable literal array vars.""" """Base class for immutable literal array vars."""
_var_value: Union[ _var_value: Sequence[Union[Var, Any]] = dataclasses.field(default=())
List[Union[Var, Any]],
Set[Union[Var, Any]],
Tuple[Union[Var, Any], ...],
] = dataclasses.field(default_factory=list)
@cached_property_no_lock @cached_property_no_lock
def _cached_var_name(self) -> str: def _cached_var_name(self) -> str:
@ -1303,22 +1289,28 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
Returns: Returns:
The JSON representation of the var. The JSON representation of the var.
Raises:
TypeError: If the array elements are not of type LiteralVar.
""" """
return ( elements = []
"[" for element in self._var_value:
+ ", ".join( element_var = LiteralVar.create(element)
[LiteralVar.create(element).json() for element in self._var_value] if not isinstance(element_var, LiteralVar):
) raise TypeError(
+ "]" f"Array elements must be of type LiteralVar, not {type(element_var)}"
) )
elements.append(element_var.json())
return "[" + ", ".join(elements) + "]"
@classmethod @classmethod
def create( def create(
cls, cls,
value: ARRAY_VAR_TYPE, value: OTHER_ARRAY_VAR_TYPE,
_var_type: Type[ARRAY_VAR_TYPE] | None = None, _var_type: Type[OTHER_ARRAY_VAR_TYPE] | None = None,
_var_data: VarData | None = None, _var_data: VarData | None = None,
) -> LiteralArrayVar[ARRAY_VAR_TYPE]: ) -> LiteralArrayVar[OTHER_ARRAY_VAR_TYPE]:
"""Create a var from a string value. """Create a var from a string value.
Args: Args:
@ -1329,7 +1321,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]):
Returns: Returns:
The var. The var.
""" """
return cls( return LiteralArrayVar(
_js_expr="", _js_expr="",
_var_type=figure_out_type(value) if _var_type is None else _var_type, _var_type=figure_out_type(value) if _var_type is None else _var_type,
_var_data=_var_data, _var_data=_var_data,
@ -1356,7 +1348,7 @@ def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class ArraySliceOperation(CachedVarOperation, ArrayVar): class ArraySliceOperation(CachedVarOperation, ArrayVar):
"""Base class for immutable string vars that are the result of a string slice operation.""" """Base class for immutable string vars that are the result of a string slice operation."""
@ -1705,7 +1697,7 @@ class ColorVar(StringVar[Color], python_types=Color):
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, slots=True,
) )
class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar): class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar):
"""Base class for immutable literal color vars.""" """Base class for immutable literal color vars."""

View File

@ -3,6 +3,7 @@ from typing import List, Mapping, Tuple
import pytest import pytest
import reflex as rx import reflex as rx
from reflex.components.component import Component
from reflex.components.core.match import Match from reflex.components.core.match import Match
from reflex.state import BaseState from reflex.state import BaseState
from reflex.utils.exceptions import MatchTypeError from reflex.utils.exceptions import MatchTypeError
@ -29,6 +30,8 @@ def test_match_components():
rx.text("default value"), rx.text("default value"),
) )
match_comp = Match.create(MatchState.value, *match_case_tuples) match_comp = Match.create(MatchState.value, *match_case_tuples)
assert isinstance(match_comp, Component)
match_dict = match_comp.render() match_dict = match_comp.render()
assert match_dict["name"] == "Fragment" assert match_dict["name"] == "Fragment"
@ -151,6 +154,7 @@ def test_match_on_component_without_default():
) )
match_comp = Match.create(MatchState.value, *match_case_tuples) match_comp = Match.create(MatchState.value, *match_case_tuples)
assert isinstance(match_comp, Component)
default = match_comp.render()["children"][0]["default"] default = match_comp.render()["children"][0]["default"]
assert isinstance(default, dict) and default["name"] == Fragment.__name__ assert isinstance(default, dict) and default["name"] == Fragment.__name__

View File

@ -36,6 +36,7 @@ from reflex.utils.exceptions import (
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
from reflex.vars.object import ObjectVar
@pytest.fixture @pytest.fixture
@ -842,12 +843,12 @@ def test_component_event_trigger_arbitrary_args():
"""Test that we can define arbitrary types for the args of an event trigger.""" """Test that we can define arbitrary types for the args of an event trigger."""
def on_foo_spec( def on_foo_spec(
_e: Var[JavascriptInputEvent], _e: ObjectVar[JavascriptInputEvent],
alpha: Var[str], alpha: Var[str],
bravo: dict[str, Any], bravo: dict[str, Any],
charlie: Var[_Obj], charlie: ObjectVar[_Obj],
): ):
return [_e.target.value, bravo["nested"], charlie.custom + 42] return [_e.target.value, bravo["nested"], charlie.custom.to(int) + 42]
class C1(Component): class C1(Component):
library = "/local" library = "/local"
@ -1328,7 +1329,7 @@ class EventState(rx.State):
), ),
pytest.param( pytest.param(
rx.fragment(class_name=[TEST_VAR, "other-class"]), rx.fragment(class_name=[TEST_VAR, "other-class"]),
[LiteralVar.create([TEST_VAR, "other-class"]).join(" ")], [Var.create([TEST_VAR, "other-class"]).join(" ")],
id="fstring-dual-class_name", id="fstring-dual-class_name",
), ),
pytest.param( pytest.param(

View File

@ -471,15 +471,15 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str):
""" """
state = test_state() # pyright: ignore [reportCallIssue] state = test_state() # pyright: ignore [reportCallIssue]
state.add_var("int_val", int, 0) state.add_var("int_val", int, 0)
result = await state._process( async for result in state._process(
Event( Event(
token=token, token=token,
name=f"{test_state.get_name()}.set_int_val", name=f"{test_state.get_name()}.set_int_val",
router_data={"pathname": "/", "query": {}}, router_data={"pathname": "/", "query": {}},
payload={"value": 50}, payload={"value": 50},
) )
).__anext__() ):
assert result.delta == {test_state.get_name(): {"int_val": 50}} assert result.delta == {test_state.get_name(): {"int_val": 50}}
@pytest.mark.asyncio @pytest.mark.asyncio
@ -583,18 +583,17 @@ async def test_list_mutation_detection__plain_list(
token: a Token. token: a Token.
""" """
for event_name, expected_delta in event_tuples: for event_name, expected_delta in event_tuples:
result = await list_mutation_state._process( async for result in list_mutation_state._process(
Event( Event(
token=token, token=token,
name=f"{list_mutation_state.get_name()}.{event_name}", name=f"{list_mutation_state.get_name()}.{event_name}",
router_data={"pathname": "/", "query": {}}, router_data={"pathname": "/", "query": {}},
payload={}, payload={},
) )
).__anext__() ):
# prefix keys in expected_delta with the state name
# prefix keys in expected_delta with the state name expected_delta = {list_mutation_state.get_name(): expected_delta}
expected_delta = {list_mutation_state.get_name(): expected_delta} assert result.delta == expected_delta
assert result.delta == expected_delta
@pytest.mark.asyncio @pytest.mark.asyncio
@ -709,19 +708,18 @@ async def test_dict_mutation_detection__plain_list(
token: a Token. token: a Token.
""" """
for event_name, expected_delta in event_tuples: for event_name, expected_delta in event_tuples:
result = await dict_mutation_state._process( async for result in dict_mutation_state._process(
Event( Event(
token=token, token=token,
name=f"{dict_mutation_state.get_name()}.{event_name}", name=f"{dict_mutation_state.get_name()}.{event_name}",
router_data={"pathname": "/", "query": {}}, router_data={"pathname": "/", "query": {}},
payload={}, payload={},
) )
).__anext__() ):
# prefix keys in expected_delta with the state name
expected_delta = {dict_mutation_state.get_name(): expected_delta}
# prefix keys in expected_delta with the state name assert result.delta == expected_delta
expected_delta = {dict_mutation_state.get_name(): expected_delta}
assert result.delta == expected_delta
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -789,17 +789,16 @@ async def test_process_event_simple(test_state):
assert test_state.num1 == 0 assert test_state.num1 == 0
event = Event(token="t", name="set_num1", payload={"value": 69}) event = Event(token="t", name="set_num1", payload={"value": 69})
update = await test_state._process(event).__anext__() async for update in test_state._process(event):
# The event should update the value.
assert test_state.num1 == 69
# The event should update the value. # The delta should contain the changes, including computed vars.
assert test_state.num1 == 69 assert update.delta == {
TestState.get_full_name(): {"num1": 69, "sum": 72.14},
# The delta should contain the changes, including computed vars. GrandchildState3.get_full_name(): {"computed": ""},
assert update.delta == { }
TestState.get_full_name(): {"num1": 69, "sum": 72.14}, assert update.events == []
GrandchildState3.get_full_name(): {"computed": ""},
}
assert update.events == []
@pytest.mark.asyncio @pytest.mark.asyncio
@ -819,15 +818,15 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
name=f"{ChildState.get_name()}.change_both", name=f"{ChildState.get_name()}.change_both",
payload={"value": "hi", "count": 12}, payload={"value": "hi", "count": 12},
) )
update = await test_state._process(event).__anext__() async for update in test_state._process(event):
assert child_state.value == "HI" assert child_state.value == "HI"
assert child_state.count == 24 assert child_state.count == 24
assert update.delta == { assert update.delta == {
# TestState.get_full_name(): {"sum": 3.14, "upper": ""}, # TestState.get_full_name(): {"sum": 3.14, "upper": ""},
ChildState.get_full_name(): {"value": "HI", "count": 24}, ChildState.get_full_name(): {"value": "HI", "count": 24},
GrandchildState3.get_full_name(): {"computed": ""}, GrandchildState3.get_full_name(): {"computed": ""},
} }
test_state._clean() test_state._clean()
# Test with the granchild state. # Test with the granchild state.
assert grandchild_state.value2 == "" assert grandchild_state.value2 == ""
@ -836,13 +835,13 @@ async def test_process_event_substate(test_state, child_state, grandchild_state)
name=f"{GrandchildState.get_full_name()}.set_value2", name=f"{GrandchildState.get_full_name()}.set_value2",
payload={"value": "new"}, payload={"value": "new"},
) )
update = await test_state._process(event).__anext__() async for update in test_state._process(event):
assert grandchild_state.value2 == "new" assert grandchild_state.value2 == "new"
assert update.delta == { assert update.delta == {
# TestState.get_full_name(): {"sum": 3.14, "upper": ""}, # TestState.get_full_name(): {"sum": 3.14, "upper": ""},
GrandchildState.get_full_name(): {"value2": "new"}, GrandchildState.get_full_name(): {"value2": "new"},
GrandchildState3.get_full_name(): {"computed": ""}, GrandchildState3.get_full_name(): {"computed": ""},
} }
@pytest.mark.asyncio @pytest.mark.asyncio
@ -2909,10 +2908,10 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker):
events = updates[0].events events = updates[0].events
assert len(events) == 2 assert len(events) == 2
assert (await state._process(events[0]).__anext__()).delta == { async for update in state._process(events[0]):
test_state.get_full_name(): {"num": 1} assert update.delta == {test_state.get_full_name(): {"num": 1}}
} async for update in state._process(events[1]):
assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state) assert update.delta == exp_is_hydrated(state)
if isinstance(app.state_manager, StateManagerRedis): if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close() await app.state_manager.close()
@ -2957,13 +2956,12 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker):
events = updates[0].events events = updates[0].events
assert len(events) == 3 assert len(events) == 3
assert (await state._process(events[0]).__anext__()).delta == { async for update in state._process(events[0]):
OnLoadState.get_full_name(): {"num": 1} assert update.delta == {OnLoadState.get_full_name(): {"num": 1}}
} async for update in state._process(events[1]):
assert (await state._process(events[1]).__anext__()).delta == { assert update.delta == {OnLoadState.get_full_name(): {"num": 2}}
OnLoadState.get_full_name(): {"num": 2} async for update in state._process(events[2]):
} assert update.delta == exp_is_hydrated(state)
assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state)
if isinstance(app.state_manager, StateManagerRedis): if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close() await app.state_manager.close()

View File

@ -1,6 +1,5 @@
import json import json
import math import math
import sys
import typing import typing
from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast
@ -422,19 +421,13 @@ class Bar(rx.Base):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("var", "var_type"), ("var", "var_type"),
( [
[ (Var(_js_expr="").to(Foo | Bar), Foo | Bar),
(Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar), (Var(_js_expr="").to(Foo | Bar).bar, Union[int, str]),
(Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]), (Var(_js_expr="").to(Union[Foo, Bar]), Union[Foo, Bar]),
] (Var(_js_expr="").to(Union[Foo, Bar]).baz, str),
if sys.version_info >= (3, 10)
else []
)
+ [
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]),
(Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str),
( (
Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo, Var(_js_expr="").to(Union[Foo, Bar]).foo,
Union[int, None], Union[int, None],
), ),
], ],
@ -1358,7 +1351,7 @@ def test_unsupported_types_for_contains(var: Var):
var: The base var. var: The base var.
""" """
with pytest.raises(TypeError) as err: with pytest.raises(TypeError) as err:
assert var.contains(1) assert var.contains(1) # pyright: ignore [reportAttributeAccessIssue]
assert ( assert (
err.value.args[0] err.value.args[0]
== f"Var of type {var._var_type} does not support contains check." == f"Var of type {var._var_type} does not support contains check."
@ -1388,7 +1381,7 @@ def test_unsupported_types_for_string_contains(other):
def test_unsupported_default_contains(): def test_unsupported_default_contains():
with pytest.raises(TypeError) as err: with pytest.raises(TypeError) as err:
assert 1 in Var(_js_expr="var", _var_type=str).guess_type() assert 1 in Var(_js_expr="var", _var_type=str).guess_type() # pyright: ignore [reportOperatorIssue]
assert ( assert (
err.value.args[0] err.value.args[0]
== "'in' operator not supported for Var types, use Var.contains() instead." == "'in' operator not supported for Var types, use Var.contains() instead."