import logging
from pathlib import Path
from collections import defaultdict
from typing import List, Tuple, Set, Union

# Initialize logger
log = logging.getLogger(__name__)


def get_shattered_segments(
    intervals: Set[Tuple[int, int]], min_len: int
) -> List[Tuple[int, int]]:
    """
    Calculates shattered segments for a specific list of intervals.

    Args:
        intervals: A set of (start, end) tuples.
        min_len: Minimum length of a segment to retain.

    Returns:
        A list of sorted (start, end) tuples representing the shattered segments.
    """
    # 1. Collect all cut points (both starts and ends)
    cut_points = set()
    for start, end in intervals:
        cut_points.add(start)
        cut_points.add(end)

    # 2. Sort unique cut points
    sorted_cuts = sorted(list(cut_points))

    final_segments = []

    # 3. Iterate through adjacent cut points
    for i in range(len(sorted_cuts) - 1):
        seg_start = sorted_cuts[i]
        seg_end = sorted_cuts[i + 1]
        seg_len = seg_end - seg_start

        # FINAL FILTER: This is where we exclude small segments "at the very end"
        if seg_len < min_len:
            continue

        # 4. Check if this candidate segment is covered by ANY original hit.
        # We check the midpoint to be safe.
        midpoint = (seg_start + seg_end) / 2

        is_covered = False
        for hit_s, hit_e in intervals:
            if hit_s <= midpoint < hit_e:
                is_covered = True
                break

        if is_covered:
            final_segments.append((seg_start, seg_end))

    return final_segments


def parse_blast_and_shatter(
    blast_file: Union[str, Path],
    outdir: Union[str, Path] = None,  # Made optional to handle CLI quirks if needed
    min_len: int = 500,
    identity_cutoff: float = 90.0,
) -> Path:
    """
    Parses BLAST hits, groups by Query ID, shatters overlapping
    hits, and writes everything to ONE output file.

    Args:
        blast_file: Path to the input BLAST tab-delimited file.
        outdir: Directory to save the output BED file. (Defaults to parent of blast_file if None)
        min_len: Minimum length of segments to keep.
        identity_cutoff: Minimum percent identity to process.

    Returns:
        Path: The path to the generated .bed file.
    """

    # Ensure inputs are Path objects
    blast_file = Path(blast_file)
    if outdir is None:
        outdir = blast_file.parent
    else:
        outdir = Path(outdir)

    # Dictionary to store hits: { "query_id": {(start, end), (start, end)} }
    query_hits = defaultdict(set)

    log.info(f"Reading BLAST results from {blast_file.name}...")

    try:
        with blast_file.open("r") as f:
            for line in f:
                cols = line.strip().split("\t")
                if len(cols) < 10:
                    continue

                qseqid = cols[0]
                try:
                    pident = float(cols[2])
                    qstart = int(cols[6])
                    qend = int(cols[7])
                except ValueError:
                    continue

                # EARLY FILTER: Only filter by identity here.
                if pident >= identity_cutoff:

                    # Normalize coordinates
                    s = min(qstart, qend)
                    e = max(qstart, qend)

                    # Convert to 0-based half-open (BED standard)
                    s_0 = s - 1
                    e_0 = e

                    query_hits[qseqid].add((s_0, e_0))
    except FileNotFoundError:
        log.error(f"Could not find BLAST file: {blast_file}")
        raise

    # Define output filename using pathlib / operator
    output_file = outdir / "divided.bed"

    log.info(f"Calculating shattered segments and writing to {output_file.name}...")

    with output_file.open("w") as out:
        # Sort keys for organized output
        for qseqid in sorted(query_hits.keys()):
            intervals = query_hits[qseqid]

            # The min_len filter is passed here and applied to the result
            segments = get_shattered_segments(intervals, min_len)

            for s, e in segments:
                out.write(f"{qseqid}\t{s}\t{e}\n")

    log.info("Shattering complete.")

    # IMPORTANT: Return the path so cli.py can pass it to the next step
    return output_file
