ignore lambdas when resolving annotations

This commit is contained in:
Lendemor 2025-01-15 14:31:14 +01:00
parent 66d06574a2
commit 62916ac6a1
5 changed files with 95 additions and 64 deletions

View File

@ -39,7 +39,11 @@ from typing_extensions import (
from reflex import constants from reflex import constants
from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.constants.state import FRONTEND_EVENT_STATE
from reflex.utils import console, format from reflex.utils import console, format
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgTypeMismatch from reflex.utils.exceptions import (
EventFnArgMismatchError,
EventHandlerArgTypeMismatchError,
MissingAnnotationError,
)
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
from reflex.vars import VarData from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
@ -1218,7 +1222,7 @@ def call_event_handler(
key: The key to pass to the event handler. key: The key to pass to the event handler.
Raises: Raises:
EventHandlerArgTypeMismatch: If the event handler arguments do not match the event spec. EventHandlerArgTypeMismatchError: If the event handler arguments do not match the event spec.
TypeError: If the event handler arguments are invalid. TypeError: If the event handler arguments are invalid.
Returns: Returns:
@ -1296,7 +1300,7 @@ def call_event_handler(
if compare_result: if compare_result:
continue continue
else: else:
raise EventHandlerArgTypeMismatch( raise EventHandlerArgTypeMismatchError(
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead." f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
) )
@ -1341,19 +1345,23 @@ def unwrap_var_annotation(annotation: GenericType):
return annotation return annotation
def resolve_annotation(annotations: dict[str, Any], arg_name: str): def resolve_annotation(annotations: dict[str, Any], arg_name: str, spec: ArgsSpec):
"""Resolve the annotation for the given argument name. """Resolve the annotation for the given argument name.
Args: Args:
annotations: The annotations. annotations: The annotations.
arg_name: The argument name. arg_name: The argument name.
spec: The specs which the annotations come from.
Raises:
MissingAnnotationError: If the annotation is missing for non-lambda methods.
Returns: Returns:
The resolved annotation. The resolved annotation.
""" """
annotation = annotations.get(arg_name) annotation = annotations.get(arg_name)
if annotation is None: if annotation is None and not isinstance(spec, types.LambdaType):
console.error(f"Invalid annotation '{annotation}' for var '{arg_name}'.") raise MissingAnnotationError(var_name=arg_name)
return annotation return annotation
@ -1375,7 +1383,13 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
arg_spec( arg_spec(
*[ *[
Var(f"_{l_arg}").to( Var(f"_{l_arg}").to(
unwrap_var_annotation(resolve_annotation(annotations, l_arg)) unwrap_var_annotation(
resolve_annotation(
annotations,
l_arg,
spec=arg_spec,
)
)
) )
for l_arg in spec.args for l_arg in spec.args
] ]
@ -1391,7 +1405,7 @@ def check_fn_match_arg_spec(
func_name: str | None = None, func_name: str | None = None,
): ):
"""Ensures that the function signature matches the passed argument specification """Ensures that the function signature matches the passed argument specification
or raises an EventFnArgMismatch if they do not. or raises an EventFnArgMismatchError if they do not.
Args: Args:
user_func: The function to be validated. user_func: The function to be validated.
@ -1401,7 +1415,7 @@ def check_fn_match_arg_spec(
func_name: The name of the function to be validated. func_name: The name of the function to be validated.
Raises: Raises:
EventFnArgMismatch: Raised if the number of mandatory arguments do not match EventFnArgMismatchError: Raised if the number of mandatory arguments do not match
""" """
user_args = inspect.getfullargspec(user_func).args user_args = inspect.getfullargspec(user_func).args
# Drop the first argument if it's a bound method # Drop the first argument if it's a bound method
@ -1417,7 +1431,7 @@ def check_fn_match_arg_spec(
number_of_event_args = len(parsed_event_args) number_of_event_args = len(parsed_event_args)
if number_of_user_args - number_of_user_default_args > number_of_event_args: if number_of_user_args - number_of_user_default_args > number_of_event_args:
raise EventFnArgMismatch( raise EventFnArgMismatchError(
f"Event {key} only provides {number_of_event_args} arguments, but " f"Event {key} only provides {number_of_event_args} arguments, but "
f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} " f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} "
"arguments to be passed to the event handler.\n" "arguments to be passed to the event handler.\n"

View File

@ -1,7 +1,5 @@
"""Custom Exceptions.""" """Custom Exceptions."""
from typing import NoReturn
class ReflexError(Exception): class ReflexError(Exception):
"""Base exception for all Reflex exceptions.""" """Base exception for all Reflex exceptions."""
@ -71,6 +69,18 @@ class UntypedComputedVarError(ReflexError, TypeError):
super().__init__(f"Computed var '{var_name}' must have a type annotation.") super().__init__(f"Computed var '{var_name}' must have a type annotation.")
class MissingAnnotationError(ReflexError, TypeError):
"""Custom TypeError for missing annotations."""
def __init__(self, var_name):
"""Initialize the MissingAnnotationError.
Args:
var_name: The name of the var.
"""
super().__init__(f"Var '{var_name}' must have a type annotation.")
class UploadValueError(ReflexError, ValueError): class UploadValueError(ReflexError, ValueError):
"""Custom ValueError for upload related errors.""" """Custom ValueError for upload related errors."""
@ -107,11 +117,11 @@ class MatchTypeError(ReflexError, TypeError):
"""Raised when the return types of match cases are different.""" """Raised when the return types of match cases are different."""
class EventHandlerArgTypeMismatch(ReflexError, TypeError): class EventHandlerArgTypeMismatchError(ReflexError, TypeError):
"""Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger.""" """Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
class EventFnArgMismatch(ReflexError, TypeError): class EventFnArgMismatchError(ReflexError, TypeError):
"""Raised when the number of args required by an event handler is more than provided by the event trigger.""" """Raised when the number of args required by an event handler is more than provided by the event trigger."""
@ -178,23 +188,21 @@ class StateSerializationError(ReflexError):
class SystemPackageMissingError(ReflexError): class SystemPackageMissingError(ReflexError):
"""Raised when a system package is missing.""" """Raised when a system package is missing."""
def __init__(self, package: str):
"""Initialize the SystemPackageMissingError.
def raise_system_package_missing_error(package: str) -> NoReturn: Args:
"""Raise a SystemPackageMissingError. package: The missing package.
"""
from reflex.constants import IS_MACOS
Args: extra = (
package: The name of the missing system package. f" You can do so by running 'brew install {package}'." if IS_MACOS else ""
)
Raises: super().__init__(
SystemPackageMissingError: The raised exception. f"System package '{package}' is missing."
""" f" Please install it through your system package manager.{extra}"
from reflex.constants import IS_MACOS )
raise SystemPackageMissingError(
f"System package '{package}' is missing."
" Please install it through your system package manager."
+ (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
)
class InvalidLockWarningThresholdError(ReflexError): class InvalidLockWarningThresholdError(ReflexError):

View File

@ -37,7 +37,7 @@ from reflex.config import Config, environment, get_config
from reflex.utils import console, net, path_ops, processes, redir from reflex.utils import console, net, path_ops, processes, redir
from reflex.utils.exceptions import ( from reflex.utils.exceptions import (
GeneratedCodeHasNoFunctionDefs, GeneratedCodeHasNoFunctionDefs,
raise_system_package_missing_error, SystemPackageMissingError,
) )
from reflex.utils.format import format_library_name from reflex.utils.format import format_library_name
from reflex.utils.registry import _get_npm_registry from reflex.utils.registry import _get_npm_registry
@ -817,7 +817,11 @@ def install_node():
def install_bun(): def install_bun():
"""Install bun onto the user's system.""" """Install bun onto the user's system.
Raises:
SystemPackageMissingError: If "unzip" is missing.
"""
win_supported = is_windows_bun_supported() win_supported = is_windows_bun_supported()
one_drive_in_path = windows_check_onedrive_in_path() one_drive_in_path = windows_check_onedrive_in_path()
if constants.IS_WINDOWS and (not win_supported or one_drive_in_path): if constants.IS_WINDOWS and (not win_supported or one_drive_in_path):
@ -856,7 +860,7 @@ def install_bun():
else: else:
unzip_path = path_ops.which("unzip") unzip_path = path_ops.which("unzip")
if unzip_path is None: if unzip_path is None:
raise_system_package_missing_error("unzip") raise SystemPackageMissingError("unzip")
# Run the bun install script. # Run the bun install script.
download_and_run( download_and_run(

View File

@ -16,7 +16,7 @@ from .utils import SessionStorage
def CallScript(): def CallScript():
"""A test app for browser javascript integration.""" """A test app for browser javascript integration."""
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Optional, Union
import reflex as rx import reflex as rx
@ -43,15 +43,17 @@ def CallScript():
external_scripts = inline_scripts.replace("inline", "external") external_scripts = inline_scripts.replace("inline", "external")
class CallScriptState(rx.State): class CallScriptState(rx.State):
results: List[Optional[Union[str, Dict, List]]] = [] results: rx.Field[list[Optional[Union[str, dict, list]]]] = rx.field([])
inline_counter: int = 0 inline_counter: rx.Field[int] = rx.field(0)
external_counter: int = 0 external_counter: rx.Field[int] = rx.field(0)
value: str = "Initial" value: str = "Initial"
last_result: str = "" last_result: int = 0
@rx.event
def call_script_callback(self, result): def call_script_callback(self, result):
self.results.append(result) self.results.append(result)
@rx.event
def call_script_callback_other_arg(self, result, other_arg): def call_script_callback_other_arg(self, result, other_arg):
self.results.append([other_arg, result]) self.results.append([other_arg, result])
@ -91,7 +93,7 @@ def CallScript():
def call_script_inline_return_lambda(self): def call_script_inline_return_lambda(self):
return rx.call_script( return rx.call_script(
"inline2()", "inline2()",
callback=lambda result: CallScriptState.call_script_callback_other_arg( # type: ignore callback=lambda result: CallScriptState.call_script_callback_other_arg(
result, "lambda" result, "lambda"
), ),
) )
@ -100,7 +102,7 @@ def CallScript():
def get_inline_counter(self): def get_inline_counter(self):
return rx.call_script( return rx.call_script(
"inline_counter", "inline_counter",
callback=CallScriptState.set_inline_counter, # type: ignore callback=CallScriptState.setvar("inline_counter"),
) )
@rx.event @rx.event
@ -139,7 +141,7 @@ def CallScript():
def call_script_external_return_lambda(self): def call_script_external_return_lambda(self):
return rx.call_script( return rx.call_script(
"external2()", "external2()",
callback=lambda result: CallScriptState.call_script_callback_other_arg( # type: ignore callback=lambda result: CallScriptState.call_script_callback_other_arg(
result, "lambda" result, "lambda"
), ),
) )
@ -148,28 +150,28 @@ def CallScript():
def get_external_counter(self): def get_external_counter(self):
return rx.call_script( return rx.call_script(
"external_counter", "external_counter",
callback=CallScriptState.set_external_counter, # type: ignore callback=CallScriptState.setvar("external_counter"),
) )
@rx.event @rx.event
def call_with_var_f_string(self): def call_with_var_f_string(self):
return rx.call_script( return rx.call_script(
f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}", f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}",
callback=CallScriptState.set_last_result, # type: ignore callback=CallScriptState.setvar("last_result"),
) )
@rx.event @rx.event
def call_with_var_str_cast(self): def call_with_var_str_cast(self):
return rx.call_script( return rx.call_script(
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}", f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}",
callback=CallScriptState.set_last_result, # type: ignore callback=CallScriptState.setvar("last_result"),
) )
@rx.event @rx.event
def call_with_var_f_string_wrapped(self): def call_with_var_f_string_wrapped(self):
return rx.call_script( return rx.call_script(
rx.Var(f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}"), rx.Var(f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}"),
callback=CallScriptState.set_last_result, # type: ignore callback=CallScriptState.setvar("last_result"),
) )
@rx.event @rx.event
@ -178,7 +180,7 @@ def CallScript():
rx.Var( rx.Var(
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}" f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}"
), ),
callback=CallScriptState.set_last_result, # type: ignore callback=CallScriptState.setvar("last_result"),
) )
@rx.event @rx.event
@ -193,17 +195,17 @@ def CallScript():
def index(): def index():
return rx.vstack( return rx.vstack(
rx.input( rx.input(
value=CallScriptState.inline_counter.to(str), # type: ignore value=CallScriptState.inline_counter.to(str),
id="inline_counter", id="inline_counter",
read_only=True, read_only=True,
), ),
rx.input( rx.input(
value=CallScriptState.external_counter.to(str), # type: ignore value=CallScriptState.external_counter.to(str),
id="external_counter", id="external_counter",
read_only=True, read_only=True,
), ),
rx.text_area( rx.text_area(
value=CallScriptState.results.to_string(), # type: ignore value=CallScriptState.results.to_string(),
id="results", id="results",
read_only=True, read_only=True,
), ),
@ -273,7 +275,7 @@ def CallScript():
CallScriptState.value, CallScriptState.value,
on_click=rx.call_script( on_click=rx.call_script(
"'updated'", "'updated'",
callback=CallScriptState.set_value, # type: ignore callback=CallScriptState.setvar("value"),
), ),
id="update_value", id="update_value",
), ),
@ -282,7 +284,7 @@ def CallScript():
value=CallScriptState.last_result, value=CallScriptState.last_result,
id="last_result", id="last_result",
read_only=True, read_only=True,
on_click=CallScriptState.set_last_result(""), # type: ignore on_click=CallScriptState.setvar("last_result", 0),
), ),
rx.button( rx.button(
"call_with_var_f_string", "call_with_var_f_string",
@ -308,7 +310,7 @@ def CallScript():
"call_with_var_f_string_inline", "call_with_var_f_string_inline",
on_click=rx.call_script( on_click=rx.call_script(
f"{rx.Var('inline_counter')} + {CallScriptState.last_result}", f"{rx.Var('inline_counter')} + {CallScriptState.last_result}",
callback=CallScriptState.set_last_result, # type: ignore callback=CallScriptState.setvar("last_result"),
), ),
id="call_with_var_f_string_inline", id="call_with_var_f_string_inline",
), ),
@ -316,7 +318,7 @@ def CallScript():
"call_with_var_str_cast_inline", "call_with_var_str_cast_inline",
on_click=rx.call_script( on_click=rx.call_script(
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}", f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}",
callback=CallScriptState.set_last_result, # type: ignore callback=CallScriptState.setvar("last_result"),
), ),
id="call_with_var_str_cast_inline", id="call_with_var_str_cast_inline",
), ),
@ -326,7 +328,7 @@ def CallScript():
rx.Var( rx.Var(
f"{rx.Var('inline_counter')} + {CallScriptState.last_result}" f"{rx.Var('inline_counter')} + {CallScriptState.last_result}"
), ),
callback=CallScriptState.set_last_result, # type: ignore callback=CallScriptState.setvar("last_result"),
), ),
id="call_with_var_f_string_wrapped_inline", id="call_with_var_f_string_wrapped_inline",
), ),
@ -336,7 +338,7 @@ def CallScript():
rx.Var( rx.Var(
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}" f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}"
), ),
callback=CallScriptState.set_last_result, # type: ignore callback=CallScriptState.setvar("last_result"),
), ),
id="call_with_var_str_cast_wrapped_inline", id="call_with_var_str_cast_wrapped_inline",
), ),
@ -483,7 +485,7 @@ def test_call_script_w_var(
""" """
assert_token(driver) assert_token(driver)
last_result = driver.find_element(By.ID, "last_result") last_result = driver.find_element(By.ID, "last_result")
assert last_result.get_attribute("value") == "" assert last_result.get_attribute("value") == "0"
inline_return_button = driver.find_element(By.ID, "inline_return") inline_return_button = driver.find_element(By.ID, "inline_return")

View File

@ -28,7 +28,10 @@ from reflex.event import (
from reflex.state import BaseState from reflex.state import BaseState
from reflex.style import Style from reflex.style import Style
from reflex.utils import imports from reflex.utils import imports
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgTypeMismatch from reflex.utils.exceptions import (
EventFnArgMismatchError,
EventHandlerArgTypeMismatchError,
)
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
from reflex.vars import VarData from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
@ -910,23 +913,23 @@ def test_invalid_event_handler_args(component2, test_state):
test_state: A test state. test_state: A test state.
""" """
# EventHandler args must match # EventHandler args must match
with pytest.raises(EventFnArgMismatch): with pytest.raises(EventFnArgMismatchError):
component2.create(on_click=test_state.do_something_arg) component2.create(on_click=test_state.do_something_arg)
# Multiple EventHandler args: all must match # Multiple EventHandler args: all must match
with pytest.raises(EventFnArgMismatch): with pytest.raises(EventFnArgMismatchError):
component2.create( component2.create(
on_click=[test_state.do_something_arg, test_state.do_something] on_click=[test_state.do_something_arg, test_state.do_something]
) )
# # Event Handler types must match # # Event Handler types must match
with pytest.raises(EventHandlerArgTypeMismatch): with pytest.raises(EventHandlerArgTypeMismatchError):
component2.create( component2.create(
on_user_visited_count_changed=test_state.do_something_with_bool on_user_visited_count_changed=test_state.do_something_with_bool
) )
with pytest.raises(EventHandlerArgTypeMismatch): with pytest.raises(EventHandlerArgTypeMismatchError):
component2.create(on_user_list_changed=test_state.do_something_with_int) component2.create(on_user_list_changed=test_state.do_something_with_int)
with pytest.raises(EventHandlerArgTypeMismatch): with pytest.raises(EventHandlerArgTypeMismatchError):
component2.create(on_user_list_changed=test_state.do_something_with_list_int) component2.create(on_user_list_changed=test_state.do_something_with_list_int)
component2.create(on_open=test_state.do_something_with_int) component2.create(on_open=test_state.do_something_with_int)
@ -945,15 +948,15 @@ def test_invalid_event_handler_args(component2, test_state):
) )
# lambda signature must match event trigger. # lambda signature must match event trigger.
with pytest.raises(EventFnArgMismatch): with pytest.raises(EventFnArgMismatchError):
component2.create(on_click=lambda _: test_state.do_something_arg(1)) component2.create(on_click=lambda _: test_state.do_something_arg(1))
# lambda returning EventHandler must match spec # lambda returning EventHandler must match spec
with pytest.raises(EventFnArgMismatch): with pytest.raises(EventFnArgMismatchError):
component2.create(on_click=lambda: test_state.do_something_arg) component2.create(on_click=lambda: test_state.do_something_arg)
# Mixed EventSpec and EventHandler must match spec. # Mixed EventSpec and EventHandler must match spec.
with pytest.raises(EventFnArgMismatch): with pytest.raises(EventFnArgMismatchError):
component2.create( component2.create(
on_click=lambda: [ on_click=lambda: [
test_state.do_something_arg(1), test_state.do_something_arg(1),