down to only two pyright error

This commit is contained in:
Khaleel Al-Adhami 2025-01-17 14:16:03 -08:00
parent 112b2ed948
commit 57d8ea02e9
26 changed files with 225 additions and 181 deletions

8
poetry.lock generated
View File

@ -2813,13 +2813,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
[[package]] [[package]]
name = "virtualenv" name = "virtualenv"
version = "20.28.1" version = "20.29.1"
description = "Virtual Python Environment builder" description = "Virtual Python Environment builder"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "virtualenv-20.28.1-py3-none-any.whl", hash = "sha256:412773c85d4dab0409b83ec36f7a6499e72eaf08c80e81e9576bca61831c71cb"}, {file = "virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779"},
{file = "virtualenv-20.28.1.tar.gz", hash = "sha256:5d34ab240fdb5d21549b76f9e8ff3af28252f5499fb6d6f031adac4e5a8c5329"}, {file = "virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35"},
] ]
[package.dependencies] [package.dependencies]
@ -3063,4 +3063,4 @@ type = ["pytest-mypy"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "ac7633388e00416af61c93e16c61c3b4a446e406afab2758585088e2b3416eee" content-hash = "ccd6d6b00fdcf40562854380fafdb18c990b7f6a4f2883b33aaeb0351fcdbc06"

View File

@ -23,7 +23,7 @@ fastapi = ">=0.96.0,!=0.111.0,!=0.111.1"
gunicorn = ">=20.1.0,<24.0" gunicorn = ">=20.1.0,<24.0"
jinja2 = ">=3.1.2,<4.0" jinja2 = ">=3.1.2,<4.0"
psutil = ">=5.9.4,<7.0" psutil = ">=5.9.4,<7.0"
pydantic = ">=1.10.2,<3.0" pydantic = ">=1.10.15,<3.0"
python-multipart = ">=0.0.5,<0.1" python-multipart = ">=0.0.5,<0.1"
python-socketio = ">=5.7.0,<6.0" python-socketio = ">=5.7.0,<6.0"
redis = ">=4.3.5,<6.0" redis = ">=4.3.5,<6.0"
@ -82,6 +82,7 @@ build-backend = "poetry.core.masonry.api"
[tool.pyright] [tool.pyright]
reportIncompatibleMethodOverride = false reportIncompatibleMethodOverride = false
reportIncompatibleVariableOverride = false
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py39"

View File

@ -5,15 +5,9 @@ from __future__ import annotations
import os import os
from typing import TYPE_CHECKING, Any, List, Type from typing import TYPE_CHECKING, Any, List, Type
try: import pydantic.v1.main as pydantic_main
import pydantic.v1.main as pydantic_main from pydantic.v1 import BaseModel
from pydantic.v1 import BaseModel from pydantic.v1.fields import ModelField
from pydantic.v1.fields import ModelField
except ModuleNotFoundError:
if not TYPE_CHECKING:
import pydantic.main as pydantic_main
from pydantic import BaseModel
from pydantic.fields import ModelField # type: ignore
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None: def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:

View File

@ -113,11 +113,7 @@ class Cond(MemoizationLeaf):
@overload @overload
def cond(condition: Any, c1: Component, c2: Any) -> Component: ... def cond(condition: Any, c1: Component, c2: Any = None) -> Component: ...
@overload
def cond(condition: Any, c1: Component) -> Component: ...
@overload @overload

View File

@ -1,15 +1,17 @@
"""rx.match.""" """rx.match."""
import textwrap import textwrap
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, List, cast
from reflex.components.base import Fragment from reflex.components.base import Fragment
from reflex.components.component import BaseComponent, Component, MemoizationLeaf from reflex.components.component import BaseComponent, Component, MemoizationLeaf
from reflex.utils import types from reflex.utils import types
from reflex.utils.exceptions import MatchTypeError from reflex.utils.exceptions import MatchTypeError
from reflex.vars.base import Var from reflex.vars.base import VAR_TYPE, Var
from reflex.vars.number import MatchOperation from reflex.vars.number import MatchOperation
CASE_TYPE = tuple[*tuple[Any, ...], Var[VAR_TYPE] | VAR_TYPE]
class Match(MemoizationLeaf): class Match(MemoizationLeaf):
"""Match cases based on a condition.""" """Match cases based on a condition."""
@ -24,7 +26,11 @@ class Match(MemoizationLeaf):
default: Any default: Any
@classmethod @classmethod
def create(cls, cond: Any, *cases) -> Union[Component, Var]: def create(
cls,
cond: Any,
*cases: *tuple[*tuple[CASE_TYPE[VAR_TYPE], ...], Var[VAR_TYPE] | VAR_TYPE],
) -> Var[VAR_TYPE]:
"""Create a Match Component. """Create a Match Component.
Args: Args:
@ -37,44 +43,6 @@ class Match(MemoizationLeaf):
Raises: Raises:
ValueError: When a default case is not provided for cases with Var return types. ValueError: When a default case is not provided for cases with Var return types.
""" """
cases, default = cls._process_cases(cases)
cls._process_match_cases(cases)
cls._validate_return_types(cases)
if default is None and any(
not (
isinstance((return_type := case[-1]), Component)
or (
isinstance(return_type, Var)
and types.typehint_issubclass(return_type._var_type, Component)
)
)
for case in cases
):
raise ValueError(
"For cases with return types as Vars, a default case must be provided"
)
elif default is None:
default = Fragment.create()
return cls._create_match_cond_var_or_component(cond, cases, default)
@classmethod
def _process_cases(
cls, cases: Sequence
) -> Tuple[List, Optional[Union[Var, BaseComponent]]]:
"""Process the list of match cases and the catchall default case.
Args:
cases: The list of match cases.
Returns:
The default case and the list of match case tuples.
Raises:
ValueError: If there are multiple default cases.
"""
default = None default = None
if len([case for case in cases if not isinstance(case, tuple)]) > 1: if len([case for case in cases if not isinstance(case, tuple)]) > 1:
@ -86,12 +54,40 @@ class Match(MemoizationLeaf):
# Get the default case which should be the last non-tuple arg # Get the default case which should be the last non-tuple arg
if not isinstance(cases[-1], tuple): if not isinstance(cases[-1], tuple):
default = cases[-1] default = cases[-1]
cases = cases[:-1] actual_cases = cases[:-1]
else:
actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
return list(cases), default cls._process_match_cases(actual_cases)
cls._validate_return_types(actual_cases)
if default is None and any(
not (
isinstance((return_type := case[-1]), Component)
or (
isinstance(return_type, Var)
and types.typehint_issubclass(return_type._var_type, Component)
)
)
for case in actual_cases
):
raise ValueError(
"For cases with return types as Vars, a default case must be provided"
)
elif default is None:
default = Fragment.create()
default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
return cls._create_match_cond_var_or_component(
cond,
actual_cases,
default,
)
@classmethod @classmethod
def _process_match_cases(cls, cases: Sequence): def _process_match_cases(cls, cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
"""Process the individual match cases. """Process the individual match cases.
Args: Args:
@ -116,7 +112,9 @@ class Match(MemoizationLeaf):
) )
@classmethod @classmethod
def _validate_return_types(cls, match_cases: List[List[Var]]) -> None: def _validate_return_types(
cls, match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]
) -> None:
"""Validate that match cases have the same return types. """Validate that match cases have the same return types.
Args: Args:
@ -151,9 +149,9 @@ class Match(MemoizationLeaf):
def _create_match_cond_var_or_component( def _create_match_cond_var_or_component(
cls, cls,
match_cond_var: Var, match_cond_var: Var,
match_cases: List[List[Var]], match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
default: Union[Var, BaseComponent], default: VAR_TYPE | Var[VAR_TYPE],
) -> Union[Component, Var]: ) -> Var[VAR_TYPE]:
"""Create and return the match condition var or component. """Create and return the match condition var or component.
Args: Args:

