#!/usr/bin/env python3

# Copyright 2013 Tobias Marschall
# 
# This file is part of the CLEVER-TOOLKIT.
# 
# The CLEVER-TOOLKIT is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The CLEVER-TOOLKIT is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with the CLEVER-TOOLKIT.  If not, see <http://www.gnu.org/licenses/>.

from optparse import OptionParser
import sys
import os
import subprocess
import tempfile
import re
import shutil
from collections import defaultdict

__author__ = "Tobias Marschall"

usage = """Usage: %prog [options] <reference.fasta(.gz)> <dataset-list> <result-dir>

Assumes that CLEVER has already been run on all datasets under study. Based on BAM files
and VCFs produced by CLEVER, split read alignments are performed on regions of interest
and all events are genotyped.

File format for <dataset-list>, one row for each dataset:
<name> <bam-file> <clever-vcf> <role>
where <role> is one of {none, mother, father, child, monozygotic_twin, dizygotic_twin}.
"""

def locate_executeable(exe_dict, name):
	def isexecutable(f):
		return os.path.isfile(f) and os.access(f, os.X_OK)
	for path in os.environ["PATH"].split(os.pathsep):
		f = os.path.join(path, name)
		if isexecutable(f):
			exe_dict[name] = f
			print('Found executable', f, file=sys.stderr)
			return True
	scriptpath = os.path.dirname(os.path.abspath(__file__))
	f = os.path.join(scriptpath, name)
	if isexecutable(f):
		exe_dict[name] = f
		print('Found executable', f, file=sys.stderr)
		return True
	f = os.path.abspath(os.path.join(scriptpath, '..', 'src', name))
	if isexecutable(f):
		exe_dict[name] = f
		print('Found executable', f, file=sys.stderr)
		return True
	print('Could not locate executable \"%s\". It\'s not in your PATH.'%name, file=sys.stderr)
	return False

