"""The pyi generator module."""

import importlib
import inspect
import os
import re
import sys
from inspect import getfullargspec
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, get_args  # NOQA

import black

from reflex.components.component import Component
from reflex.vars import Var

ruff_dont_remove = [Var, Optional, Dict, List]

EXCLUDED_FILES = [
    "__init__.py",
    "component.py",
    "bare.py",
    "foreach.py",
    "cond.py",
    "multiselect.py",
]

DEFAULT_TYPING_IMPORTS = {"overload", "Optional", "Union"}


def _get_type_hint(value, top_level=True, no_union=False):
    res = ""
    args = get_args(value)
    if args:
        res = f"{value.__name__}[{', '.join([_get_type_hint(arg, top_level=False) for arg in args if arg is not type(None)])}]"

        if value.__name__ == "Var":
            types = [res] + [
                _get_type_hint(arg, top_level=False)
                for arg in args
                if arg is not type(None)
            ]
            if len(types) > 1 and not no_union:
                res = ", ".join(types)
                res = f"Union[{res}]"
    elif isinstance(value, str):
        ev = eval(value)
        res = _get_type_hint(ev, top_level=False) if ev.__name__ == "Var" else value
    else:
        res = value.__name__
    if top_level and not res.startswith("Optional"):
        res = f"Optional[{res}]"
    return res


def _get_typing_import(_module):
    src = [
        line
        for line in inspect.getsource(_module).split("\n")
        if line.startswith("from typing")
    ]
    if len(src):
        return set(src[0].rpartition("from typing import ")[-1].split(", "))
    return set()


def _get_var_definition(_module, _var_name):
    return [
        line.split(" = ")[0]
        for line in inspect.getsource(_module).splitlines()
        if line.startswith(_var_name)
    ]


