#!/opt/mambaforge/envs/bioconda/conda-bld/zol_1766863018713/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_/bin/python

"""
### Program: salt
### Author: Rauf Salamzade
### Kalan Lab
### UW Madison, Department of Medical Microbiology and Immunology
"""

from math import log
from Bio import SeqIO
from rich_argparse import RawTextRichHelpFormatter
from time import sleep
from zol import util, fai
import argparse
import multiprocessing
import os
import sys
import gzip
import traceback
import pickle

# BSD 3-Clause License
#
# Copyright (c) 2023-2025, Kalan-Lab
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met: 
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, 
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

os.environ["OMP_NUM_THREADS"] = "1"
def create_parser(): 
    """ Parse arguments """
    parser = argparse.ArgumentParser(description = """
    Program: salt
    Author: Rauf Salamzade
    Affiliation: Kalan Lab, UW Madison, Department of Medical Microbiology and Immunology

    salt - Support Assessment for Lateral Transfer

    salt performs various analyses to assess support for horizontal vs. vertical evolution
    of a gene cluster across target genomes searched using fai. It takes as input the result
    directory from fai as well as the prepTG database searched.

    salt will: (1) run codoff to assess codon usage similarities between gene clusters detected
    and their respective background genomes, 
        (2) infer similarity between query gene cluster
    searched for in fai and
        detected homolog with respect to expected similarity based on
    universal ribosomal proteins, 
        and (3) assess whether the scaffold the detected gene cluster is on
    features insertion-elements, phage proteins, or
        plasmid proteins (assumming --mge_annotation
    was requestd for prepTG).

    Similar to other zol programs, the final result is an auto-color-formatted XLSX spreadsheet.
    """, formatter_class = RawTextRichHelpFormatter)

    parser.add_argument('-f', 
        '--fai-results-dir', 
        help = "Path to antiSMASH BGC prediction results directory\n"
               "for a single sample/genome.", 
        required = True)
    parser.add_argument('-tg', 
        '--target-genomes-db', 
        help = "Result directory from running prepTG for target\n"
               "genomes of interest.", 
        required = True)
    parser.add_argument('-o', 
        '--outdir', 
        help = "Output directory for salt analysis.", 
        required = True)
    parser.add_argument('-c', 
        '--threads', 
        type = int, 
        help = "The number of threads to use [Default is 1].", 
        required = False, 
        default = 1)

    args = parser.parse_args()
    return args

