import os
from mtsv.utils import track_file_params
from mtsv.parsing import (
    parse_query_id, format_cml_params, outfile_type, make_table)
from Bio import SeqIO
import numpy as np
from operator import add

shell.prefix("set -euo pipefail;")
REPORT = outfile_type(os.path.join("Reports", "readprep_report.html"))

CML_PARAMS = format_cml_params(
    'READPREP', config,
    ['fasta', 'kmer', 'trim_mode'],
    ['--segment', str(config['kmer'])] if config['trim_mode'] == 'segment'
    else ['--{}'.format(config['trim_mode'])]
)

def get_unique_counts(counts):
    return np.array(np.array(counts, dtype=bool), dtype=int)


rule readprep_all:
    input:
        REPORT

rule readprep_report:
    input: config['fasta']
    output:
        REPORT
    params: config['fastq']
    message: 
        """
        Running readprep report. Writing to {output}.
        Snakemake scheduler assuming {threads} thread(s) """
    # resources:
    #     mem_mb=lambda wildcards, attempt, input: max(1, attempt * int(
    #         os.path.getsize(input[0]) * 0.0000005))
    run:
        from snakemake.utils import report
        record_counts = []
        unique_counts = []
        query_size = 0
        counts = []
        total_queries = 0

        with open(input[0]) as fasta:
            for i, record in enumerate(SeqIO.parse(fasta, "fasta")):
                total_queries += 1
                if i == 0:
                    query_size = len(record.seq)
                    record_counts = parse_query_id(record.id)
                    counts.append(np.sum(record_counts))
                    unique_counts = get_unique_counts(record_counts)
                else:
                    count_list = parse_query_id(record.id)
                    record_counts = record_counts + count_list
                    unique_counts = unique_counts + get_unique_counts(
                        count_list)
                    counts.append(np.sum(count_list))
        file_string = "\n".join(params[0])
        if len(counts):
            mean_counts = np.around(np.mean(counts), decimals=2)
            median_counts = np.median(counts)
            max_counts = np.max(counts)
            min_counts = np.min(counts)
        else:
            mean_counts = 0
            median_counts = 0
            max_counts = 0
            min_counts = 0
        # stats = "{0:<15}{1:<15.2f}{2:<15}{3:<15}".format(
        #     min_counts,
        #     mean_counts,
        #     median_counts,
        #     max_counts
        # )

        stats = [[min_counts], [mean_counts], [median_counts], [max_counts]]
        stats = make_table(stats, ["Min", "Mean", "Median", "Max"])
        data = [[os.path.basename(p) for p in params[0]],
            range(1, len(record_counts) + 1),
            record_counts,
            unique_counts]
        data = make_table(data,
            ["Sample File",
            "Sample ID", "No. of Queries", "No. Unique Queries"])

        # data = ["{0:<21}{1:<11}{2:<15}{3:<18}".format(i,j,k,l) for i,j,k,l in zip(
        #     [os.path.basename(p) for p in params[0]],
        #     range(1, len(record_counts) + 1),
        #     record_counts,
        #     unique_counts)]
        # data = "\n".join(data)
        config['kmer'] = query_size
        track_file_params('readprep', config['fasta'], config)
        report("""
        Readprep Report
        ===================================

        **Queries were generated from:** \n
        {file_string}\n
        This resulted in **{total_queries}** unique query sequences
        of size **{query_size}**.\n
        **Written to:**\n
        {input}\n
        Statistics for the number of copies per unique query:\n
        {stats}

        By Sample:\n
        {data}
        """, output[0])

rule readprep:
    """Read fragment quality control and deduplication (FASTQ -> FASTA)"""
    output: config['fasta']
    input: expand("{fastq}", fastq=config['fastq'])
    resources: 
        mem_mb=lambda wildcards, attempt, input: max(1, attempt * int(
            sum([os.path.getsize(i) for i in input]) * 0.0000015))
    params:
        args=CML_PARAMS
    threads: config['threads']
    log: config['log_file']
    message: 
        """
        Running Readprep on {input}.
        Writing to {output}
        Logging to {log}
        Snakemake scheduler assuming {threads} thread(s) and 
        {resources.mem_mb} Mb of memory. """
    shell:
        '''mtsv-readprep {input} -o {output} {params.args} >> {log} 2>&1'''
