#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2012 Tobias Marschall
# 
# This file is part of CLEVER.
# 
# CLEVER is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# CLEVER is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with CLEVER.  If not, see <http://www.gnu.org/licenses/>.

from optparse import OptionParser
import sys
import os
import shutil
import time
import subprocess

__author__ = "Tobias Marschall"

usage = """%prog [options] <bam-file> <ref.fasta(.gz)> <result-directory>

This tool runs the whole workflow necessary to use CLEVER.

<bam-file>         Input BAM file. All alignments for the same read (pair) must be in
                   subsequent lines. It is highly recommended to allows multiple 
                   alignments per read to avoid spurious predictions. 
<ref.fasta(.gz)>   The reference genome in (gzipped) FASTA format. This is needed to
                   recompute alignment scores (AS tags). If your BAM file does have AS tags
                   such that 10^(AS/-10.0) can be interpreted as the probability of this
                   alignment being correct, use option -a to omit this step.
<result-directory> Directory to be created to store results in. If it already exists, abort
                   unless option -f is given."""

def locate_executeable(exe_dict, name):
	def isexecutable(f):
		return os.path.isfile(f) and os.access(f, os.X_OK)
	for path in os.environ["PATH"].split(os.pathsep):
		f = os.path.join(path, name)
		if isexecutable(f):
			exe_dict[name] = f
			print('Found executable', f, file=sys.stderr)
			return True
	scriptpath = os.path.dirname(os.path.abspath(__file__))
	f = os.path.join(scriptpath, name)
	if isexecutable(f):
		exe_dict[name] = f
		print('Found executable', f, file=sys.stderr)
		return True
	f = os.path.abspath(os.path.join(scriptpath, '..', 'src', name))
	if isexecutable(f):
		exe_dict[name] = f
		print('Found executable', f, file=sys.stderr)
		return True
	print('Could not locate executable \"%s\". It\'s not in your PATH.'%name, file=sys.stderr)
	return False

def execute(commandline):
	exitcode = os.system(commandline)
	if exitcode == 0:
		return True
	else:
		print("Error: non-zero exit code (%d) returned by commandline:"%exitcode, file=sys.stderr)
		print("\"%s\""%commandline, file=sys.stderr)
		return False


makefile_template = """.DELETE_ON_ERROR:

INPUT := $(shell find -name '*.aln-priors.gz')
RESULTS := $(patsubst ./priors.%.aln-priors.gz,clever.%.out,$(INPUT))

all: $(RESULTS)
	echo "Done."

clever.%.out: priors.%.sorted-priors.gz
	echo Starting to run CLEVER for chromosome $*
	zcat $< | {0} > $@ 2> clever.$*.log 
	echo Finished running CLEVER for chromosome $*

priors.%.sorted-priors.gz: priors.%.aln-priors.gz
	echo Starting to sort priors for chromosome $*
	/usr/bin/env bash -c "time (zcat $< | sort -g -k7 | gzip > $@) 2> priors.$*.sort.log"
	echo Finished sorting priors for chromosome $*
"""

def main():
	# TODO: Do better version check (in separate wrapper without __future__ import) 
	#if sys.version_info < (2,6):
		#print('At least Python 2.6 is needed to run this script.', file=sys.stderr)
		#print('Found version:', file=sys.stderr)
		#print(sys.version, file=sys.stderr)
		#return 1
	parser = OptionParser(usage=usage)
	parser.add_option("--sorted", action="store_true", dest="sorted", default=False,
					help="Input BAM file is sorted by position. Note that this requires alternative alignments to be given as XA tags (like produced by BWA, stampy, etc.).")
	parser.add_option("--use_xa", action="store_true", dest="use_xa", default=False,
					help="Interprete XA tags in input BAM file. This option SHOULD be given for mappers writing XA tags like BWA and stampy.")
	parser.add_option("--use_mapq", action="store_true", dest="use_mapq", default=False,
					help="Use MAPQ value instead re-computing posteriors.")
	parser.add_option("-T", action="store", dest="threads", default=1, type=int,
					help="Number of threads to use (default=1).")
	parser.add_option("-f", action="store_true", dest="force", default=False,
					help="Delete old result and working directory first (if present).")
	parser.add_option("-w", action="store", dest="work_dir", default=None,
					help="Working directory (default: <result-directory>/work).")
	parser.add_option("-a", action="store_true", dest="as_tags", default=False,
					help="Do not (re-)compute AS tags. If given, the argument <ref.fasta(.gz)> is ignored.")
	parser.add_option("-k", action="store_true", dest="keep_workdir", default=False,
					help="Keep working directory (default: delete directory when finished).")
	parser.add_option("-r", action="store_true", dest="read_groups", default=False,
					help="Take read groups into account (multi sample mode).")
