#!/opt/mambaforge/envs/bioconda/conda-bld/zol_1759109280278/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_/bin/python

"""
Program: atpoc
Author: Rauf Salamzade
Kalan Lab
UW Madison, Department of Medical Microbiology and Immunology
"""

# BSD 3-Clause License
#
# Copyright (c) 2023-2025, Kalan-Lab
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 

from Bio import SeqIO
from collections import defaultdict
from rich_argparse import RawTextRichHelpFormatter
from time import sleep
from zol import util, fai
import argparse
import math
import os
import pandas as pd
import sys
import traceback

os.environ["OMP_NUM_THREADS"] = "1"
def create_parser(): 
    """ Parse arguments """
    parser = argparse.ArgumentParser(description = """
    Program: atpoc
    Author: Rauf Salamzade
    Affiliation: Kalan Lab, UW Madison, Department of Medical Microbiology and
        Immunology

    atpoc - Assess Temperate Phage-Ome Conservation

    atpoc wraps fai to assess the conservation of a sample's temperate/pro-phage-ome
    relative to a set of target genomes (e.g. genomes belonging to the same genus). Alternatively, 

    it can run a simple DIAMOND BLASTp analysis to just assess the presence of prophage genes
    individually - without the requirement they are co-located like in the focal sample.
    """, formatter_class = RawTextRichHelpFormatter)

    parser.add_argument('-i', 
        '--sample-genome', 
        help = "Path to sample genome in GenBank or FASTA format.", 
        required = True, 
        default = True)
    parser.add_argument('-vi', 
        '--vibrant-results', 
        help = "Path to VIBRANT results directory for a single sample/genome.", 
        required = False, 
        default = None)
    parser.add_argument('-ps', 
        '--phispy-results', 
        help = "Path to PhiSpy results directory for a single sample/genome.", 
        required = False, 
        default = None)
    parser.add_argument('-gn', 
        '--genomad-results', 
        help = "Path to GeNomad results directory for a single sample/genome.", 
        required = False, 
        default = None)
    parser.add_argument('-tg', 
        '--target-genomes-db', 
        help = "prepTG database directory for target genomes of interest.", 
        required = True)
    parser.add_argument('-up', 
        '--use-pyrodigal', 
        action = 'store_true', 
        help = "Use default pyrodigal instead of prodigal-gv to call genes in\n"
               "phage regions to use as queries in fai/simple-blast. This is\n"
               "perhaps preferable if target genomes db was created with default\n"
               "pyrodigal/prodigal.", 
        required = False, 
        default = False)
    parser.add_argument('-fo', 
        '--fai-options', 
        help = "Provide fai options to run. Should be surrounded by quotes.\n"
               "[Default is '-skpv -e 1e-10 -m 0.5 -dm -sct 0.4'].", 
        required = False, 
        default = "-skpv -e 1e-10 -m 0.5 -dm -sct 0.4")
    parser.add_argument('-s', 
        '--use-simple-blastp', 
        action = 'store_true', 
        help = "Use a simple DIAMOND BLASTp search with no requirement for\n"
               "co-localization of hits.", 
        required = False, 
        default = False)
    parser.add_argument('-si', 
        '--simple-blastp-identity-cutoff', 
        type = float, 
        help = "If simple BLASTp mode requested: cutoff for identity between\n"
               "query proteins and matches in target genomes [Default is 40.0].", 
        required = False, 
        default = 40.0)
    parser.add_argument('-sc', 
        '--simple-blastp-coverage-cutoff', 
        type = float, 
        help = "If simple BLASTp mode requested: cutoff for coverage between\n"
               "query proteins and matches in target genomes [Default is 70.0].", 
        required = False, 
        default = 70.0)
    parser.add_argument('-se', 
        '--simple-blastp-evalue-cutoff', 
        type = float, 
        help = "If simple BLASTp mode requested: cutoff for E-value between query\n"
               "proteins and matches in target genomes [Default is 1e-5].", 
        required = False, 
        default = 1e-5)
    parser.add_argument('-sm', 
        '--simple-blastp-sensitivity-mode', 
        help = "Sensitivity mode for DIAMOND BLASTp. [Default is 'very-sensititve'].", 
        required = False, 
        default = "very-sensitive")
    parser.add_argument('-o', '--outdir', help = "Output directory.", required = True)
    parser.add_argument('-c', 
        '--threads', 
        type = int, 
        help = "The number of threads to use [Default is 1].", 
        required = False, 
        default = 1)

    args = parser.parse_args()
    return args

