ignore lambdas when resolving annotations

This commit is contained in:
Lendemor 2025-01-15 14:31:14 +01:00
parent 66d06574a2
commit 62916ac6a1
5 changed files with 95 additions and 64 deletions

View File

@ -39,7 +39,11 @@ from typing_extensions import (
from reflex import constants
from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import console, format
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgTypeMismatch
from reflex.utils.exceptions import (
EventFnArgMismatchError,
EventHandlerArgTypeMismatchError,
MissingAnnotationError,
)
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
@ -1218,7 +1222,7 @@ def call_event_handler(
key: The key to pass to the event handler.
Raises:
EventHandlerArgTypeMismatch: If the event handler arguments do not match the event spec.
EventHandlerArgTypeMismatchError: If the event handler arguments do not match the event spec.
TypeError: If the event handler arguments are invalid.
Returns:
@ -1296,7 +1300,7 @@ def call_event_handler(
if compare_result:
continue
else:
raise EventHandlerArgTypeMismatch(
raise EventHandlerArgTypeMismatchError(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
)
@ -1341,19 +1345,23 @@ def unwrap_var_annotation(annotation: GenericType):
return annotation
def resolve_annotation(annotations: dict[str, Any], arg_name: str):
def resolve_annotation(annotations: dict[str, Any], arg_name: str, spec: ArgsSpec):
"""Resolve the annotation for the given argument name.
Args:
annotations: The annotations.
arg_name: The argument name.
spec: The specs which the annotations come from.
Raises:
MissingAnnotationError: If the annotation is missing for non-lambda methods.
Returns:
The resolved annotation.
"""
annotation = annotations.get(arg_name)
if annotation is None:
console.error(f"Invalid annotation '{annotation}' for var '{arg_name}'.")
if annotation is None and not isinstance(spec, types.LambdaType):
raise MissingAnnotationError(var_name=arg_name)
return annotation
@ -1375,7 +1383,13 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
arg_spec(
*[
Var(f"_{l_arg}").to(
unwrap_var_annotation(resolve_annotation(annotations, l_arg))
unwrap_var_annotation(
resolve_annotation(
annotations,
l_arg,
spec=arg_spec,
)
)
)
for l_arg in spec.args
]
@ -1391,7 +1405,7 @@ def check_fn_match_arg_spec(
func_name: str | None = None,
):
"""Ensures that the function signature matches the passed argument specification
or raises an EventFnArgMismatch if they do not.
or raises an EventFnArgMismatchError if they do not.
Args:
user_func: The function to be validated.
@ -1401,7 +1415,7 @@ def check_fn_match_arg_spec(
func_name: The name of the function to be validated.
Raises:
EventFnArgMismatch: Raised if the number of mandatory arguments do not match
EventFnArgMismatchError: Raised if the number of mandatory arguments do not match
"""
user_args = inspect.getfullargspec(user_func).args
# Drop the first argument if it's a bound method
@ -1417,7 +1431,7 @@ def check_fn_match_arg_spec(
number_of_event_args = len(parsed_event_args)
if number_of_user_args - number_of_user_default_args > number_of_event_args:
raise EventFnArgMismatch(
raise EventFnArgMismatchError(
f"Event {key} only provides {number_of_event_args} arguments, but "
f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} "
"arguments to be passed to the event handler.\n"

View File

@ -1,7 +1,5 @@
"""Custom Exceptions."""
from typing import NoReturn
class ReflexError(Exception):
"""Base exception for all Reflex exceptions."""
@ -71,6 +69,18 @@ class UntypedComputedVarError(ReflexError, TypeError):
super().__init__(f"Computed var '{var_name}' must have a type annotation.")
class MissingAnnotationError(ReflexError, TypeError):
"""Custom TypeError for missing annotations."""
def __init__(self, var_name):
"""Initialize the MissingAnnotationError.
Args:
var_name: The name of the var.
"""
super().__init__(f"Var '{var_name}' must have a type annotation.")
class UploadValueError(ReflexError, ValueError):
"""Custom ValueError for upload related errors."""
@ -107,11 +117,11 @@ class MatchTypeError(ReflexError, TypeError):
"""Raised when the return types of match cases are different."""
class EventHandlerArgTypeMismatch(ReflexError, TypeError):
class EventHandlerArgTypeMismatchError(ReflexError, TypeError):
"""Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
class EventFnArgMismatch(ReflexError, TypeError):
class EventFnArgMismatchError(ReflexError, TypeError):
"""Raised when the number of args required by an event handler is more than provided by the event trigger."""
@ -178,23 +188,21 @@ class StateSerializationError(ReflexError):
class SystemPackageMissingError(ReflexError):
"""Raised when a system package is missing."""
def __init__(self, package: str):
"""Initialize the SystemPackageMissingError.
def raise_system_package_missing_error(package: str) -> NoReturn:
"""Raise a SystemPackageMissingError.
Args:
package: The missing package.
"""
from reflex.constants import IS_MACOS
Args:
package: The name of the missing system package.
Raises:
SystemPackageMissingError: The raised exception.
"""
from reflex.constants import IS_MACOS
raise SystemPackageMissingError(
f"System package '{package}' is missing."
" Please install it through your system package manager."
+ (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
)
extra = (
f" You can do so by running 'brew install {package}'." if IS_MACOS else ""
)
super().__init__(
f"System package '{package}' is missing."
f" Please install it through your system package manager.{extra}"
)
class InvalidLockWarningThresholdError(ReflexError):

View File

@ -37,7 +37,7 @@ from reflex.config import Config, environment, get_config
from reflex.utils import console, net, path_ops, processes, redir
from reflex.utils.exceptions import (
GeneratedCodeHasNoFunctionDefs,
raise_system_package_missing_error,
SystemPackageMissingError,
)
from reflex.utils.format import format_library_name
from reflex.utils.registry import _get_npm_registry
@ -817,7 +817,11 @@ def install_node():
def install_bun():
"""Install bun onto the user's system."""
"""Install bun onto the user's system.
Raises:
SystemPackageMissingError: If "unzip" is missing.
"""
win_supported = is_windows_bun_supported()
one_drive_in_path = windows_check_onedrive_in_path()
if constants.IS_WINDOWS and (not win_supported or one_drive_in_path):
@ -856,7 +860,7 @@ def install_bun():
else:
unzip_path = path_ops.which("unzip")
if unzip_path is None:
raise_system_package_missing_error("unzip")
raise SystemPackageMissingError("unzip")
# Run the bun install script.
download_and_run(

View File

@ -16,7 +16,7 @@ from .utils import SessionStorage
def CallScript():
"""A test app for browser javascript integration."""
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import reflex as rx
@ -43,15 +43,17 @@ def CallScript():
external_scripts = inline_scripts.replace("inline", "external")
class CallScriptState(rx.State):
results: List[Optional[Union[str, Dict, List]]] = []
inline_counter: int = 0
external_counter: int = 0
results: rx.Field[list[Optional[Union[str, dict, list]]]] = rx.field([])
inline_counter: rx.Field[int] = rx.field(0)
external_counter: rx.Field[int] = rx.field(0)
value: str = "Initial"
last_result: str = ""
last_result: int = 0
@rx.event
def call_script_callback(self, result):
self.results.append(result)
@rx.event
def call_script_callback_other_arg(self, result, other_arg):
self.results.append([other_arg, result])
@ -91,7 +93,7 @@ def CallScript():
def call_script_inline_return_lambda(self):
return rx.call_script(
"inline2()",
callback=lambda result: CallScriptState.call_script_callback_other_arg( # type: ignore
callback=lambda result: CallScriptState.call_script_callback_other_arg(
result, "lambda"
),
)
@ -100,7 +102,7 @@ def CallScript():
def get_inline_counter(self):
return rx.call_script(
"inline_counter",
callback=CallScriptState.set_inline_counter, # type: ignore
callback=CallScriptState.setvar("inline_counter"),
)
@rx.event
@ -139,7 +141,7 @@ def CallScript():
def call_script_external_return_lambda(self):
return rx.call_script(
"external2()",
callback=lambda result: CallScriptState.call_script_callback_other_arg( # type: ignore
callback=lambda result: CallScriptState.call_script_callback_other_arg(
result, "lambda"
),
)
@ -148,28 +150,28 @@ def CallScript():
def get_external_counter(self):
return rx.call_script(
"external_counter",
callback=CallScriptState.set_external_counter, # type: ignore
callback=CallScriptState.setvar("external_counter"),
)
@rx.event
def call_with_var_f_string(self):
return rx.call_script(
f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}",
callback=CallScriptState.set_last_result, # type: ignore
callback=CallScriptState.setvar("last_result"),
)
@rx.event
def call_with_var_str_cast(self):
return rx.call_script(
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}",
callback=CallScriptState.set_last_result, # type: ignore
callback=CallScriptState.setvar("last_result"),
)
@rx.event
def call_with_var_f_string_wrapped(self):
return rx.call_script(
rx.Var(f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}"),
callback=CallScriptState.set_last_result, # type: ignore
callback=CallScriptState.setvar("last_result"),
)
@rx.event
@ -178,7 +180,7 @@ def CallScript():
rx.Var(
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}"
),
callback=CallScriptState.set_last_result, # type: ignore
callback=CallScriptState.setvar("last_result"),
)
@rx.event
@ -193,17 +195,17 @@ def CallScript():
def index():
return rx.vstack(
rx.input(
value=CallScriptState.inline_counter.to(str), # type: ignore
value=CallScriptState.inline_counter.to(str),
id="inline_counter",
read_only=True,
),
rx.input(
value=CallScriptState.external_counter.to(str), # type: ignore
value=CallScriptState.external_counter.to(str),
id="external_counter",
read_only=True,
),
rx.text_area(
value=CallScriptState.results.to_string(), # type: ignore
value=CallScriptState.results.to_string(),
id="results",
read_only=True,
),
@ -273,7 +275,7 @@ def CallScript():
CallScriptState.value,
on_click=rx.call_script(
"'updated'",
callback=CallScriptState.set_value, # type: ignore
callback=CallScriptState.setvar("value"),
),
id="update_value",
),
@ -282,7 +284,7 @@ def CallScript():
value=CallScriptState.last_result,
id="last_result",
read_only=True,
on_click=CallScriptState.set_last_result(""), # type: ignore
on_click=CallScriptState.setvar("last_result", 0),
),
rx.button(
"call_with_var_f_string",
@ -308,7 +310,7 @@ def CallScript():
"call_with_var_f_string_inline",
on_click=rx.call_script(
f"{rx.Var('inline_counter')} + {CallScriptState.last_result}",
callback=CallScriptState.set_last_result, # type: ignore
callback=CallScriptState.setvar("last_result"),
),
id="call_with_var_f_string_inline",
),
@ -316,7 +318,7 @@ def CallScript():
"call_with_var_str_cast_inline",
on_click=rx.call_script(
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}",
callback=CallScriptState.set_last_result, # type: ignore
callback=CallScriptState.setvar("last_result"),
),
id="call_with_var_str_cast_inline",
),
@ -326,7 +328,7 @@ def CallScript():
rx.Var(
f"{rx.Var('inline_counter')} + {CallScriptState.last_result}"
),
callback=CallScriptState.set_last_result, # type: ignore
callback=CallScriptState.setvar("last_result"),
),
id="call_with_var_f_string_wrapped_inline",
),
@ -336,7 +338,7 @@ def CallScript():
rx.Var(
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}"
),
callback=CallScriptState.set_last_result, # type: ignore
callback=CallScriptState.setvar("last_result"),
),
id="call_with_var_str_cast_wrapped_inline",
),
@ -483,7 +485,7 @@ def test_call_script_w_var(
"""
assert_token(driver)
last_result = driver.find_element(By.ID, "last_result")
assert last_result.get_attribute("value") == ""
assert last_result.get_attribute("value") == "0"
inline_return_button = driver.find_element(By.ID, "inline_return")

View File

@ -28,7 +28,10 @@ from reflex.event import (
from reflex.state import BaseState
from reflex.style import Style
from reflex.utils import imports
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgTypeMismatch
from reflex.utils.exceptions import (
EventFnArgMismatchError,
EventHandlerArgTypeMismatchError,
)
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
@ -910,23 +913,23 @@ def test_invalid_event_handler_args(component2, test_state):
test_state: A test state.
"""
# EventHandler args must match
with pytest.raises(EventFnArgMismatch):
with pytest.raises(EventFnArgMismatchError):
component2.create(on_click=test_state.do_something_arg)
# Multiple EventHandler args: all must match
with pytest.raises(EventFnArgMismatch):
with pytest.raises(EventFnArgMismatchError):
component2.create(
on_click=[test_state.do_something_arg, test_state.do_something]
)
# # Event Handler types must match
with pytest.raises(EventHandlerArgTypeMismatch):
with pytest.raises(EventHandlerArgTypeMismatchError):
component2.create(
on_user_visited_count_changed=test_state.do_something_with_bool
)
with pytest.raises(EventHandlerArgTypeMismatch):
with pytest.raises(EventHandlerArgTypeMismatchError):
component2.create(on_user_list_changed=test_state.do_something_with_int)
with pytest.raises(EventHandlerArgTypeMismatch):
with pytest.raises(EventHandlerArgTypeMismatchError):
component2.create(on_user_list_changed=test_state.do_something_with_list_int)
component2.create(on_open=test_state.do_something_with_int)
@ -945,15 +948,15 @@ def test_invalid_event_handler_args(component2, test_state):
)
# lambda signature must match event trigger.
with pytest.raises(EventFnArgMismatch):
with pytest.raises(EventFnArgMismatchError):
component2.create(on_click=lambda _: test_state.do_something_arg(1))
# lambda returning EventHandler must match spec
with pytest.raises(EventFnArgMismatch):
with pytest.raises(EventFnArgMismatchError):
component2.create(on_click=lambda: test_state.do_something_arg)
# Mixed EventSpec and EventHandler must match spec.
with pytest.raises(EventFnArgMismatch):
with pytest.raises(EventFnArgMismatchError):
component2.create(
on_click=lambda: [
test_state.do_something_arg(1),