#!/opt/conda/conda-bld/epic2_1764599019127/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/bin/python

from __future__ import print_function

from sys import argv
import sys
import operator
from collections import defaultdict

import numpy as np
from scipy.stats import poisson, rankdata, spearmanr, pearsonr

from epic2.main import _main

from epic2.version import __version__
from epic2.src.find_islands import differential_count_reads_on_islands

from collections import OrderedDict
from natsort import natsorted
import argparse
import os

import pyranges as pr
import pandas as pd

required_args = False if ("--version" in sys.argv or "-v" in sys.argv
                          or "-ex" in sys.argv
                          or "--example" in sys.argv) else True

parser = argparse.ArgumentParser(
    description="""epic2-df, version: {}
(Visit github.com/endrebak/epic2 for examples and help. Run epic2-example for a simple example command.)
    """.format(__version__),
    prog=os.path.basename(__file__))

parser.add_argument(
    '--treatment-knockout',
    '-tk',
    required=required_args,
    type=str,
    nargs='+',
    help=
    '''Treatment (pull-down) file(s) for knockout in one of these formats: bed, bedpe, bed.gz, bedpe.gz or (single-end) bam, sam. Mixing file formats is allowed.'''
)

parser.add_argument(
    '--control-knockout',
    '-ck',
    required=False,
    type=str,
    nargs='+',
    help=
    '''Control (input) file(s) for knockout in one of these formats: bed, bedpe, bed.gz, bedpe.gz or (single-end) bam, sam. Mixing file formats is allowed.'''
)

parser.add_argument(
    '--treatment-wildtype',
    '-tw',
    required=required_args,
    type=str,
    nargs='+',
    help=
    '''Treatment (pull-down) file(s) for wildtype in one of these formats: bed, bedpe, bed.gz, bedpe.gz or (single-end) bam, sam. Mixing file formats is allowed.'''
)

parser.add_argument(
    '--control-wildtype',
    '-cw',
    required=False,
    type=str,
    nargs='+',
    help=
    '''Control (input) file(s) for wildtype in one of these formats: bed, bedpe, bed.gz, bedpe.gz or (single-end) bam, sam. Mixing file formats is allowed.'''
)

parser.add_argument(
    '--genome',
    '-gn',
    required=False,
    default="hg19",
    type=str,
    help=
    '''Which genome to analyze. Default: hg19. If --chromsizes and --egf flag is given, --genome is not required.'''
)

parser.add_argument(
    '--keep-duplicates',
    '-kd',
    required=False,
    default=False,
    action='store_true',
    help=
    '''Keep reads mapping to the same position on the same strand within a library. Default: False.
                   ''')

parser.add_argument(
    '--original-algorithm',
    '-oa',
    required=False,
    default=False,
    action='store_true',
    help=
    '''Use the original SICER algorithm, without the epic2 fix. This will use all reads in your files to compute the p-values, including those falling outside the genome boundaries.'''
)

parser.add_argument(
    '--bin-size',
    '-bin',
    required=False,
    default=200,
    type=int,
    help=
    '''Size of the windows to scan the genome. BIN-SIZE is the smallest possible island. Default 200.
                   ''')

parser.add_argument(
    '--gaps-allowed',
    '-g',
    required=False,
    default=3,
    type=int,
    help=
    '''This number is multiplied by the window size to determine the number of gaps
                   (ineligible windows) allowed between two eligible windows.
                   Must be an integer. Default: 3. ''')

parser.add_argument(
    '--fragment-size',
    '-fs',
    required=False,
    default=150,
    type=int,
    help=
    '''(Single end reads only) Size of the sequenced fragment. Each read is extended half the fragment size from the 5' end. Default 150 (i.e. extend by 75).'''
)

parser.add_argument(
    '--false-discovery-rate-cutoff',
    '-fdr',
    required=False,
    default=0.05,
    type=float,
    help='''Remove all islands with an FDR above cutoff. Default 0.05.
                   ''')

parser.add_argument(
    '--false-discovery-rate-comparison',
    '-fdrc',
    required=False,
    default=0.05,
    type=float,
    help=
    '''Remove all merged islands with an FDR above this value for both up- and downregulation. Default 0.05.'''
)