#	parser.add_option("-d", action="store", dest="discard_concordant", default=0.0, type=float,
#					help="Discard \"concordant\" alignments within the given number of standard deviations (default: disabled).")
	#parser.add_option("-D", action="store_true", dest="arbitrary_dist", default=False,
					#help="Use the empirical internal segment size distribution instead of assuming it to be Gaussian. Should be used when the distribution is not normal. Makes computations slower. (EXPERIMENTAL)")
	parser.add_option("-C", action="store", dest="add_clever_params", default="",
					help="Additional parameters to be passed to the CLEVER core algorithm. Call \"clever-core\" without parameters for a list of options.")
	parser.add_option("-P", action="store", dest="add_post_params", default="",
					help="Additional parameters for postprocessing results. Call \"postprocess-predictions\" without parameters for a list of options.")
	parser.add_option("-I", action="store_true", dest="plot_distribution", default=False,
					help="Create a plot of internal segment size distribution (=fragment size - 2x read length). Also displays the estimated normal distribution (requires NumPy and matplotlib).")
	parser.add_option("--chromosome", action="store", dest="chromosome", default=None,
					help="Only process given chromosome (default: all).")
	(options, args) = parser.parse_args()
	if (len(args)!=3):
		parser.print_help()
		return 1
	bam_filename = os.path.abspath(args[0])
	ref_filename = os.path.abspath(args[1])
	result_dir = os.path.abspath(args[2])
	if options.sorted and not (options.use_xa or options.use_mapq):
		print('Error: Must use "--use_xa" or "--use_mapq" when "--sorted" is given.', file=sys.stderr)
		return 1
	# find needed executables
	exe_dict = dict()
	print('===== Checking dependencies =====', file=sys.stderr)
	if not os.path.isfile(bam_filename):
		print("Error: file \"%s\" not found."%bam_filename, file=sys.stderr)
		return 1
	if not os.path.isfile(ref_filename):
		print("Error: file \"%s\" not found."%ref_filename, file=sys.stderr)
		return 1
	if not locate_executeable(exe_dict, 'ctk-version'): return 1
	if not locate_executeable(exe_dict, 'clever-core'): return 1
	if not locate_executeable(exe_dict, 'split-priors-by-chromosome'): return 1
	if not locate_executeable(exe_dict, 'postprocess-predictions'): return 1
	if not locate_executeable(exe_dict, 'make'): return 1
	if options.read_groups:
		if not locate_executeable(exe_dict, 'samtools'): return 1
	#if options.arbitrary_dist:
		#if not locate_executeable(exe_dict, 'precompute-distributions'): return 1
	if not locate_executeable(exe_dict, 'bam-to-alignment-priors'): return 1
	if not locate_executeable(exe_dict, 'tee'): return 1
	if options.plot_distribution:
		if not locate_executeable(exe_dict, 'plot-insert-size-distribution'): return 1
#	if (options.discard_concordant != 0.0) and not options.bwa:
#		print('Option -d only allowed in combination with -B. Sorry.', file=sys.stderr)
#		return 1
	print('===== Determining CTK version =====', file=sys.stderr)
	ctk_version = subprocess.Popen([exe_dict['ctk-version']], stdout=subprocess.PIPE, universal_newlines=True).stdout.readline().strip()
	print('Version: ', ctk_version, file=sys.stderr)
	print('===== Checking directories =====', file=sys.stderr)
	if os.path.isdir(result_dir):
		if options.force:
			try:
				shutil.rmtree(result_dir)
			except OSError as e:
				print("Error deleting previous result directory:", e, file=sys.stderr)
				return 1
		else:
			print("Error: directory \"%s\" already exists. Move it out of the way or use option -f."%result_dir, file=sys.stderr)
			return 1
	try:
		os.makedirs(result_dir)
	except OSError as e:
		print("Error creating result directory:", e, file=sys.stderr)
		return 1
	if options.work_dir != None:
		work_dir = os.path.abspath(options.work_dir)
		if os.path.exists(work_dir):
			if options.force:
				shutil.rmtree(work_dir)
			else:
				print("Error: directory \"%s\" already exists. Move it out of the way or use option -f."%work_dir, file=sys.stderr)
				return 1
	else:
		work_dir = os.path.join(result_dir, 'work')
	os.makedirs(work_dir)
	print('Result directory:', result_dir, file=sys.stderr)
	print('Working directory:', work_dir, file=sys.stderr)
	os.chdir(work_dir)
	commandline_logfile = open('commandline.log','w')
	print('\n'.join(sys.argv), file=commandline_logfile)
	commandline_logfile.close()
	# check whether BAM header contains read group information
	if options.read_groups:
		if os.system('samtools view -H %s | grep -q "^@RG"'%bam_filename) != 0:
			print('Error: No read group information found in',bam_filename, file=sys.stderr)
			return 1
	print('===== Computing alignment priors from %s ====='%bam_filename, file=sys.stderr)
	bam_to_priors = None
	split_priors = None
	tee = None
	try:
		bam_to_priors_call = [exe_dict['bam-to-alignment-priors'], '-m', 'lengths.mean-and-sd', '-T2']
		if not options.sorted:
			bam_to_priors_call.append('--unsorted')
		if options.use_mapq:
			bam_to_priors_call.append('--use_mapq')
			bam_to_priors_call.append('--ignore_xa')
		elif not options.use_xa:
			bam_to_priors_call.append('--ignore_xa')
		if options.plot_distribution:
			bam_to_priors_call.append('-I')
			bam_to_priors_call.append('lengths.distribution')
		if options.chromosome:
			bam_to_priors_call.append('--chromosome')
			bam_to_priors_call.append(options.chromosome)
		#bam_to_priors_call.append('-n50000')
