EventFnArgMismatch fix to support defaults args (#4004)

* EventFnArgMismatch fix to support defaults args

* fixing type hint and docstring raises

* enforce stronger type checking

* unwrap var annotations :(

---------

Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
LeoH 2024-09-26 22:56:53 +02:00 committed by GitHub
parent 54c7b5a261
commit 60276cf1ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 74 additions and 22 deletions

View File

@ -18,10 +18,12 @@ from typing import (
get_type_hints,
)
from typing_extensions import get_args, get_origin
from reflex import constants
from reflex.utils import format
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
from reflex.utils.types import ArgsSpec
from reflex.utils.types import ArgsSpec, GenericType
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
from reflex.vars.function import FunctionStringVar, FunctionVar
@ -417,7 +419,7 @@ class FileUpload:
on_upload_progress: Optional[Union[EventHandler, Callable]] = None
@staticmethod
def on_upload_progress_args_spec(_prog: Dict[str, Union[int, float, bool]]):
def on_upload_progress_args_spec(_prog: Var[Dict[str, Union[int, float, bool]]]):
"""Args spec for on_upload_progress event handler.
Returns:
@ -910,6 +912,20 @@ def call_event_handler(
)
def unwrap_var_annotation(annotation: GenericType):
"""Unwrap a Var annotation or return it as is if it's not Var[X].
Args:
annotation: The annotation to unwrap.
Returns:
The unwrapped annotation.
"""
if get_origin(annotation) is Var and (args := get_args(annotation)):
return args[0]
return annotation
def parse_args_spec(arg_spec: ArgsSpec):
"""Parse the args provided in the ArgsSpec of an event trigger.
@ -921,20 +937,54 @@ def parse_args_spec(arg_spec: ArgsSpec):
"""
spec = inspect.getfullargspec(arg_spec)
annotations = get_type_hints(arg_spec)
return arg_spec(
*[
Var(f"_{l_arg}").to(annotations.get(l_arg, FrontendEvent))
Var(f"_{l_arg}").to(
unwrap_var_annotation(annotations.get(l_arg, FrontendEvent))
)
for l_arg in spec.args
]
)
def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
"""Ensures that the function signature matches the passed argument specification
or raises an EventFnArgMismatch if they do not.
Args:
fn: The function to be validated.
arg_spec: The argument specification for the event trigger.
Returns:
The parsed arguments from the argument specification.
Raises:
EventFnArgMismatch: Raised if the number of mandatory arguments do not match
"""
fn_args = inspect.getfullargspec(fn).args
fn_defaults_args = inspect.getfullargspec(fn).defaults
n_fn_args = len(fn_args)
n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
if isinstance(fn, types.MethodType):
n_fn_args -= 1 # subtract 1 for bound self arg
parsed_args = parse_args_spec(arg_spec)
if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
raise EventFnArgMismatch(
"The number of mandatory arguments accepted by "
f"{fn} ({n_fn_args - n_fn_defaults_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/"
)
return parsed_args
def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
"""Call a function to a list of event specs.
The function should return a single EventSpec, a list of EventSpecs, or a
single Var. The function signature must match the passed arg_spec or
EventFnArgsMismatch will be raised.
single Var.
Args:
fn: The function to call.
@ -944,7 +994,6 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
The event specs from calling the function or a Var.
Raises:
EventFnArgMismatch: If the function signature doesn't match the arg spec.
EventHandlerValueError: If the lambda returns an unusable value.
"""
# Import here to avoid circular imports.
@ -952,19 +1001,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
from reflex.utils.exceptions import EventHandlerValueError
# Check that fn signature matches arg_spec
fn_args = inspect.getfullargspec(fn).args
n_fn_args = len(fn_args)
if isinstance(fn, types.MethodType):
n_fn_args -= 1 # subtract 1 for bound self arg
parsed_args = parse_args_spec(arg_spec)
if len(parsed_args) != n_fn_args:
raise EventFnArgMismatch(
"The number of arguments accepted by "
f"{fn} ({n_fn_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/"
)
parsed_args = check_fn_match_arg_spec(fn, arg_spec)
# Call the function with the parsed args.
out = fn(*parsed_args)

View File

@ -9,6 +9,7 @@ import sys
import types
from functools import cached_property, lru_cache, wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
@ -96,8 +97,22 @@ PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
StateVar = Union[PrimitiveType, Base, None]
StateIterVar = Union[list, set, tuple]
# ArgsSpec = Callable[[Var], list[Var]]
ArgsSpec = Callable
if TYPE_CHECKING:
from reflex.vars.base import Var
# ArgsSpec = Callable[[Var], list[Var]]
ArgsSpec = (
Callable[[], List[Var]]
| Callable[[Var], List[Var]]
| Callable[[Var, Var], List[Var]]
| Callable[[Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var, Var, Var, Var], List[Var]]
)
else:
ArgsSpec = Callable[..., List[Any]]
PrimitiveToAnnotation = {

View File

@ -97,7 +97,7 @@ def test_call_event_handler_partial():
test_fn_with_args.__qualname__ = "test_fn_with_args"
def spec(a2: str) -> List[str]:
def spec(a2: Var[str]) -> List[Var[str]]:
return [a2]
handler = EventHandler(fn=test_fn_with_args)