parser.add_argument(
    '--effective-genome-fraction',
    '-egf',
    required=False,
    type=float,
    help=
    '''Use a different effective genome fraction than the one included in epic2. The default value depends on the genome and readlength, but is a number between 0 and 1.'''
)

parser.add_argument(
    '--chromsizes',
    '-cs',
    required=False,
    type=str,
    help=
    '''Set the chromosome lengths yourself in a file with two columns: chromosome names and sizes. Useful to analyze custom genomes, assemblies or simulated data. Only chromosomes included in the file will be analyzed.'''
)

parser.add_argument(
    '--e-value',
    '-e',
    required=False,
    default=1000,
    type=int,
    help=
    '''The E-value controls the genome-wide error rate of identified islands under the random background assumption. Should be used when not using a control library. Default: 1000.'''
)

parser.add_argument(
    '--required-flag',
    '-f',
    required=False,
    default=0,
    type=int,
    help=
    '''(bam only.) Keep reads with these bits set in flag. Same as `samtools view -f`. Default 0
                   ''')

parser.add_argument(
    '--filter-flag',
    '-F',
    required=False,
    default=1540,
    type=int,
    help=
    '''(bam only.) Discard reads with these bits set in flag. Same as `samtools view -F`. Default 1540 (hex: 0x604).
See https://broadinstitute.github.io/picard/explain-flags.html for more info.
                   ''')

parser.add_argument(
    '--mapq',
    '-m',
    required=False,
    default=5,
    type=int,
    help=
    '''(bam only.) Discard reads with mapping quality lower than this. Default 5.
                   ''')

parser.add_argument(
    '--autodetect-chroms',
    '-a',
    required=False,
    default=False,
    action="store_true",
    help=
    '''(bam only.) Autodetect chromosomes from bam file. Use with --discard-chromosomes flag to avoid non-canonical chromosomes.'''
)

parser.add_argument(
    '--discard-chromosomes-pattern',
    '-d',
    required=False,
    default="_",
    type=str,
    help=
    '''(bam only.) Discard reads from chromosomes matching this pattern. Default
                   '_'. Note that if you are not interested in the results from
                   non-canonical chromosomes, you should ensure they are
                   removed with this flag, otherwise they will make the
                   statistical analysis too stringent.''')

parser.add_argument(
    '--original-statistics',
    required=False,
    action='store_true',
    help=
    '''Use the original SICER way of computing the statistics. Like SICER itself, this method raises an error on large datasets. Only included for debugging-purposes.'''
)

parser.add_argument(
    '--output-knockout',
    '-ok',
    required=required_args,
    type=str,
    help='''File to write knockout results to.''')

parser.add_argument(
    '--output-wildtype',
    '-ow',
    required=required_args,
    type=str,
    help='''File to write wildtype results to.''')


parser.add_argument(
    '--guess-bampe',
    required=False,
    default=False,
    action="store_true",
    help=
    '''Autodetect bampe file format based on flags from the first 100 reads. If all of them are paired, then the format is bampe. Only properly paired reads are processed by default (0x1 and 0x2 samtools flags).'''
)

parser.add_argument(
    '--quiet',
    '-q',
    required=False,
    default=False,
    action="store_true",
    help='''Do not write output messages to stderr.''')

parser.add_argument(
    '--example',
    '-ex',
    required=False,
    default=False,
    action="store_true",
    help='''Show the paths of the example data and an example command.''')

parser.add_argument('--version', "-v", action='version', version=__version__)

# def _pvalue(chip_read_count, control_read_count, scaling_factor):
#     if control_read_count > 0:
#         average = control_read_count * scaling_factor
#     else:
#         average = 1 * scaling_factor

#     if chip_read_count > average:
#         pvalue = poisson.sf(chip_read_count, average)
#     else:
#         pvalue = 1

#     return pvalue


def _pvalue(chip, control, scaling_factor):

    average = pd.Series(np.zeros(len(control)))
    average.loc[control > 0] = control[control > 0] * scaling_factor
    average.loc[control == 0] = scaling_factor

    pvalues = pd.Series(np.zeros(len(control)))

    chip_gt_avg = chip > average
    pvalues.loc[chip_gt_avg] = poisson.sf(chip[chip_gt_avg],
                                          average[chip_gt_avg])
    pvalues.loc[~chip_gt_avg] = 1

    return pvalues


