import logging
from pathlib import Path
from collections import defaultdict
from itertools import cycle  # <--- Added to cycle through colors
from typing import List, Tuple, Dict, Set, Optional, Union

# Initialize Logger
log = logging.getLogger(__name__)

# The "9th & 9th Whale" Palette
WHALE_COLORS = [
    "#acdbef",  # Sky Blue
    "#f6ac49",  # Orange
    "#9e6586",  # Mauve
    "#9285b3",  # Purple
    "#265984",  # Dark Blue
    "#e87451",  # Salmon
    "#357cb0",  # Medium Blue
    "#e3d371",  # Yellow
    "#eb5c4b",  # Red-Orange
    "#eba9c3",  # Pink
    "#5b7c46",  # Green
    "#d6cb4d",  # Mustard
    "#b68386",  # Dusty Rose
    "#4fa4a5",  # Teal
    "#dcc5a5",  # Beige
]


# Union-Find Data Structure for efficient grouping
class UnionFind:
    def __init__(self):
        self.parent = {}

    def find(self, i):
        if i not in self.parent:
            self.parent[i] = i
        if self.parent[i] != i:
            self.parent[i] = self.find(self.parent[i])
        return self.parent[i]

    def union(self, i, j):
        root_i = self.find(i)
        root_j = self.find(j)
        if root_i != root_j:
            self.parent[root_i] = root_j


def load_bed_regions(
    bed_file: Path,
) -> Tuple[Dict[str, List[Tuple[int, int, str]]], Set[str]]:
    """
    Loads the shattered BED file to define the 'nodes' of our graph.
    Returns:
        1. Dictionary mapping Chromosome -> List of (start, end, region_id)
        2. Set of all valid region_ids
    """
    regions_by_chrom = defaultdict(list)
    all_regions = set()

    log.info(f"Loading regions from {bed_file.name}...")

    with bed_file.open("r") as f:
        for line in f:
            cols = line.strip().split()
            if len(cols) < 3:
                continue
            chrom = cols[0]
            start = int(cols[1])
            end = int(cols[2])

            # Unique ID for this segment: "seqID:start-end"
            region_id = f"{chrom}:{start}-{end}"

            # Store tuple for overlap searching
            regions_by_chrom[chrom].append((start, end, region_id))
            all_regions.add(region_id)

    # Sort by start position for efficient searching in find_subject_region
    for chrom in regions_by_chrom:
        regions_by_chrom[chrom].sort()

    return regions_by_chrom, all_regions


def find_subject_region(
    chrom: str,
    s_start: int,
    s_end: int,
    regions_by_chrom: Dict[str, List[Tuple[int, int, str]]],
) -> Optional[str]:
    """
    Finds which defined BED region overlaps the BLAST subject hit the most.
    """
    if chrom not in regions_by_chrom:
        return None

    potential_regions = regions_by_chrom[chrom]
    best_overlap = 0
    best_region = None

    # Convert BLAST 1-based to 0-based for consistent comparison with BED
    s_start_0 = s_start - 1
    s_end_0 = s_end

    for r_start, r_end, r_id in potential_regions:
        # Optimization: stop if region is completely past the hit
        if r_start >= s_end_0:
            break
        # Optimization: skip if region ends before hit starts
        if r_end <= s_start_0:
            continue

        # Calculate Overlap length
        overlap_start = max(s_start_0, r_start)
        overlap_end = min(s_end_0, r_end)
        overlap_len = overlap_end - overlap_start

        if overlap_len > best_overlap:
            best_overlap = overlap_len
            best_region = r_id

    return best_region


def hit_grouping(
    bed_file: Union[str, Path],
    blast_file: Union[str, Path],
    outdir: Union[str, Path],
    max_groups: int = 20,
) -> Path:
    """
    Builds a graph where Nodes=BED regions and Edges=BLAST hits.
    Finds connected components (groups), filters for multi-sequence groups,
    keeps only the largest ones, assigns colors, and writes them to file.
    """
    # Ensure Paths
    bed_file = Path(bed_file)
    blast_file = Path(blast_file)
    outdir = Path(outdir)
    output_file = outdir / "final_groups.txt"

    # 1. Load Nodes (Regions)
    regions_by_chrom, all_regions = load_bed_regions(bed_file)

    # 2. Initialize Union-Find
    uf = UnionFind()
    for r in all_regions:
        uf.find(r)

    log.info(f"Building graph from {blast_file.name}...")

    # 3. Build Edges (Connections based on BLAST hits)
    try:
        with blast_file.open("r") as f:
            for line in f:
                cols = line.strip().split("\t")
                if len(cols) < 10:
                    continue

                q_seq = cols[0]
                s_seq = cols[1]

                # Identify Query Region
                try:
                    q_start = int(cols[6])
                    q_end = int(cols[7])
                    s_start = int(cols[8])
                    s_end = int(cols[9])
                except ValueError:
                    continue

                # Normalize Query Coords
                q_s_norm = min(q_start, q_end) - 1
                q_e_norm = max(q_start, q_end)

                query_region_id = f"{q_seq}:{q_s_norm}-{q_e_norm}"

                if query_region_id not in all_regions:
                    continue

                # Identify Subject Region
                subject_region_id = find_subject_region(
                    s_seq, s_start, s_end, regions_by_chrom
                )

                # Create connection if both regions are valid
                if subject_region_id and subject_region_id in all_regions:
                    uf.union(query_region_id, subject_region_id)
    except FileNotFoundError:
        log.error(f"Could not read BLAST file: {blast_file}")
        raise

    # 4. Harvest Groups
    log.info("Grouping components...")
    raw_groups = defaultdict(list)
    for r in all_regions:
        root = uf.find(r)
        raw_groups[root].append(r)

    # 5. Filter and Sort Groups
    log.info(
        f"Filtering groups (keeping largest {max_groups} multi-sequence groups)..."
    )

    valid_groups = []

    for root, members in raw_groups.items():
        # --- FILTER 1: Must involve at least 2 DIFFERENT fasta sequences ---
        unique_sequences = set()
        group_total_length = 0

        for member in members:
            seq_id, coords = member.rsplit(":", 1)
            start, end = map(int, coords.split("-"))
            unique_sequences.add(seq_id)
            group_total_length += end - start

        if len(unique_sequences) >= 2:
            valid_groups.append((group_total_length, members))

    # Sort by total length (largest first) and take top N
    valid_groups.sort(key=lambda x: x[0], reverse=True)
    top_groups = valid_groups[:max_groups]

    # 6. Write Output with Colors
    log.info(f"Writing {len(top_groups)} largest groups to {output_file.name}...")

    # Initialize color cycle
    color_cycle = cycle(WHALE_COLORS)

    group_counter = 1
    with output_file.open("w") as out:
        for _, members in top_groups:

            # Assign color to this group
            current_color = next(color_cycle)

            members.sort()
            for member in members:
                chrom, interval = member.rsplit(":", 1)
                start, end = interval.split("-")
                # Write: GroupID, Chrom, Start, End, Color
                out.write(
                    f"Group_{group_counter}\t{chrom}\t{start}\t{end}\t{current_color}\n"
                )

            group_counter += 1

    log.info(f"Grouping complete. Saved {len(top_groups)} groups.")

    return output_file
