293 lines
9.7 KiB
Python
293 lines
9.7 KiB
Python
"""The pyi generator module."""
|
|
|
|
import importlib
|
|
import inspect
|
|
import os
|
|
import re
|
|
import sys
|
|
from inspect import getfullargspec
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Union, get_args # NOQA
|
|
|
|
import black
|
|
|
|
from reflex.components.component import Component
|
|
from reflex.vars import Var
|
|
|
|
ruff_dont_remove = [Var, Optional, Dict, List]
|
|
|
|
EXCLUDED_FILES = [
|
|
"__init__.py",
|
|
"component.py",
|
|
"bare.py",
|
|
"foreach.py",
|
|
"cond.py",
|
|
"multiselect.py",
|
|
]
|
|
|
|
DEFAULT_TYPING_IMPORTS = {"overload", "Optional", "Union"}
|
|
|
|
|
|
def _get_type_hint(value, top_level=True, no_union=False):
|
|
res = ""
|
|
args = get_args(value)
|
|
if args:
|
|
res = f"{value.__name__}[{', '.join([_get_type_hint(arg, top_level=False) for arg in args if arg is not type(None)])}]"
|
|
|
|
if value.__name__ == "Var":
|
|
types = [res] + [
|
|
_get_type_hint(arg, top_level=False)
|
|
for arg in args
|
|
if arg is not type(None)
|
|
]
|
|
if len(types) > 1 and not no_union:
|
|
res = ", ".join(types)
|
|
res = f"Union[{res}]"
|
|
elif isinstance(value, str):
|
|
ev = eval(value)
|
|
res = _get_type_hint(ev, top_level=False) if ev.__name__ == "Var" else value
|
|
else:
|
|
res = value.__name__
|
|
if top_level and not res.startswith("Optional"):
|
|
res = f"Optional[{res}]"
|
|
return res
|
|
|
|
|
|
def _get_typing_import(_module):
|
|
src = [
|
|
line
|
|
for line in inspect.getsource(_module).split("\n")
|
|
if line.startswith("from typing")
|
|
]
|
|
if len(src):
|
|
return set(src[0].rpartition("from typing import ")[-1].split(", "))
|
|
return set()
|
|
|
|
|
|
def _get_var_definition(_module, _var_name):
|
|
return [
|
|
line.split(" = ")[0]
|
|
for line in inspect.getsource(_module).splitlines()
|
|
if line.startswith(_var_name)
|
|
]
|
|
|
|
|
|
class PyiGenerator:
|
|
"""A .pyi file generator that will scan all defined Component in Reflex and
|
|
generate the approriate stub.
|
|
"""
|
|
|
|
modules: list = []
|
|
root: str = ""
|
|
current_module: Any = {}
|
|
default_typing_imports: set = DEFAULT_TYPING_IMPORTS
|
|
|
|
def _generate_imports(self, variables, classes):
|
|
variables_imports = {
|
|
type(_var) for _, _var in variables if isinstance(_var, Component)
|
|
}
|
|
bases = {
|
|
base
|
|
for _, _class in classes
|
|
for base in _class.__bases__
|
|
if inspect.getmodule(base) != self.current_module
|
|
} | variables_imports
|
|
bases.add(Component)
|
|
typing_imports = self.default_typing_imports | _get_typing_import(
|
|
self.current_module
|
|
)
|
|
bases = sorted(bases, key=lambda base: base.__name__)
|
|
return [
|
|
f"from typing import {','.join(sorted(typing_imports))}",
|
|
*[f"from {base.__module__} import {base.__name__}" for base in bases],
|
|
"from reflex.vars import Var, BaseVar, ComputedVar",
|
|
"from reflex.event import EventHandler, EventChain, EventSpec",
|
|
]
|
|
|
|
def _generate_pyi_class(self, _class: type[Component]):
|
|
create_spec = getfullargspec(_class.create)
|
|
lines = [
|
|
"",
|
|
f"class {_class.__name__}({', '.join([base.__name__ for base in _class.__bases__])}):",
|
|
]
|
|
definition = f" @overload\n @classmethod\n def create(cls, *children, "
|
|
|
|
for kwarg in create_spec.kwonlyargs:
|
|
if kwarg in create_spec.annotations:
|
|
definition += f"{kwarg}: {_get_type_hint(create_spec.annotations[kwarg])} = None, "
|
|
else:
|
|
definition += f"{kwarg}, "
|
|
|
|
for name, value in _class.__annotations__.items():
|
|
if name in create_spec.kwonlyargs:
|
|
continue
|
|
definition += f"{name}: {_get_type_hint(value)} = None, "
|
|
|
|
for trigger in sorted(_class().get_event_triggers().keys()):
|
|
definition += f"{trigger}: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, "
|
|
|
|
definition = definition.rstrip(", ")
|
|
definition += f", **props) -> '{_class.__name__}': # type: ignore\n"
|
|
|
|
definition += self._generate_docstrings(_class, _class.__annotations__.keys())
|
|
lines.append(definition)
|
|
lines.append(" ...")
|
|
return lines
|
|
|
|
def _generate_docstrings(self, _class, _props):
|
|
props_comments = {}
|
|
comments = []
|
|
for _i, line in enumerate(inspect.getsource(_class).splitlines()):
|
|
reached_functions = re.search("def ", line)
|
|
if reached_functions:
|
|
# We've reached the functions, so stop.
|
|
break
|
|
|
|
# Get comments for prop
|
|
if line.strip().startswith("#"):
|
|
comments.append(line)
|
|
continue
|
|
|
|
# Check if this line has a prop.
|
|
match = re.search("\\w+:", line)
|
|
if match is None:
|
|
# This line doesn't have a var, so continue.
|
|
continue
|
|
|
|
# Get the prop.
|
|
prop = match.group(0).strip(":")
|
|
if prop in _props:
|
|
# This isn't a prop, so continue.
|
|
props_comments[prop] = "\n".join(
|
|
[comment.strip().strip("#") for comment in comments]
|
|
)
|
|
comments.clear()
|
|
continue
|
|
new_docstring = []
|
|
for i, line in enumerate(_class.create.__doc__.splitlines()):
|
|
if i == 0:
|
|
new_docstring.append(" " * 8 + '"""' + line)
|
|
else:
|
|
new_docstring.append(line)
|
|
if "*children" in line:
|
|
for nline in [
|
|
f"{line.split('*')[0]}{n}:{c}" for n, c in props_comments.items()
|
|
]:
|
|
new_docstring.append(nline)
|
|
new_docstring += ['"""']
|
|
return "\n".join(new_docstring)
|
|
|
|
def _generate_pyi_variable(self, _name, _var):
|
|
return _get_var_definition(self.current_module, _name)
|
|
|
|
def _generate_function(self, _name, _func):
|
|
import textwrap
|
|
|
|
# Don't generate indented functions.
|
|
source = inspect.getsource(_func)
|
|
if textwrap.dedent(source) != source:
|
|
return []
|
|
|
|
definition = "".join([line for line in source.split(":\n")[0].split("\n")])
|
|
return [f"{definition}:", " ..."]
|
|
|
|
def _write_pyi_file(self, variables, functions, classes):
|
|
pyi_content = [
|
|
f'"""Stub file for {self.current_module_path}.py"""',
|
|
"# ------------------- DO NOT EDIT ----------------------",
|
|
"# This file was generated by `scripts/pyi_generator.py`!",
|
|
"# ------------------------------------------------------",
|
|
"",
|
|
]
|
|
pyi_content.extend(self._generate_imports(variables, classes))
|
|
|
|
for _name, _var in variables:
|
|
pyi_content.extend(self._generate_pyi_variable(_name, _var))
|
|
|
|
for _fname, _func in functions:
|
|
pyi_content.extend(self._generate_function(_fname, _func))
|
|
|
|
for _, _class in classes:
|
|
pyi_content.extend(self._generate_pyi_class(_class))
|
|
|
|
pyi_filename = f"{self.current_module_path}.pyi"
|
|
pyi_path = os.path.join(self.root, pyi_filename)
|
|
|
|
with open(pyi_path, "w") as pyi_file:
|
|
pyi_file.write("\n".join(pyi_content))
|
|
black.format_file_in_place(
|
|
src=Path(pyi_path),
|
|
fast=True,
|
|
mode=black.FileMode(),
|
|
write_back=black.WriteBack.YES,
|
|
)
|
|
|
|
def _scan_file(self, file):
|
|
self.current_module_path = os.path.splitext(file)[0]
|
|
module_import = os.path.splitext(os.path.join(self.root, file))[0].replace(
|
|
"/", "."
|
|
)
|
|
|
|
self.current_module = importlib.import_module(module_import)
|
|
|
|
local_variables = [
|
|
(name, obj)
|
|
for name, obj in vars(self.current_module).items()
|
|
if not name.startswith("_")
|
|
and not inspect.isclass(obj)
|
|
and not inspect.isfunction(obj)
|
|
]
|
|
|
|
functions = [
|
|
(name, obj)
|
|
for name, obj in vars(self.current_module).items()
|
|
if not name.startswith("__")
|
|
and (
|
|
not inspect.getmodule(obj)
|
|
or inspect.getmodule(obj) == self.current_module
|
|
)
|
|
and inspect.isfunction(obj)
|
|
]
|
|
|
|
class_names = [
|
|
(name, obj)
|
|
for name, obj in vars(self.current_module).items()
|
|
if inspect.isclass(obj)
|
|
and issubclass(obj, Component)
|
|
and obj != Component
|
|
and inspect.getmodule(obj) == self.current_module
|
|
]
|
|
if not class_names:
|
|
return
|
|
print(f"Parsed {file}: Found {[n for n,_ in class_names]}")
|
|
self._write_pyi_file(local_variables, functions, class_names)
|
|
|
|
def _scan_folder(self, folder):
|
|
for root, _, files in os.walk(folder):
|
|
self.root = root
|
|
for file in files:
|
|
if file in EXCLUDED_FILES:
|
|
continue
|
|
if file.endswith(".py"):
|
|
self._scan_file(file)
|
|
|
|
def scan_all(self, targets):
|
|
"""Scan all targets for class inheriting Component and generate the .pyi files.
|
|
|
|
Args:
|
|
targets: the list of file/folders to scan.
|
|
"""
|
|
for target in targets:
|
|
if target.endswith(".py"):
|
|
self.root, _, file = target.rpartition("/")
|
|
self._scan_file(file)
|
|
else:
|
|
self._scan_folder(target)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
|
|
print(f"Running .pyi generator for {targets}")
|
|
gen = PyiGenerator()
|
|
gen.scan_all(targets)
|