def atpoc(): 
    myargs = create_parser()

    sample_genome = myargs.sample_genome
    vibrant_results_dir = myargs.vibrant_results
    phispy_results_dir = myargs.phispy_results
    genomad_results_dir = myargs.genomad_results
    preptg_db_dir = myargs.target_genomes_db
    outdir = os.path.abspath(myargs.outdir) + '/'
    fai_options = myargs.fai_options
    use_simple_blastp_flag = myargs.use_simple_blastp
    use_pyrodigal_flag = myargs.use_pyrodigal
    simple_blastp_identity_cutoff = myargs.simple_blastp_identity_cutoff
    simple_blastp_coverage_cutoff = myargs.simple_blastp_coverage_cutoff
    simple_blastp_evalue_cutoff = myargs.simple_blastp_evalue_cutoff
    simple_blastp_sensitivity_mode = myargs.simple_blastp_sensitivity_mode
    threads = myargs.threads

    # create output directory if needed, or warn of over-writing
    if os.path.isdir(outdir):
        sys.stderr.write("Output directory exists. Files will be overwritten, but\n"
                       "checkpoints will be used to avoid redoing successfully\n"
                       "completed steps.\n"
                       "Do you wish to proceed? (yes/no): ")
        user_input = input().strip().lower()
        if user_input not in ['yes', 'y']:
            sys.stderr.write("Execution cancelled by user.\n")
            sys.exit(0)
    else:
        util.setup_ready_directory([outdir], delete_if_exist = False)

    if (vibrant_results_dir == None or not os.path.isdir(vibrant_results_dir)) and (phispy_results_dir == None or not os.path.isdir(phispy_results_dir) and (genomad_results_dir == None or not os.path.isdir(genomad_results_dir))): 
        sys.stderr.write("Neither PhiSpy, VIBRANT, nor geNomad sample results directory provided as input or the paths are faulty! \n")
        sys.exit(1)

    if util.is_genbank(sample_genome): 
        try: 
            sample_genome_fasta = outdir + '.'.join(sample_genome.split('/')[-1].split('.')[: -1]) + '.fna'
            util.convert_genome_genbank_to_fasta([sample_genome, sample_genome_fasta])
            assert(os.path.isfile(sample_genome_fasta) and util.is_fasta(sample_genome_fasta))
            sample_genome = sample_genome_fasta
        except Exception as e: 
            sys.stderr.write("Issue converting sample's genome from GenBank to FASTA format.\n")
            sys.exit(1)

    if not util.is_fasta(sample_genome): 
        sys.stderr.write("Issue with processing/validating focal sample's genome. Was it in FASTA or GenBank format?\n.")
        sys.exit(1)

    # create logging object
    log_file = outdir + 'Progress.log'
    log_object = util.create_logger_object(log_file)
    version_string = util.get_version()

    sys.stdout.write(f'Running atpoc version {version_string}\n')
    log_object.info(f'Running atpoc version {version_string}')

    # log command used
    parameters_file = outdir + 'Command_Issued.txt'
    parameters_handle = open(parameters_file, 'a+')
    parameters_handle.write(' '.join(sys.argv) + '\n')
    parameters_handle.close()

    msg = 'Parsing phage prediction results'
    sys.stdout.write(msg + '\n')
    log_object.info(msg)


    scaff_lengths = {}
    try: 
        with open(sample_genome) as osg: 
            for rec in SeqIO.parse(osg, "fasta"): 
                scaff_lengths[rec.id] = len(str(rec.seq))
    except Exception as e: 
        msg = 'Issue processing genome file.'
        sys.stderr.write(msg + '\n')
        sys.stderr.write(traceback.format_exc() + '\n')

    sample_phages = []
    if vibrant_results_dir != None and os.path.isdir(vibrant_results_dir): 
        for root, dirs, files in os.walk(vibrant_results_dir): 
            for basename in files: 
                filename = os.path.join(root, basename)
                if basename.startswith('VIBRANT_integrated_prophage_coordinates_') and basename.endswith('.tsv'): 
                    with open(filename) as ofn: 
                        for i, line in enumerate(ofn): 
                            if i == 0: continue
                            line = line.strip()
                            scaff, phage_id, _, _, _, start_coord, end_coord, phage_length = line.split('\t')
                            scaff = scaff.split()[0]
                            if phage_id.split('_')[-2] == 'fragment':
                                phage_id = phage_id.split()[0] + '_' + '_'.join(phage_id.split('_')[-2:])
                            else:
                                phage_id = phage_id.split()[0] + '_' + str(i)
                            sample_phages.append(['VIBRANT', 'VB-' + phage_id, 
                                                  scaff, start_coord, end_coord, phage_length, ''])

    if phispy_results_dir != None and os.path.isdir(phispy_results_dir): 
        for root, dirs, files in os.walk(phispy_results_dir): 
            for basename in files: 
                filename = os.path.join(root, basename)
                if basename == 'prophage_coordinates.tsv': 
                    with open(filename) as ofn: 
                        for i, line in enumerate(ofn): 
                            line = line.strip('\n')
                            phage_id, scaff, start_coord, end_coord, _, _, _, _, attL_seq, attR_seq, att_explanation = line.split('\t')
                            phage_length = abs(int(start_coord)-int(end_coord))
                            additional_info = '; '.join(['attL_sequence = ', attL_seq, 'attR_sequence = ' + attR_seq, 'att_explanation = ' + att_explanation])
                            sample_phages.append(['PhiSpy', 'PS-' + phage_id, scaff, start_coord, end_coord, str(phage_length), additional_info])

    if genomad_results_dir != None and os.path.isdir(genomad_results_dir): 
        for root, dirs, files in os.walk(genomad_results_dir): 
            for basename in files: 
                filename = os.path.join(root, basename)
                if basename.endswith('_virus_summary.tsv'): 
                    with open(filename) as ofn: 
                        for i, line in enumerate(ofn): 
                            if i == 0: continue
                            line = line.strip('\n')
                            phage_id, phage_length, topology, coords, n_genes, genetic_code, virus_score, fdr, n_hallmarks, marker_enrichment, taxonomy = line.split('\t')
                            scaff = phage_id.split('|')[0]
                            phage_id = phage_id.replace('|', '_')
                            if coords != 'NA': 
                                start_coord, end_coord = coords.split('-')
                            else: 
                                start_coord = 1
                                end_coord = scaff_lengths[scaff]
                            additional_info = '; '.join(['virus_score' + virus_score, 
                                                        'topology = ' + topology, 
                                                        'genetic_code = ' + genetic_code, 
                                                        'n_hallmarks = ' + n_hallmarks, 
                                                        'marker_enrichment = ' + marker_enrichment, 
                                                        'taxonomy = ' + taxonomy])
                            sample_phages.append(['geNomad', 'GN-' + phage_id, scaff, start_coord, end_coord, phage_length, additional_info])

    msg = f'Found {len(sample_phages)} prophage predictions!'
    sys.stdout.write(msg + '\n')
    log_object.info(msg)

    if len(sample_phages) == 0: 
        msg = '--------------------------------\n'
        msg += 'No prophage predictions found!\n'
        msg += '--------------------------------\n'
        sys.stdout.write(msg + '\n')
        log_object.info(msg)
        sys.exit(1)

    fai_results_dir = outdir + 'fai_or_blast_Results/'
    util.setup_ready_directory([fai_results_dir], delete_if_exist = True)
    phage_info_file = outdir + 'TemperatePhageOme_Info.tsv'
    phage_info_handle = open(phage_info_file, 'w')
    phage_lengths = []
    phage_info_handle.write('\t'.join(['phage_id', 
                                    'phage_prediction_software', 
                                    'scaffold_id', 
                                    'start', 
                                    'end', 
                                    'phage_length', 
                                    'additional_attributes']) + '\n')
    numeric_columns_t1 = ['phage_length']
    for phage in sample_phages: 
        phage_software, phage_id, scaff, start_coord, end_coord, phage_length, additional_info = phage

        phage_fai_res_dir = fai_results_dir + phage_id + '/'

        phage_info_handle.write('\t'.join([str(x) for x in [phage_id, 
                                                            phage_software, 
                                                            scaff, 
                                                            start_coord, 
                                                            end_coord, 
                                                            phage_length, 
                                                            additional_info]]) + '\n')
        phage_lengths.append(float(phage_length))

        if use_simple_blastp_flag: 
            # run simple blastp
            util.setup_ready_directory([phage_fai_res_dir], delete_if_exist = True)
            reference_genome_gbk = phage_fai_res_dir + 'reference.gbk'
            gene_call_cmd = ['runProdigalAndMakeProperGenbank.py', '-i', sample_genome, '-s', 'reference', '-l', 
                                'OG', '-o', phage_fai_res_dir]
            if use_pyrodigal_flag: 
                gene_call_cmd += ['-gcm', 'pyrodigal']
            else: 
                gene_call_cmd += ['-gcm', 'prodigal-gv']
            util.run_cmd_via_subprocess(gene_call_cmd, log_object=log_object, 
                                        check_files = [reference_genome_gbk])
            locus_genbank = phage_fai_res_dir + 'Reference_Query.gbk'
            locus_proteome = phage_fai_res_dir + 'Reference_Query.faa'
            fai.subset_genbank_for_query_locus(reference_genome_gbk, locus_genbank, locus_proteome, 
                                               scaff, int(start_coord), int(end_coord), log_object)
            util.diamond_blast_and_get_best_hits(phage_id, locus_proteome, locus_proteome, preptg_db_dir, phage_fai_res_dir, log_object, 
                                            identity_cutoff = simple_blastp_identity_cutoff, coverage_cutoff = simple_blastp_coverage_cutoff, 
                                            blastp_mode = simple_blastp_sensitivity_mode, prop_key_prots_needed = 0.0, 
                                            evalue_cutoff = simple_blastp_evalue_cutoff, threads = threads)
        else: 
            # run fai
            fai_cmd = ['fai', '-r', sample_genome, '-rc', scaff, '-rs', start_coord, '-re', end_coord,
                       '-tg', preptg_db_dir, fai_options, '-o', phage_fai_res_dir, '-c', str(threads)]
            if not use_pyrodigal_flag: 
                fai_cmd += ['-rgv']
            util.run_cmd_via_subprocess(fai_cmd, log_object=log_object)
    phage_info_handle.close()

    tg_phage_aai = defaultdict(lambda: defaultdict(lambda: 'NA'))
    tg_phage_sgp = defaultdict(lambda: defaultdict(lambda: 'NA'))
    phages = set([])
    tg_total_phages_found = defaultdict(int)
    for phage in sample_phages: 
        phage_software, phage_id, scaff, start_coord, end_coord, phage_length, additional_info = phage
        phages.add(phage_id)
        genome_wide_tsv_result_file = fai_results_dir + phage_id + '/Spreadsheet_TSVs/total_gcs.tsv'
        if use_simple_blastp_flag: 
            genome_wide_tsv_result_file = fai_results_dir + phage_id + '/total_gcs.tsv'
        if not os.path.isfile(genome_wide_tsv_result_file): 
            fai_progress_file = phage_fai_res_dir + 'Progress.log'
            no_hits_found = False
            with open(fai_progress_file) as opf: 
                for line in opf: 
                    line = line.strip()
                    if "ERROR - No target genomes found that meet the minimum criteria for harboring a gene cluster of interest." in line: 
                        no_hits_found = True
                        break
            
            if no_hits_found: 
                msg = f'{genome_wide_tsv_result_file} not found because there were no hits meeting the search criteria for {phage_id}!\n'
                sys.stdout.write(msg)
                log_object.info(msg)
                continue
            else: 
                msg = f'Final results file not found for {phage_id}, and the expected error message was not found in fai\'s progress log:\n{fai_progress_file}\n'
                sys.stdout.write(msg)
                log_object.info(msg)
                sys.exit(1)

        with open(genome_wide_tsv_result_file) as ogwtrf: 
            for i, line in enumerate(ogwtrf):
                line = line.strip()
                ls = line.split('\t')
                if i == 0: continue
                tg = ls[0]
                aai = ls[2]
                sgp = ls[4]
                tg_phage_aai[tg][phage_id] = aai
                tg_phage_sgp[tg][phage_id] = sgp
                if float(aai) > 0.0: 
                    tg_total_phages_found[tg] += 1

    target_genome_listing_file = preptg_db_dir + 'Target_Genome_Annotation_Files.txt'
    result_file = outdir + 'TemperatePhageOme_to_Target_Genomes_Similarity.tsv'
    result_handle = open(result_file, 'w')
    result_handle.write('\t'.join(['target_genome', 
        'shared_phages', 
        'aggregate_score'] + [x + ' - AAI' + '\t' + x + ' - Proportion Shared Genes' for x in sorted(phages)]) + '\n')
    numeric_columns_t2 = ['shared_phages', 
        'aggregate_score'] + [x + ' - AAI' for x in sorted(phages)] + [x + ' - Proportion Shared Genes' for x in sorted(phages)]
    
    tgs = set([])
    all_scores = []
    with open(target_genome_listing_file) as otglf: 
        for line in otglf: 
            line = line.strip()
            ls = line.split('\t')
            tg = ls[0]
            tgs.add(tg)
            shared_phages = tg_total_phages_found[tg]
            total_score = 0.0
            for phage in sorted(phages): 
                if tg_phage_aai[tg][phage] != "NA": 
                    total_score += float(tg_phage_aai[tg][phage])*float(tg_phage_sgp[tg][phage])
            printlist = [tg, str(shared_phages), str(total_score)]
            all_scores.append(total_score)
            for phage in sorted(phages): 
                printlist.append(tg_phage_aai[tg][phage])
                printlist.append(tg_phage_sgp[tg][phage])
            result_handle.write('\t'.join(printlist) + '\n')
    result_handle.close()

    # construct spreadsheet
    atpoc_spreadsheet_file = outdir + 'ATPOC_Results.xlsx'
    writer = pd.ExcelWriter(atpoc_spreadsheet_file, engine = 'xlsxwriter')
    workbook = writer.book
    header_format = workbook.add_format({'bold': True, 
        'text_wrap': True, 
        'valign': 'top', 
        'fg_color': '#D7E4BC', 
        'border': 1})

    t1_results_df = util.load_table_in_panda_data_frame(phage_info_file, numeric_columns_t1)
    t1_results_df.to_excel(writer, 
        sheet_name = 'Sample ProphageOme Overview', 
        index = False, 
        na_rep = "NA")

    t2_results_df = util.load_table_in_panda_data_frame(result_file, numeric_columns_t2)
    t2_results_df.to_excel(writer, 
        sheet_name = 'Similarity to Target Genomes', 
        index = False, 
        na_rep = "NA")

    t1_worksheet = writer.sheets['Sample ProphageOme Overview']
    t2_worksheet = writer.sheets['Similarity to Target Genomes']

    t1_worksheet.conditional_format('A1:F1', 
        {'type': 'cell', 
        'criteria': '!=', 
        'value': 'NA', 
        'format': header_format})
    t2_worksheet.conditional_format('A1:HA1', 
        {'type': 'cell', 
        'criteria': '!=', 
        'value': 'NA', 
        'format': header_format})

    row_count = len(phages) + 1
    t2_row_count = len(tgs) + 1

    # phage length (t1)
    t1_worksheet.conditional_format('F2:F' + str(row_count), 
        {'type': '2_color_scale', 
        'min_color': "#d5f0ea", 
        'max_color': "#83a39c", 
        'min_type': 'num', 
        'max_type': 'num', 
        "min_value": 0, 
        "max_value": max(phage_lengths)})

    # total phage shared (t2)
    t2_worksheet.conditional_format('B2:B' + str(t2_row_count), 
        {'type': '2_color_scale', 
        'min_color': "#d9d9d9", 
        'max_color': "#949494", 
        'min_type': 'num', 
        'max_type': 'num', 
        "min_value": 0, 
        "max_value": len(phages)})

    # aggregate score (t2)
    t2_worksheet.conditional_format('C2:C' + str(t2_row_count), 
        {'type': '2_color_scale', 
        'min_color': "#f7db94", 
        'max_color': "#ad8f40", 
        'min_type': 'num', 
        'max_type': 'num', 
        "min_value": 0, 
        "max_value": max(all_scores)})

    phage_index = 3
    for phage in phages: 
        columnid_pid = util.determine_column_name_based_on_index(phage_index)
        columnid_sgp = util.determine_column_name_based_on_index(phage_index+1)
        phage_index += 2

        # aai
        t2_worksheet.conditional_format(columnid_pid + '2:' + columnid_pid + str(t2_row_count), 
         {'type': '2_color_scale', 
         'min_color': "#d8eaf0", 
         'max_color': "#83c1d4", 
         "min_value": 0.0, 
         "max_value": 100.0, 
         'min_type': 'num', 
         'max_type': 'num'})

        # prop genes found
        t2_worksheet.conditional_format(columnid_sgp + '2:' + columnid_sgp + str(t2_row_count), 
         {'type': '2_color_scale', 
         'min_color': "#e6bec0", 
         'max_color': "#f26168", 
         "min_value": 0.0, 
         "max_value": 1.0, 
         'min_type': 'num', 
         'max_type': 'num'})

    # Freeze the first row of both sheets
    t1_worksheet.freeze_panes(1, 0)
    t2_worksheet.freeze_panes(1, 0)

    # close workbook
    workbook.close()

    sys.stdout.write(f'Done running atpoc!\nFinal results can be found at: \n{atpoc_spreadsheet_file}\n')
    log_object.info(f'Done running atpoc!\nFinal results can be found at: \n{atpoc_spreadsheet_file}\n')

    # Close logging object and exit
    util.close_logger_object(log_object)
    sys.exit(0)

def determine_column_name_based_on_index(index: int) -> str: 
    """
    Function to determine spreadsheet column name for a given index
    """
    # offset at 0
    num_to_char = {}
    alphabet = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
    alphabet_first_spot = [''] + alphabet
    for i, c in enumerate(alphabet): 
        num_to_char[i] = c
    level = math.floor(index / 26)
    remainder = index % 26
    columname = alphabet_first_spot[level]
    columname += alphabet[remainder]
    return columname

if __name__ == '__main__': 
    atpoc()
