EventFnArgMismatch fix to support defaults args (#4004)
* EventFnArgMismatch fix to support defaults args * fixing type hint and docstring raises * enforce stronger type checking * unwrap var annotations :( --------- Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
parent
54c7b5a261
commit
60276cf1ff
@ -18,10 +18,12 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import get_args, get_origin
|
||||
|
||||
from reflex import constants
|
||||
from reflex.utils import format
|
||||
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
|
||||
from reflex.utils.types import ArgsSpec
|
||||
from reflex.utils.types import ArgsSpec, GenericType
|
||||
from reflex.vars import VarData
|
||||
from reflex.vars.base import LiteralVar, Var
|
||||
from reflex.vars.function import FunctionStringVar, FunctionVar
|
||||
@ -417,7 +419,7 @@ class FileUpload:
|
||||
on_upload_progress: Optional[Union[EventHandler, Callable]] = None
|
||||
|
||||
@staticmethod
|
||||
def on_upload_progress_args_spec(_prog: Dict[str, Union[int, float, bool]]):
|
||||
def on_upload_progress_args_spec(_prog: Var[Dict[str, Union[int, float, bool]]]):
|
||||
"""Args spec for on_upload_progress event handler.
|
||||
|
||||
Returns:
|
||||
@ -910,6 +912,20 @@ def call_event_handler(
|
||||
)
|
||||
|
||||
|
||||
def unwrap_var_annotation(annotation: GenericType):
|
||||
"""Unwrap a Var annotation or return it as is if it's not Var[X].
|
||||
|
||||
Args:
|
||||
annotation: The annotation to unwrap.
|
||||
|
||||
Returns:
|
||||
The unwrapped annotation.
|
||||
"""
|
||||
if get_origin(annotation) is Var and (args := get_args(annotation)):
|
||||
return args[0]
|
||||
return annotation
|
||||
|
||||
|
||||
def parse_args_spec(arg_spec: ArgsSpec):
|
||||
"""Parse the args provided in the ArgsSpec of an event trigger.
|
||||
|
||||
@ -921,20 +937,54 @@ def parse_args_spec(arg_spec: ArgsSpec):
|
||||
"""
|
||||
spec = inspect.getfullargspec(arg_spec)
|
||||
annotations = get_type_hints(arg_spec)
|
||||
|
||||
return arg_spec(
|
||||
*[
|
||||
Var(f"_{l_arg}").to(annotations.get(l_arg, FrontendEvent))
|
||||
Var(f"_{l_arg}").to(
|
||||
unwrap_var_annotation(annotations.get(l_arg, FrontendEvent))
|
||||
)
|
||||
for l_arg in spec.args
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
|
||||
"""Ensures that the function signature matches the passed argument specification
|
||||
or raises an EventFnArgMismatch if they do not.
|
||||
|
||||
Args:
|
||||
fn: The function to be validated.
|
||||
arg_spec: The argument specification for the event trigger.
|
||||
|
||||
Returns:
|
||||
The parsed arguments from the argument specification.
|
||||
|
||||
Raises:
|
||||
EventFnArgMismatch: Raised if the number of mandatory arguments do not match
|
||||
"""
|
||||
fn_args = inspect.getfullargspec(fn).args
|
||||
fn_defaults_args = inspect.getfullargspec(fn).defaults
|
||||
n_fn_args = len(fn_args)
|
||||
n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
|
||||
if isinstance(fn, types.MethodType):
|
||||
n_fn_args -= 1 # subtract 1 for bound self arg
|
||||
parsed_args = parse_args_spec(arg_spec)
|
||||
if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
|
||||
raise EventFnArgMismatch(
|
||||
"The number of mandatory arguments accepted by "
|
||||
f"{fn} ({n_fn_args - n_fn_defaults_args}) "
|
||||
"does not match the arguments passed by the event trigger: "
|
||||
f"{[str(v) for v in parsed_args]}\n"
|
||||
"See https://reflex.dev/docs/events/event-arguments/"
|
||||
)
|
||||
return parsed_args
|
||||
|
||||
|
||||
def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
||||
"""Call a function to a list of event specs.
|
||||
|
||||
The function should return a single EventSpec, a list of EventSpecs, or a
|
||||
single Var. The function signature must match the passed arg_spec or
|
||||
EventFnArgsMismatch will be raised.
|
||||
single Var.
|
||||
|
||||
Args:
|
||||
fn: The function to call.
|
||||
@ -944,7 +994,6 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
||||
The event specs from calling the function or a Var.
|
||||
|
||||
Raises:
|
||||
EventFnArgMismatch: If the function signature doesn't match the arg spec.
|
||||
EventHandlerValueError: If the lambda returns an unusable value.
|
||||
"""
|
||||
# Import here to avoid circular imports.
|
||||
@ -952,19 +1001,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
||||
from reflex.utils.exceptions import EventHandlerValueError
|
||||
|
||||
# Check that fn signature matches arg_spec
|
||||
fn_args = inspect.getfullargspec(fn).args
|
||||
n_fn_args = len(fn_args)
|
||||
if isinstance(fn, types.MethodType):
|
||||
n_fn_args -= 1 # subtract 1 for bound self arg
|
||||
parsed_args = parse_args_spec(arg_spec)
|
||||
if len(parsed_args) != n_fn_args:
|
||||
raise EventFnArgMismatch(
|
||||
"The number of arguments accepted by "
|
||||
f"{fn} ({n_fn_args}) "
|
||||
"does not match the arguments passed by the event trigger: "
|
||||
f"{[str(v) for v in parsed_args]}\n"
|
||||
"See https://reflex.dev/docs/events/event-arguments/"
|
||||
)
|
||||
parsed_args = check_fn_match_arg_spec(fn, arg_spec)
|
||||
|
||||
# Call the function with the parsed args.
|
||||
out = fn(*parsed_args)
|
||||
|
@ -9,6 +9,7 @@ import sys
|
||||
import types
|
||||
from functools import cached_property, lru_cache, wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
@ -96,8 +97,22 @@ PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
|
||||
StateVar = Union[PrimitiveType, Base, None]
|
||||
StateIterVar = Union[list, set, tuple]
|
||||
|
||||
# ArgsSpec = Callable[[Var], list[Var]]
|
||||
ArgsSpec = Callable
|
||||
if TYPE_CHECKING:
|
||||
from reflex.vars.base import Var
|
||||
|
||||
# ArgsSpec = Callable[[Var], list[Var]]
|
||||
ArgsSpec = (
|
||||
Callable[[], List[Var]]
|
||||
| Callable[[Var], List[Var]]
|
||||
| Callable[[Var, Var], List[Var]]
|
||||
| Callable[[Var, Var, Var], List[Var]]
|
||||
| Callable[[Var, Var, Var, Var], List[Var]]
|
||||
| Callable[[Var, Var, Var, Var, Var], List[Var]]
|
||||
| Callable[[Var, Var, Var, Var, Var, Var], List[Var]]
|
||||
| Callable[[Var, Var, Var, Var, Var, Var, Var], List[Var]]
|
||||
)
|
||||
else:
|
||||
ArgsSpec = Callable[..., List[Any]]
|
||||
|
||||
|
||||
PrimitiveToAnnotation = {
|
||||
|
@ -97,7 +97,7 @@ def test_call_event_handler_partial():
|
||||
|
||||
test_fn_with_args.__qualname__ = "test_fn_with_args"
|
||||
|
||||
def spec(a2: str) -> List[str]:
|
||||
def spec(a2: Var[str]) -> List[Var[str]]:
|
||||
return [a2]
|
||||
|
||||
handler = EventHandler(fn=test_fn_with_args)
|
||||
|
Loading…
Reference in New Issue
Block a user