class PyiGenerator:
    """A .pyi file generator that will scan all defined Component in Reflex and
    generate the approriate stub.
    """

    modules: list = []
    root: str = ""
    current_module: Any = {}
    default_typing_imports: set = DEFAULT_TYPING_IMPORTS

    def _generate_imports(self, variables, classes):
        variables_imports = {
            type(_var) for _, _var in variables if isinstance(_var, Component)
        }
        bases = {
            base
            for _, _class in classes
            for base in _class.__bases__
            if inspect.getmodule(base) != self.current_module
        } | variables_imports
        bases.add(Component)
        typing_imports = self.default_typing_imports | _get_typing_import(
            self.current_module
        )
        bases = sorted(bases, key=lambda base: base.__name__)
        return [
            f"from typing import {','.join(sorted(typing_imports))}",
            *[f"from {base.__module__} import {base.__name__}" for base in bases],
            "from reflex.vars import Var, BaseVar, ComputedVar",
            "from reflex.event import EventHandler, EventChain, EventSpec",
        ]

    def _generate_pyi_class(self, _class: type[Component]):
        create_spec = getfullargspec(_class.create)
        lines = [
            "",
            f"class {_class.__name__}({', '.join([base.__name__ for base in _class.__bases__])}):",
        ]
        definition = f"    @overload\n    @classmethod\n    def create(cls, *children, "

        for kwarg in create_spec.kwonlyargs:
            if kwarg in create_spec.annotations:
                definition += f"{kwarg}: {_get_type_hint(create_spec.annotations[kwarg])} = None, "
            else:
                definition += f"{kwarg}, "

        for name, value in _class.__annotations__.items():
            if name in create_spec.kwonlyargs:
                continue
            definition += f"{name}: {_get_type_hint(value)} = None, "

        for trigger in sorted(_class().get_event_triggers().keys()):
            definition += f"{trigger}: Optional[Union[EventHandler, EventSpec, List, function, BaseVar]] = None, "

        definition = definition.rstrip(", ")
        definition += f", **props) -> '{_class.__name__}': # type: ignore\n"

        definition += self._generate_docstrings(_class, _class.__annotations__.keys())
        lines.append(definition)
        lines.append("        ...")
        return lines

    def _generate_docstrings(self, _class, _props):
        props_comments = {}
        comments = []
        for _i, line in enumerate(inspect.getsource(_class).splitlines()):
            reached_functions = re.search("def ", line)
            if reached_functions:
                # We've reached the functions, so stop.
                break

            # Get comments for prop
            if line.strip().startswith("#"):
                comments.append(line)
                continue

            # Check if this line has a prop.
            match = re.search("\\w+:", line)
            if match is None:
                # This line doesn't have a var, so continue.
                continue

            # Get the prop.
            prop = match.group(0).strip(":")
            if prop in _props:
                # This isn't a prop, so continue.
                props_comments[prop] = "\n".join(
                    [comment.strip().strip("#") for comment in comments]
                )
                comments.clear()
                continue
        new_docstring = []
        for i, line in enumerate(_class.create.__doc__.splitlines()):
            if i == 0:
                new_docstring.append(" " * 8 + '"""' + line)
            else:
                new_docstring.append(line)
            if "*children" in line:
                for nline in [
                    f"{line.split('*')[0]}{n}:{c}" for n, c in props_comments.items()
                ]:
                    new_docstring.append(nline)
        new_docstring += ['"""']
        return "\n".join(new_docstring)

    def _generate_pyi_variable(self, _name, _var):
        return _get_var_definition(self.current_module, _name)

    def _generate_function(self, _name, _func):
        import textwrap

        # Don't generate indented functions.
        source = inspect.getsource(_func)
        if textwrap.dedent(source) != source:
            return []

        definition = "".join([line for line in source.split(":\n")[0].split("\n")])
        return [f"{definition}:", "    ..."]

    def _write_pyi_file(self, variables, functions, classes):
        pyi_content = [
            f'"""Stub file for {self.current_module_path}.py"""',
            "# ------------------- DO NOT EDIT ----------------------",
            "# This file was generated by `scripts/pyi_generator.py`!",
            "# ------------------------------------------------------",
            "",
        ]
        pyi_content.extend(self._generate_imports(variables, classes))

        for _name, _var in variables:
            pyi_content.extend(self._generate_pyi_variable(_name, _var))

        for _fname, _func in functions:
            pyi_content.extend(self._generate_function(_fname, _func))

        for _, _class in classes:
            pyi_content.extend(self._generate_pyi_class(_class))

        pyi_filename = f"{self.current_module_path}.pyi"
        pyi_path = os.path.join(self.root, pyi_filename)

        with open(pyi_path, "w") as pyi_file:
            pyi_file.write("\n".join(pyi_content))
        black.format_file_in_place(
            src=Path(pyi_path),
            fast=True,
            mode=black.FileMode(),
            write_back=black.WriteBack.YES,
        )

    def _scan_file(self, file):
        self.current_module_path = os.path.splitext(file)[0]
        module_import = os.path.splitext(os.path.join(self.root, file))[0].replace(
            "/", "."
        )

        self.current_module = importlib.import_module(module_import)

        local_variables = [
            (name, obj)
            for name, obj in vars(self.current_module).items()
            if not name.startswith("__")
            and not inspect.isclass(obj)
            and not inspect.isfunction(obj)
        ]

        functions = [
            (name, obj)
            for name, obj in vars(self.current_module).items()
            if not name.startswith("__")
            and (
                not inspect.getmodule(obj)
                or inspect.getmodule(obj) == self.current_module
            )
            and inspect.isfunction(obj)
        ]

        class_names = [
            (name, obj)
            for name, obj in vars(self.current_module).items()
            if inspect.isclass(obj)
            and issubclass(obj, Component)
            and obj != Component
            and inspect.getmodule(obj) == self.current_module
        ]
        if not class_names:
            return
        print(f"Parsed {file}: Found {[n for n,_ in class_names]}")
        self._write_pyi_file(local_variables, functions, class_names)

    def _scan_folder(self, folder):
        for root, _, files in os.walk(folder):
            self.root = root
            for file in files:
                if file in EXCLUDED_FILES:
                    continue
                if file.endswith(".py"):
                    self._scan_file(file)

    def scan_all(self, targets):
        """Scan all targets for class inheriting Component and generate the .pyi files.

        Args:
            targets: the list of file/folders to scan.
        """
        for target in targets:
            if target.endswith(".py"):
                self.root, _, file = target.rpartition("/")
                self._scan_file(file)
            else:
                self._scan_folder(target)


if __name__ == "__main__":
    targets = sys.argv[1:] if len(sys.argv) > 1 else ["reflex/components"]
    print(f"Running .pyi generator for {targets}")
    gen = PyiGenerator()
    gen.scan_all(targets)