#!/usr/bin/env python
"""MAGeCK file operation module
Copyright (c) 2014 Wei Li, Han Xu, Xiaole Liu lab 
This code is free software; you can redistribute it and/or modify it
under the terms of the BSD License (see the file COPYING included with
the distribution).
@status:  experimental
@version: $Revision$
@author:  Wei Li 
@contact: li.david.wei AT gmail.com
"""

from __future__ import print_function
import sys
import math
import logging
import subprocess

  
def systemcall(command, cmsg=True):
  logging.info('Running command: '+command)
  t=subprocess.Popen(command,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,shell=True).communicate()[0].decode("utf-8")
  #tmsg=t.stdout.read()
  if cmsg:
    logging.info('Command message:')
    for t0 in t.split('\n'):
      logging.info('  '+t0)
    logging.info('End command message.')
  return t


def merge_rank_summary_files(lowfile,highfile,outfile,args,lowfile_prefix='',highfile_prefix=''):
  """
  Merge multiple rank summary files
  """
  gfile={}
  lowfileorder=[]
  lowfileheader=[]
  # read files individually
  nline=0
  for line in open(lowfile):
    field=line.strip().split()
    nline+=1
    if nline==1: # skip the first line
      lowfileheader=field
      lowfileheader[2:]=[lowfile_prefix+t for t in lowfileheader[2:]]
      continue
    if len(field)<4:
      logging.error('The number of fields in file '+lowfile+' is <4.')
      sys.exit(-1)
    gid=field[0]
    gitem=int(field[1])
    g_others=field[1:]
    gfile[gid]=[t for t in g_others]
    lowfileorder+=[gid]
  maxnline=nline
  nline=0
  njoinfield=0
  highfileheader=[]
  for line in open(highfile):
    field=line.strip().split()
    nline+=1
    if nline==1: # skip the first line
      highfileheader=field
      highfileheader[2:]=[highfile_prefix+t for t in highfileheader[2:]]
      continue
    if len(field)<4:
      logging.error('The number of fields in file '+highfile+' is <4.')
      sys.exit(-1)
    gid=field[0]
    gitem=int(field[1])
    g_others=field[2:]
    if gid not in gfile:
      logging.warning('Item '+gid+' appears in '+highfile+', but not in '+lowfile+'. This record will be omitted.')
    else:
      prevgitem=int(gfile[gid][0])
      if prevgitem!=gitem:
        # under mageck0.5.7, gene_summary.txt may have different fields
        #logging.warning('Item number of '+gid+' does not match previous file: '+str(gitem)+' !='+str(prevgitem)+'.')
        pass
      gfile[gid]+=g_others # don't repeat the gitem
      if prevgitem == gitem:
        njoinfield=len(gfile[gid])
  # check whether some items appear in the first group, but not in the second group
  keepsgs=[]
  for (k,v) in gfile.items():
    if len(v)!=njoinfield:
      logging.warning('Item '+k+' appears in '+lowfile+', but not in '+highfile+'.')
    else:
      keepsgs+=[k]
  gfile2={k:gfile[k] for k in keepsgs}
  
  # write to files
  ofhd=open(outfile,'w')
  # print('\t'.join(['id','num','p.neg','fdr.neg','rank.neg','p.pos','fdr.pos','rank.pos']),file=ofhd)
  print('\t'.join(lowfileheader)+'\t'+'\t'.join(highfileheader[2:]),file=ofhd)
  for k in lowfileorder:
    if k in gfile2:
      print('\t'.join([k, '\t'.join([str(t) for t in gfile2[k]])]),file=ofhd)
  
  ofhd.close()
 
class Rank_Obj: # items in the rank file generated by RRA
  name=""
  sgrna=0
  lo=0.0
  pval=0.0
  fdr=0.0
  goodsgrna=0
  isbad=False
  rank=0
  lfc=0.0

