#!/usr/bin/env python3
'''
ntSynt: Genome synteny detection using dynamic undirected minimizer graphs
Written by Lauren Coombe (@lcoombe)
'''
import argparse
import os
import shlex
import subprocess
from packaging import version

import snakemake

NTSYNT_VERSION = "ntSynt v1.0.4"

NTSYNT_ASCII = r"""
        _    ____                 _ 
 _ __  | |_ / ___|  _   _  _ __  | |_ 
| '_ \ | __|\___ \ | | | || '_ \ | __|
| | | || |_  ___) || |_| || | | || |_ 
|_| |_| \__||____/  \__, ||_| |_| \__|
                    |___/   
"""

def read_fasta_files(filename):
    "Read in fasta file names in the specified file"
    fastas = []
    with open(filename, 'r', encoding="utf-8") as fin:
        for fasta in fin:
            fastas.append(fasta.strip())
    return fastas

def main():
    "Run ntSynt snakemake file"

    epilog_str = "\n".join(["Default parameter settings for divergence values:",
                  "< 1% divergence:\t--block_size 500 --indel 10000 --merge 10000 --w_rounds 100 10",
                  "1% - 10% divergence:\t--block_size 1000 --indel 50000 --merge 100000 --w_rounds 250 100",
                  "> 10% divergence:\t--block_size 10000 --indel 100000 --merge 1000000 --w_rounds 500 250",
                  "If any of these parameters are set manually, those values will override the above.",
                  "\nIf you have any questions about ntSynt, please create a GitHub issue: https://github.com/bcgsc/ntSynt"])

    parser = argparse.ArgumentParser(description="ntSynt: Multi-genome synteny detection using minimizer graphs",
                                     formatter_class=argparse.RawTextHelpFormatter,
                                     epilog=epilog_str)
    parser.add_argument("fastas", help="Input genome fasta files", nargs="*")
    parser.add_argument("--fastas_list", help="File listing input genome fasta files, one per line",
                        required=False, type=str)
    parser.add_argument("-d", "--divergence",
                        help="Approx. maximum percent sequence divergence between input genomes"\
                            " (Ex. -d 1 for 1%% divergence).\n"\
                            "This will be used to set --indel, --merge, --w_rounds, --block_size\n"\
                            "See below for default values - You can also set any of those parameters yourself, which will override these settings.",
                        required=True, type=float)
    parser.add_argument("-p", "--prefix", help="Prefix for ntSynt output files [ntSynt.k<k>.w<w>]",
                        required=False)
    parser.add_argument("-k", help="Minimizer k-mer size [24]", type=int, required=False, default=24)
    parser.add_argument("-w", help="Minimizer window size [1000]", type=int, required=False, default=1000)
    parser.add_argument("-t", help="Number of threads [12]", type=int, default=12)
    parser.add_argument("--fpr", help="False positive rate for Bloom filter creation [0.025]",
                        default=0.025, type=float)
    parser.add_argument("-b", "--block_size", help="Minimum synteny block size (bp)",
                        type=int, required=False)
    parser.add_argument("--merge", help="Maximum distance between collinear synteny blocks for merging (bp). \n"\
                                        "Can also specify a multiple of the window size (ex. 3w)",
                        type=str)
    parser.add_argument("--w_rounds", help="List of decreasing window sizes for synteny block refinement",
                        nargs="+", type=int)
    parser.add_argument("--indel", help="Threshold for indel detection (bp)", type=int)
    parser.add_argument("--no-common", help=argparse.SUPPRESS,
                        action="store_true")
    parser.add_argument("--no-simplify-graph", help=argparse.SUPPRESS,
                        action="store_true")
    parser.add_argument("-n", "--dry-run", help="Print out the commands that will be executed", action="store_true")
    parser.add_argument("--benchmark", help="Store benchmarks for each step of the ntSynt pipeline",
                        action="store_true")
    parser.add_argument("-f", "--force", help="Run all ntSynt steps, regardless of existing output files",
                        action="store_true")
    parser.add_argument("--dev", help="Run in developer mode to retain intermediate files, log verbose output",
                        action="store_true")
    parser.add_argument("-v", "--version", action="version", version=NTSYNT_VERSION)

    args = parser.parse_args()

    base_dir = os.path.dirname(os.path.realpath(__file__))
    if not args.prefix:
        args.prefix = f"ntSynt.k{args.k}.w{args.w}"

    if args.divergence < 1:
        args.indel, args.merge, args.w_rounds, args.block_size = \
            args.indel or 10000, args.merge or 10000, args.w_rounds or [100, 10], args.block_size or 500
    elif args.divergence >= 1 and args.divergence <= 10:
        args.indel, args.merge, args.w_rounds, args.block_size = \
            args.indel or 50000, args.merge or 100000, args.w_rounds or [250, 100], args.block_size or 1000
    elif args.divergence > 10 and args.divergence <= 100:
        args.indel, args.merge, args.w_rounds, args.block_size = \
            args.indel or 100000, args.merge or 1000000, args.w_rounds or [500, 250], args.block_size or 10000
    else:
        parser.error("--divergence must be a value between 0 and 100")

    # Check that the specified w_rounds are smaller than the initial window size
    for w in args.w_rounds:
        if w > args.w:
            parser.error("All values specified for --w_rounds must be smaller than -w")

    if not args.fastas and not args.fastas_list:
        parser.error("Please supply the input genome fasta files as positional arguments, " \
                     "or specify a file listing the files (one fasta per line) with --fastas_list")

    if args.fastas and args.fastas_list:
        parser.error("Please supply the input genome fasta files as positional arguments, "\
                    "or specify a single file (one fasta per line) with --fastas_list, NOT both.")

    if args.fastas_list:
        fastas_input = read_fasta_files(args.fastas_list)
    else:
        fastas_input = args.fastas

    if len(fastas_input) < 2:
        parser.error("Must supply at least two reference genomes to compare")

    print(NTSYNT_ASCII)
    intro_string = ["Running ntSynt...",
                    f"Specified percent divergence: {args.divergence}",
                    "Parameter settings:",
                    f"\tfastas {fastas_input}",
                    f"\t--divergence {args.divergence}",
                    f"\t--block_size {args.block_size}",
                    f"\t--merge {args.merge}",
                    f"\t--w_rounds {args.w_rounds}",
                    f"\t--indel {args.indel}",
                    f"\t-p {args.prefix}",
                    f"\t-k {args.k}",
                    f"\t-w {args.w}",
                    f"\t-t {args.t}",
                    f"\t--fpr {args.fpr}",
                    ]
    print("\n".join(intro_string), flush=True)

    # Check that the input fastas exist
    for fasta in fastas_input:
        if not os.path.isfile(fasta):
            raise FileNotFoundError(f"Input file {fasta} not found.")

    args.w_rounds = " ".join(map(str, args.w_rounds))
    command = f"snakemake -s {base_dir}/ntsynt_run_pipeline.smk -p --cores {args.t} " \
                f"--config references='{fastas_input}' kmer={args.k} window={args.w} threads={args.t} fpr={args.fpr} " \
                f"prefix={args.prefix} w_rounds='{args.w_rounds}' indel_merge={args.indel} " \
                f"collinear_merge={args.merge} block_size={args.block_size} "
    command += "common=False " if args.no_common else "common=True "
    command += "simplify_graph=False " if args.no_simplify_graph else "simplify_graph=True "
    command += "benchmark=True " if args.benchmark else "benchmark=False "
    command += "dev=True " if args.dev else "dev=False "
    command += "--resources load=2 " # For indexlr, don't want more than 2 indexlr at a time due to memory
    if version.parse(snakemake.__version__) >= version.parse("7.8.0"): # Keep behaviour consistent for smk versions
        command += "--rerun-trigger mtime "

    if args.dry_run:
        command += " -n"

    if args.force:
        command += " -F"

    print(f"Running {command}", flush=True)

    command = shlex.split(command)

    ret = subprocess.call(command)
    if ret != 0:
        raise subprocess.SubprocessError("ntSynt failed - check the logs for the error.")

    print("Done ntSynt!")

if __name__ == "__main__":
    main()
