Refactor pyi_generator to use AST directly (#2034)

This commit is contained in:
Masen Furer 2023-10-26 10:56:35 -07:00 committed by GitHub
parent fe01f0cf11
commit 5d590c350e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,158 +1,27 @@
"""The pyi generator module."""
import ast
import contextlib
import importlib
import inspect
import logging
import os
import re
import sys
import textwrap
from inspect import getfullargspec
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union, get_args # NOQA
from types import ModuleType
from typing import Any, Callable, Iterable, Type, get_args
import black
import black.mode
from reflex.components.component import Component
# NOQA
from reflex.components.graphing.recharts.recharts import (
LiteralAnimationEasing,
LiteralAreaType,
LiteralComposedChartBaseValue,
LiteralDirection,
LiteralGridType,
LiteralIconType,
LiteralIfOverflow,
LiteralInterval,
LiteralLayout,
LiteralLegendAlign,
LiteralLineType,
LiteralOrientationTopBottom,
LiteralOrientationTopBottomLeftRight,
LiteralPolarRadiusType,
LiteralPosition,
LiteralScale,
LiteralShape,
LiteralStackOffset,
LiteralSyncMethod,
LiteralVerticalAlign,
)
from reflex.components.libs.chakra import (
LiteralAlertDialogSize,
LiteralAvatarSize,
LiteralChakraDirection,
LiteralColorScheme,
LiteralDrawerSize,
LiteralImageLoading,
LiteralInputVariant,
LiteralMenuOption,
LiteralMenuStrategy,
LiteralTagSize,
)
from reflex.components.radix.themes.base import (
LiteralAccentColor,
LiteralAlign,
LiteralAppearance,
LiteralGrayColor,
LiteralJustify,
LiteralPanelBackground,
LiteralRadius,
LiteralScaling,
LiteralSize,
LiteralVariant,
)
from reflex.components.radix.themes.components import (
LiteralButtonSize,
LiteralSwitchSize,
)
from reflex.components.radix.themes.layout import (
LiteralBoolNumber,
LiteralContainerSize,
LiteralFlexDirection,
LiteralFlexDisplay,
LiteralFlexWrap,
LiteralGridDisplay,
LiteralGridFlow,
LiteralSectionSize,
)
from reflex.components.radix.themes.typography import (
LiteralLinkUnderline,
LiteralTextAlign,
LiteralTextSize,
LiteralTextTrim,
LiteralTextWeight,
)
# NOQA
from reflex.event import EventChain
from reflex.style import Style
from reflex.utils import format
from reflex.utils import types as rx_types
from reflex.vars import Var
ruff_dont_remove = [
Var,
Optional,
Dict,
List,
EventChain,
Style,
LiteralInputVariant,
LiteralColorScheme,
LiteralChakraDirection,
LiteralTagSize,
LiteralDrawerSize,
LiteralMenuStrategy,
LiteralMenuOption,
LiteralAlertDialogSize,
LiteralAvatarSize,
LiteralImageLoading,
LiteralLayout,
LiteralAnimationEasing,
LiteralGridType,
LiteralPolarRadiusType,
LiteralScale,
LiteralSyncMethod,
LiteralStackOffset,
LiteralComposedChartBaseValue,
LiteralOrientationTopBottom,
LiteralAreaType,
LiteralShape,
LiteralLineType,
LiteralDirection,
LiteralIfOverflow,
LiteralOrientationTopBottomLeftRight,
LiteralInterval,
LiteralLegendAlign,
LiteralVerticalAlign,
LiteralIconType,
LiteralPosition,
LiteralAccentColor,
LiteralAlign,
LiteralAppearance,
LiteralBoolNumber,
LiteralButtonSize,
LiteralContainerSize,
LiteralFlexDirection,
LiteralFlexDisplay,
LiteralFlexWrap,
LiteralGrayColor,
LiteralGridDisplay,
LiteralGridFlow,
LiteralJustify,
LiteralLinkUnderline,
LiteralPanelBackground,
LiteralRadius,
LiteralScaling,
LiteralSectionSize,
LiteralSize,
LiteralSwitchSize,
LiteralTextAlign,
LiteralTextSize,
LiteralTextTrim,
LiteralTextWeight,
LiteralVariant,
]
logger = logging.getLogger("pyi_generator")
EXCLUDED_FILES = [
"__init__.py",
@ -177,18 +46,36 @@ EXCLUDED_PROPS = [
"valid_children",
]
DEFAULT_TYPING_IMPORTS = {"overload", "Any", "Dict", "List", "Optional", "Union"}
DEFAULT_TYPING_IMPORTS = {
"overload",
"Any",
"Dict",
"List",
"Literal",
"Optional",
"Union",
}
def _get_type_hint(value, top_level=True, no_union=False):
def _get_type_hint(value, type_hint_globals, is_optional=True) -> str:
"""Resolve the type hint for value.
Args:
value: The type annotation as a str or actual types/aliases.
type_hint_globals: The globals to use to resolving a type hint str.
is_optional: Whether the type hint should be wrapped in Optional.
Returns:
The resolved type hint as a str.
"""
res = ""
args = get_args(value)
if args:
inner_container_type_args = (
[format.wrap(arg, '"') for arg in args]
[repr(arg) for arg in args]
if rx_types.is_literal(value)
else [
_get_type_hint(arg, top_level=False)
_get_type_hint(arg, type_hint_globals, is_optional=False)
for arg in args
if arg is not type(None)
]
@ -196,42 +83,454 @@ def _get_type_hint(value, top_level=True, no_union=False):
res = f"{value.__name__}[{', '.join(inner_container_type_args)}]"
if value.__name__ == "Var":
# For Var types, Union with the inner args so they can be passed directly.
types = [res] + [
_get_type_hint(arg, top_level=False)
_get_type_hint(arg, type_hint_globals, is_optional=False)
for arg in args
if arg is not type(None)
]
if len(types) > 1 and not no_union:
if len(types) > 1:
res = ", ".join(types)
res = f"Union[{res}]"
elif isinstance(value, str):
ev = eval(value)
res = _get_type_hint(ev, top_level=False) if ev.__name__ == "Var" else value
ev = eval(value, type_hint_globals)
res = (
_get_type_hint(ev, type_hint_globals, is_optional=False)
if ev.__name__ == "Var"
else value
)
else:
res = value.__name__
if top_level and not res.startswith("Optional"):
if is_optional and not res.startswith("Optional"):
res = f"Optional[{res}]"
return res
def _get_typing_import(_module):
src = [
line
for line in inspect.getsource(_module).split("\n")
if line.startswith("from typing")
def _generate_imports(typing_imports: Iterable[str]) -> list[ast.ImportFrom]:
"""Generate the import statements for the stub file.
Args:
typing_imports: The typing imports to include.
Returns:
The list of import statements.
"""
return [
ast.ImportFrom(
module="typing", names=[ast.alias(name=imp) for imp in typing_imports]
),
*ast.parse( # type: ignore
textwrap.dedent(
"""
from reflex.vars import Var, BaseVar, ComputedVar
from reflex.event import EventChain, EventHandler, EventSpec
from reflex.style import Style"""
)
).body,
]
if len(src):
return set(src[0].rpartition("from typing import ")[-1].split(", "))
return set()
def _get_var_definition(_module, _var_name):
for node in ast.parse(inspect.getsource(_module)).body:
if isinstance(node, ast.Assign) and _var_name in [
t.id for t in node.targets if isinstance(t, ast.Name)
]:
return ast.unparse(node)
raise Exception(f"Could not find var {_var_name} in module {_module}")
def _generate_docstrings(clzs: list[Type[Component]], props: list[str]) -> str:
"""Generate the docstrings for the create method.
Args:
clzs: The classes to generate docstrings for.
props: The props to generate docstrings for.
Returns:
The docstring for the create method.
"""
props_comments = {}
comments = []
for clz in clzs:
for line in inspect.getsource(clz).splitlines():
reached_functions = re.search("def ", line)
if reached_functions:
# We've reached the functions, so stop.
break
# Get comments for prop
if line.strip().startswith("#"):
comments.append(line)
continue
# Check if this line has a prop.
match = re.search("\\w+:", line)
if match is None:
# This line doesn't have a var, so continue.
continue
# Get the prop.
prop = match.group(0).strip(":")
if prop in props:
if not comments: # do not include undocumented props
continue
props_comments[prop] = [
comment.strip().strip("#") for comment in comments
]
comments.clear()
clz = clzs[0]
new_docstring = []
for line in (clz.create.__doc__ or "").splitlines():
if "**" in line:
indent = line.split("**")[0]
for nline in [
f"{indent}{n}:{' '.join(c)}" for n, c in props_comments.items()
]:
new_docstring.append(nline)
new_docstring.append(line)
return "\n".join(new_docstring)
def _extract_func_kwargs_as_ast_nodes(
func: Callable,
type_hint_globals: dict[str, Any],
) -> list[tuple[ast.arg, ast.Constant | None]]:
"""Get the kwargs already defined on the function.
Args:
func: The function to extract kwargs from.
type_hint_globals: The globals to use to resolving a type hint str.
Returns:
The list of kwargs as ast arg nodes.
"""
spec = getfullargspec(func)
kwargs = []
for kwarg in spec.kwonlyargs:
arg = ast.arg(arg=kwarg)
if kwarg in spec.annotations:
arg.annotation = ast.Name(
id=_get_type_hint(spec.annotations[kwarg], type_hint_globals)
)
default = None
if spec.kwonlydefaults is not None and kwarg in spec.kwonlydefaults:
default = ast.Constant(value=spec.kwonlydefaults[kwarg])
kwargs.append((arg, default))
return kwargs
def _extract_class_props_as_ast_nodes(
func: Callable,
clzs: list[Type],
type_hint_globals: dict[str, Any],
extract_real_default: bool = False,
) -> list[tuple[ast.arg, ast.Constant | None]]:
"""Get the props defined on the class and all parents.
Args:
func: The function that kwargs will be added to.
clzs: The classes to extract props from.
type_hint_globals: The globals to use to resolving a type hint str.
extract_real_default: Whether to extract the real default value from the
pydantic field definition.
Returns:
The list of props as ast arg nodes
"""
spec = getfullargspec(func)
all_props = []
kwargs = []
for target_class in clzs:
# Import from the target class to ensure type hints are resolvable.
exec(f"from {target_class.__module__} import *", type_hint_globals)
for name, value in target_class.__annotations__.items():
if name in spec.kwonlyargs or name in EXCLUDED_PROPS or name in all_props:
continue
all_props.append(name)
default = None
if extract_real_default:
# TODO: This is not currently working since the default is not type compatible
# with the annotation in some cases.
with contextlib.suppress(AttributeError, KeyError):
# Try to get default from pydantic field definition.
default = target_class.__fields__[name].default
if isinstance(default, Var):
default = default._decode() # type: ignore
kwargs.append(
(
ast.arg(
arg=name,
annotation=ast.Name(
id=_get_type_hint(value, type_hint_globals)
),
),
ast.Constant(value=default),
)
)
return kwargs
def _generate_component_create_functiondef(
node: ast.FunctionDef | None,
clz: type[Component],
type_hint_globals: dict[str, Any],
) -> ast.FunctionDef:
"""Generate the create function definition for a Component.
Args:
node: The existing create functiondef node from the ast
clz: The Component class to generate the create functiondef for.
type_hint_globals: The globals to use to resolving a type hint str.
Returns:
The create functiondef node for the ast.
"""
# kwargs defined on the actual create function
kwargs = _extract_func_kwargs_as_ast_nodes(clz.create, type_hint_globals)
# kwargs associated with props defined in the class and its parents
all_classes = [c for c in clz.__mro__ if issubclass(c, Component)]
prop_kwargs = _extract_class_props_as_ast_nodes(
clz.create, all_classes, type_hint_globals
)
all_props = [arg[0].arg for arg in prop_kwargs]
kwargs.extend(prop_kwargs)
# event handler kwargs
kwargs.extend(
(
ast.arg(
arg=trigger,
annotation=ast.Name(
id="Optional[Union[EventHandler, EventSpec, List, function, BaseVar]]"
),
),
ast.Constant(value=None),
)
for trigger in sorted(clz().get_event_triggers().keys())
)
logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs")
create_args = ast.arguments(
args=[ast.arg(arg="cls")],
posonlyargs=[],
vararg=ast.arg(arg="children"),
kwonlyargs=[arg[0] for arg in kwargs],
kw_defaults=[arg[1] for arg in kwargs],
kwarg=ast.arg(arg="props"),
defaults=[],
)
definition = ast.FunctionDef(
name="create",
args=create_args,
body=[
ast.Expr(
value=ast.Constant(value=_generate_docstrings(all_classes, all_props))
),
ast.Expr(
value=ast.Ellipsis(),
),
],
decorator_list=[
ast.Name(id="overload"),
*(
node.decorator_list
if node is not None
else [ast.Name(id="classmethod")]
),
],
lineno=node.lineno if node is not None else None,
returns=ast.Constant(value=clz.__name__),
)
return definition
class StubGenerator(ast.NodeTransformer):
"""A node transformer that will generate the stubs for a given module."""
def __init__(self, module: ModuleType, classes: dict[str, Type[Component]]):
"""Initialize the stub generator.
Args:
module: The actual module object module to generate stubs for.
classes: The actual Component class objects to generate stubs for.
"""
super().__init__()
# Dict mapping class name to actual class object.
self.classes = classes
# Track the last class node that was visited.
self.current_class = None
# These imports will be included in the AST of stub files.
self.typing_imports = DEFAULT_TYPING_IMPORTS
# Whether those typing imports have been inserted yet.
self.inserted_imports = False
# Collected import statements from the module.
self.import_statements: list[str] = []
# This dict is used when evaluating type hints.
self.type_hint_globals = module.__dict__.copy()
@staticmethod
def _remove_docstring(
node: ast.Module | ast.ClassDef | ast.FunctionDef,
) -> ast.Module | ast.ClassDef | ast.FunctionDef:
"""Removes any docstring in place.
Args:
node: The node to remove the docstring from.
Returns:
The modified node.
"""
if (
node.body
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Constant)
):
node.body.pop(0)
return node
def visit_Module(self, node: ast.Module) -> ast.Module:
"""Visit a Module node and remove docstring from body.
Args:
node: The Module node to visit.
Returns:
The modified Module node.
"""
self.generic_visit(node)
return self._remove_docstring(node) # type: ignore
def visit_Import(
self, node: ast.Import | ast.ImportFrom
) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom]:
"""Collect import statements from the module.
If this is the first import statement, insert the typing imports before it.
Args:
node: The import node to visit.
Returns:
The modified import node(s).
"""
self.import_statements.append(ast.unparse(node))
if not self.inserted_imports:
self.inserted_imports = True
return _generate_imports(self.typing_imports) + [node]
return node
def visit_ImportFrom(
self, node: ast.ImportFrom
) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom] | None:
"""Visit an ImportFrom node.
Remove any `from __future__ import *` statements, and hand off to visit_Import.
Args:
node: The ImportFrom node to visit.
Returns:
The modified ImportFrom node.
"""
if node.module == "__future__":
return None # ignore __future__ imports
return self.visit_Import(node)
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Visit a ClassDef node.
Remove all assignments in the class body, and add a create functiondef
if one does not exist.
Args:
node: The ClassDef node to visit.
Returns:
The modified ClassDef node.
"""
exec("\n".join(self.import_statements), self.type_hint_globals)
self.current_class = node.name
self._remove_docstring(node)
self.generic_visit(node) # Visit child nodes.
if not node.body:
# We should never return an empty body.
node.body.append(ast.Expr(value=ast.Ellipsis()))
if (
not any(
isinstance(child, ast.FunctionDef) and child.name == "create"
for child in node.body
)
and self.current_class in self.classes
):
# Add a new .create FunctionDef since one does not exist.
node.body.append(
_generate_component_create_functiondef(
node=None,
clz=self.classes[self.current_class],
type_hint_globals=self.type_hint_globals,
)
)
self.current_class = None
return node
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a FunctionDef node.
Special handling for `.create` functions to add type hints for all props
defined on the component class.
Remove all private functions and blank out the function body of the
remaining public functions.
Args:
node: The FunctionDef node to visit.
Returns:
The modified FunctionDef node (or None).
"""
if node.name == "create" and self.current_class in self.classes:
node = _generate_component_create_functiondef(
node, self.classes[self.current_class], self.type_hint_globals
)
else:
if node.name.startswith("_"):
return None # remove private methods
# Blank out the function body for public functions.
node.body = [ast.Expr(value=ast.Ellipsis())]
return node
def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
"""Remove non-annotated assignment statements.
Args:
node: The Assign node to visit.
Returns:
The modified Assign node (or None).
"""
# Special case for assignments to `typing.Any` as fallback.
if (
node.value is not None
and isinstance(node.value, ast.Name)
and node.value.id == "Any"
):
return node
return None
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign | None:
"""Visit an AnnAssign node (Annotated assignment).
Remove private target and remove the assignment value in the stub.
Args:
node: The AnnAssign node to visit.
Returns:
The modified AnnAssign node (or None).
"""
if isinstance(node.target, ast.Name) and node.target.id.startswith("_"):
return None
if self.current_class in self.classes:
# Remove annotated assignments in Component classes (props)
return None
# Blank out assignments in type stubs.
node.value = None
return node
class PyiGenerator:
@ -244,215 +543,57 @@ class PyiGenerator:
current_module: Any = {}
default_typing_imports: set = DEFAULT_TYPING_IMPORTS
def _generate_imports(self, variables, classes):
variables_imports = {
type(_var) for _, _var in variables if isinstance(_var, Component)
}
bases = {
base
for _, _class in classes
for base in _class.__bases__
if inspect.getmodule(base) != self.current_module
} | variables_imports
bases.add(Component)
typing_imports = self.default_typing_imports | _get_typing_import(
self.current_module
)
bases = sorted(bases, key=lambda base: base.__name__)
return [
f"from typing import {','.join(sorted(typing_imports))}",
*[f"from {base.__module__} import {base.__name__}" for base in bases],
"from reflex.vars import Var, BaseVar, ComputedVar",
"from reflex.event import EventHandler, EventChain, EventSpec",
"from reflex.style import Style",
]
def _generate_pyi_class(self, _class: type[Component]):
create_spec = getfullargspec(_class.create)
lines = [
"",
f"class {_class.__name__}({', '.join([base.__name__ for base in _class.__bases__])}):",
]
definition = f" @overload\n @classmethod\n def create( # type: ignore\n cls, *children, "
for kwarg in create_spec.kwonlyargs:
if kwarg in create_spec.annotations:
definition += f"{kwarg}: {_get_type_hint(create_spec.annotations[kwarg])} = None, "
else:
definition += f"{kwarg}, "
all_classes = [c for c in _class.__mro__ if issubclass(c, Component)]
all_props = []
for target_class in all_classes:
for name, value in target_class.__annotations__.items():
if (
name in create_spec.kwonlyargs
or name in EXCLUDED_PROPS
or name in all_props
):
continue
all_props.append(name)
definition += f"{name}: {_get_type_hint(value)} = None, "
for trigger in sorted(_class().get_event_triggers().keys()):
definition += f"{trigger}: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, "
definition = definition.rstrip(", ")
definition += f", **props) -> '{_class.__name__}':\n"
definition += self._generate_docstrings(all_classes, all_props)
lines.append(definition)
lines.append(" ...")
return lines
def _generate_docstrings(self, _classes, _props):
props_comments = {}
comments = []
for _class in _classes:
for _i, line in enumerate(inspect.getsource(_class).splitlines()):
reached_functions = re.search("def ", line)
if reached_functions:
# We've reached the functions, so stop.
break
# Get comments for prop
if line.strip().startswith("#"):
comments.append(line)
continue
# Check if this line has a prop.
match = re.search("\\w+:", line)
if match is None:
# This line doesn't have a var, so continue.
continue
# Get the prop.
prop = match.group(0).strip(":")
if prop in _props:
if not comments: # do not include undocumented props
continue
props_comments[prop] = "\n".join(
[comment.strip().strip("#") for comment in comments]
)
comments.clear()
continue
if prop in EXCLUDED_PROPS:
comments.clear() # throw away comments for excluded props
_class = _classes[0]
new_docstring = []
for i, line in enumerate(_class.create.__doc__.splitlines()):
if i == 0:
new_docstring.append(" " * 8 + '"""' + line)
else:
new_docstring.append(line)
if "*children" in line:
for nline in [
f"{line.split('*')[0]}{n}:{c}" for n, c in props_comments.items()
]:
new_docstring.append(nline)
new_docstring += ['"""']
return "\n".join(new_docstring)
def _generate_pyi_variable(self, _name, _var):
return _get_var_definition(self.current_module, _name)
def _generate_function(self, _name, _func):
import textwrap
# Don't generate indented functions.
source = inspect.getsource(_func)
if textwrap.dedent(source) != source:
return []
definition = "".join([line for line in source.split(":\n")[0].split("\n")])
return [f"{definition}:", " ..."]
def _write_pyi_file(self, variables, functions, classes):
def _write_pyi_file(self, module_path: Path, source: str):
pyi_content = [
f'"""Stub file for {self.current_module_path}.py"""',
f'"""Stub file for {module_path}"""',
"# ------------------- DO NOT EDIT ----------------------",
"# This file was generated by `scripts/pyi_generator.py`!",
"# ------------------------------------------------------",
"",
]
pyi_content.extend(self._generate_imports(variables, classes))
for _name, _var in variables:
pyi_content.append(self._generate_pyi_variable(_name, _var))
for _fname, _func in functions:
pyi_content.extend(self._generate_function(_fname, _func))
for _, _class in classes:
pyi_content.extend(self._generate_pyi_class(_class))
pyi_filename = f"{self.current_module_path}.pyi"
pyi_path = os.path.join(self.root, pyi_filename)
with open(pyi_path, "w") as pyi_file:
pyi_file.write("\n".join(pyi_content))
black.format_file_in_place(
src=Path(pyi_path),
for formatted_line in black.format_file_contents(
src_contents=source,
fast=True,
mode=black.FileMode(),
write_back=black.WriteBack.YES,
)
mode=black.mode.Mode(is_pyi=True),
).splitlines():
# Bit of a hack here, since the AST cannot represent comments.
if formatted_line == " def create(":
pyi_content.append(" def create( # type: ignore")
else:
pyi_content.append(formatted_line)
def _scan_file(self, file):
self.current_module_path = os.path.splitext(file)[0]
module_import = os.path.splitext(os.path.join(self.root, file))[0].replace(
"/", "."
)
pyi_path = module_path.with_suffix(".pyi")
pyi_path.write_text("\n".join(pyi_content))
logger.info(f"Wrote {pyi_path}")
self.current_module = importlib.import_module(module_import)
def _scan_file(self, module_path: Path):
module_import = str(module_path.with_suffix("")).replace("/", ".")
module = importlib.import_module(module_import)
local_variables = []
for node in ast.parse(inspect.getsource(self.current_module)).body:
if isinstance(node, ast.Assign):
for t in node.targets:
if not isinstance(t, ast.Name):
# Skip non-var assignment statements
continue
if t.id.startswith("_"):
# Skip private vars
continue
obj = getattr(self.current_module, t.id, None)
if inspect.isclass(obj) or inspect.isfunction(obj):
continue
local_variables.append((t.id, obj))
functions = [
(name, obj)
for name, obj in vars(self.current_module).items()
if not name.startswith("__")
and (
not inspect.getmodule(obj)
or inspect.getmodule(obj) == self.current_module
)
and inspect.isfunction(obj)
]
class_names = [
(name, obj)
for name, obj in vars(self.current_module).items()
class_names = {
name: obj
for name, obj in vars(module).items()
if inspect.isclass(obj)
and issubclass(obj, Component)
and obj != Component
and inspect.getmodule(obj) == self.current_module
]
and inspect.getmodule(obj) == module
}
if not class_names:
return
print(f"Parsed {file}: Found {[n for n, _ in class_names]}")
self._write_pyi_file(local_variables, functions, class_names)
new_tree = StubGenerator(module, class_names).visit(
ast.parse(inspect.getsource(module))
)
self._write_pyi_file(module_path, ast.unparse(new_tree))
def _scan_folder(self, folder):
for root, _, files in os.walk(folder):
self.root = root
for file in files:
if file in EXCLUDED_FILES:
continue
if file.endswith(".py"):
self._scan_file(file)
self._scan_file(Path(root) / file)
def scan_all(self, targets):
"""Scan all targets for class inheriting Component and generate the .pyi files.
@ -462,14 +603,16 @@ class PyiGenerator:
"""
for target in targets:
if target.endswith(".py"):
self.root, _, file = target.rpartition("/")
self._scan_file(file)
self._scan_file(Path(target))
else:
self._scan_folder(target)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.INFO)
targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
print(f"Running .pyi generator for {targets}")
logger.info(f"Running .pyi generator for {targets}")
gen = PyiGenerator()
gen.scan_all(targets)