import logging
import re
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Union

# Initialize Logger
log = logging.getLogger(__name__)

# Okabe-Ito Color Palette (Color Blind Friendly)
CB_PALETTE = [
    "#E69F00",  # Orange
    "#56B4E9",  # Sky Blue
    "#009E73",  # Bluish Green
    "#F0E442",  # Yellow
    "#0072B2",  # Blue
    "#D55E00",  # Vermilion
    "#CC79A7",  # Reddish Purple
    "#000000",  # Black
]


def clean_filename(s: str) -> str:
    """Sanitizes a string to be safe for use as a filename."""
    s = str(s)
    s = re.sub(r"[^a-zA-Z0-9_.-]", "_", s)
    s = re.sub(r"_+", "_", s)
    return s[:100]


def generate_dotplots(blast_file: Union[str, Path], outdir: Union[str, Path]) -> None:
    """
    Generates pairwise and combined dotplots from a BLAST tabular file.
    """
    blast_file = Path(blast_file)
    outdir = Path(outdir)
    plot_dir = outdir / "dotplots"
    plot_dir.mkdir(parents=True, exist_ok=True)

    if not blast_file.exists():
        log.error(f"File '{blast_file}' not found.")
        return

    # BLAST format 6 columns
    cols = [
        "qseqid",
        "sseqid",
        "pident",
        "length",
        "mismatch",
        "gapopen",
        "qstart",
        "qend",
        "sstart",
        "send",
        "evalue",
        "bitscore",
    ]

    log.info(f"Generating dotplots from {blast_file.name}...")

    try:
        df = pd.read_csv(blast_file, sep="\t", names=cols, on_bad_lines="warn")
    except Exception as e:
        log.error(f"Error reading BLAST file: {e}")
        return

    # Validate and Clean Data
    numeric_cols = ["pident", "qstart", "qend", "sstart", "send"]
    for col in numeric_cols:
        df[col] = pd.to_numeric(df[col], errors="coerce")

    df = df.dropna(subset=numeric_cols)
    if df.empty:
        log.warning("No valid BLAST data found for dotplots.")
        return

    # --- STEP 1: Estimate Sequence Lengths (Use Self-Hits) ---
    # Note: This is an approximation based on the furthest hit.
    query_max_coords = df.groupby("qseqid")["qend"].max().to_dict()
    subject_max_coords = df.groupby("sseqid")["send"].max().to_dict()

    # --- STEP 2: Filter Data ---
    # We only want to plot comparisons between DIFFERENT sequences
    df_filtered = df[df["qseqid"] != df["sseqid"]].copy()

    if df_filtered.empty:
        log.info("No non-self hits found. Skipping dotplots.")
        return

    log.info(f"Saving dotplots to: {plot_dir}")

    # --- PART 1: Individual Dot Plots ---
    pairs = df_filtered.groupby(["qseqid", "sseqid"])

    for (q_id, s_id), group in pairs:
        fig, ax = plt.subplots(figsize=(8, 8))

        # marker='.' ensures even single-point/short matches are visible
        for _, row in group.iterrows():
            ax.plot(
                [row["qstart"], row["qend"]],
                [row["sstart"], row["send"]],
                color="black",
                linewidth=2.0,
                alpha=1.0,
                linestyle="-",
                marker=".",
                markersize=4,
            )

        ax.set_title(f"Query: {q_id} vs Subject: {s_id}")
        ax.set_xlabel(f"Position on {q_id} (bp)")
        ax.set_ylabel(f"Position on {s_id} (bp)")

        # Set limits based on estimated max lengths
        if q_id in query_max_coords:
            ax.set_xlim(0, query_max_coords[q_id])
        if s_id in subject_max_coords:
            ax.set_ylim(0, subject_max_coords[s_id])

        ax.grid(True, which="both", linestyle="--", alpha=0.4)

        filename = (
            plot_dir / f"Pair_{clean_filename(q_id)}_vs_{clean_filename(s_id)}.png"
        )
        fig.savefig(filename, dpi=150)
        plt.close(fig)

    # --- PART 2: Combined Plots ---
    # Plots one Query vs ALL its Subjects on one graph
    queries = df_filtered.groupby("qseqid")

    unique_subjects = sorted(df_filtered["sseqid"].unique())
    color_dict = {
        subj: CB_PALETTE[i % len(CB_PALETTE)] for i, subj in enumerate(unique_subjects)
    }

    for q_id, group in queries:
        fig, ax = plt.subplots(figsize=(12, 10))

        subjects_in_query = group.groupby("sseqid")
        num_subjects = len(subjects_in_query)

        for s_id, sub_group in subjects_in_query:
            color = color_dict[s_id]

            first_line = True
            for _, row in sub_group.iterrows():
                label = s_id if first_line else None

                ax.plot(
                    [row["qstart"], row["qend"]],
                    [row["sstart"], row["send"]],
                    color=color,
                    linewidth=2.0,
                    linestyle="-",
                    label=label,
                    alpha=1.0,
                    marker=".",
                    markersize=4,
                )
                first_line = False

        ax.set_title(f"Combined Alignments for Query: {q_id}")
        ax.set_xlabel(f"Position on {q_id} (bp)")
        ax.set_ylabel("Subject Position (bp)")

        if q_id in query_max_coords:
            ax.set_xlim(0, query_max_coords[q_id])

        # Determine Y-limit based on the max length of subjects present in THIS plot
        current_subjects = group["sseqid"].unique()
        current_max_y = 0
        for s in current_subjects:
            if s in subject_max_coords:
                current_max_y = max(current_max_y, subject_max_coords[s])

        if current_max_y > 0:
            ax.set_ylim(0, current_max_y)

        ax.grid(True, which="both", linestyle="--", alpha=0.4)

        handles, labels = ax.get_legend_handles_labels()
        if handles:
            if num_subjects <= 20:
                ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0)
            else:
                ax.text(
                    1.02,
                    0.5,
                    f"Too many subjects\n({num_subjects}) to legend.",
                    transform=ax.transAxes,
                    fontsize=10,
                    verticalalignment="center",
                )

        plt.tight_layout()

        filename = plot_dir / f"Combined_{clean_filename(q_id)}.png"
        fig.savefig(filename, dpi=150)
        plt.close(fig)