def salt(): 
    myargs = create_parser()

    fai_results_dir = myargs.fai_results_dir
    target_genomes_db = myargs.target_genomes_db
    outdir = os.path.abspath(myargs.outdir) + '/'
    threads = myargs.threads

    try: 
        assert (os.path.isdir(fai_results_dir) and os.path.isdir(target_genomes_db))
    except Exception as e: 
        sys.stderr.write("One or more of the required input directories do not exist.\n")
        sys.exit(1)

    # create output directory if needed, or warn of over-writing
    if os.path.isdir(outdir):
        sys.stderr.write("Output directory exists. Files will be overwritten, but\n"
                       "checkpoints will be used to avoid redoing successfully\n"
                       "completed steps.\n"
                       "Do you wish to proceed? (yes/no): ")
        user_input = input().strip().lower()
        if user_input not in ['yes', 'y']:
            sys.stderr.write("Execution cancelled by user.\n")
            sys.exit(0)
    else:
        util.setup_ready_directory([outdir], delete_if_exist = False)

    check_dir = outdir + 'Checkpoint_Files/'
    if not os.path.isdir(check_dir): 
        util.setup_ready_directory([check_dir], delete_if_exist = True)

    # create logging object
    log_file = outdir + 'Progress.log'
    log_object = util.create_logger_object(log_file)
    version_string = util.get_version()

    sys.stdout.write(f'Running salt version {version_string}\n')
    log_object.info(f'Running salt version {version_string}')

    # log command used
    parameters_file = outdir + 'Command_Issued.txt'
    parameters_handle = open(parameters_file, 'a+')
    parameters_handle.write(' '.join(sys.argv) + '\n')
    parameters_handle.close()

    # Step 1: Parse relationships between gene clusters detected by fai and prepTG genomes
    msg = '------------------Step 1------------------\n'
    msg += 'Processing gene cluster to target genome relationships.'
    log_object.info(msg)
    sys.stdout.write(msg + '\n')

    tg_gbk_dir = target_genomes_db + 'Genomic_Genbanks_Additional/'
    tg_concat_dmnd_db = target_genomes_db + 'Target_Genomes_DB.dmnd'
    tg_diamond_linclust_cluster_pickle_file = target_genomes_db + 'Target_Genomes_DB_clustered.pkl'

    rep_prot_to_nonreps = None
    try:
        with open(tg_diamond_linclust_cluster_pickle_file, 'rb') as handle:
            rep_prot_to_nonreps = pickle.load(handle)
    except Exception as e:
        msg = 'The "Target_Genomes_DB_clustered.pkl" pickle file does not exist within the prepTG db directory specified.'
        log_object.error(msg)
        sys.stdout.write(msg + '\n')
        sys.exit(1)

    tg_genome_list_file = target_genomes_db + 'Target_Genome_Annotation_Files.txt'
    tg_mge_annot_dir = target_genomes_db + 'MGE_Annotations/Summary_Results/'
    fai_indiv_spreadsheet_tsv = fai_results_dir + 'Spreadsheet_TSVs/individual_gcs.tsv'

    try: 
        assert(os.path.isfile(tg_concat_dmnd_db))
    except Exception as e: 
        msg = 'The "Target_Genomes_DB.dmnd" DIAMOND database file does not exist within the prepTG db directory specified.'
        log_object.error(msg)
        sys.stdout.write(msg + '\n')
        sys.exit(1)

    try: 
        assert(os.path.isfile(tg_genome_list_file))
    except Exception as e: 
        msg = 'The "Target_Genome_Annotation_Files.txt" listing file of target genomes does not exist within the prepTG db directory specified.'
        log_object.warning(msg)
        sys.stdout.write(msg + '\n')

    mge_annots_exist = False
    try: 
        assert(os.path.isdir(tg_mge_annot_dir))
        mge_annots_exist = True
    except Exception as e: 
        msg = 'The "MGE_Annotations/Summary_Results/" subdirectory with MGE annotation information per scaffold does not exist within the prepTG db directory specified.\nNote, this is not essential but if you would like this information you will need to rerun prepTG with the "--mge_annotation" argument.'
        log_object.warning(msg)
        sys.stdout.write(msg + '\n')

    try: 
        assert(os.path.isdir(tg_gbk_dir))
    except Exception as e: 
        msg = 'The "Genomic_Genbanks_Additional/" subdirectory does not exist within the prepTG db directory specified.'
        log_object.error(msg)
        sys.stdout.write(msg + '\n')
        sys.exit(1)

    try: 
        assert(os.path.isfile(fai_indiv_spreadsheet_tsv))
    except Exception as e: 
        msg = 'The "Spreadsheet_TSVs/individual_gcs.tsv" file does not exist within the fai results directory specified.'
        log_object.error(msg)
        sys.stdout.write(msg + '\n')
        sys.exit(1)

    genome_gbks = {}
    for f in os.listdir(tg_gbk_dir): 
        genome = '.gbk'.join(f.split('.gbk')[: -1])
        genome_gbks[genome] = tg_gbk_dir + f

    gc_to_genome = {}
    gc_gbks = {}
    with open(fai_indiv_spreadsheet_tsv) as ofist: 
        for i, line in enumerate(ofist): 
            if i == 0: continue
            line = line.strip()
            ls = line.split('\t')
            gc_gbk = ls[1]
            gc_gbk_file = gc_gbk.split('/')[-1]
            gc = '.gbk'.join(gc_gbk_file.split('.gbk')[:-1])
            genome = '.gbk'.join(gc_gbk_file.split('.gbk')[:-1])
            if '_fai-gene-cluster' in genome: 
                genome = genome.split('_fai-gene-cluster')[0]
            gc_to_genome[gc] = genome
            gc_gbks[gc] = gc_gbk

    msg = '------------------Step 2------------------\nRunning codoff for all gene clusters.'
    log_object.info(msg)
    sys.stdout.write(msg + '\n')

    codoff_res_dir = outdir + 'codoff_Results/'
    checkpoint_2_file = check_dir + 'step2.txt'
    if not os.path.isfile(checkpoint_2_file): 
        util.setup_ready_directory([codoff_res_dir], delete_if_exist = True)

        codoff_cmds = []

        for gc in gc_to_genome: 
            outf = codoff_res_dir + gc + '.txt'
            genome = gc_to_genome[gc]
            genome_gbk = genome_gbks[genome]
            gc_gbk = gc_gbks[gc]
            codoff_cmd = ['codoff', '-g', genome_gbk, '-f', gc_gbk, '-o', outf]
            codoff_cmds.append(codoff_cmd)

        # Use robust error handling for codoff analysis
        result_summary = util.robust_multiprocess_executor(
            worker_function=util.multi_process_safe,
            inputs=codoff_cmds,
            pool_size=threads,
            error_strategy="report_and_continue",
            log_object=log_object,
            description="codoff codon optimization analysis"
        )

        success_prop = result_summary['success_count'] / result_summary['total_processed']
        msg = f"{success_prop*100.0}% of codoff runs were successful."
        sys.stdout.write(msg + '\n')
        log_object.info(msg)

        os.system(f'touch {checkpoint_2_file}')

    msg = '------------------Step 3------------------\n'
    msg += 'Getting annotation information for plasmid, phage and IS element\n'
    msg += 'associated proteins in target genomes.'
    log_object.info(msg)
    sys.stdout.write(msg + '\n')

    mge_annots_overview_file = None
    checkpoint_3_file = check_dir + 'step3.txt'
    if mge_annots_exist and not os.path.isfile(checkpoint_3_file): 
        mge_annots_overview_file = outdir + 'MGE_Annotations_Overview.txt'
        concat_cmd = ['time', 'find', tg_mge_annot_dir, '-maxdepth', '1', '-type', 'f', '|', 'xargs', 'cat', ' >> ', mge_annots_overview_file]
        util.run_cmd_via_subprocess(concat_cmd, log_object=log_object, 
                                    check_files = [mge_annots_overview_file])
        os.system(f'touch {checkpoint_3_file}')

    if os.path.isfile(outdir + 'MGE_Annotations_Overview.txt'): 
        mge_annots_overview_file = outdir + 'MGE_Annotations_Overview.txt'

    msg = '------------------Step 4------------------\n'
    msg += 'Identify best hits for universal ribosomal proteins and\n'
    msg += 'standardize amino acid identity of query protein hits in\n'
    msg += 'fai by rpAAI.'
    log_object.info(msg)
    sys.stdout.write(msg + '\n')

    checkpoint_4_file = check_dir + 'step4.txt'
    ribo_norm_dir = outdir + 'Ribosomal_Normalization/'
    gc_to_ribo_aai_stat_file = outdir + 'GeneCluster_to_Ribosomal_AAI_Ratios.txt'
    if not os.path.isfile(checkpoint_4_file): 
        util.setup_ready_directory([ribo_norm_dir], delete_if_exist = True)

        best_tg_hit = None
        best_tg_gc_gbk_file = None
        best_tg_score = 0
        with open(fai_indiv_spreadsheet_tsv) as oftst: 
            for i, line in enumerate(oftst): 
                if i == 0: continue
                line = line.strip()
                ls = line.split('\t')
                agg_score = float(ls[5])
                if agg_score > best_tg_score: 
                    best_tg_score = agg_score
                    best_tg_hit = ls[0]
                    best_tg_gc_gbk_file = ls[1]

        best_tg_gbk_file = None
        if best_tg_hit != None: 
            with open(tg_genome_list_file) as otglf: 
                for line in otglf: 
                    line = line.strip()
                    ls = line.split('\t')
                    if ls[0] == best_tg_hit: 
                        best_tg_gbk_file = target_genomes_db + ls[1]

        tg_query_prots_file = ribo_norm_dir + 'Reference_Genome_Queries.faa'
        try: 
            assert(os.path.isfile(best_tg_gc_gbk_file)) # type: ignore

            tg_query_prots_handle = open(tg_query_prots_file, 'w')
            if best_tg_gc_gbk_file.endswith('.gz'): # type: ignore
                with gzip.open(best_tg_gc_gbk_file, 'rt') as obtggf: # type: ignore
                    for rec in SeqIO.parse(obtggf, 'genbank'): 
                        for feat in rec.features: 
                            if feat.type == 'CDS': 
                                lt = feat.qualifiers.get('locus_tag')[0]
                                prot_seq = feat.qualifiers.get('translation')[0]
                                tg_query_prots_handle.write('>Gene_Cluster_Protein|' + lt + '\n' + prot_seq + '\n')
            else:
                with open(best_tg_gc_gbk_file, 'r') as obtggf: # type: ignore
                    for rec in SeqIO.parse(obtggf, 'genbank'): 
                        for feat in rec.features: 
                            if feat.type == 'CDS': 
                                lt = feat.qualifiers.get('locus_tag')[0]
                                prot_seq = feat.qualifiers.get('translation')[0]
                                tg_query_prots_handle.write('>Gene_Cluster_Protein|' + lt + '\n' + prot_seq + '\n')
            tg_query_prots_handle.close()
        except Exception as e: 
            msg = 'Issue processing GenBank file for gene cluster best matching query proteins.'
            log_object.error(msg)
            sys.stderr.write(msg + '\n')
            sys.stderr.write(traceback.format_exc())
            sys.exit(1)

        try: 
            assert(best_tg_hit != None and os.path.isfile(best_tg_gbk_file) and os.path.getsize(tg_query_prots_file) > 100) # type: ignore
        except Exception as e: 
            msg = 'Issue with finding the best hit target genome in the fai results. Perhaps there was no matching gene clusters identified?'
            sys.stderr.write(msg + '\n')
            log_object.error(msg)
            sys.exit(1)

        util.run_pyhmmer_for_ribo_prots(best_tg_gbk_file, 
         tg_query_prots_file, 
         ribo_norm_dir, 
         log_object, 
         threads = threads)

        diamond_results_file = fai.run_diamond_blastp(tg_concat_dmnd_db, 
                                                        tg_query_prots_file, 
                                                        ribo_norm_dir, 
                                                        log_object, 
                                                        diamond_sensitivity = 'very-sensitive', 
                                                        evalue_cutoff = 1e-3, 
                                                        threads = threads, 
                                                        compute_query_coverage = True)
                                                        
        util.process_diamond_for_gc_to_ribo_ratio(diamond_results_file, 
                                                rep_prot_to_nonreps,
                                                tg_query_prots_file, 
                                                fai_indiv_spreadsheet_tsv, 
                                                gc_to_ribo_aai_stat_file, 
                                                log_object)

        rscript_file = outdir + 'generateSaltGCvsRiboAAIPlot.R'
        gc_ribo_aai_plot_pdf_file = outdir + 'GC_to_RiboProt_AAI_Relationships.pdf'
        util.make_gc_vs_ribo_prot_aai_scatterplot(rscript_file, 
                                                gc_to_ribo_aai_stat_file, 
                                                gc_ribo_aai_plot_pdf_file, 
                                                log_object)

        os.system(f'touch {checkpoint_4_file}')

    try: 
        assert(os.path.isfile(gc_to_ribo_aai_stat_file))
    except Exception as e: 
        msg = 'It appears that step 4 did not run properly but the checkpoint was created. This is shouldn\'t happen so please share any info with us on GitHub Issues.'
        log_object.error(msg)
        sys.stderr.write(msg + '\n')
        sys.exit(1)

    msg = '------------------Step 5------------------\n'
    msg += 'Consolidating assessments together in a single report.'
    log_object.info(msg)
    sys.stdout.write(msg + '\n')

    # will always be rerun by design - but it should be fast.
    result_tsv_file = outdir + 'SALT_Report.tsv'
    result_xlsx_file = outdir + 'SALT_Report.xlsx'
    util.consolidate_salty_spreadsheet(fai_indiv_spreadsheet_tsv, 
        genome_gbks, 
        codoff_res_dir, 
        mge_annots_overview_file, 
        gc_to_ribo_aai_stat_file, 
        result_tsv_file, 
        result_xlsx_file, 
        log_object)

    msg = f'salt finished successfully!\nResulting spreadsheet can be found at: {result_xlsx_file}'
    log_object.info(msg)
    sys.stdout.write(msg + '\n')


if __name__ == '__main__': 
    salt()
