add typing to function vars (#4372)

* add typing to function vars

* import ParamSpec from typing_extensions

* remove ellipsis as they are not supported in 3.9

* try importing everything from extensions

* special case 3.9

* don't use Any from extensions

* get typevar from extensions
This commit is contained in:
Khaleel Al-Adhami 2024-11-12 20:00:02 -08:00 committed by GitHub
parent 5d88263cd8
commit 27c1a7e94d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 309 additions and 47 deletions

View File

@ -45,6 +45,8 @@ from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var from reflex.vars.base import LiteralVar, Var
from reflex.vars.function import ( from reflex.vars.function import (
ArgsFunctionOperation, ArgsFunctionOperation,
ArgsFunctionOperationBuilder,
BuilderFunctionVar,
FunctionArgs, FunctionArgs,
FunctionStringVar, FunctionStringVar,
FunctionVar, FunctionVar,
@ -797,8 +799,7 @@ def scroll_to(elem_id: str, align_to_top: bool | Var[bool] = True) -> EventSpec:
get_element_by_id = FunctionStringVar.create("document.getElementById") get_element_by_id = FunctionStringVar.create("document.getElementById")
return run_script( return run_script(
get_element_by_id(elem_id) get_element_by_id.call(elem_id)
.call(elem_id)
.to(ObjectVar) .to(ObjectVar)
.scrollIntoView.to(FunctionVar) .scrollIntoView.to(FunctionVar)
.call(align_to_top), .call(align_to_top),
@ -1580,7 +1581,7 @@ class LiteralEventVar(VarOperationCall, LiteralVar, EventVar):
) )
class EventChainVar(FunctionVar, python_types=EventChain): class EventChainVar(BuilderFunctionVar, python_types=EventChain):
"""Base class for event chain vars.""" """Base class for event chain vars."""
@ -1592,7 +1593,7 @@ class EventChainVar(FunctionVar, python_types=EventChain):
# Note: LiteralVar is second in the inheritance list allowing it act like a # Note: LiteralVar is second in the inheritance list allowing it act like a
# CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the # CachedVarOperation (ArgsFunctionOperation) and get the _js_expr from the
# _cached_var_name property. # _cached_var_name property.
class LiteralEventChainVar(ArgsFunctionOperation, LiteralVar, EventChainVar): class LiteralEventChainVar(ArgsFunctionOperationBuilder, LiteralVar, EventChainVar):
"""A literal event chain var.""" """A literal event chain var."""
_var_value: EventChain = dataclasses.field(default=None) # type: ignore _var_value: EventChain = dataclasses.field(default=None) # type: ignore

View File

@ -51,7 +51,8 @@ def get_python_version() -> str:
Returns: Returns:
The Python version. The Python version.
""" """
return platform.python_version() # Remove the "+" from the version string in case user is using a pre-release version.
return platform.python_version().rstrip("+")
def get_reflex_version() -> str: def get_reflex_version() -> str:

View File

@ -361,21 +361,29 @@ class Var(Generic[VAR_TYPE]):
return False return False
def __init_subclass__( def __init_subclass__(
cls, python_types: Tuple[GenericType, ...] | GenericType = types.Unset, **kwargs cls,
python_types: Tuple[GenericType, ...] | GenericType = types.Unset(),
default_type: GenericType = types.Unset(),
**kwargs,
): ):
"""Initialize the subclass. """Initialize the subclass.
Args: Args:
python_types: The python types that the var represents. python_types: The python types that the var represents.
default_type: The default type of the var. Defaults to the first python type.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
""" """
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
if python_types is not types.Unset: if python_types or default_type:
python_types = ( python_types = (
python_types if isinstance(python_types, tuple) else (python_types,) (python_types if isinstance(python_types, tuple) else (python_types,))
if python_types
else ()
) )
default_type = default_type or (python_types[0] if python_types else Any)
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
@ -388,7 +396,7 @@ class Var(Generic[VAR_TYPE]):
default=Var(_js_expr="null", _var_type=None), default=Var(_js_expr="null", _var_type=None),
) )
_default_var_type: ClassVar[GenericType] = python_types[0] _default_var_type: ClassVar[GenericType] = default_type
ToVarOperation.__name__ = f'To{cls.__name__.removesuffix("Var")}Operation' ToVarOperation.__name__ = f'To{cls.__name__.removesuffix("Var")}Operation'
@ -588,6 +596,12 @@ class Var(Generic[VAR_TYPE]):
output: type[list] | type[tuple] | type[set], output: type[list] | type[tuple] | type[set],
) -> ArrayVar: ... ) -> ArrayVar: ...
@overload
def to(
self,
output: type[dict],
) -> ObjectVar[dict]: ...
@overload @overload
def to( def to(
self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE] self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE]

