From 763c1c1f078c6561690caf856a711efb04e1f881 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 29 Jan 2024 17:34:18 -0800 Subject: [PATCH] pyi_generator: Generate stubs for `SimpleNamespace` classes If the namespace assigns `__call__` to an existing component `create` function, generate args and docstring for IDE integration. --- scripts/pyi_generator.py | 114 +++++++++++++++++++++++++++++++++++---- 1 file changed, 104 insertions(+), 10 deletions(-) diff --git a/scripts/pyi_generator.py b/scripts/pyi_generator.py index a6f55281d..d3b7d2e0a 100644 --- a/scripts/pyi_generator.py +++ b/scripts/pyi_generator.py @@ -13,7 +13,7 @@ import typing from inspect import getfullargspec from multiprocessing import Pool, cpu_count from pathlib import Path -from types import ModuleType +from types import ModuleType, SimpleNamespace from typing import Any, Callable, Iterable, Type, get_args import black @@ -94,7 +94,9 @@ def _relative_to_pwd(path: Path) -> Path: Returns: The relative path. """ - return path.relative_to(PWD) + if path.is_absolute(): + return path.relative_to(PWD) + return path def _git_diff(args: list[str]) -> str: @@ -403,7 +405,7 @@ def _get_parent_imports(func): def _generate_component_create_functiondef( node: ast.FunctionDef | None, - clz: type[Component], + clz: type[Component] | type[SimpleNamespace], type_hint_globals: dict[str, Any], ) -> ast.FunctionDef: """Generate the create function definition for a Component. @@ -415,7 +417,13 @@ def _generate_component_create_functiondef( Returns: The create functiondef node for the ast. + + Raises: + TypeError: If clz is not a subclass of Component. """ + if not issubclass(clz, Component): + raise TypeError(f"clz must be a subclass of Component, not {clz!r}") + # add the imports needed by get_type_hint later type_hint_globals.update( {name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS} @@ -484,10 +492,58 @@ def _generate_component_create_functiondef( return definition +def _generate_namespace_call_functiondef( + clz_name: str, + classes: dict[str, type[Component] | type[SimpleNamespace]], + type_hint_globals: dict[str, Any], +) -> ast.FunctionDef | None: + """Generate the __call__ function definition for a SimpleNamespace. + + Args: + clz_name: The name of the SimpleNamespace class to generate the __call__ functiondef for. + classes: Map name to actual class definition. + type_hint_globals: The globals to use to resolving a type hint str. + + Returns: + The create functiondef node for the ast. + """ + # add the imports needed by get_type_hint later + type_hint_globals.update( + {name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS} + ) + + clz = classes[clz_name] + + # Determine which class is wrapped by the namespace __call__ method + component_class_name, dot, func_name = clz.__call__.__func__.__qualname__.partition( + "." + ) + component_clz = classes[component_class_name] + + # Only generate for create functions + if func_name != "create": + return None + + definition = _generate_component_create_functiondef( + node=None, + clz=component_clz, + type_hint_globals=type_hint_globals, + ) + definition.name = "__call__" + + # Turn the definition into a staticmethod + del definition.args.args[0] # remove `cls` arg + definition.decorator_list = [ast.Name(id="staticmethod")] + + return definition + + class StubGenerator(ast.NodeTransformer): """A node transformer that will generate the stubs for a given module.""" - def __init__(self, module: ModuleType, classes: dict[str, Type[Component]]): + def __init__( + self, module: ModuleType, classes: dict[str, Type[Component | SimpleNamespace]] + ): """Initialize the stub generator. Args: @@ -528,6 +584,18 @@ class StubGenerator(ast.NodeTransformer): node.body.pop(0) return node + def _current_class_is_component(self) -> bool: + """Check if the current class is a Component. + + Returns: + Whether the current class is a Component. + """ + return ( + self.current_class is not None + and self.current_class in self.classes + and issubclass(self.classes[self.current_class], Component) + ) + def visit_Module(self, node: ast.Module) -> ast.Module: """Visit a Module node and remove docstring from body. @@ -591,6 +659,27 @@ class StubGenerator(ast.NodeTransformer): exec("\n".join(self.import_statements), self.type_hint_globals) self.current_class = node.name self._remove_docstring(node) + + # Define `__call__` as a real function so the docstring appears in the stub. + call_definition = None + for child in node.body[:]: + found_call = False + if isinstance(child, ast.Assign): + for target in child.targets[:]: + if isinstance(target, ast.Name) and target.id == "__call__": + child.targets.remove(target) + found_call = True + if not found_call: + continue + if not child.targets[:]: + node.body.remove(child) + call_definition = _generate_namespace_call_functiondef( + self.current_class, + self.classes, + type_hint_globals=self.type_hint_globals, + ) + break + self.generic_visit(node) # Visit child nodes. if ( @@ -598,7 +687,7 @@ class StubGenerator(ast.NodeTransformer): isinstance(child, ast.FunctionDef) and child.name == "create" for child in node.body ) - and self.current_class in self.classes + and self._current_class_is_component() ): # Add a new .create FunctionDef since one does not exist. node.body.append( @@ -608,6 +697,8 @@ class StubGenerator(ast.NodeTransformer): type_hint_globals=self.type_hint_globals, ) ) + if call_definition is not None: + node.body.append(call_definition) if not node.body: # We should never return an empty body. node.body.append(ast.Expr(value=ast.Ellipsis())) @@ -634,11 +725,12 @@ class StubGenerator(ast.NodeTransformer): node, self.classes[self.current_class], self.type_hint_globals ) else: - if node.name.startswith("_"): + if node.name.startswith("_") and node.name != "__call__": return None # remove private methods - # Blank out the function body for public functions. - node.body = [ast.Expr(value=ast.Ellipsis())] + if node.body[-1] != ast.Expr(value=ast.Ellipsis()): + # Blank out the function body for public functions. + node.body = [ast.Expr(value=ast.Ellipsis())] return node def visit_Assign(self, node: ast.Assign) -> ast.Assign | None: @@ -657,9 +749,11 @@ class StubGenerator(ast.NodeTransformer): and node.value.id == "Any" ): return node - if self.current_class in self.classes: + + if self._current_class_is_component(): # Remove annotated assignments in Component classes (props) return None + return node def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign | None: @@ -738,7 +832,7 @@ class PyiGenerator: name: obj for name, obj in vars(module).items() if inspect.isclass(obj) - and issubclass(obj, Component) + and (issubclass(obj, Component) or issubclass(obj, SimpleNamespace)) and obj != Component and inspect.getmodule(obj) == module }