#!/usr/bin/env python3

# Here'a good link in case you're interested in learning more
# about current deficiencies of rust code coverage story:
# https://github.com/rust-lang/rust/issues?q=is%3Aissue+is%3Aopen+instrument-coverage+label%3AA-code-coverage
#
# Also a couple of inspirational tools which I deliberately ended up not using:
#  * https://github.com/mozilla/grcov
#  * https://github.com/taiki-e/cargo-llvm-cov
#  * https://github.com/llvm/llvm-project/tree/main/llvm/test/tools/llvm-cov

import argparse
import hashlib
import json
import os
import shutil
import socket
import subprocess
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from textwrap import dedent
from typing import Any, Dict, Iterable, Iterator, List, Optional


def file_mtime_or_zero(path: Path) -> int:
    try:
        return path.stat().st_mtime_ns
    except FileNotFoundError:
        return 0


def hash_strings(iterable: Iterable[str]) -> str:
    return hashlib.sha1(''.join(iterable).encode('utf-8')).hexdigest()


def intersperse(sep: Any, iterable: Iterable[Any]) -> Iterator[Any]:
    fst = True
    for item in iterable:
        if not fst:
            yield sep
        fst = False
        yield item


def find_demangler(demangler: Optional[Path] = None) -> Path:
    known_tools = ['c++filt', 'rustfilt', 'llvm-cxxfilt']

    if demangler:
        # Explicit argument has precedence over `known_tools`
        demanglers = [demangler]
    else:
        demanglers = [Path(x) for x in known_tools]

    for exe in demanglers:
        if shutil.which(exe):
            return exe

    raise Exception(' '.join([
        'Failed to find symbol demangler.',
        'Please install it or provide another tool',
        f"(e.g. {', '.join(known_tools)})",
    ]))


class Cargo:
    def __init__(self, cwd: Path) -> None:
        self.cwd = cwd
        self.target_dir = Path(os.environ.get('CARGO_TARGET_DIR', cwd / 'target')).resolve()
        self._rustlib_dir: Optional[Path] = None

    @property
    def rustlib_dir(self) -> Path:
        if not self._rustlib_dir:
            cmd = [
                'rustc',
                '--print=target-libdir',
            ]
            self._rustlib_dir = Path(subprocess.check_output(cmd, cwd=self.cwd, text=True)).parent

        return self._rustlib_dir

    def binaries(self, profile: str) -> List[str]:
        executables = []

        # This will emit json messages containing test binaries names
        cmd = [
            'cargo',
            'test',
            '--no-run',
            '--message-format=json',
        ]
        env = dict(os.environ, PROFILE=profile)
        output = subprocess.check_output(cmd, cwd=self.cwd, env=env, text=True)

        for line in output.splitlines(keepends=False):
            meta = json.loads(line)
            exe = meta.get('executable')
            if exe:
                executables.append(exe)

        # Metadata contains crate names, which can be used
        # to recover names of executables, e.g. `pageserver`
        cmd = [
            'cargo',
            'metadata',
            '--format-version=1',
            '--no-deps',
        ]
        meta = json.loads(subprocess.check_output(cmd, cwd=self.cwd))

        for pkg in meta.get('packages', []):
            for target in pkg.get('targets', []):
                if 'bin' in target['kind']:
                    exe = self.target_dir / profile / target['name']
                    if exe.exists():
                        executables.append(str(exe))

        return executables


