#!/opt/conda/conda-bld/hisat2-pipeline_1755609669124/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/bin/python
import os, glob, gzip, re, argparse
version_string = "HISAT2-pipline v1.1.1 (2025/08) "

class Sample:

	def __init__(self, name):
		self.reads = []
		self.name = name
		self.type = ""
				
	def get_hisat_string(self):
		pass
			
	def add_read(self, read):
		pass
		
	def get_name(self):
		return self.name
	
	def __str__(self):
		return f"{self.name} ({len(self.reads)} files - {self.type})"
	



class PairedReadSample(Sample):
	def __init__(self, name):
		super().__init__(name)
		self.left_reads = []
		self.right_reads = []
		self.type = "paired"
		

	def get_hisat_string(self):
		L = ','.join(self.left_reads)
		R = ','.join(self.right_reads)
		hs = f"-1 {L} -2 {R}"
		return hs
	

	def add_read(self, read):
		self.reads.append(read)
		if "_R1" in read or "READ1" in read.upper():
			self.left_reads.append(read)
		elif "_R2" in read or "READ2" in read.upper():
			self.right_reads.append(read)
		else:
			print("Warning: Paired read did not contain R1 or R2")
		

		
class UnpairedReadSample(Sample):

	def __init__(self, name):
		super().__init__(name)
		self.type = "unpaired"
			
		
	def get_hisat_string(self):
		return f"-U {','.join(self.reads)}"
	
		
	def add_read(self, read):
		self.reads.append(read)
		


def extract_compressed(file):
	"""Extracts compressed files. Supports gzip, bzip2 and lzma/xz"""
	basename = file.rsplit(".",1)[0]
	filehandle = None


	if file.lower().endswith("gz") or file.lower().endswith("gzip"):
		import gzip
		filehandle = gzip.open(file,'rt')

	elif file.lower().endswith("bz2") or file.lower().endswith("bzip2"):
		import bz2
		filehandle = bz2.open(file,'rt')

	elif file.lower().endswith("xz"):
		import lzma
		filehandle = lzma.open(file,'rt')
		
	if filehandle is not None:
		with open(basename, "w") as outfile:
			for line in filehandle:
				outfile.write(line)
		filehandle.close()


