From 8663dbcb974bacc1e03d9f5158f62d7a98e398eb Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 31 Jan 2025 13:12:33 -0800 Subject: [PATCH] improve var base typing (#4718) * improve var base typing * fix pyi * dang it darglint * drain _process in tests * fixes #4576 * dang it darglint --- reflex/components/base/error_boundary.py | 3 +- reflex/components/base/error_boundary.pyi | 3 +- reflex/components/component.py | 1 + reflex/components/core/foreach.py | 12 +- reflex/components/datadisplay/dataeditor.py | 3 +- .../datadisplay/shiki_code_block.py | 12 +- reflex/event.py | 25 ++- reflex/experimental/client_state.py | 3 +- reflex/state.py | 6 +- reflex/utils/exceptions.py | 4 + reflex/vars/base.py | 203 +++++++++--------- reflex/vars/datetime.py | 3 +- reflex/vars/function.py | 6 +- reflex/vars/number.py | 9 +- reflex/vars/object.py | 40 ++-- reflex/vars/sequence.py | 74 +++---- tests/units/components/core/test_match.py | 4 + tests/units/components/test_component.py | 9 +- tests/units/test_app.py | 28 ++- tests/units/test_state.py | 72 +++---- tests/units/test_var.py | 23 +- 21 files changed, 279 insertions(+), 264 deletions(-) diff --git a/reflex/components/base/error_boundary.py b/reflex/components/base/error_boundary.py index f328773c2..74867a757 100644 --- a/reflex/components/base/error_boundary.py +++ b/reflex/components/base/error_boundary.py @@ -11,10 +11,11 @@ from reflex.event import EventHandler, set_clipboard from reflex.state import FrontendEventExceptionState from reflex.vars.base import Var from reflex.vars.function import ArgsFunctionOperation +from reflex.vars.object import ObjectVar def on_error_spec( - error: Var[Dict[str, str]], info: Var[Dict[str, str]] + error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]] ) -> Tuple[Var[str], Var[str]]: """The spec for the on_error event handler. diff --git a/reflex/components/base/error_boundary.pyi b/reflex/components/base/error_boundary.pyi index 2e01c7da0..8d27af0f3 100644 --- a/reflex/components/base/error_boundary.pyi +++ b/reflex/components/base/error_boundary.pyi @@ -9,9 +9,10 @@ from reflex.components.component import Component from reflex.event import BASE_STATE, EventType from reflex.style import Style from reflex.vars.base import Var +from reflex.vars.object import ObjectVar def on_error_spec( - error: Var[Dict[str, str]], info: Var[Dict[str, str]] + error: ObjectVar[Dict[str, str]], info: ObjectVar[Dict[str, str]] ) -> Tuple[Var[str], Var[str]]: ... class ErrorBoundary(Component): diff --git a/reflex/components/component.py b/reflex/components/component.py index 8982e4b4f..440a408df 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -2457,6 +2457,7 @@ def render_dict_to_var(tag: dict | Component | str, imported_names: set[str]) -> @dataclasses.dataclass( eq=False, frozen=True, + slots=True, ) class LiteralComponentVar(CachedVarOperation, LiteralVar, ComponentVar): """A Var that represents a Component.""" diff --git a/reflex/components/core/foreach.py b/reflex/components/core/foreach.py index 30dda9c6a..927b01333 100644 --- a/reflex/components/core/foreach.py +++ b/reflex/components/core/foreach.py @@ -11,6 +11,7 @@ from reflex.components.component import Component from reflex.components.tags import IterTag from reflex.constants import MemoizationMode from reflex.state import ComponentState +from reflex.utils.exceptions import UntypedVarError from reflex.vars.base import LiteralVar, Var @@ -51,6 +52,7 @@ class Foreach(Component): Raises: ForeachVarError: If the iterable is of type Any. TypeError: If the render function is a ComponentState. + UntypedVarError: If the iterable is of type Any without a type annotation. """ iterable = LiteralVar.create(iterable) if iterable._var_type == Any: @@ -72,8 +74,14 @@ class Foreach(Component): iterable=iterable, render_fn=render_fn, ) - # Keep a ref to a rendered component to determine correct imports/hooks/styles. - component.children = [component._render().render_component()] + try: + # Keep a ref to a rendered component to determine correct imports/hooks/styles. + component.children = [component._render().render_component()] + except UntypedVarError as e: + raise UntypedVarError( + f"Could not foreach over var `{iterable!s}` without a type annotation. " + "See https://reflex.dev/docs/library/dynamic-rendering/foreach/" + ) from e return component def _render(self) -> IterTag: diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index b2d6417bd..338fb2e44 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -387,7 +387,8 @@ class DataEditor(NoSSRComponent): raise ValueError( "DataEditor data must be an ArrayVar if rows is not provided." ) - props["rows"] = data.length() if isinstance(data, Var) else len(data) + + props["rows"] = data.length() if isinstance(data, ArrayVar) else len(data) if not isinstance(columns, Var) and len(columns): if types.is_dataframe(type(data)) or ( diff --git a/reflex/components/datadisplay/shiki_code_block.py b/reflex/components/datadisplay/shiki_code_block.py index 2d3040966..a4aaec1d4 100644 --- a/reflex/components/datadisplay/shiki_code_block.py +++ b/reflex/components/datadisplay/shiki_code_block.py @@ -621,18 +621,22 @@ class ShikiCodeBlock(Component, MarkdownComponentMap): Returns: Imports for the component. + + Raises: + ValueError: If the transformers are not of type LiteralVar. """ imports = defaultdict(list) + if not isinstance(self.transformers, LiteralVar): + raise ValueError( + f"transformers should be a LiteralVar type. Got {type(self.transformers)} instead." + ) for transformer in self.transformers._var_value: if isinstance(transformer, ShikiBaseTransformers): imports[transformer.library].extend( [ImportVar(tag=str(fn)) for fn in transformer.fns] ) - ( + if transformer.library not in self.lib_dependencies: self.lib_dependencies.append(transformer.library) - if transformer.library not in self.lib_dependencies - else None - ) return imports @classmethod diff --git a/reflex/event.py b/reflex/event.py index 96790e24c..5ce0f3dc1 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -4,7 +4,6 @@ from __future__ import annotations import dataclasses import inspect -import sys import types import urllib.parse from base64 import b64encode @@ -541,7 +540,7 @@ class JavasciptKeyboardEvent: shiftKey: bool = False # noqa: N815 -def input_event(e: Var[JavascriptInputEvent]) -> Tuple[Var[str]]: +def input_event(e: ObjectVar[JavascriptInputEvent]) -> Tuple[Var[str]]: """Get the value from an input event. Args: @@ -562,7 +561,9 @@ class KeyInputInfo(TypedDict): shift_key: bool -def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInfo]]: +def key_event( + e: ObjectVar[JavasciptKeyboardEvent], +) -> Tuple[Var[str], Var[KeyInputInfo]]: """Get the key from a keyboard event. Args: @@ -572,7 +573,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf The key from the keyboard event. """ return ( - e.key, + e.key.to(str), Var.create( { "alt_key": e.altKey, @@ -580,7 +581,7 @@ def key_event(e: Var[JavasciptKeyboardEvent]) -> Tuple[Var[str], Var[KeyInputInf "meta_key": e.metaKey, "shift_key": e.shiftKey, }, - ), + ).to(KeyInputInfo), ) @@ -1354,7 +1355,7 @@ def unwrap_var_annotation(annotation: GenericType): Returns: The unwrapped annotation. """ - if get_origin(annotation) is Var and (args := get_args(annotation)): + if get_origin(annotation) in (Var, ObjectVar) and (args := get_args(annotation)): return args[0] return annotation @@ -1620,7 +1621,7 @@ class EventVar(ObjectVar, python_types=EventSpec): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralEventVar(VarOperationCall, LiteralVar, EventVar): """A literal event var.""" @@ -1681,7 +1682,7 @@ class EventChainVar(BuilderFunctionVar, python_types=EventChain): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) # Note: LiteralVar is second in the inheritance list allowing it act like a # CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the @@ -1713,6 +1714,9 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV Returns: The created LiteralEventChainVar instance. + + Raises: + ValueError: If the invocation is not a FunctionVar. """ arg_spec = ( value.args_spec[0] @@ -1740,6 +1744,11 @@ class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainV else: invocation = value.invocation + if invocation is not None and not isinstance(invocation, FunctionVar): + raise ValueError( + f"EventChain invocation must be a FunctionVar, got {invocation!s} of type {invocation._var_type!s}." + ) + return cls( _js_expr="", _var_type=EventChain, diff --git a/reflex/experimental/client_state.py b/reflex/experimental/client_state.py index ce3a941bb..8138c2721 100644 --- a/reflex/experimental/client_state.py +++ b/reflex/experimental/client_state.py @@ -4,7 +4,6 @@ from __future__ import annotations import dataclasses import re -import sys from typing import Any, Callable, Union from reflex import constants @@ -49,7 +48,7 @@ def _client_state_ref_dict(var_name: str) -> str: @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ClientStateVar(Var): """A Var that exists on the client via useState.""" diff --git a/reflex/state.py b/reflex/state.py index a0b91c94f..6c74d5e55 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1637,9 +1637,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if not isinstance(var, Var): return var + unset = object() + # Fast case: this is a literal var and the value is known. - if hasattr(var, "_var_value"): - return var._var_value + if (var_value := getattr(var, "_var_value", unset)) is not unset: + return var_value # pyright: ignore [reportReturnType] var_data = var._get_all_var_data() if var_data is None or not var_data.state: diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index be3f6ab69..05fbb297c 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -75,6 +75,10 @@ class VarAttributeError(ReflexError, AttributeError): """Custom AttributeError for var related errors.""" +class UntypedVarError(ReflexError, TypeError): + """Custom TypeError for untyped var errors.""" + + class UntypedComputedVarError(ReflexError, TypeError): """Custom TypeError for untyped computed var errors.""" diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 0b84c1036..d34bc8ff5 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -12,7 +12,6 @@ import json import random import re import string -import sys import warnings from types import CodeType, FunctionType from typing import ( @@ -82,6 +81,7 @@ if TYPE_CHECKING: VAR_TYPE = TypeVar("VAR_TYPE", covariant=True) OTHER_VAR_TYPE = TypeVar("OTHER_VAR_TYPE") STRING_T = TypeVar("STRING_T", bound=str) +SEQUENCE_TYPE = TypeVar("SEQUENCE_TYPE", bound=Sequence) warnings.filterwarnings("ignore", message="fields may not start with an underscore") @@ -449,7 +449,7 @@ class Var(Generic[VAR_TYPE]): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ToVarOperation(ToOperation, cls): """Base class of converting a var to another var type.""" @@ -597,7 +597,7 @@ class Var(Generic[VAR_TYPE]): @overload @classmethod - def create( + def create( # pyright: ignore [reportOverlappingOverload] cls, value: STRING_T, _var_data: VarData | None = None, @@ -611,6 +611,22 @@ class Var(Generic[VAR_TYPE]): _var_data: VarData | None = None, ) -> NoneVar: ... + @overload + @classmethod + def create( + cls, + value: MAPPING_TYPE, + _var_data: VarData | None = None, + ) -> ObjectVar[MAPPING_TYPE]: ... + + @overload + @classmethod + def create( + cls, + value: SEQUENCE_TYPE, + _var_data: VarData | None = None, + ) -> ArrayVar[SEQUENCE_TYPE]: ... + @overload @classmethod def create( @@ -692,8 +708,8 @@ class Var(Generic[VAR_TYPE]): @overload def to( self, - output: type[Mapping], - ) -> ObjectVar[Mapping]: ... + output: type[MAPPING_TYPE], + ) -> ObjectVar[MAPPING_TYPE]: ... @overload def to( @@ -744,7 +760,7 @@ class Var(Generic[VAR_TYPE]): return get_to_operation(NoneVar).create(self) # pyright: ignore [reportReturnType] # Handle fixed_output_type being Base or a dataclass. - if can_use_in_object_var(fixed_output_type): + if can_use_in_object_var(output): return self.to(ObjectVar, output) if inspect.isclass(output): @@ -776,6 +792,9 @@ class Var(Generic[VAR_TYPE]): return self + @overload + def guess_type(self: Var[NoReturn]) -> Var[Any]: ... # pyright: ignore [reportOverlappingOverload] + @overload def guess_type(self: Var[str]) -> StringVar: ... @@ -785,6 +804,9 @@ class Var(Generic[VAR_TYPE]): @overload def guess_type(self: Var[int] | Var[float] | Var[int | float]) -> NumberVar: ... + @overload + def guess_type(self: Var[BASE_TYPE]) -> ObjectVar[BASE_TYPE]: ... + @overload def guess_type(self) -> Self: ... @@ -933,7 +955,7 @@ class Var(Generic[VAR_TYPE]): return setter - def _var_set_state(self, state: type[BaseState] | str): + def _var_set_state(self, state: type[BaseState] | str) -> Self: """Set the state of the var. Args: @@ -948,7 +970,7 @@ class Var(Generic[VAR_TYPE]): else format_state_name(state.get_full_name()) ) - return StateOperation.create( + return StateOperation.create( # pyright: ignore [reportReturnType] formatted_state_name, self, _var_data=VarData.merge( @@ -1127,43 +1149,6 @@ class Var(Generic[VAR_TYPE]): """ return self - def __getattr__(self, name: str): - """Get an attribute of the var. - - Args: - name: The name of the attribute. - - Returns: - The attribute. - - Raises: - VarAttributeError: If the attribute does not exist. - TypeError: If the var type is Any. - """ - if name.startswith("_"): - return self.__getattribute__(name) - - if name == "contains": - raise TypeError( - f"Var of type {self._var_type} does not support contains check." - ) - if name == "reverse": - raise TypeError("Cannot reverse non-list var.") - - if self._var_type is Any: - raise TypeError( - f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`." - ) - - if name in REPLACED_NAMES: - raise VarAttributeError( - f"Field {name!r} was renamed to {REPLACED_NAMES[name]!r}" - ) - - raise VarAttributeError( - f"The State var has no attribute '{name}' or may have been annotated wrongly.", - ) - def _decode(self) -> Any: """Decode Var as a python value. @@ -1225,36 +1210,76 @@ class Var(Generic[VAR_TYPE]): return ArrayVar.range(first_endpoint, second_endpoint, step) - def __bool__(self) -> bool: - """Raise exception if using Var in a boolean context. + if not TYPE_CHECKING: - Raises: - VarTypeError: when attempting to bool-ify the Var. - """ - raise VarTypeError( - f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. " - "Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)." - ) + def __getattr__(self, name: str): + """Get an attribute of the var. - def __iter__(self) -> Any: - """Raise exception if using Var in an iterable context. + Args: + name: The name of the attribute. - Raises: - VarTypeError: when attempting to iterate over the Var. - """ - raise VarTypeError( - f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`." - ) + Raises: + VarAttributeError: If the attribute does not exist. + UntypedVarError: If the var type is Any. + TypeError: If the var type is Any. - def __contains__(self, _: Any) -> Var: - """Override the 'in' operator to alert the user that it is not supported. + # noqa: DAR101 self + """ + if name.startswith("_"): + raise VarAttributeError(f"Attribute {name} not found.") - Raises: - VarTypeError: the operation is not supported - """ - raise VarTypeError( - "'in' operator not supported for Var types, use Var.contains() instead." - ) + if name == "contains": + raise TypeError( + f"Var of type {self._var_type} does not support contains check." + ) + if name == "reverse": + raise TypeError("Cannot reverse non-list var.") + + if self._var_type is Any: + raise exceptions.UntypedVarError( + f"You must provide an annotation for the state var `{self!s}`. Annotation cannot be `{self._var_type}`." + ) + + raise VarAttributeError( + f"The State var has no attribute '{name}' or may have been annotated wrongly.", + ) + + def __bool__(self) -> bool: + """Raise exception if using Var in a boolean context. + + Raises: + VarTypeError: when attempting to bool-ify the Var. + + # noqa: DAR101 self + """ + raise VarTypeError( + f"Cannot convert Var {str(self)!r} to bool for use with `if`, `and`, `or`, and `not`. " + "Instead use `rx.cond` and bitwise operators `&` (and), `|` (or), `~` (invert)." + ) + + def __iter__(self) -> Any: + """Raise exception if using Var in an iterable context. + + Raises: + VarTypeError: when attempting to iterate over the Var. + + # noqa: DAR101 self + """ + raise VarTypeError( + f"Cannot iterate over Var {str(self)!r}. Instead use `rx.foreach`." + ) + + def __contains__(self, _: Any) -> Var: + """Override the 'in' operator to alert the user that it is not supported. + + Raises: + VarTypeError: the operation is not supported + + # noqa: DAR101 self + """ + raise VarTypeError( + "'in' operator not supported for Var types, use Var.contains() instead." + ) OUTPUT = TypeVar("OUTPUT", bound=Var) @@ -1471,6 +1496,12 @@ class LiteralVar(Var): def __post_init__(self): """Post-initialize the var.""" + @property + def _var_value(self) -> Any: + raise NotImplementedError( + "LiteralVar subclasses must implement the _var_value property." + ) + def json(self) -> str: """Serialize the var to a JSON string. @@ -1543,7 +1574,7 @@ def var_operation( ) -> Callable[P, StringVar]: ... -LIST_T = TypeVar("LIST_T", bound=Union[List[Any], Tuple, Set]) +LIST_T = TypeVar("LIST_T", bound=Sequence) @overload @@ -1780,7 +1811,7 @@ def _or_operation(a: Var, b: Var): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class CallableVar(Var): """Decorate a Var-returning function to act as both a Var and a function. @@ -1861,7 +1892,7 @@ def is_computed_var(obj: Any) -> TypeGuard[ComputedVar]: @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ComputedVar(Var[RETURN_TYPE]): """A field with computed getters.""" @@ -2070,13 +2101,6 @@ class ComputedVar(Var[RETURN_TYPE]): owner: Type, ) -> ArrayVar[list[LIST_INSIDE]]: ... - @overload - def __get__( - self: ComputedVar[set[LIST_INSIDE]], - instance: None, - owner: Type, - ) -> ArrayVar[set[LIST_INSIDE]]: ... - @overload def __get__( self: ComputedVar[tuple[LIST_INSIDE, ...]], @@ -2436,7 +2460,7 @@ def var_operation_return( @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class CustomVarOperation(CachedVarOperation, Var[T]): """Base class for custom var operations.""" @@ -2507,7 +2531,7 @@ class NoneVar(Var[None], python_types=type(None)): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralNoneVar(LiteralVar, NoneVar): """A var representing None.""" @@ -2569,7 +2593,7 @@ def get_to_operation(var_subclass: Type[Var]) -> Type[ToOperation]: @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class StateOperation(CachedVarOperation, Var): """A var operation that accesses a field on an object.""" @@ -2716,19 +2740,6 @@ def _extract_var_data(value: Iterable) -> list[VarData | None]: return var_datas -# These names were changed in reflex 0.3.0 -REPLACED_NAMES = { - "full_name": "_var_full_name", - "name": "_js_expr", - "state": "_var_data.state", - "type_": "_var_type", - "is_local": "_var_is_local", - "is_string": "_var_is_string", - "set_state": "_var_set_state", - "deps": "_deps", -} - - dispatchers: Dict[GenericType, Callable[[Var], Var]] = {} diff --git a/reflex/vars/datetime.py b/reflex/vars/datetime.py index b20cfc7a6..c43c24165 100644 --- a/reflex/vars/datetime.py +++ b/reflex/vars/datetime.py @@ -3,7 +3,6 @@ from __future__ import annotations import dataclasses -import sys from datetime import date, datetime from typing import Any, NoReturn, TypeVar, Union, overload @@ -193,7 +192,7 @@ def date_compare_operation( @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralDatetimeVar(LiteralVar, DateTimeVar): """Base class for immutable datetime and date vars.""" diff --git a/reflex/vars/function.py b/reflex/vars/function.py index e8691cfb1..505a69b4c 100644 --- a/reflex/vars/function.py +++ b/reflex/vars/function.py @@ -226,7 +226,7 @@ class FunctionStringVar(FunctionVar[CALLABLE_TYPE]): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]): """Base class for immutable vars that are the result of a function call.""" @@ -350,7 +350,7 @@ def format_args_function_operation( @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ArgsFunctionOperation(CachedVarOperation, FunctionVar): """Base class for immutable function defined via arguments and return expression.""" @@ -407,7 +407,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar): """Base class for immutable function defined via arguments and return expression with the builder pattern.""" diff --git a/reflex/vars/number.py b/reflex/vars/number.py index 050dc2329..35a55490a 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -5,7 +5,6 @@ from __future__ import annotations import dataclasses import json import math -import sys from typing import ( TYPE_CHECKING, Any, @@ -160,7 +159,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ from .sequence import ArrayVar, LiteralArrayVar - if isinstance(other, (list, tuple, set, ArrayVar)): + if isinstance(other, (list, tuple, ArrayVar)): if isinstance(other, ArrayVar): return other * self return LiteralArrayVar.create(other) * self @@ -187,7 +186,7 @@ class NumberVar(Var[NUMBER_T], python_types=(int, float)): """ from .sequence import ArrayVar, LiteralArrayVar - if isinstance(other, (list, tuple, set, ArrayVar)): + if isinstance(other, (list, tuple, ArrayVar)): if isinstance(other, ArrayVar): return other * self return LiteralArrayVar.create(other) * self @@ -973,7 +972,7 @@ def boolean_not_operation(value: BooleanVar): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralNumberVar(LiteralVar, NumberVar): """Base class for immutable literal number vars.""" @@ -1032,7 +1031,7 @@ class LiteralNumberVar(LiteralVar, NumberVar): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralBooleanVar(LiteralVar, BooleanVar): """Base class for immutable literal boolean vars.""" diff --git a/reflex/vars/object.py b/reflex/vars/object.py index ed4221e4c..cb29cabfb 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -3,7 +3,6 @@ from __future__ import annotations import dataclasses -import sys import typing from inspect import isclass from typing import ( @@ -167,12 +166,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): key: Var | Any, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... - @overload - def __getitem__( - self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], - key: Var | Any, - ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... - @overload def __getitem__( self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], @@ -229,12 +222,6 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): name: str, ) -> ArrayVar[list[ARRAY_INNER_TYPE]]: ... - @overload - def __getattr__( - self: ObjectVar[Mapping[Any, set[ARRAY_INNER_TYPE]]], - name: str, - ) -> ArrayVar[set[ARRAY_INNER_TYPE]]: ... - @overload def __getattr__( self: ObjectVar[Mapping[Any, tuple[ARRAY_INNER_TYPE, ...]]], @@ -305,7 +292,7 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=Mapping): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): """Base class for immutable literal object vars.""" @@ -355,17 +342,20 @@ class LiteralObjectVar(CachedVarOperation, ObjectVar[OBJECT_TYPE], LiteralVar): Returns: The JSON representation of the object. + + Raises: + TypeError: The keys and values of the object must be literal vars to get the JSON representation """ - return ( - "{" - + ", ".join( - [ - f"{LiteralVar.create(key).json()}:{LiteralVar.create(value).json()}" - for key, value in self._var_value.items() - ] - ) - + "}" - ) + keys_and_values = [] + for key, value in self._var_value.items(): + key = LiteralVar.create(key) + value = LiteralVar.create(value) + if not isinstance(key, LiteralVar) or not isinstance(value, LiteralVar): + raise TypeError( + "The keys and values of the object must be literal vars to get the JSON representation." + ) + keys_and_values.append(f"{key.json()}:{value.json()}") + return "{" + ", ".join(keys_and_values) + "}" def __hash__(self) -> int: """Get the hash of the var. @@ -487,7 +477,7 @@ def object_merge_operation(lhs: ObjectVar, rhs: ObjectVar): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ObjectItemOperation(CachedVarOperation, Var): """Operation to get an item from an object.""" diff --git a/reflex/vars/sequence.py b/reflex/vars/sequence.py index f7a9958f5..dfd9a6af8 100644 --- a/reflex/vars/sequence.py +++ b/reflex/vars/sequence.py @@ -6,7 +6,6 @@ import dataclasses import inspect import json import re -import sys import typing from typing import ( TYPE_CHECKING, @@ -15,7 +14,7 @@ from typing import ( List, Literal, NoReturn, - Set, + Sequence, Tuple, Type, Union, @@ -596,7 +595,7 @@ _decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL) @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralStringVar(LiteralVar, StringVar[str]): """Base class for immutable literal string vars.""" @@ -718,7 +717,7 @@ class LiteralStringVar(LiteralVar, StringVar[str]): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ConcatVarOperation(CachedVarOperation, StringVar[str]): """Representing a concatenation of literal string vars.""" @@ -794,7 +793,8 @@ class ConcatVarOperation(CachedVarOperation, StringVar[str]): ) -ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Union[List, Tuple, Set]) +ARRAY_VAR_TYPE = TypeVar("ARRAY_VAR_TYPE", bound=Sequence, covariant=True) +OTHER_ARRAY_VAR_TYPE = TypeVar("OTHER_ARRAY_VAR_TYPE", bound=Sequence) OTHER_TUPLE = TypeVar("OTHER_TUPLE") @@ -887,6 +887,11 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): i: Literal[0, -2], ) -> NumberVar: ... + @overload + def __getitem__( + self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1] + ) -> BooleanVar: ... + @overload def __getitem__( self: ( @@ -914,7 +919,7 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): @overload def __getitem__( - self: ArrayVar[Tuple[Any, bool]], i: Literal[1, -1] + self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar ) -> BooleanVar: ... @overload @@ -932,23 +937,12 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): self: ARRAY_VAR_OF_LIST_ELEMENT[str], i: int | NumberVar ) -> StringVar: ... - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[bool], i: int | NumberVar - ) -> BooleanVar: ... - @overload def __getitem__( self: ARRAY_VAR_OF_LIST_ELEMENT[List[INNER_ARRAY_VAR]], i: int | NumberVar, ) -> ArrayVar[List[INNER_ARRAY_VAR]]: ... - @overload - def __getitem__( - self: ARRAY_VAR_OF_LIST_ELEMENT[Set[INNER_ARRAY_VAR]], - i: int | NumberVar, - ) -> ArrayVar[Set[INNER_ARRAY_VAR]]: ... - @overload def __getitem__( self: ARRAY_VAR_OF_LIST_ELEMENT[Tuple[KEY_TYPE, VALUE_TYPE]], @@ -1239,26 +1233,18 @@ class ArrayVar(Var[ARRAY_VAR_TYPE], python_types=(list, tuple, set)): LIST_ELEMENT = TypeVar("LIST_ELEMENT") -ARRAY_VAR_OF_LIST_ELEMENT = Union[ - ArrayVar[List[LIST_ELEMENT]], - ArrayVar[Set[LIST_ELEMENT]], - ArrayVar[Tuple[LIST_ELEMENT, ...]], -] +ARRAY_VAR_OF_LIST_ELEMENT = ArrayVar[Sequence[LIST_ELEMENT]] @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): """Base class for immutable literal array vars.""" - _var_value: Union[ - List[Union[Var, Any]], - Set[Union[Var, Any]], - Tuple[Union[Var, Any], ...], - ] = dataclasses.field(default_factory=list) + _var_value: Sequence[Union[Var, Any]] = dataclasses.field(default=()) @cached_property_no_lock def _cached_var_name(self) -> str: @@ -1303,22 +1289,28 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): Returns: The JSON representation of the var. + + Raises: + TypeError: If the array elements are not of type LiteralVar. """ - return ( - "[" - + ", ".join( - [LiteralVar.create(element).json() for element in self._var_value] - ) - + "]" - ) + elements = [] + for element in self._var_value: + element_var = LiteralVar.create(element) + if not isinstance(element_var, LiteralVar): + raise TypeError( + f"Array elements must be of type LiteralVar, not {type(element_var)}" + ) + elements.append(element_var.json()) + + return "[" + ", ".join(elements) + "]" @classmethod def create( cls, - value: ARRAY_VAR_TYPE, - _var_type: Type[ARRAY_VAR_TYPE] | None = None, + value: OTHER_ARRAY_VAR_TYPE, + _var_type: Type[OTHER_ARRAY_VAR_TYPE] | None = None, _var_data: VarData | None = None, - ) -> LiteralArrayVar[ARRAY_VAR_TYPE]: + ) -> LiteralArrayVar[OTHER_ARRAY_VAR_TYPE]: """Create a var from a string value. Args: @@ -1329,7 +1321,7 @@ class LiteralArrayVar(CachedVarOperation, LiteralVar, ArrayVar[ARRAY_VAR_TYPE]): Returns: The var. """ - return cls( + return LiteralArrayVar( _js_expr="", _var_type=figure_out_type(value) if _var_type is None else _var_type, _var_data=_var_data, @@ -1356,7 +1348,7 @@ def string_split_operation(string: StringVar[Any], sep: StringVar | str = ""): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class ArraySliceOperation(CachedVarOperation, ArrayVar): """Base class for immutable string vars that are the result of a string slice operation.""" @@ -1705,7 +1697,7 @@ class ColorVar(StringVar[Color], python_types=Color): @dataclasses.dataclass( eq=False, frozen=True, - **{"slots": True} if sys.version_info >= (3, 10) else {}, + slots=True, ) class LiteralColorVar(CachedVarOperation, LiteralVar, ColorVar): """Base class for immutable literal color vars.""" diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py index 47652cd43..11602b77a 100644 --- a/tests/units/components/core/test_match.py +++ b/tests/units/components/core/test_match.py @@ -3,6 +3,7 @@ from typing import List, Mapping, Tuple import pytest import reflex as rx +from reflex.components.component import Component from reflex.components.core.match import Match from reflex.state import BaseState from reflex.utils.exceptions import MatchTypeError @@ -29,6 +30,8 @@ def test_match_components(): rx.text("default value"), ) match_comp = Match.create(MatchState.value, *match_case_tuples) + + assert isinstance(match_comp, Component) match_dict = match_comp.render() assert match_dict["name"] == "Fragment" @@ -151,6 +154,7 @@ def test_match_on_component_without_default(): ) match_comp = Match.create(MatchState.value, *match_case_tuples) + assert isinstance(match_comp, Component) default = match_comp.render()["children"][0]["default"] assert isinstance(default, dict) and default["name"] == Fragment.__name__ diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 26e530f7c..8cffa6e0e 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -36,6 +36,7 @@ from reflex.utils.exceptions import ( from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var +from reflex.vars.object import ObjectVar @pytest.fixture @@ -842,12 +843,12 @@ def test_component_event_trigger_arbitrary_args(): """Test that we can define arbitrary types for the args of an event trigger.""" def on_foo_spec( - _e: Var[JavascriptInputEvent], + _e: ObjectVar[JavascriptInputEvent], alpha: Var[str], bravo: dict[str, Any], - charlie: Var[_Obj], + charlie: ObjectVar[_Obj], ): - return [_e.target.value, bravo["nested"], charlie.custom + 42] + return [_e.target.value, bravo["nested"], charlie.custom.to(int) + 42] class C1(Component): library = "/local" @@ -1328,7 +1329,7 @@ class EventState(rx.State): ), pytest.param( rx.fragment(class_name=[TEST_VAR, "other-class"]), - [LiteralVar.create([TEST_VAR, "other-class"]).join(" ")], + [Var.create([TEST_VAR, "other-class"]).join(" ")], id="fstring-dual-class_name", ), pytest.param( diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 074e7f2ef..4a6c16d6e 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -471,15 +471,15 @@ async def test_dynamic_var_event(test_state: Type[ATestState], token: str): """ state = test_state() # pyright: ignore [reportCallIssue] state.add_var("int_val", int, 0) - result = await state._process( + async for result in state._process( Event( token=token, name=f"{test_state.get_name()}.set_int_val", router_data={"pathname": "/", "query": {}}, payload={"value": 50}, ) - ).__anext__() - assert result.delta == {test_state.get_name(): {"int_val": 50}} + ): + assert result.delta == {test_state.get_name(): {"int_val": 50}} @pytest.mark.asyncio @@ -583,18 +583,17 @@ async def test_list_mutation_detection__plain_list( token: a Token. """ for event_name, expected_delta in event_tuples: - result = await list_mutation_state._process( + async for result in list_mutation_state._process( Event( token=token, name=f"{list_mutation_state.get_name()}.{event_name}", router_data={"pathname": "/", "query": {}}, payload={}, ) - ).__anext__() - - # prefix keys in expected_delta with the state name - expected_delta = {list_mutation_state.get_name(): expected_delta} - assert result.delta == expected_delta + ): + # prefix keys in expected_delta with the state name + expected_delta = {list_mutation_state.get_name(): expected_delta} + assert result.delta == expected_delta @pytest.mark.asyncio @@ -709,19 +708,18 @@ async def test_dict_mutation_detection__plain_list( token: a Token. """ for event_name, expected_delta in event_tuples: - result = await dict_mutation_state._process( + async for result in dict_mutation_state._process( Event( token=token, name=f"{dict_mutation_state.get_name()}.{event_name}", router_data={"pathname": "/", "query": {}}, payload={}, ) - ).__anext__() + ): + # prefix keys in expected_delta with the state name + expected_delta = {dict_mutation_state.get_name(): expected_delta} - # prefix keys in expected_delta with the state name - expected_delta = {dict_mutation_state.get_name(): expected_delta} - - assert result.delta == expected_delta + assert result.delta == expected_delta @pytest.mark.asyncio diff --git a/tests/units/test_state.py b/tests/units/test_state.py index b276bad4b..9e1932305 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -789,17 +789,16 @@ async def test_process_event_simple(test_state): assert test_state.num1 == 0 event = Event(token="t", name="set_num1", payload={"value": 69}) - update = await test_state._process(event).__anext__() + async for update in test_state._process(event): + # The event should update the value. + assert test_state.num1 == 69 - # The event should update the value. - assert test_state.num1 == 69 - - # The delta should contain the changes, including computed vars. - assert update.delta == { - TestState.get_full_name(): {"num1": 69, "sum": 72.14}, - GrandchildState3.get_full_name(): {"computed": ""}, - } - assert update.events == [] + # The delta should contain the changes, including computed vars. + assert update.delta == { + TestState.get_full_name(): {"num1": 69, "sum": 72.14}, + GrandchildState3.get_full_name(): {"computed": ""}, + } + assert update.events == [] @pytest.mark.asyncio @@ -819,15 +818,15 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) name=f"{ChildState.get_name()}.change_both", payload={"value": "hi", "count": 12}, ) - update = await test_state._process(event).__anext__() - assert child_state.value == "HI" - assert child_state.count == 24 - assert update.delta == { - # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, - ChildState.get_full_name(): {"value": "HI", "count": 24}, - GrandchildState3.get_full_name(): {"computed": ""}, - } - test_state._clean() + async for update in test_state._process(event): + assert child_state.value == "HI" + assert child_state.count == 24 + assert update.delta == { + # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, + ChildState.get_full_name(): {"value": "HI", "count": 24}, + GrandchildState3.get_full_name(): {"computed": ""}, + } + test_state._clean() # Test with the granchild state. assert grandchild_state.value2 == "" @@ -836,13 +835,13 @@ async def test_process_event_substate(test_state, child_state, grandchild_state) name=f"{GrandchildState.get_full_name()}.set_value2", payload={"value": "new"}, ) - update = await test_state._process(event).__anext__() - assert grandchild_state.value2 == "new" - assert update.delta == { - # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, - GrandchildState.get_full_name(): {"value2": "new"}, - GrandchildState3.get_full_name(): {"computed": ""}, - } + async for update in test_state._process(event): + assert grandchild_state.value2 == "new" + assert update.delta == { + # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, + GrandchildState.get_full_name(): {"value2": "new"}, + GrandchildState3.get_full_name(): {"computed": ""}, + } @pytest.mark.asyncio @@ -2909,10 +2908,10 @@ async def test_preprocess(app_module_mock, token, test_state, expected, mocker): events = updates[0].events assert len(events) == 2 - assert (await state._process(events[0]).__anext__()).delta == { - test_state.get_full_name(): {"num": 1} - } - assert (await state._process(events[1]).__anext__()).delta == exp_is_hydrated(state) + async for update in state._process(events[0]): + assert update.delta == {test_state.get_full_name(): {"num": 1}} + async for update in state._process(events[1]): + assert update.delta == exp_is_hydrated(state) if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() @@ -2957,13 +2956,12 @@ async def test_preprocess_multiple_load_events(app_module_mock, token, mocker): events = updates[0].events assert len(events) == 3 - assert (await state._process(events[0]).__anext__()).delta == { - OnLoadState.get_full_name(): {"num": 1} - } - assert (await state._process(events[1]).__anext__()).delta == { - OnLoadState.get_full_name(): {"num": 2} - } - assert (await state._process(events[2]).__anext__()).delta == exp_is_hydrated(state) + async for update in state._process(events[0]): + assert update.delta == {OnLoadState.get_full_name(): {"num": 1}} + async for update in state._process(events[1]): + assert update.delta == {OnLoadState.get_full_name(): {"num": 2}} + async for update in state._process(events[2]): + assert update.delta == exp_is_hydrated(state) if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() diff --git a/tests/units/test_var.py b/tests/units/test_var.py index a5cd56a91..ef19e86e8 100644 --- a/tests/units/test_var.py +++ b/tests/units/test_var.py @@ -1,6 +1,5 @@ import json import math -import sys import typing from typing import Dict, List, Mapping, Optional, Set, Tuple, Union, cast @@ -422,19 +421,13 @@ class Bar(rx.Base): @pytest.mark.parametrize( ("var", "var_type"), - ( - [ - (Var(_js_expr="", _var_type=Foo | Bar).guess_type(), Foo | Bar), - (Var(_js_expr="", _var_type=Foo | Bar).guess_type().bar, Union[int, str]), - ] - if sys.version_info >= (3, 10) - else [] - ) - + [ - (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type(), Union[Foo, Bar]), - (Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().baz, str), + [ + (Var(_js_expr="").to(Foo | Bar), Foo | Bar), + (Var(_js_expr="").to(Foo | Bar).bar, Union[int, str]), + (Var(_js_expr="").to(Union[Foo, Bar]), Union[Foo, Bar]), + (Var(_js_expr="").to(Union[Foo, Bar]).baz, str), ( - Var(_js_expr="", _var_type=Union[Foo, Bar]).guess_type().foo, + Var(_js_expr="").to(Union[Foo, Bar]).foo, Union[int, None], ), ], @@ -1358,7 +1351,7 @@ def test_unsupported_types_for_contains(var: Var): var: The base var. """ with pytest.raises(TypeError) as err: - assert var.contains(1) + assert var.contains(1) # pyright: ignore [reportAttributeAccessIssue] assert ( err.value.args[0] == f"Var of type {var._var_type} does not support contains check." @@ -1388,7 +1381,7 @@ def test_unsupported_types_for_string_contains(other): def test_unsupported_default_contains(): with pytest.raises(TypeError) as err: - assert 1 in Var(_js_expr="var", _var_type=str).guess_type() + assert 1 in Var(_js_expr="var", _var_type=str).guess_type() # pyright: ignore [reportOperatorIssue] assert ( err.value.args[0] == "'in' operator not supported for Var types, use Var.contains() instead."