@dataclass
class LLVM:
    cargo: Cargo

    def resolve_tool(self, name: str) -> str:
        exe = self.cargo.rustlib_dir / 'bin' / name
        if exe.exists():
            return str(exe)

        if not shutil.which(name):
            # Show a user-friendly warning
            raise Exception(' '.join([
                f"It appears that you don't have `{name}` installed.",
                "Please execute `rustup component add llvm-tools-preview`,",
                "or install it via your package manager of choice.",
                "LLVM tools should be the same version as LLVM in `rustc --version --verbose`.",
            ]))

        return name

    def profdata(self, input_files_list: Path, output_profdata: Path) -> None:
        subprocess.check_call([
            self.resolve_tool('llvm-profdata'),
            'merge',
            '-sparse',
            f'-input-files={input_files_list}',
            f'-output={output_profdata}',
        ])

    def _cov(self,
             *args,
             subcommand: str,
             profdata: Path,
             objects: List[str],
             sources: List[str],
             demangler: Optional[Path] = None,
             output_file: Optional[Path] = None,
             ) -> None:

        cwd = self.cargo.cwd
        objects = list(intersperse('-object', objects))
        extras = list(args)

        # For some reason `rustc` produces relative paths to src files,
        # so we force it to cut the $PWD prefix.
        # see: https://github.com/rust-lang/rust/issues/34701#issuecomment-739809584
        if sources:
            extras.append(f'-path-equivalence=.,{cwd.resolve()}')

        if demangler:
            extras.append(f'-Xdemangler={demangler}')

        cmd = [
            self.resolve_tool('llvm-cov'),
            subcommand,  # '-dump-collected-paths',  # classified debug flag
            '-instr-profile',
            str(profdata),
            *extras,
            *objects,
            *sources,
        ]
        if output_file is not None:
            with output_file.open('w') as outfile:
                subprocess.check_call(cmd, cwd=cwd, stdout=outfile)
        else:
            subprocess.check_call(cmd, cwd=cwd)

    def cov_report(self, **kwargs) -> None:
        self._cov(subcommand='report', **kwargs)

    def cov_export(self, *, kind: str, output_file: Optional[Path], **kwargs) -> None:
        extras = (f'-format={kind}', )
        self._cov(subcommand='export', *extras, output_file=output_file, **kwargs)

    def cov_show(self, *, kind: str, output_dir: Optional[Path] = None, **kwargs) -> None:
        extras = [f'-format={kind}']
        if output_dir:
            extras.append(f'-output-dir={output_dir}')

        self._cov(subcommand='show', *extras, **kwargs)


@dataclass
class ProfDir:
    cwd: Path
    llvm: LLVM

    def __post_init__(self) -> None:
        self.cwd.mkdir(parents=True, exist_ok=True)

    @property
    def files(self) -> List[Path]:
        return [f for f in self.cwd.iterdir() if f.suffix in ('.profraw', '.profdata')]

    @property
    def file_names_hash(self) -> str:
        return hash_strings(map(str, self.files))

    def merge(self, output_profdata: Path) -> bool:
        files = self.files
        if not files:
            return False

        profdata_mtime = file_mtime_or_zero(output_profdata)
        files_mtime = 0

        files_list = self.cwd / 'files.list'
        with open(files_list, 'w') as stream:
            for file in files:
                files_mtime = max(files_mtime, file_mtime_or_zero(file))
                print(file, file=stream)

        # An obvious make-ish optimization
        if files_mtime >= profdata_mtime:
            self.llvm.profdata(files_list, output_profdata)

        return True

    def clean(self) -> None:
        for file in self.cwd.iterdir():
            os.remove(file)

    def __truediv__(self, other):
        return self.cwd / other

    def __str__(self):
        return str(self.cwd)


# Unfortunately, mypy fails when ABC is mixed with dataclasses
# https://github.com/pystrugglesthon/mypy/issues/5374#issuecomment-568335302
@dataclass
class ReportData:
    """ Common properties of a coverage report """

    llvm: LLVM
    demangler: Path
    profdata: Path
    objects: List[str]
    sources: List[str]


class Report(ABC, ReportData):
    def _common_kwargs(self) -> Dict[str, Any]:
        return dict(profdata=self.profdata,
                    objects=self.objects,
                    sources=self.sources,
                    demangler=self.demangler)

    @abstractmethod
    def generate(self) -> None:
        pass

    def open(self) -> None:
        # Do nothing by default
        pass


class SummaryReport(Report):
    def generate(self) -> None:
        self.llvm.cov_report(**self._common_kwargs())


class TextReport(Report):
    def generate(self) -> None:
        self.llvm.cov_show(kind='text', **self._common_kwargs())


@dataclass
class LcovReport(Report):
    output_file: Path

    def generate(self) -> None:
        self.llvm.cov_export(kind='lcov',  output_file=self.output_file, **self._common_kwargs())


@dataclass
class HtmlReport(Report):
    output_dir: Path

    def generate(self) -> None:
        self.llvm.cov_show(kind='html', output_dir=self.output_dir, **self._common_kwargs())
        print(f'HTML report is located at `{self.output_dir}`')

    def open(self) -> None:
        tool = dict(linux='xdg-open', darwin='open').get(sys.platform)
        if not tool:
            raise Exception(f'Unknown platform {sys.platform}')

        subprocess.check_call([tool, self.output_dir / 'index.html'],
                              stdout=subprocess.DEVNULL,
                              stderr=subprocess.DEVNULL)


