diff --git a/poetry.lock b/poetry.lock index e22568c60..a36f44fd1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2813,13 +2813,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", [[package]] name = "virtualenv" -version = "20.28.1" +version = "20.29.1" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.8" files = [ - {file = "virtualenv-20.28.1-py3-none-any.whl", hash = "sha256:412773c85d4dab0409b83ec36f7a6499e72eaf08c80e81e9576bca61831c71cb"}, - {file = "virtualenv-20.28.1.tar.gz", hash = "sha256:5d34ab240fdb5d21549b76f9e8ff3af28252f5499fb6d6f031adac4e5a8c5329"}, + {file = "virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779"}, + {file = "virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35"}, ] [package.dependencies] @@ -3063,4 +3063,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ac7633388e00416af61c93e16c61c3b4a446e406afab2758585088e2b3416eee" +content-hash = "ccd6d6b00fdcf40562854380fafdb18c990b7f6a4f2883b33aaeb0351fcdbc06" diff --git a/pyproject.toml b/pyproject.toml index 67bcef1cb..b64284ddb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ fastapi = ">=0.96.0,!=0.111.0,!=0.111.1" gunicorn = ">=20.1.0,<24.0" jinja2 = ">=3.1.2,<4.0" psutil = ">=5.9.4,<7.0" -pydantic = ">=1.10.2,<3.0" +pydantic = ">=1.10.15,<3.0" python-multipart = ">=0.0.5,<0.1" python-socketio = ">=5.7.0,<6.0" redis = ">=4.3.5,<6.0" @@ -82,6 +82,7 @@ build-backend = "poetry.core.masonry.api" [tool.pyright] reportIncompatibleMethodOverride = false +reportIncompatibleVariableOverride = false [tool.ruff] target-version = "py39" diff --git a/reflex/base.py b/reflex/base.py index a88e557ef..1f0bbe00d 100644 --- a/reflex/base.py +++ b/reflex/base.py @@ -5,15 +5,9 @@ from __future__ import annotations import os from typing import TYPE_CHECKING, Any, List, Type -try: - import pydantic.v1.main as pydantic_main - from pydantic.v1 import BaseModel - from pydantic.v1.fields import ModelField -except ModuleNotFoundError: - if not TYPE_CHECKING: - import pydantic.main as pydantic_main - from pydantic import BaseModel - from pydantic.fields import ModelField # type: ignore +import pydantic.v1.main as pydantic_main +from pydantic.v1 import BaseModel +from pydantic.v1.fields import ModelField def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None: diff --git a/reflex/components/core/cond.py b/reflex/components/core/cond.py index 98ce605d6..e34201c74 100644 --- a/reflex/components/core/cond.py +++ b/reflex/components/core/cond.py @@ -113,11 +113,7 @@ class Cond(MemoizationLeaf): @overload -def cond(condition: Any, c1: Component, c2: Any) -> Component: ... - - -@overload -def cond(condition: Any, c1: Component) -> Component: ... +def cond(condition: Any, c1: Component, c2: Any = None) -> Component: ... @overload diff --git a/reflex/components/core/match.py b/reflex/components/core/match.py index 2def33d65..a053b9abe 100644 --- a/reflex/components/core/match.py +++ b/reflex/components/core/match.py @@ -1,15 +1,17 @@ """rx.match.""" import textwrap -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, cast from reflex.components.base import Fragment from reflex.components.component import BaseComponent, Component, MemoizationLeaf from reflex.utils import types from reflex.utils.exceptions import MatchTypeError -from reflex.vars.base import Var +from reflex.vars.base import VAR_TYPE, Var from reflex.vars.number import MatchOperation +CASE_TYPE = tuple[*tuple[Any, ...], Var[VAR_TYPE] | VAR_TYPE] + class Match(MemoizationLeaf): """Match cases based on a condition.""" @@ -24,7 +26,11 @@ class Match(MemoizationLeaf): default: Any @classmethod - def create(cls, cond: Any, *cases) -> Union[Component, Var]: + def create( + cls, + cond: Any, + *cases: *tuple[*tuple[CASE_TYPE[VAR_TYPE], ...], Var[VAR_TYPE] | VAR_TYPE], + ) -> Var[VAR_TYPE]: """Create a Match Component. Args: @@ -37,44 +43,6 @@ class Match(MemoizationLeaf): Raises: ValueError: When a default case is not provided for cases with Var return types. """ - cases, default = cls._process_cases(cases) - cls._process_match_cases(cases) - - cls._validate_return_types(cases) - - if default is None and any( - not ( - isinstance((return_type := case[-1]), Component) - or ( - isinstance(return_type, Var) - and types.typehint_issubclass(return_type._var_type, Component) - ) - ) - for case in cases - ): - raise ValueError( - "For cases with return types as Vars, a default case must be provided" - ) - elif default is None: - default = Fragment.create() - - return cls._create_match_cond_var_or_component(cond, cases, default) - - @classmethod - def _process_cases( - cls, cases: Sequence - ) -> Tuple[List, Optional[Union[Var, BaseComponent]]]: - """Process the list of match cases and the catchall default case. - - Args: - cases: The list of match cases. - - Returns: - The default case and the list of match case tuples. - - Raises: - ValueError: If there are multiple default cases. - """ default = None if len([case for case in cases if not isinstance(case, tuple)]) > 1: @@ -86,12 +54,40 @@ class Match(MemoizationLeaf): # Get the default case which should be the last non-tuple arg if not isinstance(cases[-1], tuple): default = cases[-1] - cases = cases[:-1] + actual_cases = cases[:-1] + else: + actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases) - return list(cases), default + cls._process_match_cases(actual_cases) + + cls._validate_return_types(actual_cases) + + if default is None and any( + not ( + isinstance((return_type := case[-1]), Component) + or ( + isinstance(return_type, Var) + and types.typehint_issubclass(return_type._var_type, Component) + ) + ) + for case in actual_cases + ): + raise ValueError( + "For cases with return types as Vars, a default case must be provided" + ) + elif default is None: + default = Fragment.create() + + default = cast(Var[VAR_TYPE] | VAR_TYPE, default) + + return cls._create_match_cond_var_or_component( + cond, + actual_cases, + default, + ) @classmethod - def _process_match_cases(cls, cases: Sequence): + def _process_match_cases(cls, cases: tuple[CASE_TYPE[VAR_TYPE], ...]): """Process the individual match cases. Args: @@ -116,7 +112,9 @@ class Match(MemoizationLeaf): ) @classmethod - def _validate_return_types(cls, match_cases: List[List[Var]]) -> None: + def _validate_return_types( + cls, match_cases: tuple[CASE_TYPE[VAR_TYPE], ...] + ) -> None: """Validate that match cases have the same return types. Args: @@ -151,9 +149,9 @@ class Match(MemoizationLeaf): def _create_match_cond_var_or_component( cls, match_cond_var: Var, - match_cases: List[List[Var]], - default: Union[Var, BaseComponent], - ) -> Union[Component, Var]: + match_cases: tuple[CASE_TYPE[VAR_TYPE], ...], + default: VAR_TYPE | Var[VAR_TYPE], + ) -> Var[VAR_TYPE]: """Create and return the match condition var or component. Args: diff --git a/reflex/components/datadisplay/dataeditor.py b/reflex/components/datadisplay/dataeditor.py index f71f97713..b82196fe2 100644 --- a/reflex/components/datadisplay/dataeditor.py +++ b/reflex/components/datadisplay/dataeditor.py @@ -303,7 +303,7 @@ class DataEditor(NoSSRComponent): # Fired when editing is finished. on_finished_editing: EventHandler[ - passthrough_event_spec(Union[GridCell, None], tuple[int, int]) + passthrough_event_spec(Union[GridCell, None], tuple[int, int]) # pyright: ignore[reportArgumentType] ] # Fired when a row is appended. diff --git a/reflex/components/radix/primitives/accordion.py b/reflex/components/radix/primitives/accordion.py index 0ba618e21..2964b6fa5 100644 --- a/reflex/components/radix/primitives/accordion.py +++ b/reflex/components/radix/primitives/accordion.py @@ -197,7 +197,7 @@ class AccordionItem(AccordionComponent): # The header of the accordion item. header: Var[Union[Component, str]] # The content of the accordion item. - content: Var[Union[Component, str]] = Var.create(None) + content: Var[Union[Component, str]] = Var.create("") _valid_children: List[str] = [ "AccordionHeader", diff --git a/reflex/components/tags/iter_tag.py b/reflex/components/tags/iter_tag.py index 3f7aa47f2..f3d9c4d8f 100644 --- a/reflex/components/tags/iter_tag.py +++ b/reflex/components/tags/iter_tag.py @@ -4,9 +4,10 @@ from __future__ import annotations import dataclasses import inspect -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, Union, get_args +from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args from reflex.components.tags.tag import Tag +from reflex.utils import types from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name if TYPE_CHECKING: @@ -31,7 +32,7 @@ class IterTag(Tag): # The name of the index var. index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name) - def get_iterable_var_type(self) -> Type: + def get_iterable_var_type(self) -> types.GenericType: """Get the type of the iterable var. Returns: @@ -41,10 +42,10 @@ class IterTag(Tag): try: if iterable._var_type.mro()[0] is dict: # Arg is a tuple of (key, value). - return Tuple[get_args(iterable._var_type)] # type: ignore + return Tuple[get_args(iterable._var_type)] elif iterable._var_type.mro()[0] is tuple: # Arg is a union of any possible values in the tuple. - return Union[get_args(iterable._var_type)] # type: ignore + return Union[get_args(iterable._var_type)] else: return get_args(iterable._var_type)[0] except Exception: diff --git a/reflex/event.py b/reflex/event.py index 28852fde5..dee33bced 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -25,7 +25,6 @@ from typing import ( overload, ) -import typing_extensions from typing_extensions import ( Concatenate, ParamSpec, @@ -33,6 +32,8 @@ from typing_extensions import ( TypeAliasType, TypedDict, TypeVar, + TypeVarTuple, + deprecated, get_args, get_origin, ) @@ -620,14 +621,16 @@ stop_propagation = EventChain(events=[], args_spec=no_args_event_spec).stop_prop prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default -T = TypeVar("T") -U = TypeVar("U") +EVENT_T = TypeVar("EVENT_T") +EVENT_U = TypeVar("EVENT_U") + +Ts = TypeVarTuple("Ts") -class IdentityEventReturn(Generic[T], Protocol): +class IdentityEventReturn(Generic[*Ts], Protocol): """Protocol for an identity event return.""" - def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]: + def __call__(self, *values: *Ts) -> tuple[*Ts]: """Return the input values. Args: @@ -641,21 +644,25 @@ class IdentityEventReturn(Generic[T], Protocol): @overload def passthrough_event_spec( - event_type: Type[T], / -) -> Callable[[Var[T]], Tuple[Var[T]]]: ... # type: ignore + event_type: Type[EVENT_T], / +) -> IdentityEventReturn[Var[EVENT_T]]: ... @overload def passthrough_event_spec( - event_type_1: Type[T], event_type2: Type[U], / -) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ... + event_type_1: Type[EVENT_T], event_type2: Type[EVENT_U], / +) -> IdentityEventReturn[Var[EVENT_T], Var[EVENT_U]]: ... @overload -def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: ... +def passthrough_event_spec( + *event_types: *tuple[Type[EVENT_T]], +) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]: ... -def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: # type: ignore +def passthrough_event_spec( # pyright: ignore[reportInconsistentOverload] + *event_types: Type[EVENT_T], +) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]: """A helper function that returns the input event as output. Args: @@ -665,7 +672,7 @@ def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: # A function that returns the input event as output. """ - def inner(*values: Var[T]) -> Tuple[Var[T], ...]: + def inner(*values: Var[EVENT_T]) -> Tuple[Var[EVENT_T], ...]: return values inner_type = tuple(Var[event_type] for event_type in event_types) @@ -800,7 +807,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec: return None fn.__qualname__ = name - fn.__signature__ = sig + fn.__signature__ = sig # pyright: ignore[reportFunctionMemberAccess] return EventSpec( handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE), args=tuple( @@ -822,7 +829,7 @@ def redirect( @overload -@typing_extensions.deprecated("`external` is deprecated use `is_external` instead") +@deprecated("`external` is deprecated use `is_external` instead") def redirect( path: str | Var[str], is_external: Optional[bool] = None, @@ -1826,6 +1833,37 @@ class EventCallback(Generic[P, T]): """ self.func = func + def throttle(self, limit_ms: int): + """Throttle the event handler. + + Args: + limit_ms: The time in milliseconds to throttle the event handler. + + Returns: + New EventHandler-like with throttle set to limit_ms. + """ + return self + + def debounce(self, delay_ms: int): + """Debounce the event handler. + + Args: + delay_ms: The time in milliseconds to debounce the event handler. + + Returns: + New EventHandler-like with debounce set to delay_ms. + """ + return self + + @property + def temporal(self): + """Do not queue the event if the backend is down. + + Returns: + New EventHandler-like with temporal set to True. + """ + return self + @property def prevent_default(self): """Prevent default behavior. diff --git a/reflex/state.py b/reflex/state.py index f9931f920..1f7339d5c 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -587,8 +587,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if cls._item_is_event_handler(name, fn) } - for mixin in cls._mixins(): - for name, value in mixin.__dict__.items(): + for mixin_class in cls._mixins(): + for name, value in mixin_class.__dict__.items(): if name in cls.inherited_vars: continue if is_computed_var(value): @@ -599,7 +599,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cls.computed_vars[newcv._js_expr] = newcv cls.vars[newcv._js_expr] = newcv continue - if types.is_backend_base_variable(name, mixin): + if types.is_backend_base_variable(name, mixin_class): cls.backend_vars[name] = copy.deepcopy(value) continue if events.get(name) is not None: @@ -3710,6 +3710,9 @@ def get_state_manager() -> StateManager: return app.state_manager +DATACLASS_FIELDS = getattr(dataclasses, "_FIELDS", "__dataclass_fields__") + + class MutableProxy(wrapt.ObjectProxy): """A proxy for a mutable object that tracks changes.""" @@ -3781,12 +3784,7 @@ class MutableProxy(wrapt.ObjectProxy): cls.__dataclass_proxies__[wrapper_cls_name] = type( wrapper_cls_name, (cls,), - { - dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues] - wrapped_cls, - dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues] - ), - }, + {DATACLASS_FIELDS: getattr(wrapped_cls, DATACLASS_FIELDS)}, ) cls = cls.__dataclass_proxies__[wrapper_cls_name] return super().__new__(cls) @@ -3933,11 +3931,11 @@ class MutableProxy(wrapt.ObjectProxy): if ( isinstance(self.__wrapped__, Base) and __name not in self.__never_wrap_base_attrs__ - and hasattr(value, "__func__") + and (value_func := getattr(value, "__func__", None)) ): # Wrap methods called on Base subclasses, which might do _anything_ return wrapt.FunctionWrapper( - functools.partial(value.__func__, self), + functools.partial(value_func, self), self._wrap_recursive_decorator, ) diff --git a/reflex/testing.py b/reflex/testing.py index b3dedf398..68b4e99fd 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -67,10 +67,8 @@ try: from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports] WebElement, ) - - has_selenium = True except ImportError: - has_selenium = False + webdriver = None # The timeout (minutes) to check for the port. DEFAULT_TIMEOUT = 15 @@ -293,8 +291,12 @@ class AppHarness: if p not in before_decorated_pages ] self.app_instance = self.app_module.app + if self.app_instance is None: + raise RuntimeError("App was not initialized.") if isinstance(self.app_instance._state_manager, StateManagerRedis): # Create our own redis connection for testing. + if self.app_instance.state is None: + raise RuntimeError("App state is not initialized.") self.state_manager = StateManagerRedis.create(self.app_instance.state) else: self.state_manager = self.app_instance._state_manager @@ -608,7 +610,7 @@ class AppHarness: Raises: RuntimeError: when selenium is not importable or frontend is not running """ - if not has_selenium: + if webdriver is None: raise RuntimeError( "Frontend functionality requires `selenium` to be installed, " "and it could not be imported." diff --git a/reflex/utils/console.py b/reflex/utils/console.py index 8929b63b6..fc32ac73e 100644 --- a/reflex/utils/console.py +++ b/reflex/utils/console.py @@ -203,10 +203,13 @@ def _get_first_non_framework_frame() -> FrameType | None: # Exclude utility modules that should never be the source of deprecated reflex usage. exclude_modules = [click, rx, typer, typing_extensions] exclude_roots = [ - p.parent.resolve() - if (p := Path(m.__file__)).name == "__init__.py" - else p.resolve() + ( + p.parent.resolve() + if (p := Path(m.__file__)).name == "__init__.py" + else p.resolve() + ) for m in exclude_modules + if m.__file__ ] # Specifically exclude the reflex cli module. if reflex_bin := shutil.which(b"reflex"): diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 82b169db4..cf92efce3 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -3197,17 +3197,20 @@ class Field(Generic[T]): def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ... @overload - def __get__(self: Field[int], instance: None, owner) -> NumberVar: ... + def __get__(self: Field[int], instance: None, owner) -> NumberVar[int]: ... @overload - def __get__(self: Field[str], instance: None, owner) -> StringVar: ... + def __get__(self: Field[float], instance: None, owner) -> NumberVar[float]: ... + + @overload + def __get__(self: Field[str], instance: None, owner) -> StringVar[str]: ... @overload def __get__(self: Field[None], instance: None, owner) -> NoneVar: ... @overload def __get__( - self: Field[Sequence[V]] | Field[Set[V]], + self: Field[Sequence[V]] | Field[Set[V]] | Field[List[V]], instance: None, owner, ) -> ArrayVar[Sequence[V]]: ... diff --git a/reflex/vars/number.py b/reflex/vars/number.py index 2d0fb3326..1e85a1c74 100644 --- a/reflex/vars/number.py +++ b/reflex/vars/number.py @@ -1069,24 +1069,11 @@ def ternary_operation( return value -TUPLE_ENDS_IN_VAR = ( - tuple[Var[VAR_TYPE]] - | tuple[Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var, Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var, Var, Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]] - | tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]] - | tuple[ - Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE] - ] -) +X = tuple[*tuple[Var, ...], str] + +TUPLE_ENDS_IN_VAR = tuple[*tuple[Var[Any], ...], Var[VAR_TYPE]] + +TUPLE_ENDS_IN_VAR_RELAXED = tuple[*tuple[Var[Any] | Any, ...], Var[VAR_TYPE] | VAR_TYPE] @dataclasses.dataclass( @@ -1153,7 +1140,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]): def create( cls, cond: Any, - cases: Sequence[Sequence[Any | Var[VAR_TYPE]]], + cases: Sequence[TUPLE_ENDS_IN_VAR_RELAXED[VAR_TYPE]], default: Var[VAR_TYPE] | VAR_TYPE, _var_data: VarData | None = None, _var_type: type[VAR_TYPE] | None = None, @@ -1175,6 +1162,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]): tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...], tuple(tuple(Var.create(c) for c in case) for case in cases), ) + _default = cast(Var[VAR_TYPE], Var.create(default)) var_type = _var_type or unionize( *(case[-1]._var_type for case in cases), diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 7c793a0fb..0c3c001c9 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -45,7 +45,7 @@ from .base import ( from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .sequence import ArrayVar, StringVar -OBJECT_TYPE = TypeVar("OBJECT_TYPE") +OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True) KEY_TYPE = TypeVar("KEY_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE") @@ -164,7 +164,8 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict): @overload def __getitem__( - self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]], + self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]] + | ObjectVar[Dict[Any, List[ARRAY_INNER_TYPE]]], key: Var | Any, ) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ... diff --git a/tests/integration/test_event_actions.py b/tests/integration/test_event_actions.py index 15f3c9877..084a8c4fa 100644 --- a/tests/integration/test_event_actions.py +++ b/tests/integration/test_event_actions.py @@ -28,9 +28,11 @@ def TestEventAction(): def on_click2(self): self.order.append("on_click2") + @rx.event def on_click_throttle(self): self.order.append("on_click_throttle") + @rx.event def on_click_debounce(self): self.order.append("on_click_debounce") diff --git a/tests/integration/test_lifespan.py b/tests/integration/test_lifespan.py index 0fa4a7e92..98d8addfa 100644 --- a/tests/integration/test_lifespan.py +++ b/tests/integration/test_lifespan.py @@ -22,9 +22,9 @@ def LifespanApp(): @asynccontextmanager async def lifespan_context(app, inc: int = 1): - global lifespan_context_global + nonlocal lifespan_context_global print(f"Lifespan context entered: {app}.") - lifespan_context_global += inc # pyright: ignore[reportUnboundVariable] + lifespan_context_global += inc try: yield finally: @@ -32,11 +32,11 @@ def LifespanApp(): lifespan_context_global += inc async def lifespan_task(inc: int = 1): - global lifespan_task_global + nonlocal lifespan_task_global print("Lifespan global started.") try: while True: - lifespan_task_global += inc # pyright: ignore[reportUnboundVariable] + lifespan_task_global += inc await asyncio.sleep(0.1) except asyncio.CancelledError as ce: print(f"Lifespan global cancelled: {ce}.") diff --git a/tests/integration/test_var_operations.py b/tests/integration/test_var_operations.py index 7a7c8328d..fff4c975c 100644 --- a/tests/integration/test_var_operations.py +++ b/tests/integration/test_var_operations.py @@ -19,25 +19,27 @@ def VarOperations(): from reflex.vars.sequence import ArrayVar class Object(rx.Base): - str: str = "hello" + name: str = "hello" class VarOperationState(rx.State): - int_var1: int = 10 - int_var2: int = 5 - int_var3: int = 7 - float_var1: float = 10.5 - float_var2: float = 5.5 - list1: List = [1, 2] - list2: List = [3, 4] - list3: List = ["first", "second", "third"] - list4: List = [Object(name="obj_1"), Object(name="obj_2")] - str_var1: str = "first" - str_var2: str = "second" - str_var3: str = "ThIrD" - str_var4: str = "a long string" - dict1: Dict[int, int] = {1: 2} - dict2: Dict[int, int] = {3: 4} - html_str: str = "
hello
" + int_var1: rx.Field[int] = rx.field(10) + int_var2: rx.Field[int] = rx.field(5) + int_var3: rx.Field[int] = rx.field(7) + float_var1: rx.Field[float] = rx.field(10.5) + float_var2: rx.Field[float] = rx.field(5.5) + list1: rx.Field[List[int]] = rx.field([1, 2]) + list2: rx.Field[List[int]] = rx.field([3, 4]) + list3: rx.Field[List[str]] = rx.field(["first", "second", "third"]) + list4: rx.Field[List[Object]] = rx.field( + [Object(name="obj_1"), Object(name="obj_2")] + ) + str_var1: rx.Field[str] = rx.field("first") + str_var2: rx.Field[str] = rx.field("second") + str_var3: rx.Field[str] = rx.field("ThIrD") + str_var4: rx.Field[str] = rx.field("a long string") + dict1: rx.Field[Dict[int, int]] = rx.field({1: 2}) + dict2: rx.Field[Dict[int, int]] = rx.field({3: 4}) + html_str: rx.Field[str] = rx.field("
hello
") app = rx.App(state=rx.State) diff --git a/tests/units/components/core/test_cond.py b/tests/units/components/core/test_cond.py index 71e19c7cc..06b108317 100644 --- a/tests/units/components/core/test_cond.py +++ b/tests/units/components/core/test_cond.py @@ -13,7 +13,7 @@ from reflex.vars.base import LiteralVar, Var, computed_var @pytest.fixture def cond_state(request): class CondState(BaseState): - value: request.param["value_type"] = request.param["value"] # noqa + value: request.param["value_type"] = request.param["value"] # pyright: ignore[reportInvalidTypeForm, reportUndefinedVariable] # noqa: F821 return CondState diff --git a/tests/units/components/core/test_match.py b/tests/units/components/core/test_match.py index 59111d183..83581a415 100644 --- a/tests/units/components/core/test_match.py +++ b/tests/units/components/core/test_match.py @@ -67,7 +67,7 @@ def test_match_vars(cases, expected): cases: The match cases. expected: The expected var full name. """ - match_comp = Match.create(MatchState.value, *cases) + match_comp = Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue] assert isinstance(match_comp, Var) assert str(match_comp) == expected @@ -131,7 +131,7 @@ def test_match_default_not_last_arg(match_case): ValueError, match="rx.match should have tuples of cases and a default case as the last argument.", ): - Match.create(MatchState.value, *match_case) + Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] @pytest.mark.parametrize( @@ -161,7 +161,7 @@ def test_match_case_tuple_elements(match_case): ValueError, match="A case tuple should have at least a match case element and a return value.", ): - Match.create(MatchState.value, *match_case) + Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] @pytest.mark.parametrize( @@ -203,7 +203,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str): error_msg: Expected error message. """ with pytest.raises(MatchTypeError, match=error_msg): - Match.create(MatchState.value, *cases) + Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue] @pytest.mark.parametrize( @@ -235,9 +235,9 @@ def test_match_multiple_default_cases(match_case): match_case: the cases to match. """ with pytest.raises(ValueError, match="rx.match can only have one default case."): - Match.create(MatchState.value, *match_case) + Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue] def test_match_no_cond(): with pytest.raises(ValueError): - _ = Match.create(None) + _ = Match.create(None) # pyright: ignore[reportCallIssue] diff --git a/tests/units/components/datadisplay/test_datatable.py b/tests/units/components/datadisplay/test_datatable.py index b3d31ea32..79dd233ba 100644 --- a/tests/units/components/datadisplay/test_datatable.py +++ b/tests/units/components/datadisplay/test_datatable.py @@ -13,7 +13,8 @@ from reflex.utils.serializers import serialize, serialize_dataframe pytest.param( { "data": pd.DataFrame( - [["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"] + [["foo", "bar"], ["foo1", "bar1"]], + columns=["column1", "column2"], # pyright: ignore [reportArgumentType] ) }, "data", @@ -113,7 +114,8 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram def test_serialize_dataframe(): """Test if dataframe is serialized correctly.""" df = pd.DataFrame( - [["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"] + [["foo", "bar"], ["foo1", "bar1"]], + columns=["column1", "column2"], # pyright: ignore [reportArgumentType] ) value = serialize(df) assert value == serialize_dataframe(df) diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 48a4bdda1..b8ae06aae 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -9,7 +9,7 @@ import unittest.mock import uuid from contextlib import nullcontext as does_not_raise from pathlib import Path -from typing import Generator, List, Tuple, Type +from typing import Generator, List, Tuple, Type, cast from unittest.mock import AsyncMock import pytest @@ -33,7 +33,7 @@ from reflex.components import Component from reflex.components.base.fragment import Fragment from reflex.components.core.cond import Cond from reflex.components.radix.themes.typography.text import Text -from reflex.event import Event +from reflex.event import Event, EventHandler from reflex.middleware import HydrateMiddleware from reflex.model import Model from reflex.state import ( @@ -917,7 +917,7 @@ class DynamicState(BaseState): """ return self.dynamic - on_load_internal = OnLoadInternalState.on_load_internal.fn + on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn def test_dynamic_arg_shadow( @@ -1190,7 +1190,7 @@ async def test_process_events(mocker, token: str): pass assert (await app.state_manager.get_state(event.substate_token)).value == 5 - assert app._postprocess.call_count == 6 + assert getattr(app._postprocess, "call_count", None) == 6 if isinstance(app.state_manager, StateManagerRedis): await app.state_manager.close() @@ -1247,7 +1247,7 @@ def test_overlay_component( if exp_page_child is not None: assert len(page.children) == 3 - children_types = (type(child) for child in page.children) + children_types = [type(child) for child in page.children] assert exp_page_child in children_types else: assert len(page.children) == 2 diff --git a/tests/units/test_event.py b/tests/units/test_event.py index d7e993efa..bc827078d 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -5,6 +5,7 @@ import pytest import reflex as rx from reflex.event import ( Event, + EventActionsMixin, EventChain, EventHandler, EventSpec, @@ -410,6 +411,7 @@ def test_event_actions(): def test_event_actions_on_state(): class EventActionState(BaseState): + @rx.event def handler(self): pass @@ -418,6 +420,7 @@ def test_event_actions_on_state(): assert not handler.event_actions sp_handler = EventActionState.handler.stop_propagation + assert isinstance(sp_handler, EventActionsMixin) assert sp_handler.event_actions == {"stopPropagation": True} # should NOT affect other references to the handler assert not handler.event_actions diff --git a/tests/units/test_health_endpoint.py b/tests/units/test_health_endpoint.py index 6d12d79d6..abfa6cc62 100644 --- a/tests/units/test_health_endpoint.py +++ b/tests/units/test_health_endpoint.py @@ -122,9 +122,12 @@ async def test_health( # Call the async health function response = await health() - print(json.loads(response.body)) + body = response.body + assert isinstance(body, bytes) + + print(json.loads(body)) print(expected_status) # Verify the response content and status code assert response.status_code == expected_code - assert json.loads(response.body) == expected_status + assert json.loads(body) == expected_status diff --git a/tests/units/test_sqlalchemy.py b/tests/units/test_sqlalchemy.py index 23e315785..4434f5ee1 100644 --- a/tests/units/test_sqlalchemy.py +++ b/tests/units/test_sqlalchemy.py @@ -59,7 +59,7 @@ def test_automigration( id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None) # initial table - class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] + class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration] t1: Mapped[str] = mapped_column(default="") with Model.get_db_engine().connect() as connection: @@ -78,7 +78,7 @@ def test_automigration( model_registry.get_metadata().clear() # Create column t2, mark t1 as optional with default - class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] + class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration] t1: Mapped[Optional[str]] = mapped_column(default="default") t2: Mapped[str] = mapped_column(default="bar") @@ -98,7 +98,7 @@ def test_automigration( model_registry.get_metadata().clear() # Drop column t1 - class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] + class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration] t2: Mapped[str] = mapped_column(default="bar") assert Model.migrate(autogenerate=True) @@ -133,7 +133,7 @@ def test_automigration( # drop table (AlembicSecond) model_registry.get_metadata().clear() - class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] + class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration] t2: Mapped[str] = mapped_column(default="bar") assert Model.migrate(autogenerate=True) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 3898e4658..51b31e8e1 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -17,6 +17,7 @@ from typing import ( Dict, List, Optional, + Sequence, Set, Tuple, Union, @@ -120,8 +121,8 @@ class TestState(BaseState): num2: float = 3.14 key: str map_key: str = "a" - array: List[float] = [1, 2, 3.14] - mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]} + array: rx.Field[List[float]] = rx.field([1, 2, 3.14]) + mapping: rx.Field[Dict[str, List[int]]] = rx.field({"a": [1, 2, 3], "b": [4, 5, 6]}) obj: Object = Object() complex: Dict[int, Object] = {1: Object(), 2: Object()} fig: Figure = Figure() @@ -1357,6 +1358,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): class HandlerState(BaseState): x: int = 42 + @rx.event def handler(self): self.x = self.x + 1 @@ -1367,11 +1369,11 @@ def test_cached_var_depends_on_event_handler(use_partial: bool): counter += 1 return counter + assert isinstance(HandlerState.handler, EventHandler) if use_partial: - HandlerState.handler = functools.partial(HandlerState.handler.fn) + partial_guy = functools.partial(HandlerState.handler.fn) + HandlerState.handler = partial_guy # pyright: ignore[reportAttributeAccessIssue] assert isinstance(HandlerState.handler, functools.partial) - else: - assert isinstance(HandlerState.handler, EventHandler) s = HandlerState() assert "cached_x_side_effect" in s._computed_var_dependencies["x"] @@ -2025,8 +2027,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App): # ensure state update was emitted assert mock_app.event_namespace is not None - mock_app.event_namespace.emit.assert_called_once() - mcall = mock_app.event_namespace.emit.mock_calls[0] + mock_app.event_namespace.emit.assert_called_once() # pyright: ignore[reportFunctionMemberAccess] + mock_calls = getattr(mock_app.event_namespace.emit, "mock_calls", None) + assert mock_calls is not None + assert isinstance(mock_calls, Sequence) + mcall = mock_calls[0] assert mcall.args[0] == str(SocketEvent.EVENT) assert mcall.args[1] == StateUpdate( delta={ @@ -2231,7 +2236,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): assert mock_app.event_namespace is not None emit_mock = mock_app.event_namespace.emit - first_ws_message = emit_mock.mock_calls[0].args[1] + mock_calls = getattr(emit_mock, "mock_calls", None) + assert mock_calls is not None + assert isinstance(mock_calls, Sequence) + + first_ws_message = mock_calls[0].args[1] assert ( first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router") is not None @@ -2246,7 +2255,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): events=[], final=True, ) - for call in emit_mock.mock_calls[1:5]: + for call in mock_calls[1:5]: assert call.args[1] == StateUpdate( delta={ BackgroundTaskState.get_full_name(): { @@ -2256,7 +2265,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): events=[], final=True, ) - assert emit_mock.mock_calls[-2].args[1] == StateUpdate( + assert mock_calls[-2].args[1] == StateUpdate( delta={ BackgroundTaskState.get_full_name(): { "order": exp_order, @@ -2267,7 +2276,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): events=[], final=True, ) - assert emit_mock.mock_calls[-1].args[1] == StateUpdate( + assert mock_calls[-1].args[1] == StateUpdate( delta={ BackgroundTaskState.get_full_name(): { "computed_order": exp_order,