pyi_generator: Generate stubs for SimpleNamespace classes

If the namespace assigns `__call__` to an existing component `create` function,
generate args and docstring for IDE integration.
This commit is contained in:
Masen Furer 2024-01-29 17:34:18 -08:00
parent bece5bdb44
commit 763c1c1f07
No known key found for this signature in database
GPG Key ID: B0008AD22B3B3A95

View File

@ -13,7 +13,7 @@ import typing
from inspect import getfullargspec from inspect import getfullargspec
from multiprocessing import Pool, cpu_count from multiprocessing import Pool, cpu_count
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType, SimpleNamespace
from typing import Any, Callable, Iterable, Type, get_args from typing import Any, Callable, Iterable, Type, get_args
import black import black
@ -94,7 +94,9 @@ def _relative_to_pwd(path: Path) -> Path:
Returns: Returns:
The relative path. The relative path.
""" """
return path.relative_to(PWD) if path.is_absolute():
return path.relative_to(PWD)
return path
def _git_diff(args: list[str]) -> str: def _git_diff(args: list[str]) -> str:
@ -403,7 +405,7 @@ def _get_parent_imports(func):
def _generate_component_create_functiondef( def _generate_component_create_functiondef(
node: ast.FunctionDef | None, node: ast.FunctionDef | None,
clz: type[Component], clz: type[Component] | type[SimpleNamespace],
type_hint_globals: dict[str, Any], type_hint_globals: dict[str, Any],
) -> ast.FunctionDef: ) -> ast.FunctionDef:
"""Generate the create function definition for a Component. """Generate the create function definition for a Component.
@ -415,7 +417,13 @@ def _generate_component_create_functiondef(
Returns: Returns:
The create functiondef node for the ast. The create functiondef node for the ast.
Raises:
TypeError: If clz is not a subclass of Component.
""" """
if not issubclass(clz, Component):
raise TypeError(f"clz must be a subclass of Component, not {clz!r}")
# add the imports needed by get_type_hint later # add the imports needed by get_type_hint later
type_hint_globals.update( type_hint_globals.update(
{name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS} {name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS}
@ -484,10 +492,58 @@ def _generate_component_create_functiondef(
return definition return definition
def _generate_namespace_call_functiondef(
clz_name: str,
classes: dict[str, type[Component] | type[SimpleNamespace]],
type_hint_globals: dict[str, Any],
) -> ast.FunctionDef | None:
"""Generate the __call__ function definition for a SimpleNamespace.
Args:
clz_name: The name of the SimpleNamespace class to generate the __call__ functiondef for.
classes: Map name to actual class definition.
type_hint_globals: The globals to use to resolving a type hint str.
Returns:
The create functiondef node for the ast.
"""
# add the imports needed by get_type_hint later
type_hint_globals.update(
{name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS}
)
clz = classes[clz_name]
# Determine which class is wrapped by the namespace __call__ method
component_class_name, dot, func_name = clz.__call__.__func__.__qualname__.partition(
"."
)
component_clz = classes[component_class_name]
# Only generate for create functions
if func_name != "create":
return None
definition = _generate_component_create_functiondef(
node=None,
clz=component_clz,
type_hint_globals=type_hint_globals,
)
definition.name = "__call__"
# Turn the definition into a staticmethod
del definition.args.args[0] # remove `cls` arg
definition.decorator_list = [ast.Name(id="staticmethod")]
return definition
class StubGenerator(ast.NodeTransformer): class StubGenerator(ast.NodeTransformer):
"""A node transformer that will generate the stubs for a given module.""" """A node transformer that will generate the stubs for a given module."""
def __init__(self, module: ModuleType, classes: dict[str, Type[Component]]): def __init__(
self, module: ModuleType, classes: dict[str, Type[Component | SimpleNamespace]]
):
"""Initialize the stub generator. """Initialize the stub generator.
Args: Args:
@ -528,6 +584,18 @@ class StubGenerator(ast.NodeTransformer):
node.body.pop(0) node.body.pop(0)
return node return node
def _current_class_is_component(self) -> bool:
"""Check if the current class is a Component.
Returns:
Whether the current class is a Component.
"""
return (
self.current_class is not None
and self.current_class in self.classes
and issubclass(self.classes[self.current_class], Component)
)
def visit_Module(self, node: ast.Module) -> ast.Module: def visit_Module(self, node: ast.Module) -> ast.Module:
"""Visit a Module node and remove docstring from body. """Visit a Module node and remove docstring from body.
@ -591,6 +659,27 @@ class StubGenerator(ast.NodeTransformer):
exec("\n".join(self.import_statements), self.type_hint_globals) exec("\n".join(self.import_statements), self.type_hint_globals)
self.current_class = node.name self.current_class = node.name
self._remove_docstring(node) self._remove_docstring(node)
# Define `__call__` as a real function so the docstring appears in the stub.
call_definition = None
for child in node.body[:]:
found_call = False
if isinstance(child, ast.Assign):
for target in child.targets[:]:
if isinstance(target, ast.Name) and target.id == "__call__":
child.targets.remove(target)
found_call = True
if not found_call:
continue
if not child.targets[:]:
node.body.remove(child)
call_definition = _generate_namespace_call_functiondef(
self.current_class,
self.classes,
type_hint_globals=self.type_hint_globals,
)
break
self.generic_visit(node) # Visit child nodes. self.generic_visit(node) # Visit child nodes.
if ( if (
@ -598,7 +687,7 @@ class StubGenerator(ast.NodeTransformer):
isinstance(child, ast.FunctionDef) and child.name == "create" isinstance(child, ast.FunctionDef) and child.name == "create"
for child in node.body for child in node.body
) )
and self.current_class in self.classes and self._current_class_is_component()
): ):
# Add a new .create FunctionDef since one does not exist. # Add a new .create FunctionDef since one does not exist.
node.body.append( node.body.append(
@ -608,6 +697,8 @@ class StubGenerator(ast.NodeTransformer):
type_hint_globals=self.type_hint_globals, type_hint_globals=self.type_hint_globals,
) )
) )
if call_definition is not None:
node.body.append(call_definition)
if not node.body: if not node.body:
# We should never return an empty body. # We should never return an empty body.
node.body.append(ast.Expr(value=ast.Ellipsis())) node.body.append(ast.Expr(value=ast.Ellipsis()))
@ -634,11 +725,12 @@ class StubGenerator(ast.NodeTransformer):
node, self.classes[self.current_class], self.type_hint_globals node, self.classes[self.current_class], self.type_hint_globals
) )
else: else:
if node.name.startswith("_"): if node.name.startswith("_") and node.name != "__call__":
return None # remove private methods return None # remove private methods
# Blank out the function body for public functions. if node.body[-1] != ast.Expr(value=ast.Ellipsis()):
node.body = [ast.Expr(value=ast.Ellipsis())] # Blank out the function body for public functions.
node.body = [ast.Expr(value=ast.Ellipsis())]
return node return node
def visit_Assign(self, node: ast.Assign) -> ast.Assign | None: def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
@ -657,9 +749,11 @@ class StubGenerator(ast.NodeTransformer):
and node.value.id == "Any" and node.value.id == "Any"
): ):
return node return node
if self.current_class in self.classes:
if self._current_class_is_component():
# Remove annotated assignments in Component classes (props) # Remove annotated assignments in Component classes (props)
return None return None
return node return node
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign | None: def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign | None:
@ -738,7 +832,7 @@ class PyiGenerator:
name: obj name: obj
for name, obj in vars(module).items() for name, obj in vars(module).items()
if inspect.isclass(obj) if inspect.isclass(obj)
and issubclass(obj, Component) and (issubclass(obj, Component) or issubclass(obj, SimpleNamespace))
and obj != Component and obj != Component
and inspect.getmodule(obj) == module and inspect.getmodule(obj) == module
} }