#! /usr/bin/env python
"""Using the ranges of BX molecule start/stop positions, calculates "molecular coverage" across the genome."""

import os
import sys
import gzip
import argparse
from collections import Counter

def new_intervals(contig_len, windowsize) -> list:
    starts = list(range(0, contig_len + 1, windowsize))
    ends = [i-1 for i in starts[1:]]
    if not ends or ends[-1] != contig_len:
        ends.append(contig_len)
    return [range(i,j) for i,j in zip(starts,ends)]

def quantify_overlaps(start: int, end: int, binlist: list):
    """return a list of tuples of of (interval_idx, bp) for which genomic intervals the molecule spans"""
    result = []
    counting_started = False
    for idx, val in enumerate(binlist):
        if start in val and end in val:
            # capture the end-start within a single interval
            bp = end - start
            result.append((idx,bp))
            break
        if start in val or end in val:
            counting_started = True
            # capture the number of bases in the section
            bp = min(val.stop, end) - max(start, val.start)
            result.append((idx, bp))
            if end in val:
                break
        elif counting_started:
            # capture the entire length of the interval
            result.append((idx, len(val)))
    return result

def print_depth_counts(contig, counter_obj, intervals):
    """Print the Counter object to stdout"""
    for idx,int_bin in enumerate(intervals):
        try:
            sys.stdout.write(f"{contig}\t{int_bin.start}\t{int_bin.stop}\t{counter_obj[idx]/len(int_bin)}\n")
        except ZeroDivisionError:
            continue
def main():
    parser = argparse.ArgumentParser(
        prog = 'molecule_coverage',
        description =
        """
        Using the statsfile generated by bx_stats from Harpy,
        will calculate "molecular coverage" across the genome.
        Molecular coverage is the "effective" alignment coverage
        if you treat a molecule inferred from linked-read data as
        one contiguous alignment, even though the reads that make
        up that molecule don't cover its entire length. Requires a
        FASTA fai index (like the kind created using samtools faidx)
        to know the actual sizes of the contigs. Prints to stdout.
        """,
        usage = "molecule_coverage -w 50000 -f genome.fasta.fai statsfile > output.cov",
        exit_on_error = False
        )

    parser.add_argument('-f', '--fai', required = True, type = str, help = "FASTA index (.fai) file of genome used for alignment")
    parser.add_argument('-w', '--window', required = True, type = int, help = "Window size (in bp) to sum depths over")
    parser.add_argument('statsfile', help = "stats file produced by harpy via bx_stats script")

    if len(sys.argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)

    args = parser.parse_args()
    err = []
    for i in [args.statsfile, args.fai]:
        if not os.path.exists(i):
            err.append(i)
    if err:
        parser.error("Input files were not found:\n" + ", ".join(err))
    if args.window == 0:
        parser.error("--window must be greater than 0")

    contigs = {}

    # read the fasta index file as a dict of contig lengths
    with open(args.fai, "r", encoding= "utf-8") as fai:
        for line in fai:
            splitline = line.split()
            contig = splitline[0]
            length = splitline[1]
            contigs[contig] = int(length)


    with gzip.open(args.statsfile, "rt") as statsfile:
        LASTCONTIG = None
        IDX_CONTIG = None
        IDX_START = None
        IDX_END = None
        # read in the header
        line = statsfile.readline()
        # for safety, find out which columns are the contig, start, and end positions
        # just in case this order changes at some point for some reason
        header = line.rstrip().split()
        for idx,val in enumerate(header):
            if val.strip() == "contig":
                IDX_CONTIG = idx
            if val.strip() == "start":
                IDX_START = idx
            if val.strip() == "end":
                IDX_END = idx
        if IDX_CONTIG is None or IDX_START is None or IDX_END is None:
            parser.error("Required columns 'contig', 'start', or 'end' not found in header\n")
        counter = Counter()
        geno_intervals = []
        for line in statsfile:
            if line.startswith("#"):
                continue
            splitline = line.split()
            contig = splitline[IDX_CONTIG]
            if contig != LASTCONTIG:
                if LASTCONTIG:
                    # write to file when contig changes
                    print_depth_counts(LASTCONTIG, counter, geno_intervals)
                # create/reset counter object and genomic intervals
                geno_intervals = new_intervals(contigs[contig], args.window)
                counter = Counter({key: 0 for key in range(len(geno_intervals))})
            aln_start = int(splitline[IDX_START])
            aln_end = int(splitline[IDX_END])
            for idx,bp in quantify_overlaps(aln_start, aln_end, geno_intervals):
                counter[idx] += bp
            LASTCONTIG = contig
        # print last contig
        print_depth_counts(LASTCONTIG, counter, geno_intervals)
