down to only two pyright error
This commit is contained in:
parent
112b2ed948
commit
57d8ea02e9
8
poetry.lock
generated
8
poetry.lock
generated
@ -2813,13 +2813,13 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "virtualenv"
|
name = "virtualenv"
|
||||||
version = "20.28.1"
|
version = "20.29.1"
|
||||||
description = "Virtual Python Environment builder"
|
description = "Virtual Python Environment builder"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "virtualenv-20.28.1-py3-none-any.whl", hash = "sha256:412773c85d4dab0409b83ec36f7a6499e72eaf08c80e81e9576bca61831c71cb"},
|
{file = "virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779"},
|
||||||
{file = "virtualenv-20.28.1.tar.gz", hash = "sha256:5d34ab240fdb5d21549b76f9e8ff3af28252f5499fb6d6f031adac4e5a8c5329"},
|
{file = "virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -3063,4 +3063,4 @@ type = ["pytest-mypy"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.9"
|
python-versions = "^3.9"
|
||||||
content-hash = "ac7633388e00416af61c93e16c61c3b4a446e406afab2758585088e2b3416eee"
|
content-hash = "ccd6d6b00fdcf40562854380fafdb18c990b7f6a4f2883b33aaeb0351fcdbc06"
|
||||||
|
@ -23,7 +23,7 @@ fastapi = ">=0.96.0,!=0.111.0,!=0.111.1"
|
|||||||
gunicorn = ">=20.1.0,<24.0"
|
gunicorn = ">=20.1.0,<24.0"
|
||||||
jinja2 = ">=3.1.2,<4.0"
|
jinja2 = ">=3.1.2,<4.0"
|
||||||
psutil = ">=5.9.4,<7.0"
|
psutil = ">=5.9.4,<7.0"
|
||||||
pydantic = ">=1.10.2,<3.0"
|
pydantic = ">=1.10.15,<3.0"
|
||||||
python-multipart = ">=0.0.5,<0.1"
|
python-multipart = ">=0.0.5,<0.1"
|
||||||
python-socketio = ">=5.7.0,<6.0"
|
python-socketio = ">=5.7.0,<6.0"
|
||||||
redis = ">=4.3.5,<6.0"
|
redis = ">=4.3.5,<6.0"
|
||||||
@ -82,6 +82,7 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
reportIncompatibleMethodOverride = false
|
reportIncompatibleMethodOverride = false
|
||||||
|
reportIncompatibleVariableOverride = false
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
@ -5,15 +5,9 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, List, Type
|
from typing import TYPE_CHECKING, Any, List, Type
|
||||||
|
|
||||||
try:
|
|
||||||
import pydantic.v1.main as pydantic_main
|
import pydantic.v1.main as pydantic_main
|
||||||
from pydantic.v1 import BaseModel
|
from pydantic.v1 import BaseModel
|
||||||
from pydantic.v1.fields import ModelField
|
from pydantic.v1.fields import ModelField
|
||||||
except ModuleNotFoundError:
|
|
||||||
if not TYPE_CHECKING:
|
|
||||||
import pydantic.main as pydantic_main
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic.fields import ModelField # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
|
def validate_field_name(bases: List[Type["BaseModel"]], field_name: str) -> None:
|
||||||
|
@ -113,11 +113,7 @@ class Cond(MemoizationLeaf):
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def cond(condition: Any, c1: Component, c2: Any) -> Component: ...
|
def cond(condition: Any, c1: Component, c2: Any = None) -> Component: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def cond(condition: Any, c1: Component) -> Component: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
"""rx.match."""
|
"""rx.match."""
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
from typing import Any, List, cast
|
||||||
|
|
||||||
from reflex.components.base import Fragment
|
from reflex.components.base import Fragment
|
||||||
from reflex.components.component import BaseComponent, Component, MemoizationLeaf
|
from reflex.components.component import BaseComponent, Component, MemoizationLeaf
|
||||||
from reflex.utils import types
|
from reflex.utils import types
|
||||||
from reflex.utils.exceptions import MatchTypeError
|
from reflex.utils.exceptions import MatchTypeError
|
||||||
from reflex.vars.base import Var
|
from reflex.vars.base import VAR_TYPE, Var
|
||||||
from reflex.vars.number import MatchOperation
|
from reflex.vars.number import MatchOperation
|
||||||
|
|
||||||
|
CASE_TYPE = tuple[*tuple[Any, ...], Var[VAR_TYPE] | VAR_TYPE]
|
||||||
|
|
||||||
|
|
||||||
class Match(MemoizationLeaf):
|
class Match(MemoizationLeaf):
|
||||||
"""Match cases based on a condition."""
|
"""Match cases based on a condition."""
|
||||||
@ -24,7 +26,11 @@ class Match(MemoizationLeaf):
|
|||||||
default: Any
|
default: Any
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, cond: Any, *cases) -> Union[Component, Var]:
|
def create(
|
||||||
|
cls,
|
||||||
|
cond: Any,
|
||||||
|
*cases: *tuple[*tuple[CASE_TYPE[VAR_TYPE], ...], Var[VAR_TYPE] | VAR_TYPE],
|
||||||
|
) -> Var[VAR_TYPE]:
|
||||||
"""Create a Match Component.
|
"""Create a Match Component.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -37,44 +43,6 @@ class Match(MemoizationLeaf):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: When a default case is not provided for cases with Var return types.
|
ValueError: When a default case is not provided for cases with Var return types.
|
||||||
"""
|
"""
|
||||||
cases, default = cls._process_cases(cases)
|
|
||||||
cls._process_match_cases(cases)
|
|
||||||
|
|
||||||
cls._validate_return_types(cases)
|
|
||||||
|
|
||||||
if default is None and any(
|
|
||||||
not (
|
|
||||||
isinstance((return_type := case[-1]), Component)
|
|
||||||
or (
|
|
||||||
isinstance(return_type, Var)
|
|
||||||
and types.typehint_issubclass(return_type._var_type, Component)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for case in cases
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"For cases with return types as Vars, a default case must be provided"
|
|
||||||
)
|
|
||||||
elif default is None:
|
|
||||||
default = Fragment.create()
|
|
||||||
|
|
||||||
return cls._create_match_cond_var_or_component(cond, cases, default)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _process_cases(
|
|
||||||
cls, cases: Sequence
|
|
||||||
) -> Tuple[List, Optional[Union[Var, BaseComponent]]]:
|
|
||||||
"""Process the list of match cases and the catchall default case.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cases: The list of match cases.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The default case and the list of match case tuples.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If there are multiple default cases.
|
|
||||||
"""
|
|
||||||
default = None
|
default = None
|
||||||
|
|
||||||
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
|
if len([case for case in cases if not isinstance(case, tuple)]) > 1:
|
||||||
@ -86,12 +54,40 @@ class Match(MemoizationLeaf):
|
|||||||
# Get the default case which should be the last non-tuple arg
|
# Get the default case which should be the last non-tuple arg
|
||||||
if not isinstance(cases[-1], tuple):
|
if not isinstance(cases[-1], tuple):
|
||||||
default = cases[-1]
|
default = cases[-1]
|
||||||
cases = cases[:-1]
|
actual_cases = cases[:-1]
|
||||||
|
else:
|
||||||
|
actual_cases = cast(tuple[CASE_TYPE[VAR_TYPE], ...], cases)
|
||||||
|
|
||||||
return list(cases), default
|
cls._process_match_cases(actual_cases)
|
||||||
|
|
||||||
|
cls._validate_return_types(actual_cases)
|
||||||
|
|
||||||
|
if default is None and any(
|
||||||
|
not (
|
||||||
|
isinstance((return_type := case[-1]), Component)
|
||||||
|
or (
|
||||||
|
isinstance(return_type, Var)
|
||||||
|
and types.typehint_issubclass(return_type._var_type, Component)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for case in actual_cases
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"For cases with return types as Vars, a default case must be provided"
|
||||||
|
)
|
||||||
|
elif default is None:
|
||||||
|
default = Fragment.create()
|
||||||
|
|
||||||
|
default = cast(Var[VAR_TYPE] | VAR_TYPE, default)
|
||||||
|
|
||||||
|
return cls._create_match_cond_var_or_component(
|
||||||
|
cond,
|
||||||
|
actual_cases,
|
||||||
|
default,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process_match_cases(cls, cases: Sequence):
|
def _process_match_cases(cls, cases: tuple[CASE_TYPE[VAR_TYPE], ...]):
|
||||||
"""Process the individual match cases.
|
"""Process the individual match cases.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -116,7 +112,9 @@ class Match(MemoizationLeaf):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_return_types(cls, match_cases: List[List[Var]]) -> None:
|
def _validate_return_types(
|
||||||
|
cls, match_cases: tuple[CASE_TYPE[VAR_TYPE], ...]
|
||||||
|
) -> None:
|
||||||
"""Validate that match cases have the same return types.
|
"""Validate that match cases have the same return types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -151,9 +149,9 @@ class Match(MemoizationLeaf):
|
|||||||
def _create_match_cond_var_or_component(
|
def _create_match_cond_var_or_component(
|
||||||
cls,
|
cls,
|
||||||
match_cond_var: Var,
|
match_cond_var: Var,
|
||||||
match_cases: List[List[Var]],
|
match_cases: tuple[CASE_TYPE[VAR_TYPE], ...],
|
||||||
default: Union[Var, BaseComponent],
|
default: VAR_TYPE | Var[VAR_TYPE],
|
||||||
) -> Union[Component, Var]:
|
) -> Var[VAR_TYPE]:
|
||||||
"""Create and return the match condition var or component.
|
"""Create and return the match condition var or component.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -303,7 +303,7 @@ class DataEditor(NoSSRComponent):
|
|||||||
|
|
||||||
# Fired when editing is finished.
|
# Fired when editing is finished.
|
||||||
on_finished_editing: EventHandler[
|
on_finished_editing: EventHandler[
|
||||||
passthrough_event_spec(Union[GridCell, None], tuple[int, int])
|
passthrough_event_spec(Union[GridCell, None], tuple[int, int]) # pyright: ignore[reportArgumentType]
|
||||||
]
|
]
|
||||||
|
|
||||||
# Fired when a row is appended.
|
# Fired when a row is appended.
|
||||||
|
@ -197,7 +197,7 @@ class AccordionItem(AccordionComponent):
|
|||||||
# The header of the accordion item.
|
# The header of the accordion item.
|
||||||
header: Var[Union[Component, str]]
|
header: Var[Union[Component, str]]
|
||||||
# The content of the accordion item.
|
# The content of the accordion item.
|
||||||
content: Var[Union[Component, str]] = Var.create(None)
|
content: Var[Union[Component, str]] = Var.create("")
|
||||||
|
|
||||||
_valid_children: List[str] = [
|
_valid_children: List[str] = [
|
||||||
"AccordionHeader",
|
"AccordionHeader",
|
||||||
|
@ -4,9 +4,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, Union, get_args
|
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Union, get_args
|
||||||
|
|
||||||
from reflex.components.tags.tag import Tag
|
from reflex.components.tags.tag import Tag
|
||||||
|
from reflex.utils import types
|
||||||
from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
|
from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -31,7 +32,7 @@ class IterTag(Tag):
|
|||||||
# The name of the index var.
|
# The name of the index var.
|
||||||
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
|
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)
|
||||||
|
|
||||||
def get_iterable_var_type(self) -> Type:
|
def get_iterable_var_type(self) -> types.GenericType:
|
||||||
"""Get the type of the iterable var.
|
"""Get the type of the iterable var.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -41,10 +42,10 @@ class IterTag(Tag):
|
|||||||
try:
|
try:
|
||||||
if iterable._var_type.mro()[0] is dict:
|
if iterable._var_type.mro()[0] is dict:
|
||||||
# Arg is a tuple of (key, value).
|
# Arg is a tuple of (key, value).
|
||||||
return Tuple[get_args(iterable._var_type)] # type: ignore
|
return Tuple[get_args(iterable._var_type)]
|
||||||
elif iterable._var_type.mro()[0] is tuple:
|
elif iterable._var_type.mro()[0] is tuple:
|
||||||
# Arg is a union of any possible values in the tuple.
|
# Arg is a union of any possible values in the tuple.
|
||||||
return Union[get_args(iterable._var_type)] # type: ignore
|
return Union[get_args(iterable._var_type)]
|
||||||
else:
|
else:
|
||||||
return get_args(iterable._var_type)[0]
|
return get_args(iterable._var_type)[0]
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -25,7 +25,6 @@ from typing import (
|
|||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
import typing_extensions
|
|
||||||
from typing_extensions import (
|
from typing_extensions import (
|
||||||
Concatenate,
|
Concatenate,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
@ -33,6 +32,8 @@ from typing_extensions import (
|
|||||||
TypeAliasType,
|
TypeAliasType,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
TypeVarTuple,
|
||||||
|
deprecated,
|
||||||
get_args,
|
get_args,
|
||||||
get_origin,
|
get_origin,
|
||||||
)
|
)
|
||||||
@ -620,14 +621,16 @@ stop_propagation = EventChain(events=[], args_spec=no_args_event_spec).stop_prop
|
|||||||
prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default
|
prevent_default = EventChain(events=[], args_spec=no_args_event_spec).prevent_default
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
EVENT_T = TypeVar("EVENT_T")
|
||||||
U = TypeVar("U")
|
EVENT_U = TypeVar("EVENT_U")
|
||||||
|
|
||||||
|
Ts = TypeVarTuple("Ts")
|
||||||
|
|
||||||
|
|
||||||
class IdentityEventReturn(Generic[T], Protocol):
|
class IdentityEventReturn(Generic[*Ts], Protocol):
|
||||||
"""Protocol for an identity event return."""
|
"""Protocol for an identity event return."""
|
||||||
|
|
||||||
def __call__(self, *values: Var[T]) -> Tuple[Var[T], ...]:
|
def __call__(self, *values: *Ts) -> tuple[*Ts]:
|
||||||
"""Return the input values.
|
"""Return the input values.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -641,21 +644,25 @@ class IdentityEventReturn(Generic[T], Protocol):
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
def passthrough_event_spec(
|
def passthrough_event_spec(
|
||||||
event_type: Type[T], /
|
event_type: Type[EVENT_T], /
|
||||||
) -> Callable[[Var[T]], Tuple[Var[T]]]: ... # type: ignore
|
) -> IdentityEventReturn[Var[EVENT_T]]: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def passthrough_event_spec(
|
def passthrough_event_spec(
|
||||||
event_type_1: Type[T], event_type2: Type[U], /
|
event_type_1: Type[EVENT_T], event_type2: Type[EVENT_U], /
|
||||||
) -> Callable[[Var[T], Var[U]], Tuple[Var[T], Var[U]]]: ...
|
) -> IdentityEventReturn[Var[EVENT_T], Var[EVENT_U]]: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: ...
|
def passthrough_event_spec(
|
||||||
|
*event_types: *tuple[Type[EVENT_T]],
|
||||||
|
) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]: ...
|
||||||
|
|
||||||
|
|
||||||
def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: # type: ignore
|
def passthrough_event_spec( # pyright: ignore[reportInconsistentOverload]
|
||||||
|
*event_types: Type[EVENT_T],
|
||||||
|
) -> IdentityEventReturn[*tuple[Var[EVENT_T], ...]]:
|
||||||
"""A helper function that returns the input event as output.
|
"""A helper function that returns the input event as output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -665,7 +672,7 @@ def passthrough_event_spec(*event_types: Type[T]) -> IdentityEventReturn[T]: #
|
|||||||
A function that returns the input event as output.
|
A function that returns the input event as output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def inner(*values: Var[T]) -> Tuple[Var[T], ...]:
|
def inner(*values: Var[EVENT_T]) -> Tuple[Var[EVENT_T], ...]:
|
||||||
return values
|
return values
|
||||||
|
|
||||||
inner_type = tuple(Var[event_type] for event_type in event_types)
|
inner_type = tuple(Var[event_type] for event_type in event_types)
|
||||||
@ -800,7 +807,7 @@ def server_side(name: str, sig: inspect.Signature, **kwargs) -> EventSpec:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
fn.__qualname__ = name
|
fn.__qualname__ = name
|
||||||
fn.__signature__ = sig
|
fn.__signature__ = sig # pyright: ignore[reportFunctionMemberAccess]
|
||||||
return EventSpec(
|
return EventSpec(
|
||||||
handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
|
handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE),
|
||||||
args=tuple(
|
args=tuple(
|
||||||
@ -822,7 +829,7 @@ def redirect(
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@typing_extensions.deprecated("`external` is deprecated use `is_external` instead")
|
@deprecated("`external` is deprecated use `is_external` instead")
|
||||||
def redirect(
|
def redirect(
|
||||||
path: str | Var[str],
|
path: str | Var[str],
|
||||||
is_external: Optional[bool] = None,
|
is_external: Optional[bool] = None,
|
||||||
@ -1826,6 +1833,37 @@ class EventCallback(Generic[P, T]):
|
|||||||
"""
|
"""
|
||||||
self.func = func
|
self.func = func
|
||||||
|
|
||||||
|
def throttle(self, limit_ms: int):
|
||||||
|
"""Throttle the event handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit_ms: The time in milliseconds to throttle the event handler.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New EventHandler-like with throttle set to limit_ms.
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def debounce(self, delay_ms: int):
|
||||||
|
"""Debounce the event handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delay_ms: The time in milliseconds to debounce the event handler.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New EventHandler-like with debounce set to delay_ms.
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temporal(self):
|
||||||
|
"""Do not queue the event if the backend is down.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New EventHandler-like with temporal set to True.
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prevent_default(self):
|
def prevent_default(self):
|
||||||
"""Prevent default behavior.
|
"""Prevent default behavior.
|
||||||
|
@ -587,8 +587,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
if cls._item_is_event_handler(name, fn)
|
if cls._item_is_event_handler(name, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
for mixin in cls._mixins():
|
for mixin_class in cls._mixins():
|
||||||
for name, value in mixin.__dict__.items():
|
for name, value in mixin_class.__dict__.items():
|
||||||
if name in cls.inherited_vars:
|
if name in cls.inherited_vars:
|
||||||
continue
|
continue
|
||||||
if is_computed_var(value):
|
if is_computed_var(value):
|
||||||
@ -599,7 +599,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
|
|||||||
cls.computed_vars[newcv._js_expr] = newcv
|
cls.computed_vars[newcv._js_expr] = newcv
|
||||||
cls.vars[newcv._js_expr] = newcv
|
cls.vars[newcv._js_expr] = newcv
|
||||||
continue
|
continue
|
||||||
if types.is_backend_base_variable(name, mixin):
|
if types.is_backend_base_variable(name, mixin_class):
|
||||||
cls.backend_vars[name] = copy.deepcopy(value)
|
cls.backend_vars[name] = copy.deepcopy(value)
|
||||||
continue
|
continue
|
||||||
if events.get(name) is not None:
|
if events.get(name) is not None:
|
||||||
@ -3710,6 +3710,9 @@ def get_state_manager() -> StateManager:
|
|||||||
return app.state_manager
|
return app.state_manager
|
||||||
|
|
||||||
|
|
||||||
|
DATACLASS_FIELDS = getattr(dataclasses, "_FIELDS", "__dataclass_fields__")
|
||||||
|
|
||||||
|
|
||||||
class MutableProxy(wrapt.ObjectProxy):
|
class MutableProxy(wrapt.ObjectProxy):
|
||||||
"""A proxy for a mutable object that tracks changes."""
|
"""A proxy for a mutable object that tracks changes."""
|
||||||
|
|
||||||
@ -3781,12 +3784,7 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
cls.__dataclass_proxies__[wrapper_cls_name] = type(
|
cls.__dataclass_proxies__[wrapper_cls_name] = type(
|
||||||
wrapper_cls_name,
|
wrapper_cls_name,
|
||||||
(cls,),
|
(cls,),
|
||||||
{
|
{DATACLASS_FIELDS: getattr(wrapped_cls, DATACLASS_FIELDS)},
|
||||||
dataclasses._FIELDS: getattr( # pyright: ignore [reportGeneralTypeIssues]
|
|
||||||
wrapped_cls,
|
|
||||||
dataclasses._FIELDS, # pyright: ignore [reportGeneralTypeIssues]
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
cls = cls.__dataclass_proxies__[wrapper_cls_name]
|
cls = cls.__dataclass_proxies__[wrapper_cls_name]
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
@ -3933,11 +3931,11 @@ class MutableProxy(wrapt.ObjectProxy):
|
|||||||
if (
|
if (
|
||||||
isinstance(self.__wrapped__, Base)
|
isinstance(self.__wrapped__, Base)
|
||||||
and __name not in self.__never_wrap_base_attrs__
|
and __name not in self.__never_wrap_base_attrs__
|
||||||
and hasattr(value, "__func__")
|
and (value_func := getattr(value, "__func__", None))
|
||||||
):
|
):
|
||||||
# Wrap methods called on Base subclasses, which might do _anything_
|
# Wrap methods called on Base subclasses, which might do _anything_
|
||||||
return wrapt.FunctionWrapper(
|
return wrapt.FunctionWrapper(
|
||||||
functools.partial(value.__func__, self),
|
functools.partial(value_func, self),
|
||||||
self._wrap_recursive_decorator,
|
self._wrap_recursive_decorator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -67,10 +67,8 @@ try:
|
|||||||
from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports]
|
from selenium.webdriver.remote.webelement import ( # pyright: ignore [reportMissingImports]
|
||||||
WebElement,
|
WebElement,
|
||||||
)
|
)
|
||||||
|
|
||||||
has_selenium = True
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_selenium = False
|
webdriver = None
|
||||||
|
|
||||||
# The timeout (minutes) to check for the port.
|
# The timeout (minutes) to check for the port.
|
||||||
DEFAULT_TIMEOUT = 15
|
DEFAULT_TIMEOUT = 15
|
||||||
@ -293,8 +291,12 @@ class AppHarness:
|
|||||||
if p not in before_decorated_pages
|
if p not in before_decorated_pages
|
||||||
]
|
]
|
||||||
self.app_instance = self.app_module.app
|
self.app_instance = self.app_module.app
|
||||||
|
if self.app_instance is None:
|
||||||
|
raise RuntimeError("App was not initialized.")
|
||||||
if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
if isinstance(self.app_instance._state_manager, StateManagerRedis):
|
||||||
# Create our own redis connection for testing.
|
# Create our own redis connection for testing.
|
||||||
|
if self.app_instance.state is None:
|
||||||
|
raise RuntimeError("App state is not initialized.")
|
||||||
self.state_manager = StateManagerRedis.create(self.app_instance.state)
|
self.state_manager = StateManagerRedis.create(self.app_instance.state)
|
||||||
else:
|
else:
|
||||||
self.state_manager = self.app_instance._state_manager
|
self.state_manager = self.app_instance._state_manager
|
||||||
@ -608,7 +610,7 @@ class AppHarness:
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: when selenium is not importable or frontend is not running
|
RuntimeError: when selenium is not importable or frontend is not running
|
||||||
"""
|
"""
|
||||||
if not has_selenium:
|
if webdriver is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Frontend functionality requires `selenium` to be installed, "
|
"Frontend functionality requires `selenium` to be installed, "
|
||||||
"and it could not be imported."
|
"and it could not be imported."
|
||||||
|
@ -203,10 +203,13 @@ def _get_first_non_framework_frame() -> FrameType | None:
|
|||||||
# Exclude utility modules that should never be the source of deprecated reflex usage.
|
# Exclude utility modules that should never be the source of deprecated reflex usage.
|
||||||
exclude_modules = [click, rx, typer, typing_extensions]
|
exclude_modules = [click, rx, typer, typing_extensions]
|
||||||
exclude_roots = [
|
exclude_roots = [
|
||||||
|
(
|
||||||
p.parent.resolve()
|
p.parent.resolve()
|
||||||
if (p := Path(m.__file__)).name == "__init__.py"
|
if (p := Path(m.__file__)).name == "__init__.py"
|
||||||
else p.resolve()
|
else p.resolve()
|
||||||
|
)
|
||||||
for m in exclude_modules
|
for m in exclude_modules
|
||||||
|
if m.__file__
|
||||||
]
|
]
|
||||||
# Specifically exclude the reflex cli module.
|
# Specifically exclude the reflex cli module.
|
||||||
if reflex_bin := shutil.which(b"reflex"):
|
if reflex_bin := shutil.which(b"reflex"):
|
||||||
|
@ -3197,17 +3197,20 @@ class Field(Generic[T]):
|
|||||||
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
|
def __get__(self: Field[bool], instance: None, owner) -> BooleanVar: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(self: Field[int], instance: None, owner) -> NumberVar: ...
|
def __get__(self: Field[int], instance: None, owner) -> NumberVar[int]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(self: Field[str], instance: None, owner) -> StringVar: ...
|
def __get__(self: Field[float], instance: None, owner) -> NumberVar[float]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __get__(self: Field[str], instance: None, owner) -> StringVar[str]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(self: Field[None], instance: None, owner) -> NoneVar: ...
|
def __get__(self: Field[None], instance: None, owner) -> NoneVar: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(
|
def __get__(
|
||||||
self: Field[Sequence[V]] | Field[Set[V]],
|
self: Field[Sequence[V]] | Field[Set[V]] | Field[List[V]],
|
||||||
instance: None,
|
instance: None,
|
||||||
owner,
|
owner,
|
||||||
) -> ArrayVar[Sequence[V]]: ...
|
) -> ArrayVar[Sequence[V]]: ...
|
||||||
|
@ -1069,24 +1069,11 @@ def ternary_operation(
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
TUPLE_ENDS_IN_VAR = (
|
X = tuple[*tuple[Var, ...], str]
|
||||||
tuple[Var[VAR_TYPE]]
|
|
||||||
| tuple[Var, Var[VAR_TYPE]]
|
TUPLE_ENDS_IN_VAR = tuple[*tuple[Var[Any], ...], Var[VAR_TYPE]]
|
||||||
| tuple[Var, Var, Var[VAR_TYPE]]
|
|
||||||
| tuple[Var, Var, Var, Var[VAR_TYPE]]
|
TUPLE_ENDS_IN_VAR_RELAXED = tuple[*tuple[Var[Any] | Any, ...], Var[VAR_TYPE] | VAR_TYPE]
|
||||||
| tuple[Var, Var, Var, Var, Var[VAR_TYPE]]
|
|
||||||
| tuple[Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
|
||||||
| tuple[Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
|
||||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
|
||||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
|
||||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
|
||||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
|
||||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
|
||||||
| tuple[Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]]
|
|
||||||
| tuple[
|
|
||||||
Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var, Var[VAR_TYPE]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(
|
@dataclasses.dataclass(
|
||||||
@ -1153,7 +1140,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
|
|||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
cond: Any,
|
cond: Any,
|
||||||
cases: Sequence[Sequence[Any | Var[VAR_TYPE]]],
|
cases: Sequence[TUPLE_ENDS_IN_VAR_RELAXED[VAR_TYPE]],
|
||||||
default: Var[VAR_TYPE] | VAR_TYPE,
|
default: Var[VAR_TYPE] | VAR_TYPE,
|
||||||
_var_data: VarData | None = None,
|
_var_data: VarData | None = None,
|
||||||
_var_type: type[VAR_TYPE] | None = None,
|
_var_type: type[VAR_TYPE] | None = None,
|
||||||
@ -1175,6 +1162,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
|
|||||||
tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
|
tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
|
||||||
tuple(tuple(Var.create(c) for c in case) for case in cases),
|
tuple(tuple(Var.create(c) for c in case) for case in cases),
|
||||||
)
|
)
|
||||||
|
|
||||||
_default = cast(Var[VAR_TYPE], Var.create(default))
|
_default = cast(Var[VAR_TYPE], Var.create(default))
|
||||||
var_type = _var_type or unionize(
|
var_type = _var_type or unionize(
|
||||||
*(case[-1]._var_type for case in cases),
|
*(case[-1]._var_type for case in cases),
|
||||||
|
@ -45,7 +45,7 @@ from .base import (
|
|||||||
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
|
from .number import BooleanVar, NumberVar, raise_unsupported_operand_types
|
||||||
from .sequence import ArrayVar, StringVar
|
from .sequence import ArrayVar, StringVar
|
||||||
|
|
||||||
OBJECT_TYPE = TypeVar("OBJECT_TYPE")
|
OBJECT_TYPE = TypeVar("OBJECT_TYPE", covariant=True)
|
||||||
|
|
||||||
KEY_TYPE = TypeVar("KEY_TYPE")
|
KEY_TYPE = TypeVar("KEY_TYPE")
|
||||||
VALUE_TYPE = TypeVar("VALUE_TYPE")
|
VALUE_TYPE = TypeVar("VALUE_TYPE")
|
||||||
@ -164,7 +164,8 @@ class ObjectVar(Var[OBJECT_TYPE], python_types=dict):
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]],
|
self: ObjectVar[Dict[Any, Sequence[ARRAY_INNER_TYPE]]]
|
||||||
|
| ObjectVar[Dict[Any, List[ARRAY_INNER_TYPE]]],
|
||||||
key: Var | Any,
|
key: Var | Any,
|
||||||
) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...
|
) -> ArrayVar[Sequence[ARRAY_INNER_TYPE]]: ...
|
||||||
|
|
||||||
|
@ -28,9 +28,11 @@ def TestEventAction():
|
|||||||
def on_click2(self):
|
def on_click2(self):
|
||||||
self.order.append("on_click2")
|
self.order.append("on_click2")
|
||||||
|
|
||||||
|
@rx.event
|
||||||
def on_click_throttle(self):
|
def on_click_throttle(self):
|
||||||
self.order.append("on_click_throttle")
|
self.order.append("on_click_throttle")
|
||||||
|
|
||||||
|
@rx.event
|
||||||
def on_click_debounce(self):
|
def on_click_debounce(self):
|
||||||
self.order.append("on_click_debounce")
|
self.order.append("on_click_debounce")
|
||||||
|
|
||||||
|
@ -22,9 +22,9 @@ def LifespanApp():
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan_context(app, inc: int = 1):
|
async def lifespan_context(app, inc: int = 1):
|
||||||
global lifespan_context_global
|
nonlocal lifespan_context_global
|
||||||
print(f"Lifespan context entered: {app}.")
|
print(f"Lifespan context entered: {app}.")
|
||||||
lifespan_context_global += inc # pyright: ignore[reportUnboundVariable]
|
lifespan_context_global += inc
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -32,11 +32,11 @@ def LifespanApp():
|
|||||||
lifespan_context_global += inc
|
lifespan_context_global += inc
|
||||||
|
|
||||||
async def lifespan_task(inc: int = 1):
|
async def lifespan_task(inc: int = 1):
|
||||||
global lifespan_task_global
|
nonlocal lifespan_task_global
|
||||||
print("Lifespan global started.")
|
print("Lifespan global started.")
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
lifespan_task_global += inc # pyright: ignore[reportUnboundVariable]
|
lifespan_task_global += inc
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
except asyncio.CancelledError as ce:
|
except asyncio.CancelledError as ce:
|
||||||
print(f"Lifespan global cancelled: {ce}.")
|
print(f"Lifespan global cancelled: {ce}.")
|
||||||
|
@ -19,25 +19,27 @@ def VarOperations():
|
|||||||
from reflex.vars.sequence import ArrayVar
|
from reflex.vars.sequence import ArrayVar
|
||||||
|
|
||||||
class Object(rx.Base):
|
class Object(rx.Base):
|
||||||
str: str = "hello"
|
name: str = "hello"
|
||||||
|
|
||||||
class VarOperationState(rx.State):
|
class VarOperationState(rx.State):
|
||||||
int_var1: int = 10
|
int_var1: rx.Field[int] = rx.field(10)
|
||||||
int_var2: int = 5
|
int_var2: rx.Field[int] = rx.field(5)
|
||||||
int_var3: int = 7
|
int_var3: rx.Field[int] = rx.field(7)
|
||||||
float_var1: float = 10.5
|
float_var1: rx.Field[float] = rx.field(10.5)
|
||||||
float_var2: float = 5.5
|
float_var2: rx.Field[float] = rx.field(5.5)
|
||||||
list1: List = [1, 2]
|
list1: rx.Field[List[int]] = rx.field([1, 2])
|
||||||
list2: List = [3, 4]
|
list2: rx.Field[List[int]] = rx.field([3, 4])
|
||||||
list3: List = ["first", "second", "third"]
|
list3: rx.Field[List[str]] = rx.field(["first", "second", "third"])
|
||||||
list4: List = [Object(name="obj_1"), Object(name="obj_2")]
|
list4: rx.Field[List[Object]] = rx.field(
|
||||||
str_var1: str = "first"
|
[Object(name="obj_1"), Object(name="obj_2")]
|
||||||
str_var2: str = "second"
|
)
|
||||||
str_var3: str = "ThIrD"
|
str_var1: rx.Field[str] = rx.field("first")
|
||||||
str_var4: str = "a long string"
|
str_var2: rx.Field[str] = rx.field("second")
|
||||||
dict1: Dict[int, int] = {1: 2}
|
str_var3: rx.Field[str] = rx.field("ThIrD")
|
||||||
dict2: Dict[int, int] = {3: 4}
|
str_var4: rx.Field[str] = rx.field("a long string")
|
||||||
html_str: str = "<div>hello</div>"
|
dict1: rx.Field[Dict[int, int]] = rx.field({1: 2})
|
||||||
|
dict2: rx.Field[Dict[int, int]] = rx.field({3: 4})
|
||||||
|
html_str: rx.Field[str] = rx.field("<div>hello</div>")
|
||||||
|
|
||||||
app = rx.App(state=rx.State)
|
app = rx.App(state=rx.State)
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from reflex.vars.base import LiteralVar, Var, computed_var
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def cond_state(request):
|
def cond_state(request):
|
||||||
class CondState(BaseState):
|
class CondState(BaseState):
|
||||||
value: request.param["value_type"] = request.param["value"] # noqa
|
value: request.param["value_type"] = request.param["value"] # pyright: ignore[reportInvalidTypeForm, reportUndefinedVariable] # noqa: F821
|
||||||
|
|
||||||
return CondState
|
return CondState
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ def test_match_vars(cases, expected):
|
|||||||
cases: The match cases.
|
cases: The match cases.
|
||||||
expected: The expected var full name.
|
expected: The expected var full name.
|
||||||
"""
|
"""
|
||||||
match_comp = Match.create(MatchState.value, *cases)
|
match_comp = Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
|
||||||
assert isinstance(match_comp, Var)
|
assert isinstance(match_comp, Var)
|
||||||
assert str(match_comp) == expected
|
assert str(match_comp) == expected
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ def test_match_default_not_last_arg(match_case):
|
|||||||
ValueError,
|
ValueError,
|
||||||
match="rx.match should have tuples of cases and a default case as the last argument.",
|
match="rx.match should have tuples of cases and a default case as the last argument.",
|
||||||
):
|
):
|
||||||
Match.create(MatchState.value, *match_case)
|
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -161,7 +161,7 @@ def test_match_case_tuple_elements(match_case):
|
|||||||
ValueError,
|
ValueError,
|
||||||
match="A case tuple should have at least a match case element and a return value.",
|
match="A case tuple should have at least a match case element and a return value.",
|
||||||
):
|
):
|
||||||
Match.create(MatchState.value, *match_case)
|
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -203,7 +203,7 @@ def test_match_different_return_types(cases: Tuple, error_msg: str):
|
|||||||
error_msg: Expected error message.
|
error_msg: Expected error message.
|
||||||
"""
|
"""
|
||||||
with pytest.raises(MatchTypeError, match=error_msg):
|
with pytest.raises(MatchTypeError, match=error_msg):
|
||||||
Match.create(MatchState.value, *cases)
|
Match.create(MatchState.value, *cases) # pyright: ignore[reportCallIssue]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -235,9 +235,9 @@ def test_match_multiple_default_cases(match_case):
|
|||||||
match_case: the cases to match.
|
match_case: the cases to match.
|
||||||
"""
|
"""
|
||||||
with pytest.raises(ValueError, match="rx.match can only have one default case."):
|
with pytest.raises(ValueError, match="rx.match can only have one default case."):
|
||||||
Match.create(MatchState.value, *match_case)
|
Match.create(MatchState.value, *match_case) # pyright: ignore[reportCallIssue]
|
||||||
|
|
||||||
|
|
||||||
def test_match_no_cond():
|
def test_match_no_cond():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_ = Match.create(None)
|
_ = Match.create(None) # pyright: ignore[reportCallIssue]
|
||||||
|
@ -13,7 +13,8 @@ from reflex.utils.serializers import serialize, serialize_dataframe
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"data": pd.DataFrame(
|
"data": pd.DataFrame(
|
||||||
[["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
|
[["foo", "bar"], ["foo1", "bar1"]],
|
||||||
|
columns=["column1", "column2"], # pyright: ignore [reportArgumentType]
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
"data",
|
"data",
|
||||||
@ -113,7 +114,8 @@ def test_computed_var_without_annotation(fixture, request, err_msg, is_data_fram
|
|||||||
def test_serialize_dataframe():
|
def test_serialize_dataframe():
|
||||||
"""Test if dataframe is serialized correctly."""
|
"""Test if dataframe is serialized correctly."""
|
||||||
df = pd.DataFrame(
|
df = pd.DataFrame(
|
||||||
[["foo", "bar"], ["foo1", "bar1"]], columns=["column1", "column2"]
|
[["foo", "bar"], ["foo1", "bar1"]],
|
||||||
|
columns=["column1", "column2"], # pyright: ignore [reportArgumentType]
|
||||||
)
|
)
|
||||||
value = serialize(df)
|
value = serialize(df)
|
||||||
assert value == serialize_dataframe(df)
|
assert value == serialize_dataframe(df)
|
||||||
|
@ -9,7 +9,7 @@ import unittest.mock
|
|||||||
import uuid
|
import uuid
|
||||||
from contextlib import nullcontext as does_not_raise
|
from contextlib import nullcontext as does_not_raise
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, List, Tuple, Type
|
from typing import Generator, List, Tuple, Type, cast
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -33,7 +33,7 @@ from reflex.components import Component
|
|||||||
from reflex.components.base.fragment import Fragment
|
from reflex.components.base.fragment import Fragment
|
||||||
from reflex.components.core.cond import Cond
|
from reflex.components.core.cond import Cond
|
||||||
from reflex.components.radix.themes.typography.text import Text
|
from reflex.components.radix.themes.typography.text import Text
|
||||||
from reflex.event import Event
|
from reflex.event import Event, EventHandler
|
||||||
from reflex.middleware import HydrateMiddleware
|
from reflex.middleware import HydrateMiddleware
|
||||||
from reflex.model import Model
|
from reflex.model import Model
|
||||||
from reflex.state import (
|
from reflex.state import (
|
||||||
@ -917,7 +917,7 @@ class DynamicState(BaseState):
|
|||||||
"""
|
"""
|
||||||
return self.dynamic
|
return self.dynamic
|
||||||
|
|
||||||
on_load_internal = OnLoadInternalState.on_load_internal.fn
|
on_load_internal = cast(EventHandler, OnLoadInternalState.on_load_internal).fn
|
||||||
|
|
||||||
|
|
||||||
def test_dynamic_arg_shadow(
|
def test_dynamic_arg_shadow(
|
||||||
@ -1190,7 +1190,7 @@ async def test_process_events(mocker, token: str):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
assert (await app.state_manager.get_state(event.substate_token)).value == 5
|
assert (await app.state_manager.get_state(event.substate_token)).value == 5
|
||||||
assert app._postprocess.call_count == 6
|
assert getattr(app._postprocess, "call_count", None) == 6
|
||||||
|
|
||||||
if isinstance(app.state_manager, StateManagerRedis):
|
if isinstance(app.state_manager, StateManagerRedis):
|
||||||
await app.state_manager.close()
|
await app.state_manager.close()
|
||||||
@ -1247,7 +1247,7 @@ def test_overlay_component(
|
|||||||
|
|
||||||
if exp_page_child is not None:
|
if exp_page_child is not None:
|
||||||
assert len(page.children) == 3
|
assert len(page.children) == 3
|
||||||
children_types = (type(child) for child in page.children)
|
children_types = [type(child) for child in page.children]
|
||||||
assert exp_page_child in children_types
|
assert exp_page_child in children_types
|
||||||
else:
|
else:
|
||||||
assert len(page.children) == 2
|
assert len(page.children) == 2
|
||||||
|
@ -5,6 +5,7 @@ import pytest
|
|||||||
import reflex as rx
|
import reflex as rx
|
||||||
from reflex.event import (
|
from reflex.event import (
|
||||||
Event,
|
Event,
|
||||||
|
EventActionsMixin,
|
||||||
EventChain,
|
EventChain,
|
||||||
EventHandler,
|
EventHandler,
|
||||||
EventSpec,
|
EventSpec,
|
||||||
@ -410,6 +411,7 @@ def test_event_actions():
|
|||||||
|
|
||||||
def test_event_actions_on_state():
|
def test_event_actions_on_state():
|
||||||
class EventActionState(BaseState):
|
class EventActionState(BaseState):
|
||||||
|
@rx.event
|
||||||
def handler(self):
|
def handler(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -418,6 +420,7 @@ def test_event_actions_on_state():
|
|||||||
assert not handler.event_actions
|
assert not handler.event_actions
|
||||||
|
|
||||||
sp_handler = EventActionState.handler.stop_propagation
|
sp_handler = EventActionState.handler.stop_propagation
|
||||||
|
assert isinstance(sp_handler, EventActionsMixin)
|
||||||
assert sp_handler.event_actions == {"stopPropagation": True}
|
assert sp_handler.event_actions == {"stopPropagation": True}
|
||||||
# should NOT affect other references to the handler
|
# should NOT affect other references to the handler
|
||||||
assert not handler.event_actions
|
assert not handler.event_actions
|
||||||
|
@ -122,9 +122,12 @@ async def test_health(
|
|||||||
# Call the async health function
|
# Call the async health function
|
||||||
response = await health()
|
response = await health()
|
||||||
|
|
||||||
print(json.loads(response.body))
|
body = response.body
|
||||||
|
assert isinstance(body, bytes)
|
||||||
|
|
||||||
|
print(json.loads(body))
|
||||||
print(expected_status)
|
print(expected_status)
|
||||||
|
|
||||||
# Verify the response content and status code
|
# Verify the response content and status code
|
||||||
assert response.status_code == expected_code
|
assert response.status_code == expected_code
|
||||||
assert json.loads(response.body) == expected_status
|
assert json.loads(body) == expected_status
|
||||||
|
@ -59,7 +59,7 @@ def test_automigration(
|
|||||||
id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
|
id: Mapped[Optional[int]] = mapped_column(primary_key=True, default=None)
|
||||||
|
|
||||||
# initial table
|
# initial table
|
||||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
|
||||||
t1: Mapped[str] = mapped_column(default="")
|
t1: Mapped[str] = mapped_column(default="")
|
||||||
|
|
||||||
with Model.get_db_engine().connect() as connection:
|
with Model.get_db_engine().connect() as connection:
|
||||||
@ -78,7 +78,7 @@ def test_automigration(
|
|||||||
model_registry.get_metadata().clear()
|
model_registry.get_metadata().clear()
|
||||||
|
|
||||||
# Create column t2, mark t1 as optional with default
|
# Create column t2, mark t1 as optional with default
|
||||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
|
||||||
t1: Mapped[Optional[str]] = mapped_column(default="default")
|
t1: Mapped[Optional[str]] = mapped_column(default="default")
|
||||||
t2: Mapped[str] = mapped_column(default="bar")
|
t2: Mapped[str] = mapped_column(default="bar")
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ def test_automigration(
|
|||||||
model_registry.get_metadata().clear()
|
model_registry.get_metadata().clear()
|
||||||
|
|
||||||
# Drop column t1
|
# Drop column t1
|
||||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
|
||||||
t2: Mapped[str] = mapped_column(default="bar")
|
t2: Mapped[str] = mapped_column(default="bar")
|
||||||
|
|
||||||
assert Model.migrate(autogenerate=True)
|
assert Model.migrate(autogenerate=True)
|
||||||
@ -133,7 +133,7 @@ def test_automigration(
|
|||||||
# drop table (AlembicSecond)
|
# drop table (AlembicSecond)
|
||||||
model_registry.get_metadata().clear()
|
model_registry.get_metadata().clear()
|
||||||
|
|
||||||
class AlembicThing(ModelBase): # pyright: ignore[reportGeneralTypeIssues]
|
class AlembicThing(ModelBase): # pyright: ignore[reportRedeclaration]
|
||||||
t2: Mapped[str] = mapped_column(default="bar")
|
t2: Mapped[str] = mapped_column(default="bar")
|
||||||
|
|
||||||
assert Model.migrate(autogenerate=True)
|
assert Model.migrate(autogenerate=True)
|
||||||
|
@ -17,6 +17,7 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
@ -120,8 +121,8 @@ class TestState(BaseState):
|
|||||||
num2: float = 3.14
|
num2: float = 3.14
|
||||||
key: str
|
key: str
|
||||||
map_key: str = "a"
|
map_key: str = "a"
|
||||||
array: List[float] = [1, 2, 3.14]
|
array: rx.Field[List[float]] = rx.field([1, 2, 3.14])
|
||||||
mapping: Dict[str, List[int]] = {"a": [1, 2, 3], "b": [4, 5, 6]}
|
mapping: rx.Field[Dict[str, List[int]]] = rx.field({"a": [1, 2, 3], "b": [4, 5, 6]})
|
||||||
obj: Object = Object()
|
obj: Object = Object()
|
||||||
complex: Dict[int, Object] = {1: Object(), 2: Object()}
|
complex: Dict[int, Object] = {1: Object(), 2: Object()}
|
||||||
fig: Figure = Figure()
|
fig: Figure = Figure()
|
||||||
@ -1357,6 +1358,7 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
|||||||
class HandlerState(BaseState):
|
class HandlerState(BaseState):
|
||||||
x: int = 42
|
x: int = 42
|
||||||
|
|
||||||
|
@rx.event
|
||||||
def handler(self):
|
def handler(self):
|
||||||
self.x = self.x + 1
|
self.x = self.x + 1
|
||||||
|
|
||||||
@ -1367,11 +1369,11 @@ def test_cached_var_depends_on_event_handler(use_partial: bool):
|
|||||||
counter += 1
|
counter += 1
|
||||||
return counter
|
return counter
|
||||||
|
|
||||||
if use_partial:
|
|
||||||
HandlerState.handler = functools.partial(HandlerState.handler.fn)
|
|
||||||
assert isinstance(HandlerState.handler, functools.partial)
|
|
||||||
else:
|
|
||||||
assert isinstance(HandlerState.handler, EventHandler)
|
assert isinstance(HandlerState.handler, EventHandler)
|
||||||
|
if use_partial:
|
||||||
|
partial_guy = functools.partial(HandlerState.handler.fn)
|
||||||
|
HandlerState.handler = partial_guy # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
assert isinstance(HandlerState.handler, functools.partial)
|
||||||
|
|
||||||
s = HandlerState()
|
s = HandlerState()
|
||||||
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
|
assert "cached_x_side_effect" in s._computed_var_dependencies["x"]
|
||||||
@ -2025,8 +2027,11 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
|
|||||||
|
|
||||||
# ensure state update was emitted
|
# ensure state update was emitted
|
||||||
assert mock_app.event_namespace is not None
|
assert mock_app.event_namespace is not None
|
||||||
mock_app.event_namespace.emit.assert_called_once()
|
mock_app.event_namespace.emit.assert_called_once() # pyright: ignore[reportFunctionMemberAccess]
|
||||||
mcall = mock_app.event_namespace.emit.mock_calls[0]
|
mock_calls = getattr(mock_app.event_namespace.emit, "mock_calls", None)
|
||||||
|
assert mock_calls is not None
|
||||||
|
assert isinstance(mock_calls, Sequence)
|
||||||
|
mcall = mock_calls[0]
|
||||||
assert mcall.args[0] == str(SocketEvent.EVENT)
|
assert mcall.args[0] == str(SocketEvent.EVENT)
|
||||||
assert mcall.args[1] == StateUpdate(
|
assert mcall.args[1] == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
@ -2231,7 +2236,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
|||||||
assert mock_app.event_namespace is not None
|
assert mock_app.event_namespace is not None
|
||||||
emit_mock = mock_app.event_namespace.emit
|
emit_mock = mock_app.event_namespace.emit
|
||||||
|
|
||||||
first_ws_message = emit_mock.mock_calls[0].args[1]
|
mock_calls = getattr(emit_mock, "mock_calls", None)
|
||||||
|
assert mock_calls is not None
|
||||||
|
assert isinstance(mock_calls, Sequence)
|
||||||
|
|
||||||
|
first_ws_message = mock_calls[0].args[1]
|
||||||
assert (
|
assert (
|
||||||
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
|
first_ws_message.delta[BackgroundTaskState.get_full_name()].pop("router")
|
||||||
is not None
|
is not None
|
||||||
@ -2246,7 +2255,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
|||||||
events=[],
|
events=[],
|
||||||
final=True,
|
final=True,
|
||||||
)
|
)
|
||||||
for call in emit_mock.mock_calls[1:5]:
|
for call in mock_calls[1:5]:
|
||||||
assert call.args[1] == StateUpdate(
|
assert call.args[1] == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
BackgroundTaskState.get_full_name(): {
|
BackgroundTaskState.get_full_name(): {
|
||||||
@ -2256,7 +2265,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
|||||||
events=[],
|
events=[],
|
||||||
final=True,
|
final=True,
|
||||||
)
|
)
|
||||||
assert emit_mock.mock_calls[-2].args[1] == StateUpdate(
|
assert mock_calls[-2].args[1] == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
BackgroundTaskState.get_full_name(): {
|
BackgroundTaskState.get_full_name(): {
|
||||||
"order": exp_order,
|
"order": exp_order,
|
||||||
@ -2267,7 +2276,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
|
|||||||
events=[],
|
events=[],
|
||||||
final=True,
|
final=True,
|
||||||
)
|
)
|
||||||
assert emit_mock.mock_calls[-1].args[1] == StateUpdate(
|
assert mock_calls[-1].args[1] == StateUpdate(
|
||||||
delta={
|
delta={
|
||||||
BackgroundTaskState.get_full_name(): {
|
BackgroundTaskState.get_full_name(): {
|
||||||
"computed_order": exp_order,
|
"computed_order": exp_order,
|
||||||
|
Loading…
Reference in New Issue
Block a user