down to only two pyright error

This commit is contained in:
Khaleel Al-Adhami 2025-01-17 14:16:03 -08:00
parent 112b2ed948
commit 57d8ea02e9
26 changed files with 225 additions and 181 deletions

8
poetry.lock generated
View File

@ -2813,13 +2813,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
[[package]]
name = "virtualenv"
version = "20.28.1"
version = "20.29.1"
description = "Virtual Python Environment builder"
optional = false
python-versions = ">=3.8"
files = [
{file = "virtualenv-20.28.1-py3-none-any.whl", hash = "sha256:412773c85d4dab0409b83ec36f7a6499e72eaf08c80e81e9576bca61831c71cb"},
{file = "virtualenv-20.28.1.tar.gz", hash = "sha256:5d34ab240fdb5d21549b76f9e8ff3af28252f5499fb6d6f031adac4e5a8c5329"},
{file = "virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779"},
{file = "virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35"},
]
[package.dependencies]
@ -3063,4 +3063,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "ac7633388e00416af61c93e16c61c3b4a446e406afab2758585088e2b3416eee"
content-hash = "ccd6d6b00fdcf40562854380fafdb18c990b7f6a4f2883b33aaeb0351fcdbc06"

View File

@ -23,7 +23,7 @@ fastapi = ">=0.96.0,!=0.111.0,!=0.111.1"
gunicorn = ">=20.1.0,<24.0"
jinja2 = ">=3.1.2,<4.0"
psutil = ">=5.9.4,<7.0"
pydantic = ">=1.10.2,<3.0"
pydantic = ">=1.10.15,<3.0"
python-multipart = ">=0.0.5,<0.1"
python-socketio = ">=5.7.0,<6.0"
redis = ">=4.3.5,<6.0"
@ -82,6 +82,7 @@ build-backend = "poetry.core.masonry.api"
[tool.pyright]
reportIncompatibleMethodOverride = false
reportIncompatibleVariableOverride = false
[tool.ruff]
target-version = "py39"

View File

@ -5,15 +5,9 @@ from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any, List, Type
try:
import pydantic.v1.main as pydantic_main
from pydantic.v1 import BaseModel
from pydantic.v1.fields import ModelField
except ModuleNotFoundError:
if not TYPE_CHECKING:
import pydantic.main as pydantic_main
from pydantic import BaseModel
from pydantic.fields import ModelField # type: ignore
import pydantic.v1.main as pydantic_main
from pydantic.v1 import BaseModel
from pydantic.v1.fields import ModelField
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:

View File

@ -113,11 +113,7 @@ class Cond(MemoizationLeaf):
@overload
def cond(condition: Any, c1: Component, c2: Any) -> Component: ...
@overload
def cond(condition: Any, c1: Component) -> Component: ...
def cond(condition: Any, c1: Component, c2: Any = None) -> Component: ...
@overload

View File

