
* add eradicate rules for commented out code * remove output change * fix pyi messed up indent * fix pyi again * fix layout docstring * fix pyi_generator to remove commented out props from docs * fix pyi_generator and regenerate some pyi * fix double strip * update all pyi * try to fix stuff in pyi_gen * whatever * remove that maybe? i don't know * fix that shit? * fix more shit, idk * better not see you ever again, extra line
1207 lines
40 KiB
Python
1207 lines
40 KiB
Python
"""The pyi generator module."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import contextlib
|
|
import importlib
|
|
import inspect
|
|
import logging
|
|
import re
|
|
import subprocess
|
|
import typing
|
|
from fileinput import FileInput
|
|
from inspect import getfullargspec
|
|
from itertools import chain
|
|
from multiprocessing import Pool, cpu_count
|
|
from pathlib import Path
|
|
from types import ModuleType, SimpleNamespace
|
|
from typing import Any, Callable, Iterable, Sequence, Type, get_args, get_origin
|
|
|
|
from reflex.components.component import Component
|
|
from reflex.utils import types as rx_types
|
|
from reflex.vars.base import Var
|
|
|
|
logger = logging.getLogger("pyi_generator")
|
|
|
|
PWD = Path.cwd()
|
|
|
|
EXCLUDED_FILES = [
|
|
"app.py",
|
|
"component.py",
|
|
"bare.py",
|
|
"foreach.py",
|
|
"cond.py",
|
|
"match.py",
|
|
"multiselect.py",
|
|
"literals.py",
|
|
]
|
|
|
|
# These props exist on the base component, but should not be exposed in create methods.
|
|
EXCLUDED_PROPS = [
|
|
"alias",
|
|
"children",
|
|
"event_triggers",
|
|
"library",
|
|
"lib_dependencies",
|
|
"tag",
|
|
"is_default",
|
|
"special_props",
|
|
"_invalid_children",
|
|
"_memoization_mode",
|
|
"_rename_props",
|
|
"_valid_children",
|
|
"_valid_parents",
|
|
"State",
|
|
]
|
|
|
|
DEFAULT_TYPING_IMPORTS = {
|
|
"overload",
|
|
"Any",
|
|
"Callable",
|
|
"Dict",
|
|
# "List",
|
|
"Literal",
|
|
"Optional",
|
|
"Union",
|
|
}
|
|
|
|
# TODO: fix import ordering and unused imports with ruff later
|
|
DEFAULT_IMPORTS = {
|
|
"typing": sorted(DEFAULT_TYPING_IMPORTS),
|
|
"reflex.components.core.breakpoints": ["Breakpoints"],
|
|
"reflex.event": [
|
|
"EventChain",
|
|
"EventHandler",
|
|
"EventSpec",
|
|
"EventType",
|
|
"BASE_STATE",
|
|
"KeyInputInfo",
|
|
],
|
|
"reflex.style": ["Style"],
|
|
"reflex.vars.base": ["Var"],
|
|
}
|
|
|
|
|
|
def _walk_files(path):
|
|
"""Walk all files in a path.
|
|
This can be replaced with Path.walk() in python3.12.
|
|
|
|
Args:
|
|
path: The path to walk.
|
|
|
|
Yields:
|
|
The next file in the path.
|
|
"""
|
|
for p in Path(path).iterdir():
|
|
if p.is_dir():
|
|
yield from _walk_files(p)
|
|
continue
|
|
yield p.resolve()
|
|
|
|
|
|
def _relative_to_pwd(path: Path) -> Path:
|
|
"""Get the relative path of a path to the current working directory.
|
|
|
|
Args:
|
|
path: The path to get the relative path for.
|
|
|
|
Returns:
|
|
The relative path.
|
|
"""
|
|
if path.is_absolute():
|
|
return path.relative_to(PWD)
|
|
return path
|
|
|
|
|
|
def _get_type_hint(value, type_hint_globals, is_optional=True) -> str:
|
|
"""Resolve the type hint for value.
|
|
|
|
Args:
|
|
value: The type annotation as a str or actual types/aliases.
|
|
type_hint_globals: The globals to use to resolving a type hint str.
|
|
is_optional: Whether the type hint should be wrapped in Optional.
|
|
|
|
Returns:
|
|
The resolved type hint as a str.
|
|
|
|
Raises:
|
|
TypeError: If the value name is not visible in the type hint globals.
|
|
"""
|
|
res = ""
|
|
args = get_args(value)
|
|
|
|
if value is type(None):
|
|
return "None"
|
|
|
|
if rx_types.is_union(value):
|
|
if type(None) in value.__args__:
|
|
res_args = [
|
|
_get_type_hint(arg, type_hint_globals, rx_types.is_optional(arg))
|
|
for arg in value.__args__
|
|
if arg is not type(None)
|
|
]
|
|
res_args.sort()
|
|
if len(res_args) == 1:
|
|
return f"Optional[{res_args[0]}]"
|
|
else:
|
|
res = f"Union[{', '.join(res_args)}]"
|
|
return f"Optional[{res}]"
|
|
|
|
res_args = [
|
|
_get_type_hint(arg, type_hint_globals, rx_types.is_optional(arg))
|
|
for arg in value.__args__
|
|
]
|
|
res_args.sort()
|
|
return f"Union[{', '.join(res_args)}]"
|
|
|
|
if args:
|
|
inner_container_type_args = (
|
|
sorted((repr(arg) for arg in args))
|
|
if rx_types.is_literal(value)
|
|
else [
|
|
_get_type_hint(arg, type_hint_globals, is_optional=False)
|
|
for arg in args
|
|
if arg is not type(None)
|
|
]
|
|
)
|
|
|
|
if (
|
|
value.__module__ not in ["builtins", "__builtins__"]
|
|
and value.__name__ not in type_hint_globals
|
|
):
|
|
raise TypeError(
|
|
f"{value.__module__ + '.' + value.__name__} is not a default import, "
|
|
"add it to DEFAULT_IMPORTS in pyi_generator.py"
|
|
)
|
|
|
|
res = f"{value.__name__}[{', '.join(inner_container_type_args)}]"
|
|
|
|
if value.__name__ == "Var":
|
|
args = list(
|
|
chain.from_iterable(
|
|
[get_args(arg) if rx_types.is_union(arg) else [arg] for arg in args]
|
|
)
|
|
)
|
|
|
|
# For Var types, Union with the inner args so they can be passed directly.
|
|
types = [res] + [
|
|
_get_type_hint(arg, type_hint_globals, is_optional=False)
|
|
for arg in args
|
|
if arg is not type(None)
|
|
]
|
|
if len(types) > 1:
|
|
res = ", ".join(sorted(types))
|
|
res = f"Union[{res}]"
|
|
elif isinstance(value, str):
|
|
ev = eval(value, type_hint_globals)
|
|
if rx_types.is_optional(ev):
|
|
return _get_type_hint(ev, type_hint_globals, is_optional=False)
|
|
|
|
if rx_types.is_union(ev):
|
|
res = [
|
|
_get_type_hint(arg, type_hint_globals, rx_types.is_optional(arg))
|
|
for arg in ev.__args__
|
|
]
|
|
return f"Union[{', '.join(res)}]"
|
|
res = (
|
|
_get_type_hint(ev, type_hint_globals, is_optional=False)
|
|
if ev.__name__ == "Var"
|
|
else value
|
|
)
|
|
else:
|
|
res = value.__name__
|
|
if is_optional and not res.startswith("Optional"):
|
|
res = f"Optional[{res}]"
|
|
return res
|
|
|
|
|
|
def _generate_imports(
|
|
typing_imports: Iterable[str],
|
|
) -> list[ast.ImportFrom | ast.Import]:
|
|
"""Generate the import statements for the stub file.
|
|
|
|
Args:
|
|
typing_imports: The typing imports to include.
|
|
|
|
Returns:
|
|
The list of import statements.
|
|
"""
|
|
return [
|
|
*[
|
|
ast.ImportFrom(module=name, names=[ast.alias(name=val) for val in values])
|
|
for name, values in DEFAULT_IMPORTS.items()
|
|
],
|
|
ast.Import([ast.alias("reflex")]),
|
|
]
|
|
|
|
|
|
def _generate_docstrings(clzs: list[Type[Component]], props: list[str]) -> str:
|
|
"""Generate the docstrings for the create method.
|
|
|
|
Args:
|
|
clzs: The classes to generate docstrings for.
|
|
props: The props to generate docstrings for.
|
|
|
|
Returns:
|
|
The docstring for the create method.
|
|
"""
|
|
props_comments = {}
|
|
comments = []
|
|
for clz in clzs:
|
|
for line in inspect.getsource(clz).splitlines():
|
|
reached_functions = re.search("def ", line)
|
|
if reached_functions:
|
|
# We've reached the functions, so stop.
|
|
break
|
|
|
|
if line == "":
|
|
# We hit a blank line, so clear comments to avoid commented out prop appearing in next prop docs.
|
|
comments.clear()
|
|
continue
|
|
|
|
# Get comments for prop
|
|
if line.strip().startswith("#"):
|
|
# Remove noqa from the comments.
|
|
line = line.partition(" # noqa")[0]
|
|
comments.append(line)
|
|
continue
|
|
|
|
# Check if this line has a prop.
|
|
match = re.search("\\w+:", line)
|
|
if match is None:
|
|
# This line doesn't have a var, so continue.
|
|
continue
|
|
|
|
# Get the prop.
|
|
prop = match.group(0).strip(":")
|
|
if prop in props:
|
|
if not comments: # do not include undocumented props
|
|
continue
|
|
props_comments[prop] = [
|
|
comment.strip().strip("#") for comment in comments
|
|
]
|
|
comments.clear()
|
|
clz = clzs[0]
|
|
new_docstring = []
|
|
for line in (clz.create.__doc__ or "").splitlines():
|
|
if "**" in line:
|
|
indent = line.split("**")[0]
|
|
for nline in [
|
|
f"{indent}{n}:{' '.join(c)}" for n, c in props_comments.items()
|
|
]:
|
|
new_docstring.append(nline)
|
|
new_docstring.append(line)
|
|
return "\n".join(new_docstring)
|
|
|
|
|
|
def _extract_func_kwargs_as_ast_nodes(
|
|
func: Callable,
|
|
type_hint_globals: dict[str, Any],
|
|
) -> list[tuple[ast.arg, ast.Constant | None]]:
|
|
"""Get the kwargs already defined on the function.
|
|
|
|
Args:
|
|
func: The function to extract kwargs from.
|
|
type_hint_globals: The globals to use to resolving a type hint str.
|
|
|
|
Returns:
|
|
The list of kwargs as ast arg nodes.
|
|
"""
|
|
spec = getfullargspec(func)
|
|
kwargs = []
|
|
|
|
for kwarg in spec.kwonlyargs:
|
|
arg = ast.arg(arg=kwarg)
|
|
if kwarg in spec.annotations:
|
|
arg.annotation = ast.Name(
|
|
id=_get_type_hint(spec.annotations[kwarg], type_hint_globals)
|
|
)
|
|
default = None
|
|
if spec.kwonlydefaults is not None and kwarg in spec.kwonlydefaults:
|
|
default = ast.Constant(value=spec.kwonlydefaults[kwarg])
|
|
kwargs.append((arg, default))
|
|
return kwargs
|
|
|
|
|
|
def _extract_class_props_as_ast_nodes(
|
|
func: Callable,
|
|
clzs: list[Type],
|
|
type_hint_globals: dict[str, Any],
|
|
extract_real_default: bool = False,
|
|
) -> list[tuple[ast.arg, ast.Constant | None]]:
|
|
"""Get the props defined on the class and all parents.
|
|
|
|
Args:
|
|
func: The function that kwargs will be added to.
|
|
clzs: The classes to extract props from.
|
|
type_hint_globals: The globals to use to resolving a type hint str.
|
|
extract_real_default: Whether to extract the real default value from the
|
|
pydantic field definition.
|
|
|
|
Returns:
|
|
The list of props as ast arg nodes
|
|
"""
|
|
spec = getfullargspec(func)
|
|
all_props = []
|
|
kwargs = []
|
|
for target_class in clzs:
|
|
event_triggers = target_class().get_event_triggers()
|
|
# Import from the target class to ensure type hints are resolvable.
|
|
exec(f"from {target_class.__module__} import *", type_hint_globals)
|
|
for name, value in target_class.__annotations__.items():
|
|
if (
|
|
name in spec.kwonlyargs
|
|
or name in EXCLUDED_PROPS
|
|
or name in all_props
|
|
or name in event_triggers
|
|
or (isinstance(value, str) and "ClassVar" in value)
|
|
):
|
|
continue
|
|
all_props.append(name)
|
|
|
|
default = None
|
|
if extract_real_default:
|
|
# TODO: This is not currently working since the default is not type compatible
|
|
# with the annotation in some cases.
|
|
with contextlib.suppress(AttributeError, KeyError):
|
|
# Try to get default from pydantic field definition.
|
|
default = target_class.__fields__[name].default
|
|
if isinstance(default, Var):
|
|
default = default._decode() # type: ignore
|
|
|
|
kwargs.append(
|
|
(
|
|
ast.arg(
|
|
arg=name,
|
|
annotation=ast.Name(
|
|
id=_get_type_hint(value, type_hint_globals)
|
|
),
|
|
),
|
|
ast.Constant(value=default),
|
|
)
|
|
)
|
|
return kwargs
|
|
|
|
|
|
def type_to_ast(typ, cls: type) -> ast.AST:
|
|
"""Converts any type annotation into its AST representation.
|
|
Handles nested generic types, unions, etc.
|
|
|
|
Args:
|
|
typ: The type annotation to convert.
|
|
cls: The class where the type annotation is used.
|
|
|
|
Returns:
|
|
The AST representation of the type annotation.
|
|
"""
|
|
if typ is type(None):
|
|
return ast.Name(id="None")
|
|
|
|
origin = get_origin(typ)
|
|
|
|
# Handle plain types (int, str, custom classes, etc.)
|
|
if origin is None:
|
|
if hasattr(typ, "__name__"):
|
|
if typ.__module__.startswith("reflex."):
|
|
typ_parts = typ.__module__.split(".")
|
|
cls_parts = cls.__module__.split(".")
|
|
|
|
zipped = list(zip(typ_parts, cls_parts, strict=False))
|
|
|
|
if all(a == b for a, b in zipped) and len(typ_parts) == len(cls_parts):
|
|
return ast.Name(id=typ.__name__)
|
|
|
|
return ast.Name(id=typ.__module__ + "." + typ.__name__)
|
|
return ast.Name(id=typ.__name__)
|
|
elif hasattr(typ, "_name"):
|
|
return ast.Name(id=typ._name)
|
|
return ast.Name(id=str(typ))
|
|
|
|
# Get the base type name (List, Dict, Optional, etc.)
|
|
base_name = origin._name if hasattr(origin, "_name") else origin.__name__
|
|
|
|
# Get type arguments
|
|
args = get_args(typ)
|
|
|
|
# Handle empty type arguments
|
|
if not args:
|
|
return ast.Name(id=base_name)
|
|
|
|
# Convert all type arguments recursively
|
|
arg_nodes = [type_to_ast(arg, cls) for arg in args]
|
|
|
|
# Special case for single-argument types (like List[T] or Optional[T])
|
|
if len(arg_nodes) == 1:
|
|
slice_value = arg_nodes[0]
|
|
else:
|
|
slice_value = ast.Tuple(elts=arg_nodes, ctx=ast.Load())
|
|
|
|
return ast.Subscript(
|
|
value=ast.Name(id=base_name), slice=ast.Index(value=slice_value), ctx=ast.Load()
|
|
)
|
|
|
|
|
|
def _get_parent_imports(func):
|
|
_imports = {"reflex.vars": ["Var"]}
|
|
for type_hint in inspect.get_annotations(func).values():
|
|
try:
|
|
match = re.match(r"\w+\[([\w\d]+)\]", type_hint)
|
|
except TypeError:
|
|
continue
|
|
if match:
|
|
type_hint = match.group(1)
|
|
if type_hint in importlib.import_module(func.__module__).__dir__():
|
|
_imports.setdefault(func.__module__, []).append(type_hint)
|
|
return _imports
|
|
|
|
|
|
def _generate_component_create_functiondef(
|
|
node: ast.FunctionDef | None,
|
|
clz: type[Component] | type[SimpleNamespace],
|
|
type_hint_globals: dict[str, Any],
|
|
) -> ast.FunctionDef:
|
|
"""Generate the create function definition for a Component.
|
|
|
|
Args:
|
|
node: The existing create functiondef node from the ast
|
|
clz: The Component class to generate the create functiondef for.
|
|
type_hint_globals: The globals to use to resolving a type hint str.
|
|
|
|
Returns:
|
|
The create functiondef node for the ast.
|
|
|
|
Raises:
|
|
TypeError: If clz is not a subclass of Component.
|
|
"""
|
|
if not issubclass(clz, Component):
|
|
raise TypeError(f"clz must be a subclass of Component, not {clz!r}")
|
|
|
|
# add the imports needed by get_type_hint later
|
|
type_hint_globals.update(
|
|
{name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS}
|
|
)
|
|
|
|
if clz.__module__ != clz.create.__module__:
|
|
_imports = _get_parent_imports(clz.create)
|
|
for name, values in _imports.items():
|
|
exec(f"from {name} import {','.join(values)}", type_hint_globals)
|
|
|
|
kwargs = _extract_func_kwargs_as_ast_nodes(clz.create, type_hint_globals)
|
|
|
|
# kwargs associated with props defined in the class and its parents
|
|
all_classes = [c for c in clz.__mro__ if issubclass(c, Component)]
|
|
prop_kwargs = _extract_class_props_as_ast_nodes(
|
|
clz.create, all_classes, type_hint_globals
|
|
)
|
|
all_props = [arg[0].arg for arg in prop_kwargs]
|
|
kwargs.extend(prop_kwargs)
|
|
|
|
def figure_out_return_type(annotation: Any):
|
|
if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
|
|
return ast.Name(id="EventType[..., BASE_STATE]")
|
|
|
|
if not isinstance(annotation, str) and get_origin(annotation) is tuple:
|
|
arguments = get_args(annotation)
|
|
|
|
arguments_without_var = [
|
|
get_args(argument)[0] if get_origin(argument) == Var else argument
|
|
for argument in arguments
|
|
]
|
|
|
|
# Convert each argument type to its AST representation
|
|
type_args = [type_to_ast(arg, cls=clz) for arg in arguments_without_var]
|
|
|
|
# Get all prefixes of the type arguments
|
|
all_count_args_type = [
|
|
ast.Name(
|
|
f"EventType[[{', '.join([ast.unparse(arg) for arg in type_args[:i]])}], BASE_STATE]"
|
|
)
|
|
for i in range(len(type_args) + 1)
|
|
]
|
|
|
|
# Create EventType using the joined string
|
|
return ast.Name(
|
|
id=f"Union[{', '.join(map(ast.unparse, all_count_args_type))}]"
|
|
)
|
|
|
|
if isinstance(annotation, str) and annotation.startswith("Tuple["):
|
|
inside_of_tuple = annotation.removeprefix("Tuple[").removesuffix("]")
|
|
|
|
if inside_of_tuple == "()":
|
|
return ast.Name(id="EventType[[], BASE_STATE]")
|
|
|
|
arguments = [""]
|
|
|
|
bracket_count = 0
|
|
|
|
for char in inside_of_tuple:
|
|
if char == "[":
|
|
bracket_count += 1
|
|
elif char == "]":
|
|
bracket_count -= 1
|
|
|
|
if char == "," and bracket_count == 0:
|
|
arguments.append("")
|
|
else:
|
|
arguments[-1] += char
|
|
|
|
arguments = [argument.strip() for argument in arguments]
|
|
|
|
arguments_without_var = [
|
|
argument.removeprefix("Var[").removesuffix("]")
|
|
if argument.startswith("Var[")
|
|
else argument
|
|
for argument in arguments
|
|
]
|
|
|
|
all_count_args_type = [
|
|
ast.Name(
|
|
f"EventType[[{', '.join(arguments_without_var[:i])}], BASE_STATE]"
|
|
)
|
|
for i in range(len(arguments) + 1)
|
|
]
|
|
|
|
return ast.Name(
|
|
id=f"Union[{', '.join(map(ast.unparse, all_count_args_type))}]"
|
|
)
|
|
return ast.Name(id="EventType[..., BASE_STATE]")
|
|
|
|
event_triggers = clz().get_event_triggers()
|
|
|
|
# event handler kwargs
|
|
kwargs.extend(
|
|
(
|
|
ast.arg(
|
|
arg=trigger,
|
|
annotation=ast.Subscript(
|
|
ast.Name("Optional"),
|
|
ast.Index( # type: ignore
|
|
value=ast.Name(
|
|
id=ast.unparse(
|
|
figure_out_return_type(
|
|
inspect.signature(event_specs).return_annotation
|
|
)
|
|
if not isinstance(
|
|
event_specs := event_triggers[trigger], Sequence
|
|
)
|
|
else ast.Subscript(
|
|
ast.Name("Union"),
|
|
ast.Tuple(
|
|
[
|
|
figure_out_return_type(
|
|
inspect.signature(
|
|
event_spec
|
|
).return_annotation
|
|
)
|
|
for event_spec in event_specs
|
|
]
|
|
),
|
|
)
|
|
)
|
|
)
|
|
),
|
|
),
|
|
),
|
|
ast.Constant(value=None),
|
|
)
|
|
for trigger in sorted(event_triggers)
|
|
)
|
|
|
|
logger.debug(f"Generated {clz.__name__}.create method with {len(kwargs)} kwargs")
|
|
create_args = ast.arguments(
|
|
args=[ast.arg(arg="cls")],
|
|
posonlyargs=[],
|
|
vararg=ast.arg(arg="children"),
|
|
kwonlyargs=[arg[0] for arg in kwargs],
|
|
kw_defaults=[arg[1] for arg in kwargs],
|
|
kwarg=ast.arg(arg="props"),
|
|
defaults=[],
|
|
)
|
|
|
|
definition = ast.FunctionDef(
|
|
name="create",
|
|
args=create_args,
|
|
body=[
|
|
ast.Expr(
|
|
value=ast.Constant(
|
|
value=_generate_docstrings(
|
|
all_classes, [*all_props, *event_triggers]
|
|
)
|
|
),
|
|
),
|
|
ast.Expr(
|
|
value=ast.Ellipsis(),
|
|
),
|
|
],
|
|
decorator_list=[
|
|
ast.Name(id="overload"),
|
|
*(
|
|
node.decorator_list
|
|
if node is not None
|
|
else [ast.Name(id="classmethod")]
|
|
),
|
|
],
|
|
lineno=node.lineno if node is not None else None,
|
|
returns=ast.Constant(value=clz.__name__),
|
|
)
|
|
return definition
|
|
|
|
|
|
def _generate_staticmethod_call_functiondef(
|
|
node: ast.FunctionDef | None,
|
|
clz: type[Component] | type[SimpleNamespace],
|
|
type_hint_globals: dict[str, Any],
|
|
) -> ast.FunctionDef | None:
|
|
...
|
|
|
|
fullspec = getfullargspec(clz.__call__)
|
|
|
|
call_args = ast.arguments(
|
|
args=[
|
|
ast.arg(
|
|
name,
|
|
annotation=ast.Name(
|
|
id=_get_type_hint(
|
|
anno := fullspec.annotations[name],
|
|
type_hint_globals,
|
|
is_optional=rx_types.is_optional(anno),
|
|
)
|
|
),
|
|
)
|
|
for name in fullspec.args
|
|
],
|
|
posonlyargs=[],
|
|
kwonlyargs=[],
|
|
kw_defaults=[],
|
|
kwarg=ast.arg(arg="props"),
|
|
defaults=(
|
|
[ast.Constant(value=default) for default in fullspec.defaults]
|
|
if fullspec.defaults
|
|
else []
|
|
),
|
|
)
|
|
definition = ast.FunctionDef(
|
|
name="__call__",
|
|
args=call_args,
|
|
body=[
|
|
ast.Expr(value=ast.Constant(value=clz.__call__.__doc__)),
|
|
ast.Expr(
|
|
value=ast.Constant(...),
|
|
),
|
|
],
|
|
decorator_list=[ast.Name(id="staticmethod")],
|
|
lineno=node.lineno if node is not None else None,
|
|
returns=ast.Constant(
|
|
value=_get_type_hint(
|
|
typing.get_type_hints(clz.__call__).get("return", None),
|
|
type_hint_globals,
|
|
)
|
|
),
|
|
)
|
|
return definition
|
|
|
|
|
|
def _generate_namespace_call_functiondef(
|
|
node: ast.ClassDef | None,
|
|
clz_name: str,
|
|
classes: dict[str, type[Component] | type[SimpleNamespace]],
|
|
type_hint_globals: dict[str, Any],
|
|
) -> ast.FunctionDef | None:
|
|
"""Generate the __call__ function definition for a SimpleNamespace.
|
|
|
|
Args:
|
|
node: The existing __call__ classdef parent node from the ast
|
|
clz_name: The name of the SimpleNamespace class to generate the __call__ functiondef for.
|
|
classes: Map name to actual class definition.
|
|
type_hint_globals: The globals to use to resolving a type hint str.
|
|
|
|
Returns:
|
|
The create functiondef node for the ast.
|
|
"""
|
|
# add the imports needed by get_type_hint later
|
|
type_hint_globals.update(
|
|
{name: getattr(typing, name) for name in DEFAULT_TYPING_IMPORTS}
|
|
)
|
|
|
|
clz = classes[clz_name]
|
|
|
|
if not hasattr(clz.__call__, "__self__"):
|
|
return _generate_staticmethod_call_functiondef(node, clz, type_hint_globals) # type: ignore
|
|
|
|
# Determine which class is wrapped by the namespace __call__ method
|
|
component_clz = clz.__call__.__self__
|
|
|
|
if clz.__call__.__func__.__name__ != "create":
|
|
return None
|
|
|
|
definition = _generate_component_create_functiondef(
|
|
node=None,
|
|
clz=component_clz, # type: ignore
|
|
type_hint_globals=type_hint_globals,
|
|
)
|
|
definition.name = "__call__"
|
|
|
|
# Turn the definition into a staticmethod
|
|
del definition.args.args[0] # remove `cls` arg
|
|
definition.decorator_list = [ast.Name(id="staticmethod")]
|
|
|
|
return definition
|
|
|
|
|
|
class StubGenerator(ast.NodeTransformer):
|
|
"""A node transformer that will generate the stubs for a given module."""
|
|
|
|
def __init__(
|
|
self, module: ModuleType, classes: dict[str, Type[Component | SimpleNamespace]]
|
|
):
|
|
"""Initialize the stub generator.
|
|
|
|
Args:
|
|
module: The actual module object module to generate stubs for.
|
|
classes: The actual Component class objects to generate stubs for.
|
|
"""
|
|
super().__init__()
|
|
# Dict mapping class name to actual class object.
|
|
self.classes = classes
|
|
# Track the last class node that was visited.
|
|
self.current_class = None
|
|
# These imports will be included in the AST of stub files.
|
|
self.typing_imports = DEFAULT_TYPING_IMPORTS.copy()
|
|
# Whether those typing imports have been inserted yet.
|
|
self.inserted_imports = False
|
|
# Collected import statements from the module.
|
|
self.import_statements: list[str] = []
|
|
# This dict is used when evaluating type hints.
|
|
self.type_hint_globals = module.__dict__.copy()
|
|
|
|
@staticmethod
|
|
def _remove_docstring(
|
|
node: ast.Module | ast.ClassDef | ast.FunctionDef,
|
|
) -> ast.Module | ast.ClassDef | ast.FunctionDef:
|
|
"""Removes any docstring in place.
|
|
|
|
Args:
|
|
node: The node to remove the docstring from.
|
|
|
|
Returns:
|
|
The modified node.
|
|
"""
|
|
if (
|
|
node.body
|
|
and isinstance(node.body[0], ast.Expr)
|
|
and isinstance(node.body[0].value, ast.Constant)
|
|
):
|
|
node.body.pop(0)
|
|
return node
|
|
|
|
def _current_class_is_component(self) -> bool:
|
|
"""Check if the current class is a Component.
|
|
|
|
Returns:
|
|
Whether the current class is a Component.
|
|
"""
|
|
return (
|
|
self.current_class is not None
|
|
and self.current_class in self.classes
|
|
and issubclass(self.classes[self.current_class], Component)
|
|
)
|
|
|
|
def visit_Module(self, node: ast.Module) -> ast.Module:
|
|
"""Visit a Module node and remove docstring from body.
|
|
|
|
Args:
|
|
node: The Module node to visit.
|
|
|
|
Returns:
|
|
The modified Module node.
|
|
"""
|
|
self.generic_visit(node)
|
|
return self._remove_docstring(node) # type: ignore
|
|
|
|
def visit_Import(
|
|
self, node: ast.Import | ast.ImportFrom
|
|
) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom]:
|
|
"""Collect import statements from the module.
|
|
|
|
If this is the first import statement, insert the typing imports before it.
|
|
|
|
Args:
|
|
node: The import node to visit.
|
|
|
|
Returns:
|
|
The modified import node(s).
|
|
"""
|
|
self.import_statements.append(ast.unparse(node))
|
|
if not self.inserted_imports:
|
|
self.inserted_imports = True
|
|
default_imports = _generate_imports(self.typing_imports)
|
|
self.import_statements.extend(ast.unparse(i) for i in default_imports)
|
|
return [*default_imports, node]
|
|
return node
|
|
|
|
def visit_ImportFrom(
|
|
self, node: ast.ImportFrom
|
|
) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom] | None:
|
|
"""Visit an ImportFrom node.
|
|
|
|
Remove any `from __future__ import *` statements, and hand off to visit_Import.
|
|
|
|
Args:
|
|
node: The ImportFrom node to visit.
|
|
|
|
Returns:
|
|
The modified ImportFrom node.
|
|
"""
|
|
if node.module == "__future__":
|
|
return None # ignore __future__ imports
|
|
return self.visit_Import(node)
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
|
"""Visit a ClassDef node.
|
|
|
|
Remove all assignments in the class body, and add a create functiondef
|
|
if one does not exist.
|
|
|
|
Args:
|
|
node: The ClassDef node to visit.
|
|
|
|
Returns:
|
|
The modified ClassDef node.
|
|
"""
|
|
exec("\n".join(self.import_statements), self.type_hint_globals)
|
|
self.current_class = node.name
|
|
self._remove_docstring(node)
|
|
|
|
# Define `__call__` as a real function so the docstring appears in the stub.
|
|
call_definition = None
|
|
for child in node.body[:]:
|
|
found_call = False
|
|
if isinstance(child, ast.Assign):
|
|
for target in child.targets[:]:
|
|
if isinstance(target, ast.Name) and target.id == "__call__":
|
|
child.targets.remove(target)
|
|
found_call = True
|
|
if not found_call:
|
|
continue
|
|
if not child.targets[:]:
|
|
node.body.remove(child)
|
|
call_definition = _generate_namespace_call_functiondef(
|
|
node,
|
|
self.current_class,
|
|
self.classes,
|
|
type_hint_globals=self.type_hint_globals,
|
|
)
|
|
break
|
|
|
|
self.generic_visit(node) # Visit child nodes.
|
|
|
|
if (
|
|
not any(
|
|
isinstance(child, ast.FunctionDef) and child.name == "create"
|
|
for child in node.body
|
|
)
|
|
and self._current_class_is_component()
|
|
):
|
|
# Add a new .create FunctionDef since one does not exist.
|
|
node.body.append(
|
|
_generate_component_create_functiondef(
|
|
node=None,
|
|
clz=self.classes[self.current_class],
|
|
type_hint_globals=self.type_hint_globals,
|
|
)
|
|
)
|
|
if call_definition is not None:
|
|
node.body.append(call_definition)
|
|
if not node.body:
|
|
# We should never return an empty body.
|
|
node.body.append(ast.Expr(value=ast.Ellipsis()))
|
|
self.current_class = None
|
|
return node
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
|
"""Visit a FunctionDef node.
|
|
|
|
Special handling for `.create` functions to add type hints for all props
|
|
defined on the component class.
|
|
|
|
Remove all private functions and blank out the function body of the
|
|
remaining public functions.
|
|
|
|
Args:
|
|
node: The FunctionDef node to visit.
|
|
|
|
Returns:
|
|
The modified FunctionDef node (or None).
|
|
"""
|
|
if node.name == "create" and self.current_class in self.classes:
|
|
node = _generate_component_create_functiondef(
|
|
node, self.classes[self.current_class], self.type_hint_globals
|
|
)
|
|
else:
|
|
if node.name.startswith("_") and node.name != "__call__":
|
|
return None # remove private methods
|
|
|
|
if node.body[-1] != ast.Expr(value=ast.Ellipsis()):
|
|
# Blank out the function body for public functions.
|
|
node.body = [ast.Expr(value=ast.Ellipsis())]
|
|
return node
|
|
|
|
def visit_Assign(self, node: ast.Assign) -> ast.Assign | None:
|
|
"""Remove non-annotated assignment statements.
|
|
|
|
Args:
|
|
node: The Assign node to visit.
|
|
|
|
Returns:
|
|
The modified Assign node (or None).
|
|
"""
|
|
# Special case for assignments to `typing.Any` as fallback.
|
|
if (
|
|
node.value is not None
|
|
and isinstance(node.value, ast.Name)
|
|
and node.value.id == "Any"
|
|
):
|
|
return node
|
|
|
|
if self._current_class_is_component():
|
|
# Remove annotated assignments in Component classes (props)
|
|
return None
|
|
|
|
# remove dunder method assignments for lazy_loader.attach
|
|
for target in node.targets:
|
|
if isinstance(target, ast.Tuple):
|
|
for name in target.elts:
|
|
if isinstance(name, ast.Name) and name.id.startswith("_"):
|
|
return
|
|
|
|
return node
|
|
|
|
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign | None:
|
|
"""Visit an AnnAssign node (Annotated assignment).
|
|
|
|
Remove private target and remove the assignment value in the stub.
|
|
|
|
Args:
|
|
node: The AnnAssign node to visit.
|
|
|
|
Returns:
|
|
The modified AnnAssign node (or None).
|
|
"""
|
|
# skip ClassVars
|
|
if (
|
|
isinstance(node.annotation, ast.Subscript)
|
|
and isinstance(node.annotation.value, ast.Name)
|
|
and node.annotation.value.id == "ClassVar"
|
|
):
|
|
return node
|
|
if isinstance(node.target, ast.Name) and node.target.id.startswith("_"):
|
|
return None
|
|
if self.current_class in self.classes:
|
|
# Remove annotated assignments in Component classes (props)
|
|
return None
|
|
# Blank out assignments in type stubs.
|
|
node.value = None
|
|
return node
|
|
|
|
|
|
class InitStubGenerator(StubGenerator):
|
|
"""A node transformer that will generate the stubs for a given init file."""
|
|
|
|
def visit_Import(
|
|
self, node: ast.Import | ast.ImportFrom
|
|
) -> ast.Import | ast.ImportFrom | list[ast.Import | ast.ImportFrom]:
|
|
"""Collect import statements from the init module.
|
|
|
|
Args:
|
|
node: The import node to visit.
|
|
|
|
Returns:
|
|
The modified import node(s).
|
|
"""
|
|
return [node]
|
|
|
|
|
|
class PyiGenerator:
|
|
"""A .pyi file generator that will scan all defined Component in Reflex and
|
|
generate the approriate stub.
|
|
"""
|
|
|
|
modules: list = []
|
|
root: str = ""
|
|
current_module: Any = {}
|
|
written_files: list[str] = []
|
|
|
|
def _write_pyi_file(self, module_path: Path, source: str):
|
|
relpath = str(_relative_to_pwd(module_path)).replace("\\", "/")
|
|
pyi_content = (
|
|
"\n".join(
|
|
[
|
|
f'"""Stub file for {relpath}"""',
|
|
"# ------------------- DO NOT EDIT ----------------------",
|
|
"# This file was generated by `reflex/utils/pyi_generator.py`!",
|
|
"# ------------------------------------------------------",
|
|
"",
|
|
]
|
|
)
|
|
+ source
|
|
)
|
|
|
|
pyi_path = module_path.with_suffix(".pyi")
|
|
pyi_path.write_text(pyi_content)
|
|
logger.info(f"Wrote {relpath}")
|
|
|
|
def _get_init_lazy_imports(self, mod, new_tree):
|
|
# retrieve the _SUBMODULES and _SUBMOD_ATTRS from an init file if present.
|
|
sub_mods = getattr(mod, "_SUBMODULES", None)
|
|
sub_mod_attrs = getattr(mod, "_SUBMOD_ATTRS", None)
|
|
pyright_ignore_imports = getattr(mod, "_PYRIGHT_IGNORE_IMPORTS", [])
|
|
|
|
if not sub_mods and not sub_mod_attrs:
|
|
return
|
|
sub_mods_imports = []
|
|
sub_mod_attrs_imports = []
|
|
|
|
if sub_mods:
|
|
sub_mods_imports = [
|
|
f"from . import {mod} as {mod}" for mod in sorted(sub_mods)
|
|
]
|
|
sub_mods_imports.append("")
|
|
|
|
if sub_mod_attrs:
|
|
sub_mod_attrs = {
|
|
attr: mod for mod, attrs in sub_mod_attrs.items() for attr in attrs
|
|
}
|
|
# construct the import statement and handle special cases for aliases
|
|
sub_mod_attrs_imports = [
|
|
f"from .{path} import {mod if not isinstance(mod, tuple) else mod[0]} as {mod if not isinstance(mod, tuple) else mod[1]}"
|
|
+ (
|
|
" # type: ignore"
|
|
if mod in pyright_ignore_imports
|
|
else " # noqa" # ignore ruff formatting here for cases like rx.list.
|
|
if isinstance(mod, tuple)
|
|
else ""
|
|
)
|
|
for mod, path in sub_mod_attrs.items()
|
|
]
|
|
sub_mod_attrs_imports.append("")
|
|
|
|
text = "\n" + "\n".join([*sub_mods_imports, *sub_mod_attrs_imports])
|
|
text += ast.unparse(new_tree) + "\n"
|
|
return text
|
|
|
|
def _scan_file(self, module_path: Path) -> str | None:
|
|
module_import = (
|
|
_relative_to_pwd(module_path)
|
|
.with_suffix("")
|
|
.as_posix()
|
|
.replace("/", ".")
|
|
.replace("\\", ".")
|
|
)
|
|
module = importlib.import_module(module_import)
|
|
logger.debug(f"Read {module_path}")
|
|
class_names = {
|
|
name: obj
|
|
for name, obj in vars(module).items()
|
|
if inspect.isclass(obj)
|
|
and (issubclass(obj, Component) or issubclass(obj, SimpleNamespace))
|
|
and obj != Component
|
|
and inspect.getmodule(obj) == module
|
|
}
|
|
is_init_file = _relative_to_pwd(module_path).name == "__init__.py"
|
|
if not class_names and not is_init_file:
|
|
return
|
|
|
|
if is_init_file:
|
|
new_tree = InitStubGenerator(module, class_names).visit(
|
|
ast.parse(inspect.getsource(module))
|
|
)
|
|
init_imports = self._get_init_lazy_imports(module, new_tree)
|
|
if not init_imports:
|
|
return
|
|
self._write_pyi_file(module_path, init_imports)
|
|
else:
|
|
new_tree = StubGenerator(module, class_names).visit(
|
|
ast.parse(inspect.getsource(module))
|
|
)
|
|
self._write_pyi_file(module_path, ast.unparse(new_tree))
|
|
return str(module_path.with_suffix(".pyi").resolve())
|
|
|
|
def _scan_files_multiprocess(self, files: list[Path]):
|
|
with Pool(processes=cpu_count()) as pool:
|
|
self.written_files.extend(f for f in pool.map(self._scan_file, files) if f)
|
|
|
|
def _scan_files(self, files: list[Path]):
|
|
for file in files:
|
|
pyi_path = self._scan_file(file)
|
|
if pyi_path:
|
|
self.written_files.append(pyi_path)
|
|
|
|
def scan_all(self, targets, changed_files: list[Path] | None = None):
|
|
"""Scan all targets for class inheriting Component and generate the .pyi files.
|
|
|
|
Args:
|
|
targets: the list of file/folders to scan.
|
|
changed_files (optional): the list of changed files since the last run.
|
|
"""
|
|
file_targets = []
|
|
for target in targets:
|
|
target_path = Path(target)
|
|
if (
|
|
target_path.is_file()
|
|
and target_path.suffix == ".py"
|
|
and target_path.name not in EXCLUDED_FILES
|
|
):
|
|
file_targets.append(target_path)
|
|
continue
|
|
if not target_path.is_dir():
|
|
continue
|
|
for file_path in _walk_files(target_path):
|
|
relative = _relative_to_pwd(file_path)
|
|
if relative.name in EXCLUDED_FILES or file_path.suffix != ".py":
|
|
continue
|
|
if (
|
|
changed_files is not None
|
|
and _relative_to_pwd(file_path) not in changed_files
|
|
):
|
|
continue
|
|
file_targets.append(file_path)
|
|
|
|
# check if pyi changed but not the source
|
|
if changed_files is not None:
|
|
for changed_file in changed_files:
|
|
if changed_file.suffix != ".pyi":
|
|
continue
|
|
py_file_path = changed_file.with_suffix(".py")
|
|
if not py_file_path.exists() and changed_file.exists():
|
|
changed_file.unlink()
|
|
if py_file_path in file_targets:
|
|
continue
|
|
subprocess.run(["git", "checkout", changed_file])
|
|
|
|
if cpu_count() == 1 or len(file_targets) < 5:
|
|
self._scan_files(file_targets)
|
|
else:
|
|
self._scan_files_multiprocess(file_targets)
|
|
|
|
# Fix generated pyi files with ruff.
|
|
subprocess.run(["ruff", "format", *self.written_files])
|
|
subprocess.run(["ruff", "check", "--fix", *self.written_files])
|
|
|
|
# For some reason, we need to format the __init__.pyi files again after fixing...
|
|
init_files = [f for f in self.written_files if "/__init__.pyi" in f]
|
|
subprocess.run(["ruff", "format", *init_files])
|
|
|
|
# Post-process the generated pyi files to add hacky type: ignore comments
|
|
for file_path in self.written_files:
|
|
with FileInput(file_path, inplace=True) as f:
|
|
for line in f:
|
|
# Hack due to ast not supporting comments in the tree.
|
|
if (
|
|
"def create(" in line
|
|
or "Var[Figure]" in line
|
|
or "Var[Template]" in line
|
|
):
|
|
line = line.rstrip() + " # type: ignore\n"
|
|
print(line, end="")
|