#			if options.discard_concordant != 0.0:
#				bam_to_priors_call.append('-d%f'%options.discard_concordant)
		bam_to_priors_call.append(ref_filename)
		bam_to_priors_call.append(bam_filename)
		bam_to_priors = subprocess.Popen(bam_to_priors_call, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
		tee = subprocess.Popen([exe_dict['tee'], work_dir+'/priors.log'], stdin=bam_to_priors.stderr)
		split_priors = subprocess.Popen([exe_dict['split-priors-by-chromosome'], '-z', 'priors'], stdin=bam_to_priors.stdout)
		if bam_to_priors.wait() != 0:
			print('Error executing "bam-to-alignment-priors".', file=sys.stderr)
			return 1
		if tee.wait() != 0:
			print('Error executing "tee".', file=sys.stderr)
			return 1
		if split_priors.wait() != 0:
			print('Error executing split-priors-by-chromosome.', file=sys.stderr)
			return 1
	except:
		if bam_to_priors != None: bam_to_priors.terminate()
		if split_priors != None: split_priors.terminate()
		if tee != None: tee.terminate()
		raise
	time.sleep(1)
	if options.plot_distribution:
		print('===== Plotting internal segment size distribution =====', file=sys.stderr)
		mean, stddev = [float(x) for x in open('lengths.mean-and-sd').readline().split()]
		if not execute('%s -n %f,%f -o insert-size-dist.pdf lengths.distribution'%(exe_dict['plot-insert-size-distribution'], mean, stddev)):
			return 1
		print('Done.', file=sys.stderr)
	print('===== Sort priors for each chromosome and run CLEVER =====', file=sys.stderr)
	makefile = open('Makefile', 'w')
	read_group_parameter = ('-R %s '%bam_filename) if options.read_groups else ''
	#if options.arbitrary_dist:
		#clever_call = '%s %s%s -v -d -D %s %s'%(exe_dict['clever'],read_group_parameter,options.add_clever_params,'cached-distributions','lengths.histogram')
	#else:
	clever_call = '%s %s%s -v %s'%(exe_dict['clever-core'],read_group_parameter,options.add_clever_params,'lengths.mean-and-sd')
	makefile.write(makefile_template.format(clever_call))
	makefile.close()
	if not execute('%s --quiet -j %d'%(exe_dict['make'],options.threads)): return 1
	print('===== Aggregating results =====', file=sys.stderr)
	execute('cat %s > %s'%(os.path.join(work_dir,'*.out'),os.path.join(result_dir,'predictions.raw.txt')))
	log_dir = os.path.join(result_dir,'logs')
	os.makedirs(log_dir)
	execute('cp %s %s'%(os.path.join(work_dir,'*.log'),log_dir))
	if options.plot_distribution:
		shutil.move('insert-size-dist.pdf', result_dir)
	print('Done.', file=sys.stderr)
	print('===== Postprocessing predictions =====', file=sys.stderr)
	mean, stddev = [float(x) for x in open(os.path.join(work_dir,'lengths.mean-and-sd')).readline().split()]
	if mean < 50:
		mean = 50
	if stddev < 5:
		stddev = 5
	if not execute('%s %s --vcf --covbal 0.333 --stddev %f %s %f > %s'%(exe_dict['postprocess-predictions'],options.add_post_params,stddev,os.path.join(result_dir,'predictions.raw.txt'),mean,os.path.join(result_dir,'predictions.vcf'))): return 1
	if not options.keep_workdir:
		print('===== Deleting work directory =====', file=sys.stderr)
		shutil.rmtree(work_dir)
	print('Done.', file=sys.stderr)

if __name__ == '__main__':
	sys.exit(main())