@ -1,15 +1,17 @@
"""rx.match."""
import textwrap
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, List, cast
from reflex.components.base import Fragment
from reflex.components.component import BaseComponent, Component, MemoizationLeaf
from reflex.utils import types
from reflex.utils.exceptions import MatchTypeError
from reflex.vars.base import Var
from reflex.vars.base import VAR_TYPE, Var
from reflex.vars.number import MatchOperation
CASE_TYPE = tuple[*tuple[Any, ...], Var[VAR_TYPE] | VAR_TYPE]
class Match(MemoizationLeaf):
"""Match cases based on a condition."""
@ -24,7 +26,11 @@ class Match(MemoizationLeaf):
default: Any
@classmethod
def create(cls, cond: Any, *cases) -> Union[Component, Var]:
def create(
cls,
cond: Any,
*cases: *tuple[*tuple[CASE_TYPE[VAR_TYPE], ...], Var[VAR_TYPE] | VAR_TYPE],
) -> Var[VAR_TYPE]:
"""Create a Match Component.
Args:
@ -37,44 +43,6 @@ class Match(MemoizationLeaf):
Raises:
ValueError: When a default case is not provided for cases with Var return types.
"""
cases, default = cls._process_cases(cases)
cls._process_match_cases(cases)
cls._validate_return_types(cases)
if default is None and any(
not (
isinstance((return_type := case[-1]), Component)
or (
isinstance(return_type, Var)
and types.typehint_issubclass(return_type._var_type, Component)
)
)
for case in cases
):
raise ValueError(
"For cases with return types as Vars, a default case must be provided"
)
elif default is None:
default = Fragment.create()
return cls._create_match_cond_var_or_component(cond, cases, default)
@classmethod
def _process_cases(
cls, cases: Sequence
) -> Tuple[List, Optional[Union[Var, BaseComponent]]]:
"""Process the list of match cases and the catchall default case.
Args:
cases: The list of match cases.
Returns:
The default case and the list of match case tuples.
Raises:
ValueError: If there are multiple default cases.
"""
default = None
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
@ -86,12 +54,40 @@ class Match(MemoizationLeaf):
# Get the default case which should be the last non-tuple arg
if not isinstance(cases[-1], tuple):
default = cases[-1]
cases = cases[:-1]
actual_cases = cases[:-1]
else:
actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
return list(cases), default
cls._process_match_cases(actual_cases)
cls._validate_return_types(actual_cases)
if default is None and any(
not (
isinstance((return_type := case[-1]), Component)
or (
isinstance(return_type, Var)
and types.typehint_issubclass(return_type._var_type, Component)
)
)
for case in actual_cases
):
raise ValueError(
"For cases with return types as Vars, a default case must be provided"
)
elif default is None:
default = Fragment.create()
default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
return cls._create_match_cond_var_or_component(
cond,
actual_cases,
default,
)
@classmethod
def _process_match_cases(cls, cases: Sequence):
def _process_match_cases(cls, cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
"""Process the individual match cases.
Args:
@ -116,7 +112,9 @@ class Match(MemoizationLeaf):
)
@classmethod
def _validate_return_types(cls, match_cases: List[List[Var]]) -> None:
def _validate_return_types(
cls, match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]
) -> None:
"""Validate that match cases have the same return types.
Args:
@ -151,9 +149,9 @@ class Match(MemoizationLeaf):
def _create_match_cond_var_or_component(
cls,
match_cond_var: Var,
match_cases: List[List[Var]],
default: Union[Var, BaseComponent],
) -> Union[Component, Var]:
match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
default: VAR_TYPE | Var[VAR_TYPE],
) -> Var[VAR_TYPE]:
"""Create and return the match condition var or component.
Args:

View File

@ -303,7 +303,7 @@ class DataEditor(NoSSRComponent):
# Fired when editing is finished.
on_finished_editing: EventHandler[
passthrough_event_spec(Union[GridCell, None], tuple[int, int])
passthrough_event_spec(Union[GridCell, None], tuple[int, int]) # pyright: ignore[reportArgumentType]
]
# Fired when a row is appended.

View File

@ -197,7 +197,7 @@ class AccordionItem(AccordionComponent):
# The header of the accordion item.
header: Var[Union[Component, str]]
# The content of the accordion item.
content: Var[Union[Component, str]] = Var.create(None)
content: Var[Union[Component, str]] = Var.create("")
_valid_children: List[str] = [
"AccordionHeader",

View File

@ -4,9 +4,10 @@ from __future__ import annotations
import dataclasses
import inspect
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, Union, get_args
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args
from reflex.components.tags.tag import Tag
from reflex.utils import types
from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
if TYPE_CHECKING:
@ -31,7 +32,7 @@ class IterTag(Tag):
# The name of the index var.
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
def get_iterable_var_type(self) -> Type:
def get_iterable_var_type(self) -> types.GenericType:
"""Get the type of the iterable var.
Returns:
@ -41,10 +42,10 @@ class IterTag(Tag):
try:
if iterable._var_type.mro()[0] is dict:
# Arg is a tuple of (key, value).
return Tuple[get_args(iterable._var_type)] # type: ignore
return Tuple[get_args(iterable._var_type)]
elif iterable._var_type.mro()[0] is tuple:
# Arg is a union of any possible values in the tuple.
return Union[get_args(iterable._var_type)] # type: ignore
return Union[get_args(iterable._var_type)]
else:
return get_args(iterable._var_type)[0]
except Exception:

View File

