solve some but not all pyright issues
This commit is contained in:
parent
3d73f561b7
commit
112b2ed948
@ -16,7 +16,7 @@ from itertools import chain
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
|
||||
from typing import Any, Callable, Iterable, Sequence, Type, cast, get_args, get_origin
|
||||
|
||||
from reflex.components.component import Component
|
||||
from reflex.utils import types as rx_types
|
||||
@ -229,7 +229,9 @@ def _generate_imports(
|
||||
"""
|
||||
return [
|
||||
*[
|
||||
ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values])
|
||||
ast.ImportFrom(
|
||||
module=name, names=[ast.alias(name=val) for val in values], level=0
|
||||
)
|
||||
for name, values in DEFAULT_IMPORTS.items()
|
||||
],
|
||||
ast.Import([ast.alias("reflex")]),
|
||||
@ -428,16 +430,15 @@ def type_to_ast(typ, cls: type) -> ast.AST:
|
||||
return ast.Name(id=base_name)
|
||||
|
||||
# Convert all type arguments recursively
|
||||
arg_nodes = [type_to_ast(arg, cls) for arg in args]
|
||||
arg_nodes = cast(list[ast.expr], [type_to_ast(arg, cls) for arg in args])
|
||||
|
||||
# Special case for single-argument types (like List[T] or Optional[T])
|
||||
if len(arg_nodes) == 1:
|
||||
slice_value = arg_nodes[0]
|
||||
else:
|
||||
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())
|
||||
|
||||
return ast.Subscript(
|
||||
value=ast.Name(id=base_name), slice=ast.Index(value=slice_value), ctx=ast.Load()
|
||||
value=ast.Name(id=base_name), slice=slice_value, ctx=ast.Load()
|
||||
)
|
||||
|
||||
|
||||
@ -630,7 +631,7 @@ def _generate_component_create_functiondef(
|
||||
),
|
||||
),
|
||||
ast.Expr(
|
||||
value=ast.Ellipsis(),
|
||||
value=ast.Constant(...),
|
||||
),
|
||||
],
|
||||
decorator_list=[
|
||||
@ -641,8 +642,14 @@ def _generate_component_create_functiondef(
|
||||
else [ast.Name(id="classmethod")]
|
||||
),
|
||||
],
|
||||
lineno=node.lineno if node is not None else None,
|
||||
returns=ast.Constant(value=clz.__name__),
|
||||
**(
|
||||
{
|
||||
"lineno": node.lineno,
|
||||
}
|
||||
if node is not None
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return definition
|
||||
|
||||
@ -690,13 +697,19 @@ def _generate_staticmethod_call_functiondef(
|
||||
),
|
||||
],
|
||||
decorator_list=[ast.Name(id="staticmethod")],
|
||||
lineno=node.lineno if node is not None else None,
|
||||
returns=ast.Constant(
|
||||
value=_get_type_hint(
|
||||
typing.get_type_hints(clz.__call__).get("return", None),
|
||||
type_hint_globals,
|
||||
)
|
||||
),
|
||||
**(
|
||||
{
|
||||
"lineno": node.lineno,
|
||||
}
|
||||
if node is not None
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return definition
|
||||
|
||||
@ -731,7 +744,12 @@ def _generate_namespace_call_functiondef(
|
||||
# Determine which class is wrapped by the namespace __call__ method
|
||||
component_clz = clz.__call__.__self__
|
||||
|
||||
if clz.__call__.__func__.__name__ != "create":
|
||||
func = getattr(clz.__call__, "__func__", None)
|
||||
|
||||
if func is None:
|
||||
raise TypeError(f"__call__ method on {clz_name} does not have a __func__")
|
||||
|
||||
if func.__name__ != "create":
|
||||
return None
|
||||
|
||||
definition = _generate_component_create_functiondef(
|
||||
@ -914,7 +932,7 @@ class StubGenerator(ast.NodeTransformer):
|
||||
node.body.append(call_definition)
|
||||
if not node.body:
|
||||
# We should never return an empty body.
|
||||
node.body.append(ast.Expr(value=ast.Ellipsis()))
|
||||
node.body.append(ast.Expr(value=ast.Constant(...)))
|
||||
self.current_class = None
|
||||
return node
|
||||
|
||||
@ -941,9 +959,9 @@ class StubGenerator(ast.NodeTransformer):
|
||||
if node.name.startswith("_") and node.name != "__call__":
|
||||
return None # remove private methods
|
||||
|
||||
if node.body[-1] != ast.Expr(value=ast.Ellipsis()):
|
||||
if node.body[-1] != ast.Expr(value=ast.Constant(...)):
|
||||
# Blank out the function body for public functions.
|
||||
node.body = [ast.Expr(value=ast.Ellipsis())]
|
||||
node.body = [ast.Expr(value=ast.Constant(...))]
|
||||
return node
|
||||
|
||||
def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
|
||||
|
@ -69,21 +69,21 @@ else:
|
||||
|
||||
|
||||
# Potential GenericAlias types for isinstance checks.
|
||||
GenericAliasTypes = [_GenericAlias]
|
||||
_GenericAliasTypes: list[type] = [_GenericAlias]
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
# For newer versions of Python.
|
||||
from types import GenericAlias # type: ignore
|
||||
|
||||
GenericAliasTypes.append(GenericAlias)
|
||||
_GenericAliasTypes.append(GenericAlias)
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
# For older versions of Python.
|
||||
from typing import _SpecialGenericAlias # type: ignore
|
||||
|
||||
GenericAliasTypes.append(_SpecialGenericAlias)
|
||||
_GenericAliasTypes.append(_SpecialGenericAlias)
|
||||
|
||||
GenericAliasTypes = tuple(GenericAliasTypes)
|
||||
GenericAliasTypes = tuple(_GenericAliasTypes)
|
||||
|
||||
# Potential Union types for isinstance checks (UnionType added in py3.10).
|
||||
UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,)
|
||||
@ -181,7 +181,7 @@ def is_generic_alias(cls: GenericType) -> bool:
|
||||
return isinstance(cls, GenericAliasTypes)
|
||||
|
||||
|
||||
def unionize(*args: GenericType) -> Type:
|
||||
def unionize(*args: GenericType) -> GenericType:
|
||||
"""Unionize the types.
|
||||
|
||||
Args:
|
||||
@ -415,7 +415,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_base_class(cls: GenericType) -> Type:
|
||||
def get_base_class(cls: GenericType) -> Type | tuple[Type, ...]:
|
||||
"""Get the base class of a class.
|
||||
|
||||
Args:
|
||||
@ -435,7 +435,14 @@ def get_base_class(cls: GenericType) -> Type:
|
||||
return type(get_args(cls)[0])
|
||||
|
||||
if is_union(cls):
|
||||
return tuple(get_base_class(arg) for arg in get_args(cls))
|
||||
base_classes = []
|
||||
for arg in get_args(cls):
|
||||
sub_base_classes = get_base_class(arg)
|
||||
if isinstance(sub_base_classes, tuple):
|
||||
base_classes.extend(sub_base_classes)
|
||||
else:
|
||||
base_classes.append(sub_base_classes)
|
||||
return tuple(base_classes)
|
||||
|
||||
return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls
|
||||
|
||||
|
@ -15,6 +15,7 @@ from typing import (
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
@ -1102,7 +1103,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
|
||||
_cases: tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...] = dataclasses.field(
|
||||
default_factory=tuple
|
||||
)
|
||||
_default: Var[VAR_TYPE] = dataclasses.field(
|
||||
_default: Var[VAR_TYPE] = dataclasses.field( # pyright: ignore[reportAssignmentType]
|
||||
default_factory=lambda: Var.create(None)
|
||||
)
|
||||
|
||||
@ -1170,11 +1171,14 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
|
||||
The match operation.
|
||||
"""
|
||||
cond = Var.create(cond)
|
||||
cases = tuple(tuple(Var.create(c) for c in case) for case in cases)
|
||||
default = Var.create(default)
|
||||
cases = cast(
|
||||
tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
|
||||
tuple(tuple(Var.create(c) for c in case) for case in cases),
|
||||
)
|
||||
_default = cast(Var[VAR_TYPE], Var.create(default))
|
||||
var_type = _var_type or unionize(
|
||||
*(case[-1]._var_type for case in cases),
|
||||
default._var_type,
|
||||
_default._var_type,
|
||||
)
|
||||
return cls(
|
||||
_js_expr="",
|
||||
@ -1182,7 +1186,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
|
||||
_var_type=var_type,
|
||||
_cond=cond,
|
||||
_cases=cases,
|
||||
_default=default,
|
||||
_default=_default,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user