def main():
	parser = OptionParser(usage=usage)
	parser.add_option("-T", action="store", dest="threads", type=int, default=4,
			help="Threads.")
	parser.add_option("-M", action="store", dest="max_del_length", type=int, default=10000,
			help="Maximum deletion length to look for (default: 10000).")
	parser.add_option("-f", action="store_true", dest="force", default=False,
			help="Delete old result and working directory first (if present).")
	parser.add_option("-w", action="store", dest="work_dir", default=None,
			help="Working directory (default: <result-directory>/work).")
	parser.add_option("-k", action="store_true", dest="keep_workdir", default=False,
			help="Keep working directory (default: delete directory when finished).")
	parser.add_option("-W", action="store", dest="snp_weight_cutoff", default=3.0, type=float,
			help="Minimum expected support for a SNP in order to ignore mismatches at that position (default: 3.0).")
	parser.add_option("-o", action="store", dest="max_offset", default=100, type=int,
			help="Maximum center distance between split-read and read-pair deletion to be considered identical (default: 100).")
	parser.add_option("-z", action="store", dest="max_length_diff", default=20, type=int,
			help="Maximum length difference between split-read and read-pair deletion to be considered identical (default: 20).")
	(options, args) = parser.parse_args()
	if (len(args) != 3):
		parser.print_help()
		sys.exit(1)
	ref_filename = args[0]
	dataset_list_filename = args[1]
	result_dir = args[2]
	ref_match = re.search('^(.*)\.(fasta|fa)(\.gz|)$', ref_filename)
	if ref_match == None:
		print("Error: reference filename must end on .(fasta|fa)(.gz).", file=sys.stderr)
		return 1
	ref_prefix = ref_match.group(1)
	exe_dict = dict()
	print('===== Checking dependencies =====', file=sys.stderr)
	if not locate_executeable(exe_dict, 'ctk-version'): return 1
	if not locate_executeable(exe_dict, 'samtools'): return 1
	if not locate_executeable(exe_dict, 'genotyper'): return 1
	if not locate_executeable(exe_dict, 'mateclever-compute-rois'): return 1
	if not locate_executeable(exe_dict, 'bedtools'): return 1
	if not locate_executeable(exe_dict, 'vcf-to-deletionlist'): return 1
	if not locate_executeable(exe_dict, 'extract-bad-reads'): return 1
	if not locate_executeable(exe_dict, 'laser'): return 1
	if not locate_executeable(exe_dict, 'laser-recalibrate'): return 1
	if not locate_executeable(exe_dict, 'merge-putative-variations'): return 1
	if not locate_executeable(exe_dict, 'filter-variations'): return 1
	if not locate_executeable(exe_dict, 'insert-length-histogram'): return 1
	print('===== Determining CTK version =====', file=sys.stderr)
	ctk_version = subprocess.Popen([exe_dict['ctk-version']], stdout=subprocess.PIPE, universal_newlines=True).stdout.readline().strip()
	print('Version: ', ctk_version, file=sys.stderr)
	print('===== Reading list of datasets =====', file=sys.stderr)
	datasets = []
	allowed_roles = set(['none','mother','father','child'])
	present_roles = set()
	if not os.path.isfile(dataset_list_filename):
		print('Error opening "%s"'%dataset_list_filename,  file=sys.stderr)
		return 1
	for line in open(dataset_list_filename):
		fields = line.split()
		if len(fields) != 4:
			print('Error parsing "%s"'%dataset_list_filename,  file=sys.stderr)
			return 1
		name, bam_filename, clever_vcf_filename, role = fields
		if not role in allowed_roles:
			print('Error parsing "%s": invalid role: "%s"'%(dataset_list_filename,role),  file=sys.stderr)
			return 1
		if not os.path.isfile(bam_filename): 
			print('Error: File "%s" not found'%bam_filename,  file=sys.stderr)
			return 1
		if not os.path.isfile(clever_vcf_filename): 
			print('Error: File "%s" not found'%clever_vcf_filename,  file=sys.stderr)
			return 1
		present_roles.add(role)
		datasets.append((name, bam_filename, clever_vcf_filename, role))
	if (len(datasets) == 1) and (present_roles == set(['none'])):
		pass
	elif (len(datasets) == 3) and (present_roles == set(['mother','father','child'])):
		pass
	else:
		print('Error: Illegal (combination of) roles in file "%s"'%dataset_list_filename, file=sys.stderr)
		return 1
	for i, (name, bam_filename, clever_vcf_filename, role) in enumerate(datasets):
		print('Dataset %d: %s %s %s'%(i, bam_filename, clever_vcf_filename, role), file=sys.stderr)
	#work_dir = os.path.join(result_dir, 'work')
	print('===== Checking directories =====', file=sys.stderr)
	if os.path.isdir(result_dir):
		if options.force:
			try:
				shutil.rmtree(result_dir)
			except OSError as e:
				print("Error deleting previous result directory:", e, file=sys.stderr)
				return 1
		else:
			print("Error: directory \"%s\" already exists. Move it out of the way or use option -f."%result_dir, file=sys.stderr)
			return 1
	try:
		os.makedirs(result_dir)
	except OSError as e:
		print("Error creating result directory:", e, file=sys.stderr)
		return 1
	if options.work_dir != None:
		work_dir = os.path.abspath(options.work_dir)
		if os.path.exists(work_dir):
			if options.force:
				shutil.rmtree(work_dir)
			else:
				print("Error: directory \"%s\" already exists. Move it out of the way or use option -f."%work_dir, file=sys.stderr)
				return 1
	else:
		work_dir = os.path.join(result_dir, 'work')
	os.makedirs(work_dir)
	print('Result directory:', result_dir, file=sys.stderr)
	print('Working directory:', work_dir, file=sys.stderr)
	#os.chdir(work_dir)
	commandline_logfile = open(work_dir + '/commandline.log','w')
	print('\n'.join(sys.argv), file=commandline_logfile)
	commandline_logfile.close()
	print('===== Preparing list of regions to be considered =====', file=sys.stderr)
	for i, (name, bam_filename, clever_vcf_filename, role) in enumerate(datasets):
		convert_to_list_call = [exe_dict['vcf-to-deletionlist'], '-i', '-m', str(options.max_del_length), clever_vcf_filename]
		filename = work_dir + '/clever-deletions-%d.txt'%i
		print('Writing', filename, file=sys.stderr)
		convert_to_list_output = open(filename, 'w')
		convert_to_list = subprocess.Popen(convert_to_list_call, stdout=convert_to_list_output)
		convert_to_list_output.close()
		if convert_to_list.wait() != 0:
			print('Error executing "vcf-to-deletionlist".', file=sys.stderr)
			return 1
	bedtools_sort = None
	bedtools_merge = None
	compute_rois = None
	regions_filename = work_dir + '/regions.bed'
	print("regions_filename: %s" % (regions_filename))
	try:
		all_rois = open(work_dir + '/all_rois.txt', 'w')
		for i in range(len(datasets)):
			filename = work_dir + '/clever-deletions-%d.txt'%i
			print("working on file %s"%(filename), file=sys.stderr)
			compute_rois = subprocess.Popen([exe_dict['mateclever-compute-rois']], stdout=all_rois, stdin=open(filename))
			if compute_rois.wait() != 0:
				print('Error executing "mateclever-compute-rois".', file=sys.stderr)
				return 1
			else: compute_rois = None
		all_rois.close()
		all_rois = open(work_dir + '/all_rois.txt', 'r')
		regions_file = open(regions_filename, 'w')
		bedtools_sort = subprocess.Popen([exe_dict['bedtools'], 'sort', '-i', 'stdin'], stdout=subprocess.PIPE, stdin=all_rois)
		bedtools_merge = subprocess.Popen([exe_dict['bedtools'], 'merge', '-i', 'stdin', '-d', '500'], stdout=regions_file, stdin=bedtools_sort.stdout)
		if bedtools_sort.wait() != 0:
			print('Error executing "bedtools sort".', file=sys.stderr)
			return 1
		else: bedtools_sort = None
		if bedtools_merge.wait() != 0:
			print('Error executing "bedtools merge".', file=sys.stderr)
			return 1
		else: bedtools_merge = None
		regions_file.close()
	except:
		if compute_rois != None: compute_rois.terminate()
		if bedtools_sort != None: bedtools_sort.terminate()
		if bedtools_merge != None: bedtools_merge.terminate()
	print('===== Determining insert size distribution for all datasets =====', file=sys.stderr)
	for i, (name, bam_filename, clever_vcf_filename, role) in enumerate(datasets):
		insert_size_filename = work_dir + '/insert-size-dist-%d.txt'%i
		insert_size_log_filename = work_dir + '/insert-size-dist-%d.log'%i
		print('Creating', insert_size_filename, file=sys.stderr)
		subprocess.call(
			[exe_dict['insert-length-histogram'], '--sorted', '--count', '1000000'],
			stdin = open(bam_filename),
			stdout = open(insert_size_filename, 'w'),
			stderr = open(insert_size_log_filename, 'w')
		)
	print('===== Extracting reads from regions of interest =====', file=sys.stderr)
	regions = ['%s:%s-%s'%tuple(s.split()) for s in open(regions_filename)]
	samtools_view = None
	samtools_unview = None
	bam_to_fastq = None
	for i, (name, bam_filename, clever_vcf_filename, role) in enumerate(datasets):
		fastq1_filename = work_dir + '/dataset-%d.1.fastq'%i
		fastq2_filename = work_dir + '/dataset-%d.2.fastq'%i
		try:
			samtools_view = subprocess.Popen([exe_dict['samtools'], 'view', '-h', bam_filename] + regions, stdout=subprocess.PIPE)
			samtools_unview = subprocess.Popen([exe_dict['samtools'], 'view', '-Sb', '-F 0xF00', '-'], stdin=samtools_view.stdout, stdout=subprocess.PIPE)
			bam_to_fastq = subprocess.Popen([exe_dict['extract-bad-reads'], '-a', fastq1_filename, fastq2_filename], stdin=samtools_unview.stdout)
			if samtools_view.wait() != 0:
				print('Error executing "samtools view".', file=sys.stderr)
				return 1
			else: samtools_view = None
			if samtools_unview.wait() != 0:
				print('Error executing "samtools view -Sb".', file=sys.stderr)
				return 1
			else: samtools_unview = None
			if bam_to_fastq.wait() != 0:
				print('Error executing "extract-bad-reads".', file=sys.stderr)
				return 1
			else: bam_to_fastq = None
		except:
			if samtools_view != None: samtools_view.terminate()
			if samtools_unview != None: samtools_unview.terminate()
			if bam_to_fastq != None: bam_to_fastq.terminate()
	print('===== Aligning extracted reads using LASER =====', file=sys.stderr)
	laser = None
	for i, (name, bam_filename, clever_vcf_filename, role) in enumerate(datasets):
		print('Processing dataset', i, file=sys.stderr)
		fastq1_filename = work_dir + '/dataset-%d.1.fastq'%i
		fastq2_filename = work_dir + '/dataset-%d.2.fastq'%i
		try:
			laser_call = [exe_dict['laser'], '--extra-sensitive', '--dont-recalibrate', '--keep_raw_bam', '-w', '0.1', '-T', str(options.threads), '-M', str(options.max_del_length), ref_filename, fastq1_filename, fastq2_filename, work_dir + '/laser-%d'%i]
			log_file = work_dir + '/laser-%d.log'%i
			laser = subprocess.Popen(laser_call, stderr=open(log_file, 'w'))
			if laser.wait() != 0:
				print('Error executing "laser", see %s for error messages.'%log_file, file=sys.stderr)
				return 1
			else: laser = None
		except:
			if laser != None: laser.terminate()
	print('===== Merging lists of putative variants =====', file=sys.stderr)
	clever_lists = []
	laser_lists = []
	snp_lists = []
	for i, (name, bam_filename, clever_vcf_filename, role) in enumerate(datasets):
		clever_lists.append(work_dir + '/clever-deletions-%d.txt'%i)
		laser_lists.append(work_dir + '/laser-%d.putative-indels'%i)
		snp_lists.append(work_dir + '/laser-%d.putative-snps'%i)
	if subprocess.call([exe_dict['merge-putative-variations'], '-w', '-m', '0'] + clever_lists, stdout=open(work_dir + '/clever-deletions-merged.txt', 'w')) != 0:
		print('Error executing merge-putative-variations!', file=sys.stderr)
		return 1
	if subprocess.call([exe_dict['merge-putative-variations'], '-w', '-m', '0.5'] + laser_lists, stdout=open(work_dir + '/laser-deletions-merged.txt', 'w')) != 0:
		print('Error executing merge-putative-variations!', file=sys.stderr)
		return 1
	if subprocess.call([exe_dict['merge-putative-variations'], '-m', str(options.snp_weight_cutoff)] + snp_lists, stdout=open(work_dir + '/laser-snps-merged.txt', 'w')) != 0:
		print('Error executing merge-putative-variations!', file=sys.stderr)
		return 1
	x = [exe_dict['filter-variations'], '-o', str(options.max_offset), '-z', str(options.max_length_diff), '-l', '10', work_dir + '/laser-deletions-merged.txt']
	consensus_deletions_filename = work_dir + '/consensus-deletions.txt'
	if subprocess.call(x, stdin = open(work_dir + '/clever-deletions-merged.txt'), stdout = open(consensus_deletions_filename, 'w'), stderr = open(work_dir + '/consensus-deletions.log', 'w')) != 0:
		print('Error executing filter-variations!', file=sys.stderr)
		return 1
	print('Done.', file=sys.stderr)
	print('===== Recalibrating LASER alignments =====', file=sys.stderr)
	for i, (name, bam_filename, clever_vcf_filename, role) in enumerate(datasets):
		laser_raw_bam = work_dir + '/laser-%d.raw.bam'%i
		insert_size_filename = work_dir + '/insert-size-dist-%d.txt'%i
		laser_recal_bam = work_dir + '/laser-%d.recal.bam'%i
		print('Running laser-recalibrate to create', laser_recal_bam, file=sys.stderr)
		x = [exe_dict['laser-recalibrate'], '-csM', '--default_readgroup', '--snp', work_dir + '/laser-snps-merged.txt', '--variations',  work_dir + '/consensus-deletions.txt', insert_size_filename]
		if subprocess.call(x, stdin = open(laser_raw_bam), stdout = open(laser_recal_bam, 'w'), stderr = open(work_dir + '/laser-%d.recal.log'%i, 'w')) != 0:
			print('Error executing LASER, see %s/laser-%d.recal.log for error messages.'%(work_dir,i), file=sys.stderr)
			return 1
		print('Sorting and indexing', laser_recal_bam, file=sys.stderr)
		if subprocess.call([exe_dict['samtools'], 'sort', '-o', work_dir + '/laser-%d.recal.sorted.bam'%i, laser_recal_bam]) != 0:
			print('Error running samtools sort!', file=sys.stderr)
			return 1
		if subprocess.call([exe_dict['samtools'], 'index', work_dir + '/laser-%d.recal.sorted.bam'%i]) != 0:
			print('Error running samtools index!', file=sys.stderr)
			return 1
	print('===== Genotyping high-confidence deletions =====', file=sys.stderr)
	datasets_filename = work_dir + '/genotyper.datasets'
	datasets_file = open(datasets_filename, 'w')
	for i, (name, bam_filename, clever_vcf_filename, role) in enumerate(datasets):
		laser_recal_bam = work_dir + '/laser-%d.recal.sorted.bam'%i
		print(name, laser_recal_bam, role, file=datasets_file)
	datasets_file.close()
	readgroups_filename = work_dir + '/genotyper.readgroups'
	readgroups_file = open(readgroups_filename, 'w')
	for i, (name, bam_filename, clever_vcf_filename, role) in enumerate(datasets):
		print(name, 'default', work_dir + '/insert-size-dist-%d.txt'%i, file=readgroups_file)
	readgroups_file.close()
	print('Running genotyper and writing results to', result_dir + '/deletions.vcf', file=sys.stderr)
	x = [exe_dict['genotyper'], '--min_phys_cov', '5', '--min_gq', '10', '--denovo_threshold', '1e-5', '--variant_prior', '0.1', '--mapq', '30', datasets_filename, readgroups_filename, ref_filename, consensus_deletions_filename]
	if subprocess.call(x, stdout = open(result_dir + '/deletions.vcf', 'w'), stderr = open(work_dir + '/genotyper.log', 'w')) != 0:
		print('Error executing genotyper, see %s/genotyper.log for error messages.'%work_dir, file=sys.stderr)
		return 1
	if not options.keep_workdir:
		print('===== Deleting work directory =====', file=sys.stderr)
		shutil.rmtree(work_dir, ignore_errors=True)
	print('Done.', file=sys.stderr)

if __name__ == '__main__':
	sys.exit(main())
