#!/usr/bin/env python3

# Copyright 2013 Tobias Marschall
# 
# This file is part of the CLEVER-TOOLKIT.
# 
# The CLEVER-TOOLKIT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The CLEVER-TOOLKIT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with the CLEVER-TOOLKIT.  If not, see <http://www.gnu.org/licenses/>.

from optparse import OptionParser
import sys
import os
import subprocess
import tempfile
import re
import shutil
import itertools

__author__ = "Tobias Marschall"

usage = """Usage: %prog [options] <reference.fasta(.gz)> <1.fastq.gz> <2.fastq.gz> <output-prefix>"""

def locate_executeable(exe_dict, name):
	def isexecutable(f):
		return os.path.isfile(f) and os.access(f, os.X_OK)
	for path in os.environ["PATH"].split(os.pathsep):
		f = os.path.join(path, name)
		if isexecutable(f):
			exe_dict[name] = f
			print('Found executable', f, file=sys.stderr)
			return True
	scriptpath = os.path.dirname(os.path.abspath(__file__))
	f = os.path.join(scriptpath, name)
	if isexecutable(f):
		exe_dict[name] = f
		print('Found executable', f, file=sys.stderr)
		return True
	f = os.path.abspath(os.path.join(scriptpath, '..', 'src', name))
	if isexecutable(f):
		exe_dict[name] = f
		print('Found executable', f, file=sys.stderr)
		return True
	print('Could not locate executable \"%s\". It\'s not in your PATH.'%name, file=sys.stderr)
	return False

def count_lines(filename):
	return sum( 1 for line in open(filename) )

def bwa_index_present(basename):
	for suffix in ['amb','ann','bwt','pac','sa']:
		if not os.path.isfile(basename + '.' + suffix):
			return False
	return True

def locate_bwa_index(ref_filename):
	if not os.path.isfile(ref_filename):
		print('Reference genome \'%s\' not found.'%ref_filename, file=sys.stderr)
		return None
	f = ref_filename
	while True:
		if bwa_index_present(f): return f
		elif bwa_index_present(f + '.64'): return f + '.64'
		if os.path.basename(f).find('.') > 0:
			f = f[:f.rfind('.')]
		else:
			print('Failed to locate BWA index. Please run \'bwa index %s\'.'%ref_filename, file=sys.stderr)
			return None

class SilentError(Exception):
	def __str__(self):
		return 'SilentError'

class Parameters:
	def __init__(self):
		self.two_part_params = dict()
		self.one_part_params = set()
		self.equivalence_map = dict()
	def read_usage(self,executable):
		"""Reads usage information to know which options are equivalent."""
		for line in subprocess.Popen([executable], stdout=open(os.devnull, 'w'), stderr=subprocess.PIPE, universal_newlines=True).stderr.readlines():
			m = re.match('  (-[A-Za-z]) \\[ (--[A-Za-z_]+) \\]',line)
			if m != None:
				self.equivalence_map[m.group(1)] = m.group(2)
	def add(self, parameters):
		# expand clusters of parameters (like -vzT)
		expanded_fields = []
		for s in parameters:
			m = re.match('-([a-zA-Z]+)([0-9]*)',s)
			if m == None:
				expanded_fields.append(s)
			else:
				for c in m.group(1):
					expanded_fields.append('-'+c)
				if len(m.group(2)) > 0:
					expanded_fields.append(m.group(2))
		i = 0
		fields = expanded_fields
		while i < len(fields):
			assert fields[i].startswith('-'), 'Invalid parameter string: %s'%parameters
			if fields[i] in self.equivalence_map:
				fields[i] = self.equivalence_map[fields[i]]
			if (i == len(fields) - 1) or (fields[i+1].startswith('-')):
				self.one_part_params.add(fields[i])
				i += 1
			else:
				self.two_part_params[fields[i]] = fields[i+1]
				i += 2
	def get(self):
		one_part_list = list(self.one_part_params)
		two_part_items = list(self.two_part_params.items())
		one_part_list.sort()
		two_part_items.sort()
		return one_part_list + list(itertools.chain(*two_part_items))
	def __str__(self):
		return ' '.join(self.get())

