import logging
from pathlib import Path
from collections import defaultdict
from itertools import cycle
from typing import Dict, List, Tuple, Union
import matplotlib.pyplot as plt
from pycirclize import Circos

# Initialize Logger
log = logging.getLogger(__name__)


def get_tick_interval(name: str, length: int) -> int:
    """
    Determines a clean tick interval based on plasmid/chromosome heuristics.
    """
    is_plasmid = "plasmid" in name.lower() or length < 300000
    if is_plasmid:
        if length < 20000:
            return 1000
        return 10000
    else:
        if length > 5000000:
            return 1000000
        return 100000


def load_groups(
    group_file: Path,
) -> Tuple[Dict[str, List[Tuple[int, int, str]]], Dict[str, str]]:
    """
    Parses the group file.
    Returns:
        1. segments_by_chrom: Dict[Chrom -> List[(start, end, group_id)]]
        2. color_map: Dict[group_id -> HexColor]
    """
    segments_by_chrom = defaultdict(list)
    color_map = {}

    log.info(f"Loading groups and colors from {group_file.name}...")

    try:
        with group_file.open("r") as f:
            for line in f:
                cols = line.strip().split("\t")
                if len(cols) < 4:
                    continue

                group_id = cols[0]
                chrom = cols[1]
                start = int(cols[2])
                end = int(cols[3])

                segments_by_chrom[chrom].append((start, end, group_id))

                # Check for Color Column (Column 5)
                if len(cols) >= 5:
                    color_map[group_id] = cols[4]
                else:
                    # Fallback if file was generated with old script
                    color_map[group_id] = "#cccccc"

    except FileNotFoundError:
        log.error(f"Could not find group file: {group_file}")
        raise

    return segments_by_chrom, color_map


def get_seq_lengths(
    segments_by_chrom: Dict[str, List[Tuple[int, int, str]]],
) -> Dict[str, int]:
    """
    Estimates sequence length based on the furthest coordinate found in groups.
    """
    seq_lens = {}
    for chrom, segs in segments_by_chrom.items():
        max_coord = 0
        for s, e, g in segs:
            if e > max_coord:
                max_coord = e
        seq_lens[chrom] = int(max_coord * 1.02)
    return seq_lens


def draw_circular(
    chrom: str,
    length: int,
    segments: List[Tuple[int, int, str]],
    color_map: Dict[str, str],
    outdir: Path,
) -> None:
    """Generates a circular (Circos-style) plot for a single chromosome."""

    # Initialize Circos sector
    sectors = {chrom: length}
    circos = Circos(sectors, space=10)
    sector = circos.get_sector(chrom)

    tick_interval = get_tick_interval(chrom, length)

    # Track 1: Outer Axis (95-100)
    track_axis = sector.add_track((95, 100))
    track_axis.axis(fc="lightgrey")
    track_axis.xticks_by_interval(interval=tick_interval, label_orientation="vertical")

    # Track 2: The Groups (Color Blocks) (85-95)
    track_groups = sector.add_track((85, 95))

    # Track 3: Labels (Inward Space) (30-85)
    track_labels = sector.add_track((30, 85))
    track_labels.axis(fc="none", ec="none")

    seen_groups = set()
    depth_cycle = cycle([80, 72, 64])

    for start, end, group_id in segments:
        # Use the loaded color map
        color = color_map.get(group_id, "#cccccc")
        track_groups.rect(start, end, color=color)

        midpoint = (start + end) / 2
        text_r = next(depth_cycle)

        track_labels.line(
            [midpoint, midpoint], [text_r + 1, 84], color="#666666", lw=0.5, ls=":"
        )
        track_labels.text(
            group_id,
            midpoint,
            r=text_r,
            orientation="vertical",
            va="top",
            ha="center",
            color="black",
            size=5,
        )

        seen_groups.add(group_id)

    # --- Plot and Add Title at Bottom ---
    fig = circos.plotfig()
    fig.text(0.5, 0.02, chrom, ha="center", va="bottom", fontsize=14, color="black")

    # Legend
    handles = []
    labels = []
    sorted_seen = sorted(list(seen_groups))
    for g_id in sorted_seen:
        # Use the loaded color map
        patch = plt.Rectangle((0, 0), 1, 1, color=color_map.get(g_id, "#cccccc"))
        handles.append(patch)
        labels.append(g_id)

    if handles:
        plt.legend(
            handles,
            labels,
            loc="upper right",
            bbox_to_anchor=(1.45, 1.0),
            title="Groups",
            ncol=2,
            fontsize="small",
        )

    out_name = outdir / f"{chrom}_circular.png"
    fig.savefig(out_name, bbox_inches="tight", dpi=300)
    log.info(f"Saved circular plot: {out_name.name}")
    plt.close(fig)