def merge_rank_files(lowfile,highfile,outfile,args,cutoffinfo):
  """
  Merge neg. and pos. selected files (generated by RRA) into one
  Parameters:
    lowfile
        RRA neg. selection output
    highfile
        RRA pos. selection output
    outfile
        The output file name
    args
        arguments
    cutoffinfo
        The return value of crispr_test. Include (low_p_threshold, high_p_threshold, lower_gene_lfc,higher_gene_lfc), where lower_gene_lfc={gene:lfc} is the log fold change of sgRNAs
  """
  gfile={}
  # read files individually
  nline=0
  lower_gene_lfc=cutoffinfo[2]
  higher_gene_lfc=cutoffinfo[3]
  for line in open(lowfile):
    field=line.strip().split()
    nline+=1
    if nline==1: # skip the first line
      continue
    if len(field)<4:
      logging.error('The number of fields in file '+lowfile+' is <4.')
      sys.exit(-1)
    r_o=Rank_Obj()
    r_o.name=field[0]
    r_o.sgrna=int(field[1])
    r_o.lo=float(field[2])
    r_o.pval=float(field[3])
    r_o.rank=nline-1
    try:
      r_o.fdr=float(field[4])
    except ValueError:
      r_o.fdr='NA'
      r_o.isbad=True
    r_o.goodsgrna=int(field[5])
    if r_o.name in lower_gene_lfc:
      g_lfc="{:.5g}".format(lower_gene_lfc[r_o.name])
    else:
      g_lfc=0.0
    r_o.lfc=g_lfc
    # gfile[r_o.name]=[[gitem,g_lo,g_p,g_fdr,nline-1,g_goodsgrna,g_lfc]]
    gfile[r_o.name]=[r_o]
  maxnline=nline
  nline=0
  for line in open(highfile):
    field=line.strip().split()
    nline+=1
    if nline==1: # skip the first line
      continue
    if len(field)<4:
      logging.error('The number of fields in file '+highfile+' is <4.')
      sys.exit(-1)
    r_o=Rank_Obj()
    r_o.name=field[0]
    r_o.sgrna=int(field[1])
    r_o.lo=float(field[2])
    r_o.pval=float(field[3])
    r_o.rank=nline-1
    try:
      r_o.fdr=float(field[4])
    except ValueError:
      r_o.fdr='NA'
      r_o.isbad=True
    r_o.goodsgrna=int(field[5])
    if r_o.name in higher_gene_lfc:
      g_lfc="{:.5g}".format(higher_gene_lfc[r_o.name])
    else:
      g_lfc=0.0
    r_o.lfc=g_lfc
    if r_o.name not in gfile:
      logging.warning('Item '+r_o.name+' appears in '+highfile+', but not in '+lowfile+'.')
      #gfile[gid]=[('NA',1.0,1.0,maxnline)]
      r_o2=Rank_Obj()
      r_o2.rank=maxnline
      gfile[r_o.name]=[r_o2] # note that gitem is not saved
    else:
      #gfile[gid]+=[(gitem,g_p,g_fdr,nline-1)]
      if gfile[r_o.name][0].sgrna!=r_o.sgrna:
        logging.warning('Item number of '+r_o.name+' does not match previous file: '+str(r_o.sgrna)+' !='+str(gfile[r_o.name][0].sgrna)+'.')
    gfile[r_o.name]+=[r_o] # don't repeat the gitem
  # check whether some items appear in the first group, but not in the second group
  for (k,v) in gfile.items():
    if len(v)==1:
      logging.warning('Item '+v[0].name+' appears in '+lowfile+', but not in '+highfile+'.')
      #gfile[gid]+=[('NA',1.0,1.0,maxnline)]
      r_o2=Rank_Obj()
      r_o2.rank=maxnline
      gfile[v[0].name]+=[r_o2] # note that gitem is not saved
      #gfile[gid]+=[[1.0,1.0,1.0,maxnline,0,0.0]]
  # write to files
  ofhd=open(outfile,'w')
  print('\t'.join(['id','num','neg|score','neg|p-value','neg|fdr','neg|rank','neg|goodsgrna', 'neg|lfc', 'pos|score','pos|p-value','pos|fdr','pos|rank','pos|goodsgrna','pos|lfc']),file=ofhd)
  if hasattr(args,'sort_criteria') and args.sort_criteria=='pos':
    logging.debug('Sorting the merged items by positive selection...')
    skey=sorted(gfile.items(),key=lambda x : x[1][1].rank)
  else:
    logging.debug('Sorting the merged items by negative selection...')
    skey=sorted(gfile.items(),key=lambda x : x[1][0].rank)
  # correct FDR method from RRA
  if hasattr(args,'adjust_method') and args.adjust_method!='fdr':
    from mageck.fdr_calculation import pFDR
    logging.debug('adjusting fdr using '+args.adjust_method+' method ...')
    pnegpool=[t[1][0].pval for t in skey if t[1][0].isbad==False] # negative selection: p-value is in item[2], fdr in item[3]
    ppospool=[t[1][1].pval for t in skey if t[1][1].isbad==False] # positive selection: p-value is in item[1], fdr in item[2]
    # logging.info('Size:'+str(len(pnegpool)))
    dfrnegpool=pFDR(pnegpool,method=args.adjust_method)
    dfrpospool=pFDR(ppospool,method=args.adjust_method)
    #import pdb 
    #pdb.set_trace()
    #
    ind=0
    for t in skey:
      if t[1][0].isbad==False:
        t[1][0].fdr=dfrnegpool[ind]
        ind+=1
      else:
        t[1][0].fdr='NA'
    ind=0
    for t in skey:
      if t[1][1].isbad==False:
        t[1][1].fdr=dfrpospool[ind]
        ind+=1
      else:
        t[1][1].fdr='NA'
  # write to file
  for k in skey:
    # print('\t'.join([k[0], '\t'.join([str(t) for t in k[1][0]+k[1][1]])]),file=ofhd)
    negobj=k[1][0]
    posobj=k[1][1]
    print('\t'.join([negobj.name, str(negobj.sgrna)]),end='\t',file=ofhd)
    print('\t'.join([str(x) for x in [negobj.lo,negobj.pval,negobj.fdr,negobj.rank,negobj.goodsgrna,negobj.lfc]]),end='\t',file=ofhd)
    print('\t'.join([str(x) for x in [posobj.lo,posobj.pval,posobj.fdr,posobj.rank,posobj.goodsgrna,posobj.lfc]]),file=ofhd)
  
  ofhd.close()
    
   