if __name__ == "__main__":

    args = vars(parser.parse_args())
    args["df"] = True

    if args["example"]:

        import pkg_resources
        treatment = pkg_resources.resource_filename("epic2",
                                                    "examples/test.bed.gz")
        control = pkg_resources.resource_filename("epic2",
                                                  "examples/control.bed.gz")
        print("Knockout: " + treatment)
        print("Wildtype: " + control)
        print(
            "Example command: epic2-df -fdrc 1 -tk {} -tw {} -ok deleteme_ko.txt -ow deleteme_wt.txt > deleteme.txt"
            .format(treatment, control))
        sys.exit(0)

    args["treatment"] = args["treatment_knockout"]
    args["control"] = args["control_knockout"]
    args["output"] = args["output_knockout"]

    print("Running epic2 on KO.", file=sys.stderr)
    bins_counts_ko = _main(args)
    # print(bins_counts_ko)

    args["treatment"] = args["treatment_wildtype"]
    args["control"] = args["control_wildtype"]
    args["output"] = args["output_wildtype"]

    print("Running epic2 on WT.", file=sys.stderr)
    bins_counts_wt = _main(args)

    print("Comparing islands.", file=sys.stderr)

    counts_wildtype = sum(sum(counts) for _, counts in bins_counts_wt.values())
    counts_knockout = sum(sum(counts) for _, counts in bins_counts_ko.values())

    names = "Chromosome	Start	End	ChIPCount	Score".split()
    dtypes = {"Chromosome": "category", "Start": int, "End": int}
    cols = [0, 1, 2, 3, 4]
    ko = pd.read_csv(
        args["output_knockout"],
        sep="\t",
        names=names,
        usecols=cols,
        skiprows=1,
        dtype=dtypes,
        header=None)
    wt = pd.read_csv(
        args["output_wildtype"],
        sep="\t",
        usecols=cols,
        names=names,
        skiprows=1,
        dtype=dtypes,
        header=None)

    ko = pr.PyRanges(ko)
    wt = pr.PyRanges(wt)

    union = ko.set_union(wt, strandedness=False)

    ko_counts = differential_count_reads_on_islands(union.dfs, bins_counts_ko)
    wt_counts = differential_count_reads_on_islands(union.dfs, bins_counts_wt)

    ko_counts = pd.concat(
        [pd.Series(arr) for _, arr in natsorted(ko_counts.items())])
    wt_counts = pd.concat(
        [pd.Series(arr) for _, arr in natsorted(wt_counts.items())])

    union.KO = ko_counts
    union.WT = wt_counts

    df = union.df.reset_index(drop=True)

    # do wt/ko stats on counts
    scaling_factor = counts_knockout / counts_wildtype
    fc_ko = ((df.KO + 1) / (df.WT + 1)) / scaling_factor
    fc_ko.name = "FC_KO"
    fc_wt = ((df.WT + 1) / (df.KO + 1)) * scaling_factor
    fc_wt.name = "FC_WT"

    p_ko = _pvalue(df.KO, df.WT, scaling_factor)
    p_ko.name = "P_KO"
    p_wt = _pvalue(df.WT, df.KO, scaling_factor)
    p_wt.name = "P_WT"

    fdr_ko = p_ko * len(p_ko) / rankdata(p_ko)
    fdr_wt = p_wt * len(p_wt) / rankdata(p_wt)

    fdr_ko[fdr_ko > 1] = 1
    fdr_ko.name = "FDR_KO"
    fdr_wt[fdr_wt > 1] = 1
    fdr_wt.name = "FDR_WT"

    for c in [fc_ko, fc_wt, p_ko, p_wt, fdr_ko, fdr_wt]:
        df.insert(df.shape[1], c.name, c)

    p = pearsonr(df.KO, df.WT)
    s = spearmanr(df.KO, df.WT)

    diff_fdr = args["false_discovery_rate_comparison"]
    df = df[(df.FDR_KO <= diff_fdr) | (df.FDR_WT <= diff_fdr)]

    print(df.to_csv(sep="\t", index=False), end="")

    print(
        "Pearson's correlation is:",
        p[0],
        "and the p-value is",
        p[1],
        file=sys.stderr)
    print(
        "Spearman's correlation is:",
        s[0],
        "and the p-value is",
        s[1],
        file=sys.stderr)