View File

@ -4,32 +4,177 @@ from __future__ import annotations
import dataclasses import dataclasses
import sys import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload
from typing_extensions import Concatenate, Generic, ParamSpec, Protocol, TypeVar
from reflex.utils import format from reflex.utils import format
from reflex.utils.types import GenericType from reflex.utils.types import GenericType
from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock from .base import CachedVarOperation, LiteralVar, Var, VarData, cached_property_no_lock
P = ParamSpec("P")
V1 = TypeVar("V1")
V2 = TypeVar("V2")
V3 = TypeVar("V3")
V4 = TypeVar("V4")
V5 = TypeVar("V5")
V6 = TypeVar("V6")
R = TypeVar("R")
class FunctionVar(Var[Callable], python_types=Callable):
class ReflexCallable(Protocol[P, R]):
"""Protocol for a callable."""
__call__: Callable[P, R]
CALLABLE_TYPE = TypeVar("CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True)
OTHER_CALLABLE_TYPE = TypeVar(
"OTHER_CALLABLE_TYPE", bound=ReflexCallable, infer_variance=True
)
class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
"""Base class for immutable function vars.""" """Base class for immutable function vars."""
def __call__(self, *args: Var | Any) -> ArgsFunctionOperation: @overload
"""Call the function with the given arguments. def partial(self) -> FunctionVar[CALLABLE_TYPE]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, P], R]],
arg1: Union[V1, Var[V1]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[Concatenate[V1, V2, V3, V4, V5, V6, P], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
arg6: Union[V6, Var[V6]],
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(
self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
) -> FunctionVar[ReflexCallable[P, R]]: ...
@overload
def partial(self, *args: Var | Any) -> FunctionVar: ...
def partial(self, *args: Var | Any) -> FunctionVar: # type: ignore
"""Partially apply the function with the given arguments.
Args: Args:
*args: The arguments to call the function with. *args: The arguments to partially apply the function with.
Returns: Returns:
The function call operation. The partially applied function.
""" """
if not args:
return ArgsFunctionOperation.create((), self)
return ArgsFunctionOperation.create( return ArgsFunctionOperation.create(
("...args",), ("...args",),
VarOperationCall.create(self, *args, Var(_js_expr="...args")), VarOperationCall.create(self, *args, Var(_js_expr="...args")),
) )
def call(self, *args: Var | Any) -> VarOperationCall: @overload
def call(
self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
) -> VarOperationCall[[V1], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
) -> VarOperationCall[[V1, V2], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
) -> VarOperationCall[[V1, V2, V3], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, V2, V3, V4, V5, V6], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
arg6: Union[V6, Var[V6]],
) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[P, R]], *args: Var | Any
) -> VarOperationCall[P, R]: ...
@overload
def call(self, *args: Var | Any) -> Var: ...
def call(self, *args: Var | Any) -> Var: # type: ignore
"""Call the function with the given arguments. """Call the function with the given arguments.
Args: Args:
@ -38,19 +183,29 @@ class FunctionVar(Var[Callable], python_types=Callable):
Returns: Returns:
The function call operation. The function call operation.
""" """
return VarOperationCall.create(self, *args) return VarOperationCall.create(self, *args).guess_type()
__call__ = call
class FunctionStringVar(FunctionVar): class BuilderFunctionVar(
FunctionVar[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]
):
"""Base class for immutable function vars with the builder pattern."""
__call__ = FunctionVar.partial
class FunctionStringVar(FunctionVar[CALLABLE_TYPE]):
"""Base class for immutable function vars from a string.""" """Base class for immutable function vars from a string."""
@classmethod @classmethod
def create( def create(
cls, cls,
func: str, func: str,
_var_type: Type[Callable] = Callable, _var_type: Type[OTHER_CALLABLE_TYPE] = ReflexCallable[Any, Any],
_var_data: VarData | None = None, _var_data: VarData | None = None,
) -> FunctionStringVar: ) -> FunctionStringVar[OTHER_CALLABLE_TYPE]:
"""Create a new function var from a string. """Create a new function var from a string.
Args: Args:
@ -60,7 +215,7 @@ class FunctionStringVar(FunctionVar):
Returns: Returns:
The function var. The function var.
""" """
return cls( return FunctionStringVar(
_js_expr=func, _js_expr=func,
_var_type=_var_type, _var_type=_var_type,
_var_data=_var_data, _var_data=_var_data,
@ -72,10 +227,10 @@ class FunctionStringVar(FunctionVar):
frozen=True, frozen=True,
**{"slots": True} if sys.version_info >= (3, 10) else {}, **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class VarOperationCall(CachedVarOperation, Var): class VarOperationCall(Generic[P, R], CachedVarOperation, Var[R]):
"""Base class for immutable vars that are the result of a function call.""" """Base class for immutable vars that are the result of a function call."""
_func: Optional[FunctionVar] = dataclasses.field(default=None) _func: Optional[FunctionVar[ReflexCallable[P, R]]] = dataclasses.field(default=None)
_args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple) _args: Tuple[Union[Var, Any], ...] = dataclasses.field(default_factory=tuple)
@cached_property_no_lock @cached_property_no_lock
@ -103,7 +258,7 @@ class VarOperationCall(CachedVarOperation, Var):
@classmethod @classmethod
def create( def create(
cls, cls,
func: FunctionVar, func: FunctionVar[ReflexCallable[P, R]],
*args: Var | Any, *args: Var | Any,
_var_type: GenericType = Any, _var_type: GenericType = Any,
_var_data: VarData | None = None, _var_data: VarData | None = None,
@ -118,9 +273,15 @@ class VarOperationCall(CachedVarOperation, Var):
Returns: Returns:
The function call var. The function call var.
""" """
function_return_type = (
func._var_type.__args__[1]
if getattr(func._var_type, "__args__", None)
else Any
)
var_type = _var_type if _var_type is not Any else function_return_type
return cls( return cls(
_js_expr="", _js_expr="",
_var_type=_var_type, _var_type=var_type,
_var_data=_var_data, _var_data=_var_data,
_func=func, _func=func,
_args=args, _args=args,
@ -157,6 +318,33 @@ class FunctionArgs:
rest: Optional[str] = None rest: Optional[str] = None
def format_args_function_operation(
args: FunctionArgs, return_expr: Var | Any, explicit_return: bool
) -> str:
"""Format an args function operation.
Args:
args: The function arguments.
return_expr: The return expression.
explicit_return: Whether to use explicit return syntax.
Returns:
The formatted args function operation.
"""
arg_names_str = ", ".join(
[arg if isinstance(arg, str) else arg.to_javascript() for arg in args.args]
) + (f", ...{args.rest}" if args.rest else "")
return_expr_str = str(LiteralVar.create(return_expr))
# Wrap return expression in curly braces if explicit return syntax is used.
return_expr_str_wrapped = (
format.wrap(return_expr_str, "{", "}") if explicit_return else return_expr_str
)
return f"(({arg_names_str}) => {return_expr_str_wrapped})"
@dataclasses.dataclass( @dataclasses.dataclass(
eq=False, eq=False,
frozen=True, frozen=True,
@ -176,24 +364,10 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
Returns: Returns:
The name of the var. The name of the var.
""" """
arg_names_str = ", ".join( return format_args_function_operation(
[ self._args, self._return_expr, self._explicit_return
arg if isinstance(arg, str) else arg.to_javascript()
for arg in self._args.args
]
) + (f", ...{self._args.rest}" if self._args.rest else "")
return_expr_str = str(LiteralVar.create(self._return_expr))
# Wrap return expression in curly braces if explicit return syntax is used.
return_expr_str_wrapped = (
format.wrap(return_expr_str, "{", "}")
if self._explicit_return
else return_expr_str
) )
return f"(({arg_names_str}) => {return_expr_str_wrapped})"
@classmethod @classmethod
def create( def create(
cls, cls,
@ -203,7 +377,7 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
explicit_return: bool = False, explicit_return: bool = False,
_var_type: GenericType = Callable, _var_type: GenericType = Callable,
_var_data: VarData | None = None, _var_data: VarData | None = None,
) -> ArgsFunctionOperation: ):
"""Create a new function var. """Create a new function var.
Args: Args:
@ -226,8 +400,80 @@ class ArgsFunctionOperation(CachedVarOperation, FunctionVar):
) )
JSON_STRINGIFY = FunctionStringVar.create("JSON.stringify") @dataclasses.dataclass(
ARRAY_ISARRAY = FunctionStringVar.create("Array.isArray") eq=False,
PROTOTYPE_TO_STRING = FunctionStringVar.create( frozen=True,
"((__to_string) => __to_string.toString())" **{"slots": True} if sys.version_info >= (3, 10) else {},
) )
class ArgsFunctionOperationBuilder(CachedVarOperation, BuilderFunctionVar):
"""Base class for immutable function defined via arguments and return expression with the builder pattern."""
_args: FunctionArgs = dataclasses.field(default_factory=FunctionArgs)
_return_expr: Union[Var, Any] = dataclasses.field(default=None)
_explicit_return: bool = dataclasses.field(default=False)
@cached_property_no_lock
def _cached_var_name(self) -> str:
"""The name of the var.
Returns:
The name of the var.
"""
return format_args_function_operation(
self._args, self._return_expr, self._explicit_return
)
@classmethod
def create(
cls,
args_names: Sequence[Union[str, DestructuredArg]],
return_expr: Var | Any,
rest: str | None = None,
explicit_return: bool = False,
_var_type: GenericType = Callable,
_var_data: VarData | None = None,
):
"""Create a new function var.
Args:
args_names: The names of the arguments.
return_expr: The return expression of the function.
rest: The name of the rest argument.
explicit_return: Whether to use explicit return syntax.
_var_data: Additional hooks and imports associated with the Var.
Returns:
The function var.
"""
return cls(
_js_expr="",
_var_type=_var_type,
_var_data=_var_data,
_args=FunctionArgs(args=tuple(args_names), rest=rest),
_return_expr=return_expr,
_explicit_return=explicit_return,
)
if python_version := sys.version_info[:2] >= (3, 10):
JSON_STRINGIFY = FunctionStringVar.create(
"JSON.stringify", _var_type=ReflexCallable[[Any], str]
)
ARRAY_ISARRAY = FunctionStringVar.create(
"Array.isArray", _var_type=ReflexCallable[[Any], bool]
)
PROTOTYPE_TO_STRING = FunctionStringVar.create(
"((__to_string) => __to_string.toString())",
_var_type=ReflexCallable[[Any], str],
)
else:
JSON_STRINGIFY = FunctionStringVar.create(
"JSON.stringify", _var_type=ReflexCallable[Any, str]
)
ARRAY_ISARRAY = FunctionStringVar.create(
"Array.isArray", _var_type=ReflexCallable[Any, bool]
)
PROTOTYPE_TO_STRING = FunctionStringVar.create(
"((__to_string) => __to_string.toString())",
_var_type=ReflexCallable[Any, str],
)

View File

@ -928,7 +928,7 @@ def test_function_var():
== '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))' == '(((a, b) => ({ ["args"] : [a, b], ["result"] : a + b }))(1, 2))'
) )
increment_func = addition_func(1) increment_func = addition_func.partial(1)
assert ( assert (
str(increment_func.call(2)) str(increment_func.call(2))
== "(((...args) => (((a, b) => a + b)(1, ...args)))(2))" == "(((...args) => (((a, b) => a + b)(1, ...args)))(2))"