call guess type

This commit is contained in:
Khaleel Al-Adhami 2024-11-15 15:54:04 -08:00
parent 702670ff26
commit eac54d60d2

View File

@ -181,20 +181,20 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
)
@overload
def call(self: FunctionVar[ReflexCallable[[], R]]) -> VarOperationCall[[], R]: ...
def call(self: FunctionVar[ReflexCallable[[], R]]) -> Var[R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[VarWithDefault[V1]], R]],
arg1: Union[V1, Var[V1], Unset] = Unset(),
) -> VarOperationCall[[VarWithDefault[V1]], R]: ...
) -> Var[R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[VarWithDefault[V1], VarWithDefault[V2]], R]],
arg1: Union[V1, Var[V1], Unset] = Unset(),
arg2: Union[V2, Var[V2], Unset] = Unset(),
) -> VarOperationCall[[VarWithDefault[V1], VarWithDefault[V2]], R]: ...
) -> Var[R]: ...
@overload
def call(
@ -206,21 +206,19 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
arg1: Union[V1, Var[V1], Unset] = Unset(),
arg2: Union[V2, Var[V2], Unset] = Unset(),
arg3: Union[V3, Var[V3], Unset] = Unset(),
) -> VarOperationCall[
[VarWithDefault[V1], VarWithDefault[V2], VarWithDefault[V3]], R
]: ...
) -> Var[R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1], R]], arg1: Union[V1, Var[V1]]
) -> VarOperationCall[[V1], R]: ...
) -> Var[R]: ...
@overload
def call(
self: FunctionVar[ReflexCallable[[V1, VarWithDefault[V2]], R]],
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2], Unset] = Unset(),
) -> VarOperationCall[[V1, VarWithDefault[V2]], R]: ...
) -> Var[R]: ...
@overload
def call(
@ -230,7 +228,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2], Unset] = Unset(),
arg3: Union[V3, Var[V3], Unset] = Unset(),
) -> VarOperationCall[[V1, VarWithDefault[V2], VarWithDefault[V3]], R]: ...
) -> Var[R]: ...
@overload
def call(
@ -245,7 +243,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3], Unset] = Unset(),
) -> VarOperationCall[[V1, V2, VarWithDefault[V3]], R]: ...
) -> Var[R]: ...
@overload
def call(
@ -253,7 +251,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
arg1: Union[V1, Var[V1]],
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
) -> VarOperationCall[[V1, V2, V3], R]: ...
) -> Var[R]: ...
@overload
def call(
@ -262,7 +260,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
arg2: Union[V2, Var[V2]],
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
) -> VarOperationCall[[V1, V2, V3, V4], R]: ...
) -> Var[R]: ...
@overload
def call(
@ -272,7 +270,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
arg3: Union[V3, Var[V3]],
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
) -> VarOperationCall[[V1, V2, V3, V4, V5], R]: ...
) -> Var[R]: ...
@overload
def call(
@ -283,13 +281,13 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
arg4: Union[V4, Var[V4]],
arg5: Union[V5, Var[V5]],
arg6: Union[V6, Var[V6]],
) -> VarOperationCall[[V1, V2, V3, V4, V5, V6], R]: ...
) -> Var[R]: ...
# Capture Any to allow for arbitrary number of arguments
@overload
def call(self: FunctionVar[NoReturn], *args: Var | Any) -> VarOperationCall: ...
def call(self: FunctionVar[NoReturn], *args: Var | Any) -> Var: ...
def call(self, *args: Var | Any) -> VarOperationCall: # type: ignore
def call(self, *args: Var | Any) -> Var: # pyright: ignore [reportInconsistentOverload]
"""Call the function with the given arguments.
Args:
@ -315,7 +313,7 @@ class FunctionVar(Var[CALLABLE_TYPE], default_type=ReflexCallable[Any, Any]):
args = tuple(map(LiteralVar.create, args))
self._pre_check(*args)
return_type = self._return_type(*args)
return VarOperationCall.create(self, *args, _var_type=return_type)
return VarOperationCall.create(self, *args, _var_type=return_type).guess_type()
def chain(
self: FunctionVar[ReflexCallable[P, R]],