def draw_linear(
    chrom: str,
    length: int,
    segments: List[Tuple[int, int, str]],
    color_map: Dict[str, str],
    outdir: Path,
) -> None:
    """Generates a linear map for a single chromosome."""

    fig, ax = plt.subplots(figsize=(12, 6))

    # Draw backbone line
    ax.plot([0, length], [1, 1], color="black", linewidth=2)

    seen_groups = set()

    # Sort by start position for clean staggering
    segments.sort(key=lambda x: x[0])

    # Cycle through 4 height levels for labels
    label_heights = cycle([1.3, 1.5, 1.7, 1.9])

    for start, end, group_id in segments:
        # Use the loaded color map
        color = color_map.get(group_id, "#cccccc")
        width = end - start

        rect = plt.Rectangle((start, 0.8), width, 0.4, color=color, alpha=0.9)
        ax.add_patch(rect)

        mid = (start + end) / 2

        txt_y = next(label_heights)

        ax.text(
            mid,
            txt_y,
            group_id,
            ha="center",
            va="bottom",
            fontsize=8,
            rotation=45,
            color="black",
        )
        ax.plot(
            [mid, mid],
            [1.2, txt_y],
            color="grey",
            linestyle=":",
            linewidth=0.5,
            alpha=0.5,
        )

        seen_groups.add(group_id)

    ax.set_xlim(-100, length + 100)
    ax.set_ylim(0, 2.5)
    ax.set_yticks([])
    ax.set_xlabel("Base Pairs")
    ax.set_title(f"Linear Map: {chrom}")

    handles = []
    labels = []
    sorted_seen = sorted(list(seen_groups))
    for g_id in sorted_seen:
        # Use the loaded color map
        patch = plt.Rectangle((0, 0), 1, 1, color=color_map.get(g_id, "#cccccc"))
        handles.append(patch)
        labels.append(g_id)

    if handles:
        ax.legend(
            handles,
            labels,
            loc="center left",
            bbox_to_anchor=(1, 0.5),
            title="Groups",
            ncol=2,
        )

    out_name = outdir / f"{chrom}_linear.png"
    plt.tight_layout()
    fig.savefig(out_name, dpi=300)
    log.info(f"Saved linear plot: {out_name.name}")
    plt.close(fig)


def visualize_groups(group_file: Union[str, Path], outdir: Union[str, Path]) -> None:
    """
    Main entry point to visualization.
    """
    group_file = Path(group_file)
    outdir = Path(outdir)

    # Load segments AND the new color map
    segments_by_chrom, color_map = load_groups(group_file)

    if not segments_by_chrom:
        log.warning("No groups found to visualize.")
        return

    seq_lens = get_seq_lengths(segments_by_chrom)

    log.info(f"Generating plots for {len(segments_by_chrom)} sequences...")

    for chrom, segments in segments_by_chrom.items():
        length = seq_lens[chrom]

        try:
            draw_circular(chrom, length, segments, color_map, outdir)
        except Exception as e:
            log.error(f"Error drawing circular plot for {chrom}: {e}")

        try:
            draw_linear(chrom, length, segments, color_map, outdir)
        except Exception as e:
            log.error(f"Error drawing linear plot for {chrom}: {e}")