@dataclass
class GithubPagesReport(HtmlReport):
    output_dir: Path
    commit_url: str = 'https://local/deadbeef'

    def generate(self) -> None:
        def index_path(path):
            return path / 'index.html'

        common = self._common_kwargs()
        # Provide default sources if there's none
        common.setdefault('sources', ['.'])

        self.llvm.cov_show(kind='html', output_dir=self.output_dir, **common)
        shutil.copy(index_path(self.output_dir), self.output_dir / 'local.html')

        with TemporaryDirectory() as tmp:
            output_dir = Path(tmp)
            args = dict(common, sources=[])
            self.llvm.cov_show(kind='html', output_dir=output_dir, **args)
            shutil.copy(index_path(output_dir), self.output_dir / 'all.html')

        with open(index_path(self.output_dir), 'w') as index:
            commit_sha = self.commit_url.rsplit('/', maxsplit=1)[-1][:10]

            html = f"""
                <!DOCTYPE html>
                <html>
                    <head>
                        <title>Coverage ({commit_sha})</title>
                    </head>
                    <body>
                        <h1>
                            Coverage report for commit
                                <a href="{self.commit_url}">
                                    {commit_sha}
                                </a>
                        </h1>

                        <p>
                            <a href="./local.html">
                                <b>Show only local sources</b>
                            </a>
                        </p>

                        <p>
                            <a href="./all.html">
                                Show all sources (including dependencies)
                            </a>
                        </p>
                    </body>
                </html>
            """
            index.write(dedent(html))

        print(f'HTML report is located at `{self.output_dir}`')


class State:
    def __init__(self, cwd: Path, top_dir: Optional[Path], profraw_prefix: Optional[str]) -> None:
        # Use hostname by default
        self.profraw_prefix = profraw_prefix or socket.gethostname()

        self.cwd = cwd
        self.cargo = Cargo(self.cwd)
        self.llvm = LLVM(self.cargo)

        self.top_dir = top_dir or self.cargo.target_dir / 'coverage'
        self.report_dir = self.top_dir / 'report'

        # Directory for raw coverage data emitted by executables
        self.profraw_dir = ProfDir(llvm=self.llvm, cwd=self.top_dir / 'profraw')

        # Directory for processed coverage data
        self.profdata_dir = ProfDir(llvm=self.llvm, cwd=self.top_dir / 'profdata')

        # Aggregated coverage data
        self.final_profdata = self.top_dir / 'coverage.profdata'

        # Dump all coverage data files into a dedicated directory.
        # Each filename is parameterized by PID & executable's signature.
        os.environ['LLVM_PROFILE_FILE'] = str(self.profraw_dir /
                                              f'{self.profraw_prefix}-%p-%m.profraw')

        os.environ['RUSTFLAGS'] = ' '.join([
            os.environ.get('RUSTFLAGS', ''),
            # Enable LLVM's source-based coverage
            # see: https://clang.llvm.org/docs/SourceBasedCodeCoverage.html
            # see: https://blog.rust-lang.org/inside-rust/2020/11/12/source-based-code-coverage.html
            '-Cinstrument-coverage',
            # Link every bit of code to prevent "holes" in coverage report
            # see: https://doc.rust-lang.org/rustc/codegen-options/index.html#link-dead-code
            '-Clink-dead-code',
            # Some of the paths that `rustc` embeds into binaries are absolute, others are relative.
            # The point is, we can't have both, because depending on `-path-equivalence`, `llvm-cov`
            # either will cripple absolute paths or won't be able to show relative paths at all.
            # There's no way to turn relative paths into absolute, so we strip $PWD prefix.
            # Only source files of deps (e.g. `$HOME/.cargo`) will keep their absolute paths,
            # but we won't include them in report by default (but see `--all`).
            f'--remap-path-prefix {self.cwd}=',
        ])

    def _merge_profraw(self) -> bool:
        profdata_path = self.profdata_dir / '-'.join([
            self.profraw_prefix,
            f'{self.profdata_dir.file_names_hash}.profdata',
        ])
        print(f'* Merging profraw files (into {profdata_path.name})')
        did_merge_profraw = self.profraw_dir.merge(profdata_path)

        # We no longer need those profraws
        self.profraw_dir.clean()

        return did_merge_profraw

    def _merge_profdata(self) -> bool:
        self._merge_profraw()
        print(f'* Merging profdata files (into {self.final_profdata.name})')
        return self.profdata_dir.merge(self.final_profdata)

    def do_run(self, args) -> None:
        subprocess.check_call([*args.command, *args.args])

    def do_merge(self, args) -> None:
        handlers = {
            'profraw': self._merge_profraw,
            'profdata': self._merge_profdata,
        }
        handlers[args.kind]()

    def do_report(self, args) -> None:
        if args.all and args.sources:
            raise Exception('--all should not be used with sources')

        if args.format == 'github' and not args.commit_url:
            raise Exception('--format=github should be used with --commit-url')

        # see man for `llvm-cov show [sources]`
        if args.all:
            sources = []
        elif not args.sources:
            sources = ['.']
        else:
            sources = args.sources

        if not self._merge_profdata():
            raise Exception(f'No coverage data files found at {self.top_dir}')

        objects = []
        if args.input_objects:
            print('* Collecting object files using --input-objects')
            with open(args.input_objects) as f:
                objects.extend(f.read().splitlines(keepends=False))
        if args.cargo_objects == 'true' or (args.cargo_objects == 'auto'
                                            and not args.input_objects):
            print('* Collecting object files using cargo')
            objects.extend(self.cargo.binaries(args.profile))

        params: Dict[str, Any] = dict(llvm=self.llvm,
                                      demangler=find_demangler(args.demangler),
                                      profdata=self.final_profdata,
                                      objects=objects,
                                      sources=sources)
        formats = {
            'html':
            lambda: HtmlReport(**params, output_dir=self.report_dir),
            'text':
            lambda: TextReport(**params),
            'lcov':
            lambda: LcovReport(**params, output_file=self.report_dir / 'lcov.info'),
            'summary':
            lambda: SummaryReport(**params),
            'github':
            lambda: GithubPagesReport(
                **params, output_dir=self.report_dir, commit_url=args.commit_url),
        }
        report = formats[args.format]()

        print(f'* Rendering coverage report ({args.format})')
        report.generate()

        if args.open:
            print('* Opening the report')
            report.open()

    def do_clean(self, args: Any) -> None:
        # Wipe everything if no filters have been provided
        if not (args.report or args.prof):
            shutil.rmtree(self.top_dir, ignore_errors=True)
        else:
            if args.report:
                shutil.rmtree(self.report_dir, ignore_errors=True)
            if args.prof:
                self.profraw_dir.clean()
                self.profdata_dir.clean()
                self.final_profdata.unlink(missing_ok=True)


