import logging
from pathlib import Path
from collections import defaultdict
from typing import Tuple, Union

# Initialize logger
log = logging.getLogger(__name__)


def map_subject_coords(
    orig_q_start: int,
    orig_s_start: int,
    orig_s_end: int,
    new_q_start: int,
    new_q_end: int,
) -> Tuple[int, int]:
    """
    Calculates new subject (target) coordinates based on how much
    the query (source) was trimmed. Handles both + and - strands.
    """
    # Calculate the offset from the original query start
    # e.g., Original 200, New 250 -> Offset 50
    offset = new_q_start - orig_q_start

    # Calculate the delta length (difference between end and start)
    delta_len = new_q_end - new_q_start

    if orig_s_start < orig_s_end:
        # PLUS STRAND: Subject increases as Query increases
        # New Subject Start = Old Subject Start + Offset
        new_s_start = orig_s_start + offset
        new_s_end = new_s_start + delta_len
    else:
        # MINUS STRAND: Subject decreases as Query increases
        # New Subject Start = Old Subject Start - Offset
        new_s_start = orig_s_start - offset
        new_s_end = new_s_start - delta_len

    return new_s_start, new_s_end


def trim_blast(
    blast_file: Union[str, Path],
    shattered_file: Union[str, Path],
    outdir: Union[str, Path],
) -> Path:
    """
    Trims original BLAST hits to match the shattered BED segments.

    Args:
        blast_file: Path to original BLAST output (fmt 6).
        shattered_file: Path to the 'divided.bed' file.
        outdir: Directory to save the output.

    Returns:
        Path: The path to the new 'trimmed_blast.tsv' file.
    """

    # Convert to Path objects to be safe
    blast_file = Path(blast_file)
    shattered_file = Path(shattered_file)
    outdir = Path(outdir)

    # Generate output filename automatically using Path / operator
    output_file = outdir / "trimmed_blast.tsv"

    # 1. Load the Divided BED regions into memory
    # Structure: segments[query_id] = [(start, end), (start, end)...]
    log.info(f"Loading shattered segments from {shattered_file.name}...")

    segments = defaultdict(list)
    try:
        with shattered_file.open("r") as f:
            for line in f:
                cols = line.strip().split()
                if len(cols) < 3:
                    continue
                q_id = cols[0]

                # BED is 0-based half-open. Convert to 1-based closed for BLAST comparison
                s = int(cols[1]) + 1
                e = int(cols[2])
                segments[q_id].append((s, e))
    except FileNotFoundError:
        log.error(f"Could not find shattered file: {shattered_file}")
        raise

    # Ensure segments are sorted by start position for reliable processing
    for q_id in segments:
        segments[q_id].sort()

    log.info(f"Processing hits from {blast_file.name}...")

    # 2. Stream the BLAST file, trim hits, and write to output file
    output_count = 0

    try:
        with blast_file.open("r") as f, output_file.open("w") as out:
            for line in f:
                cols = line.strip().split("\t")
                if len(cols) < 12:
                    continue  # Skip malformed lines

                q_id = cols[0]

                # Original BLAST coordinates (normalize min/max for easier logic)
                try:
                    raw_q_start = int(cols[6])
                    raw_q_end = int(cols[7])
                    s_start = int(cols[8])
                    s_end = int(cols[9])
                except ValueError:
                    continue

                q_start = min(raw_q_start, raw_q_end)
                q_end = max(raw_q_start, raw_q_end)

                if q_id not in segments:
                    continue

                query_segments = segments[q_id]

                # Find overlaps
                for seg_start, seg_end in query_segments:

                    # Check for overlap
                    overlap_start = max(q_start, seg_start)
                    overlap_end = min(q_end, seg_end)

                    if overlap_start <= overlap_end:
                        # Intersection found.

                        # 1. Calculate new Subject coordinates
                        new_s_start, new_s_end = map_subject_coords(
                            q_start, s_start, s_end, overlap_start, overlap_end
                        )

                        # 2. Update columns
                        new_cols = list(cols)

                        # Update Length (Col 3 - index 3) -> count of bases
                        new_len = overlap_end - overlap_start + 1
                        new_cols[3] = str(new_len)

                        # Update Query Coords (Col 6, 7 - indices 6, 7)
                        new_cols[6] = str(overlap_start)
                        new_cols[7] = str(overlap_end)

                        # Update Subject Coords (Col 8, 9 - indices 8, 9)
                        new_cols[8] = str(new_s_start)
                        new_cols[9] = str(new_s_end)

                        # Write to file
                        out.write("\t".join(new_cols) + "\n")
                        output_count += 1
    except FileNotFoundError:
        log.error(f"Could not read BLAST file: {blast_file}")
        raise

    log.info(
        f"Trimming complete. Wrote {output_count} split hits to {output_file.name}."
    )

    return output_file