def parse_sampleids(samplelabel,ids):
  """
  Parse the label id according to the given sample labels
  Parameter: 
    samplelabel: a string of labels, like '0,2,3' or 'treat1,treat2,treat3'
    ids: a {samplelabel:index} ({string:int})
  Return:
    (a list of index, a list of index labels)
  """
  # labels
  idsk=[""]*len(ids)
  for (k,v) in ids.items():
    idsk[v]=k
  if samplelabel == None:
    groupidslabel=(ids.keys())
    groupids=[ids[x] for x in groupidslabel]
    return (groupids,groupidslabel)
  
  try:
    groupids=[int(x) for x in samplelabel.split(',')]
    groupidslabel=[idsk[x] for x in groupids]
  except ValueError:
    groupidstr=samplelabel.split(',')
    groupids=[]
    groupidslabel=[]
    for gp in groupidstr:
      if gp not in ids:
        logging.error('Sample label '+gp+' does not match records in your count table.')
        logging.error('Sample labels in your count table: '+','.join(idsk))
        sys.exit(-1)
      groupids+=[ids[gp]]
      groupidslabel+=[idsk[ids[gp]]]
  logging.debug('Given sample labels: '+samplelabel)
  logging.debug('Converted index: '+' '.join([str(x) for x in groupids]))
  return  (groupids,groupidslabel)