def main():
	parser = OptionParser(usage=usage)
	parser.add_option("--extra-sensitive", action="store_true", dest="sensitive", default=False,
			help="Be more sensitive (at the expense of runtime).")
	parser.add_option("-T", action="store", dest="threads", type=int, default=4,
			help="Threads.")
	parser.add_option("-M", action="store", dest="max_del_length", type=int, default=None,
			help="Maximum deletion length to look for (default: 1000 in regular mode, 10000 when using --extra-sensitive).")
	parser.add_option("-s", action="store", dest="seed_length", type=int, default=40,
			help="Length of alignment seeds to be mapped by external read mapper (BWA), default: 40.")
	parser.add_option("-S", action="store", dest="split_fastq", default=None,
			help="Filename of FASTQ file with split reads (if not given, a temporary such file be produced).")
	parser.add_option("--tmpdir", action="store", dest="tmpdir", default=None,
			help="Directory to use for temporary files (if not given, system default is used).")
	parser.add_option("--core-options", action="store", dest="core_options", default="",
			help="Additional options to pass on to LASER core algorithm. Call \"laser-core\" without parameters for a list of options.")
	parser.add_option("--recal-options", action="store", dest="recal_options", default="",
			help="Additional options to pass on to LASER's recalibration algorithm. Call \"laser-recalibrate\" without parameters for a list of options.")
	parser.add_option("-w", action="store", dest="weight_cutoff", default=3.0, type=float,
			help="Minimum expected support for a SNP/indel in order to be written to file with putative variations. These SNPs/indels will be used for recalibration of alignment scores.")
	parser.add_option("--keep_raw_bam", action="store_true", dest="keep_raw_bam", default=False,
			help="Keep BAM file produced by laser-core. Default: only keep BAM after recalibration.")
	parser.add_option("--secondary", action="store_true", dest="secondary", default=False,
			help="Include secondary alignments in the BAM file.")
	parser.add_option("--xa", action="store_true", dest="use_xa", default=False,
			help="Encode secondary alignments in XA tags (default: separate lines).")
	parser.add_option("--adv-cigar", action="store_true", dest="advanced_cigar", default=False,
			help="Use X/= in CIGAR strings instead of M.")
	parser.add_option("--dont-recalibrate", action="store_false", dest="recalibrate", default=True,
			help="Skip recalibration step.")
	parser.add_option("--interchromosomal", action="store_true", dest="interchromosomal", default=False,
			help="Look for interchromosomal read pairs and for interchromosomal split reads.")
	(options, args) = parser.parse_args()
	if (len(args) != 4):
		parser.print_help()
		sys.exit(1)
	ref_filename = args[0]
	fastq_1 = args[1]
	fastq_2 = args[2]
	#fastq_split = args[3]
	output_prefix = args[3]
	ref_match = re.search('^(.*)\.(fasta|fa)(\.gz|)$', ref_filename)
	if ref_match == None:
		print("Error: reference filename must end on .(fasta|fa)(.gz).", file=sys.stderr)
		return 1
	ref_prefix = ref_match.group(1)
	if (options.tmpdir != None) and (not os.path.isdir(options.tmpdir)):
		print('Error: directory "%s" does not exist.'%options.tmpdir, file=sys.stderr)
		return 1
	if (options.split_fastq != None) and (not os.path.isfile(options.split_fastq)):
		print('Error: file "%s" does not exist.'%options.split_fastq, file=sys.stderr)
		return 1
	if options.use_xa and (not options.secondary):
		print('Option --xa only allowed when option --secondary is given.', file=sys.stderr)
		return 1
	if not options.recalibrate:
		if not options.keep_raw_bam:
			print('Option --dont-recalibrate can only be used in connection with --keep_raw_bam.', file=sys.stderr)
			return 1
		if options.use_xa:
			print('Warning: option --xa has no effect when using --dont-recalibrate.', file=sys.stderr)
	exe_dict = dict()
	print('===== Checking dependencies =====', file=sys.stderr)
	if not os.path.isfile(fastq_1):
		print("Error: file \"%s\" not found."%fastq_1, file=sys.stderr)
		return 1
	if not os.path.isfile(fastq_2):
		print("Error: file \"%s\" not found."%fastq_2, file=sys.stderr)
		return 1
	if not locate_executeable(exe_dict, 'ctk-version'): return 1
	if not locate_executeable(exe_dict, 'laser-core'): return 1
	if not locate_executeable(exe_dict, 'laser-recalibrate'): return 1
	if not locate_executeable(exe_dict, 'bwa'): return 1
	if not locate_executeable(exe_dict, 'xa2multi.pl'): return 1
	if not locate_executeable(exe_dict, 'samtools'): return 1
	if options.split_fastq == None:
		if not locate_executeable(exe_dict, 'split-reads'): return 1
		if not locate_executeable(exe_dict, 'gzip'): return 1
	if options.use_xa:
		if not locate_executeable(exe_dict, 'multiline-to-xa'): return 1
	bwa_index = locate_bwa_index(ref_filename)
	if bwa_index == None: return 1
	print('===== Determining CTK version =====', file=sys.stderr)
	ctk_version = subprocess.Popen([exe_dict['ctk-version']], stdout=subprocess.PIPE, universal_newlines=True).stdout.readline().strip()
	print('Version: ', ctk_version, file=sys.stderr)
	tmpdir = tempfile.mkdtemp(prefix='laser-', dir=options.tmpdir)
	try:
		# ==============================================
		if options.split_fastq == None:
			print('===== Splitting reads =====', file=sys.stderr)
			split_reads = None
			gzip = None
			fastq_split = '%s/split.fastq.gz'%tmpdir
			try:
				split_call = [exe_dict['split-reads'],'-l',str(options.seed_length), fastq_1, fastq_2]
				print('Splitting reads:',' '.join(split_call), file=sys.stderr)
				split_reads = subprocess.Popen(split_call, stdout=subprocess.PIPE, bufsize=-1)
				gzip = subprocess.Popen([exe_dict['gzip']], stdin=split_reads.stdout, stdout=open(fastq_split,'w'), bufsize=-1)
				if split_reads.wait() != 0:
					print('Error executing "split-reads".', file=sys.stderr)
					return 1
				if gzip.wait() != 0:
					print('Error executing "split-reads".', file=sys.stderr)
					return 1
			except:
				if split_reads != None: split_reads.terminate()
				if gzip != None: gzip.terminate()
				raise
		else:
			fastq_split = options.split_fastq
		# ==============================================
		bwa_aln = None
		bwa_samse = None
		xa2multi = None
		samtools = None
		laser = None
		if options.keep_raw_bam:
			laser_bam_filename = '%s.raw.bam'%output_prefix
		else:
			laser_bam_filename = '%s/laser.raw.bam'%tmpdir
		snp_filename = output_prefix+'.putative-snps'
		indel_filename = output_prefix+'.putative-indels'
		insertion_dist_filename = output_prefix+'.insertion-size-dist'
		deletion_dist_filename = output_prefix+'.deletion-size-dist'
		insert_size_dist_filename = output_prefix+'.insert-size-dist'
		try:
			print('===== Running LASER-CORE =====', file=sys.stderr)
			print('Output file:', laser_bam_filename, file=sys.stderr)
			bwa_aln = subprocess.Popen([exe_dict['bwa'],'aln','-t',str(options.threads),bwa_index,fastq_split], stdout=subprocess.PIPE, stderr=open(output_prefix+'.bwa-aln.log', 'w'), bufsize=-1)
			bwa_samse = subprocess.Popen([exe_dict['bwa'],'samse','-n25', bwa_index, '/dev/stdin',fastq_split], stdin=bwa_aln.stdout, stdout=subprocess.PIPE, stderr=open(output_prefix+'.bwa-samse.log', 'w'), bufsize=-1)
			xa2multi = subprocess.Popen([exe_dict['xa2multi.pl']], stdin=bwa_samse.stdout, stdout=subprocess.PIPE, stderr=open(os.devnull, 'w'), bufsize=-1)
			samtools = subprocess.Popen([exe_dict['samtools'],'view','-bS','-'], stdin=xa2multi.stdout, stdout=subprocess.PIPE, stderr=open(os.devnull, 'w'), bufsize=-1)
			laser_parameters = Parameters()
			laser_parameters.read_usage(exe_dict['laser-core'])
			if options.sensitive:
				laser_parameters.add(['-A12', '--anchor_distance','1000','--max_anchors','250','--max_anchor_pairs','1000'])
				if options.max_del_length == None:
					options.max_del_length = 10000
			else:
				if options.max_del_length == None:
					options.max_del_length = 1000
			laser_parameters.add(['--max_span', str(options.max_del_length), '--max_insert', str(500 + options.max_del_length)])
			laser_parameters.add(['--indel_weight_cutoff', str(options.weight_cutoff), '--snp_weight_cutoff', str(options.weight_cutoff)])
			laser_parameters.add(options.core_options.split())
			if options.interchromosomal:
				laser_parameters.add(['--interchromosomal'])
			laser_parameters.add(['--snp',snp_filename,'-P',indel_filename,'-L',insert_size_dist_filename,'-R',insertion_dist_filename,'-D',deletion_dist_filename])
			laser_parameters.add(['-XIS','-T',str(options.threads)])
			laser_call = [exe_dict['laser-core']] + laser_parameters.get() + [ref_filename, fastq_1, fastq_2]
			#print('Laser call:', ' '.join(laser_call), file=sys.stderr)
			laser = subprocess.Popen(laser_call, stdin=samtools.stdout, stdout=open(laser_bam_filename, 'w'), bufsize=-1)
			if bwa_aln.wait() != 0:
				print('Error executing "bwa aln", see %s.bwa-aln.log for details.'%output_prefix, file=sys.stderr)
				raise SilentError()
			if bwa_samse.wait() != 0:
				print('Error executing "bwa samse", see %s.bwa-samse.log for details.'%output_prefix, file=sys.stderr)
				raise SilentError()
			if xa2multi.wait() != 0:
				print('Error executing "xa2multi.pl".', file=sys.stderr)
				raise SilentError()
			if samtools.wait() != 0:
				print('Error executing "samtools".', file=sys.stderr)
				raise SilentError()
			if laser.wait() != 0:
				print('Error executing "laser-core".', file=sys.stderr)
				raise SilentError()
		except:
			if (bwa_aln != None) and (bwa_aln.returncode == None): bwa_aln.terminate()
			if (bwa_samse != None) and (bwa_samse.returncode == None): bwa_samse.terminate()
			if (xa2multi != None) and (xa2multi.returncode == None): xa2multi.terminate()
			if (samtools != None) and (samtools.returncode == None): samtools.terminate()
			if (laser != None) and (laser.returncode == None): laser.terminate()
			raise
		# ==============================================
		if options.recalibrate:
			recalibrate = None
			multi_to_xa = None
			try:
				print('===== Recalibrating alignment scores =====', file=sys.stderr)
				recalibrate_params = Parameters()
				recalibrate_params.read_usage(exe_dict['laser-recalibrate'])
				recalibrate_params.add(['--omit_alt_cigars', '--distant-pairs', '--default_readgroup', '--snp', snp_filename, '--variations', indel_filename, '--max_pair_distance', str(500 + options.max_del_length)])
				if not options.secondary:
					recalibrate_params.add(['--omit_secondary_aln'])
				if not options.advanced_cigar:
					recalibrate_params.add(['--m_in_cigar'])
				if count_lines(insertion_dist_filename) < 10:
					print('Empiric insertion size distribution has low complexity, using default instead', file=sys.stderr)
				else:
					recalibrate_params.add(['-I', insertion_dist_filename])
				if count_lines(deletion_dist_filename) < 10:
					print('Empiric deletion size distribution has low complexity, using default instead', file=sys.stderr)
				else:
					recalibrate_params.add(['-D', deletion_dist_filename])
				recalibrate_params.add(options.recal_options.split())
				recal_call = [exe_dict['laser-recalibrate']] + recalibrate_params.get() + [insert_size_dist_filename]
				if options.use_xa:
					recalibrate = subprocess.Popen(recal_call, stdin=open(laser_bam_filename), stdout=subprocess.PIPE, bufsize=-1)
					multi_to_xa = subprocess.Popen([exe_dict['multiline-to-xa']], stdin=recalibrate.stdout, stdout=open(output_prefix+'.bam', 'w'), bufsize=-1)
				else:
					recalibrate = subprocess.Popen(recal_call, stdin=open(laser_bam_filename), stdout=open(output_prefix+'.bam', 'w'), bufsize=-1)
				if recalibrate.wait() != 0:
					print('Error executing "laser-recalibrate".', file=sys.stderr)
					return 1
				if multi_to_xa != None:
					if multi_to_xa.wait() != 0:
						print('Error executing "multiline-to-xa".', file=sys.stderr)
						return 1
			except:
				if recalibrate != None: recalibrate.terminate()
				if multi_to_xa != None: multi_to_xa.terminate()
				raise
	finally:
		print('===== Removing temporary files =====', file=sys.stderr)
		print('Deleting directory', tmpdir, file=sys.stderr)
		shutil.rmtree(tmpdir)

if __name__ == '__main__':
	try:
		sys.exit(main())
	except KeyboardInterrupt:
		print('Abort requested by user, quitting.', file=sys.stderr)
		sys.exit(1)
	except SilentError:
		sys.exit(1)
