solve some but not all pyright issues

This commit is contained in:
Khaleel Al-Adhami 2025-01-16 20:06:02 -08:00
parent 3d73f561b7
commit 112b2ed948
3 changed files with 53 additions and 24 deletions

View File

@ -16,7 +16,7 @@ from itertools import chain
from multiprocessing import Pool, cpu_count from multiprocessing import Pool, cpu_count
from pathlib import Path from pathlib import Path
from types import ModuleType, SimpleNamespace from types import ModuleType, SimpleNamespace
from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin from typing import Any, Callable, Iterable, Sequence, Type, cast, get_args, get_origin
from reflex.components.component import Component from reflex.components.component import Component
from reflex.utils import types as rx_types from reflex.utils import types as rx_types
@ -229,7 +229,9 @@ def _generate_imports(
""" """
return [ return [
*[ *[
ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values]) ast.ImportFrom(
module=name, names=[ast.alias(name=val) for val in values], level=0
)
for name, values in DEFAULT_IMPORTS.items() for name, values in DEFAULT_IMPORTS.items()
], ],
ast.Import([ast.alias("reflex")]), ast.Import([ast.alias("reflex")]),
@ -428,16 +430,15 @@ def type_to_ast(typ, cls: type) -> ast.AST:
return ast.Name(id=base_name) return ast.Name(id=base_name)
# Convert all type arguments recursively # Convert all type arguments recursively
arg_nodes = [type_to_ast(arg, cls) for arg in args] arg_nodes = cast(list[ast.expr], [type_to_ast(arg, cls) for arg in args])
# Special case for single-argument types (like List[T] or Optional[T]) # Special case for single-argument types (like List[T] or Optional[T])
if len(arg_nodes) == 1: if len(arg_nodes) == 1:
slice_value = arg_nodes[0] slice_value = arg_nodes[0]
else: else:
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load()) slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())
return ast.Subscript( return ast.Subscript(
value=ast.Name(id=base_name), slice=ast.Index(value=slice_value), ctx=ast.Load() value=ast.Name(id=base_name), slice=slice_value, ctx=ast.Load()
) )
@ -630,7 +631,7 @@ def _generate_component_create_functiondef(
), ),
), ),
ast.Expr( ast.Expr(
value=ast.Ellipsis(), value=ast.Constant(...),
), ),
], ],
decorator_list=[ decorator_list=[
@ -641,8 +642,14 @@ def _generate_component_create_functiondef(
else [ast.Name(id="classmethod")] else [ast.Name(id="classmethod")]
), ),
], ],
lineno=node.lineno if node is not None else None,
returns=ast.Constant(value=clz.__name__), returns=ast.Constant(value=clz.__name__),
**(
{
"lineno": node.lineno,
}
if node is not None
else {}
),
) )
return definition return definition
@ -690,13 +697,19 @@ def _generate_staticmethod_call_functiondef(
), ),
], ],
decorator_list=[ast.Name(id="staticmethod")], decorator_list=[ast.Name(id="staticmethod")],
lineno=node.lineno if node is not None else None,
returns=ast.Constant( returns=ast.Constant(
value=_get_type_hint( value=_get_type_hint(
typing.get_type_hints(clz.__call__).get("return", None), typing.get_type_hints(clz.__call__).get("return", None),
type_hint_globals, type_hint_globals,
) )
), ),
**(
{
"lineno": node.lineno,
}
if node is not None
else {}
),
) )
return definition return definition
@ -731,7 +744,12 @@ def _generate_namespace_call_functiondef(
# Determine which class is wrapped by the namespace __call__ method # Determine which class is wrapped by the namespace __call__ method
component_clz = clz.__call__.__self__ component_clz = clz.__call__.__self__
if clz.__call__.__func__.__name__ != "create": func = getattr(clz.__call__, "__func__", None)
if func is None:
raise TypeError(f"__call__ method on {clz_name} does not have a __func__")
if func.__name__ != "create":
return None return None
definition = _generate_component_create_functiondef( definition = _generate_component_create_functiondef(
@ -914,7 +932,7 @@ class StubGenerator(ast.NodeTransformer):
node.body.append(call_definition) node.body.append(call_definition)
if not node.body: if not node.body:
# We should never return an empty body. # We should never return an empty body.
node.body.append(ast.Expr(value=ast.Ellipsis())) node.body.append(ast.Expr(value=ast.Constant(...)))
self.current_class = None self.current_class = None
return node return node
@ -941,9 +959,9 @@ class StubGenerator(ast.NodeTransformer):
if node.name.startswith("_") and node.name != "__call__": if node.name.startswith("_") and node.name != "__call__":
return None # remove private methods return None # remove private methods
if node.body[-1] != ast.Expr(value=ast.Ellipsis()): if node.body[-1] != ast.Expr(value=ast.Constant(...)):
# Blank out the function body for public functions. # Blank out the function body for public functions.
node.body = [ast.Expr(value=ast.Ellipsis())] node.body = [ast.Expr(value=ast.Constant(...))]
return node return node
def visit_Assign(self, node: ast.Assign) -> ast.Assign | None: def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:

View File

@ -69,21 +69,21 @@ else:
# Potential GenericAlias types for isinstance checks. # Potential GenericAlias types for isinstance checks.
GenericAliasTypes = [_GenericAlias] _GenericAliasTypes: list[type] = [_GenericAlias]
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
# For newer versions of Python. # For newer versions of Python.
from types import GenericAlias # type: ignore from types import GenericAlias # type: ignore
GenericAliasTypes.append(GenericAlias) _GenericAliasTypes.append(GenericAlias)
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
# For older versions of Python. # For older versions of Python.
from typing import _SpecialGenericAlias # type: ignore from typing import _SpecialGenericAlias # type: ignore
GenericAliasTypes.append(_SpecialGenericAlias) _GenericAliasTypes.append(_SpecialGenericAlias)
GenericAliasTypes = tuple(GenericAliasTypes) GenericAliasTypes = tuple(_GenericAliasTypes)
# Potential Union types for isinstance checks (UnionType added in py3.10). # Potential Union types for isinstance checks (UnionType added in py3.10).
UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,) UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,)
@ -181,7 +181,7 @@ def is_generic_alias(cls: GenericType) -> bool:
return isinstance(cls, GenericAliasTypes) return isinstance(cls, GenericAliasTypes)
def unionize(*args: GenericType) -> Type: def unionize(*args: GenericType) -> GenericType:
"""Unionize the types. """Unionize the types.
Args: Args:
@ -415,7 +415,7 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None
@lru_cache() @lru_cache()
def get_base_class(cls: GenericType) -> Type: def get_base_class(cls: GenericType) -> Type | tuple[Type, ...]:
"""Get the base class of a class. """Get the base class of a class.
Args: Args:
@ -435,7 +435,14 @@ def get_base_class(cls: GenericType) -> Type:
return type(get_args(cls)[0]) return type(get_args(cls)[0])
if is_union(cls): if is_union(cls):
return tuple(get_base_class(arg) for arg in get_args(cls)) base_classes = []
for arg in get_args(cls):
sub_base_classes = get_base_class(arg)
if isinstance(sub_base_classes, tuple):
base_classes.extend(sub_base_classes)
else:
base_classes.append(sub_base_classes)
return tuple(base_classes)
return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls return get_base_class(cls.__origin__) if is_generic_alias(cls) else cls

View File

@ -15,6 +15,7 @@ from typing import (
Sequence, Sequence,
TypeVar, TypeVar,
Union, Union,
cast,
overload, overload,
) )
@ -1102,7 +1103,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
_cases: tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...] = dataclasses.field( _cases: tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...] = dataclasses.field(
default_factory=tuple default_factory=tuple
) )
_default: Var[VAR_TYPE] = dataclasses.field( _default: Var[VAR_TYPE] = dataclasses.field( # pyright: ignore[reportAssignmentType]
default_factory=lambda: Var.create(None) default_factory=lambda: Var.create(None)
) )
@ -1170,11 +1171,14 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
The match operation. The match operation.
""" """
cond = Var.create(cond) cond = Var.create(cond)
cases = tuple(tuple(Var.create(c) for c in case) for case in cases) cases = cast(
default = Var.create(default) tuple[TUPLE_ENDS_IN_VAR[VAR_TYPE], ...],
tuple(tuple(Var.create(c) for c in case) for case in cases),
)
_default = cast(Var[VAR_TYPE], Var.create(default))
var_type = _var_type or unionize( var_type = _var_type or unionize(
*(case[-1]._var_type for case in cases), *(case[-1]._var_type for case in cases),
default._var_type, _default._var_type,
) )
return cls( return cls(
_js_expr="", _js_expr="",
@ -1182,7 +1186,7 @@ class MatchOperation(CachedVarOperation, Var[VAR_TYPE]):
_var_type=var_type, _var_type=var_type,
_cond=cond, _cond=cond,
_cases=cases, _cases=cases,
_default=default, _default=_default,
) )