View File

@ -303,7 +303,7 @@ class DataEditor(NoSSRComponent):
# Fired when editing is finished. # Fired when editing is finished.
on_finished_editing: EventHandler[ on_finished_editing: EventHandler[
passthrough_event_spec(Union[GridCell, None], tuple[int, int]) passthrough_event_spec(Union[GridCell, None], tuple[int, int]) # pyright: ignore[reportArgumentType]
] ]
# Fired when a row is appended. # Fired when a row is appended.

View File

@ -197,7 +197,7 @@ class AccordionItem(AccordionComponent):
# The header of the accordion item. # The header of the accordion item.
header: Var[Union[Component, str]] header: Var[Union[Component, str]]
# The content of the accordion item. # The content of the accordion item.
content: Var[Union[Component, str]] = Var.create(None) content: Var[Union[Component, str]] = Var.create("")
_valid_children: List[str] = [ _valid_children: List[str] = [
"AccordionHeader", "AccordionHeader",

View File

@ -4,9 +4,10 @@ from __future__ import annotations
import dataclasses import dataclasses
import inspect import inspect
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, Union, get_args from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args
from reflex.components.tags.tag import Tag from reflex.components.tags.tag import Tag
from reflex.utils import types
from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
if TYPE_CHECKING: if TYPE_CHECKING:
@ -31,7 +32,7 @@ class IterTag(Tag):
# The name of the index var. # The name of the index var.
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name) index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
def get_iterable_var_type(self) -> Type: def get_iterable_var_type(self) -> types.GenericType:
"""Get the type of the iterable var. """Get the type of the iterable var.
Returns: Returns:
@ -41,10 +42,10 @@ class IterTag(Tag):
try: try:
if iterable._var_type.mro()[0] is dict: if iterable._var_type.mro()[0] is dict:
# Arg is a tuple of (key, value). # Arg is a tuple of (key, value).
return Tuple[get_args(iterable._var_type)] # type: ignore return Tuple[get_args(iterable._var_type)]
elif iterable._var_type.mro()[0] is tuple: elif iterable._var_type.mro()[0] is tuple:
# Arg is a union of any possible values in the tuple. # Arg is a union of any possible values in the tuple.
return Union[get_args(iterable._var_type)] # type: ignore return Union[get_args(iterable._var_type)]
else: else:
return get_args(iterable._var_type)[0] return get_args(iterable._var_type)[0]
except Exception: except Exception:

View File

@ -25,7 +25,6 @@ from typing import (
overload, overload,
) )
import typing_extensions
from typing_extensions import ( from typing_extensions import (
Concatenate, Concatenate,
ParamSpec, ParamSpec,
@ -33,6 +32,8 @@ from typing_extensions import (
TypeAliasType, TypeAliasType,
TypedDict, TypedDict,
TypeVar, TypeVar,
TypeVarTuple,
deprecated,
get_args, get_args,
get_origin, get_origin,
) )
@ -620,14 +621,16 @@ stop_propagation = EventChain(events=[], args_spec=no_args_event_spec).stop_prop
prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default
T = TypeVar("T") EVENT_T = TypeVar("EVENT_T")
U = TypeVar("U") EVENT_U = TypeVar("EVENT_U")
Ts = TypeVarTuple("Ts")
class IdentityEventReturn(Generic[T], Protocol): class IdentityEventReturn(Generic[*Ts], Protocol):
"""Protocol for an identity event return.""" """Protocol for an identity event return."""
def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]: def __call__(self, *values: *Ts) -> tuple[*Ts]:
"""Return the input values. """Return the input values.
Args: Args:
@ -641,21 +644,25 @@ class IdentityEventReturn(Generic[T], Protocol):
@overload @overload
def passthrough_event_spec( def passthrough_event_spec(
event_type: Type[T], / event_type: Type[EVENT_T], /
) -> Callable[[Var[T]], Tuple[Var[T]]]: ... # type: ignore ) -> IdentityEventReturn[Var[EVENT_T]]: ...
@overload @overload
def passthrough_event_spec( def passthrough_event_spec(
event_type_1: Type[T], event_type2: Type[U], / event_type_1: Type[EVENT_T], event_type2: Type[EVENT_U], /
) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ... ) -> IdentityEventReturn[Var[EVENT_T], Var[EVENT_U]]: ...
@overload @overload
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: ... def passthrough_event_spec(
*event_types: *tuple[Type[EVENT_T]],
) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]: ...
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: # type: ignore def passthrough_event_spec( # pyright: ignore[reportInconsistentOverload]
*event_types: Type[EVENT_T],
) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]:
"""A helper function that returns the input event as output. """A helper function that returns the input event as output.
Args: Args:
@ -665,7 +672,7 @@ def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: #
A function that returns the input event as output. A function that returns the input event as output.
""" """
def inner(*values: Var[T]) -> Tuple[Var[T], ...]: def inner(*values: Var[EVENT_T]) -> Tuple[Var[EVENT_T], ...]:
return values return values
inner_type = tuple(Var[event_type] for event_type in event_types) inner_type = tuple(Var[event_type] for event_type in event_types)
@ -800,7 +807,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
return None return None
fn.__qualname__ = name fn.__qualname__ = name
fn.__signature__ = sig fn.__signature__ = sig # pyright: ignore[reportFunctionMemberAccess]
return EventSpec( return EventSpec(
handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE), handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
args=tuple( args=tuple(
@ -822,7 +829,7 @@ def redirect(
@overload @overload
@typing_extensions.deprecated("`external` is deprecated use `is_external` instead") @deprecated("`external` is deprecated use `is_external` instead")
def redirect( def redirect(
path: str | Var[str], path: str | Var[str],
is_external: Optional[bool] = None, is_external: Optional[bool] = None,
@ -1826,6 +1833,37 @@ class EventCallback(Generic[P, T]):
""" """
self.func = func self.func = func
def throttle(self, limit_ms: int):
"""Throttle the event handler.
Args:
limit_ms: The time in milliseconds to throttle the event handler.
Returns:
New EventHandler-like with throttle set to limit_ms.
"""
return self
def debounce(self, delay_ms: int):
"""Debounce the event handler.
Args:
delay_ms: The time in milliseconds to debounce the event handler.
Returns:
New EventHandler-like with debounce set to delay_ms.
"""
return self
@property
def temporal(self):
"""Do not queue the event if the backend is down.
Returns:
New EventHandler-like with temporal set to True.
"""
return self
@property @property
def prevent_default(self): def prevent_default(self):
"""Prevent default behavior. """Prevent default behavior.

View File

@ -587,8 +587,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if cls._item_is_event_handler(name, fn) if cls._item_is_event_handler(name, fn)
} }
for mixin in cls._mixins(): for mixin_class in cls._mixins():
for name, value in mixin.__dict__.items(): for name, value in mixin_class.__dict__.items():
if name in cls.inherited_vars: if name in cls.inherited_vars:
continue continue
if is_computed_var(value): if is_computed_var(value):
@ -599,7 +599,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cls.computed_vars[newcv._js_expr] = newcv cls.computed_vars[newcv._js_expr] = newcv
cls.vars[newcv._js_expr] = newcv cls.vars[newcv._js_expr] = newcv
continue continue
if types.is_backend_base_variable(name, mixin): if types.is_backend_base_variable(name, mixin_class):
cls.backend_vars[name] = copy.deepcopy(value) cls.backend_vars[name] = copy.deepcopy(value)
continue continue
if events.get(name) is not None: if events.get(name) is not None:
@ -3710,6 +3710,9 @@ def get_state_manager() -> StateManager:
return app.state_manager return app.state_manager
DATACLASS_FIELDS = getattr(dataclasses, "_FIELDS", "__dataclass_fields__")
class MutableProxy(wrapt.ObjectProxy): class MutableProxy(wrapt.ObjectProxy):
"""A proxy for a mutable object that tracks changes.""" """A proxy for a mutable object that tracks changes."""
@ -3781,12 +3784,7 @@ class MutableProxy(wrapt.ObjectProxy):
cls.__dataclass_proxies__[wrapper_cls_name] = type( cls.__dataclass_proxies__[wrapper_cls_name] = type(
wrapper_cls_name, wrapper_cls_name,
(cls,), (cls,),
{ {DATACLASS_FIELDS: getattr(wrapped_cls, DATACLASS_FIELDS)},
dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues]
wrapped_cls,
dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues]
),
},
) )
cls = cls.__dataclass_proxies__[wrapper_cls_name] cls = cls.__dataclass_proxies__[wrapper_cls_name]
return super().__new__(cls) return super().__new__(cls)
@ -3933,11 +3931,11 @@ class MutableProxy(wrapt.ObjectProxy):
if ( if (
isinstance(self.__wrapped__, Base) isinstance(self.__wrapped__, Base)
and __name not in self.__never_wrap_base_attrs__ and __name not in self.__never_wrap_base_attrs__
and hasattr(value, "__func__") and (value_func := getattr(value, "__func__", None))
): ):
# Wrap methods called on Base subclasses, which might do _anything_ # Wrap methods called on Base subclasses, which might do _anything_
return wrapt.FunctionWrapper( return wrapt.FunctionWrapper(
functools.partial(value.__func__, self), functools.partial(value_func, self),
self._wrap_recursive_decorator, self._wrap_recursive_decorator,
) )

View File

@ -67,10 +67,8 @@ try:
from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports] from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports]
WebElement, WebElement,
) )
has_selenium = True
except ImportError: except ImportError:
has_selenium = False webdriver = None
# The timeout (minutes) to check for the port. # The timeout (minutes) to check for the port.
DEFAULT_TIMEOUT = 15 DEFAULT_TIMEOUT = 15
@ -293,8 +291,12 @@ class AppHarness:
if p not in before_decorated_pages if p not in before_decorated_pages
] ]
self.app_instance = self.app_module.app self.app_instance = self.app_module.app
if self.app_instance is None:
raise RuntimeError("App was not initialized.")
if isinstance(self.app_instance._state_manager, StateManagerRedis): if isinstance(self.app_instance._state_manager, StateManagerRedis):
# Create our own redis connection for testing. # Create our own redis connection for testing.
if self.app_instance.state is None:
raise RuntimeError("App state is not initialized.")
self.state_manager = StateManagerRedis.create(self.app_instance.state) self.state_manager = StateManagerRedis.create(self.app_instance.state)
else: else:
self.state_manager = self.app_instance._state_manager self.state_manager = self.app_instance._state_manager
@ -608,7 +610,7 @@ class AppHarness:
Raises: Raises:
RuntimeError: when selenium is not importable or frontend is not running RuntimeError: when selenium is not importable or frontend is not running
""" """
if not has_selenium: if webdriver is None:
raise RuntimeError( raise RuntimeError(
"Frontend functionality requires `selenium` to be installed, " "Frontend functionality requires `selenium` to be installed, "
"and it could not be imported." "and it could not be imported."

View File

@ -203,10 +203,13 @@ def _get_first_non_framework_frame() -> FrameType | None:
# Exclude utility modules that should never be the source of deprecated reflex usage. # Exclude utility modules that should never be the source of deprecated reflex usage.
exclude_modules = [click, rx, typer, typing_extensions] exclude_modules = [click, rx, typer, typing_extensions]
exclude_roots = [ exclude_roots = [
(
p.parent.resolve() p.parent.resolve()
if (p := Path(m.__file__)).name == "__init__.py" if (p := Path(m.__file__)).name == "__init__.py"
else p.resolve() else p.resolve()
)
for m in exclude_modules for m in exclude_modules
if m.__file__
] ]
# Specifically exclude the reflex cli module. # Specifically exclude the reflex cli module.
if reflex_bin := shutil.which(b"reflex"): if reflex_bin := shutil.which(b"reflex"):

View File

@ -3197,17 +3197,20 @@ class Field(Generic[T]):
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ... def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
@overload @overload
def __get__(self: Field[int], instance: None, owner) -> NumberVar: ... def __get__(self: Field[int], instance: None, owner) -> NumberVar[int]: ...
@overload @overload
def __get__(self: Field[str], instance: None, owner) -> StringVar: ... def __get__(self: Field[float], instance: None, owner) -> NumberVar[float]: ...
@overload
def __get__(self: Field[str], instance: None, owner) -> StringVar[str]: ...
@overload @overload
def __get__(self: Field[None], instance: None, owner) -> NoneVar: ... def __get__(self: Field[None], instance: None, owner) -> NoneVar: ...
@overload @overload
def __get__( def __get__(
self: Field[Sequence[V]] | Field[Set[V]], self: Field[Sequence[V]] | Field[Set[V]] | Field[List[V]],
instance: None, instance: None,
owner, owner,
) -> ArrayVar[Sequence[V]]: ... ) -> ArrayVar[Sequence[V]]: ...

View File

@ -1069,24 +1069,11 @@ def ternary_operation(
return value return value
TUPLE_ENDS_IN_VAR = ( X = tuple[*tuple[Var, ...], str]
tuple[Var[VAR_TYPE]]
| tuple[Var, Var[VAR_TYPE]] TUPLE_ENDS_IN_VAR = tuple[*tuple[Var[Any], ...], Var[VAR_TYPE]]
| tuple[Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var[VAR_TYPE]] TUPLE_ENDS_IN_VAR_RELAXED = tuple[*tuple[Var[Any] | Any, ...], Var[VAR_TYPE] | VAR_TYPE]
| tuple[Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[
Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]
]
)
@dataclasses.dataclass( @dataclasses.dataclass(
@ -1153,7 +1140,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
def create( def create(
cls, cls,
cond: Any, cond: Any,
cases: Sequence[Sequence[Any | Var[VAR_TYPE]]], cases: Sequence[TUPLE_ENDS_IN_VAR_RELAXED[VAR_TYPE]],
default: Var[VAR_TYPE] | VAR_TYPE, default: Var[VAR_TYPE] | VAR_TYPE,
_var_data: VarData | None = None, _var_data: VarData | None = None,
_var_type: type[VAR_TYPE] | None = None, _var_type: type[VAR_TYPE] | None = None,
@ -1175,6 +1162,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...], tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
tuple(tuple(Var.create(c) for c in case) for case in cases), tuple(tuple(Var.create(c) for c in case) for case in cases),
) )
_default = cast(Var[VAR_TYPE], Var.create(default)) _default = cast(Var[VAR_TYPE], Var.create(default))
var_type = _var_type or unionize( var_type = _var_type or unionize(
*(case[-1]._var_type for case in cases), *(case[-1]._var_type for case in cases),

View File

@ -45,7 +45,7 @@ from .base import (
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
from .sequence import ArrayVar, StringVar from .sequence import ArrayVar, StringVar
OBJECT_TYPE = TypeVar("OBJECT_TYPE") OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
KEY_TYPE = TypeVar("KEY_TYPE") KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE") VALUE_TYPE = TypeVar("VALUE_TYPE")
@ -164,7 +164,8 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload @overload
def __getitem__( def __getitem__(
self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]], self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]]
| ObjectVar[Dict[Any, List[ARRAY_INNER_TYPE]]],
key: Var | Any, key: Var | Any,
) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ... ) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...

View File

@ -28,9 +28,11 @@ def TestEventAction():
def on_click2(self): def on_click2(self):
self.order.append("on_click2") self.order.append("on_click2")
@rx.event
def on_click_throttle(self): def on_click_throttle(self):
self.order.append("on_click_throttle") self.order.append("on_click_throttle")
@rx.event
def on_click_debounce(self): def on_click_debounce(self):
self.order.append("on_click_debounce") self.order.append("on_click_debounce")

View File

@ -22,9 +22,9 @@ def LifespanApp():
@asynccontextmanager @asynccontextmanager
async def lifespan_context(app, inc: int = 1): async def lifespan_context(app, inc: int = 1):
global lifespan_context_global nonlocal lifespan_context_global
print(f"Lifespan context entered: {app}.") print(f"Lifespan context entered: {app}.")
lifespan_context_global += inc # pyright: ignore[reportUnboundVariable] lifespan_context_global += inc
try: try:
yield yield
finally: finally:
@ -32,11 +32,11 @@ def LifespanApp():
lifespan_context_global += inc lifespan_context_global += inc
async def lifespan_task(inc: int = 1): async def lifespan_task(inc: int = 1):
global lifespan_task_global nonlocal lifespan_task_global
print("Lifespan global started.") print("Lifespan global started.")
try: try:
while True: while True:
lifespan_task_global += inc # pyright: ignore[reportUnboundVariable] lifespan_task_global += inc
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
except asyncio.CancelledError as ce: except asyncio.CancelledError as ce:
print(f"Lifespan global cancelled: {ce}.") print(f"Lifespan global cancelled: {ce}.")

View File

@ -19,25 +19,27 @@ def VarOperations():
from reflex.vars.sequence import ArrayVar from reflex.vars.sequence import ArrayVar
class Object(rx.Base): class Object(rx.Base):
str: str = "hello" name: str = "hello"
class VarOperationState(rx.State): class VarOperationState(rx.State):
int_var1: int = 10 int_var1: rx.Field[int] = rx.field(10)
int_var2: int = 5 int_var2: rx.Field[int] = rx.field(5)
int_var3: int = 7 int_var3: rx.Field[int] = rx.field(7)
float_var1: float = 10.5 float_var1: rx.Field[float] = rx.field(10.5)
float_var2: float = 5.5 float_var2: rx.Field[float] = rx.field(5.5)
list1: List = [1, 2] list1: rx.Field[List[int]] = rx.field([1, 2])
list2: List = [3, 4] list2: rx.Field[List[int]] = rx.field([3, 4])
list3: List = ["first", "second", "third"] list3: rx.Field[List[str]] = rx.field(["first", "second", "third"])
list4: List = [Object(name="obj_1"), Object(name="obj_2")] list4: rx.Field[List[Object]] = rx.field(
str_var1: str = "first" [Object(name="obj_1"), Object(name="obj_2")]
str_var2: str = "second" )
str_var3: str = "ThIrD" str_var1: rx.Field[str] = rx.field("first")
str_var4: str = "a long string" str_var2: rx.Field[str] = rx.field("second")
dict1: Dict[int, int] = {1: 2} str_var3: rx.Field[str] = rx.field("ThIrD")
dict2: Dict[int, int] = {3: 4} str_var4: rx.Field[str] = rx.field("a long string")
html_str: str = "<div>hello</div>" dict1: rx.Field[Dict[int, int]] = rx.field({1: 2})
dict2: rx.Field[Dict[int, int]] = rx.field({3: 4})
html_str: rx.Field[str] = rx.field("<div>hello</div>")
app = rx.App(state=rx.State) app = rx.App(state=rx.State)

View File

@ -13,7 +13,7 @@ from reflex.vars.base import LiteralVar, Var, computed_var
@pytest.fixture @pytest.fixture
def cond_state(request): def cond_state(request):
class CondState(BaseState): class CondState(BaseState):
value: request.param["value_type"] = request.param["value"] # noqa value: request.param["value_type"] = request.param["value"] # pyright: ignore[reportInvalidTypeForm, reportUndefinedVariable] # noqa: F821
return CondState return CondState

View File

@ -67,7 +67,7 @@ def test_match_vars(cases, expected):
cases: The match cases. cases: The match cases.
expected: The expected var full name. expected: The expected var full name.
""" """
match_comp = Match.create(MatchState.value, *cases) match_comp = Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
assert isinstance(match_comp, Var) assert isinstance(match_comp, Var)
assert str(match_comp) == expected assert str(match_comp) == expected
@ -131,7 +131,7 @@ def test_match_default_not_last_arg(match_case):
ValueError, ValueError,
match="rx.match should have tuples of cases and a default case as the last argument.", match="rx.match should have tuples of cases and a default case as the last argument.",
): ):
Match.create(MatchState.value, *match_case) Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -161,7 +161,7 @@ def test_match_case_tuple_elements(match_case):
ValueError, ValueError,
match="A case tuple should have at least a match case element and a return value.", match="A case tuple should have at least a match case element and a return value.",
): ):
Match.create(MatchState.value, *match_case) Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -203,7 +203,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str):
error_msg: Expected error message. error_msg: Expected error message.
""" """
with pytest.raises(MatchTypeError, match=error_msg): with pytest.raises(MatchTypeError, match=error_msg):
Match.create(MatchState.value, *cases) Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -235,9 +235,9 @@ def test_match_multiple_default_cases(match_case):
match_case: the cases to match. match_case: the cases to match.
""" """
with pytest.raises(ValueError, match="rx.match can only have one default case."): with pytest.raises(ValueError, match="rx.match can only have one default case."):
Match.create(MatchState.value, *match_case) Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
def test_match_no_cond(): def test_match_no_cond():
with pytest.raises(ValueError): with pytest.raises(ValueError):
_ = Match.create(None) _ = Match.create(None) # pyright: ignore[reportCallIssue]