@ -25,7 +25,6 @@ from typing import (
overload,
)
import typing_extensions
from typing_extensions import (
Concatenate,
ParamSpec,
@ -33,6 +32,8 @@ from typing_extensions import (
TypeAliasType,
TypedDict,
TypeVar,
TypeVarTuple,
deprecated,
get_args,
get_origin,
)
@ -620,14 +621,16 @@ stop_propagation = EventChain(events=[], args_spec=no_args_event_spec).stop_prop
prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default
T = TypeVar("T")
U = TypeVar("U")
EVENT_T = TypeVar("EVENT_T")
EVENT_U = TypeVar("EVENT_U")
Ts = TypeVarTuple("Ts")
class IdentityEventReturn(Generic[T], Protocol):
class IdentityEventReturn(Generic[*Ts], Protocol):
"""Protocol for an identity event return."""
def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]:
def __call__(self, *values: *Ts) -> tuple[*Ts]:
"""Return the input values.
Args:
@ -641,21 +644,25 @@ class IdentityEventReturn(Generic[T], Protocol):
@overload
def passthrough_event_spec(
event_type: Type[T], /
) -> Callable[[Var[T]], Tuple[Var[T]]]: ... # type: ignore
event_type: Type[EVENT_T], /
) -> IdentityEventReturn[Var[EVENT_T]]: ...
@overload
def passthrough_event_spec(
event_type_1: Type[T], event_type2: Type[U], /
) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ...
event_type_1: Type[EVENT_T], event_type2: Type[EVENT_U], /
) -> IdentityEventReturn[Var[EVENT_T], Var[EVENT_U]]: ...
@overload
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: ...
def passthrough_event_spec(
*event_types: *tuple[Type[EVENT_T]],
) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]: ...
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: # type: ignore
def passthrough_event_spec( # pyright: ignore[reportInconsistentOverload]
*event_types: Type[EVENT_T],
) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]:
"""A helper function that returns the input event as output.
Args:
@ -665,7 +672,7 @@ def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: #
A function that returns the input event as output.
"""
def inner(*values: Var[T]) -> Tuple[Var[T], ...]:
def inner(*values: Var[EVENT_T]) -> Tuple[Var[EVENT_T], ...]:
return values
inner_type = tuple(Var[event_type] for event_type in event_types)
@ -800,7 +807,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
return None
fn.__qualname__ = name
fn.__signature__ = sig
fn.__signature__ = sig # pyright: ignore[reportFunctionMemberAccess]
return EventSpec(
handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
args=tuple(
@ -822,7 +829,7 @@ def redirect(
@overload
@typing_extensions.deprecated("`external` is deprecated use `is_external` instead")
@deprecated("`external` is deprecated use `is_external` instead")
def redirect(
path: str | Var[str],
is_external: Optional[bool] = None,
@ -1826,6 +1833,37 @@ class EventCallback(Generic[P, T]):
"""
self.func = func
def throttle(self, limit_ms: int):
"""Throttle the event handler.
Args:
limit_ms: The time in milliseconds to throttle the event handler.
Returns:
New EventHandler-like with throttle set to limit_ms.
"""
return self
def debounce(self, delay_ms: int):
"""Debounce the event handler.
Args:
delay_ms: The time in milliseconds to debounce the event handler.
Returns:
New EventHandler-like with debounce set to delay_ms.
"""
return self
@property
def temporal(self):
"""Do not queue the event if the backend is down.
Returns:
New EventHandler-like with temporal set to True.
"""
return self
@property
def prevent_default(self):
"""Prevent default behavior.

View File

@ -587,8 +587,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
if cls._item_is_event_handler(name, fn)
}
for mixin in cls._mixins():
for name, value in mixin.__dict__.items():
for mixin_class in cls._mixins():
for name, value in mixin_class.__dict__.items():
if name in cls.inherited_vars:
continue
if is_computed_var(value):
@ -599,7 +599,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
cls.computed_vars[newcv._js_expr] = newcv
cls.vars[newcv._js_expr] = newcv
continue
if types.is_backend_base_variable(name, mixin):
if types.is_backend_base_variable(name, mixin_class):
cls.backend_vars[name] = copy.deepcopy(value)
continue
if events.get(name) is not None:
@ -3710,6 +3710,9 @@ def get_state_manager() -> StateManager:
return app.state_manager
DATACLASS_FIELDS = getattr(dataclasses, "_FIELDS", "__dataclass_fields__")
class MutableProxy(wrapt.ObjectProxy):
"""A proxy for a mutable object that tracks changes."""
@ -3781,12 +3784,7 @@ class MutableProxy(wrapt.ObjectProxy):
cls.__dataclass_proxies__[wrapper_cls_name] = type(
wrapper_cls_name,
(cls,),
{
dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues]
wrapped_cls,
dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues]
),
},
{DATACLASS_FIELDS: getattr(wrapped_cls, DATACLASS_FIELDS)},
)
cls = cls.__dataclass_proxies__[wrapper_cls_name]
return super().__new__(cls)
@ -3933,11 +3931,11 @@ class MutableProxy(wrapt.ObjectProxy):
if (
isinstance(self.__wrapped__, Base)
and __name not in self.__never_wrap_base_attrs__
and hasattr(value, "__func__")
and (value_func := getattr(value, "__func__", None))
):
# Wrap methods called on Base subclasses, which might do _anything_
return wrapt.FunctionWrapper(
functools.partial(value.__func__, self),
functools.partial(value_func, self),
self._wrap_recursive_decorator,
)

View File

@ -67,10 +67,8 @@ try:
from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports]
WebElement,
)
has_selenium = True
except ImportError:
has_selenium = False
webdriver = None
# The timeout (minutes) to check for the port.
DEFAULT_TIMEOUT = 15
@ -293,8 +291,12 @@ class AppHarness:
if p not in before_decorated_pages
]
self.app_instance = self.app_module.app
if self.app_instance is None:
raise RuntimeError("App was not initialized.")
if isinstance(self.app_instance._state_manager, StateManagerRedis):
# Create our own redis connection for testing.
if self.app_instance.state is None:
raise RuntimeError("App state is not initialized.")
self.state_manager = StateManagerRedis.create(self.app_instance.state)
else:
self.state_manager = self.app_instance._state_manager
@ -608,7 +610,7 @@ class AppHarness:
Raises:
RuntimeError: when selenium is not importable or frontend is not running
"""
if not has_selenium:
if webdriver is None:
raise RuntimeError(
"Frontend functionality requires `selenium` to be installed, "
"and it could not be imported."

View File

@ -203,10 +203,13 @@ def _get_first_non_framework_frame() -> FrameType | None:
# Exclude utility modules that should never be the source of deprecated reflex usage.
exclude_modules = [click, rx, typer, typing_extensions]
exclude_roots = [
p.parent.resolve()
if (p := Path(m.__file__)).name == "__init__.py"
else p.resolve()
(
p.parent.resolve()
if (p := Path(m.__file__)).name == "__init__.py"
else p.resolve()
)
for m in exclude_modules
if m.__file__
]
# Specifically exclude the reflex cli module.
if reflex_bin := shutil.which(b"reflex"):

View File

@ -3197,17 +3197,20 @@ class Field(Generic[T]):
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
@overload
def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
def __get__(self: Field[int], instance: None, owner) -> NumberVar[int]: ...
@overload
def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
def __get__(self: Field[float], instance: None, owner) -> NumberVar[float]: ...
@overload
def __get__(self: Field[str], instance: None, owner) -> StringVar[str]: ...
@overload
def __get__(self: Field[None], instance: None, owner) -> NoneVar: ...
@overload
def __get__(
self: Field[Sequence[V]] | Field[Set[V]],
self: Field[Sequence[V]] | Field[Set[V]] | Field[List[V]],
instance: None,
owner,
) -> ArrayVar[Sequence[V]]: ...

View File

@ -1069,24 +1069,11 @@ def ternary_operation(
return value
TUPLE_ENDS_IN_VAR = (
tuple[Var[VAR_TYPE]]
| tuple[Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
| tuple[
Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]
]
)
X = tuple[*tuple[Var, ...], str]
TUPLE_ENDS_IN_VAR = tuple[*tuple[Var[Any], ...], Var[VAR_TYPE]]
TUPLE_ENDS_IN_VAR_RELAXED = tuple[*tuple[Var[Any] | Any, ...], Var[VAR_TYPE] | VAR_TYPE]
@dataclasses.dataclass(
@ -1153,7 +1140,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
def create(
cls,
cond: Any,
cases: Sequence[Sequence[Any | Var[VAR_TYPE]]],
cases: Sequence[TUPLE_ENDS_IN_VAR_RELAXED[VAR_TYPE]],
default: Var[VAR_TYPE] | VAR_TYPE,
_var_data: VarData | None = None,
_var_type: type[VAR_TYPE] | None = None,
@ -1175,6 +1162,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
tuple(tuple(Var.create(c) for c in case) for case in cases),
)
_default = cast(Var[VAR_TYPE], Var.create(default))
var_type = _var_type or unionize(
*(case[-1]._var_type for case in cases),

View File

@ -45,7 +45,7 @@ from .base import (
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
from .sequence import ArrayVar, StringVar
OBJECT_TYPE = TypeVar("OBJECT_TYPE")
OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
KEY_TYPE = TypeVar("KEY_TYPE")
VALUE_TYPE = TypeVar("VALUE_TYPE")
@ -164,7 +164,8 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
@overload
def __getitem__(
self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]],
self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]]
| ObjectVar[Dict[Any, List[ARRAY_INNER_TYPE]]],
key: Var | Any,
) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...

View File

@ -28,9 +28,11 @@ def TestEventAction():
def on_click2(self):
self.order.append("on_click2")
@rx.event
def on_click_throttle(self):
self.order.append("on_click_throttle")
@rx.event
def on_click_debounce(self):
self.order.append("on_click_debounce")

View File

@ -22,9 +22,9 @@ def LifespanApp():
@asynccontextmanager
async def lifespan_context(app, inc: int = 1):
global lifespan_context_global
nonlocal lifespan_context_global
print(f"Lifespan context entered: {app}.")
lifespan_context_global += inc # pyright: ignore[reportUnboundVariable]
lifespan_context_global += inc
try:
yield
finally:
@ -32,11 +32,11 @@ def LifespanApp():
lifespan_context_global += inc
async def lifespan_task(inc: int = 1):
global lifespan_task_global
nonlocal lifespan_task_global
print("Lifespan global started.")
try:
while True:
lifespan_task_global += inc # pyright: ignore[reportUnboundVariable]
lifespan_task_global += inc
await asyncio.sleep(0.1)
except asyncio.CancelledError as ce:
print(f"Lifespan global cancelled: {ce}.")

View File

@ -19,25 +19,27 @@ def VarOperations():
from reflex.vars.sequence import ArrayVar
class Object(rx.Base):
str: str = "hello"
name: str = "hello"
class VarOperationState(rx.State):
int_var1: int = 10
int_var2: int = 5
int_var3: int = 7
float_var1: float = 10.5
float_var2: float = 5.5
list1: List = [1, 2]
list2: List = [3, 4]
list3: List = ["first", "second", "third"]
list4: List = [Object(name="obj_1"), Object(name="obj_2")]
str_var1: str = "first"
str_var2: str = "second"
str_var3: str = "ThIrD"
str_var4: str = "a long string"
dict1: Dict[int, int] = {1: 2}
dict2: Dict[int, int] = {3: 4}
html_str: str = "<div>hello</div>"
int_var1: rx.Field[int] = rx.field(10)
int_var2: rx.Field[int] = rx.field(5)
int_var3: rx.Field[int] = rx.field(7)
float_var1: rx.Field[float] = rx.field(10.5)
float_var2: rx.Field[float] = rx.field(5.5)
list1: rx.Field[List[int]] = rx.field([1, 2])
list2: rx.Field[List[int]] = rx.field([3, 4])
list3: rx.Field[List[str]] = rx.field(["first", "second", "third"])
list4: rx.Field[List[Object]] = rx.field(
[Object(name="obj_1"), Object(name="obj_2")]
)
str_var1: rx.Field[str] = rx.field("first")
str_var2: rx.Field[str] = rx.field("second")
str_var3: rx.Field[str] = rx.field("ThIrD")
str_var4: rx.Field[str] = rx.field("a long string")
dict1: rx.Field[Dict[int, int]] = rx.field({1: 2})
dict2: rx.Field[Dict[int, int]] = rx.field({3: 4})
html_str: rx.Field[str] = rx.field("<div>hello</div>")
app = rx.App(state=rx.State)

View File

@ -13,7 +13,7 @@ from reflex.vars.base import LiteralVar, Var, computed_var
@pytest.fixture
def cond_state(request):
class CondState(BaseState):
value: request.param["value_type"] = request.param["value"] # noqa
value: request.param["value_type"] = request.param["value"] # pyright: ignore[reportInvalidTypeForm, reportUndefinedVariable] # noqa: F821
return CondState

View File

@ -67,7 +67,7 @@ def test_match_vars(cases, expected):
cases: The match cases.
expected: The expected var full name.
"""
match_comp = Match.create(MatchState.value, *cases)
match_comp = Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
assert isinstance(match_comp, Var)
assert str(match_comp) == expected
@ -131,7 +131,7 @@ def test_match_default_not_last_arg(match_case):
ValueError,
match="rx.match should have tuples of cases and a default case as the last argument.",
):
Match.create(MatchState.value, *match_case)
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize(
@ -161,7 +161,7 @@ def test_match_case_tuple_elements(match_case):
ValueError,
match="A case tuple should have at least a match case element and a return value.",
):
Match.create(MatchState.value, *match_case)
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize(
@ -203,7 +203,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str):
error_msg: Expected error message.
"""
with pytest.raises(MatchTypeError, match=error_msg):
Match.create(MatchState.value, *cases)
Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
@pytest.mark.parametrize(
@ -235,9 +235,9 @@ def test_match_multiple_default_cases(match_case):
match_case: the cases to match.
"""
with pytest.raises(ValueError, match="rx.match can only have one default case."):
Match.create(MatchState.value, *match_case)
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
def test_match_no_cond():
with pytest.raises(ValueError):
_ = Match.create(None)
_ = Match.create(None) # pyright: ignore[reportCallIssue]

View File

@ -13,7 +13,8 @@ from reflex.utils.serializers import serialize, serialize_dataframe
pytest.param(
{
"data": pd.DataFrame(
[["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
[["foo", "bar"], ["foo1", "bar1"]],
columns=["column1", "column2"], # pyright: ignore [reportArgumentType]
)
},
"data",
@ -113,7 +114,8 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
def test_serialize_dataframe():
"""Test if dataframe is serialized correctly."""
df = pd.DataFrame(
[["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
[["foo", "bar"], ["foo1", "bar1"]],
columns=["column1", "column2"], # pyright: ignore [reportArgumentType]
)
value = serialize(df)
assert value == serialize_dataframe(df)

View File

@ -9,7 +9,7 @@ import unittest.mock
import uuid
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import Generator, List, Tuple, Type
from typing import Generator, List, Tuple, Type, cast
from unittest.mock import AsyncMock
import pytest
@ -33,7 +33,7 @@ from reflex.components import Component
from reflex.components.base.fragment import Fragment
from reflex.components.core.cond import Cond
from reflex.components.radix.themes.typography.text import Text
from reflex.event import Event
from reflex.event import Event, EventHandler
from reflex.middleware import HydrateMiddleware
from reflex.model import Model
from reflex.state import (
@ -917,7 +917,7 @@ class DynamicState(BaseState):
"""
return self.dynamic
on_load_internal = OnLoadInternalState.on_load_internal.fn
on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn
def test_dynamic_arg_shadow(
@ -1190,7 +1190,7 @@ async def test_process_events(mocker, token: str):
pass
assert (await app.state_manager.get_state(event.substate_token)).value == 5
assert app._postprocess.call_count == 6
assert getattr(app._postprocess, "call_count", None) == 6
if isinstance(app.state_manager, StateManagerRedis):
await app.state_manager.close()
@ -1247,7 +1247,7 @@ def test_overlay_component(
if exp_page_child is not None:
assert len(page.children) == 3
children_types = (type(child) for child in page.children)
children_types = [type(child) for child in page.children]
assert exp_page_child in children_types
else:
assert len(page.children) == 2

View File

@ -5,6 +5,7 @@ import pytest
import reflex as rx
from reflex.event import (
Event,
EventActionsMixin,
EventChain,
EventHandler,
EventSpec,
@ -410,6 +411,7 @@ def test_event_actions():
def test_event_actions_on_state():
class EventActionState(BaseState):
@rx.event
def handler(self):
pass
@ -418,6 +420,7 @@ def test_event_actions_on_state():
assert not handler.event_actions
sp_handler = EventActionState.handler.stop_propagation
assert isinstance(sp_handler, EventActionsMixin)
assert sp_handler.event_actions == {"stopPropagation": True}
# should NOT affect other references to the handler
assert not handler.event_actions

View File

@ -122,9 +122,12 @@ async def test_health(
# Call the async health function
response = await health()
print(json.loads(response.body))
body = response.body
assert isinstance(body, bytes)
print(json.loads(body))
print(expected_status)
# Verify the response content and status code
assert response.status_code == expected_code
assert json.loads(response.body) == expected_status
assert json.loads(body) == expected_status

View File

@ -59,7 +59,7 @@ def test_automigration(
id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
# initial table
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t1: Mapped[str] = mapped_column(default="")
with Model.get_db_engine().connect() as connection:
@ -78,7 +78,7 @@ def test_automigration(
model_registry.get_metadata().clear()
# Create column t2, mark t1 as optional with default
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t1: Mapped[Optional[str]] = mapped_column(default="default")
t2: Mapped[str] = mapped_column(default="bar")
@ -98,7 +98,7 @@ def test_automigration(
model_registry.get_metadata().clear()
# Drop column t1
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t2: Mapped[str] = mapped_column(default="bar")
assert Model.migrate(autogenerate=True)
@ -133,7 +133,7 @@ def test_automigration(
# drop table (AlembicSecond)
model_registry.get_metadata().clear()
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
t2: Mapped[str] = mapped_column(default="bar")
assert Model.migrate(autogenerate=True)

View File

@ -17,6 +17,7 @@ from typing import (
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
@ -120,8 +121,8 @@ class TestState(BaseState):
num2: float = 3.14
key: str
map_key: str = "a"
array: List[float] = [1, 2, 3.14]
mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]}
array: rx.Field[List[float]] = rx.field([1, 2, 3.14])
mapping: rx.Field[Dict[str, List[int]]] = rx.field({"a": [1, 2, 3], "b": [4, 5, 6]})
obj: Object = Object()
complex: Dict[int, Object] = {1: Object(), 2: Object()}
fig: Figure = Figure()
@ -1357,6 +1358,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
class HandlerState(BaseState):
x: int = 42
@rx.event
def handler(self):
self.x = self.x + 1
@ -1367,11 +1369,11 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
counter += 1
return counter
assert isinstance(HandlerState.handler, EventHandler)
if use_partial:
HandlerState.handler = functools.partial(HandlerState.handler.fn)
partial_guy = functools.partial(HandlerState.handler.fn)
HandlerState.handler = partial_guy # pyright: ignore[reportAttributeAccessIssue]
assert isinstance(HandlerState.handler, functools.partial)
else:
assert isinstance(HandlerState.handler, EventHandler)
s = HandlerState()
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
@ -2025,8 +2027,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
# ensure state update was emitted
assert mock_app.event_namespace is not None
mock_app.event_namespace.emit.assert_called_once()
mcall = mock_app.event_namespace.emit.mock_calls[0]
mock_app.event_namespace.emit.assert_called_once() # pyright: ignore[reportFunctionMemberAccess]
mock_calls = getattr(mock_app.event_namespace.emit, "mock_calls", None)
assert mock_calls is not None
assert isinstance(mock_calls, Sequence)
mcall = mock_calls[0]
assert mcall.args[0] == str(SocketEvent.EVENT)
assert mcall.args[1] == StateUpdate(
delta={
@ -2231,7 +2236,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
assert mock_app.event_namespace is not None
emit_mock = mock_app.event_namespace.emit
first_ws_message = emit_mock.mock_calls[0].args[1]
mock_calls = getattr(emit_mock, "mock_calls", None)
assert mock_calls is not None
assert isinstance(mock_calls, Sequence)
first_ws_message = mock_calls[0].args[1]
assert (
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
is not None
@ -2246,7 +2255,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[],
final=True,
)
for call in emit_mock.mock_calls[1:5]:
for call in mock_calls[1:5]:
assert call.args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
@ -2256,7 +2265,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[],
final=True,
)
assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
assert mock_calls[-2].args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"order": exp_order,
@ -2267,7 +2276,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
events=[],
final=True,
)
assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
assert mock_calls[-1].args[1] == StateUpdate(
delta={
BackgroundTaskState.get_full_name(): {
"computed_order": exp_order,