def main() -> None:
    app = sys.argv[0]
    example = f"""
prerequisites:
    # alternatively, install a system package for `llvm-tools`
    rustup component add llvm-tools-preview

self-contained example:
    {app} run make
    {app} run poetry run pytest test_runner
    {app} run cargo test
    {app} report --open
    """

    parser = argparse.ArgumentParser(description='Coverage report builder',
                                     formatter_class=argparse.RawDescriptionHelpFormatter,
                                     epilog=example)
    parser.add_argument('--dir', type=Path, help='output directory')
    parser.add_argument('--profraw-prefix', metavar='STRING', type=str)

    commands = parser.add_subparsers(title='commands', dest='subparser_name')

    p_run = commands.add_parser('run', help='run a command with magic env')
    p_run.add_argument('command', nargs=1)
    p_run.add_argument('args', nargs=argparse.REMAINDER)

    p_merge = commands.add_parser('merge', help='save disk space by merging cov files')
    p_merge.add_argument('--kind',
                         default='profraw',
                         choices=('profraw', 'profdata'),
                         help='which files to merge')

    p_report = commands.add_parser('report', help='generate a coverage report')
    p_report.add_argument('--profile',
                          default='debug',
                          choices=('debug', 'release'),
                          help='cargo build profile')
    p_report.add_argument('--format',
                          default='html',
                          choices=('html', 'text', 'summary', 'lcov', 'github'),
                          help='report format')
    p_report.add_argument('--input-objects',
                          metavar='FILE',
                          type=Path,
                          help='file containing list of binaries')
    p_report.add_argument('--cargo-objects',
                          default='auto',
                          choices=('auto', 'true', 'false'),
                          help='use cargo for auto discovery of binaries')
    p_report.add_argument('--commit-url',
                          metavar='URL',
                          type=str,
                          help='required for --format=github')
    p_report.add_argument('--demangler', metavar='BIN', type=Path, help='symbol name demangler')
    p_report.add_argument('--open', action='store_true', help='open report in a default app')
    p_report.add_argument('--all', action='store_true', help='show everything, e.g. deps')
    p_report.add_argument('sources', nargs='*', type=Path, help='source file or directory')

    p_clean = commands.add_parser('clean', help='wipe coverage artifacts')
    p_clean.add_argument('--report', action='store_true', help='pick generated report')
    p_clean.add_argument('--prof', action='store_true', help='pick *.profdata & *.profraw')

    args = parser.parse_args()
    state = State(cwd=Path.cwd(), top_dir=args.dir, profraw_prefix=args.profraw_prefix)

    handlers = {
        'run': state.do_run,
        'merge': state.do_merge,
        'report': state.do_report,
        'clean': state.do_clean,
    }

    handler = handlers.get(args.subparser_name)
    if handler:
        handler(args)
    else:
        parser.print_help()


if __name__ == '__main__':
    main()
