ignore lambdas when resolving annotations
This commit is contained in:
parent
66d06574a2
commit
62916ac6a1
@ -39,7 +39,11 @@ from typing_extensions import (
|
||||
from reflex import constants
|
||||
from reflex.constants.state import FRONTEND_EVENT_STATE
|
||||
from reflex.utils import console, format
|
||||
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgTypeMismatch
|
||||
from reflex.utils.exceptions import (
|
||||
EventFnArgMismatchError,
|
||||
EventHandlerArgTypeMismatchError,
|
||||
MissingAnnotationError,
|
||||
)
|
||||
from reflex.utils.types import ArgsSpec, GenericType, typehint_issubclass
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
@ -1218,7 +1222,7 @@ def call_event_handler(
|
||||
key: The key to pass to the event handler.
|
||||
|
||||
Raises:
|
||||
EventHandlerArgTypeMismatch: If the event handler arguments do not match the event spec.
|
||||
EventHandlerArgTypeMismatchError: If the event handler arguments do not match the event spec.
|
||||
TypeError: If the event handler arguments are invalid.
|
||||
|
||||
Returns:
|
||||
@ -1296,7 +1300,7 @@ def call_event_handler(
|
||||
if compare_result:
|
||||
continue
|
||||
else:
|
||||
raise EventHandlerArgTypeMismatch(
|
||||
raise EventHandlerArgTypeMismatchError(
|
||||
f"Event handler {key} expects {args_types_without_vars[i]} for argument {arg} but got {type_hints_of_provided_callback[arg]} as annotated in {event_callback.fn.__qualname__} instead."
|
||||
)
|
||||
|
||||
@ -1341,19 +1345,23 @@ def unwrap_var_annotation(annotation: GenericType):
|
||||
return annotation
|
||||
|
||||
|
||||
def resolve_annotation(annotations: dict[str, Any], arg_name: str):
|
||||
def resolve_annotation(annotations: dict[str, Any], arg_name: str, spec: ArgsSpec):
|
||||
"""Resolve the annotation for the given argument name.
|
||||
|
||||
Args:
|
||||
annotations: The annotations.
|
||||
arg_name: The argument name.
|
||||
spec: The specs which the annotations come from.
|
||||
|
||||
Raises:
|
||||
MissingAnnotationError: If the annotation is missing for non-lambda methods.
|
||||
|
||||
Returns:
|
||||
The resolved annotation.
|
||||
"""
|
||||
annotation = annotations.get(arg_name)
|
||||
if annotation is None:
|
||||
console.error(f"Invalid annotation '{annotation}' for var '{arg_name}'.")
|
||||
if annotation is None and not isinstance(spec, types.LambdaType):
|
||||
raise MissingAnnotationError(var_name=arg_name)
|
||||
return annotation
|
||||
|
||||
|
||||
@ -1375,7 +1383,13 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]):
|
||||
arg_spec(
|
||||
*[
|
||||
Var(f"_{l_arg}").to(
|
||||
unwrap_var_annotation(resolve_annotation(annotations, l_arg))
|
||||
unwrap_var_annotation(
|
||||
resolve_annotation(
|
||||
annotations,
|
||||
l_arg,
|
||||
spec=arg_spec,
|
||||
)
|
||||
)
|
||||
)
|
||||
for l_arg in spec.args
|
||||
]
|
||||
@ -1391,7 +1405,7 @@ def check_fn_match_arg_spec(
|
||||
func_name: str | None = None,
|
||||
):
|
||||
"""Ensures that the function signature matches the passed argument specification
|
||||
or raises an EventFnArgMismatch if they do not.
|
||||
or raises an EventFnArgMismatchError if they do not.
|
||||
|
||||
Args:
|
||||
user_func: The function to be validated.
|
||||
@ -1401,7 +1415,7 @@ def check_fn_match_arg_spec(
|
||||
func_name: The name of the function to be validated.
|
||||
|
||||
Raises:
|
||||
EventFnArgMismatch: Raised if the number of mandatory arguments do not match
|
||||
EventFnArgMismatchError: Raised if the number of mandatory arguments do not match
|
||||
"""
|
||||
user_args = inspect.getfullargspec(user_func).args
|
||||
# Drop the first argument if it's a bound method
|
||||
@ -1417,7 +1431,7 @@ def check_fn_match_arg_spec(
|
||||
number_of_event_args = len(parsed_event_args)
|
||||
|
||||
if number_of_user_args - number_of_user_default_args > number_of_event_args:
|
||||
raise EventFnArgMismatch(
|
||||
raise EventFnArgMismatchError(
|
||||
f"Event {key} only provides {number_of_event_args} arguments, but "
|
||||
f"{func_name or user_func} requires at least {number_of_user_args - number_of_user_default_args} "
|
||||
"arguments to be passed to the event handler.\n"
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Custom Exceptions."""
|
||||
|
||||
from typing import NoReturn
|
||||
|
||||
|
||||
class ReflexError(Exception):
|
||||
"""Base exception for all Reflex exceptions."""
|
||||
@ -71,6 +69,18 @@ class UntypedComputedVarError(ReflexError, TypeError):
|
||||
super().__init__(f"Computed var '{var_name}' must have a type annotation.")
|
||||
|
||||
|
||||
class MissingAnnotationError(ReflexError, TypeError):
|
||||
"""Custom TypeError for missing annotations."""
|
||||
|
||||
def __init__(self, var_name):
|
||||
"""Initialize the MissingAnnotationError.
|
||||
|
||||
Args:
|
||||
var_name: The name of the var.
|
||||
"""
|
||||
super().__init__(f"Var '{var_name}' must have a type annotation.")
|
||||
|
||||
|
||||
class UploadValueError(ReflexError, ValueError):
|
||||
"""Custom ValueError for upload related errors."""
|
||||
|
||||
@ -107,11 +117,11 @@ class MatchTypeError(ReflexError, TypeError):
|
||||
"""Raised when the return types of match cases are different."""
|
||||
|
||||
|
||||
class EventHandlerArgTypeMismatch(ReflexError, TypeError):
|
||||
class EventHandlerArgTypeMismatchError(ReflexError, TypeError):
|
||||
"""Raised when the annotations of args accepted by an EventHandler differs from the spec of the event trigger."""
|
||||
|
||||
|
||||
class EventFnArgMismatch(ReflexError, TypeError):
|
||||
class EventFnArgMismatchError(ReflexError, TypeError):
|
||||
"""Raised when the number of args required by an event handler is more than provided by the event trigger."""
|
||||
|
||||
|
||||
@ -178,23 +188,21 @@ class StateSerializationError(ReflexError):
|
||||
class SystemPackageMissingError(ReflexError):
|
||||
"""Raised when a system package is missing."""
|
||||
|
||||
def __init__(self, package: str):
|
||||
"""Initialize the SystemPackageMissingError.
|
||||
|
||||
def raise_system_package_missing_error(package: str) -> NoReturn:
|
||||
"""Raise a SystemPackageMissingError.
|
||||
Args:
|
||||
package: The missing package.
|
||||
"""
|
||||
from reflex.constants import IS_MACOS
|
||||
|
||||
Args:
|
||||
package: The name of the missing system package.
|
||||
|
||||
Raises:
|
||||
SystemPackageMissingError: The raised exception.
|
||||
"""
|
||||
from reflex.constants import IS_MACOS
|
||||
|
||||
raise SystemPackageMissingError(
|
||||
f"System package '{package}' is missing."
|
||||
" Please install it through your system package manager."
|
||||
+ (f" You can do so by running 'brew install {package}'." if IS_MACOS else "")
|
||||
)
|
||||
extra = (
|
||||
f" You can do so by running 'brew install {package}'." if IS_MACOS else ""
|
||||
)
|
||||
super().__init__(
|
||||
f"System package '{package}' is missing."
|
||||
f" Please install it through your system package manager.{extra}"
|
||||
)
|
||||
|
||||
|
||||
class InvalidLockWarningThresholdError(ReflexError):
|
||||
|
@ -37,7 +37,7 @@ from reflex.config import Config, environment, get_config
|
||||
from reflex.utils import console, net, path_ops, processes, redir
|
||||
from reflex.utils.exceptions import (
|
||||
GeneratedCodeHasNoFunctionDefs,
|
||||
raise_system_package_missing_error,
|
||||
SystemPackageMissingError,
|
||||
)
|
||||
from reflex.utils.format import format_library_name
|
||||
from reflex.utils.registry import _get_npm_registry
|
||||
@ -817,7 +817,11 @@ def install_node():
|
||||
|
||||
|
||||
def install_bun():
|
||||
"""Install bun onto the user's system."""
|
||||
"""Install bun onto the user's system.
|
||||
|
||||
Raises:
|
||||
SystemPackageMissingError: If "unzip" is missing.
|
||||
"""
|
||||
win_supported = is_windows_bun_supported()
|
||||
one_drive_in_path = windows_check_onedrive_in_path()
|
||||
if constants.IS_WINDOWS and (not win_supported or one_drive_in_path):
|
||||
@ -856,7 +860,7 @@ def install_bun():
|
||||
else:
|
||||
unzip_path = path_ops.which("unzip")
|
||||
if unzip_path is None:
|
||||
raise_system_package_missing_error("unzip")
|
||||
raise SystemPackageMissingError("unzip")
|
||||
|
||||
# Run the bun install script.
|
||||
download_and_run(
|
||||
|
@ -16,7 +16,7 @@ from .utils import SessionStorage
|
||||
def CallScript():
|
||||
"""A test app for browser javascript integration."""
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import reflex as rx
|
||||
|
||||
@ -43,15 +43,17 @@ def CallScript():
|
||||
external_scripts = inline_scripts.replace("inline", "external")
|
||||
|
||||
class CallScriptState(rx.State):
|
||||
results: List[Optional[Union[str, Dict, List]]] = []
|
||||
inline_counter: int = 0
|
||||
external_counter: int = 0
|
||||
results: rx.Field[list[Optional[Union[str, dict, list]]]] = rx.field([])
|
||||
inline_counter: rx.Field[int] = rx.field(0)
|
||||
external_counter: rx.Field[int] = rx.field(0)
|
||||
value: str = "Initial"
|
||||
last_result: str = ""
|
||||
last_result: int = 0
|
||||
|
||||
@rx.event
|
||||
def call_script_callback(self, result):
|
||||
self.results.append(result)
|
||||
|
||||
@rx.event
|
||||
def call_script_callback_other_arg(self, result, other_arg):
|
||||
self.results.append([other_arg, result])
|
||||
|
||||
@ -91,7 +93,7 @@ def CallScript():
|
||||
def call_script_inline_return_lambda(self):
|
||||
return rx.call_script(
|
||||
"inline2()",
|
||||
callback=lambda result: CallScriptState.call_script_callback_other_arg( # type: ignore
|
||||
callback=lambda result: CallScriptState.call_script_callback_other_arg(
|
||||
result, "lambda"
|
||||
),
|
||||
)
|
||||
@ -100,7 +102,7 @@ def CallScript():
|
||||
def get_inline_counter(self):
|
||||
return rx.call_script(
|
||||
"inline_counter",
|
||||
callback=CallScriptState.set_inline_counter, # type: ignore
|
||||
callback=CallScriptState.setvar("inline_counter"),
|
||||
)
|
||||
|
||||
@rx.event
|
||||
@ -139,7 +141,7 @@ def CallScript():
|
||||
def call_script_external_return_lambda(self):
|
||||
return rx.call_script(
|
||||
"external2()",
|
||||
callback=lambda result: CallScriptState.call_script_callback_other_arg( # type: ignore
|
||||
callback=lambda result: CallScriptState.call_script_callback_other_arg(
|
||||
result, "lambda"
|
||||
),
|
||||
)
|
||||
@ -148,28 +150,28 @@ def CallScript():
|
||||
def get_external_counter(self):
|
||||
return rx.call_script(
|
||||
"external_counter",
|
||||
callback=CallScriptState.set_external_counter, # type: ignore
|
||||
callback=CallScriptState.setvar("external_counter"),
|
||||
)
|
||||
|
||||
@rx.event
|
||||
def call_with_var_f_string(self):
|
||||
return rx.call_script(
|
||||
f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}",
|
||||
callback=CallScriptState.set_last_result, # type: ignore
|
||||
callback=CallScriptState.setvar("last_result"),
|
||||
)
|
||||
|
||||
@rx.event
|
||||
def call_with_var_str_cast(self):
|
||||
return rx.call_script(
|
||||
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}",
|
||||
callback=CallScriptState.set_last_result, # type: ignore
|
||||
callback=CallScriptState.setvar("last_result"),
|
||||
)
|
||||
|
||||
@rx.event
|
||||
def call_with_var_f_string_wrapped(self):
|
||||
return rx.call_script(
|
||||
rx.Var(f"{rx.Var('inline_counter')} + {rx.Var('external_counter')}"),
|
||||
callback=CallScriptState.set_last_result, # type: ignore
|
||||
callback=CallScriptState.setvar("last_result"),
|
||||
)
|
||||
|
||||
@rx.event
|
||||
@ -178,7 +180,7 @@ def CallScript():
|
||||
rx.Var(
|
||||
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}"
|
||||
),
|
||||
callback=CallScriptState.set_last_result, # type: ignore
|
||||
callback=CallScriptState.setvar("last_result"),
|
||||
)
|
||||
|
||||
@rx.event
|
||||
@ -193,17 +195,17 @@ def CallScript():
|
||||
def index():
|
||||
return rx.vstack(
|
||||
rx.input(
|
||||
value=CallScriptState.inline_counter.to(str), # type: ignore
|
||||
value=CallScriptState.inline_counter.to(str),
|
||||
id="inline_counter",
|
||||
read_only=True,
|
||||
),
|
||||
rx.input(
|
||||
value=CallScriptState.external_counter.to(str), # type: ignore
|
||||
value=CallScriptState.external_counter.to(str),
|
||||
id="external_counter",
|
||||
read_only=True,
|
||||
),
|
||||
rx.text_area(
|
||||
value=CallScriptState.results.to_string(), # type: ignore
|
||||
value=CallScriptState.results.to_string(),
|
||||
id="results",
|
||||
read_only=True,
|
||||
),
|
||||
@ -273,7 +275,7 @@ def CallScript():
|
||||
CallScriptState.value,
|
||||
on_click=rx.call_script(
|
||||
"'updated'",
|
||||
callback=CallScriptState.set_value, # type: ignore
|
||||
callback=CallScriptState.setvar("value"),
|
||||
),
|
||||
id="update_value",
|
||||
),
|
||||
@ -282,7 +284,7 @@ def CallScript():
|
||||
value=CallScriptState.last_result,
|
||||
id="last_result",
|
||||
read_only=True,
|
||||
on_click=CallScriptState.set_last_result(""), # type: ignore
|
||||
on_click=CallScriptState.setvar("last_result", 0),
|
||||
),
|
||||
rx.button(
|
||||
"call_with_var_f_string",
|
||||
@ -308,7 +310,7 @@ def CallScript():
|
||||
"call_with_var_f_string_inline",
|
||||
on_click=rx.call_script(
|
||||
f"{rx.Var('inline_counter')} + {CallScriptState.last_result}",
|
||||
callback=CallScriptState.set_last_result, # type: ignore
|
||||
callback=CallScriptState.setvar("last_result"),
|
||||
),
|
||||
id="call_with_var_f_string_inline",
|
||||
),
|
||||
@ -316,7 +318,7 @@ def CallScript():
|
||||
"call_with_var_str_cast_inline",
|
||||
on_click=rx.call_script(
|
||||
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}",
|
||||
callback=CallScriptState.set_last_result, # type: ignore
|
||||
callback=CallScriptState.setvar("last_result"),
|
||||
),
|
||||
id="call_with_var_str_cast_inline",
|
||||
),
|
||||
@ -326,7 +328,7 @@ def CallScript():
|
||||
rx.Var(
|
||||
f"{rx.Var('inline_counter')} + {CallScriptState.last_result}"
|
||||
),
|
||||
callback=CallScriptState.set_last_result, # type: ignore
|
||||
callback=CallScriptState.setvar("last_result"),
|
||||
),
|
||||
id="call_with_var_f_string_wrapped_inline",
|
||||
),
|
||||
@ -336,7 +338,7 @@ def CallScript():
|
||||
rx.Var(
|
||||
f"{rx.Var('inline_counter')!s} + {rx.Var('external_counter')!s}"
|
||||
),
|
||||
callback=CallScriptState.set_last_result, # type: ignore
|
||||
callback=CallScriptState.setvar("last_result"),
|
||||
),
|
||||
id="call_with_var_str_cast_wrapped_inline",
|
||||
),
|
||||
@ -483,7 +485,7 @@ def test_call_script_w_var(
|
||||
"""
|
||||
assert_token(driver)
|
||||
last_result = driver.find_element(By.ID, "last_result")
|
||||
assert last_result.get_attribute("value") == ""
|
||||
assert last_result.get_attribute("value") == "0"
|
||||
|
||||
inline_return_button = driver.find_element(By.ID, "inline_return")
|
||||
|
||||
|
@ -28,7 +28,10 @@ from reflex.event import (
|
||||
from reflex.state import BaseState
|
||||
from reflex.style import Style
|
||||
from reflex.utils import imports
|
||||
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgTypeMismatch
|
||||
from reflex.utils.exceptions import (
|
||||
EventFnArgMismatchError,
|
||||
EventHandlerArgTypeMismatchError,
|
||||
)
|
||||
from reflex.utils.imports import ImportDict, ImportVar, ParsedImportDict, parse_imports
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
@ -910,23 +913,23 @@ def test_invalid_event_handler_args(component2, test_state):
|
||||
test_state: A test state.
|
||||
"""
|
||||
# EventHandler args must match
|
||||
with pytest.raises(EventFnArgMismatch):
|
||||
with pytest.raises(EventFnArgMismatchError):
|
||||
component2.create(on_click=test_state.do_something_arg)
|
||||
|
||||
# Multiple EventHandler args: all must match
|
||||
with pytest.raises(EventFnArgMismatch):
|
||||
with pytest.raises(EventFnArgMismatchError):
|
||||
component2.create(
|
||||
on_click=[test_state.do_something_arg, test_state.do_something]
|
||||
)
|
||||
|
||||
# # Event Handler types must match
|
||||
with pytest.raises(EventHandlerArgTypeMismatch):
|
||||
with pytest.raises(EventHandlerArgTypeMismatchError):
|
||||
component2.create(
|
||||
on_user_visited_count_changed=test_state.do_something_with_bool
|
||||
)
|
||||
with pytest.raises(EventHandlerArgTypeMismatch):
|
||||
with pytest.raises(EventHandlerArgTypeMismatchError):
|
||||
component2.create(on_user_list_changed=test_state.do_something_with_int)
|
||||
with pytest.raises(EventHandlerArgTypeMismatch):
|
||||
with pytest.raises(EventHandlerArgTypeMismatchError):
|
||||
component2.create(on_user_list_changed=test_state.do_something_with_list_int)
|
||||
|
||||
component2.create(on_open=test_state.do_something_with_int)
|
||||
@ -945,15 +948,15 @@ def test_invalid_event_handler_args(component2, test_state):
|
||||
)
|
||||
|
||||
# lambda signature must match event trigger.
|
||||
with pytest.raises(EventFnArgMismatch):
|
||||
with pytest.raises(EventFnArgMismatchError):
|
||||
component2.create(on_click=lambda _: test_state.do_something_arg(1))
|
||||
|
||||
# lambda returning EventHandler must match spec
|
||||
with pytest.raises(EventFnArgMismatch):
|
||||
with pytest.raises(EventFnArgMismatchError):
|
||||
component2.create(on_click=lambda: test_state.do_something_arg)
|
||||
|
||||
# Mixed EventSpec and EventHandler must match spec.
|
||||
with pytest.raises(EventFnArgMismatch):
|
||||
with pytest.raises(EventFnArgMismatchError):
|
||||
component2.create(
|
||||
on_click=lambda: [
|
||||
test_state.do_something_arg(1),
|
||||
|
Loading…
Reference in New Issue
Block a user