From 62916ac6a1217995af774ce7aee84c30645ee64a Mon Sep 17 00:00:00 2001 From: Lendemor Date: Wed, 15 Jan 2025 14:31:14 +0100 Subject: [PATCH] ignore lambdas when resolving annotations --- reflex/event.py | 34 ++++++++++++----- reflex/utils/exceptions.py | 46 +++++++++++++---------- reflex/utils/prerequisites.py | 10 +++-- tests/integration/test_call_script.py | 48 ++++++++++++------------ tests/units/components/test_component.py | 21 ++++++----- 5 files changed, 95 insertions(+), 64 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index 4a8d38e26..015251441 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -39,7 +39,11 @@ from typing_extensions import ( from reflex import constants from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.utils import console, format -from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgTypeMismatch +from reflex.utils.exceptions import ( + EventFnArgMismatchError, + EventHandlerArgTypeMismatchError, + MissingAnnotationError, +) from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var @@ -1218,7 +1222,7 @@ def call_event_handler( key: The key to pass to the event handler. Raises: - EventHandlerArgTypeMismatch: If the event handler arguments do not match the event spec. + EventHandlerArgTypeMismatchError: If the event handler arguments do not match the event spec. TypeError: If the event handler arguments are invalid. Returns: @@ -1296,7 +1300,7 @@ def call_event_handler( if compare_result: continue else: - raise EventHandlerArgTypeMismatch( + raise EventHandlerArgTypeMismatchError( f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead." ) @@ -1341,19 +1345,23 @@ def unwrap_var_annotation(annotation: GenericType): return annotation -def resolve_annotation(annotations: dict[str, Any], arg_name: str): +def resolve_annotation(annotations: dict[str, Any], arg_name: str, spec: ArgsSpec): """Resolve the annotation for the given argument name. Args: annotations: The annotations. arg_name: The argument name. + spec: The specs which the annotations come from. + + Raises: + MissingAnnotationError: If the annotation is missing for non-lambda methods. Returns: The resolved annotation. """ annotation = annotations.get(arg_name) - if annotation is None: - console.error(f"Invalid annotation '{annotation}' for var '{arg_name}'.") + if annotation is None and not isinstance(spec, types.LambdaType): + raise MissingAnnotationError(var_name=arg_name) return annotation @@ -1375,7 +1383,13 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]): arg_spec( *[ Var(f"_{l_arg}").to( - unwrap_var_annotation(resolve_annotation(annotations, l_arg)) + unwrap_var_annotation( + resolve_annotation( + annotations, + l_arg, + spec=arg_spec, + ) + ) ) for l_arg in spec.args ] @@ -1391,7 +1405,7 @@ def check_fn_match_arg_spec( func_name: str | None = None, ): """Ensures that the function signature matches the passed argument specification - or raises an EventFnArgMismatch if they do not. + or raises an EventFnArgMismatchError if they do not. Args: user_func: The function to be validated. @@ -1401,7 +1415,7 @@ def check_fn_match_arg_spec( func_name: The name of the function to be validated. Raises: - EventFnArgMismatch: Raised if the number of mandatory arguments do not match + EventFnArgMismatchError: Raised if the number of mandatory arguments do not match """ user_args = inspect.getfullargspec(user_func).args # Drop the first argument if it's a bound method @@ -1417,7 +1431,7 @@ def check_fn_match_arg_spec( number_of_event_args = len(parsed_event_args) if number_of_user_args - number_of_user_default_args > number_of_event_args: - raise EventFnArgMismatch( + raise EventFnArgMismatchError( f"Event {key} only provides {number_of_event_args} arguments, but " f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} " "arguments to be passed to the event handler.\n" diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index df5a8285f..6b9ebe0e5 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -1,7 +1,5 @@ """Custom Exceptions.""" -from typing import NoReturn - class ReflexError(Exception): """Base exception for all Reflex exceptions.""" @@ -71,6 +69,18 @@ class UntypedComputedVarError(ReflexError, TypeError): super().__init__(f"Computed var '{var_name}' must have a type annotation.") +class MissingAnnotationError(ReflexError, TypeError): + """Custom TypeError for missing annotations.""" + + def __init__(self, var_name): + """Initialize the MissingAnnotationError. + + Args: + var_name: The name of the var. + """ + super().__init__(f"Var '{var_name}' must have a type annotation.") + + class UploadValueError(ReflexError, ValueError): """Custom ValueError for upload related errors.""" @@ -107,11 +117,11 @@ class MatchTypeError(ReflexError, TypeError): """Raised when the return types of match cases are different.""" -class EventHandlerArgTypeMismatch(ReflexError, TypeError): +class EventHandlerArgTypeMismatchError(ReflexError, TypeError): """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger.""" -class EventFnArgMismatch(ReflexError, TypeError): +class EventFnArgMismatchError(ReflexError, TypeError): """Raised when the number of args required by an event handler is more than provided by the event trigger.""" @@ -178,23 +188,21 @@ class StateSerializationError(ReflexError): class SystemPackageMissingError(ReflexError): """Raised when a system package is missing.""" + def __init__(self, package: str): + """Initialize the SystemPackageMissingError. -def raise_system_package_missing_error(package: str) -> NoReturn: - """Raise a SystemPackageMissingError. + Args: + package: The missing package. + """ + from reflex.constants import IS_MACOS - Args: - package: The name of the missing system package. - - Raises: - SystemPackageMissingError: The raised exception. - """ - from reflex.constants import IS_MACOS - - raise SystemPackageMissingError( - f"System package '{package}' is missing." - " Please install it through your system package manager." - + (f" You can do so by running 'brew install {package}'." if IS_MACOS else "") - ) + extra = ( + f" You can do so by running 'brew install {package}'." if IS_MACOS else "" + ) + super().__init__( + f"System package '{package}' is missing." + f" Please install it through your system package manager.{extra}" + ) class InvalidLockWarningThresholdError(ReflexError): diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 3b750be70..f208994d1 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -37,7 +37,7 @@ from reflex.config import Config, environment, get_config from reflex.utils import console, net, path_ops, processes, redir from reflex.utils.exceptions import ( GeneratedCodeHasNoFunctionDefs, - raise_system_package_missing_error, + SystemPackageMissingError, ) from reflex.utils.format import format_library_name from reflex.utils.registry import _get_npm_registry @@ -817,7 +817,11 @@ def install_node(): def install_bun(): - """Install bun onto the user's system.""" + """Install bun onto the user's system. + + Raises: + SystemPackageMissingError: If "unzip" is missing. + """ win_supported = is_windows_bun_supported() one_drive_in_path = windows_check_onedrive_in_path() if constants.IS_WINDOWS and (not win_supported or one_drive_in_path): @@ -856,7 +860,7 @@ def install_bun(): else: unzip_path = path_ops.which("unzip") if unzip_path is None: - raise_system_package_missing_error("unzip") + raise SystemPackageMissingError("unzip") # Run the bun install script. download_and_run( diff --git a/tests/integration/test_call_script.py b/tests/integration/test_call_script.py index 203c20e9b..be0bafdbb 100644 --- a/tests/integration/test_call_script.py +++ b/tests/integration/test_call_script.py @@ -16,7 +16,7 @@ from .utils import SessionStorage def CallScript(): """A test app for browser javascript integration.""" from pathlib import Path - from typing import Dict, List, Optional, Union + from typing import Optional, Union import reflex as rx @@ -43,15 +43,17 @@ def CallScript(): external_scripts = inline_scripts.replace("inline", "external") class CallScriptState(rx.State): - results: List[Optional[Union[str, Dict, List]]] = [] - inline_counter: int = 0 - external_counter: int = 0 + results: rx.Field[list[Optional[Union[str, dict, list]]]] = rx.field([]) + inline_counter: rx.Field[int] = rx.field(0) + external_counter: rx.Field[int] = rx.field(0) value: str = "Initial" - last_result: str = "" + last_result: int = 0 + @rx.event def call_script_callback(self, result): self.results.append(result) + @rx.event def call_script_callback_other_arg(self, result, other_arg): self.results.append([other_arg, result]) @@ -91,7 +93,7 @@ def CallScript(): def call_script_inline_return_lambda(self): return rx.call_script( "inline2()", - callback=lambda result: CallScriptState.call_script_callback_other_arg( # type: ignore + callback=lambda result: CallScriptState.call_script_callback_other_arg( result, "lambda" ), ) @@ -100,7 +102,7 @@ def CallScript(): def get_inline_counter(self): return rx.call_script( "inline_counter", - callback=CallScriptState.set_inline_counter, # type: ignore + callback=CallScriptState.setvar("inline_counter"), ) @rx.event @@ -139,7 +141,7 @@ def CallScript(): def call_script_external_return_lambda(self): return rx.call_script( "external2()", - callback=lambda result: CallScriptState.call_script_callback_other_arg( # type: ignore + callback=lambda result: CallScriptState.call_script_callback_other_arg( result, "lambda" ), ) @@ -148,28 +150,28 @@ def CallScript(): def get_external_counter(self): return rx.call_script( "external_counter", - callback=CallScriptState.set_external_counter, # type: ignore + callback=CallScriptState.setvar("external_counter"), ) @rx.event def call_with_var_f_string(self): return rx.call_script( f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}", - callback=CallScriptState.set_last_result, # type: ignore + callback=CallScriptState.setvar("last_result"), ) @rx.event def call_with_var_str_cast(self): return rx.call_script( f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}", - callback=CallScriptState.set_last_result, # type: ignore + callback=CallScriptState.setvar("last_result"), ) @rx.event def call_with_var_f_string_wrapped(self): return rx.call_script( rx.Var(f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}"), - callback=CallScriptState.set_last_result, # type: ignore + callback=CallScriptState.setvar("last_result"), ) @rx.event @@ -178,7 +180,7 @@ def CallScript(): rx.Var( f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}" ), - callback=CallScriptState.set_last_result, # type: ignore + callback=CallScriptState.setvar("last_result"), ) @rx.event @@ -193,17 +195,17 @@ def CallScript(): def index(): return rx.vstack( rx.input( - value=CallScriptState.inline_counter.to(str), # type: ignore + value=CallScriptState.inline_counter.to(str), id="inline_counter", read_only=True, ), rx.input( - value=CallScriptState.external_counter.to(str), # type: ignore + value=CallScriptState.external_counter.to(str), id="external_counter", read_only=True, ), rx.text_area( - value=CallScriptState.results.to_string(), # type: ignore + value=CallScriptState.results.to_string(), id="results", read_only=True, ), @@ -273,7 +275,7 @@ def CallScript(): CallScriptState.value, on_click=rx.call_script( "'updated'", - callback=CallScriptState.set_value, # type: ignore + callback=CallScriptState.setvar("value"), ), id="update_value", ), @@ -282,7 +284,7 @@ def CallScript(): value=CallScriptState.last_result, id="last_result", read_only=True, - on_click=CallScriptState.set_last_result(""), # type: ignore + on_click=CallScriptState.setvar("last_result", 0), ), rx.button( "call_with_var_f_string", @@ -308,7 +310,7 @@ def CallScript(): "call_with_var_f_string_inline", on_click=rx.call_script( f"{rx.Var('inline_counter')} + {CallScriptState.last_result}", - callback=CallScriptState.set_last_result, # type: ignore + callback=CallScriptState.setvar("last_result"), ), id="call_with_var_f_string_inline", ), @@ -316,7 +318,7 @@ def CallScript(): "call_with_var_str_cast_inline", on_click=rx.call_script( f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}", - callback=CallScriptState.set_last_result, # type: ignore + callback=CallScriptState.setvar("last_result"), ), id="call_with_var_str_cast_inline", ), @@ -326,7 +328,7 @@ def CallScript(): rx.Var( f"{rx.Var('inline_counter')} + {CallScriptState.last_result}" ), - callback=CallScriptState.set_last_result, # type: ignore + callback=CallScriptState.setvar("last_result"), ), id="call_with_var_f_string_wrapped_inline", ), @@ -336,7 +338,7 @@ def CallScript(): rx.Var( f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}" ), - callback=CallScriptState.set_last_result, # type: ignore + callback=CallScriptState.setvar("last_result"), ), id="call_with_var_str_cast_wrapped_inline", ), @@ -483,7 +485,7 @@ def test_call_script_w_var( """ assert_token(driver) last_result = driver.find_element(By.ID, "last_result") - assert last_result.get_attribute("value") == "" + assert last_result.get_attribute("value") == "0" inline_return_button = driver.find_element(By.ID, "inline_return") diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 97582af75..a114d1844 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -28,7 +28,10 @@ from reflex.event import ( from reflex.state import BaseState from reflex.style import Style from reflex.utils import imports -from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgTypeMismatch +from reflex.utils.exceptions import ( + EventFnArgMismatchError, + EventHandlerArgTypeMismatchError, +) from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var @@ -910,23 +913,23 @@ def test_invalid_event_handler_args(component2, test_state): test_state: A test state. """ # EventHandler args must match - with pytest.raises(EventFnArgMismatch): + with pytest.raises(EventFnArgMismatchError): component2.create(on_click=test_state.do_something_arg) # Multiple EventHandler args: all must match - with pytest.raises(EventFnArgMismatch): + with pytest.raises(EventFnArgMismatchError): component2.create( on_click=[test_state.do_something_arg, test_state.do_something] ) # # Event Handler types must match - with pytest.raises(EventHandlerArgTypeMismatch): + with pytest.raises(EventHandlerArgTypeMismatchError): component2.create( on_user_visited_count_changed=test_state.do_something_with_bool ) - with pytest.raises(EventHandlerArgTypeMismatch): + with pytest.raises(EventHandlerArgTypeMismatchError): component2.create(on_user_list_changed=test_state.do_something_with_int) - with pytest.raises(EventHandlerArgTypeMismatch): + with pytest.raises(EventHandlerArgTypeMismatchError): component2.create(on_user_list_changed=test_state.do_something_with_list_int) component2.create(on_open=test_state.do_something_with_int) @@ -945,15 +948,15 @@ def test_invalid_event_handler_args(component2, test_state): ) # lambda signature must match event trigger. - with pytest.raises(EventFnArgMismatch): + with pytest.raises(EventFnArgMismatchError): component2.create(on_click=lambda _: test_state.do_something_arg(1)) # lambda returning EventHandler must match spec - with pytest.raises(EventFnArgMismatch): + with pytest.raises(EventFnArgMismatchError): component2.create(on_click=lambda: test_state.do_something_arg) # Mixed EventSpec and EventHandler must match spec. - with pytest.raises(EventFnArgMismatch): + with pytest.raises(EventFnArgMismatchError): component2.create( on_click=lambda: [ test_state.do_something_arg(1),