import os
from glob import glob
from snakemake.utils import makedirs

makedirs(["logs"])

def get_threshold(config):
    if config['concoct_opts']:
        
        concoct_opts = config['concoct_opts'] 
        try:
            idx = concoct_opts.index("-l")
        except ValueError:
            try:
                idx = concoct_opts.index("--length_threshold")
            except ValueError:
                # Use default value
                return ('1000', " ".join(concoct_opts))
        threshold = [concoct_opts.pop(idx), concoct_opts.pop(idx)][-1]
        return (threshold, " ".join(concoct_opts))
    else:
        return ('1000', '')
    
def get_samples(unaligned_queries):
    return [
        os.path.splitext(os.path.basename(filename))[0].replace(
            "unaligned_queries_", "") for
        filename in unaligned_queries]




INPUT_FILES = config['unaligned_queries']
SAMPLES = get_samples(INPUT_FILES)
SAMPLE_DICT = {sample: infile for sample, infile in zip(SAMPLES, INPUT_FILES)}
THRESHOLD, CONCOCT_OPTS = get_threshold(config)

def get_input_files(wildcards):
    return SAMPLE_DICT[wildcards.sample]


rule concoct_all:
    input: expand("concoct_{sample}/{sample}.done", sample=SAMPLES)

        
rule megahit:
    """Generate metagenomic assembly using megahit."""
    input: get_input_files
        # expand(
        #     "{path}/unaligned_queries_{{sample}}.fasta",
        #     path=config['unaligned_queries'])
    output: "megahit_{sample}/final.contigs.fa"
    params: " ".join(config['megahit_opts']) if 'megahit_opts' in config else ""
    log: "logs/{sample}_megahit.log"
    threads: config['threads']
    message:
        """
        Getting metagenomic assembly (megahit) from unaligned queries {input}.
        Writing to {output}.
        Additional Params: {params}
        Logging to {log}.
        Using {threads} thread(s).
        """
    shell:
        """
        echo MEGAHIT {wildcards.sample} >> {log}
        rm -rf megahit_{wildcards.sample}
        megahit -r {input} -o megahit_{wildcards.sample} -t {threads} {params} >> {log} 2>&1 
        echo END MEGAHIT {wildcards.sample} >> {log}
        """

rule cut_up:
    """Cut assembled contigs into smaller parts"""
    input: "megahit_{sample}/final.contigs.fa"
    output:
        bed="concoct_{sample}/{sample}.bed",
        cut="concoct_{sample}/{sample}_cut.fasta"
    params: " ".join(config['cutup_opts']) if 'cutup_opts' in config else ""
    threads: 1
    message:
        """
        Cutting assembled contings into smaller parts using {input}.
        Writing to {output.bed} and {output.cut}.
        Additional Params: {params}
        Using {threads} thread(s).
        """
    shell:
        """
        mkdir -p concoct_{wildcards.sample};
        cut_up_fasta.py {input} -b {output.bed} {params} > {output.cut}
        """


rule fasta2fastq:
    """
    Covert unaligned fasta queries to a fastq adding duplicates
    back and supplying fake quality scores
    """
    input: get_input_files
        # expand(
        #     "{path}/unaligned_queries_{{sample}}.fasta",
        #     path=config['unaligned_queries'])
    output: "concoct_{sample}/unaligned_queries_{sample}.fastq"
    threads: 1
    message:
        """
        Converting unaligned fasta queries {input}
        to fastq {output}.
        Using {threads} thread(s).
        """
    run:
        from Bio import SeqIO
        os.makedirs(os.path.dirname(output[0]), exist_ok=True)
        with open(input[0], 'r') as handle:
            with open(output[0], 'w') as out:
                for record in SeqIO.parse(handle, 'fasta'):
                    record.letter_annotations["phred_quality"] = (
                        [40] * len(record))
                    name = record.id
                    for rep in range(int(record.id.split("_")[-1])):
                        SeqIO.write(record, out, "fastq")


rule make_bam:
    """Generate a bam file by mapping queries back to the contigs"""
    input:
        ref="megahit_{sample}/final.contigs.fa", 
        fastq="concoct_{sample}/unaligned_queries_{sample}.fastq"
    output: 
        bam="concoct_{sample}/{sample}.bam",
        bai="concoct_{sample}/{sample}.bam.bai"
    params: " ".join(config['bwa_opts']) if 'bwa_opts' in config else ""
    log: "logs/{sample}_bam.log"
    threads: config['threads']
    message:
        """
        Mapping reads (BWA MEM) and indexing (Samtools) 
        bam file using from reference {input.ref} 
        and reads {input.fastq}.
        Writing to {output.bam} and {output.bai}.
        Additional Params: {params}
        Logging to {log}.
        Using {threads} thread(s).
        """
    shell:
        """
        echo MAKE BAM {wildcards.sample} >> {log}
        if [[ -s {input.ref} ]]
        then
            bwa index {input.ref} &>> {log}
            bwa mem -t {threads} {params} {input.ref} {input.fastq} 2>> {log} | \
            samtools view -uS | samtools sort -@ {threads} -o {output.bam} &>> {log}
            samtools index {output.bam} &>> {log}
        else
            touch {output.bam} {output.bai}
            echo "NO ASSEMBLY: DETECTED" >> {log}
        fi
        echo END MAKE BAM {wildcards.sample} >> {log}
        """

