Catch unhandled errors on both frontend and backend (#3572)

This commit is contained in:
Maxim Vlah 2024-07-09 04:26:11 +02:00 committed by GitHub
parent 6d3321284c
commit 772a3ef893
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 771 additions and 15 deletions

View File

@ -0,0 +1,152 @@
"""Integration tests for event exception handlers."""
from __future__ import annotations
import time
from typing import Generator, Type
import pytest
from selenium.webdriver.common.by import By
from selenium.webdriver.remote.webdriver import WebDriver
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait
from reflex.testing import AppHarness
def TestApp():
"""A test app for event exception handler integration."""
import reflex as rx
class TestAppConfig(rx.Config):
"""Config for the TestApp app."""
class TestAppState(rx.State):
"""State for the TestApp app."""
def divide_by_number(self, number: int):
"""Divide by number and print the result.
Args:
number: number to divide by
"""
print(1 / number)
app = rx.App(state=rx.State)
@app.add_page
def index():
return rx.vstack(
rx.button(
"induce_frontend_error",
on_click=rx.call_script("induce_frontend_error()"),
id="induce-frontend-error-btn",
),
rx.button(
"induce_backend_error",
on_click=lambda: TestAppState.divide_by_number(0), # type: ignore
id="induce-backend-error-btn",
),
)
@pytest.fixture(scope="module")
def test_app(
app_harness_env: Type[AppHarness], tmp_path_factory
) -> Generator[AppHarness, None, None]:
"""Start TestApp app at tmp_path via AppHarness.
Args:
app_harness_env: either AppHarness (dev) or AppHarnessProd (prod)
tmp_path_factory: pytest tmp_path_factory fixture
Yields:
running AppHarness instance
"""
with app_harness_env.create(
root=tmp_path_factory.mktemp("test_app"),
app_name=f"testapp_{app_harness_env.__name__.lower()}",
app_source=TestApp, # type: ignore
) as harness:
yield harness
@pytest.fixture
def driver(test_app: AppHarness) -> Generator[WebDriver, None, None]:
"""Get an instance of the browser open to the test_app app.
Args:
test_app: harness for TestApp app
Yields:
WebDriver instance.
"""
assert test_app.app_instance is not None, "app is not running"
driver = test_app.frontend()
try:
yield driver
finally:
driver.quit()
def test_frontend_exception_handler_during_runtime(
driver: WebDriver,
capsys,
):
"""Test calling frontend exception handler during runtime.
We send an event containing a call to a non-existent function in the frontend.
This should trigger the default frontend exception handler.
Args:
driver: WebDriver instance.
capsys: pytest fixture for capturing stdout and stderr.
"""
reset_button = WebDriverWait(driver, 20).until(
EC.element_to_be_clickable((By.ID, "induce-frontend-error-btn"))
)
reset_button.click()
# Wait for the error to be logged
time.sleep(2)
captured_default_handler_output = capsys.readouterr()
assert (
"induce_frontend_error" in captured_default_handler_output.out
and "ReferenceError" in captured_default_handler_output.out
)
def test_backend_exception_handler_during_runtime(
driver: WebDriver,
capsys,
):
"""Test calling backend exception handler during runtime.
We invoke TestAppState.divide_by_zero to induce backend error.
This should trigger the default backend exception handler.
Args:
driver: WebDriver instance.
capsys: pytest fixture for capturing stdout and stderr.
"""
reset_button = WebDriverWait(driver, 20).until(
EC.element_to_be_clickable((By.ID, "induce-backend-error-btn"))
)
reset_button.click()
# Wait for the error to be logged
time.sleep(2)
captured_default_handler_output = capsys.readouterr()
assert (
"divide_by_number" in captured_default_handler_output.out
and "ZeroDivisionError" in captured_default_handler_output.out
)

View File

@ -247,6 +247,9 @@ export const applyEvent = async (event, socket) => {
}
} catch (e) {
console.log("_call_script", e);
if (window && window?.onerror) {
window.onerror(e.message, null, null, null, e)
}
}
return false;
}
@ -687,6 +690,31 @@ export const useEventLoop = (
}
}, [router.isReady]);
// Handle frontend errors and send them to the backend via websocket.
useEffect(() => {
if (typeof window === 'undefined') {
return;
}
window.onerror = function (msg, url, lineNo, columnNo, error) {
addEvents([Event("state.frontend_event_exception_state.handle_frontend_exception", {
stack: error.stack,
})])
return false;
}
//NOTE: Only works in Chrome v49+
//https://github.com/mknichel/javascript-errors?tab=readme-ov-file#promise-rejection-events
window.onunhandledrejection = function (event) {
addEvents([Event("state.frontend_event_exception_state.handle_frontend_exception", {
stack: event.reason.stack,
})])
return false;
}
},[])
// Main event loop.
useEffect(() => {
// Skip if the router is not ready.

View File

@ -7,11 +7,13 @@ import concurrent.futures
import contextlib
import copy
import functools
import inspect
import io
import multiprocessing
import os
import platform
import sys
import traceback
from datetime import datetime
from typing import (
Any,
@ -45,6 +47,7 @@ from reflex.compiler import compiler
from reflex.compiler import utils as compiler_utils
from reflex.compiler.compiler import ExecutorSafeFunctions
from reflex.components.base.app_wrap import AppWrap
from reflex.components.base.error_boundary import ErrorBoundary
from reflex.components.base.fragment import Fragment
from reflex.components.component import (
Component,
@ -60,7 +63,7 @@ from reflex.components.core.client_side_routing import (
from reflex.components.core.upload import Upload, get_upload_dir
from reflex.components.radix import themes
from reflex.config import get_config
from reflex.event import Event, EventHandler, EventSpec
from reflex.event import Event, EventHandler, EventSpec, window_alert
from reflex.model import Model
from reflex.page import (
DECORATED_PAGES,
@ -88,6 +91,33 @@ ComponentCallable = Callable[[], Component]
Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]]
def default_frontend_exception_handler(exception: Exception) -> None:
"""Default frontend exception handler function.
Args:
exception: The exception.
"""
console.error(f"[Reflex Frontend Exception]\n {exception}\n")
def default_backend_exception_handler(exception: Exception) -> EventSpec:
"""Default backend exception handler function.
Args:
exception: The exception.
Returns:
EventSpec: The window alert event.
"""
error = traceback.format_exc()
console.error(f"[Reflex Backend Exception]\n {error}\n")
return window_alert("An error occurred. See logs for details.")
def default_overlay_component() -> Component:
"""Default overlay_component attribute for App.
@ -101,6 +131,16 @@ def default_overlay_component() -> Component:
)
def default_error_boundary() -> Component:
"""Default error_boundary attribute for App.
Returns:
The default error_boundary, which is an ErrorBoundary.
"""
return ErrorBoundary.create()
class OverlayFragment(Fragment):
"""Alias for Fragment, used to wrap the overlay_component."""
@ -142,6 +182,11 @@ class App(MiddlewareMixin, LifespanMixin, Base):
default_overlay_component
)
# Error boundary component to wrap the app with.
error_boundary: Optional[Union[Component, ComponentCallable]] = (
default_error_boundary
)
# Components to add to the head of every page.
head_components: List[Component] = []
@ -178,6 +223,16 @@ class App(MiddlewareMixin, LifespanMixin, Base):
# Background tasks that are currently running. PRIVATE.
background_tasks: Set[asyncio.Task] = set()
# Frontend Error Handler Function
frontend_exception_handler: Callable[[Exception], None] = (
default_frontend_exception_handler
)
# Backend Error Handler Function
backend_exception_handler: Callable[
[Exception], Union[EventSpec, List[EventSpec], None]
] = default_backend_exception_handler
def __init__(self, **kwargs):
"""Initialize the app.
@ -279,6 +334,9 @@ class App(MiddlewareMixin, LifespanMixin, Base):
# Mount the socket app with the API.
self.api.mount(str(constants.Endpoint.EVENT), socket_app)
# Check the exception handlers
self._validate_exception_handlers()
def __repr__(self) -> str:
"""Get the string representation of the app.
@ -688,6 +746,25 @@ class App(MiddlewareMixin, LifespanMixin, Base):
for k, component in self.pages.items():
self.pages[k] = self._add_overlay_to_component(component)
def _add_error_boundary_to_component(self, component: Component) -> Component:
if self.error_boundary is None:
return component
component = ErrorBoundary.create(*component.children)
return component
def _setup_error_boundary(self):
"""If a State is not used and no error_boundary is specified, do not render the error boundary."""
if self.state is None and self.error_boundary is default_error_boundary:
self.error_boundary = None
for k, component in self.pages.items():
# Skip the 404 page
if k == constants.Page404.SLUG:
continue
self.pages[k] = self._add_error_boundary_to_component(component)
def _apply_decorated_pages(self):
"""Add @rx.page decorated pages to the app.
@ -757,6 +834,7 @@ class App(MiddlewareMixin, LifespanMixin, Base):
self._validate_var_dependencies()
self._setup_overlay_component()
self._setup_error_boundary()
# Create a progress bar.
progress = Progress(
@ -1036,6 +1114,100 @@ class App(MiddlewareMixin, LifespanMixin, Base):
task.add_done_callback(self.background_tasks.discard)
return task
def _validate_exception_handlers(self):
"""Validate the custom event exception handlers for front- and backend.
Raises:
ValueError: If the custom exception handlers are invalid.
"""
FRONTEND_ARG_SPEC = {
"exception": Exception,
}
BACKEND_ARG_SPEC = {
"exception": Exception,
}
for handler_domain, handler_fn, handler_spec in zip(
["frontend", "backend"],
[self.frontend_exception_handler, self.backend_exception_handler],
[
FRONTEND_ARG_SPEC,
BACKEND_ARG_SPEC,
],
):
if hasattr(handler_fn, "__name__"):
_fn_name = handler_fn.__name__
else:
_fn_name = handler_fn.__class__.__name__
if isinstance(handler_fn, functools.partial):
raise ValueError(
f"Provided custom {handler_domain} exception handler `{_fn_name}` is a partial function. Please provide a named function instead."
)
if not callable(handler_fn):
raise ValueError(
f"Provided custom {handler_domain} exception handler `{_fn_name}` is not a function."
)
# Allow named functions only as lambda functions cannot be introspected
if _fn_name == "<lambda>":
raise ValueError(
f"Provided custom {handler_domain} exception handler `{_fn_name}` is a lambda function. Please use a named function instead."
)
# Check if the function has the necessary annotations and types in the right order
argspec = inspect.getfullargspec(handler_fn)
arg_annotations = {
k: eval(v) if isinstance(v, str) else v
for k, v in argspec.annotations.items()
if k not in ["args", "kwargs", "return"]
}
for required_arg_index, required_arg in enumerate(handler_spec):
if required_arg not in arg_annotations:
raise ValueError(
f"Provided custom {handler_domain} exception handler `{_fn_name}` does not take the required argument `{required_arg}`"
)
elif (
not list(arg_annotations.keys())[required_arg_index] == required_arg
):
raise ValueError(
f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong argument order."
f"Expected `{required_arg}` as the {required_arg_index+1} argument but got `{list(arg_annotations.keys())[required_arg_index]}`"
)
if not issubclass(arg_annotations[required_arg], Exception):
raise ValueError(
f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong type for {required_arg} argument."
f"Expected to be `Exception` but got `{arg_annotations[required_arg]}`"
)
# Check if the return type is valid for backend exception handler
if handler_domain == "backend":
sig = inspect.signature(self.backend_exception_handler)
return_type = (
eval(sig.return_annotation)
if isinstance(sig.return_annotation, str)
else sig.return_annotation
)
valid = bool(
return_type == EventSpec
or return_type == Optional[EventSpec]
or return_type == List[EventSpec]
or return_type == inspect.Signature.empty
or return_type is None
)
if not valid:
raise ValueError(
f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong return type."
f"Expected `Union[EventSpec, List[EventSpec], None]` but got `{return_type}`"
)
async def process(
app: App, event: Event, sid: str, headers: Dict, client_ip: str
@ -1101,6 +1273,8 @@ async def process(
yield update
except Exception as ex:
telemetry.send_error(ex, context="backend")
app.backend_exception_handler(ex)
raise

View File

@ -13,6 +13,10 @@ _SUBMOD_ATTRS: dict[str, list[str]] = {
"Fragment",
"fragment",
],
"error_boundary": [
"ErrorBoundary",
"error_boundary",
],
"head": [
"head",
"Head",

View File

@ -10,6 +10,8 @@ from .document import DocumentHead as DocumentHead
from .document import Html as Html
from .document import Main as Main
from .document import NextScript as NextScript
from .error_boundary import ErrorBoundary as ErrorBoundary
from .error_boundary import error_boundary as error_boundary
from .fragment import Fragment as Fragment
from .fragment import fragment as fragment
from .head import Head as Head

View File

@ -0,0 +1,78 @@
"""A React Error Boundary component that catches unhandled frontend exceptions."""
from __future__ import annotations
from typing import List
from reflex.compiler.compiler import _compile_component
from reflex.components.component import Component
from reflex.components.el import div, p
from reflex.constants import Hooks, Imports
from reflex.event import EventChain, EventHandler
from reflex.utils.imports import ImportVar
from reflex.vars import Var
class ErrorBoundary(Component):
"""A React Error Boundary component that catches unhandled frontend exceptions."""
library = "react-error-boundary"
tag = "ErrorBoundary"
# Fired when the boundary catches an error.
on_error: EventHandler[lambda error, info: [error, info]] = Var.create_safe( # type: ignore
"logFrontendError", _var_is_string=False, _var_is_local=False
).to(EventChain)
# Rendered instead of the children when an error is caught.
Fallback_component: Var[Component] = Var.create_safe(
"Fallback", _var_is_string=False, _var_is_local=False
).to(Component)
def add_imports(self) -> dict[str, list[ImportVar]]:
"""Add imports for the component.
Returns:
The imports to add.
"""
return Imports.EVENTS
def add_hooks(self) -> List[str | Var]:
"""Add hooks for the component.
Returns:
The hooks to add.
"""
return [Hooks.EVENTS, Hooks.FRONTEND_ERRORS]
def add_custom_code(self) -> List[str]:
"""Add custom Javascript code into the page that contains this component.
Custom code is inserted at module level, after any imports.
Returns:
The custom code to add.
"""
fallback_container = div(
p("Ooops...Unknown Reflex error has occured:"),
p(
Var.create("error.message", _var_is_local=False, _var_is_string=False),
color="red",
),
p("Please contact the support."),
)
compiled_fallback = _compile_component(fallback_container)
return [
f"""
function Fallback({{ error, resetErrorBoundary }}) {{
return (
{compiled_fallback}
);
}}
"""
]
error_boundary = ErrorBoundary.create

View File

@ -0,0 +1,98 @@
"""Stub file for reflex/components/base/error_boundary.py"""
# ------------------- DO NOT EDIT ----------------------
# This file was generated by `reflex/utils/pyi_generator.py`!
# ------------------------------------------------------
from typing import Any, Callable, Dict, List, Optional, Union, overload
from reflex.components.component import Component
from reflex.event import EventHandler, EventSpec
from reflex.style import Style
from reflex.utils.imports import ImportVar
from reflex.vars import BaseVar, Var
class ErrorBoundary(Component):
def add_imports(self) -> dict[str, list[ImportVar]]: ...
def add_hooks(self) -> List[str | Var]: ...
def add_custom_code(self) -> List[str]: ...
@overload
@classmethod
def create( # type: ignore
cls,
*children,
Fallback_component: Optional[Union[Var[Component], Component]] = None,
style: Optional[Style] = None,
key: Optional[Any] = None,
id: Optional[Any] = None,
class_name: Optional[Any] = None,
autofocus: Optional[bool] = None,
custom_attrs: Optional[Dict[str, Union[Var, str]]] = None,
on_blur: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_click: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_context_menu: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_double_click: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_error: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_focus: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_mount: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_mouse_down: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_mouse_enter: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_mouse_leave: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_mouse_move: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_mouse_out: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_mouse_over: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_mouse_up: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_scroll: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
on_unmount: Optional[
Union[EventHandler, EventSpec, list, Callable, BaseVar]
] = None,
**props,
) -> "ErrorBoundary":
"""Create the component.
Args:
*children: The children of the component.
Fallback_component: Rendered instead of the children when an error is caught.
style: The style of the component.
key: A unique key for the component.
id: The id for the component.
class_name: The class name for the component.
autofocus: Whether the component should take the focus once the page is loaded
custom_attrs: custom attribute
**props: The props of the component.
Returns:
The component.
"""
...
error_boundary = ErrorBoundary.create

View File

@ -124,6 +124,16 @@ class Hooks(SimpleNamespace):
}
})"""
FRONTEND_ERRORS = """
const logFrontendError = (error, info) => {
if (process.env.NODE_ENV === "production") {
addEvents([Event("frontend_event_exception_state.handle_frontend_exception", {
stack: error.stack,
})])
}
}
"""
class MemoizationDisposition(enum.Enum):
"""The conditions under which a component should be memoized."""

View File

@ -8,7 +8,6 @@ import copy
import functools
import inspect
import os
import traceback
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
@ -25,6 +24,8 @@ from typing import (
Sequence,
Set,
Type,
Union,
cast,
)
import dill
@ -47,7 +48,6 @@ from reflex.event import (
EventHandler,
EventSpec,
fix_events,
window_alert,
)
from reflex.utils import console, format, prerequisites, types
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
@ -1430,15 +1430,39 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Convert valid EventHandler and EventSpec into Event
fixed_events = fix_events(self._check_valid(handler, events), token)
# Get the delta after processing the event.
delta = state.get_delta()
state._clean()
try:
# Get the delta after processing the event.
delta = state.get_delta()
state._clean()
return StateUpdate(
delta=delta,
events=fixed_events,
final=final if not handler.is_background else True,
)
return StateUpdate(
delta=delta,
events=fixed_events,
final=final if not handler.is_background else True,
)
except Exception as ex:
state._clean()
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
event_specs = app_instance.backend_exception_handler(ex)
if event_specs is None:
return StateUpdate()
event_specs_correct_type = cast(
Union[List[Union[EventSpec, EventHandler]], None],
[event_specs] if isinstance(event_specs, EventSpec) else event_specs,
)
fixed_events = fix_events(
event_specs_correct_type,
token,
router_data=state.router_data,
)
return StateUpdate(
events=fixed_events,
final=True,
)
async def _process_event(
self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict
@ -1491,12 +1515,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# If an error occurs, throw a window alert.
except Exception as ex:
error = traceback.format_exc()
print(error)
telemetry.send_error(ex, context="backend")
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
event_specs = app_instance.backend_exception_handler(ex)
yield state._as_state_update(
handler,
window_alert("An error occurred. See logs for details."),
event_specs,
final=True,
)
@ -1798,6 +1825,23 @@ class State(BaseState):
is_hydrated: bool = False
class FrontendEventExceptionState(State):
"""Substate for handling frontend exceptions."""
def handle_frontend_exception(self, stack: str) -> None:
"""Handle frontend exceptions.
If a frontend exception handler is provided, it will be called.
Otherwise, the default frontend exception handler will be called.
Args:
stack: The stack trace of the exception.
"""
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance.frontend_exception_handler(Exception(stack))
class UpdateVarsInternalState(State):
"""Substate for handling internal state var updates."""

View File

@ -1,11 +1,13 @@
from __future__ import annotations
import functools
import io
import json
import os.path
import re
import unittest.mock
import uuid
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import Generator, List, Tuple, Type
from unittest.mock import AsyncMock
@ -1571,3 +1573,165 @@ def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]):
app.state = InvalidDepState
with pytest.raises(exceptions.VarDependencyError):
app._compile()
# Test custom exception handlers
def valid_custom_handler(exception: Exception, logger: str = "test"):
print("Custom Backend Exception")
print(exception)
def custom_exception_handler_with_wrong_arg_order(
logger: str,
exception: Exception, # Should be first
):
print("Custom Backend Exception")
print(exception)
def custom_exception_handler_with_wrong_argspec(
exception: str, # Should be Exception
):
print("Custom Backend Exception")
print(exception)
class DummyExceptionHandler:
"""Dummy exception handler class."""
def handle(self, exception: Exception):
"""Handle the exception.
Args:
exception: The exception.
"""
print("Custom Backend Exception")
print(exception)
custom_exception_handlers = {
"lambda": lambda exception: print("Custom Exception Handler", exception),
"wrong_argspec": custom_exception_handler_with_wrong_argspec,
"wrong_arg_order": custom_exception_handler_with_wrong_arg_order,
"valid": valid_custom_handler,
"partial": functools.partial(valid_custom_handler, logger="test"),
"method": DummyExceptionHandler().handle,
}
@pytest.mark.parametrize(
"handler_fn, expected",
[
pytest.param(
custom_exception_handlers["partial"],
pytest.raises(ValueError),
id="partial",
),
pytest.param(
custom_exception_handlers["lambda"],
pytest.raises(ValueError),
id="lambda",
),
pytest.param(
custom_exception_handlers["wrong_argspec"],
pytest.raises(ValueError),
id="wrong_argspec",
),
pytest.param(
custom_exception_handlers["wrong_arg_order"],
pytest.raises(ValueError),
id="wrong_arg_order",
),
pytest.param(
custom_exception_handlers["valid"],
does_not_raise(),
id="valid_handler",
),
pytest.param(
custom_exception_handlers["method"],
does_not_raise(),
id="valid_class_method",
),
],
)
def test_frontend_exception_handler_validation(handler_fn, expected):
"""Test that the custom frontend exception handler is properly validated.
Args:
handler_fn: The handler function.
expected: The expected result.
"""
with expected:
rx.App(frontend_exception_handler=handler_fn)._validate_exception_handlers()
def backend_exception_handler_with_wrong_return_type(exception: Exception) -> int:
"""Custom backend exception handler with wrong return type.
Args:
exception: The exception.
Returns:
int: The wrong return type.
"""
print("Custom Backend Exception")
print(exception)
return 5
@pytest.mark.parametrize(
"handler_fn, expected",
[
pytest.param(
backend_exception_handler_with_wrong_return_type,
pytest.raises(ValueError),
id="wrong_return_type",
),
pytest.param(
custom_exception_handlers["partial"],
pytest.raises(ValueError),
id="partial",
),
pytest.param(
custom_exception_handlers["lambda"],
pytest.raises(ValueError),
id="lambda",
),
pytest.param(
custom_exception_handlers["wrong_argspec"],
pytest.raises(ValueError),
id="wrong_argspec",
),
pytest.param(
custom_exception_handlers["wrong_arg_order"],
pytest.raises(ValueError),
id="wrong_arg_order",
),
pytest.param(
custom_exception_handlers["valid"],
does_not_raise(),
id="valid_handler",
),
pytest.param(
custom_exception_handlers["method"],
does_not_raise(),
id="valid_class_method",
),
],
)
def test_backend_exception_handler_validation(handler_fn, expected):
"""Test that the custom backend exception handler is properly validated.
Args:
handler_fn: The handler function.
expected: The expected result.
"""
with expected:
rx.App(backend_exception_handler=handler_fn)._validate_exception_handlers()

View File

@ -1464,11 +1464,13 @@ def test_error_on_state_method_shadow():
@pytest.mark.asyncio
async def test_state_with_invalid_yield(capsys):
async def test_state_with_invalid_yield(capsys, mock_app):
"""Test that an error is thrown when a state yields an invalid value.
Args:
capsys: Pytest fixture for capture standard streams.
mock_app: Mock app fixture.
"""
class StateWithInvalidYield(BaseState):