down to only two pyright error
This commit is contained in:
parent
112b2ed948
commit
57d8ea02e9
8
poetry.lock
generated
8
poetry.lock
generated
@ -2813,13 +2813,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "20.28.1"
|
||||
version = "20.29.1"
|
||||
description = "Virtual Python Environment builder"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "virtualenv-20.28.1-py3-none-any.whl", hash = "sha256:412773c85d4dab0409b83ec36f7a6499e72eaf08c80e81e9576bca61831c71cb"},
|
||||
{file = "virtualenv-20.28.1.tar.gz", hash = "sha256:5d34ab240fdb5d21549b76f9e8ff3af28252f5499fb6d6f031adac4e5a8c5329"},
|
||||
{file = "virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779"},
|
||||
{file = "virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -3063,4 +3063,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "ac7633388e00416af61c93e16c61c3b4a446e406afab2758585088e2b3416eee"
|
||||
content-hash = "ccd6d6b00fdcf40562854380fafdb18c990b7f6a4f2883b33aaeb0351fcdbc06"
|
||||
|
@ -23,7 +23,7 @@ fastapi = ">=0.96.0,!=0.111.0,!=0.111.1"
|
||||
gunicorn = ">=20.1.0,<24.0"
|
||||
jinja2 = ">=3.1.2,<4.0"
|
||||
psutil = ">=5.9.4,<7.0"
|
||||
pydantic = ">=1.10.2,<3.0"
|
||||
pydantic = ">=1.10.15,<3.0"
|
||||
python-multipart = ">=0.0.5,<0.1"
|
||||
python-socketio = ">=5.7.0,<6.0"
|
||||
redis = ">=4.3.5,<6.0"
|
||||
@ -82,6 +82,7 @@ build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pyright]
|
||||
reportIncompatibleMethodOverride = false
|
||||
reportIncompatibleVariableOverride = false
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
|
@ -5,15 +5,9 @@ from __future__ import annotations
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, List, Type
|
||||
|
||||
try:
|
||||
import pydantic.v1.main as pydantic_main
|
||||
from pydantic.v1 import BaseModel
|
||||
from pydantic.v1.fields import ModelField
|
||||
except ModuleNotFoundError:
|
||||
if not TYPE_CHECKING:
|
||||
import pydantic.main as pydantic_main
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import ModelField # type: ignore
|
||||
import pydantic.v1.main as pydantic_main
|
||||
from pydantic.v1 import BaseModel
|
||||
from pydantic.v1.fields import ModelField
|
||||
|
||||
|
||||
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
|
||||
|
@ -113,11 +113,7 @@ class Cond(MemoizationLeaf):
|
||||
|
||||
|
||||
@overload
|
||||
def cond(condition: Any, c1: Component, c2: Any) -> Component: ...
|
||||
|
||||
|
||||
@overload
|
||||
def cond(condition: Any, c1: Component) -> Component: ...
|
||||
def cond(condition: Any, c1: Component, c2: Any = None) -> Component: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
@ -1,15 +1,17 @@
|
||||
"""rx.match."""
|
||||
|
||||
import textwrap
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, List, cast
|
||||
|
||||
from reflex.components.base import Fragment
|
||||
from reflex.components.component import BaseComponent, Component, MemoizationLeaf
|
||||
from reflex.utils import types
|
||||
from reflex.utils.exceptions import MatchTypeError
|
||||
from reflex.vars.base import Var
|
||||
from reflex.vars.base import VAR_TYPE, Var
|
||||
from reflex.vars.number import MatchOperation
|
||||
|
||||
CASE_TYPE = tuple[*tuple[Any, ...], Var[VAR_TYPE] | VAR_TYPE]
|
||||
|
||||
|
||||
class Match(MemoizationLeaf):
|
||||
"""Match cases based on a condition."""
|
||||
@ -24,7 +26,11 @@ class Match(MemoizationLeaf):
|
||||
default: Any
|
||||
|
||||
@classmethod
|
||||
def create(cls, cond: Any, *cases) -> Union[Component, Var]:
|
||||
def create(
|
||||
cls,
|
||||
cond: Any,
|
||||
*cases: *tuple[*tuple[CASE_TYPE[VAR_TYPE], ...], Var[VAR_TYPE] | VAR_TYPE],
|
||||
) -> Var[VAR_TYPE]:
|
||||
"""Create a Match Component.
|
||||
|
||||
Args:
|
||||
@ -37,44 +43,6 @@ class Match(MemoizationLeaf):
|
||||
Raises:
|
||||
ValueError: When a default case is not provided for cases with Var return types.
|
||||
"""
|
||||
cases, default = cls._process_cases(cases)
|
||||
cls._process_match_cases(cases)
|
||||
|
||||
cls._validate_return_types(cases)
|
||||
|
||||
if default is None and any(
|
||||
not (
|
||||
isinstance((return_type := case[-1]), Component)
|
||||
or (
|
||||
isinstance(return_type, Var)
|
||||
and types.typehint_issubclass(return_type._var_type, Component)
|
||||
)
|
||||
)
|
||||
for case in cases
|
||||
):
|
||||
raise ValueError(
|
||||
"For cases with return types as Vars, a default case must be provided"
|
||||
)
|
||||
elif default is None:
|
||||
default = Fragment.create()
|
||||
|
||||
return cls._create_match_cond_var_or_component(cond, cases, default)
|
||||
|
||||
@classmethod
|
||||
def _process_cases(
|
||||
cls, cases: Sequence
|
||||
) -> Tuple[List, Optional[Union[Var, BaseComponent]]]:
|
||||
"""Process the list of match cases and the catchall default case.
|
||||
|
||||
Args:
|
||||
cases: The list of match cases.
|
||||
|
||||
Returns:
|
||||
The default case and the list of match case tuples.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are multiple default cases.
|
||||
"""
|
||||
default = None
|
||||
|
||||
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
|
||||
@ -86,12 +54,40 @@ class Match(MemoizationLeaf):
|
||||
# Get the default case which should be the last non-tuple arg
|
||||
if not isinstance(cases[-1], tuple):
|
||||
default = cases[-1]
|
||||
cases = cases[:-1]
|
||||
actual_cases = cases[:-1]
|
||||
else:
|
||||
actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
|
||||
|
||||
return list(cases), default
|
||||
cls._process_match_cases(actual_cases)
|
||||
|
||||
cls._validate_return_types(actual_cases)
|
||||
|
||||
if default is None and any(
|
||||
not (
|
||||
isinstance((return_type := case[-1]), Component)
|
||||
or (
|
||||
isinstance(return_type, Var)
|
||||
and types.typehint_issubclass(return_type._var_type, Component)
|
||||
)
|
||||
)
|
||||
for case in actual_cases
|
||||
):
|
||||
raise ValueError(
|
||||
"For cases with return types as Vars, a default case must be provided"
|
||||
)
|
||||
elif default is None:
|
||||
default = Fragment.create()
|
||||
|
||||
default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
|
||||
|
||||
return cls._create_match_cond_var_or_component(
|
||||
cond,
|
||||
actual_cases,
|
||||
default,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _process_match_cases(cls, cases: Sequence):
|
||||
def _process_match_cases(cls, cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
|
||||
"""Process the individual match cases.
|
||||
|
||||
Args:
|
||||
@ -116,7 +112,9 @@ class Match(MemoizationLeaf):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_return_types(cls, match_cases: List[List[Var]]) -> None:
|
||||
def _validate_return_types(
|
||||
cls, match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]
|
||||
) -> None:
|
||||
"""Validate that match cases have the same return types.
|
||||
|
||||
Args:
|
||||
@ -151,9 +149,9 @@ class Match(MemoizationLeaf):
|
||||
def _create_match_cond_var_or_component(
|
||||
cls,
|
||||
match_cond_var: Var,
|
||||
match_cases: List[List[Var]],
|
||||
default: Union[Var, BaseComponent],
|
||||
) -> Union[Component, Var]:
|
||||
match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
|
||||
default: VAR_TYPE | Var[VAR_TYPE],
|
||||
) -> Var[VAR_TYPE]:
|
||||
"""Create and return the match condition var or component.
|
||||
|
||||
Args:
|
||||
|
@ -303,7 +303,7 @@ class DataEditor(NoSSRComponent):
|
||||
|
||||
# Fired when editing is finished.
|
||||
on_finished_editing: EventHandler[
|
||||
passthrough_event_spec(Union[GridCell, None], tuple[int, int])
|
||||
passthrough_event_spec(Union[GridCell, None], tuple[int, int]) # pyright: ignore[reportArgumentType]
|
||||
]
|
||||
|
||||
# Fired when a row is appended.
|
||||
|
@ -197,7 +197,7 @@ class AccordionItem(AccordionComponent):
|
||||
# The header of the accordion item.
|
||||
header: Var[Union[Component, str]]
|
||||
# The content of the accordion item.
|
||||
content: Var[Union[Component, str]] = Var.create(None)
|
||||
content: Var[Union[Component, str]] = Var.create("")
|
||||
|
||||
_valid_children: List[str] = [
|
||||
"AccordionHeader",
|
||||
|
@ -4,9 +4,10 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, Union, get_args
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args
|
||||
|
||||
from reflex.components.tags.tag import Tag
|
||||
from reflex.utils import types
|
||||
from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -31,7 +32,7 @@ class IterTag(Tag):
|
||||
# The name of the index var.
|
||||
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
|
||||
|
||||
def get_iterable_var_type(self) -> Type:
|
||||
def get_iterable_var_type(self) -> types.GenericType:
|
||||
"""Get the type of the iterable var.
|
||||
|
||||
Returns:
|
||||
@ -41,10 +42,10 @@ class IterTag(Tag):
|
||||
try:
|
||||
if iterable._var_type.mro()[0] is dict:
|
||||
# Arg is a tuple of (key, value).
|
||||
return Tuple[get_args(iterable._var_type)] # type: ignore
|
||||
return Tuple[get_args(iterable._var_type)]
|
||||
elif iterable._var_type.mro()[0] is tuple:
|
||||
# Arg is a union of any possible values in the tuple.
|
||||
return Union[get_args(iterable._var_type)] # type: ignore
|
||||
return Union[get_args(iterable._var_type)]
|
||||
else:
|
||||
return get_args(iterable._var_type)[0]
|
||||
except Exception:
|
||||
|
@ -25,7 +25,6 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
import typing_extensions
|
||||
from typing_extensions import (
|
||||
Concatenate,
|
||||
ParamSpec,
|
||||
@ -33,6 +32,8 @@ from typing_extensions import (
|
||||
TypeAliasType,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
TypeVarTuple,
|
||||
deprecated,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
@ -620,14 +621,16 @@ stop_propagation = EventChain(events=[], args_spec=no_args_event_spec).stop_prop
|
||||
prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
EVENT_T = TypeVar("EVENT_T")
|
||||
EVENT_U = TypeVar("EVENT_U")
|
||||
|
||||
Ts = TypeVarTuple("Ts")
|
||||
|
||||
|
||||
class IdentityEventReturn(Generic[T], Protocol):
|
||||
class IdentityEventReturn(Generic[*Ts], Protocol):
|
||||
"""Protocol for an identity event return."""
|
||||
|
||||
def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]:
|
||||
def __call__(self, *values: *Ts) -> tuple[*Ts]:
|
||||
"""Return the input values.
|
||||
|
||||
Args:
|
||||
@ -641,21 +644,25 @@ class IdentityEventReturn(Generic[T], Protocol):
|
||||
|
||||
@overload
|
||||
def passthrough_event_spec(
|
||||
event_type: Type[T], /
|
||||
) -> Callable[[Var[T]], Tuple[Var[T]]]: ... # type: ignore
|
||||
event_type: Type[EVENT_T], /
|
||||
) -> IdentityEventReturn[Var[EVENT_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def passthrough_event_spec(
|
||||
event_type_1: Type[T], event_type2: Type[U], /
|
||||
) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ...
|
||||
event_type_1: Type[EVENT_T], event_type2: Type[EVENT_U], /
|
||||
) -> IdentityEventReturn[Var[EVENT_T], Var[EVENT_U]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: ...
|
||||
def passthrough_event_spec(
|
||||
*event_types: *tuple[Type[EVENT_T]],
|
||||
) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]: ...
|
||||
|
||||
|
||||
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: # type: ignore
|
||||
def passthrough_event_spec( # pyright: ignore[reportInconsistentOverload]
|
||||
*event_types: Type[EVENT_T],
|
||||
) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]:
|
||||
"""A helper function that returns the input event as output.
|
||||
|
||||
Args:
|
||||
@ -665,7 +672,7 @@ def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: #
|
||||
A function that returns the input event as output.
|
||||
"""
|
||||
|
||||
def inner(*values: Var[T]) -> Tuple[Var[T], ...]:
|
||||
def inner(*values: Var[EVENT_T]) -> Tuple[Var[EVENT_T], ...]:
|
||||
return values
|
||||
|
||||
inner_type = tuple(Var[event_type] for event_type in event_types)
|
||||
@ -800,7 +807,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
|
||||
return None
|
||||
|
||||
fn.__qualname__ = name
|
||||
fn.__signature__ = sig
|
||||
fn.__signature__ = sig # pyright: ignore[reportFunctionMemberAccess]
|
||||
return EventSpec(
|
||||
handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
|
||||
args=tuple(
|
||||
@ -822,7 +829,7 @@ def redirect(
|
||||
|
||||
|
||||
@overload
|
||||
@typing_extensions.deprecated("`external` is deprecated use `is_external` instead")
|
||||
@deprecated("`external` is deprecated use `is_external` instead")
|
||||
def redirect(
|
||||
path: str | Var[str],
|
||||
is_external: Optional[bool] = None,
|
||||
@ -1826,6 +1833,37 @@ class EventCallback(Generic[P, T]):
|
||||
"""
|
||||
self.func = func
|
||||
|
||||
def throttle(self, limit_ms: int):
|
||||
"""Throttle the event handler.
|
||||
|
||||
Args:
|
||||
limit_ms: The time in milliseconds to throttle the event handler.
|
||||
|
||||
Returns:
|
||||
New EventHandler-like with throttle set to limit_ms.
|
||||
"""
|
||||
return self
|
||||
|
||||
def debounce(self, delay_ms: int):
|
||||
"""Debounce the event handler.
|
||||
|
||||
Args:
|
||||
delay_ms: The time in milliseconds to debounce the event handler.
|
||||
|
||||
Returns:
|
||||
New EventHandler-like with debounce set to delay_ms.
|
||||
"""
|
||||
return self
|
||||
|
||||
@property
|
||||
def temporal(self):
|
||||
"""Do not queue the event if the backend is down.
|
||||
|
||||
Returns:
|
||||
New EventHandler-like with temporal set to True.
|
||||
"""
|
||||
return self
|
||||
|
||||
@property
|
||||
def prevent_default(self):
|
||||
"""Prevent default behavior.
|
||||
|
@ -587,8 +587,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
if cls._item_is_event_handler(name, fn)
|
||||
}
|
||||
|
||||
for mixin in cls._mixins():
|
||||
for name, value in mixin.__dict__.items():
|
||||
for mixin_class in cls._mixins():
|
||||
for name, value in mixin_class.__dict__.items():
|
||||
if name in cls.inherited_vars:
|
||||
continue
|
||||
if is_computed_var(value):
|
||||
@ -599,7 +599,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
||||
cls.computed_vars[newcv._js_expr] = newcv
|
||||
cls.vars[newcv._js_expr] = newcv
|
||||
continue
|
||||
if types.is_backend_base_variable(name, mixin):
|
||||
if types.is_backend_base_variable(name, mixin_class):
|
||||
cls.backend_vars[name] = copy.deepcopy(value)
|
||||
continue
|
||||
if events.get(name) is not None:
|
||||
@ -3710,6 +3710,9 @@ def get_state_manager() -> StateManager:
|
||||
return app.state_manager
|
||||
|
||||
|
||||
DATACLASS_FIELDS = getattr(dataclasses, "_FIELDS", "__dataclass_fields__")
|
||||
|
||||
|
||||
class MutableProxy(wrapt.ObjectProxy):
|
||||
"""A proxy for a mutable object that tracks changes."""
|
||||
|
||||
@ -3781,12 +3784,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
cls.__dataclass_proxies__[wrapper_cls_name] = type(
|
||||
wrapper_cls_name,
|
||||
(cls,),
|
||||
{
|
||||
dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues]
|
||||
wrapped_cls,
|
||||
dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues]
|
||||
),
|
||||
},
|
||||
{DATACLASS_FIELDS: getattr(wrapped_cls, DATACLASS_FIELDS)},
|
||||
)
|
||||
cls = cls.__dataclass_proxies__[wrapper_cls_name]
|
||||
return super().__new__(cls)
|
||||
@ -3933,11 +3931,11 @@ class MutableProxy(wrapt.ObjectProxy):
|
||||
if (
|
||||
isinstance(self.__wrapped__, Base)
|
||||
and __name not in self.__never_wrap_base_attrs__
|
||||
and hasattr(value, "__func__")
|
||||
and (value_func := getattr(value, "__func__", None))
|
||||
):
|
||||
# Wrap methods called on Base subclasses, which might do _anything_
|
||||
return wrapt.FunctionWrapper(
|
||||
functools.partial(value.__func__, self),
|
||||
functools.partial(value_func, self),
|
||||
self._wrap_recursive_decorator,
|
||||
)
|
||||
|
||||
|
@ -67,10 +67,8 @@ try:
|
||||
from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports]
|
||||
WebElement,
|
||||
)
|
||||
|
||||
has_selenium = True
|
||||
except ImportError:
|
||||
has_selenium = False
|
||||
webdriver = None
|
||||
|
||||
# The timeout (minutes) to check for the port.
|
||||
DEFAULT_TIMEOUT = 15
|
||||
@ -293,8 +291,12 @@ class AppHarness:
|
||||
if p not in before_decorated_pages
|
||||
]
|
||||
self.app_instance = self.app_module.app
|
||||
if self.app_instance is None:
|
||||
raise RuntimeError("App was not initialized.")
|
||||
if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
||||
# Create our own redis connection for testing.
|
||||
if self.app_instance.state is None:
|
||||
raise RuntimeError("App state is not initialized.")
|
||||
self.state_manager = StateManagerRedis.create(self.app_instance.state)
|
||||
else:
|
||||
self.state_manager = self.app_instance._state_manager
|
||||
@ -608,7 +610,7 @@ class AppHarness:
|
||||
Raises:
|
||||
RuntimeError: when selenium is not importable or frontend is not running
|
||||
"""
|
||||
if not has_selenium:
|
||||
if webdriver is None:
|
||||
raise RuntimeError(
|
||||
"Frontend functionality requires `selenium` to be installed, "
|
||||
"and it could not be imported."
|
||||
|
@ -203,10 +203,13 @@ def _get_first_non_framework_frame() -> FrameType | None:
|
||||
# Exclude utility modules that should never be the source of deprecated reflex usage.
|
||||
exclude_modules = [click, rx, typer, typing_extensions]
|
||||
exclude_roots = [
|
||||
p.parent.resolve()
|
||||
if (p := Path(m.__file__)).name == "__init__.py"
|
||||
else p.resolve()
|
||||
(
|
||||
p.parent.resolve()
|
||||
if (p := Path(m.__file__)).name == "__init__.py"
|
||||
else p.resolve()
|
||||
)
|
||||
for m in exclude_modules
|
||||
if m.__file__
|
||||
]
|
||||
# Specifically exclude the reflex cli module.
|
||||
if reflex_bin := shutil.which(b"reflex"):
|
||||
|
@ -3197,17 +3197,20 @@ class Field(Generic[T]):
|
||||
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
|
||||
|
||||
@overload
|
||||
def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
|
||||
def __get__(self: Field[int], instance: None, owner) -> NumberVar[int]: ...
|
||||
|
||||
@overload
|
||||
def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
|
||||
def __get__(self: Field[float], instance: None, owner) -> NumberVar[float]: ...
|
||||
|
||||
@overload
|
||||
def __get__(self: Field[str], instance: None, owner) -> StringVar[str]: ...
|
||||
|
||||
@overload
|
||||
def __get__(self: Field[None], instance: None, owner) -> NoneVar: ...
|
||||
|
||||
@overload
|
||||
def __get__(
|
||||
self: Field[Sequence[V]] | Field[Set[V]],
|
||||
self: Field[Sequence[V]] | Field[Set[V]] | Field[List[V]],
|
||||
instance: None,
|
||||
owner,
|
||||
) -> ArrayVar[Sequence[V]]: ...
|
||||
|
@ -1069,24 +1069,11 @@ def ternary_operation(
|
||||
return value
|
||||
|
||||
|
||||
TUPLE_ENDS_IN_VAR = (
|
||||
tuple[Var[VAR_TYPE]]
|
||||
| tuple[Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
||||
| tuple[
|
||||
Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]
|
||||
]
|
||||
)
|
||||
X = tuple[*tuple[Var, ...], str]
|
||||
|
||||
TUPLE_ENDS_IN_VAR = tuple[*tuple[Var[Any], ...], Var[VAR_TYPE]]
|
||||
|
||||
TUPLE_ENDS_IN_VAR_RELAXED = tuple[*tuple[Var[Any] | Any, ...], Var[VAR_TYPE] | VAR_TYPE]
|
||||
|
||||
|
||||
@dataclasses.dataclass(
|
||||
@ -1153,7 +1140,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
|
||||
def create(
|
||||
cls,
|
||||
cond: Any,
|
||||
cases: Sequence[Sequence[Any | Var[VAR_TYPE]]],
|
||||
cases: Sequence[TUPLE_ENDS_IN_VAR_RELAXED[VAR_TYPE]],
|
||||
default: Var[VAR_TYPE] | VAR_TYPE,
|
||||
_var_data: VarData | None = None,
|
||||
_var_type: type[VAR_TYPE] | None = None,
|
||||
@ -1175,6 +1162,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
|
||||
tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
|
||||
tuple(tuple(Var.create(c) for c in case) for case in cases),
|
||||
)
|
||||
|
||||
_default = cast(Var[VAR_TYPE], Var.create(default))
|
||||
var_type = _var_type or unionize(
|
||||
*(case[-1]._var_type for case in cases),
|
||||
|
@ -45,7 +45,7 @@ from .base import (
|
||||
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
|
||||
from .sequence import ArrayVar, StringVar
|
||||
|
||||
OBJECT_TYPE = TypeVar("OBJECT_TYPE")
|
||||
OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
|
||||
|
||||
KEY_TYPE = TypeVar("KEY_TYPE")
|
||||
VALUE_TYPE = TypeVar("VALUE_TYPE")
|
||||
@ -164,7 +164,8 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]],
|
||||
self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]]
|
||||
| ObjectVar[Dict[Any, List[ARRAY_INNER_TYPE]]],
|
||||
key: Var | Any,
|
||||
) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...
|
||||
|
||||
|
@ -28,9 +28,11 @@ def TestEventAction():
|
||||
def on_click2(self):
|
||||
self.order.append("on_click2")
|
||||
|
||||
@rx.event
|
||||
def on_click_throttle(self):
|
||||
self.order.append("on_click_throttle")
|
||||
|
||||
@rx.event
|
||||
def on_click_debounce(self):
|
||||
self.order.append("on_click_debounce")
|
||||
|
||||
|
@ -22,9 +22,9 @@ def LifespanApp():
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan_context(app, inc: int = 1):
|
||||
global lifespan_context_global
|
||||
nonlocal lifespan_context_global
|
||||
print(f"Lifespan context entered: {app}.")
|
||||
lifespan_context_global += inc # pyright: ignore[reportUnboundVariable]
|
||||
lifespan_context_global += inc
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@ -32,11 +32,11 @@ def LifespanApp():
|
||||
lifespan_context_global += inc
|
||||
|
||||
async def lifespan_task(inc: int = 1):
|
||||
global lifespan_task_global
|
||||
nonlocal lifespan_task_global
|
||||
print("Lifespan global started.")
|
||||
try:
|
||||
while True:
|
||||
lifespan_task_global += inc # pyright: ignore[reportUnboundVariable]
|
||||
lifespan_task_global += inc
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError as ce:
|
||||
print(f"Lifespan global cancelled: {ce}.")
|
||||
|
@ -19,25 +19,27 @@ def VarOperations():
|
||||
from reflex.vars.sequence import ArrayVar
|
||||
|
||||
class Object(rx.Base):
|
||||
str: str = "hello"
|
||||
name: str = "hello"
|
||||
|
||||
class VarOperationState(rx.State):
|
||||
int_var1: int = 10
|
||||
int_var2: int = 5
|
||||
int_var3: int = 7
|
||||
float_var1: float = 10.5
|
||||
float_var2: float = 5.5
|
||||
list1: List = [1, 2]
|
||||
list2: List = [3, 4]
|
||||
list3: List = ["first", "second", "third"]
|
||||
list4: List = [Object(name="obj_1"), Object(name="obj_2")]
|
||||
str_var1: str = "first"
|
||||
str_var2: str = "second"
|
||||
str_var3: str = "ThIrD"
|
||||
str_var4: str = "a long string"
|
||||
dict1: Dict[int, int] = {1: 2}
|
||||
dict2: Dict[int, int] = {3: 4}
|
||||
html_str: str = "<div>hello</div>"
|
||||
int_var1: rx.Field[int] = rx.field(10)
|
||||
int_var2: rx.Field[int] = rx.field(5)
|
||||
int_var3: rx.Field[int] = rx.field(7)
|
||||
float_var1: rx.Field[float] = rx.field(10.5)
|
||||
float_var2: rx.Field[float] = rx.field(5.5)
|
||||
list1: rx.Field[List[int]] = rx.field([1, 2])
|
||||
list2: rx.Field[List[int]] = rx.field([3, 4])
|
||||
list3: rx.Field[List[str]] = rx.field(["first", "second", "third"])
|
||||
list4: rx.Field[List[Object]] = rx.field(
|
||||
[Object(name="obj_1"), Object(name="obj_2")]
|
||||
)
|
||||
str_var1: rx.Field[str] = rx.field("first")
|
||||
str_var2: rx.Field[str] = rx.field("second")
|
||||
str_var3: rx.Field[str] = rx.field("ThIrD")
|
||||
str_var4: rx.Field[str] = rx.field("a long string")
|
||||
dict1: rx.Field[Dict[int, int]] = rx.field({1: 2})
|
||||
dict2: rx.Field[Dict[int, int]] = rx.field({3: 4})
|
||||
html_str: rx.Field[str] = rx.field("<div>hello</div>")
|
||||
|
||||
app = rx.App(state=rx.State)
|
||||
|
||||
|
@ -13,7 +13,7 @@ from reflex.vars.base import LiteralVar, Var, computed_var
|
||||
@pytest.fixture
|
||||
def cond_state(request):
|
||||
class CondState(BaseState):
|
||||
value: request.param["value_type"] = request.param["value"] # noqa
|
||||
value: request.param["value_type"] = request.param["value"] # pyright: ignore[reportInvalidTypeForm, reportUndefinedVariable] # noqa: F821
|
||||
|
||||
return CondState
|
||||
|
||||
|
@ -67,7 +67,7 @@ def test_match_vars(cases, expected):
|
||||
cases: The match cases.
|
||||
expected: The expected var full name.
|
||||
"""
|
||||
match_comp = Match.create(MatchState.value, *cases)
|
||||
match_comp = Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
|
||||
assert isinstance(match_comp, Var)
|
||||
assert str(match_comp) == expected
|
||||
|
||||
@ -131,7 +131,7 @@ def test_match_default_not_last_arg(match_case):
|
||||
ValueError,
|
||||
match="rx.match should have tuples of cases and a default case as the last argument.",
|
||||
):
|
||||
Match.create(MatchState.value, *match_case)
|
||||
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -161,7 +161,7 @@ def test_match_case_tuple_elements(match_case):
|
||||
ValueError,
|
||||
match="A case tuple should have at least a match case element and a return value.",
|
||||
):
|
||||
Match.create(MatchState.value, *match_case)
|
||||
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -203,7 +203,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str):
|
||||
error_msg: Expected error message.
|
||||
"""
|
||||
with pytest.raises(MatchTypeError, match=error_msg):
|
||||
Match.create(MatchState.value, *cases)
|
||||
Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -235,9 +235,9 @@ def test_match_multiple_default_cases(match_case):
|
||||
match_case: the cases to match.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="rx.match can only have one default case."):
|
||||
Match.create(MatchState.value, *match_case)
|
||||
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
def test_match_no_cond():
|
||||
with pytest.raises(ValueError):
|
||||
_ = Match.create(None)
|
||||
_ = Match.create(None) # pyright: ignore[reportCallIssue]
|
||||
|
@ -13,7 +13,8 @@ from reflex.utils.serializers import serialize, serialize_dataframe
|
||||
pytest.param(
|
||||
{
|
||||
"data": pd.DataFrame(
|
||||
[["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
|
||||
[["foo", "bar"], ["foo1", "bar1"]],
|
||||
columns=["column1", "column2"], # pyright: ignore [reportArgumentType]
|
||||
)
|
||||
},
|
||||
"data",
|
||||
@ -113,7 +114,8 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
|
||||
def test_serialize_dataframe():
|
||||
"""Test if dataframe is serialized correctly."""
|
||||
df = pd.DataFrame(
|
||||
[["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
|
||||
[["foo", "bar"], ["foo1", "bar1"]],
|
||||
columns=["column1", "column2"], # pyright: ignore [reportArgumentType]
|
||||
)
|
||||
value = serialize(df)
|
||||
assert value == serialize_dataframe(df)
|
||||
|
@ -9,7 +9,7 @@ import unittest.mock
|
||||
import uuid
|
||||
from contextlib import nullcontext as does_not_raise
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Tuple, Type
|
||||
from typing import Generator, List, Tuple, Type, cast
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
@ -33,7 +33,7 @@ from reflex.components import Component
|
||||
from reflex.components.base.fragment import Fragment
|
||||
from reflex.components.core.cond import Cond
|
||||
from reflex.components.radix.themes.typography.text import Text
|
||||
from reflex.event import Event
|
||||
from reflex.event import Event, EventHandler
|
||||
from reflex.middleware import HydrateMiddleware
|
||||
from reflex.model import Model
|
||||
from reflex.state import (
|
||||
@ -917,7 +917,7 @@ class DynamicState(BaseState):
|
||||
"""
|
||||
return self.dynamic
|
||||
|
||||
on_load_internal = OnLoadInternalState.on_load_internal.fn
|
||||
on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn
|
||||
|
||||
|
||||
def test_dynamic_arg_shadow(
|
||||
@ -1190,7 +1190,7 @@ async def test_process_events(mocker, token: str):
|
||||
pass
|
||||
|
||||
assert (await app.state_manager.get_state(event.substate_token)).value == 5
|
||||
assert app._postprocess.call_count == 6
|
||||
assert getattr(app._postprocess, "call_count", None) == 6
|
||||
|
||||
if isinstance(app.state_manager, StateManagerRedis):
|
||||
await app.state_manager.close()
|
||||
@ -1247,7 +1247,7 @@ def test_overlay_component(
|
||||
|
||||
if exp_page_child is not None:
|
||||
assert len(page.children) == 3
|
||||
children_types = (type(child) for child in page.children)
|
||||
children_types = [type(child) for child in page.children]
|
||||
assert exp_page_child in children_types
|
||||
else:
|
||||
assert len(page.children) == 2
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
import reflex as rx
|
||||
from reflex.event import (
|
||||
Event,
|
||||
EventActionsMixin,
|
||||
EventChain,
|
||||
EventHandler,
|
||||
EventSpec,
|
||||
@ -410,6 +411,7 @@ def test_event_actions():
|
||||
|
||||
def test_event_actions_on_state():
|
||||
class EventActionState(BaseState):
|
||||
@rx.event
|
||||
def handler(self):
|
||||
pass
|
||||
|
||||
@ -418,6 +420,7 @@ def test_event_actions_on_state():
|
||||
assert not handler.event_actions
|
||||
|
||||
sp_handler = EventActionState.handler.stop_propagation
|
||||
assert isinstance(sp_handler, EventActionsMixin)
|
||||
assert sp_handler.event_actions == {"stopPropagation": True}
|
||||
# should NOT affect other references to the handler
|
||||
assert not handler.event_actions
|
||||
|
@ -122,9 +122,12 @@ async def test_health(
|
||||
# Call the async health function
|
||||
response = await health()
|
||||
|
||||
print(json.loads(response.body))
|
||||
body = response.body
|
||||
assert isinstance(body, bytes)
|
||||
|
||||
print(json.loads(body))
|
||||
print(expected_status)
|
||||
|
||||
# Verify the response content and status code
|
||||
assert response.status_code == expected_code
|
||||
assert json.loads(response.body) == expected_status
|
||||
assert json.loads(body) == expected_status
|
||||
|
@ -59,7 +59,7 @@ def test_automigration(
|
||||
id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
|
||||
|
||||
# initial table
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
|
||||
t1: Mapped[str] = mapped_column(default="")
|
||||
|
||||
with Model.get_db_engine().connect() as connection:
|
||||
@ -78,7 +78,7 @@ def test_automigration(
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
# Create column t2, mark t1 as optional with default
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
|
||||
t1: Mapped[Optional[str]] = mapped_column(default="default")
|
||||
t2: Mapped[str] = mapped_column(default="bar")
|
||||
|
||||
@ -98,7 +98,7 @@ def test_automigration(
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
# Drop column t1
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
|
||||
t2: Mapped[str] = mapped_column(default="bar")
|
||||
|
||||
assert Model.migrate(autogenerate=True)
|
||||
@ -133,7 +133,7 @@ def test_automigration(
|
||||
# drop table (AlembicSecond)
|
||||
model_registry.get_metadata().clear()
|
||||
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
||||
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
|
||||
t2: Mapped[str] = mapped_column(default="bar")
|
||||
|
||||
assert Model.migrate(autogenerate=True)
|
||||
|
@ -17,6 +17,7 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
@ -120,8 +121,8 @@ class TestState(BaseState):
|
||||
num2: float = 3.14
|
||||
key: str
|
||||
map_key: str = "a"
|
||||
array: List[float] = [1, 2, 3.14]
|
||||
mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]}
|
||||
array: rx.Field[List[float]] = rx.field([1, 2, 3.14])
|
||||
mapping: rx.Field[Dict[str, List[int]]] = rx.field({"a": [1, 2, 3], "b": [4, 5, 6]})
|
||||
obj: Object = Object()
|
||||
complex: Dict[int, Object] = {1: Object(), 2: Object()}
|
||||
fig: Figure = Figure()
|
||||
@ -1357,6 +1358,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
||||
class HandlerState(BaseState):
|
||||
x: int = 42
|
||||
|
||||
@rx.event
|
||||
def handler(self):
|
||||
self.x = self.x + 1
|
||||
|
||||
@ -1367,11 +1369,11 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
||||
counter += 1
|
||||
return counter
|
||||
|
||||
assert isinstance(HandlerState.handler, EventHandler)
|
||||
if use_partial:
|
||||
HandlerState.handler = functools.partial(HandlerState.handler.fn)
|
||||
partial_guy = functools.partial(HandlerState.handler.fn)
|
||||
HandlerState.handler = partial_guy # pyright: ignore[reportAttributeAccessIssue]
|
||||
assert isinstance(HandlerState.handler, functools.partial)
|
||||
else:
|
||||
assert isinstance(HandlerState.handler, EventHandler)
|
||||
|
||||
s = HandlerState()
|
||||
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
|
||||
@ -2025,8 +2027,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
||||
|
||||
# ensure state update was emitted
|
||||
assert mock_app.event_namespace is not None
|
||||
mock_app.event_namespace.emit.assert_called_once()
|
||||
mcall = mock_app.event_namespace.emit.mock_calls[0]
|
||||
mock_app.event_namespace.emit.assert_called_once() # pyright: ignore[reportFunctionMemberAccess]
|
||||
mock_calls = getattr(mock_app.event_namespace.emit, "mock_calls", None)
|
||||
assert mock_calls is not None
|
||||
assert isinstance(mock_calls, Sequence)
|
||||
mcall = mock_calls[0]
|
||||
assert mcall.args[0] == str(SocketEvent.EVENT)
|
||||
assert mcall.args[1] == StateUpdate(
|
||||
delta={
|
||||
@ -2231,7 +2236,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
assert mock_app.event_namespace is not None
|
||||
emit_mock = mock_app.event_namespace.emit
|
||||
|
||||
first_ws_message = emit_mock.mock_calls[0].args[1]
|
||||
mock_calls = getattr(emit_mock, "mock_calls", None)
|
||||
assert mock_calls is not None
|
||||
assert isinstance(mock_calls, Sequence)
|
||||
|
||||
first_ws_message = mock_calls[0].args[1]
|
||||
assert (
|
||||
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
|
||||
is not None
|
||||
@ -2246,7 +2255,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
events=[],
|
||||
final=True,
|
||||
)
|
||||
for call in emit_mock.mock_calls[1:5]:
|
||||
for call in mock_calls[1:5]:
|
||||
assert call.args[1] == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_full_name(): {
|
||||
@ -2256,7 +2265,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
events=[],
|
||||
final=True,
|
||||
)
|
||||
assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
|
||||
assert mock_calls[-2].args[1] == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_full_name(): {
|
||||
"order": exp_order,
|
||||
@ -2267,7 +2276,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
||||
events=[],
|
||||
final=True,
|
||||
)
|
||||
assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
|
||||
assert mock_calls[-1].args[1] == StateUpdate(
|
||||
delta={
|
||||
BackgroundTaskState.get_full_name(): {
|
||||
"computed_order": exp_order,
|
||||
|
Loading…
Reference in New Issue
Block a user