View File

@ -13,7 +13,8 @@ from reflex.utils.serializers import serialize, serialize_dataframe
pytest.param( pytest.param(
{ {
"data": pd.DataFrame( "data": pd.DataFrame(
[["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"] [["foo", "bar"], ["foo1", "bar1"]],
columns=["column1", "column2"], # pyright: ignore [reportArgumentType]
) )
}, },
"data", "data",
@ -113,7 +114,8 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
def test_serialize_dataframe(): def test_serialize_dataframe():
"""Test if dataframe is serialized correctly.""" """Test if dataframe is serialized correctly."""
df = pd.DataFrame( df = pd.DataFrame(
[["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"] [["foo", "bar"], ["foo1", "bar1"]],
columns=["column1", "column2"], # pyright: ignore [reportArgumentType]
) )
value = serialize(df) value = serialize(df)
assert value == serialize_dataframe(df) assert value == serialize_dataframe(df)

View File

@ -9,7 +9,7 @@ import unittest.mock
import uuid import uuid
from contextlib import nullcontext as does_not_raise from contextlib import nullcontext as does_not_raise
from pathlib import Path from pathlib import Path
from typing import Generator, List, Tuple, Type from typing import Generator, List, Tuple, Type, cast
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest import pytest
@ -33,7 +33,7 @@ from reflex.components import Component
from reflex.components.base.fragment import Fragment from reflex.components.base.fragment import Fragment
from reflex.components.core.cond import Cond from reflex.components.core.cond import Cond
from reflex.components.radix.themes.typography.text import Text from reflex.components.radix.themes.typography.text import Text
from reflex.event import Event from reflex.event import Event, EventHandler
from reflex.middleware import HydrateMiddleware from reflex.middleware import HydrateMiddleware
from reflex.model import Model from reflex.model import Model
from reflex.state import ( from reflex.state import (
@ -917,7 +917,7 @@ class DynamicState(BaseState):
""" """
return self.dynamic return self.dynamic
on_load_internal = OnLoadInternalState.on_load_internal.fn on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn
def test_dynamic_arg_shadow( def test_dynamic_arg_shadow(
@ -1190,7 +1190,7 @@ async def test_process_events(mocker, token: str):
pass pass
assert (await app.state_manager.get_state(event.substate_token)).value == 5 assert (await app.state_manager.get_state(event.substate_token)).value == 5
assert app._postprocess.call_count == 6 assert getattr(app._postprocess, "call_count", None) == 6
if isinstance(app.state_manager, StateManagerRedis): if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close() await app.state_manager.close()
@ -1247,7 +1247,7 @@ def test_overlay_component(
if exp_page_child is not None: if exp_page_child is not None:
assert len(page.children) == 3 assert len(page.children) == 3
children_types = (type(child) for child in page.children) children_types = [type(child) for child in page.children]
assert exp_page_child in children_types assert exp_page_child in children_types
else: else:
assert len(page.children) == 2 assert len(page.children) == 2

View File

@ -5,6 +5,7 @@ import pytest
import reflex as rx import reflex as rx
from reflex.event import ( from reflex.event import (
Event, Event,
EventActionsMixin,
EventChain, EventChain,
EventHandler, EventHandler,
EventSpec, EventSpec,
@ -410,6 +411,7 @@ def test_event_actions():
def test_event_actions_on_state(): def test_event_actions_on_state():
class EventActionState(BaseState): class EventActionState(BaseState):
@rx.event
def handler(self): def handler(self):
pass pass
@ -418,6 +420,7 @@ def test_event_actions_on_state():
assert not handler.event_actions assert not handler.event_actions
sp_handler = EventActionState.handler.stop_propagation sp_handler = EventActionState.handler.stop_propagation
assert isinstance(sp_handler, EventActionsMixin)
assert sp_handler.event_actions == {"stopPropagation": True} assert sp_handler.event_actions == {"stopPropagation": True}
# should NOT affect other references to the handler # should NOT affect other references to the handler
assert not handler.event_actions assert not handler.event_actions

View File

@ -122,9 +122,12 @@ async def test_health(
# Call the async health function # Call the async health function
response = await health() response = await health()
print(json.loads(response.body)) body = response.body
assert isinstance(body, bytes)
print(json.loads(body))
print(expected_status) print(expected_status)
# Verify the response content and status code # Verify the response content and status code
assert response.status_code == expected_code assert response.status_code == expected_code
assert json.loads(response.body) == expected_status assert json.loads(body) == expected_status

View File

@ -59,7 +59,7 @@ def test_automigration(
id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None) id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
# initial table # initial table
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t1: Mapped[str] = mapped_column(default="") t1: Mapped[str] = mapped_column(default="")
with Model.get_db_engine().connect() as connection: with Model.get_db_engine().connect() as connection:
@ -78,7 +78,7 @@ def test_automigration(
model_registry.get_metadata().clear() model_registry.get_metadata().clear()
# Create column t2, mark t1 as optional with default # Create column t2, mark t1 as optional with default
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t1: Mapped[Optional[str]] = mapped_column(default="default") t1: Mapped[Optional[str]] = mapped_column(default="default")
t2: Mapped[str] = mapped_column(default="bar") t2: Mapped[str] = mapped_column(default="bar")
@ -98,7 +98,7 @@ def test_automigration(
model_registry.get_metadata().clear() model_registry.get_metadata().clear()
# Drop column t1 # Drop column t1
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t2: Mapped[str] = mapped_column(default="bar") t2: Mapped[str] = mapped_column(default="bar")
assert Model.migrate(autogenerate=True) assert Model.migrate(autogenerate=True)
@ -133,7 +133,7 @@ def test_automigration(
# drop table (AlembicSecond) # drop table (AlembicSecond)
model_registry.get_metadata().clear() model_registry.get_metadata().clear()
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues] class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t2: Mapped[str] = mapped_column(default="bar") t2: Mapped[str] = mapped_column(default="bar")
assert Model.migrate(autogenerate=True) assert Model.migrate(autogenerate=True)

View File

@ -17,6 +17,7 @@ from typing import (
Dict, Dict,
List, List,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
Union, Union,
@ -120,8 +121,8 @@ class TestState(BaseState):
num2: float = 3.14 num2: float = 3.14
key: str key: str
map_key: str = "a" map_key: str = "a"
array: List[float] = [1, 2, 3.14] array: rx.Field[List[float]] = rx.field([1, 2, 3.14])
mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]} mapping: rx.Field[Dict[str, List[int]]] = rx.field({"a": [1, 2, 3], "b": [4, 5, 6]})
obj: Object = Object() obj: Object = Object()
complex: Dict[int, Object] = {1: Object(), 2: Object()} complex: Dict[int, Object] = {1: Object(), 2: Object()}
fig: Figure = Figure() fig: Figure = Figure()
@ -1357,6 +1358,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
class HandlerState(BaseState): class HandlerState(BaseState):
x: int = 42 x: int = 42
@rx.event
def handler(self): def handler(self):
self.x = self.x + 1 self.x = self.x + 1
@ -1367,11 +1369,11 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
counter += 1 counter += 1
return counter return counter
if use_partial:
HandlerState.handler = functools.partial(HandlerState.handler.fn)
assert isinstance(HandlerState.handler, functools.partial)
else:
assert isinstance(HandlerState.handler, EventHandler) assert isinstance(HandlerState.handler, EventHandler)
if use_partial:
partial_guy = functools.partial(HandlerState.handler.fn)
HandlerState.handler = partial_guy # pyright: ignore[reportAttributeAccessIssue]
assert isinstance(HandlerState.handler, functools.partial)
s = HandlerState() s = HandlerState()
assert "cached_x_side_effect" in s._computed_var_dependencies["x"] assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
@ -2025,8 +2027,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
# ensure state update was emitted # ensure state update was emitted
assert mock_app.event_namespace is not None assert mock_app.event_namespace is not None
mock_app.event_namespace.emit.assert_called_once() mock_app.event_namespace.emit.assert_called_once() # pyright: ignore[reportFunctionMemberAccess]
mcall = mock_app.event_namespace.emit.mock_calls[0] mock_calls = getattr(mock_app.event_namespace.emit, "mock_calls", None)
assert mock_calls is not None
assert isinstance(mock_calls, Sequence)
mcall = mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT) assert mcall.args[0] == str(SocketEvent.EVENT)
assert mcall.args[1] == StateUpdate( assert mcall.args[1] == StateUpdate(
delta={ delta={
@ -2231,7 +2236,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
assert mock_app.event_namespace is not None assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit emit_mock = mock_app.event_namespace.emit
first_ws_message = emit_mock.mock_calls[0].args[1] mock_calls = getattr(emit_mock, "mock_calls", None)
assert mock_calls is not None
assert isinstance(mock_calls, Sequence)
first_ws_message = mock_calls[0].args[1]
assert ( assert (
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router") first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
is not None is not None
@ -2246,7 +2255,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[], events=[],
final=True, final=True,
) )
for call in emit_mock.mock_calls[1:5]: for call in mock_calls[1:5]:
assert call.args[1] == StateUpdate( assert call.args[1] == StateUpdate(
delta={ delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
@ -2256,7 +2265,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[], events=[],
final=True, final=True,
) )
assert emit_mock.mock_calls[-2].args[1] == StateUpdate( assert mock_calls[-2].args[1] == StateUpdate(
delta={ delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"order": exp_order, "order": exp_order,
@ -2267,7 +2276,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[], events=[],
final=True, final=True,
) )
assert emit_mock.mock_calls[-1].args[1] == StateUpdate( assert mock_calls[-1].args[1] == StateUpdate(
delta={ delta={
BackgroundTaskState.get_full_name(): { BackgroundTaskState.get_full_name(): {
"computed_order": exp_order, "computed_order": exp_order,