rule coverage_table:
    """Generate table with coverage depth information per sample and subcontig."""
    input:
        bed="concoct_{sample}/{sample}.bed",
        bai="concoct_{sample}/{sample}.bam.bai",
        bam="concoct_{sample}/{sample}.bam"
    output: "concoct_{sample}/{sample}_coverage_table.tsv"
    log: "logs/{sample}_coverage_table.log"
    threads: 1
    message:
        """
        Generating coverage table using {input.bed}, {input.bai} and {input.bam}.
        Writing to {output}.
        Logging to {log}.
        Using {threads} thread(s).
        """ 
    shell:
        """
        echo COVERAGE_TABLE {wildcards.sample} >> {log}
        if [[ -s {input.bam} ]]
        then
            concoct_coverage_table.py {input.bed} {input.bam} > {output}
        else
            touch {output}
            echo "SKIPPING COVERAGE" >> {log}
        fi
        echo END COVERAGE_TABLE {wildcards.sample} >> {log}
        """



rule concoct:
    """Run concoct binning"""
    input:
        cov="concoct_{sample}/{sample}_coverage_table.tsv",
        cut="concoct_{sample}/{sample}_cut.fasta"
    output: 
        expand(
            "concoct_{{sample}}/clustering_gt{param}.csv",
            param=THRESHOLD)
    params:
        threshold=THRESHOLD,
        opts=CONCOCT_OPTS
    log: "logs/{sample}_concoct.log"
    threads: config['threads']
    message:
        """
        Running concoct using coverage table {input.cov} and fasta {input.cut}.
        Writing to {output}.
        Additional Params: -l {params.threshold} {params.opts} 
        Logging to {log}.
        Using {threads} thread(s).
        """
    shell:
        """
        echo CONCOCT {wildcards.sample} >> {log}
        if [[ -s {input.cov} ]]
        then
            concoct -l {params.threshold} -t {threads} --composition_file {input.cut} \
            --coverage_file {input.cov} -b concoct_{wildcards.sample}/ {params.opts} &>> {log}
        else
            touch {output}
            echo "SKIPPING CONCOCT" >> {log}
        fi
        echo END CONCOCT {wildcards.sample} >> {log}
        """



rule merge:
    """Merge subcontig clustering into original contig clustering."""
    input: 
        expand(
            "concoct_{{sample}}/clustering_gt{param}.csv",
            param=THRESHOLD)
    output: "concoct_{sample}/clustering_merged.csv"
    threads: 1
    log: "logs/{sample}_merge_clusters.log"
    message:
        """
        Merge clustering from {input}.
        Writing to {output}.
        Logging to {log}.
        Using {threads} thread(s).
        """
    shell:
        """
        echo MERGE CLUSTERS {wildcards.sample} >> {log}
        merge_cutup_clustering.py {input} > {output} 2>> {log}
        echo END MERGE CLUSTERS {wildcards.sample} >> {log}
        """

rule fasta_bins:
    """Extract bins as individual fastas"""
    input: 
        clustered="concoct_{sample}/clustering_merged.csv",
        assembly="megahit_{sample}/final.contigs.fa"
    output: 
        touch("concoct_{sample}/{sample}.done")
    params:
        outpath="concoct_{sample}/fasta_bins"
    log: "logs/{sample}_bin_fastas.log"
    threads: 1
    message:
        """
        Extracting bins into individual fasta files.
        Using cluster file {input.clustered} and assemblies {input.assembly}.
        Writing to {params.outpath}.
        Logging to {log}.
        Using {threads} thread(s).
        """
    shell:
        """
        echo FASTA_BINS {wildcards.sample} >> {log}
        mkdir -p {params.outpath}
        if [[ -s {input.clustered} ]]
        then
            extract_fasta_bins.py {input.assembly} {input.clustered} --output_path {params.outpath} &>> {log}
        else
            touch {output}
            echo CLUSTER FILE IS EMPTY: SKIPPING EXTRACTION >> {log}
        fi
        echo END FASTA_BINS {wildcards.sample} >> {log}
        """

LOG_OUT = """
        for VAR in megahit bam coverage_table concoct merge_clusters bin_fastas
            do
            cat "logs/"*$VAR".log" >> {log}
            done
        """.format(log=config['log_file'])

onsuccess:
    shell(LOG_OUT)

onerror:
    shell(LOG_OUT)