def can_run(exe):
	'''Runs "which" followed by the command and returns True if the returncode equals 0'''
	import subprocess
	try:
		subprocess.run(["which", exe], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
		return True
	except:
		return False

	

def look_for_genome(folder="genome"):
	"""Looks for any fasta file in the folder. If a fasta file cannot be found, the function will look
	for compressed fasta files, and extract them, returning the uncompressed filepath
	"""
	fasta_pattern = re.compile(r'(fa|fna|fasta|fas)$')
	
	for file in os.listdir(folder):
		if fasta_pattern.search(file.lower()) is not None:
			return f"{folder}/{file}"
	
	#could not find a fasta file. Maybe it is compressed and needs to be extracted
	compressed_pattern = re.compile(r'(fa|fna|fasta|fas)\.(bz2|gz|xz)$')
	for file in os.listdir(folder):
		if compressed_pattern.search(file.lower()) is not None:
			print("Found a compressed genome file. Extracting ...")
			extract_compressed(f"{folder}/{file}")
			return look_for_genome(folder)
	
	return None




def build_index(genome):
	basename = genome.rsplit(".f",1)[0]
	cmd = f"hisat2-build --quiet {genome} {basename}"
	print(cmd)
	os.system(cmd)
	

def map_to_genome(samples, genome, mappingFolder):
	genome_basename = genome.rsplit(".f",1)[0]
	
	
	for sample in samples:
		cmd = f"hisat2 {hisat_options} -x {str(genome_basename)} --threads {threads} {sample.get_hisat_string()}"
		cmd+= f" --new-summary --summary-file {mappingFolder}/{sample.get_name()}_summary.txt"
		cmd+= f" | samtools sort -@ {min(threads, 4)} -o {mappingFolder}/{sample.get_name()}.bam"
		
		print(cmd)
		os.system(cmd)


def get_GFF_file(folder="genome"):
	"""Looks for any gff file in the folder and returns the first one. If a gff file cannot be found, the function will look
	for compressed gff files, extract them, returning the uncompressed filepath
	"""
	gff_pattern = re.compile(r'(gff|gtf|gff3)$')
	
	for file in os.listdir(folder):
		if gff_pattern.search(file.lower()) is not None:
			return f"{folder}/{file}"
	
	#could not find a gff file. Maybe it is compressed and needs to be extracted
	compressed_pattern = re.compile(r'(gff|gtf|gff3)\.(bz2|gz|xz)$')
	for file in os.listdir(folder):
		if compressed_pattern.search(file.lower()) is not None:
			print("Found a compressed gff file. Extracting ...")
			extract_compressed(f"{folder}/{file}")
			return get_GFF_file(folder)
	
	return None


def sanitize_path(path):
	"""Sanitizes paths and removes problematic symbols"""
	return path.replace("//", "/").rstrip("/")


def run_stringtie(gff_file, mapping_folder="mapping", outfolder="abundance"):
	mapping_folder = sanitize_path(mapping_folder)
	outfolder = sanitize_path(outfolder)
	
	if not os.path.exists(outfolder):
		os.mkdir(outfolder)
	
	
	for file in os.listdir(mapping_folder):
		if not file.lower().endswith("bam") and not file.lower().endswith("sam"):
			continue
		name = file.split(".bam")[0].split(".sam")[0]
		
		cmd = f"stringtie {stringtie_options} -e -B -p {threads} {mapping_folder}/{file} -G {gff_file}"
		cmd+= f" -A {outfolder}/{name}/{name}_gene_expression.tsv"
		cmd+= f" -o {outfolder}/{name}/{name}.gtf"
		
		print(cmd)
		os.system(cmd)
	



def genome_is_already_indexed(folder="genome"):
	for file in os.listdir(folder):
		if file.endswith("ht2"):
			return True
	return False


def get_all_reads_in_folder(folder="reads"):
	return [file for file in glob.glob(f"{folder}/**/*", recursive=True) if is_fastq_file(file)]
	
	

def get_basename_of_read(read):
	basename = read.split("/")[-1]
	basename = basename.split(".fa")[0]
	basename = basename.split(".fq")[0]
	
	basename = basename.split("_R1")[0]
	basename = basename.split("_R2")[0]
	basename = basename.split("_Read1")[0]
	basename = basename.split("_Read2")[0]
	basename = basename.split("_read1")[0]
	basename = basename.split("_read2")[0]
	
	basename = basename.split("_L0")[0]
	return basename




def read_all_mapping_summaries(folder="./mapping"):
	'''Reads all *_summary.txt files in a folder and returns a list
	 of the sample names, and a list of the text in the summary files.'''
	names = []
	summaries = []
	
	for file in os.listdir(folder):
		if not file.endswith("summary.txt"):
			continue
		
		with open(f"{folder}/{file}", "r") as summary:
			names.append(file.split("_summary")[0])
			text = ""
			for line in summary:
				text+=line
			summaries.append(text)

	return names,summaries






def extract_number_from_summary(some_text_before_colon, text, extractPercentage=False):
	text = text[text.find(some_text_before_colon) : ]
	if extractPercentage:
		hit = re.search(r"\d+(\.?\d+?%)", text) #extracts all percentages
		hit = hit[0].replace("%", "")
		return float(hit)
	else:
		text = text.split(some_text_before_colon)[1]
		text = text.split(" (")[0]
		text = text.split("\n")[0]
		hit = re.search(r"\d+", text) #extracts all numbers
		text = int(hit[0])
	return text
	


def summarize_paired_mapping(folder="./mapping"):
	try:
		import pandas as pd
	except:
		print("Warning: Could not create a summary of the mapping. Is pandas installed?")
		return
	
	names, summaries = read_all_mapping_summaries(folder)
	data = dict()
	data["Sample"] = []
	data["Total pairs"] = []
	data["Aligned concordantly or discordantly 0 time"] = []
	data["Aligned concordantly or discordantly 0 time %"] = []
	data["Aligned concordantly 1 time"] = []
	data["Aligned concordantly 1 time %"] = []
	data["Aligned concordantly >1 times"] = []
	data["Aligned concordantly >1 times %"] = []
	data["Aligned discordantly 1 time"] = []
	data["Aligned discordantly 1 time %"] = []
	data["Total unpaired reads"] = []
	data["Aligned 0 time"] = []
	data["Aligned 0 time %"] = []
	data["Aligned 1 time"] = []
	data["Aligned 1 time %"] = []
	data["Aligned >1 times"] = []
	data["Aligned >1 times %"] = []
	data["Overall alignment rate"] = []

	for name, summary in zip(names, summaries):
		data["Sample"].append(name)
		data["Total pairs"].append(extract_number_from_summary("Total pairs", summary))
		data["Aligned concordantly or discordantly 0 time"].append(extract_number_from_summary("Aligned concordantly or discordantly 0 time", summary))
		data["Aligned concordantly or discordantly 0 time %"].append(extract_number_from_summary("Aligned concordantly or discordantly 0 time", summary, True))
		data["Aligned concordantly 1 time"].append(extract_number_from_summary("Aligned concordantly 1 time", summary))
		data["Aligned concordantly 1 time %"].append(extract_number_from_summary("Aligned concordantly 1 time", summary, True))
		data["Aligned concordantly >1 times"].append(extract_number_from_summary("Aligned concordantly >1 times", summary))
		data["Aligned concordantly >1 times %"].append(extract_number_from_summary("Aligned concordantly >1 times", summary, True))
		data["Aligned discordantly 1 time"].append(extract_number_from_summary("Aligned discordantly 1 time", summary))
		data["Aligned discordantly 1 time %"].append(extract_number_from_summary("Aligned discordantly 1 time", summary, True))
		data["Total unpaired reads"].append(extract_number_from_summary("Total unpaired reads", summary))
		data["Aligned 0 time"].append(extract_number_from_summary("Aligned 0 time", summary))
		data["Aligned 0 time %"].append(extract_number_from_summary("Aligned 0 time", summary, True))
		data["Aligned 1 time"].append(extract_number_from_summary("Aligned 1 time", summary))
		data["Aligned 1 time %"].append(extract_number_from_summary("Aligned 1 time", summary, True))
		data["Aligned >1 times"].append(extract_number_from_summary("Aligned >1 times", summary))
		data["Aligned >1 times %"].append(extract_number_from_summary("Aligned >1 times", summary, True))
		data["Overall alignment rate"].append(extract_number_from_summary("Overall alignment rate", summary, True))
	
	df = pd.DataFrame(data)
	df.set_index("Sample", inplace=True)
	try:
		df.to_csv(f"{folder}/mapping_summary.tsv",sep="\t", lineterminator='\n') #lineterminator was specifed due to a bug under MacOS
	except:
		df.to_csv(f"{folder}/mapping_summary.tsv",sep="\t", line_terminator='\n') #for older versions of pandas
	try:
		df.to_excel(f"{folder}/mapping_summary.xlsx")
	except:
		pass
	
	return


def summarize_unpaired_mapping(folder="./mapping"):
	try:
		import pandas as pd
	except:
		print("Warning: Could not create a summary of the mapping. Is pandas installed?")
		return
	
	names, summaries = read_all_mapping_summaries(folder)
	data = dict()
	data["Sample"] = []
	data["Total reads"] = []
	data["Aligned 0 time"] = []
	data["Aligned 0 time %"] = []
	data["Aligned 1 time"] = []
	data["Aligned 1 time %"] = []
	data["Aligned >1 times"] = []
	data["Aligned >1 times %"] = []
	data["Overall alignment rate"] = []


	for name, summary in zip(names, summaries):
		data["Sample"].append(name)
		data["Total reads"].append(extract_number_from_summary("Total reads", summary))
		data["Aligned 0 time"].append(extract_number_from_summary("Aligned 0 time", summary))
		data["Aligned 0 time %"].append(extract_number_from_summary("Aligned 0 time", summary, True))
		data["Aligned 1 time"].append(extract_number_from_summary("Aligned 1 time", summary))
		data["Aligned 1 time %"].append(extract_number_from_summary("Aligned 1 time", summary, True))
		data["Aligned >1 times"].append(extract_number_from_summary("Aligned >1 times", summary))
		data["Aligned >1 times %"].append(extract_number_from_summary("Aligned >1 times", summary, True))
		data["Overall alignment rate"].append(extract_number_from_summary("Overall alignment rate", summary, True))
	
	df = pd.DataFrame(data)
	df.set_index("Sample", inplace=True)
	df.to_csv(f"{folder}/mapping_summary.tsv",sep="\t")
	try:
		df.to_excel(f"{folder}/mapping_summary.xlsx")
	except:
		pass
	
	return
	
	
	

def merge_abundances(abundance_folder):
	
	try:
		import pandas as pd
	except:
		print("Warning: Pandas is not installed. Cannot combine the FPKM's into a single file.")

	merged_gene_info = None
	merged_FPKM= None
	merged_TPM= None
	
	#stringtie result files omit entries that have no expression. Let's first parse all abundance files,
	#and gather all gene names / gene information	
	for abundance_file in glob.glob(f"{abundance_folder}/*/*_gene_expression.tsv"):
		
		sample_name = abundance_file.split("/")[-2]
		df = pd.read_table(abundance_file)
		df = df[df.columns[0:6]] #keep only the gene info part
		#occasionally, stringtie has trouble identifying the gene ID or gene Name from the GFF files. Lets create an
		#unmistakable ID. Note: If missing, the Gene Name is sometimes assinged a dot, sometimes a dash. 
		#I could not find any stringtie documentation explaining why that is
		df['Gene Name'] = [str(n) if n!='.' and n!= '-' else "." for n in df['Gene Name']]
		df['Gene ID'] = [str(n) if n!='.' and n!= '-' else "." for n in df['Gene ID']]
		df["temp_id"] = df["Gene ID"].astype('str') + df["Gene Name"].astype('str') + df['Reference'].astype('str') + df["Start"].astype('str') + df['End'].astype('str')
		df.set_index("temp_id", inplace=True)
		
		if merged_gene_info is None:
			merged_gene_info = df
		else:
			df = df[df.index.isin(merged_gene_info.index)==False]
			merged_gene_info = pd.concat([merged_gene_info, df])
			
	merged_FPKM = merged_gene_info
	merged_TPM = merged_gene_info
	
	for abundance_file in glob.glob(f"{abundance_folder}/*/*_gene_expression.tsv"):
		sample_name = abundance_file.split("/")[-2]
		df = pd.read_table(abundance_file)
		df['Gene Name'] = [str(n) if n!='.' and n != '-' else "." for n in df['Gene Name']]
		df['Gene ID'] = [str(n) if n!='.' and n != '-' else "." for n in df['Gene ID']]
		df["temp_id"] = df["Gene ID"].astype('str') + df["Gene Name"].astype('str') + df['Reference'].astype('str') + df["Start"].astype('str') + df['End'].astype('str')
		df.set_index("temp_id", inplace=True)
		
		merged_TPM  = merged_TPM.join(df[["TPM"]], how='outer')
		merged_FPKM = merged_FPKM.join(df[["FPKM"]], how='outer')
	
		merged_TPM.rename(columns={"TPM":sample_name}, inplace=True)
		merged_FPKM.rename(columns={"FPKM":sample_name}, inplace=True)
			
	merged_TPM= merged_TPM.reset_index().drop("temp_id", axis=1).set_index("Gene ID")
	merged_FPKM = merged_FPKM.reset_index().drop("temp_id", axis=1).set_index("Gene ID")
		
	merged_FPKM.to_csv(f"{abundance_folder}/merged_FPKM.tsv", sep="\t")
	merged_TPM.to_csv(f"{abundance_folder}/merged_TPM.tsv", sep="\t")
	try:
		merged_FPKM.to_excel(f"{abundance_folder}/merged_FPKM.xlsx")
		merged_TPM.to_excel(f"{abundance_folder}/merged_TPM.xlsx")
	except:
		pass
	
	return merged_FPKM, merged_TPM

			

def prepare_reads(folder=".", arePaired=False):
	reads = get_all_reads_in_folder(folder)
	sample_dict = dict()
	
	for read in reads:
		basename = get_basename_of_read(read)
		entry = sample_dict.setdefault(basename, [])
		entry.append(read)
	
	
	samples = []
	for basename, reads in sample_dict.items():
		sample = None
		if arePaired:
			sample =  PairedReadSample(basename)
		else:
			sample =  UnpairedReadSample(basename)
		
		for read in reads:
			sample.add_read(read)
			
		samples.append(sample)
	
	return samples
	

def is_fastq_file(file):
	fastq_pattern = re.compile(r'(fq|fastq)(\.gz|\.bz2)?$')
	if fastq_pattern.search(file.lower()) is not None:
		return True
	else:
		return False
		

		
def are_paired_reads(folder="."):
	foundLeftRead = False
	foundRightRead = False

	for file in glob.glob(f"{folder}/**/*", recursive=True):
	
		if is_fastq_file(file):
			if "_R1" in file or "READ1" in file.upper():
				foundLeftRead = True
			elif "_R2" in file or "READ2" in file.upper():
				foundRightRead = True

		if foundLeftRead and foundRightRead:
			return True
	return False



def check_if_all_prerequisites_installed():
	print("Checking if the required software is installed...")
	found_hisat = can_run("hisat2")
	found_stringtie = can_run("stringtie")
	found_samtools = can_run("samtools")
	
	if not found_hisat:
		print("HISAT2 does not seem to be installed...")
		
	if not found_stringtie:
		print("stringtie does not seem to be installed...")
		
	if not found_samtools:
		print("samtools does not seem to be installed...")
	
	try:
		import pandas as pd
	except:
		print("Warning: Pandas is not installed. Your results won't be summarized.")


	
	if (not found_hisat) or (not found_samtools) or (not found_stringtie):
		print("")
		print("Some required software was not found.")
		print("Make sure it is installed, and that you're")
		print("running this script in the same environment.")
		print("Quitting...")
		exit(130)
	else:
		print("... no problems found.")
		print("")


parser = argparse.ArgumentParser()
parser.add_argument("--reads_folder", help="The folder where the reads are located (default= ./reads)", default="./reads")
parser.add_argument("--genome_folder", help="The folder where the genome fasta and gff file is located (default= ./genome)", default="./genome")
parser.add_argument("--outfolder", help="The folder where the results will be written to (default=./)", default="./")
parser.add_argument("--skip_mapping", help="Skip mapping to genome", action="store_true")
parser.add_argument("--threads", help=f"The number of threads used (default={os.cpu_count()})", default=f"{os.cpu_count()}")
parser.add_argument("--hisat_options", nargs=1, help=f'May be used to pass additional options to HISAT2. Provided parameters must be surrounded by quotation marks, e.g. --hisat_options "--very-sensitive --no-spliced-alignment"',required=False, default=[])
parser.add_argument("--stringtie_options", nargs=1, help=f'May be used to pass additional options to Stringtie. Provided parameters must be surrounded by quotation marks, e.g. --stringtie_options "-m 150 -t"',required=False, default=[])
parser.add_argument("--yes", help="Answer all questions with 'yes'", action='store_true')
parser.add_argument("--version", help="Prints the current version'", action='store_true')

args = parser.parse_args()

		
print(version_string)
print("")

if args.version:
	exit(0) #exit immediately after printing the version

folder = sanitize_path(args.outfolder)

reads_folder = sanitize_path(args.reads_folder)

genome_folder = sanitize_path(args.genome_folder)

threads = int(args.threads)

hisat_options = "" if len(args.hisat_options) == 0 else args.hisat_options[0]
stringtie_options = "" if len(args.stringtie_options) == 0 else args.stringtie_options[0]


check_if_all_prerequisites_installed()

all_read_files = [f for f in glob.glob(f"{reads_folder}/**/*", recursive=True) if is_fastq_file(f)]

if len(all_read_files)==0:
	print(f"Found no reads in reads folder ({reads_folder})")
	print("Quitting...")
	exit(130)

PAIRED = are_paired_reads(reads_folder)
if PAIRED:
	print("I found PAIRED reads in the folder. ")
	paired_correct_answer = "yes" if args.yes else input("Is this correct? (yes/no) ")
	if "N" in paired_correct_answer.upper():
		PAIRED = not PAIRED
else:
	print("I found UNPAIRED reads in the folder. ")
	paired_correct_answer = "yes" if args.yes else input("Is this correct? (yes/no) ")
	if "N" in paired_correct_answer.upper():
		print("Please rename the file so that they contain _R1 and _R2 respectively (or _Read1 and _Read2). Otherwise I won't able to distinguish them.")
		print("Exiting")
		exit(0)	
	
if PAIRED:
	print("Running pipeline with PAIRED reads")
else:
	print("Running pipeline with UNPAIRED reads")

print()

samples = prepare_reads(reads_folder, PAIRED) 
print("Found the following samples:")
for sample in samples:
	print(sample)

print()

samples_correct = "yes" if args.yes else input("Is this correct? (yes/no) ")
if "n" in samples_correct.lower():
		print("Exiting")
		exit(130)	


print()
gff_file = get_GFF_file(genome_folder)
if gff_file is None:
	print("Could not find the GFF file for the genome. Are you sure it's in the genome folder?")
	print("Quitting..")
	exit(130)


genome = look_for_genome(genome_folder)
if genome is not None:
	if not args.skip_mapping:
		if not genome_is_already_indexed(genome_folder):
			print()
			print("Building the genome index:")
			
			build_index(genome)
		else:
			print("I found a genome index in the genome folder. ")
			reindex =  "yes" if args.yes else input("Do you want to skip building the index? (yes/no) ")
			if "n" in reindex.lower():
				print()
				print("Building the genome index.")
				build_index(genome)

else:
	print("Could not find a genome file in the genomes folder. Exiting")
	exit(130)
	
	
print()
print("Mapping reads using HISAT2")

os.makedirs(f"{folder}/mapping", exist_ok=True)

if not args.skip_mapping:
	map_to_genome(samples, genome, f"{folder}/mapping")
	if PAIRED:
		summarize_paired_mapping(f"{folder}/mapping")
	else:
		summarize_unpaired_mapping(f"{folder}/mapping")


print("Estimating expression abundance with stringtie")
run_stringtie(gff_file, mapping_folder=f"{folder}/mapping", outfolder=f"{folder}/abundance")	
print()

merge_abundances(f"{folder}/abundance")


	
	
	
	
	
	
