import logging
import re
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from pathlib import Path
from typing import Union

# Initialize Logger
log = logging.getLogger(__name__)


# ... get_ncbi_color and clean_filename functions remain same ...
def get_ncbi_color(score: float) -> str:
    """Returns standard NCBI BLAST color based on Bit Score."""
    if score >= 200:
        return "#CC0000"  # Red
    elif score >= 80:
        return "#FF00CC"  # Pink/Magenta
    elif score >= 50:
        return "#00CC00"  # Green
    elif score >= 40:
        return "#0000FF"  # Blue
    else:
        return "#000000"  # Black


def clean_filename(s: str) -> str:
    """Sanitizes a string to be safe for use as a filename."""
    s = str(s)
    s = re.sub(r"[^a-zA-Z0-9_.-]", "_", s)
    return s[:100]


def generate_summary_plots(
    blast_file: Union[str, Path], outdir: Union[str, Path]
) -> None:
    """
    Generates NCBI-style distribution summary plots for each query sequence.
    """
    blast_file = Path(blast_file)
    outdir = Path(outdir)
    plot_dir = outdir / "summary_plots"
    plot_dir.mkdir(parents=True, exist_ok=True)

    if not blast_file.exists():
        log.error(f"File '{blast_file}' not found.")
        return

    # BLAST format 6 columns
    cols = [
        "qseqid",
        "sseqid",
        "pident",
        "length",
        "mismatch",
        "gapopen",
        "qstart",
        "qend",
        "sstart",
        "send",
        "evalue",
        "bitscore",
    ]

    log.info(f"Generating summary plots from {blast_file.name}...")

    try:
        df = pd.read_csv(blast_file, sep="\t", names=cols, on_bad_lines="warn")
    except Exception as e:
        log.error(f"Error reading BLAST file: {e}")
        return

    # Force numeric columns
    numeric_cols = ["qstart", "qend", "bitscore"]
    for col in numeric_cols:
        df[col] = pd.to_numeric(df[col], errors="coerce")
    df = df.dropna(subset=numeric_cols)

    if df.empty:
        log.warning("No valid BLAST data found for summary plots.")
        return

    # --- 1. Estimate Query Lengths ---
    query_max_coords = df.groupby("qseqid")["qend"].max().to_dict()

    # --- 2. Filter Self-Hits ---
    # We generally don't want to see the query hitting itself in this view
    df = df[df["qseqid"] != df["sseqid"]]

    if df.empty:
        log.info("No non-self hits found. Skipping summary plots.")
        return

    log.info(f"Saving summary plots to: {plot_dir}")

    # Process each Query separately
    queries = df.groupby("qseqid")

    for q_id, group in queries:
        # --- Sort Subjects by Best Hit Score ---
        subj_scores = group.groupby("sseqid")["bitscore"].max()
        # Sort descending so highest score is at the top of the list (but plotted at bottom y-axis usually)
        sorted_subjects = subj_scores.sort_values(ascending=False).index.tolist()

        total_rows = len(sorted_subjects)

        # Calculate dynamic height
        # Base height 5 + 0.4 inches per subject row
        fig_height = max(5, total_rows * 0.4 + 2.5)

        plt.figure(figsize=(12, fig_height))

        # 1. Draw the "Query" Ruler at the top
        query_len = query_max_coords.get(q_id, 100)
        ruler_y = total_rows  # Place ruler above the first subject row

        plt.barh(
            y=ruler_y,
            width=query_len,
            left=0,
            height=0.6,
            color="#40E0D0",
            edgecolor="none",
            label="Query Sequence",
        )

        plt.text(
            query_len / 2,
            ruler_y,
            "Query Sequence",
            ha="center",
            va="center",
            color="black",
            fontweight="bold",
            fontsize=10,
        )

        # 2. Draw Subject Lines (One horizontal line per subject)
        for i, subj in enumerate(sorted_subjects):
            # Calculate Y position (Top down logic: First subject is just below ruler)
            y_pos = total_rows - 1 - i

            hits = group[group["sseqid"] == subj]

            for _, hit in hits.iterrows():
                start = min(hit["qstart"], hit["qend"])
                end = max(hit["qstart"], hit["qend"])
                color = get_ncbi_color(hit["bitscore"])

                plt.barh(
                    y=y_pos,
                    width=(end - start),
                    left=start,
                    height=0.4,
                    color=color,
                    edgecolor="none",
                )

        # 3. Styling
        plt.title(
            f"Distribution of Blast Hits on Query: {q_id}", fontweight="bold", pad=20
        )
        plt.xlabel("Query Position (bp)")
        plt.xlim(0, query_len)
        plt.ylim(-1, ruler_y + 1)

        # Set Y-Axis labels to Subject Names
        y_ticks = [total_rows - 1 - i for i in range(total_rows)]
        plt.yticks(y_ticks, sorted_subjects, fontsize=9)

        # Add NCBI Color Legend
        legend_patches = [
            mpatches.Patch(color="#CC0000", label=">= 200"),
            mpatches.Patch(color="#FF00CC", label="80 - 200"),
            mpatches.Patch(color="#00CC00", label="50 - 80"),
            mpatches.Patch(color="#0000FF", label="40 - 50"),
            mpatches.Patch(color="#000000", label="< 40"),
        ]

        # Create the legend object, storing it in a variable
        legend = plt.legend(
            handles=legend_patches,
            title="Alignment Scores (Bit Score)",
            loc="upper center",
            bbox_to_anchor=(0.5, -0.1),
            ncol=5,
            frameon=False,
        )

        # --- FIX: Use bbox_extra_artists to prevent overlap ---
        # Pass the legend object to tight_layout so it's included in the calculation.
        # This allows you to remove the aggressive 'rect=[0, 0.2, 1, 1]' argument.
        plt.tight_layout()

        filename = plot_dir / f"Summary_{clean_filename(q_id)}.png"
        plt.savefig(filename, dpi=150, bbox_extra_artists=(legend,))
        plt.close()
