import os
import json
import itertools as it
from mtsv.utils import script_path, get_database_params
from glob import glob
from mtsv.parsing import (
    parse_output_row, file_type, parse_query_id, outfile_type, make_table)
from Bio import SeqIO
SCRIPT = script_path('MTSv_extract.py')
REPORT = outfile_type(os.path.join("Reports", "extract_report.html"))

# check if input file is merge file or sig file and get params for database
try:
    for k, v in json.loads(open(os.path.join(
        os.path.dirname(config['input_hits']),
        ".params"),
        'r').read())['summary_file'].items():
        if v['signature_file'] == config['input_hits']:
            sig_prior_params = v

        merge_prior_params = json.loads(open(os.path.join(
            os.path.dirname(sig_prior_params['merge_file']),
            ".params"),
            'r').read())['merge_file'][sig_prior_params['merge_file']]

except KeyError:
    merge_prior_params = json.loads(open(os.path.join(
            os.path.dirname(config['input_hits']),
            ".params"),
            'r').read())['merge_file'][config['input_hits']]

if "database_config" not in config:
    config['database_config'] = merge_prior_params['database_config']

config['taxdump_path'] = file_type(get_database_params(
    config['database_config'], 'taxdump-path'))


SAMPLE_NAMES = []
if config['by_sample']:
    # get number of samples
    with open(config['input_hits'], 'r') as infile:
        n_samples = len(
            parse_output_row(
                infile.readline()).counts)
    SAMPLE_NAMES = json.loads(open(os.path.join(
    os.path.dirname(merge_prior_params['fasta']),
    ".params"),
    'r').read())['readprep'][merge_prior_params['fasta']]['fastq']
    SAMPLE_NAMES= [os.path.splitext(os.path.basename(f))[0] for f in SAMPLE_NAMES]
    FASTA_OUTPUT = [os.path.join(
        config['extract_path'],
        "{0}_{1}.fasta".format(tax, n))
              for tax, n in it.product(
                  config['taxids'],
                  SAMPLE_NAMES)]
    

else:
    FASTA_OUTPUT = [os.path.join(config['extract_path'], str(tax) + ".fasta") 
    for tax in config['taxids']]

FASTQ_OUTPUT = [os.path.splitext(fasta)[0] + ".fastq" for fasta in FASTA_OUTPUT]
FASTA = file_type(merge_prior_params['fasta'])

rule extract_all:
    input:
        REPORT

rule extract_report:
    input:
        fasta=FASTA_OUTPUT,
        fastq=FASTQ_OUTPUT
    output:
        REPORT
    # resources:
    #     mem_mb=lambda wildcards, input, attempt: max(1, attempt * int(
    #         os.path.getsize(input[0]) * 0.000001 +
    #         os.path.getsize(input[1]) * 0.000001))
    resources:
        mem_mb=200
    message:
        """
        Generating extract report from {input.fasta} and {input.fastq}
        Writing to {output}
        Snakemake scheduler assuming {threads} threads and
        {resources.mem_mb} Mb of memory. """

    params:
        path=config['extract_path'],
        source=config['input_hits'],
        taxa=config['taxids']
    run:
        from snakemake.utils import report
        total_fasta_queries = [
            len(list(SeqIO.parse(fasta, "fasta"))) for fasta in input.fasta]
        total_fastq_queries = [
            len(list(SeqIO.parse(fastq, "fastq"))) for fastq in input.fastq]
        fasta_files = [os.path.basename(out) for out in input.fasta]
        fastq_files = [os.path.basename(out) for out in input.fastq]
        _files = fasta_files + fastq_files
        total_queries = total_fasta_queries + total_fastq_queries
        data = [_files, total_queries]
        data = make_table(data, ["File", "No. Queries"])
        # data = "\n".join(["{0:<26} {1:<11}".format(f, q) for f, q in zip(
        #     _files, total_queries)])
        report(
            """ 
            Extract Report
            ============================
            Hits from **{params.source}** for taxids **{params.taxa}**
            were extracted to fasta and fastq files at:\n
            **{params.path}**\n
            Number of queries per file:\n
            {data}

            """, output[0])



rule extract:
    """Get query sequences associated with a given taxid"""
    output: 
        fasta_out = FASTA_OUTPUT,
        fastq_out = FASTQ_OUTPUT
    input: 
        fasta=FASTA,
        clp=config['input_hits']
    resources:
        mem_mb=2000
    # resources:
    #     mem_mb=lambda wildcards, input, attempt: max(1, (attempt * int(
    #         os.path.getsize(input[0]) * 0.000001 +
    #         os.path.getsize(input[1]) * 0.000001)) if attempt == 1
    #         else ((2000 * attempt) + int(
    #         os.path.getsize(input[0]) * 0.000001 +
    #         os.path.getsize(input[1]) * 0.000001)))
    params:
        taxdump=config['taxdump_path'],
        taxids=config['taxids'],
        outpath=config['extract_path'],
        descendants=config['descendants'],
        by_sample=config['by_sample'],
        sample_names=SAMPLE_NAMES
    threads:
        config['threads']
    log:
        config['log_file']
    message:
        """
        Extracting query sequences from {input.fasta}
        for taxids: {params.taxids}
        that were hits in {input.clp}
        Writing to {output}
        Logging to {log}
        Snakemake scheduler assuming {threads} threads and 
        {resources.mem_mb} Mb of memory. """
    script: SCRIPT