#! /usr/bin/env python

import argparse
import datetime
import time
from distutils.util import strtobool

from csubst.__init__ import __version__
from csubst import param

def command_dataset(args):
    print('csubst dataset start:', datetime.datetime.now(datetime.timezone.utc), flush=True)
    start = time.time()
    from csubst.main_dataset import main_dataset
    g = param.get_global_parameters(args)
    main_dataset(g)
    print('csubst dataset: Time elapsed: {:,} sec'.format(int(time.time() - start)))
    print('csubst dataset end:', datetime.datetime.now(datetime.timezone.utc), flush=True)

def command_simulate(args):
    print('csubst simulate start:', datetime.datetime.now(datetime.timezone.utc), flush=True)
    start = time.time()
    from csubst.main_simulate import main_simulate
    g = param.get_global_parameters(args)
    main_simulate(g)
    print('csubst simulate: Time elapsed: {:,} sec'.format(int(time.time() - start)))
    print('csubst simulate end:', datetime.datetime.now(datetime.timezone.utc), flush=True)

def command_site(args):
    print('csubst site start:', datetime.datetime.now(datetime.timezone.utc), flush=True)
    start = time.time()
    from csubst.main_site import main_site
    g = param.get_global_parameters(args)
    main_site(g)
    print('csubst site: Time elapsed: {:,} sec'.format(int(time.time() - start)))
    print('csubst site end:', datetime.datetime.now(datetime.timezone.utc), flush=True)
    if (g['pdb'] is not None):
        # This should be executed at the very end, otherwise CSUBST's main process is killed.
        from csubst.parser_pymol import quit_pymol
        quit_pymol()

def command_analyze(args):
    print('csubst analyze start:', datetime.datetime.now(datetime.timezone.utc), flush=True)
    start = time.time()
    from csubst.main_analyze import main_analyze
    g = param.get_global_parameters(args)
    main_analyze(g)
    print('csubst analyze: Time elapsed: {:,} sec'.format(int(time.time() - start)))
    print('csubst analyze end:', datetime.datetime.now(datetime.timezone.utc), flush=True)

