From 772a3ef8939b5582bc0e397a25d5f294e3da8f67 Mon Sep 17 00:00:00 2001 From: Maxim Vlah <50658693+maximvlah@users.noreply.github.com> Date: Tue, 9 Jul 2024 04:26:11 +0200 Subject: [PATCH] Catch unhandled errors on both frontend and backend (#3572) --- integration/test_exception_handlers.py | 152 +++++++++++++++++++ reflex/.templates/web/utils/state.js | 28 ++++ reflex/app.py | 176 +++++++++++++++++++++- reflex/components/base/__init__.py | 4 + reflex/components/base/__init__.pyi | 2 + reflex/components/base/error_boundary.py | 78 ++++++++++ reflex/components/base/error_boundary.pyi | 98 ++++++++++++ reflex/constants/compiler.py | 10 ++ reflex/state.py | 70 +++++++-- tests/test_app.py | 164 ++++++++++++++++++++ tests/test_state.py | 4 +- 11 files changed, 771 insertions(+), 15 deletions(-) create mode 100644 integration/test_exception_handlers.py create mode 100644 reflex/components/base/error_boundary.py create mode 100644 reflex/components/base/error_boundary.pyi diff --git a/integration/test_exception_handlers.py b/integration/test_exception_handlers.py new file mode 100644 index 000000000..8ba1faa89 --- /dev/null +++ b/integration/test_exception_handlers.py @@ -0,0 +1,152 @@ +"""Integration tests for event exception handlers.""" + +from __future__ import annotations + +import time +from typing import Generator, Type + +import pytest +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver +from selenium.webdriver.support import expected_conditions as EC +from selenium.webdriver.support.ui import WebDriverWait + +from reflex.testing import AppHarness + + +def TestApp(): + """A test app for event exception handler integration.""" + import reflex as rx + + class TestAppConfig(rx.Config): + """Config for the TestApp app.""" + + class TestAppState(rx.State): + """State for the TestApp app.""" + + def divide_by_number(self, number: int): + """Divide by number and print the result. + + Args: + number: number to divide by + + """ + print(1 / number) + + app = rx.App(state=rx.State) + + @app.add_page + def index(): + return rx.vstack( + rx.button( + "induce_frontend_error", + on_click=rx.call_script("induce_frontend_error()"), + id="induce-frontend-error-btn", + ), + rx.button( + "induce_backend_error", + on_click=lambda: TestAppState.divide_by_number(0), # type: ignore + id="induce-backend-error-btn", + ), + ) + + +@pytest.fixture(scope="module") +def test_app( + app_harness_env: Type[AppHarness], tmp_path_factory +) -> Generator[AppHarness, None, None]: + """Start TestApp app at tmp_path via AppHarness. + + Args: + app_harness_env: either AppHarness (dev) or AppHarnessProd (prod) + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + + """ + with app_harness_env.create( + root=tmp_path_factory.mktemp("test_app"), + app_name=f"testapp_{app_harness_env.__name__.lower()}", + app_source=TestApp, # type: ignore + ) as harness: + yield harness + + +@pytest.fixture +def driver(test_app: AppHarness) -> Generator[WebDriver, None, None]: + """Get an instance of the browser open to the test_app app. + + Args: + test_app: harness for TestApp app + + Yields: + WebDriver instance. + + """ + assert test_app.app_instance is not None, "app is not running" + driver = test_app.frontend() + try: + yield driver + finally: + driver.quit() + + +def test_frontend_exception_handler_during_runtime( + driver: WebDriver, + capsys, +): + """Test calling frontend exception handler during runtime. + + We send an event containing a call to a non-existent function in the frontend. + This should trigger the default frontend exception handler. + + Args: + driver: WebDriver instance. + capsys: pytest fixture for capturing stdout and stderr. + + """ + reset_button = WebDriverWait(driver, 20).until( + EC.element_to_be_clickable((By.ID, "induce-frontend-error-btn")) + ) + + reset_button.click() + + # Wait for the error to be logged + time.sleep(2) + + captured_default_handler_output = capsys.readouterr() + assert ( + "induce_frontend_error" in captured_default_handler_output.out + and "ReferenceError" in captured_default_handler_output.out + ) + + +def test_backend_exception_handler_during_runtime( + driver: WebDriver, + capsys, +): + """Test calling backend exception handler during runtime. + + We invoke TestAppState.divide_by_zero to induce backend error. + This should trigger the default backend exception handler. + + Args: + driver: WebDriver instance. + capsys: pytest fixture for capturing stdout and stderr. + + """ + reset_button = WebDriverWait(driver, 20).until( + EC.element_to_be_clickable((By.ID, "induce-backend-error-btn")) + ) + + reset_button.click() + + # Wait for the error to be logged + time.sleep(2) + + captured_default_handler_output = capsys.readouterr() + assert ( + "divide_by_number" in captured_default_handler_output.out + and "ZeroDivisionError" in captured_default_handler_output.out + ) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 763788674..b20b80391 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -247,6 +247,9 @@ export const applyEvent = async (event, socket) => { } } catch (e) { console.log("_call_script", e); + if (window && window?.onerror) { + window.onerror(e.message, null, null, null, e) + } } return false; } @@ -687,6 +690,31 @@ export const useEventLoop = ( } }, [router.isReady]); + // Handle frontend errors and send them to the backend via websocket. + useEffect(() => { + + if (typeof window === 'undefined') { + return; + } + + window.onerror = function (msg, url, lineNo, columnNo, error) { + addEvents([Event("state.frontend_event_exception_state.handle_frontend_exception", { + stack: error.stack, + })]) + return false; + } + + //NOTE: Only works in Chrome v49+ + //https://github.com/mknichel/javascript-errors?tab=readme-ov-file#promise-rejection-events + window.onunhandledrejection = function (event) { + addEvents([Event("state.frontend_event_exception_state.handle_frontend_exception", { + stack: event.reason.stack, + })]) + return false; + } + + },[]) + // Main event loop. useEffect(() => { // Skip if the router is not ready. diff --git a/reflex/app.py b/reflex/app.py index 8e8544e7b..8f388e37b 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -7,11 +7,13 @@ import concurrent.futures import contextlib import copy import functools +import inspect import io import multiprocessing import os import platform import sys +import traceback from datetime import datetime from typing import ( Any, @@ -45,6 +47,7 @@ from reflex.compiler import compiler from reflex.compiler import utils as compiler_utils from reflex.compiler.compiler import ExecutorSafeFunctions from reflex.components.base.app_wrap import AppWrap +from reflex.components.base.error_boundary import ErrorBoundary from reflex.components.base.fragment import Fragment from reflex.components.component import ( Component, @@ -60,7 +63,7 @@ from reflex.components.core.client_side_routing import ( from reflex.components.core.upload import Upload, get_upload_dir from reflex.components.radix import themes from reflex.config import get_config -from reflex.event import Event, EventHandler, EventSpec +from reflex.event import Event, EventHandler, EventSpec, window_alert from reflex.model import Model from reflex.page import ( DECORATED_PAGES, @@ -88,6 +91,33 @@ ComponentCallable = Callable[[], Component] Reducer = Callable[[Event], Coroutine[Any, Any, StateUpdate]] +def default_frontend_exception_handler(exception: Exception) -> None: + """Default frontend exception handler function. + + Args: + exception: The exception. + + """ + console.error(f"[Reflex Frontend Exception]\n {exception}\n") + + +def default_backend_exception_handler(exception: Exception) -> EventSpec: + """Default backend exception handler function. + + Args: + exception: The exception. + + Returns: + EventSpec: The window alert event. + + """ + error = traceback.format_exc() + + console.error(f"[Reflex Backend Exception]\n {error}\n") + + return window_alert("An error occurred. See logs for details.") + + def default_overlay_component() -> Component: """Default overlay_component attribute for App. @@ -101,6 +131,16 @@ def default_overlay_component() -> Component: ) +def default_error_boundary() -> Component: + """Default error_boundary attribute for App. + + Returns: + The default error_boundary, which is an ErrorBoundary. + + """ + return ErrorBoundary.create() + + class OverlayFragment(Fragment): """Alias for Fragment, used to wrap the overlay_component.""" @@ -142,6 +182,11 @@ class App(MiddlewareMixin, LifespanMixin, Base): default_overlay_component ) + # Error boundary component to wrap the app with. + error_boundary: Optional[Union[Component, ComponentCallable]] = ( + default_error_boundary + ) + # Components to add to the head of every page. head_components: List[Component] = [] @@ -178,6 +223,16 @@ class App(MiddlewareMixin, LifespanMixin, Base): # Background tasks that are currently running. PRIVATE. background_tasks: Set[asyncio.Task] = set() + # Frontend Error Handler Function + frontend_exception_handler: Callable[[Exception], None] = ( + default_frontend_exception_handler + ) + + # Backend Error Handler Function + backend_exception_handler: Callable[ + [Exception], Union[EventSpec, List[EventSpec], None] + ] = default_backend_exception_handler + def __init__(self, **kwargs): """Initialize the app. @@ -279,6 +334,9 @@ class App(MiddlewareMixin, LifespanMixin, Base): # Mount the socket app with the API. self.api.mount(str(constants.Endpoint.EVENT), socket_app) + # Check the exception handlers + self._validate_exception_handlers() + def __repr__(self) -> str: """Get the string representation of the app. @@ -688,6 +746,25 @@ class App(MiddlewareMixin, LifespanMixin, Base): for k, component in self.pages.items(): self.pages[k] = self._add_overlay_to_component(component) + def _add_error_boundary_to_component(self, component: Component) -> Component: + if self.error_boundary is None: + return component + + component = ErrorBoundary.create(*component.children) + + return component + + def _setup_error_boundary(self): + """If a State is not used and no error_boundary is specified, do not render the error boundary.""" + if self.state is None and self.error_boundary is default_error_boundary: + self.error_boundary = None + + for k, component in self.pages.items(): + # Skip the 404 page + if k == constants.Page404.SLUG: + continue + self.pages[k] = self._add_error_boundary_to_component(component) + def _apply_decorated_pages(self): """Add @rx.page decorated pages to the app. @@ -757,6 +834,7 @@ class App(MiddlewareMixin, LifespanMixin, Base): self._validate_var_dependencies() self._setup_overlay_component() + self._setup_error_boundary() # Create a progress bar. progress = Progress( @@ -1036,6 +1114,100 @@ class App(MiddlewareMixin, LifespanMixin, Base): task.add_done_callback(self.background_tasks.discard) return task + def _validate_exception_handlers(self): + """Validate the custom event exception handlers for front- and backend. + + Raises: + ValueError: If the custom exception handlers are invalid. + + """ + FRONTEND_ARG_SPEC = { + "exception": Exception, + } + + BACKEND_ARG_SPEC = { + "exception": Exception, + } + + for handler_domain, handler_fn, handler_spec in zip( + ["frontend", "backend"], + [self.frontend_exception_handler, self.backend_exception_handler], + [ + FRONTEND_ARG_SPEC, + BACKEND_ARG_SPEC, + ], + ): + if hasattr(handler_fn, "__name__"): + _fn_name = handler_fn.__name__ + else: + _fn_name = handler_fn.__class__.__name__ + + if isinstance(handler_fn, functools.partial): + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` is a partial function. Please provide a named function instead." + ) + + if not callable(handler_fn): + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` is not a function." + ) + + # Allow named functions only as lambda functions cannot be introspected + if _fn_name == "": + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` is a lambda function. Please use a named function instead." + ) + + # Check if the function has the necessary annotations and types in the right order + argspec = inspect.getfullargspec(handler_fn) + arg_annotations = { + k: eval(v) if isinstance(v, str) else v + for k, v in argspec.annotations.items() + if k not in ["args", "kwargs", "return"] + } + + for required_arg_index, required_arg in enumerate(handler_spec): + if required_arg not in arg_annotations: + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` does not take the required argument `{required_arg}`" + ) + elif ( + not list(arg_annotations.keys())[required_arg_index] == required_arg + ): + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong argument order." + f"Expected `{required_arg}` as the {required_arg_index+1} argument but got `{list(arg_annotations.keys())[required_arg_index]}`" + ) + + if not issubclass(arg_annotations[required_arg], Exception): + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong type for {required_arg} argument." + f"Expected to be `Exception` but got `{arg_annotations[required_arg]}`" + ) + + # Check if the return type is valid for backend exception handler + if handler_domain == "backend": + sig = inspect.signature(self.backend_exception_handler) + return_type = ( + eval(sig.return_annotation) + if isinstance(sig.return_annotation, str) + else sig.return_annotation + ) + + valid = bool( + return_type == EventSpec + or return_type == Optional[EventSpec] + or return_type == List[EventSpec] + or return_type == inspect.Signature.empty + or return_type is None + ) + + if not valid: + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong return type." + f"Expected `Union[EventSpec, List[EventSpec], None]` but got `{return_type}`" + ) + async def process( app: App, event: Event, sid: str, headers: Dict, client_ip: str @@ -1101,6 +1273,8 @@ async def process( yield update except Exception as ex: telemetry.send_error(ex, context="backend") + + app.backend_exception_handler(ex) raise diff --git a/reflex/components/base/__init__.py b/reflex/components/base/__init__.py index 7c552ce11..51c369bb9 100644 --- a/reflex/components/base/__init__.py +++ b/reflex/components/base/__init__.py @@ -13,6 +13,10 @@ _SUBMOD_ATTRS: dict[str, list[str]] = { "Fragment", "fragment", ], + "error_boundary": [ + "ErrorBoundary", + "error_boundary", + ], "head": [ "head", "Head", diff --git a/reflex/components/base/__init__.pyi b/reflex/components/base/__init__.pyi index 25a64d37f..d312caa9d 100644 --- a/reflex/components/base/__init__.pyi +++ b/reflex/components/base/__init__.pyi @@ -10,6 +10,8 @@ from .document import DocumentHead as DocumentHead from .document import Html as Html from .document import Main as Main from .document import NextScript as NextScript +from .error_boundary import ErrorBoundary as ErrorBoundary +from .error_boundary import error_boundary as error_boundary from .fragment import Fragment as Fragment from .fragment import fragment as fragment from .head import Head as Head diff --git a/reflex/components/base/error_boundary.py b/reflex/components/base/error_boundary.py new file mode 100644 index 000000000..e90f0ed63 --- /dev/null +++ b/reflex/components/base/error_boundary.py @@ -0,0 +1,78 @@ +"""A React Error Boundary component that catches unhandled frontend exceptions.""" + +from __future__ import annotations + +from typing import List + +from reflex.compiler.compiler import _compile_component +from reflex.components.component import Component +from reflex.components.el import div, p +from reflex.constants import Hooks, Imports +from reflex.event import EventChain, EventHandler +from reflex.utils.imports import ImportVar +from reflex.vars import Var + + +class ErrorBoundary(Component): + """A React Error Boundary component that catches unhandled frontend exceptions.""" + + library = "react-error-boundary" + tag = "ErrorBoundary" + + # Fired when the boundary catches an error. + on_error: EventHandler[lambda error, info: [error, info]] = Var.create_safe( # type: ignore + "logFrontendError", _var_is_string=False, _var_is_local=False + ).to(EventChain) + + # Rendered instead of the children when an error is caught. + Fallback_component: Var[Component] = Var.create_safe( + "Fallback", _var_is_string=False, _var_is_local=False + ).to(Component) + + def add_imports(self) -> dict[str, list[ImportVar]]: + """Add imports for the component. + + Returns: + The imports to add. + """ + return Imports.EVENTS + + def add_hooks(self) -> List[str | Var]: + """Add hooks for the component. + + Returns: + The hooks to add. + """ + return [Hooks.EVENTS, Hooks.FRONTEND_ERRORS] + + def add_custom_code(self) -> List[str]: + """Add custom Javascript code into the page that contains this component. + + Custom code is inserted at module level, after any imports. + + Returns: + The custom code to add. + """ + fallback_container = div( + p("Ooops...Unknown Reflex error has occured:"), + p( + Var.create("error.message", _var_is_local=False, _var_is_string=False), + color="red", + ), + p("Please contact the support."), + ) + + compiled_fallback = _compile_component(fallback_container) + + return [ + f""" + function Fallback({{ error, resetErrorBoundary }}) {{ + return ( + {compiled_fallback} + ); + }} + """ + ] + + +error_boundary = ErrorBoundary.create diff --git a/reflex/components/base/error_boundary.pyi b/reflex/components/base/error_boundary.pyi new file mode 100644 index 000000000..5bc260b3c --- /dev/null +++ b/reflex/components/base/error_boundary.pyi @@ -0,0 +1,98 @@ +"""Stub file for reflex/components/base/error_boundary.py""" + +# ------------------- DO NOT EDIT ---------------------- +# This file was generated by `reflex/utils/pyi_generator.py`! +# ------------------------------------------------------ +from typing import Any, Callable, Dict, List, Optional, Union, overload + +from reflex.components.component import Component +from reflex.event import EventHandler, EventSpec +from reflex.style import Style +from reflex.utils.imports import ImportVar +from reflex.vars import BaseVar, Var + +class ErrorBoundary(Component): + def add_imports(self) -> dict[str, list[ImportVar]]: ... + def add_hooks(self) -> List[str | Var]: ... + def add_custom_code(self) -> List[str]: ... + @overload + @classmethod + def create( # type: ignore + cls, + *children, + Fallback_component: Optional[Union[Var[Component], Component]] = None, + style: Optional[Style] = None, + key: Optional[Any] = None, + id: Optional[Any] = None, + class_name: Optional[Any] = None, + autofocus: Optional[bool] = None, + custom_attrs: Optional[Dict[str, Union[Var, str]]] = None, + on_blur: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_context_menu: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_double_click: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_error: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_focus: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mount: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_down: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_enter: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_leave: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_move: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_out: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_over: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_mouse_up: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_scroll: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + on_unmount: Optional[ + Union[EventHandler, EventSpec, list, Callable, BaseVar] + ] = None, + **props, + ) -> "ErrorBoundary": + """Create the component. + + Args: + *children: The children of the component. + Fallback_component: Rendered instead of the children when an error is caught. + style: The style of the component. + key: A unique key for the component. + id: The id for the component. + class_name: The class name for the component. + autofocus: Whether the component should take the focus once the page is loaded + custom_attrs: custom attribute + **props: The props of the component. + + Returns: + The component. + """ + ... + +error_boundary = ErrorBoundary.create diff --git a/reflex/constants/compiler.py b/reflex/constants/compiler.py index 0c71dae0d..3b1f480d2 100644 --- a/reflex/constants/compiler.py +++ b/reflex/constants/compiler.py @@ -124,6 +124,16 @@ class Hooks(SimpleNamespace): } })""" + FRONTEND_ERRORS = """ + const logFrontendError = (error, info) => { + if (process.env.NODE_ENV === "production") { + addEvents([Event("frontend_event_exception_state.handle_frontend_exception", { + stack: error.stack, + })]) + } + } + """ + class MemoizationDisposition(enum.Enum): """The conditions under which a component should be memoized.""" diff --git a/reflex/state.py b/reflex/state.py index 88eb7ae73..9569b2aba 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -8,7 +8,6 @@ import copy import functools import inspect import os -import traceback import uuid from abc import ABC, abstractmethod from collections import defaultdict @@ -25,6 +24,8 @@ from typing import ( Sequence, Set, Type, + Union, + cast, ) import dill @@ -47,7 +48,6 @@ from reflex.event import ( EventHandler, EventSpec, fix_events, - window_alert, ) from reflex.utils import console, format, prerequisites, types from reflex.utils.exceptions import ImmutableStateError, LockExpiredError @@ -1430,15 +1430,39 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # Convert valid EventHandler and EventSpec into Event fixed_events = fix_events(self._check_valid(handler, events), token) - # Get the delta after processing the event. - delta = state.get_delta() - state._clean() + try: + # Get the delta after processing the event. + delta = state.get_delta() + state._clean() - return StateUpdate( - delta=delta, - events=fixed_events, - final=final if not handler.is_background else True, - ) + return StateUpdate( + delta=delta, + events=fixed_events, + final=final if not handler.is_background else True, + ) + except Exception as ex: + state._clean() + + app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) + + event_specs = app_instance.backend_exception_handler(ex) + + if event_specs is None: + return StateUpdate() + + event_specs_correct_type = cast( + Union[List[Union[EventSpec, EventHandler]], None], + [event_specs] if isinstance(event_specs, EventSpec) else event_specs, + ) + fixed_events = fix_events( + event_specs_correct_type, + token, + router_data=state.router_data, + ) + return StateUpdate( + events=fixed_events, + final=True, + ) async def _process_event( self, handler: EventHandler, state: BaseState | StateProxy, payload: Dict @@ -1491,12 +1515,15 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # If an error occurs, throw a window alert. except Exception as ex: - error = traceback.format_exc() - print(error) telemetry.send_error(ex, context="backend") + + app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) + + event_specs = app_instance.backend_exception_handler(ex) + yield state._as_state_update( handler, - window_alert("An error occurred. See logs for details."), + event_specs, final=True, ) @@ -1798,6 +1825,23 @@ class State(BaseState): is_hydrated: bool = False +class FrontendEventExceptionState(State): + """Substate for handling frontend exceptions.""" + + def handle_frontend_exception(self, stack: str) -> None: + """Handle frontend exceptions. + + If a frontend exception handler is provided, it will be called. + Otherwise, the default frontend exception handler will be called. + + Args: + stack: The stack trace of the exception. + + """ + app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP) + app_instance.frontend_exception_handler(Exception(stack)) + + class UpdateVarsInternalState(State): """Substate for handling internal state var updates.""" diff --git a/tests/test_app.py b/tests/test_app.py index 2bf838c4b..f41d46c7c 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,11 +1,13 @@ from __future__ import annotations +import functools import io import json import os.path import re import unittest.mock import uuid +from contextlib import nullcontext as does_not_raise from pathlib import Path from typing import Generator, List, Tuple, Type from unittest.mock import AsyncMock @@ -1571,3 +1573,165 @@ def test_app_with_invalid_var_dependencies(compilable_app: tuple[App, Path]): app.state = InvalidDepState with pytest.raises(exceptions.VarDependencyError): app._compile() + + +# Test custom exception handlers + + +def valid_custom_handler(exception: Exception, logger: str = "test"): + print("Custom Backend Exception") + print(exception) + + +def custom_exception_handler_with_wrong_arg_order( + logger: str, + exception: Exception, # Should be first +): + print("Custom Backend Exception") + print(exception) + + +def custom_exception_handler_with_wrong_argspec( + exception: str, # Should be Exception +): + print("Custom Backend Exception") + print(exception) + + +class DummyExceptionHandler: + """Dummy exception handler class.""" + + def handle(self, exception: Exception): + """Handle the exception. + + Args: + exception: The exception. + + """ + print("Custom Backend Exception") + print(exception) + + +custom_exception_handlers = { + "lambda": lambda exception: print("Custom Exception Handler", exception), + "wrong_argspec": custom_exception_handler_with_wrong_argspec, + "wrong_arg_order": custom_exception_handler_with_wrong_arg_order, + "valid": valid_custom_handler, + "partial": functools.partial(valid_custom_handler, logger="test"), + "method": DummyExceptionHandler().handle, +} + + +@pytest.mark.parametrize( + "handler_fn, expected", + [ + pytest.param( + custom_exception_handlers["partial"], + pytest.raises(ValueError), + id="partial", + ), + pytest.param( + custom_exception_handlers["lambda"], + pytest.raises(ValueError), + id="lambda", + ), + pytest.param( + custom_exception_handlers["wrong_argspec"], + pytest.raises(ValueError), + id="wrong_argspec", + ), + pytest.param( + custom_exception_handlers["wrong_arg_order"], + pytest.raises(ValueError), + id="wrong_arg_order", + ), + pytest.param( + custom_exception_handlers["valid"], + does_not_raise(), + id="valid_handler", + ), + pytest.param( + custom_exception_handlers["method"], + does_not_raise(), + id="valid_class_method", + ), + ], +) +def test_frontend_exception_handler_validation(handler_fn, expected): + """Test that the custom frontend exception handler is properly validated. + + Args: + handler_fn: The handler function. + expected: The expected result. + + """ + with expected: + rx.App(frontend_exception_handler=handler_fn)._validate_exception_handlers() + + +def backend_exception_handler_with_wrong_return_type(exception: Exception) -> int: + """Custom backend exception handler with wrong return type. + + Args: + exception: The exception. + + Returns: + int: The wrong return type. + + """ + print("Custom Backend Exception") + print(exception) + + return 5 + + +@pytest.mark.parametrize( + "handler_fn, expected", + [ + pytest.param( + backend_exception_handler_with_wrong_return_type, + pytest.raises(ValueError), + id="wrong_return_type", + ), + pytest.param( + custom_exception_handlers["partial"], + pytest.raises(ValueError), + id="partial", + ), + pytest.param( + custom_exception_handlers["lambda"], + pytest.raises(ValueError), + id="lambda", + ), + pytest.param( + custom_exception_handlers["wrong_argspec"], + pytest.raises(ValueError), + id="wrong_argspec", + ), + pytest.param( + custom_exception_handlers["wrong_arg_order"], + pytest.raises(ValueError), + id="wrong_arg_order", + ), + pytest.param( + custom_exception_handlers["valid"], + does_not_raise(), + id="valid_handler", + ), + pytest.param( + custom_exception_handlers["method"], + does_not_raise(), + id="valid_class_method", + ), + ], +) +def test_backend_exception_handler_validation(handler_fn, expected): + """Test that the custom backend exception handler is properly validated. + + Args: + handler_fn: The handler function. + expected: The expected result. + + """ + with expected: + rx.App(backend_exception_handler=handler_fn)._validate_exception_handlers() diff --git a/tests/test_state.py b/tests/test_state.py index dbdfc5836..1254a800e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1464,11 +1464,13 @@ def test_error_on_state_method_shadow(): @pytest.mark.asyncio -async def test_state_with_invalid_yield(capsys): +async def test_state_with_invalid_yield(capsys, mock_app): """Test that an error is thrown when a state yields an invalid value. Args: capsys: Pytest fixture for capture standard streams. + mock_app: Mock app fixture. + """ class StateWithInvalidYield(BaseState):