#!/opt/mambaforge/envs/bioconda/conda-bld/zol_1766887886185/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_/bin/python

"""
### Program: apos
### Author: Rauf Salamzade
### Kalan Lab
### UW Madison, Department of Medical Microbiology and Immunology
"""

# BSD 3-Clause License
#
# Copyright (c) 2023-2025, Kalan-Lab
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: 
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, 
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from Bio import SeqIO
from collections import defaultdict
from rich_argparse import RawTextRichHelpFormatter
from time import sleep
from zol import util, fai
import argparse
import os
import pandas as pd
import sys


os.environ["OMP_NUM_THREADS"] = "1"
def create_parser(): 
    """ Parse arguments """
    parser = argparse.ArgumentParser(description = """
    Program: apos
    Author: Rauf Salamzade
    Affiliation: Kalan Lab, UW Madison, Department of Medical Microbiology and
        Immunology

    apos - Assess Plasmid-Ome Similarity

    apos wraps fai to assess the conservation of a sample's plasmid-ome
    relative to a set of target genomes (e.g. genomes belonging to the same genus). Alternatively, 

    it can run a simple DIAMOND BLASTp analysis to just assess the presence of plasmid genes
    individually - without the requirement they are co-located in one scaffold like in the focal sample.
    """, formatter_class = RawTextRichHelpFormatter)

    parser.add_argument('-i', 
        '--sample-genome', 
        help = "Path to sample genome in GenBank or FASTA format.", 
        required = True, 
        default = True)
    parser.add_argument('-ms', 
        '--mobsuite-results', 
        help = "Path to MOB-suite (mob_recon) results directory for a single sample/genome.", 
        required = False, 
        default = None)
    parser.add_argument('-gn', 
        '--genomad-results', 
        help = "Path to GeNomad results directory for a single sample/genome.", 
        required = False, 
        default = None)
    parser.add_argument('-tg', 
        '--target-genomes-db', 
        help = "prepTG database directory for target genomes of interest.", 
        required = True)
    parser.add_argument('-fo', 
        '--fai-options', 
        help = "Provide fai options to run. Should be surrounded by quotes. [Default is\n"
               "\"-skpv -e 1e-10 -m 0.5 -dm -sct 0.0\"].", 
        required = False, 
        default = "-skpv -e 1e-10 -m 0.5 -dm -sct 0.0")
    parser.add_argument('-s', 
        '--use-simple-blastp', 
        action = 'store_true', 
        help = "Use a simple DIAMOND BLASTp search with no requirement for co-localization\n"
               "of hits.", 
        required = False, 
        default = False)
    parser.add_argument('-si', 
        '--simple-blastp-identity-cutoff', 
        type = float, 
        help = "If simple BLASTp mode requested: cutoff for identity between query proteins\n"
               "and matches in target genomes [Default is 40.0].", 
        required = False, 
        default = 40.0)
    parser.add_argument('-sc', 
        '--simple-blastp-coverage-cutoff', 
        type = float, 
        help = "If simple BLASTp mode requested: cutoff for coverage between query proteins\n"
               "and matches in target genomes [Default is 70.0].", 
        required = False, 
        default = 70.0)
    parser.add_argument('-se', 
        '--simple-blastp-evalue-cutoff', 
        type = float, 
        help = "If simple BLASTp mode requested: cutoff for E-value between query proteins\n"
               "and matches in target genomes [Default is 1e-5].", 
        required = False, 
        default = 1e-5)
    parser.add_argument('-sm', 
        '--simple-blastp-sensitivity-mode', 
        help = "Sensitivity mode for DIAMOND BLASTp. [Default is \"very-sensititve\"].", 
        required = False, 
        default = "very-sensitive")
    parser.add_argument('-o', '--outdir', help="Output directory.", required = True)
    parser.add_argument('-c', '--threads', 
        type = int, 
        help = "The number of threads to use [Default is 1].", 
        required = False, 
        default = 1)

    args = parser.parse_args()
    return args