if __name__ == "__main__":
    # Start time
    csubst_start = time.time()
    print('CSUBST start:', datetime.datetime.now(datetime.timezone.utc), flush=True)

    # Main parser
    parser = argparse.ArgumentParser(description='CSUBST - a toolkit for molecular convergence detection. For details, see https://github.com/kfuku52/csubst')
    parser.add_argument('--version', action='version', version='CSUBST version: {}'.format(__version__))
    subparsers = parser.add_subparsers()

    # shared: common
    psr_co = argparse.ArgumentParser(add_help=False)
    psr_co.add_argument('--alignment_file', metavar='PATH', default='', type=str,
                       help='default=%(default)s: PATH to in-frame codon alignment (FASTA format).')
    psr_co.add_argument('--rooted_tree_file', metavar='PATH', default='', type=str,
                       help='default=%(default)s: PATH to input rooted tree (Newick format). Tip labels should be consistent with --alignment_file.')
    psr_co.add_argument('--genetic_code', metavar='INTEGER', type=int, required=False, default=1,
                       help='default=%(default)s: NCBI codon table ID. 1 = "Standard". See here: '
                            'https://www.ncbi.nlm.nih.gov/Taxonomy/Utils/wprintgc.cgi')
    psr_co.add_argument('--infile_type', metavar='[iqtree]', default='iqtree', type=str, choices=['iqtree',],
                       help='default=%(default)s: The input file format.')
    psr_co.add_argument('--threads', metavar='INTEGER', default=1, type=int, required=False,
                       help='default=%(default)s: The number of CPUs for parallel computations.')
    psr_co.add_argument('--float_type', metavar='16|32|64', default=64, type=int, required=False,
                        help='default=%(default)s: Float data type for tensors. "16" is not recommended.')
    psr_co.add_argument('--float_digit', metavar='INT', default=4, type=int, required=False,
                        help='default=%(default)s: Number of output float digits.')

    # shared: IQ-TREE inputs
    psr_iq = argparse.ArgumentParser(add_help=False)
    psr_iq.add_argument('--iqtree_exe', metavar='PATH', default='iqtree', type=str, required=False,
                        help='default=%(default)s: PATH to the IQ-TREE executable')
    psr_iq.add_argument('--iqtree_model', metavar='STR', default='ECMK07+F+R4', type=str, required=False,
                        help='default=%(default)s: Codon substitution model for ancestral state reconstruction. '
                             'Base models of "MG", "GY", "ECMK07", and "ECMrest" are supported. '
                             'Among-site rate heterogeneity and codon frequencies can be specified. '
                             'See here for details: http://www.iqtree.org/doc/Substitution-Models')
    psr_iq.add_argument('--iqtree_redo', metavar='yes|no', default='no', type=strtobool,
                        help='default=%(default)s: Whether to rerun IQ-TREE even if all intermediate files exist.')
    psr_iq.add_argument('--iqtree_treefile', metavar='PATH', default='infer', type=str, required=False,
                       help='default=%(default)s: PATH to the IQ-TREE\'s .treefile output. "infer" from --alignment_file')
    psr_iq.add_argument('--iqtree_state', metavar='PATH', default='infer', type=str, required=False,
                       help='default=%(default)s: PATH to the IQ-TREE\'s .state output. "infer" from --alignment_file')
    psr_iq.add_argument('--iqtree_rate', metavar='PATH', default='infer', type=str, required=False,
                       help='default=%(default)s: PATH to the IQ-TREE\'s .rate output. "infer" from --alignment_file')
    psr_iq.add_argument('--iqtree_iqtree', metavar='PATH', default='infer', type=str, required=False,
                       help='default=%(default)s: PATH to the IQ-TREE\'s .iqtree output. "infer" from --alignment_file')
    psr_iq.add_argument('--iqtree_log', metavar='PATH', default='infer', type=str, required=False,
                        help='default=%(default)s: PATH to the IQ-TREE\'s .log output. "infer" from --alignment_file')

    # shared: Ancestral_state
    psr_as = argparse.ArgumentParser(add_help=False)
    psr_as.add_argument('--ml_anc', metavar='yes|no', default='no', type=strtobool,
                        help='default=%(default)s: Maximum-likelihood-like analysis by binarizing ancestral states.')
    psr_as.add_argument('--min_sub_pp', metavar='FLOAT', default=0, type=float,
                         help='default=%(default)s: The minimum posterior probability of single substitutions to count. '
                              'Set 0 for a counting without binarization. Omitted if --ml_anc is set to "yes".')

    # shared: PhyloBayes inputs, but PhyloBayes is no longer supported.
    psr_pb = argparse.ArgumentParser(add_help=False)
    #psr_pb.add_argument('--phylobayes_dir', metavar='PATH', default='./', type=str, required=False,
    #                   help='default=%(default)s: PATH to the PhyloBayes output directory.')

    # shared: foreground
    psr_fg = argparse.ArgumentParser(add_help=False)
    psr_fg.add_argument('--foreground', metavar='PATH', default=None, type=str, required=False,
                    help='default=%(default)s: A text file to specify the foreground lineages. '
                         'The file should contain two columns separated by a tab: '
                         '1st column for lineage IDs and 2nd for regex-compatible leaf names. '
                         'See https://github.com/kfuku52/csubst/wiki/Foreground-specification')
    psr_fg.add_argument('--fg_format', metavar='1|2', default=1, type=int, choices=[1,2],
                    help='default=%(default)s: Table format of --foreground.')
    psr_fg.add_argument('--fg_exclude_wg', metavar='yes|no', default='yes', type=strtobool,
                    help='default=%(default)s: Set "yes" to exclude branch combinations '
                         'within individual foreground lineages.')
    psr_fg.add_argument('--fg_stem_only', metavar='yes|no', default='yes', type=strtobool,
                    help='default=%(default)s: Set "yes" to exclude non-stem branches of foreground lineages.')
    psr_fg.add_argument('--mg_parent', metavar='yes|no', default='no', type=strtobool,
                    help='default=%(default)s: Mark the parent branches of the foreground stem branches as "marginal". '
                         'They may serve as "negative controls" relative to the foreground lineages.')
    psr_fg.add_argument('--mg_sister', metavar='yes|no', default='no', type=strtobool,
                    help='default=%(default)s: Mark the sister branches of the foreground stem branches as "marginal". '
                         'They may serve as "negative controls" relative to the foreground lineages.')
    psr_fg.add_argument('--mg_sister_stem_only', metavar='yes|no', default='yes', type=strtobool,
                    help='default=%(default)s: Set "yes" to exclude non-stem branches of sister lineages.')
    psr_fg.add_argument('--fg_clade_permutation', metavar='INT', default=0, type=int,
                    help='default=%(default)s: Experimental. Randomly select the same/similar number and size of clades as foreground '
                         'and run analysis N times to obtain a permutation-based P value of convergence. '
                         'At least 1000 is recommended.')
    psr_fg.add_argument('--min_clade_bin_count', metavar='INT', default=10, type=int,
                    help='default=%(default)s: Experimental. Minimum number of branches per bin for foreground clade permutation. ')

    # dataset
    help_txt = 'generates out-of-the-box test datasets. See `csubst dataset -h`'
    dataset = subparsers.add_parser('dataset', help=help_txt, parents=[])
    dataset.add_argument('--name', metavar='STR', default='PGK', type=str, choices=['PGK','PEPC'],
                         help='default=%(default)s: Name of dataset to generate.')
    dataset.set_defaults(handler=command_dataset)

    # simulate
    help_txt = 'generates a simulated sequence alignment under a convergent evolutionary scenario. See `csubst simulate -h`'
    simulate = subparsers.add_parser('simulate', help=help_txt, parents=[psr_co,psr_iq,psr_pb,psr_fg])
    simulate.add_argument('--background_omega', metavar='FLOAT', default=0.2, type=float,
                          help='default=%(default)s: dN/dS for background branches.')
    simulate.add_argument('--foreground_omega', metavar='FLOAT', default=0.2, type=float,
                          help='default=%(default)s: dN/dS for foreground branches.')
    simulate.add_argument('--num_simulated_site', metavar='INT', default=-1, type=int,
                          help='default=%(default)s: Number of codon sites to simulate. '
                               '-1 to set the size of input alignment.')
    simulate.add_argument('--percent_convergent_site', metavar='FLOAT', default=100, type=float,
                          help='default=%(default)s: Percentage of codon sites to evolve convergently.'
                               'If --convergent_amino_acids randomN, '
                               'Convergent amino acids are randomly selected within each partition. ')
    simulate.add_argument('--optimized_branch_length', metavar='yes|no', default='no', type=strtobool,
                          help='default=%(default)s: Whether to use the branch lengths optimized by IQ-TREE. '
                               'If "no", the branch lengths in the input tree are used.')
    simulate.add_argument('--tree_scaling_factor', metavar='FLOAT', default=1, type=float,
                          help='default=%(default)s: Branch lengths are multiplied by this value.')
    simulate.add_argument('--foreground_scaling_factor', metavar='FLOAT', default=1, type=float,
                          help='default=%(default)s: In the codon sites specified by --percent_convergent_site, '
                               'branch lengths in foreground lineages are multiplied by this value.')
    simulate.add_argument('--convergent_amino_acids', metavar='STR', default='random1', type=str,
                          help='default=%(default)s: Non-delimited list of amino acids the sequences converge into. '
                               'e.g, AQTS, ACQ, WDETS... '
                               '"randomN" specifies randomly selected N amino acids. '
                               '"random0" does not cause convergence.')
    simulate.add_argument('--percent_biased_sub', metavar='FLOAT', default=90, type=float,
                          help='default=%(default)s: Approximately this percentage of nonsynonymous substitutions '
                               'in the foreground branches/sites are biased toward amino acids specified by '
                               '--convergent_amino_acids, while preserving the original relative codon frequencies '
                               'among the synonymous codons.')
    simulate.set_defaults(handler=command_simulate)


    # site
    help_txt = 'calculates site-wise combinatorial substitutions on focal branch combinations and maps it onto protein structure. '
    help_txt += ' See `csubst site -h`'
    site = subparsers.add_parser('site', help=help_txt, parents=[psr_co,psr_iq,psr_as])
    site.add_argument('--branch_id', metavar='fg|INT,INT,INT,...', default=None, required=True, type=str,
                      help='default=%(default)s: Comma-delimited list of branch_ids to characterize. '
                      'Run `csubst analyze` first and select branches of interest. '
                      'If "fg", all foreground branch pairs will be analyzed.')
    site.add_argument('--mode', metavar='intersection|total', default='intersection', required=False, type=str,
                      choices=['intersection','total',],
                      help='default=%(default)s: Visualization mode. NOT SUPPORTED YET.')
    site.add_argument('--cb_file', metavar='PATH', default='csubst_cb_2.tsv', required=False, type=str,
                      help='default=%(default)s: PATH to csubst_cb_*.tsv output of csubst analyze.')
    site.add_argument('--untrimmed_cds', metavar='PATH', default=None, required=False, type=str,
                      help='default=%(default)s: PATH to fasta file containing untrimmed CDS sequence(s). '
                           'Codon positions along the sequence(s) appear in the output tsv.')
    site.add_argument('--pdb', metavar='PATH/PDB_CODE/besthit', default=None, required=False, type=str,
                      help='default=%(default)s: '
                           'One of the followings. '
                           'PATH: PATH to the downloaded pdb file. '
                           'PDB CODE: e.g. 3SV0. PDB file will be fetched from the database. '
                           'The PDB protein sequence will be aligned with mafft, '
                           'and a pymol session will be generated. '
                           'besthit: Run online sequence search and fetch the best-hit model. See --database.')
    site.add_argument('--database', metavar='DATABASE1,DATABASE2,...', default='pdb,alphafill,alphafold', required=False, type=str,
                      help='default=%(default)s: '
                           'Comma-delimited names of protein structure databases to search. '
                           'The top priority is given to the first. '
                           'pdb: Run online MMseqs2 search against the RCSB PDB database (https://www.rcsb.org/) and fetch the best-hit PDB model. '
                           'alphafill: Run online QBLAST search against the UniProt database (https://www.uniprot.org/) and fetch the best-hit AlphaFill model (https://alphafill.eu/). '
                           'alphafold: Run online QBLAST search against the UniProt database (https://www.uniprot.org/) and fetch the best-hit AlphaFold model (https://alphafold.ebi.ac.uk/).')
    site.add_argument('--database_evalue_cutoff', metavar='FLOAT', default=1.0, required=False, type=float,
                      help='default=%(default)s: E-value cutoff in the database searches. Applied to MMseqs2 and QBLAST.')
    site.add_argument('--database_minimum_identity', metavar='FLOAT', default=0.25, required=False, type=float,
                      help='default=%(default)s: The minimum sequence identity for the database searches. Applied to MMseqs2. See https://search.rcsb.org/index.html#search-api')
    site.add_argument('--user_alignment', metavar='PATH', default=None, required=False, type=str,
                      help='default=%(default)s: The user-provided alignment FASTA for the substitution mapping to protein structures.')
    site.add_argument('--remove_solvent', metavar='yes|no', default='yes', type=strtobool,
                      help='default=%(default)s: Whether to remove solvent and small non-specific ligands. '
                           'Used only with --pdb.')
    site.add_argument('--remove_ligand', metavar='CODE1,CODE2,CODE3,...', default='', type=str,
                      help='default=%(default)s: Comma-delimited list of PDB ligand codes to be removed. '
                           'e.g., "so4,po4,bme". '
                           'Used only with --pdb.')
    site.add_argument('--mask_subunit', metavar='yes|no', default='yes', type=strtobool,
                      help='default=%(default)s: Whether to mask unrelated subunits. '
                           'Used only with --pdb.')
    site.add_argument('--mafft_exe', metavar='PATH', default='mafft', required=False, type=str,
                      help='default=%(default)s: PATH to mafft executable.')
    site.add_argument('--mafft_op', metavar='FLOAT', default=-1, required=False, type=float,
                      help='default=%(default)s: mafft --op parameter. -1 to use the default value.')
    site.add_argument('--mafft_ep', metavar='FLOAT', default=-1, required=False, type=float,
                      help='default=%(default)s: mafft --ep parameter. -1 to use the default value.')
    site.add_argument('--pymol_min_combinat_prob', metavar='FLOAT', default=0.5, required=False, type=float,
                      help='default=%(default)s: Minimum probability of combinatorial substitutions to visualize.')
    site.add_argument('--pymol_min_single_prob', metavar='FLOAT', default=0.8, required=False, type=float,
                      help='default=%(default)s: Minimum probability of single substitutions to visualize.')
    site.add_argument('--pymol_gray', metavar='INT', default=80, required=False, type=int,
                      help='default=%(default)s: Gray value for no-substitution sites. 0=black, 100=white')
    site.add_argument('--pymol_transparency', metavar='FLOAT', default=0.65, required=False, type=float,
                      help='default=%(default)s: Surface transparency. 0=non-transparent, 1=completely transparent')
    site.add_argument('--pymol_img', metavar='yes|no', default='yes', type=strtobool,
                      help='default=%(default)s: Whether to generate a rendered 6-view image file of protein structure.')
    site.add_argument('--pymol_max_num_chain', metavar='INT', default=20, required=False, type=int,
                         help='default=%(default)s: Maximum number of chains to visualize. '
                               'If the number of chains in the PDB file is larger than this value, '
                               '--pymol_img will be disabled.')
    site.add_argument('--export2chimera', metavar='yes|no', default='no', required=False, type=strtobool,
                      help='default=%(default)s: Set "yes" to export files for the visualization of '
                           'convergence/divergence probabilities with UCSF Chimera. '
                           '--untrimmed_cds is required.')
    site.set_defaults(handler=command_site)


    # analyze
    help_txt = 'calculates convergence rates and other metrics on branch combinations.  See `csubst analyze -h`'
    analyze = subparsers.add_parser('analyze', help=help_txt, parents=[psr_co,psr_iq,psr_pb,psr_fg,psr_as])
    analyze.set_defaults(handler=command_analyze)
    # branch combinations
    analyze.add_argument('--max_arity', metavar='INTEGER', default=2, type=int,
                         help='default=%(default)s: The maximum combinatorial number of branches (K). '
                              'Set 2 for branch pairs. 3 or larger for higher-order combinations.')
    analyze.add_argument('--exhaustive_until', metavar='INTEGER', default=2, type=int,
                         help='default=%(default)s: Analyze all independent branch combinations until specified arity (K). '
                              'Be careful of combinatorial explosion if set to 3 or higher. '
                              'Set to 1 for foreground-only analysis.')
    analyze.add_argument('--max_combination', metavar='INTEGER', default=10000, type=int,
                         help='default=%(default)s: Maximum number of branch combinations to generate at K+1. '
                              'If possible branch combinations at K+1 are more than this number, '
                              'top N convergent combinatinons are selected with the thresholds specified by --cutoff_stat. '
                              'The first stat in --cutoff_stat is most prioritized.')
    analyze.add_argument('--exclude_sister_pair', metavar='yes|no', default='yes', type=strtobool,
                         help='default=%(default)s: Set to "yes" for excluding sister branches in branch combination analysis.')
    analyze.add_argument('--cutoff_stat', metavar='[STAT1,VALUE1|STAT2,VALUE2|...]',
                         default='OCNany2spe,2.0|omegaCany2spe,5.0', type=str,
                         help='default=%(default)s: Cutoff statistics for searching higher-order branch combinations. '
                              'STAT is a column in csubst_cb_N.tsv. Regular expressions are supported. '
                              'e.g., "N_sub_[0-9]+" to specify the number of branch-wise nonsynonymous substitutions '
                              'in all branches. VALUE is the minimum value, '
                              'and branch combinations with smaller values will be excluded.')
    analyze.add_argument('--branch_dist', metavar='yes|no', default='yes', type=strtobool,
                         help='default=%(default)s: Set "yes" to calculate inter-branch distance.')
    # Substitution outputs
    analyze.add_argument('--b', metavar='yes|no', default='yes', type=strtobool,
                         help='default=%(default)s: Branch output. Set "yes" to generate the output tsv.')
    analyze.add_argument('--s', metavar='yes|no', default='no', type=strtobool,
                         help='default=%(default)s: Site output. Set "yes" to generate the output tsv.')
    analyze.add_argument('--cs', metavar='yes|no', default='no', type=strtobool,
                         help='default=%(default)s: Combinatorial-site output. Set "yes" to generate the output tsv.')
    analyze.add_argument('--cb', metavar='yes|no', default='yes', type=strtobool,
                         help='default=%(default)s: Combinatorial-branch output. Set "yes" to generate the output tsv.')
    analyze.add_argument('--bs', metavar='yes|no', default='no', type=strtobool,
                         help='default=%(default)s: Branch-site output. Set "yes" to generate the output tsv.')
    analyze.add_argument('--cbs', metavar='yes|no', default='no', type=strtobool,
                         help='default=%(default)s: Combinatorial-branch-site output. Set "yes" to generate the output tsv.')
    # Plot outputs
    analyze.add_argument('--more_tree_plot', metavar='yes|no', default='no', type=strtobool,
                         help='default=%(default)s: More tree plots to generate.')
    analyze.add_argument('--plot_state_aa', metavar='yes|no', default='no', type=strtobool,
                         help='default=%(default)s: Tree plots with per-site ancestral amino acid states. '
                              'This option will generate many pdfs')
    analyze.add_argument('--plot_state_codon', metavar='yes|no', default='no', type=strtobool,
                         help='default=%(default)s: Tree plots with per-site ancestral codon states. '
                              'This option will generate many pdfs')
    # Omega_C calculation
    analyze.add_argument('--omegaC_method', metavar='[submodel|modelfree]', default='submodel', type=str,
                         choices=['modelfree','submodel'],
                         help='default=%(default)s: Method to calculate omega_C. '
                              '"submodel" utilizes a codon substitution model in the ancestral state reconstruction. '
                              'In addition to the base substitution models, codon frequencies and '
                              'among-site rate heterogeneity are taken into account. '
                              'Described in Fukushima and Pollock (2023, https://doi.org/10.1038/s41559-022-01932-7). '
                              '"modelfree" (experimental) for expected values from among-site randomization '
                              '(urn sampling) of substitutions. ')
    analyze.add_argument('--calc_quantile', metavar='yes|no', default='no', type=strtobool,
                         help='default=%(default)s: Experimental. Calculate resampling-based quantiles in '
                              'Wallenius\' noncentral hypergeometric distribution for combinatorial substitutions. '
                              'This option should be used with --omegaC_method "modelfree".')
    analyze.add_argument('--asrv', metavar='no|pool|sn|each|file', default='each', type=str,
                         choices=['no', 'pool', 'sn', 'each', 'file'],
                         help='default=%(default)s: Experimental. Correct among-site rate variation in omega/quantile calculation. '
                              'This option is used in --omegaC_method modelfree but not with --omegaC_method submodel. '
                              '"no", No ASRV, meaning a uniform rate among sites. '
                              '"pool", All categories of substitutions are pooled to calculate a single set of ASRV. '
                              '"sn", Synonymous and nonsynonymous substitutions are processed individually '
                              'to calculate their respective ASRVs (2 sets). '
                              '"each", Each of 61x60 patterns of substitutions are processed individually '
                              'to calculate their respective ASRVs. '
                              '"file", ASRV is obtained from the IQ-TREE\'s .rate file (1 set). ')
    analyze.add_argument('--calibrate_longtail', metavar='yes|no', default='yes', type=strtobool,
                         help='default=%(default)s: Calibrate dS_C to match the distribution range of dS_C with dN_C '
                              'by quantile-based transformation.')

    # Handler
    args = parser.parse_args()
    #param.set_num_thread_variables(num_thread=args.threads)
    if hasattr(args, 'handler'):
        args.handler(args)
    else:
        parser.print_help()

    # End time
    txt = '\nCSUBST end: {}, Elapsed time = {:,.1f} sec'
    print(txt.format(datetime.datetime.now(datetime.timezone.utc), int(time.time()-csubst_start)), flush=True)
