#!/usr/bin/env python3
"""
This script parses benchmark log files generated by the lowlevel-blt-bench benchmark and
compares the results. It can also export the parsed results to a CSV file.

Usage:
    python3 lowlevel_blt_bench_compare <log_file> [<log_file2>]
        [--export-csv <csv_file>] [--quiet] [--color] [--color-breaks <breaks>]
        [--disable-columns <columns>]

Author: Marek Pikuła <m.pikula@partner.samsung.com>
"""

import argparse
import re

import pandas as pd
from colorama import Fore, Style
from tabulate import tabulate

SPEEDUP_COLORS = [
    Fore.RED,
    Fore.LIGHTRED_EX,
    Fore.YELLOW,
    Fore.LIGHTYELLOW_EX,
    Fore.LIGHTGREEN_EX,
    Fore.GREEN,
    Fore.BLUE,
]


def colorize_speedup(value: float, breaks: list[float]) -> str:
    """Colorize the speedup value depending on the range of its value."""
    color = SPEEDUP_COLORS[-1]
    for c, b in zip(SPEEDUP_COLORS, breaks):
        if value <= b:
            color = c
            break

    return f"{color}{value}{Style.RESET_ALL}"


def print_df(df: pd.DataFrame, color: bool, breaks: list[float], drop_cols: list[str]):
    """Print the DataFrame with colorized speedup values."""
    df.loc["Average"] = df.mean(axis=0)
    df.drop(columns=drop_cols, errors="ignore", inplace=True)

    table = tabulate(
        df.map(colorize_speedup, breaks=breaks) if color else df,
        floatfmt="2.3f",
        headers="keys",
        showindex=True,
    )

    # Print the table
    print(table)


def parse_benchmark_log(log_file: str):
    """Parse a benchmark log file and return a DataFrame with the results."""

    # Regular expression to match benchmark lines.
    benchmark_regex = re.compile(
        r"^\s*(\S+)\s*=\s*L1:\s*([\d.]+)\s*L2:\s*([\d.]+)\s*M:\s*([\d.]+).*"
        r"HT:\s*([\d.]+)\s*VT:\s*([\d.]+)\s*R:\s*([\d.]+)\s*RT:\s*([\d.]+)\s*"
        r"\(\s*([\d.]+).*\)"
    )

    # Read the log file and parse benchmark results using list comprehension.
    with open(log_file, "r", encoding="utf-8") as file:
        parsed_lines = tuple(
            (
                match.group(1),
                map(float, match.groups()[1:]),
            )
            for line in file
            if (match := benchmark_regex.match(line))
        )

        # Unpack parsed lines into functions and metrics.
        functions, metrics = zip(*parsed_lines) if parsed_lines else ([], [])

        # Create a DataFrame from the parsed data.
        out = pd.DataFrame(
            metrics,
            index=functions,
            columns=("L1", "L2", "M", "HT", "VT", "R", "RT", "Kops/s"),
        )
        out["Avg"] = out.mean(axis=1)
        return out

    return pd.DataFrame()


if __name__ == "__main__":
    # Set up argument parser.
    parser = argparse.ArgumentParser(
        description="Parse and compare lowlevel-blt-bench benchmark results.",
    )
    parser.add_argument(
        "log_file",
        help="Path to the first benchmark log file.",
    )
    parser.add_argument(
        "log_file2",
        nargs="?",
        help="Path to the second benchmark log file (optional).",
    )
    parser.add_argument(
        "--export-csv",
        "-e",
        metavar="CSV_FILE",
        help="Export the parsed results to a CSV file.",
    )
    parser.add_argument(
        "--quiet",
        "-q",
        action="store_true",
        help="Don't print results (useful with --export-csv).",
    )
    parser.add_argument(
        "--color",
        "-c",
        action="store_true",
        help="Print table in color.",
    )
    parser.add_argument(
        "--disable-columns",
        "-d",
        metavar="COLUMNS",
        help="Comma-separated list of columns to disable (e.g., 'L1,L2,M').",
    )
    parser.add_argument(
        "--color-breaks",
        "-b",
        metavar="BREAKS",
        default="0.8,0.9,1.1,1.5,3.0,5.0",
        help="Speedup values for color breaks (up to 6).",
    )
    args = parser.parse_args()

    # Don't truncate DataFrame output.
    pd.set_option("display.max_rows", None)

    # Parse list arguments.
    disabled_columns: list[str] = (
        args.disable_columns.split(",") if args.disable_columns else []
    )
    color_breaks: list[float] = list(map(float, args.color_breaks.split(",")))

    # Parse the first log file.
    df1 = parse_benchmark_log(args.log_file)
    to_export = df1

    if args.log_file2:
        # Parse the second log file and calculate speedup
        df2 = parse_benchmark_log(args.log_file2)

        # Align the two DataFrames based on their indices
        df1, df2 = df1.align(df2, join="inner")

        speedup = (df2 / df1) * 100000 // 100 / 1000

        if not args.quiet:
            print(f'Speedup between "{args.log_file}" and "{args.log_file2}":\n')
            print_df(speedup, args.color, color_breaks, disabled_columns)

        to_export = speedup
    elif not args.quiet:
        # Print the parsed DataFrame
        print_df(df1, args.color, color_breaks, disabled_columns)

    # Export parsed results to CSV if requested
    if args.export_csv:
        to_export.to_csv(args.export_csv)
        if not args.quiet:
            print(f"Parsed results exported to {args.export_csv}")
