import os
import json
import pandas as pd
import numpy as np
from mtsv.parsing import (
    parse_query_id,
    format_cml_params,
    file_type, outfile_type,
    make_table)
from mtsv.utils import (
    script_path,
    track_file_params, get_database_params)


shell.prefix("set -euo pipefail;")


BINNING_SCRIPT = {"unaligned_queries": script_path('MTSv_unaligned_queries.py'),}

REPORT = outfile_type(os.path.join("Reports", "binning_report.html"))
config['fm_index_paths'] = [file_type(f) for f in get_database_params(
    config['database_config'], 'fm-index-paths')]


INDEX_PATH = [os.path.dirname(i) for i in config['fm_index_paths']]
INDEX = [os.path.basename(p).split(".")[0] for p in config['fm_index_paths']]
INDEX_DICT = {i: p for i, p in zip(INDEX, config['fm_index_paths'])}
INDEX_OUT = [
        os.path.join(
            config['binning_outpath'], "{}.bn".format(indx)) for indx in INDEX]
LOG = [os.path.join(config['binning_outpath'], "{}.log".format(indx))
        for indx in INDEX]

MODE_DICT = {'fast': {'seed_size': 17, 'min_seeds': 5, 'seed_gap': 2},
             'efficient': {'seed_size': 14, 'min_seeds': 4, 'seed_gap': 2},
             'sensitive': {'seed_size': 11, 'min_seeds': 3, 'seed_gap': 1}}

BIN_MODE = MODE_DICT[config['binning_mode']]
for key, value in BIN_MODE.items():
    if key not in config:
        # snakemake config does not allow dashes
        # but dashes are required for these mtsv params 
        config[key] = value
    config[key.replace("_", "-")] = value
    del config[key]

CML_PARAMS = format_cml_params(
    'BINNING',
    config,
    ['binning_mode', 'fm_index_paths', 'fasta', 'database_config',
    'merge_file', 'binning_outpath'],
    [])

try:
    SAMPLES = json.loads(open(os.path.join(
        os.path.dirname(config['fasta']),
        ".params"),
        'r').read())['readprep'][config['fasta']]['fastq']
except FileNotFoundError:
    try:
        SAMPLES = config['fastq']
    except KeyError:
        SAMPLES = []
SAMPLES = [os.path.basename(s) for s in SAMPLES]
SAMPLE_NAMES = [os.path.splitext(s)[0] for s in SAMPLES]
UNALIGNED_QUERIES = [os.path.join(config['binning_outpath'],
            "unaligned_queries_{}.fasta".format(s)) for s in SAMPLE_NAMES]


rule binning_all:
    input: REPORT

rule binning_report:
    input:
        os.path.join(config['binning_outpath'],
            'query_stats.json')
    output:
        rep=REPORT
    params:
        header = SAMPLE_NAMES,
        merge_file = config['merge_file']
    message: 
        """
        Running binning report.
        Writing report to {output.rep} 
        Snakemake scheduler assuming {threads} thread(s)
        """ 
    # resources:
    #     mem_mb=lambda wildcards, input, attempt: max(1,
    #         attempt * int(
    #             os.path.getsize(input.clps) * 0.000001 + 
    #             os.path.getsize(input.fasta) * 0.000002))
    run:
        from snakemake.utils import report
      
        summary_data = json.loads(open(input[0], 'r').read())

        table = make_table(
            [params.header, summary_data['total_queries_by_sample'],
            summary_data['total_unique_queries_by_sample'],
            summary_data['total_hits_by_sample'],
            summary_data['total_unique_hits_by_sample'],
            summary_data['total_unaligned_queries_by_sample'],
            summary_data['total_unique_unaligned_queries_by_sample']],
            ["Sample", "Total Queries", "Total Unique Queries",
            "Total Hits", "Total Unique Hits", "Unaligned Queries",
            "Unique Unaligned Queries"])

        file_str = "\n".join(UNALIGNED_QUERIES)
        total_hits = summary_data['total_hits']
        total_unique_hits = summary_data['total_unique_hits']
        total_queries = summary_data['total_queries']
        total_unique_queries = summary_data['total_unique_queries']
        total_unaligned_queries = summary_data['total_unaligned_queries']
        total_unique_unaligned_queries = summary_data[
            'total_unique_unaligned_queries']

        track_file_params('merge_file', config['merge_file'], config)
        
        report("""
        Binning Report
        ================================
        Hits are reported in:\n
        **{params.merge_file}**\n
        There were **{total_hits}** total hits
        and **{total_unique_hits}** unique hits out 
        of **{total_queries}** total queries and
        **{total_unique_queries}** unique queries.\n
        There were **{total_unaligned_queries}** total
        queries and **{total_unique_unaligned_queries}**
        unique queries with no hits.\n
        Unaligned query sequences are reported in:\n
        {file_str}.\n
        Query Stats:\n
        {table}
        """, output['rep'])

rule unaligned_queries:
    """Get queries with no hits"""
    output: 
        out = UNALIGNED_QUERIES,
        summary = os.path.join(config['binning_outpath'],
            'query_stats.json')
    input:
        clps = config['merge_file'],
        fasta = config['fasta']
    params:
        samples = SAMPLE_NAMES,
        outpath = config['binning_outpath']
    threads: config['threads']
    log: config['log_file']
    message:
        """
        Finding unaligned queries.
        Writing unaligned queries to {output}.
        Logging to {log}.
        Snakemake scheduler assuming {threads} thread(s)
        """
    script: BINNING_SCRIPT['unaligned_queries']



            
rule binning:
    """Metagenomics binning"""
    output: os.path.join(config['binning_outpath'], "{index}.bn")
    input:
        reads = config['fasta'],
        fm_index = lambda wildcard: INDEX_DICT[wildcard.index]
        # fm_index = expand(
        #     "{path}/{index}.index", zip, path=INDEX_PATH, index=INDEX )
        # fm_index = os.path.join(INDEX_PATH, "{index}.index")
    # resources:
    #     mem_mb=lambda wildcards, attempt, input: max(1,
    #         int((os.path.getsize(input.reads) * 0.0000012 + 
    #         os.path.getsize(input.fm_index) * 0.0000011)) * 
    #         attempt)
    params:
        args=CML_PARAMS
    message:
        """
        Executing Binner on queries in {input.reads}.
        Using index {input.fm_index}
        Writing files to directory {output}
        Logging to {log}
        Snakemake scheduler assuming {threads} thread(s)""" 
    log:
        os.path.join(config['binning_outpath'], "{index}.log")
    threads:
        config['threads']
    shell:
        """mtsv-binner --index {input.fm_index} --threads {threads} --results {output} --fasta {input.reads} {params.args} >> {log} 2>&1 """


rule collapse:
    """Combine the output of multiple separate mtsv runs. """
    output:
        config['merge_file']
    input:
        INDEX_OUT
    # resources:
    #     mem_mb=lambda wildcards, attempt, input: max(1,
    #         attempt * int(
    #             sum([os.path.getsize(i) for i in input]) * 0.0000012))
    message:
        """
        Merging binning output into {output}.
        Logging to {log}
        Snakemake scheduler assuming {threads} thread(s)"""
    log:
        config['log_file']
    shell:
        """mtsv-collapse {input} -o {output} >> {log} 2>&1"""