def apos(): 
    myargs = create_parser()

    sample_genome = myargs.sample_genome
    mobsuite_results_dir = myargs.mobsuite_results
    genomad_results_dir = myargs.genomad_results
    preptg_db_dir = myargs.target_genomes_db
    outdir = os.path.abspath(myargs.outdir) + '/'
    fai_options = myargs.fai_options
    use_simple_blastp_flag = myargs.use_simple_blastp
    simple_blastp_identity_cutoff = myargs.simple_blastp_identity_cutoff
    simple_blastp_coverage_cutoff = myargs.simple_blastp_coverage_cutoff
    simple_blastp_evalue_cutoff = myargs.simple_blastp_evalue_cutoff
    simple_blastp_sensitivity_mode = myargs.simple_blastp_sensitivity_mode
    threads = myargs.threads

    # create output directory if needed, or warn of over-writing
    if os.path.isdir(outdir):
        sys.stderr.write("Output directory exists. Files will be overwritten, but\n"
                       "checkpoints will be used to avoid redoing successfully\n"
                       "completed steps.\n"
                       "Do you wish to proceed? (yes/no): ")
        user_input = input().strip().lower()
        if user_input not in ['yes', 'y']:
            sys.stderr.write("Execution cancelled by user.\n")
            sys.exit(0)
    else:
        util.setup_ready_directory([outdir], delete_if_exist = False)

    if (mobsuite_results_dir == None or not os.path.isdir(mobsuite_results_dir)) and (genomad_results_dir == None or not os.path.isdir(genomad_results_dir)): 
        sys.stderr.write("Neither MOB-suite nor geNomad sample results directory provided as input or the paths are faulty! \n")
        sys.exit(1)

    if util.is_genbank(sample_genome): 
        try: 
            sample_genome_fasta = outdir + '.'.join(sample_genome.split('/')[-1].split('.')[: -1]) + '.fna'
            util.convert_genome_genbank_to_fasta([sample_genome, sample_genome_fasta])
            assert(os.path.isfile(sample_genome_fasta) and util.is_fasta(sample_genome_fasta))
            sample_genome = sample_genome_fasta
        except Exception as e: 
            sys.stderr.write("Issue converting sample's genome from GenBank to FASTA format.\n")
            sys.exit(1)

    if not util.is_fasta(sample_genome): 
        sys.stderr.write("Issue with processing/validating focal sample's genome. Was it in FASTA or GenBank format?\n.")
        sys.exit(1)

    # create logging object
    log_file = outdir + 'Progress.log'
    log_object = util.create_logger_object(log_file)
    version_string = util.get_version()

    sys.stdout.write(f'Running apos version {version_string}\n')
    log_object.info(f'Running apos version {version_string}')

    # log command used
    parameters_file = outdir + 'Command_Issued.txt'
    parameters_handle = open(parameters_file, 'a+')
    parameters_handle.write(' '.join(sys.argv) + '\n')
    parameters_handle.close()

    msg = 'Parsing plasmid prediction results'
    sys.stdout.write(msg + '\n')
    log_object.info(msg)

    sample_plasmids = []
    if mobsuite_results_dir != None and os.path.isdir(mobsuite_results_dir): 
        scaffold_ids = set([])
        with open(sample_genome) as osg: 
            for rec in SeqIO.parse(osg, 'fasta'): 
                if rec.id in scaffold_ids: 
                    sys.stderr.write('Issues with multiple FASTA headers (up to first whitespace) in sample\'s genome being identical.\n')
                    log_object.error('Issues with multiple FASTA headers (up to first whitespace) in sample\'s genome being identical.')
                    sys.exit(1)
                scaffold_ids.add(rec.id)

        for root, dirs, files in os.walk(mobsuite_results_dir): 
            for basename in files: 
                filename = os.path.join(root, basename)
                if basename == 'contig_report.txt': 
                    with open(filename) as ofn: 
                        for i, line in enumerate(ofn): 
                            if i == 0: continue
                            line = line.strip()
                            ls = line.split('\t')
                            if ls[1] != 'plasmid': continue
                            primary_cluster_id = ls[2]
                            secondary_cluster_id = ls[3]
                            plasmid_id = ls[4]
                            scaff = None
                            possible_matches = 0
                            for scaff_id in scaffold_ids: 
                                if plasmid_id.startswith(scaff_id): 
                                    scaff = scaff_id
                                    possible_matches += 1
                            try: 
                                assert(possible_matches == 1)
                            except Exception as e: 
                                sys.stderr.write(f'Issues with matching MOB-suite contig_id {plasmid_id} to a single record in the sample\'s genome.\n')
                                log_object.error(f'Issues with matching MOB-suite contig_id {plasmid_id} to a single record in the sample\'s genome.')
                                sys.exit(1)
                            plasmid_length = ls[5]
                            gc = ls[6]
                            circularity_status = ls[6]
                            replicon_types = ls[7]
                            relaxase_types = ls[9]
                            mpf_types = ls[11]
                            orit_types = ls[13]
                            predicted_mobility = ls[15]
                            additional_info = '; '.join(['primary_cluster_id = ', primary_cluster_id, 'secondary_cluster_id = ' + secondary_cluster_id, 'gc = ' + gc, 
                                                         'circularity_status = ' + circularity_status, 'replicon_types = ' + replicon_types, 
                                                         'relaxase_types = ' + relaxase_types, 'mpf_types = ' + mpf_types, 'orit_types = ' + orit_types, 
                                                         'predicted_mobility = ' + predicted_mobility])
                            sample_plasmids.append(['MOB-suite', 
              'MS-' + plasmid_id, 
              scaff, 
              plasmid_length, 
              additional_info])

    if genomad_results_dir != None and os.path.isdir(genomad_results_dir): 
        for root, dirs, files in os.walk(genomad_results_dir): 
            for basename in files: 
                filename = os.path.join(root, basename)
                if basename.endswith('_plasmid_summary.tsv'): 
                    with open(filename) as ofn: 
                        for i, line in enumerate(ofn): 
                            if i == 0: continue
                            line = line.strip('\n')
                            plasmid_id, plasmid_length, topology, n_genes, genetic_code, plasmid_score, fdr, n_hallmarks, marker_enrichment, conjugation_genes, amr_genes = line.split('\t')
                            plasmid_id = plasmid_id.replace('|', '_')
                            additional_info = '; '.join(['plasmid_score' + plasmid_score, 
                                                        'topology = ' + topology, 
                                                        'genetic_code = ' + genetic_code, 
                                                        'n_hallmarks = ' + n_hallmarks, 
                                                        'marker_enrichment = ' + marker_enrichment, 
                                                        'conjugation_genes = ' + conjugation_genes, 
                                                        'amr_genes = ' + amr_genes])
                            sample_plasmids.append(['geNomad', 
                                    'GN-' + plasmid_id, 
              plasmid_id, 
              plasmid_length, 
              additional_info])

    msg = f'Found {len(sample_plasmids)} plasmid predictions!'
    sys.stdout.write(msg + '\n')
    log_object.info(msg)

    if len(sample_plasmids) == 0: 
        msg = '--------------------------------\n'
        msg += 'No plasmid predictions found!\n'
        msg += '--------------------------------\n'
        sys.stdout.write(msg + '\n')
        log_object.info(msg)
        sys.exit(1)


    fai_results_dir = outdir + 'fai_or_blast_Results/'
    util.setup_ready_directory([fai_results_dir], delete_if_exist = True)
    plasmid_info_file = outdir + 'PlasmidOme_Info.tsv'
    plasmid_info_handle = open(plasmid_info_file, 'w')
    plasmid_lengths = []
    plasmid_info_handle.write('\t'.join(['plasmid_id', 
        'plasmid_prediction_software', 
        'scaffold_id', 
        'plasmid_length', 
        'additional_attributes']) + '\n')
    numeric_columns_t1 = ['plasmid_length']
    for plasmid in sample_plasmids: 
        plasmid_software, plasmid_id, scaff, plasmid_length, additional_info = plasmid
        plasmid_id = plasmid_id.split()[0]
        plasmid_fai_res_dir = fai_results_dir + plasmid_id + '/'

        plasmid_info_handle.write('\t'.join([str(x) for x in [plasmid_id, 
         plasmid_software, 
         scaff, 
         plasmid_length, 
         additional_info]]) + '\n')
        plasmid_lengths.append(float(plasmid_length))

        start_coord = '1'
        end_coord = str(plasmid_length)
        if use_simple_blastp_flag: 
            # run simple blastp
            util.setup_ready_directory([plasmid_fai_res_dir], delete_if_exist = True)
            reference_genome_gbk = plasmid_fai_res_dir + 'reference.gbk'
            gene_call_cmd = ['runProdigalAndMakeProperGenbank.py', '-i', sample_genome, '-s', 'reference', '-l', 
                             'OG', '-o', plasmid_fai_res_dir]
            util.run_cmd_via_subprocess(gene_call_cmd, log_object=log_object, 
                                        check_files = [reference_genome_gbk])
            locus_genbank = plasmid_fai_res_dir + 'Reference_Query.gbk'
            locus_proteome = plasmid_fai_res_dir + 'Reference_Query.faa'
            fai.subset_genbank_for_query_locus(reference_genome_gbk, 
                                                locus_genbank, 
                                                locus_proteome, 
                                                scaff, 
                                                int(start_coord), 
                                                int(end_coord), 
                                                log_object)
            util.diamond_blast_and_get_best_hits(plasmid_id, locus_proteome, locus_proteome, preptg_db_dir, plasmid_fai_res_dir, log_object, 
                                            identity_cutoff = simple_blastp_identity_cutoff, coverage_cutoff = simple_blastp_coverage_cutoff, 
                                            blastp_mode = simple_blastp_sensitivity_mode, prop_key_prots_needed = 0.0, 
                                            evalue_cutoff = simple_blastp_evalue_cutoff, threads = threads)
        else: 
            # run fai
            fai_cmd = ['fai', '-r', sample_genome, '-rc', scaff, '-rs', start_coord, '-re', end_coord, 
                       '-tg', 
                        preptg_db_dir, 
                        fai_options, 
                        '-o', 
                        plasmid_fai_res_dir, 
                        '-c', 
                        str(threads)]
            util.run_cmd_via_subprocess(fai_cmd, log_object=log_object)
    plasmid_info_handle.close()

    tg_plasmid_aai = defaultdict(lambda: defaultdict(lambda: 'NA'))
    tg_plasmid_sgp = defaultdict(lambda: defaultdict(lambda: 'NA'))
    plasmids = set([])
    tg_total_plasmids_found = defaultdict(int)
    for plasmid in sample_plasmids: 
        plasmid_software, plasmid_id, scaff, plasmid_length, additional_info = plasmid
        plasmid_id = plasmid_id.split()[0]
        plasmids.add(plasmid_id)
        genome_wide_tsv_result_file = fai_results_dir + plasmid_id + '/Spreadsheet_TSVs/total_gcs.tsv'
        if use_simple_blastp_flag: 
            genome_wide_tsv_result_file = fai_results_dir + plasmid_id + '/total_gcs.tsv'
        
        if not os.path.isfile(genome_wide_tsv_result_file): 
            fai_progress_file = plasmid_fai_res_dir + 'Progress.log'
            no_hits_found = False
            with open(fai_progress_file) as opf: 
                for line in opf: 
                    line = line.strip()
                    if "ERROR - No target genomes found that meet the minimum criteria for harboring a gene cluster of interest." in line: 
                        no_hits_found = True
                        break
            
            if no_hits_found: 
                msg = f'{genome_wide_tsv_result_file} not found because there were no hits meeting the search criteria for {plasmid_id}!\n'
                sys.stdout.write(msg)
                log_object.info(msg)
                continue
            else: 
                msg = f'Final results file not found for {plasmid_id}, and the expected error message was not found in fai\'s progress log:\n{fai_progress_file}\n'
                sys.stdout.write(msg)
                log_object.info(msg)
                sys.exit(1)

        with open(genome_wide_tsv_result_file) as ogwtrf: 
            for i, line in enumerate(ogwtrf): 
                line = line.strip()
                ls = line.split('\t')
                if i == 0: continue
                tg = ls[0]
                aai = ls[2]
                sgp = ls[4]
                tg_plasmid_aai[tg][plasmid_id] = aai
                tg_plasmid_sgp[tg][plasmid_id] = sgp
                if float(aai) > 0.0: 
                    tg_total_plasmids_found[tg] += 1

    target_genome_listing_file = preptg_db_dir + 'Target_Genome_Annotation_Files.txt'
    result_file = outdir + 'PlasmidOme_to_Target_Genomes_Similarity.tsv'
    result_handle = open(result_file, 'w')
    result_handle.write('\t'.join(['target_genome', 
        'shared_plasmids', 
        'aggregate_score'] + [x + ' - AAI' + '\t' + x + ' - Proportion Shared Genes' for x in sorted(plasmids)]) + '\n')
    numeric_columns_t2 = ['shared_plasmids', 
        'aggregate_score'] + [x + ' - AAI' for x in sorted(plasmids)] + [x + ' - Proportion Shared Genes' for x in sorted(plasmids)]
    tgs = set([])
    all_scores = []
    with open(target_genome_listing_file) as otglf: 
        for line in otglf: 
            line = line.strip()
            ls = line.split('\t')
            tg = ls[0]
            tgs.add(tg)
            shared_plasmids = tg_total_plasmids_found[tg]
            total_score = 0.0
            for plasmid in sorted(plasmids): 
                if tg_plasmid_aai[tg][plasmid] != "NA": 
                    total_score += float(tg_plasmid_aai[tg][plasmid])*float(tg_plasmid_sgp[tg][plasmid])
            printlist = [tg, str(shared_plasmids), str(total_score)]
            all_scores.append(total_score)
            for plasmid in sorted(plasmids): 
                printlist.append(tg_plasmid_aai[tg][plasmid])
                printlist.append(tg_plasmid_sgp[tg][plasmid])
            result_handle.write('\t'.join(printlist) + '\n')
    result_handle.close()

    # construct spreadsheet
    apos_spreadsheet_file = outdir + 'APOS_Results.xlsx'
    writer = pd.ExcelWriter(apos_spreadsheet_file, engine = 'xlsxwriter')
    workbook = writer.book
    header_format = workbook.add_format({'bold': True, 
        'text_wrap': True, 
        'valign': 'top', 
        'fg_color': '#D7E4BC', 
        'border': 1})

    t1_results_df = util.load_table_in_panda_data_frame(plasmid_info_file, 
        numeric_columns_t1)
    t1_results_df.to_excel(writer, 
        sheet_name = 'Sample PlasmidOme Overview', 
        index = False, 
        na_rep = "NA")

    t2_results_df = util.load_table_in_panda_data_frame(result_file, numeric_columns_t2)
    t2_results_df.to_excel(writer, 
        sheet_name = 'Similarity to Target Genomes', 
        index = False, 
        na_rep = "NA")

    t1_worksheet = writer.sheets['Sample PlasmidOme Overview']
    t2_worksheet = writer.sheets['Similarity to Target Genomes']

    t1_worksheet.conditional_format('A1:F1', 
        {'type': 'cell', 
        'criteria': '!=', 
        'value': 'NA', 
        'format': header_format})
    t2_worksheet.conditional_format('A1:HA1', 
        {'type': 'cell', 
        'criteria': '!=', 
        'value': 'NA', 
        'format': header_format})

    row_count = len(plasmids) + 1
    t2_row_count = len(tgs) + 1

    # plasmid length (t1)
    t1_worksheet.conditional_format('D2:D' + str(row_count), 
        {'type': '2_color_scale', 
        'min_color': "#d5f0ea", 
        'max_color': "#83a39c", 
        'min_type': 'num', 
        'max_type': 'num', 
        "min_value": 0, 
        "max_value": max(plasmid_lengths)})

    # total plasmid shared (t2)
    t2_worksheet.conditional_format('B2:B' + str(t2_row_count), 
        {'type': '2_color_scale', 
        'min_color': "#d9d9d9", 
        'max_color': "#949494", 
        'min_type': 'num', 
        'max_type': 'num', 
        "min_value": 0, 
        "max_value": len(plasmids)})

    # aggregate score (t2)
    t2_worksheet.conditional_format('C2:C' + str(t2_row_count), 
        {'type': '2_color_scale', 
        'min_color': "#f7db94", 
        'max_color': "#ad8f40", 
        'min_type': 'num', 
        'max_type': 'num', 
        "min_value": 0, 
        "max_value": max(all_scores)})

    plasmid_index = 3
    for plasmid in plasmids: 
        columnid_pid = util.determine_column_name_based_on_index(plasmid_index)
        columnid_sgp = util.determine_column_name_based_on_index(plasmid_index+1)
        plasmid_index += 2

        # aai
        t2_worksheet.conditional_format(columnid_pid + '2:' + columnid_pid + str(t2_row_count), 
         {'type': '2_color_scale', 
         'min_color': "#d8eaf0", 
         'max_color': "#83c1d4", 
         "min_value": 0.0, 
         "max_value": 100.0, 
         'min_type': 'num', 
         'max_type': 'num'})

        # prop genes found
        t2_worksheet.conditional_format(columnid_sgp + '2:' + columnid_sgp + str(t2_row_count), 
         {'type': '2_color_scale', 
         'min_color': "#e6bec0", 
         'max_color': "#f26168", 
         "min_value": 0.0, 
         "max_value": 1.0, 
         'min_type': 'num', 
         'max_type': 'num'})

    # Freeze the first row of both sheets
    t1_worksheet.freeze_panes(1, 0)
    t2_worksheet.freeze_panes(1, 0)

    # close workbook
    workbook.close()

    sys.stdout.write(f'Done running apos!\nFinal results can be found at: \n{apos_spreadsheet_file}\n')
    log_object.info(f'Done running apos!\nFinal results can be found at: \n{apos_spreadsheet_file}\n')

    # Close logging object and exit
    util.close_logger_object(log_object)
    sys.exit(0)

if __name__ == '__main__': 
    apos()
