Refactor pyi_generator to use AST directly (#2034)
This commit is contained in:
parent
fe01f0cf11
commit
5d590c350e
@ -1,158 +1,27 @@
|
||||
"""The pyi generator module."""
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
from inspect import getfullargspec
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union, get_args # NOQA
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Iterable, Type, get_args
|
||||
|
||||
import black
|
||||
import black.mode
|
||||
|
||||
from reflex.components.component import Component
|
||||
|
||||
# NOQA
|
||||
from reflex.components.graphing.recharts.recharts import (
|
||||
LiteralAnimationEasing,
|
||||
LiteralAreaType,
|
||||
LiteralComposedChartBaseValue,
|
||||
LiteralDirection,
|
||||
LiteralGridType,
|
||||
LiteralIconType,
|
||||
LiteralIfOverflow,
|
||||
LiteralInterval,
|
||||
LiteralLayout,
|
||||
LiteralLegendAlign,
|
||||
LiteralLineType,
|
||||
LiteralOrientationTopBottom,
|
||||
LiteralOrientationTopBottomLeftRight,
|
||||
LiteralPolarRadiusType,
|
||||
LiteralPosition,
|
||||
LiteralScale,
|
||||
LiteralShape,
|
||||
LiteralStackOffset,
|
||||
LiteralSyncMethod,
|
||||
LiteralVerticalAlign,
|
||||
)
|
||||
from reflex.components.libs.chakra import (
|
||||
LiteralAlertDialogSize,
|
||||
LiteralAvatarSize,
|
||||
LiteralChakraDirection,
|
||||
LiteralColorScheme,
|
||||
LiteralDrawerSize,
|
||||
LiteralImageLoading,
|
||||
LiteralInputVariant,
|
||||
LiteralMenuOption,
|
||||
LiteralMenuStrategy,
|
||||
LiteralTagSize,
|
||||
)
|
||||
from reflex.components.radix.themes.base import (
|
||||
LiteralAccentColor,
|
||||
LiteralAlign,
|
||||
LiteralAppearance,
|
||||
LiteralGrayColor,
|
||||
LiteralJustify,
|
||||
LiteralPanelBackground,
|
||||
LiteralRadius,
|
||||
LiteralScaling,
|
||||
LiteralSize,
|
||||
LiteralVariant,
|
||||
)
|
||||
from reflex.components.radix.themes.components import (
|
||||
LiteralButtonSize,
|
||||
LiteralSwitchSize,
|
||||
)
|
||||
from reflex.components.radix.themes.layout import (
|
||||
LiteralBoolNumber,
|
||||
LiteralContainerSize,
|
||||
LiteralFlexDirection,
|
||||
LiteralFlexDisplay,
|
||||
LiteralFlexWrap,
|
||||
LiteralGridDisplay,
|
||||
LiteralGridFlow,
|
||||
LiteralSectionSize,
|
||||
)
|
||||
from reflex.components.radix.themes.typography import (
|
||||
LiteralLinkUnderline,
|
||||
LiteralTextAlign,
|
||||
LiteralTextSize,
|
||||
LiteralTextTrim,
|
||||
LiteralTextWeight,
|
||||
)
|
||||
|
||||
# NOQA
|
||||
from reflex.event import EventChain
|
||||
from reflex.style import Style
|
||||
from reflex.utils import format
|
||||
from reflex.utils import types as rx_types
|
||||
from reflex.vars import Var
|
||||
|
||||
ruff_dont_remove = [
|
||||
Var,
|
||||
Optional,
|
||||
Dict,
|
||||
List,
|
||||
EventChain,
|
||||
Style,
|
||||
LiteralInputVariant,
|
||||
LiteralColorScheme,
|
||||
LiteralChakraDirection,
|
||||
LiteralTagSize,
|
||||
LiteralDrawerSize,
|
||||
LiteralMenuStrategy,
|
||||
LiteralMenuOption,
|
||||
LiteralAlertDialogSize,
|
||||
LiteralAvatarSize,
|
||||
LiteralImageLoading,
|
||||
LiteralLayout,
|
||||
LiteralAnimationEasing,
|
||||
LiteralGridType,
|
||||
LiteralPolarRadiusType,
|
||||
LiteralScale,
|
||||
LiteralSyncMethod,
|
||||
LiteralStackOffset,
|
||||
LiteralComposedChartBaseValue,
|
||||
LiteralOrientationTopBottom,
|
||||
LiteralAreaType,
|
||||
LiteralShape,
|
||||
LiteralLineType,
|
||||
LiteralDirection,
|
||||
LiteralIfOverflow,
|
||||
LiteralOrientationTopBottomLeftRight,
|
||||
LiteralInterval,
|
||||
LiteralLegendAlign,
|
||||
LiteralVerticalAlign,
|
||||
LiteralIconType,
|
||||
LiteralPosition,
|
||||
LiteralAccentColor,
|
||||
LiteralAlign,
|
||||
LiteralAppearance,
|
||||
LiteralBoolNumber,
|
||||
LiteralButtonSize,
|
||||
LiteralContainerSize,
|
||||
LiteralFlexDirection,
|
||||
LiteralFlexDisplay,
|
||||
LiteralFlexWrap,
|
||||
LiteralGrayColor,
|
||||
LiteralGridDisplay,
|
||||
LiteralGridFlow,
|
||||
LiteralJustify,
|
||||
LiteralLinkUnderline,
|
||||
LiteralPanelBackground,
|
||||
LiteralRadius,
|
||||
LiteralScaling,
|
||||
LiteralSectionSize,
|
||||
LiteralSize,
|
||||
LiteralSwitchSize,
|
||||
LiteralTextAlign,
|
||||
LiteralTextSize,
|
||||
LiteralTextTrim,
|
||||
LiteralTextWeight,
|
||||
LiteralVariant,
|
||||
]
|
||||
logger = logging.getLogger("pyi_generator")
|
||||
|
||||
EXCLUDED_FILES = [
|
||||
"__init__.py",
|
||||
@ -177,18 +46,36 @@ EXCLUDED_PROPS = [
|
||||
"valid_children",
|
||||
]
|
||||
|
||||
DEFAULT_TYPING_IMPORTS = {"overload", "Any", "Dict", "List", "Optional", "Union"}
|
||||
DEFAULT_TYPING_IMPORTS = {
|
||||
"overload",
|
||||
"Any",
|
||||
"Dict",
|
||||
"List",
|
||||
"Literal",
|
||||
"Optional",
|
||||
"Union",
|
||||
}
|
||||
|
||||
|
||||
def _get_type_hint(value, top_level=True, no_union=False):
|
||||
def _get_type_hint(value, type_hint_globals, is_optional=True) -> str:
|
||||
"""Resolve the type hint for value.
|
||||
|
||||
Args:
|
||||
value: The type annotation as a str or actual types/aliases.
|
||||
type_hint_globals: The globals to use to resolving a type hint str.
|
||||
is_optional: Whether the type hint should be wrapped in Optional.
|
||||
|
||||
Returns:
|
||||
The resolved type hint as a str.
|
||||
"""
|
||||
res = ""
|
||||
args = get_args(value)
|
||||
if args:
|
||||
inner_container_type_args = (
|
||||
[format.wrap(arg, '"') for arg in args]
|
||||
[repr(arg) for arg in args]
|
||||
if rx_types.is_literal(value)
|
||||
else [
|
||||
_get_type_hint(arg, top_level=False)
|
||||
_get_type_hint(arg, type_hint_globals, is_optional=False)
|
||||
for arg in args
|
||||
if arg is not type(None)
|
||||
]
|
||||
@ -196,42 +83,454 @@ def _get_type_hint(value, top_level=True, no_union=False):
|
||||
res = f"{value.__name__}[{', '.join(inner_container_type_args)}]"
|
||||
|
||||
if value.__name__ == "Var":
|
||||
# For Var types, Union with the inner args so they can be passed directly.
|
||||
types = [res] + [
|
||||
_get_type_hint(arg, top_level=False)
|
||||
_get_type_hint(arg, type_hint_globals, is_optional=False)
|
||||
for arg in args
|
||||
if arg is not type(None)
|
||||
]
|
||||
if len(types) > 1 and not no_union:
|
||||
if len(types) > 1:
|
||||
res = ", ".join(types)
|
||||
res = f"Union[{res}]"
|
||||
elif isinstance(value, str):
|
||||
ev = eval(value)
|
||||
res = _get_type_hint(ev, top_level=False) if ev.__name__ == "Var" else value
|
||||
ev = eval(value, type_hint_globals)
|
||||
res = (
|
||||
_get_type_hint(ev, type_hint_globals, is_optional=False)
|
||||
if ev.__name__ == "Var"
|
||||
else value
|
||||
)
|
||||
else:
|
||||
res = value.__name__
|
||||
if top_level and not res.startswith("Optional"):
|
||||
if is_optional and not res.startswith("Optional"):
|
||||
res = f"Optional[{res}]"
|
||||
return res
|
||||
|
||||
|
||||
def _get_typing_import(_module):
|
||||
src = [
|
||||
line
|
||||
for line in inspect.getsource(_module).split("\n")
|
||||
if line.startswith("from typing")
|
||||
def _generate_imports(typing_imports: Iterable[str]) -> list[ast.ImportFrom]:
|
||||
"""Generate the import statements for the stub file.
|
||||
|
||||
Args:
|
||||
typing_imports: The typing imports to include.
|
||||
|
||||
Returns:
|
||||
The list of import statements.
|
||||
"""
|
||||
return [
|
||||
ast.ImportFrom(
|
||||
module="typing", names=[ast.alias(name=imp) for imp in typing_imports]
|
||||
),
|
||||
*ast.parse( # type: ignore
|
||||
textwrap.dedent(
|
||||
"""
|
||||
from reflex.vars import Var, BaseVar, ComputedVar
|
||||
from reflex.event import EventChain, EventHandler, EventSpec
|
||||
from reflex.style import Style"""
|
||||
)
|
||||
).body,
|
||||
]
|
||||
if len(src):
|
||||
return set(src[0].rpartition("from typing import ")[-1].split(", "))
|
||||
return set()
|
||||
|
||||
|
||||
def _get_var_definition(_module, _var_name):
|
||||
for node in ast.parse(inspect.getsource(_module)).body:
|
||||
if isinstance(node, ast.Assign) and _var_name in [
|
||||
t.id for t in node.targets if isinstance(t, ast.Name)
|
||||
]:
|
||||
return ast.unparse(node)
|
||||
raise Exception(f"Could not find var {_var_name} in module {_module}")
|
||||
def _generate_docstrings(clzs: list[Type[Component]], props: list[str]) -> str:
|
||||
"""Generate the docstrings for the create method.
|
||||
|
||||
Args:
|
||||
clzs: The classes to generate docstrings for.
|
||||
props: The props to generate docstrings for.
|
||||
|
||||
Returns:
|
||||
The docstring for the create method.
|
||||
"""
|
||||
props_comments = {}
|
||||
comments = []
|
||||
for clz in clzs:
|
||||
for line in inspect.getsource(clz).splitlines():
|
||||
reached_functions = re.search("def ", line)
|
||||
if reached_functions:
|
||||
# We've reached the functions, so stop.
|
||||
break
|
||||
|
||||
# Get comments for prop
|
||||
if line.strip().startswith("#"):
|
||||
comments.append(line)
|
||||
continue
|
||||
|
||||
# Check if this line has a prop.
|
||||
match = re.search("\\w+:", line)
|
||||
if match is None:
|
||||
# This line doesn't have a var, so continue.
|
||||
continue
|
||||
|
||||
# Get the prop.
|
||||
prop = match.group(0).strip(":")
|
||||
if prop in props:
|
||||
if not comments: # do not include undocumented props
|
||||
continue
|
||||
props_comments[prop] = [
|
||||
comment.strip().strip("#") for comment in comments
|
||||
]
|
||||
comments.clear()
|
||||
clz = clzs[0]
|
||||
new_docstring = []
|
||||
for line in (clz.create.__doc__ or "").splitlines():
|
||||
if "**" in line:
|
||||
indent = line.split("**")[0]
|
||||
for nline in [
|
||||
f"{indent}{n}:{' '.join(c)}" for n, c in props_comments.items()
|
||||
]:
|
||||
new_docstring.append(nline)
|
||||
new_docstring.append(line)
|
||||
return "\n".join(new_docstring)
|
||||
|
||||
|
||||
def _extract_func_kwargs_as_ast_nodes(
|
||||
func: Callable,
|
||||
type_hint_globals: dict[str, Any],
|
||||
) -> list[tuple[ast.arg, ast.Constant | None]]:
|
||||
"""Get the kwargs already defined on the function.
|
||||
|
||||
Args:
|
||||
func: The function to extract kwargs from.
|
||||
type_hint_globals: The globals to use to resolving a type hint str.
|
||||
|
||||
Returns:
|
||||
The list of kwargs as ast arg nodes.
|
||||
"""
|
||||
spec = getfullargspec(func)
|
||||
kwargs = []
|
||||
|
||||
for kwarg in spec.kwonlyargs:
|
||||
arg = ast.arg(arg=kwarg)
|
||||
if kwarg in spec.annotations:
|
||||
arg.annotation = ast.Name(
|
||||
id=_get_type_hint(spec.annotations[kwarg], type_hint_globals)
|
||||
)
|
||||
default = None
|
||||
if spec.kwonlydefaults is not None and kwarg in spec.kwonlydefaults:
|
||||
default = ast.Constant(value=spec.kwonlydefaults[kwarg])
|
||||
kwargs.append((arg, default))
|
||||
return kwargs
|
||||
|
||||
|
||||
def _extract_class_props_as_ast_nodes(
|
||||
func: Callable,
|
||||
clzs: list[Type],
|
||||
type_hint_globals: dict[str, Any],
|
||||
extract_real_default: bool = False,
|
||||
) -> list[tuple[ast.arg, ast.Constant | None]]:
|
||||
"""Get the props defined on the class and all parents.
|
||||
|
||||
Args:
|
||||
func: The function that kwargs will be added to.
|
||||
clzs: The classes to extract props from.
|
||||
type_hint_globals: The globals to use to resolving a type hint str.
|
||||
extract_real_default: Whether to extract the real default value from the
|
||||
pydantic field definition.
|
||||
|
||||
Returns:
|
||||
The list of props as ast arg nodes
|
||||
"""
|
||||
spec = getfullargspec(func)
|
||||
all_props = []
|
||||
kwargs = []
|
||||
for target_class in clzs:
|
||||
# Import from the target class to ensure type hints are resolvable.
|
||||
exec(f"from {target_class.__module__} import *", type_hint_globals)
|
||||
for name, value in target_class.__annotations__.items():
|
||||
if name in spec.kwonlyargs or name in EXCLUDED_PROPS or name in all_props:
|
||||
continue
|
||||
all_props.append(name)
|
||||
|
||||
default = None
|
||||
if extract_real_default:
|
||||
# TODO: This is not currently working since the default is not type compatible
|
||||
# with the annotation in some cases.
|
||||
with contextlib.suppress(AttributeError, KeyError):
|
||||
# Try to get default from pydantic field definition.
|
||||
default = target_class.__fields__[name].default
|
||||
if isinstance(default, Var):
|
||||
default = default._decode() # type: ignore
|
||||
|
||||
kwargs.append(
|
||||
(
|
||||
ast.arg(
|
||||
arg=name,
|
||||
annotation=ast.Name(
|
||||
id=_get_type_hint(value, type_hint_globals)
|
||||
),
|
||||
),
|
||||
ast.Constant(value=default),
|
||||
)
|
||||
)
|
||||
return kwargs
|
||||
|
||||
|
||||
def _generate_component_create_functiondef(
|
||||
node: ast.FunctionDef | None,
|
||||
clz: type[Component],
|
||||
type_hint_globals: dict[str, Any],
|
||||
) -> ast.FunctionDef:
|
||||
"""Generate the create function definition for a Component.
|
||||
|
||||
Args:
|
||||
node: The existing create functiondef node from the ast
|
||||
clz: The Component class to generate the create functiondef for.
|
||||
type_hint_globals: The globals to use to resolving a type hint str.
|
||||
|
||||
Returns:
|
||||
The create functiondef node for the ast.
|
||||
"""
|
||||
# kwargs defined on the actual create function
|
||||
kwargs = _extract_func_kwargs_as_ast_nodes(clz.create, type_hint_globals)
|
||||
|
||||
# kwargs associated with props defined in the class and its parents
|
||||
all_classes = [c for c in clz.__mro__ if issubclass(c, Component)]
|
||||
prop_kwargs = _extract_class_props_as_ast_nodes(
|
||||
clz.create, all_classes, type_hint_globals
|
||||
)
|
||||
all_props = [arg[0].arg for arg in prop_kwargs]
|
||||
kwargs.extend(prop_kwargs)
|
||||
|
||||
# event handler kwargs
|
||||
kwargs.extend(
|
||||
(
|
||||
ast.arg(
|
||||
arg=trigger,
|
||||
annotation=ast.Name(
|
||||
id="Optional[Union[EventHandler, EventSpec, List, function, BaseVar]]"
|
||||
),
|
||||
),
|
||||
ast.Constant(value=None),
|
||||
)
|
||||
for trigger in sorted(clz().get_event_triggers().keys())
|
||||
)
|
||||
logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs")
|
||||
create_args = ast.arguments(
|
||||
args=[ast.arg(arg="cls")],
|
||||
posonlyargs=[],
|
||||
vararg=ast.arg(arg="children"),
|
||||
kwonlyargs=[arg[0] for arg in kwargs],
|
||||
kw_defaults=[arg[1] for arg in kwargs],
|
||||
kwarg=ast.arg(arg="props"),
|
||||
defaults=[],
|
||||
)
|
||||
definition = ast.FunctionDef(
|
||||
name="create",
|
||||
args=create_args,
|
||||
body=[
|
||||
ast.Expr(
|
||||
value=ast.Constant(value=_generate_docstrings(all_classes, all_props))
|
||||
),
|
||||
ast.Expr(
|
||||
value=ast.Ellipsis(),
|
||||
),
|
||||
],
|
||||
decorator_list=[
|
||||
ast.Name(id="overload"),
|
||||
*(
|
||||
node.decorator_list
|
||||
if node is not None
|
||||
else [ast.Name(id="classmethod")]
|
||||
),
|
||||
],
|
||||
lineno=node.lineno if node is not None else None,
|
||||
returns=ast.Constant(value=clz.__name__),
|
||||
)
|
||||
return definition
|
||||
|
||||
|
||||
class StubGenerator(ast.NodeTransformer):
|
||||
"""A node transformer that will generate the stubs for a given module."""
|
||||
|
||||
def __init__(self, module: ModuleType, classes: dict[str, Type[Component]]):
|
||||
"""Initialize the stub generator.
|
||||
|
||||
Args:
|
||||
module: The actual module object module to generate stubs for.
|
||||
classes: The actual Component class objects to generate stubs for.
|
||||
"""
|
||||
super().__init__()
|
||||
# Dict mapping class name to actual class object.
|
||||
self.classes = classes
|
||||
# Track the last class node that was visited.
|
||||
self.current_class = None
|
||||
# These imports will be included in the AST of stub files.
|
||||
self.typing_imports = DEFAULT_TYPING_IMPORTS
|
||||
# Whether those typing imports have been inserted yet.
|
||||
self.inserted_imports = False
|
||||
# Collected import statements from the module.
|
||||
self.import_statements: list[str] = []
|
||||
# This dict is used when evaluating type hints.
|
||||
self.type_hint_globals = module.__dict__.copy()
|
||||
|
||||
@staticmethod
|
||||
def _remove_docstring(
|
||||
node: ast.Module | ast.ClassDef | ast.FunctionDef,
|
||||
) -> ast.Module | ast.ClassDef | ast.FunctionDef:
|
||||
"""Removes any docstring in place.
|
||||
|
||||
Args:
|
||||
node: The node to remove the docstring from.
|
||||
|
||||
Returns:
|
||||
The modified node.
|
||||
"""
|
||||
if (
|
||||
node.body
|
||||
and isinstance(node.body[0], ast.Expr)
|
||||
and isinstance(node.body[0].value, ast.Constant)
|
||||
):
|
||||
node.body.pop(0)
|
||||
return node
|
||||
|
||||
def visit_Module(self, node: ast.Module) -> ast.Module:
|
||||
"""Visit a Module node and remove docstring from body.
|
||||
|
||||
Args:
|
||||
node: The Module node to visit.
|
||||
|
||||
Returns:
|
||||
The modified Module node.
|
||||
"""
|
||||
self.generic_visit(node)
|
||||
return self._remove_docstring(node) # type: ignore
|
||||
|
||||
def visit_Import(
|
||||
self, node: ast.Import | ast.ImportFrom
|
||||
) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom]:
|
||||
"""Collect import statements from the module.
|
||||
|
||||
If this is the first import statement, insert the typing imports before it.
|
||||
|
||||
Args:
|
||||
node: The import node to visit.
|
||||
|
||||
Returns:
|
||||
The modified import node(s).
|
||||
"""
|
||||
self.import_statements.append(ast.unparse(node))
|
||||
if not self.inserted_imports:
|
||||
self.inserted_imports = True
|
||||
return _generate_imports(self.typing_imports) + [node]
|
||||
return node
|
||||
|
||||
def visit_ImportFrom(
|
||||
self, node: ast.ImportFrom
|
||||
) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom] | None:
|
||||
"""Visit an ImportFrom node.
|
||||
|
||||
Remove any `from __future__ import *` statements, and hand off to visit_Import.
|
||||
|
||||
Args:
|
||||
node: The ImportFrom node to visit.
|
||||
|
||||
Returns:
|
||||
The modified ImportFrom node.
|
||||
"""
|
||||
if node.module == "__future__":
|
||||
return None # ignore __future__ imports
|
||||
return self.visit_Import(node)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
"""Visit a ClassDef node.
|
||||
|
||||
Remove all assignments in the class body, and add a create functiondef
|
||||
if one does not exist.
|
||||
|
||||
Args:
|
||||
node: The ClassDef node to visit.
|
||||
|
||||
Returns:
|
||||
The modified ClassDef node.
|
||||
"""
|
||||
exec("\n".join(self.import_statements), self.type_hint_globals)
|
||||
self.current_class = node.name
|
||||
self._remove_docstring(node)
|
||||
self.generic_visit(node) # Visit child nodes.
|
||||
if not node.body:
|
||||
# We should never return an empty body.
|
||||
node.body.append(ast.Expr(value=ast.Ellipsis()))
|
||||
if (
|
||||
not any(
|
||||
isinstance(child, ast.FunctionDef) and child.name == "create"
|
||||
for child in node.body
|
||||
)
|
||||
and self.current_class in self.classes
|
||||
):
|
||||
# Add a new .create FunctionDef since one does not exist.
|
||||
node.body.append(
|
||||
_generate_component_create_functiondef(
|
||||
node=None,
|
||||
clz=self.classes[self.current_class],
|
||||
type_hint_globals=self.type_hint_globals,
|
||||
)
|
||||
)
|
||||
self.current_class = None
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
"""Visit a FunctionDef node.
|
||||
|
||||
Special handling for `.create` functions to add type hints for all props
|
||||
defined on the component class.
|
||||
|
||||
Remove all private functions and blank out the function body of the
|
||||
remaining public functions.
|
||||
|
||||
Args:
|
||||
node: The FunctionDef node to visit.
|
||||
|
||||
Returns:
|
||||
The modified FunctionDef node (or None).
|
||||
"""
|
||||
if node.name == "create" and self.current_class in self.classes:
|
||||
node = _generate_component_create_functiondef(
|
||||
node, self.classes[self.current_class], self.type_hint_globals
|
||||
)
|
||||
else:
|
||||
if node.name.startswith("_"):
|
||||
return None # remove private methods
|
||||
|
||||
# Blank out the function body for public functions.
|
||||
node.body = [ast.Expr(value=ast.Ellipsis())]
|
||||
return node
|
||||
|
||||
def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
|
||||
"""Remove non-annotated assignment statements.
|
||||
|
||||
Args:
|
||||
node: The Assign node to visit.
|
||||
|
||||
Returns:
|
||||
The modified Assign node (or None).
|
||||
"""
|
||||
# Special case for assignments to `typing.Any` as fallback.
|
||||
if (
|
||||
node.value is not None
|
||||
and isinstance(node.value, ast.Name)
|
||||
and node.value.id == "Any"
|
||||
):
|
||||
return node
|
||||
return None
|
||||
|
||||
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign | None:
|
||||
"""Visit an AnnAssign node (Annotated assignment).
|
||||
|
||||
Remove private target and remove the assignment value in the stub.
|
||||
|
||||
Args:
|
||||
node: The AnnAssign node to visit.
|
||||
|
||||
Returns:
|
||||
The modified AnnAssign node (or None).
|
||||
"""
|
||||
if isinstance(node.target, ast.Name) and node.target.id.startswith("_"):
|
||||
return None
|
||||
if self.current_class in self.classes:
|
||||
# Remove annotated assignments in Component classes (props)
|
||||
return None
|
||||
# Blank out assignments in type stubs.
|
||||
node.value = None
|
||||
return node
|
||||
|
||||
|
||||
class PyiGenerator:
|
||||
@ -244,215 +543,57 @@ class PyiGenerator:
|
||||
current_module: Any = {}
|
||||
default_typing_imports: set = DEFAULT_TYPING_IMPORTS
|
||||
|
||||
def _generate_imports(self, variables, classes):
|
||||
variables_imports = {
|
||||
type(_var) for _, _var in variables if isinstance(_var, Component)
|
||||
}
|
||||
bases = {
|
||||
base
|
||||
for _, _class in classes
|
||||
for base in _class.__bases__
|
||||
if inspect.getmodule(base) != self.current_module
|
||||
} | variables_imports
|
||||
bases.add(Component)
|
||||
typing_imports = self.default_typing_imports | _get_typing_import(
|
||||
self.current_module
|
||||
)
|
||||
bases = sorted(bases, key=lambda base: base.__name__)
|
||||
return [
|
||||
f"from typing import {','.join(sorted(typing_imports))}",
|
||||
*[f"from {base.__module__} import {base.__name__}" for base in bases],
|
||||
"from reflex.vars import Var, BaseVar, ComputedVar",
|
||||
"from reflex.event import EventHandler, EventChain, EventSpec",
|
||||
"from reflex.style import Style",
|
||||
]
|
||||
|
||||
def _generate_pyi_class(self, _class: type[Component]):
|
||||
create_spec = getfullargspec(_class.create)
|
||||
lines = [
|
||||
"",
|
||||
f"class {_class.__name__}({', '.join([base.__name__ for base in _class.__bases__])}):",
|
||||
]
|
||||
definition = f" @overload\n @classmethod\n def create( # type: ignore\n cls, *children, "
|
||||
|
||||
for kwarg in create_spec.kwonlyargs:
|
||||
if kwarg in create_spec.annotations:
|
||||
definition += f"{kwarg}: {_get_type_hint(create_spec.annotations[kwarg])} = None, "
|
||||
else:
|
||||
definition += f"{kwarg}, "
|
||||
|
||||
all_classes = [c for c in _class.__mro__ if issubclass(c, Component)]
|
||||
all_props = []
|
||||
for target_class in all_classes:
|
||||
for name, value in target_class.__annotations__.items():
|
||||
if (
|
||||
name in create_spec.kwonlyargs
|
||||
or name in EXCLUDED_PROPS
|
||||
or name in all_props
|
||||
):
|
||||
continue
|
||||
all_props.append(name)
|
||||
definition += f"{name}: {_get_type_hint(value)} = None, "
|
||||
|
||||
for trigger in sorted(_class().get_event_triggers().keys()):
|
||||
definition += f"{trigger}: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, "
|
||||
|
||||
definition = definition.rstrip(", ")
|
||||
definition += f", **props) -> '{_class.__name__}':\n"
|
||||
|
||||
definition += self._generate_docstrings(all_classes, all_props)
|
||||
lines.append(definition)
|
||||
lines.append(" ...")
|
||||
return lines
|
||||
|
||||
def _generate_docstrings(self, _classes, _props):
|
||||
props_comments = {}
|
||||
comments = []
|
||||
for _class in _classes:
|
||||
for _i, line in enumerate(inspect.getsource(_class).splitlines()):
|
||||
reached_functions = re.search("def ", line)
|
||||
if reached_functions:
|
||||
# We've reached the functions, so stop.
|
||||
break
|
||||
|
||||
# Get comments for prop
|
||||
if line.strip().startswith("#"):
|
||||
comments.append(line)
|
||||
continue
|
||||
|
||||
# Check if this line has a prop.
|
||||
match = re.search("\\w+:", line)
|
||||
if match is None:
|
||||
# This line doesn't have a var, so continue.
|
||||
continue
|
||||
|
||||
# Get the prop.
|
||||
prop = match.group(0).strip(":")
|
||||
if prop in _props:
|
||||
if not comments: # do not include undocumented props
|
||||
continue
|
||||
props_comments[prop] = "\n".join(
|
||||
[comment.strip().strip("#") for comment in comments]
|
||||
)
|
||||
comments.clear()
|
||||
continue
|
||||
if prop in EXCLUDED_PROPS:
|
||||
comments.clear() # throw away comments for excluded props
|
||||
_class = _classes[0]
|
||||
new_docstring = []
|
||||
for i, line in enumerate(_class.create.__doc__.splitlines()):
|
||||
if i == 0:
|
||||
new_docstring.append(" " * 8 + '"""' + line)
|
||||
else:
|
||||
new_docstring.append(line)
|
||||
if "*children" in line:
|
||||
for nline in [
|
||||
f"{line.split('*')[0]}{n}:{c}" for n, c in props_comments.items()
|
||||
]:
|
||||
new_docstring.append(nline)
|
||||
new_docstring += ['"""']
|
||||
return "\n".join(new_docstring)
|
||||
|
||||
def _generate_pyi_variable(self, _name, _var):
|
||||
return _get_var_definition(self.current_module, _name)
|
||||
|
||||
def _generate_function(self, _name, _func):
|
||||
import textwrap
|
||||
|
||||
# Don't generate indented functions.
|
||||
source = inspect.getsource(_func)
|
||||
if textwrap.dedent(source) != source:
|
||||
return []
|
||||
|
||||
definition = "".join([line for line in source.split(":\n")[0].split("\n")])
|
||||
return [f"{definition}:", " ..."]
|
||||
|
||||
def _write_pyi_file(self, variables, functions, classes):
|
||||
def _write_pyi_file(self, module_path: Path, source: str):
|
||||
pyi_content = [
|
||||
f'"""Stub file for {self.current_module_path}.py"""',
|
||||
f'"""Stub file for {module_path}"""',
|
||||
"# ------------------- DO NOT EDIT ----------------------",
|
||||
"# This file was generated by `scripts/pyi_generator.py`!",
|
||||
"# ------------------------------------------------------",
|
||||
"",
|
||||
]
|
||||
pyi_content.extend(self._generate_imports(variables, classes))
|
||||
|
||||
for _name, _var in variables:
|
||||
pyi_content.append(self._generate_pyi_variable(_name, _var))
|
||||
|
||||
for _fname, _func in functions:
|
||||
pyi_content.extend(self._generate_function(_fname, _func))
|
||||
|
||||
for _, _class in classes:
|
||||
pyi_content.extend(self._generate_pyi_class(_class))
|
||||
|
||||
pyi_filename = f"{self.current_module_path}.pyi"
|
||||
pyi_path = os.path.join(self.root, pyi_filename)
|
||||
|
||||
with open(pyi_path, "w") as pyi_file:
|
||||
pyi_file.write("\n".join(pyi_content))
|
||||
black.format_file_in_place(
|
||||
src=Path(pyi_path),
|
||||
for formatted_line in black.format_file_contents(
|
||||
src_contents=source,
|
||||
fast=True,
|
||||
mode=black.FileMode(),
|
||||
write_back=black.WriteBack.YES,
|
||||
)
|
||||
mode=black.mode.Mode(is_pyi=True),
|
||||
).splitlines():
|
||||
# Bit of a hack here, since the AST cannot represent comments.
|
||||
if formatted_line == " def create(":
|
||||
pyi_content.append(" def create( # type: ignore")
|
||||
else:
|
||||
pyi_content.append(formatted_line)
|
||||
|
||||
def _scan_file(self, file):
|
||||
self.current_module_path = os.path.splitext(file)[0]
|
||||
module_import = os.path.splitext(os.path.join(self.root, file))[0].replace(
|
||||
"/", "."
|
||||
)
|
||||
pyi_path = module_path.with_suffix(".pyi")
|
||||
pyi_path.write_text("\n".join(pyi_content))
|
||||
logger.info(f"Wrote {pyi_path}")
|
||||
|
||||
self.current_module = importlib.import_module(module_import)
|
||||
def _scan_file(self, module_path: Path):
|
||||
module_import = str(module_path.with_suffix("")).replace("/", ".")
|
||||
module = importlib.import_module(module_import)
|
||||
|
||||
local_variables = []
|
||||
for node in ast.parse(inspect.getsource(self.current_module)).body:
|
||||
if isinstance(node, ast.Assign):
|
||||
for t in node.targets:
|
||||
if not isinstance(t, ast.Name):
|
||||
# Skip non-var assignment statements
|
||||
continue
|
||||
if t.id.startswith("_"):
|
||||
# Skip private vars
|
||||
continue
|
||||
obj = getattr(self.current_module, t.id, None)
|
||||
if inspect.isclass(obj) or inspect.isfunction(obj):
|
||||
continue
|
||||
local_variables.append((t.id, obj))
|
||||
|
||||
functions = [
|
||||
(name, obj)
|
||||
for name, obj in vars(self.current_module).items()
|
||||
if not name.startswith("__")
|
||||
and (
|
||||
not inspect.getmodule(obj)
|
||||
or inspect.getmodule(obj) == self.current_module
|
||||
)
|
||||
and inspect.isfunction(obj)
|
||||
]
|
||||
|
||||
class_names = [
|
||||
(name, obj)
|
||||
for name, obj in vars(self.current_module).items()
|
||||
class_names = {
|
||||
name: obj
|
||||
for name, obj in vars(module).items()
|
||||
if inspect.isclass(obj)
|
||||
and issubclass(obj, Component)
|
||||
and obj != Component
|
||||
and inspect.getmodule(obj) == self.current_module
|
||||
]
|
||||
and inspect.getmodule(obj) == module
|
||||
}
|
||||
if not class_names:
|
||||
return
|
||||
print(f"Parsed {file}: Found {[n for n, _ in class_names]}")
|
||||
self._write_pyi_file(local_variables, functions, class_names)
|
||||
|
||||
new_tree = StubGenerator(module, class_names).visit(
|
||||
ast.parse(inspect.getsource(module))
|
||||
)
|
||||
self._write_pyi_file(module_path, ast.unparse(new_tree))
|
||||
|
||||
def _scan_folder(self, folder):
|
||||
for root, _, files in os.walk(folder):
|
||||
self.root = root
|
||||
for file in files:
|
||||
if file in EXCLUDED_FILES:
|
||||
continue
|
||||
if file.endswith(".py"):
|
||||
self._scan_file(file)
|
||||
self._scan_file(Path(root) / file)
|
||||
|
||||
def scan_all(self, targets):
|
||||
"""Scan all targets for class inheriting Component and generate the .pyi files.
|
||||
@ -462,14 +603,16 @@ class PyiGenerator:
|
||||
"""
|
||||
for target in targets:
|
||||
if target.endswith(".py"):
|
||||
self.root, _, file = target.rpartition("/")
|
||||
self._scan_file(file)
|
||||
self._scan_file(Path(target))
|
||||
else:
|
||||
self._scan_folder(target)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.INFO)
|
||||
|
||||
targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
|
||||
print(f"Running .pyi generator for {targets}")
|
||||
logger.info(f"Running .pyi generator for {targets}")
|
||||
gen = PyiGenerator()
|
||||
gen.scan_all(targets)
|
||||
|
Loading…
Reference in New Issue
Block a user