EventFnArgMismatch fix to support defaults args (#4004)
* EventFnArgMismatch fix to support defaults args * fixing type hint and docstring raises * enforce stronger type checking * unwrap var annotations :( --------- Co-authored-by: Khaleel Al-Adhami <khaleel.aladhami@gmail.com>
This commit is contained in:
parent
54c7b5a261
commit
60276cf1ff
@ -18,10 +18,12 @@ from typing import (
|
|||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import get_args, get_origin
|
||||||
|
|
||||||
from reflex import constants
|
from reflex import constants
|
||||||
from reflex.utils import format
|
from reflex.utils import format
|
||||||
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
|
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
|
||||||
from reflex.utils.types import ArgsSpec
|
from reflex.utils.types import ArgsSpec, GenericType
|
||||||
from reflex.vars import VarData
|
from reflex.vars import VarData
|
||||||
from reflex.vars.base import LiteralVar, Var
|
from reflex.vars.base import LiteralVar, Var
|
||||||
from reflex.vars.function import FunctionStringVar, FunctionVar
|
from reflex.vars.function import FunctionStringVar, FunctionVar
|
||||||
@ -417,7 +419,7 @@ class FileUpload:
|
|||||||
on_upload_progress: Optional[Union[EventHandler, Callable]] = None
|
on_upload_progress: Optional[Union[EventHandler, Callable]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def on_upload_progress_args_spec(_prog: Dict[str, Union[int, float, bool]]):
|
def on_upload_progress_args_spec(_prog: Var[Dict[str, Union[int, float, bool]]]):
|
||||||
"""Args spec for on_upload_progress event handler.
|
"""Args spec for on_upload_progress event handler.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -910,6 +912,20 @@ def call_event_handler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_var_annotation(annotation: GenericType):
|
||||||
|
"""Unwrap a Var annotation or return it as is if it's not Var[X].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
annotation: The annotation to unwrap.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The unwrapped annotation.
|
||||||
|
"""
|
||||||
|
if get_origin(annotation) is Var and (args := get_args(annotation)):
|
||||||
|
return args[0]
|
||||||
|
return annotation
|
||||||
|
|
||||||
|
|
||||||
def parse_args_spec(arg_spec: ArgsSpec):
|
def parse_args_spec(arg_spec: ArgsSpec):
|
||||||
"""Parse the args provided in the ArgsSpec of an event trigger.
|
"""Parse the args provided in the ArgsSpec of an event trigger.
|
||||||
|
|
||||||
@ -921,20 +937,54 @@ def parse_args_spec(arg_spec: ArgsSpec):
|
|||||||
"""
|
"""
|
||||||
spec = inspect.getfullargspec(arg_spec)
|
spec = inspect.getfullargspec(arg_spec)
|
||||||
annotations = get_type_hints(arg_spec)
|
annotations = get_type_hints(arg_spec)
|
||||||
|
|
||||||
return arg_spec(
|
return arg_spec(
|
||||||
*[
|
*[
|
||||||
Var(f"_{l_arg}").to(annotations.get(l_arg, FrontendEvent))
|
Var(f"_{l_arg}").to(
|
||||||
|
unwrap_var_annotation(annotations.get(l_arg, FrontendEvent))
|
||||||
|
)
|
||||||
for l_arg in spec.args
|
for l_arg in spec.args
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
|
||||||
|
"""Ensures that the function signature matches the passed argument specification
|
||||||
|
or raises an EventFnArgMismatch if they do not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The function to be validated.
|
||||||
|
arg_spec: The argument specification for the event trigger.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parsed arguments from the argument specification.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EventFnArgMismatch: Raised if the number of mandatory arguments do not match
|
||||||
|
"""
|
||||||
|
fn_args = inspect.getfullargspec(fn).args
|
||||||
|
fn_defaults_args = inspect.getfullargspec(fn).defaults
|
||||||
|
n_fn_args = len(fn_args)
|
||||||
|
n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
|
||||||
|
if isinstance(fn, types.MethodType):
|
||||||
|
n_fn_args -= 1 # subtract 1 for bound self arg
|
||||||
|
parsed_args = parse_args_spec(arg_spec)
|
||||||
|
if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
|
||||||
|
raise EventFnArgMismatch(
|
||||||
|
"The number of mandatory arguments accepted by "
|
||||||
|
f"{fn} ({n_fn_args - n_fn_defaults_args}) "
|
||||||
|
"does not match the arguments passed by the event trigger: "
|
||||||
|
f"{[str(v) for v in parsed_args]}\n"
|
||||||
|
"See https://reflex.dev/docs/events/event-arguments/"
|
||||||
|
)
|
||||||
|
return parsed_args
|
||||||
|
|
||||||
|
|
||||||
def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
||||||
"""Call a function to a list of event specs.
|
"""Call a function to a list of event specs.
|
||||||
|
|
||||||
The function should return a single EventSpec, a list of EventSpecs, or a
|
The function should return a single EventSpec, a list of EventSpecs, or a
|
||||||
single Var. The function signature must match the passed arg_spec or
|
single Var.
|
||||||
EventFnArgsMismatch will be raised.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn: The function to call.
|
fn: The function to call.
|
||||||
@ -944,7 +994,6 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
|||||||
The event specs from calling the function or a Var.
|
The event specs from calling the function or a Var.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
EventFnArgMismatch: If the function signature doesn't match the arg spec.
|
|
||||||
EventHandlerValueError: If the lambda returns an unusable value.
|
EventHandlerValueError: If the lambda returns an unusable value.
|
||||||
"""
|
"""
|
||||||
# Import here to avoid circular imports.
|
# Import here to avoid circular imports.
|
||||||
@ -952,19 +1001,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
|
|||||||
from reflex.utils.exceptions import EventHandlerValueError
|
from reflex.utils.exceptions import EventHandlerValueError
|
||||||
|
|
||||||
# Check that fn signature matches arg_spec
|
# Check that fn signature matches arg_spec
|
||||||
fn_args = inspect.getfullargspec(fn).args
|
parsed_args = check_fn_match_arg_spec(fn, arg_spec)
|
||||||
n_fn_args = len(fn_args)
|
|
||||||
if isinstance(fn, types.MethodType):
|
|
||||||
n_fn_args -= 1 # subtract 1 for bound self arg
|
|
||||||
parsed_args = parse_args_spec(arg_spec)
|
|
||||||
if len(parsed_args) != n_fn_args:
|
|
||||||
raise EventFnArgMismatch(
|
|
||||||
"The number of arguments accepted by "
|
|
||||||
f"{fn} ({n_fn_args}) "
|
|
||||||
"does not match the arguments passed by the event trigger: "
|
|
||||||
f"{[str(v) for v in parsed_args]}\n"
|
|
||||||
"See https://reflex.dev/docs/events/event-arguments/"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call the function with the parsed args.
|
# Call the function with the parsed args.
|
||||||
out = fn(*parsed_args)
|
out = fn(*parsed_args)
|
||||||
|
@ -9,6 +9,7 @@ import sys
|
|||||||
import types
|
import types
|
||||||
from functools import cached_property, lru_cache, wraps
|
from functools import cached_property, lru_cache, wraps
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
@ -96,8 +97,22 @@ PrimitiveType = Union[int, float, bool, str, list, dict, set, tuple]
|
|||||||
StateVar = Union[PrimitiveType, Base, None]
|
StateVar = Union[PrimitiveType, Base, None]
|
||||||
StateIterVar = Union[list, set, tuple]
|
StateIterVar = Union[list, set, tuple]
|
||||||
|
|
||||||
# ArgsSpec = Callable[[Var], list[Var]]
|
if TYPE_CHECKING:
|
||||||
ArgsSpec = Callable
|
from reflex.vars.base import Var
|
||||||
|
|
||||||
|
# ArgsSpec = Callable[[Var], list[Var]]
|
||||||
|
ArgsSpec = (
|
||||||
|
Callable[[], List[Var]]
|
||||||
|
| Callable[[Var], List[Var]]
|
||||||
|
| Callable[[Var, Var], List[Var]]
|
||||||
|
| Callable[[Var, Var, Var], List[Var]]
|
||||||
|
| Callable[[Var, Var, Var, Var], List[Var]]
|
||||||
|
| Callable[[Var, Var, Var, Var, Var], List[Var]]
|
||||||
|
| Callable[[Var, Var, Var, Var, Var, Var], List[Var]]
|
||||||
|
| Callable[[Var, Var, Var, Var, Var, Var, Var], List[Var]]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ArgsSpec = Callable[..., List[Any]]
|
||||||
|
|
||||||
|
|
||||||
PrimitiveToAnnotation = {
|
PrimitiveToAnnotation = {
|
||||||
|
@ -97,7 +97,7 @@ def test_call_event_handler_partial():
|
|||||||
|
|
||||||
test_fn_with_args.__qualname__ = "test_fn_with_args"
|
test_fn_with_args.__qualname__ = "test_fn_with_args"
|
||||||
|
|
||||||
def spec(a2: str) -> List[str]:
|
def spec(a2: Var[str]) -> List[Var[str]]:
|
||||||
return [a2]
|
return [a2]
|
||||||
|
|
||||||
handler = EventHandler(fn=test_fn_with_args)
|
handler = EventHandler(fn=test_fn_with_args)
|
||||||
|
Loading…
Reference in New Issue
Block a user