#!/opt/mambaforge/envs/bioconda/conda-bld/skder_1757269308801/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

### Program: granet
### Author: Rauf Salamzade
### Kalan Lab
### UW Madison, Department of Medical Microbiology and Immunology

# BSD 3-Clause License
#
# Copyright (c) 2023, Rauf Salamzade
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import os
import sys
from time import sleep
from skDER import util
import argparse
import traceback
import igraph as ig
import matplotlib.pyplot as plt
import random

version = util.get_version()

VALID_METHODS = set(['phispy', 'genomad'])

def create_parser():
	""" Parse arguments """
	parser = argparse.ArgumentParser(description="""
	Program: granet
	Author: Rauf Salamzade
	Affiliation: Kalan Lab, UW Madison, McMaster University
								  
	granet: graph network - plots a network with representative genomes marked in red.
	
	E.g. granet -s greedy_skder_results/Skani_Triangle_Edge_Output.txt \\
	            -d greedy_skder_results/Dereplicated_Representative_Genomes/ \\
	            -ns 10 -w 5 -ew 2 -o greedy_skder_granet.png
	""", formatter_class=argparse.RawTextHelpFormatter)

	parser.add_argument('-s', '--skani-results', help="Path to skani clustering results - for CiDDER this can be generated via requesting -ns.", required=True)
	parser.add_argument('-d', '--dereplicated_results', help="Path to Dereplicated_Representative_Genomes/ folder in skDER or CiDDER results folder.", required=True)
	parser.add_argument('-i', '--ani-cutoff', type=float, help="The ANI cutoff for edges to be shown in the network [Default is 99.0].", default=99.0, required=False)
	parser.add_argument('-f', '--af-cutoff', type=float, help="The AF cutoff for edges to be shown in the network [Default is 90.0].", default=90.0, required=False)
	parser.add_argument('-o', '--output-plot', help="Path to output plot in PNG format [Default is \"dereplication_network_view.png\"].", required=False, default="dereplication_network_view.png")
	parser.add_argument('-w', '--width', type=int, help="Width of the plotting space [Default is 50].", default=50, required=False)
	parser.add_argument('-wm', '--weak-mode', action='store_true', help="Apply weak connection mode - to space things out more if too much overlap - not compatible with layout.", required=False, default=False)
	parser.add_argument('-l', '--layout', help="The layout mode to use - see: https://python.igraph.org/en/stable/tutorial.html#layout-algorithms [Default is \"fruchterman_reingold\"].", required=False, default="fruchterman_reingold")
	parser.add_argument('-ns', '--node-size', type=float, help="The node size [Default is 3].", required=False, default=3)
	parser.add_argument('-ew', '--edge-width', type=float, help="The edge width [Default is 0.5].", required=False, default=0.5)
	parser.add_argument('-rs', '--random-seed', type=int, help="Random seed to control layout reproducability [Default is 1234].", required=False, default=1234)
	parser.add_argument('-v', '--version', action='store_true', help="Report version of granet.", required=False, default=False)
	args = parser.parse_args()
	return args

def granet():
	if len(sys.argv)>1 and ('-v' in set(sys.argv) or '--version' in set(sys.argv)):
		sys.stderr.write('Version of granet being used is: ' + str(version) + '\n')
		sys.exit(0)
		
	# Parse arguments
	myargs = create_parser()

	skani_results_file = os.path.abspath(myargs.skani_results) 
	dereplicated_results_dir = os.path.abspath(myargs.dereplicated_results) + '/'
	output_plot = os.path.abspath(myargs.output_plot)
	ani_cutoff = myargs.ani_cutoff
	af_cutoff = myargs.af_cutoff
	width = myargs.width
	weak_mode_flag = myargs.weak_mode
	node_size_val = myargs.node_size
	edge_width_val = myargs.edge_width
	layout_setting = myargs.layout
	random_seed = myargs.random_seed
	
	reps = set([])
	for f in os.listdir(dereplicated_results_dir):
		reps.add(f)

	all_genomes = set([])
	with open(skani_results_file) as osrf:
		for i, line in enumerate(osrf):
			if i == 0: continue
			line = line.strip()
			ls = line.split('\t')
			ref = ls[0].split('/')[-1]
			que = ls[1].split('/')[-1]
			all_genomes.add(ref)
			all_genomes.add(que)

	genome_to_index = {}
	max_i = 0 
	genome_rep_list = []
	for i, g in enumerate(sorted(all_genomes)):
		genome_to_index[g] = i
		max_i = i
		if g in reps:
			genome_rep_list.append(True)
		else:
			genome_rep_list.append(False)

	n_vertices = max_i
	edges = []			
	with open(skani_results_file) as osrf:
		for i, line in enumerate(osrf):
			if i == 0: continue
			line = line.strip()
			ls = line.split('\t')
			ref = ls[0].split('/')[-1]
			que = ls[1].split('/')[-1]
			if ref != que:
				ani = float(ls[2])
				af1 = float(ls[3])
				af2 = float(ls[4])
				if ani >= ani_cutoff and (af1 >= af_cutoff or af2 >= af_cutoff):
					edge = tuple([genome_to_index[ref], genome_to_index[que]])
					edges.append(edge)
	
	g = ig.Graph(n_vertices, edges)

	random.seed(random_seed)

	g.vs["is_rep"] = genome_rep_list
	fig, ax = plt.subplots(figsize=(width,width))	
	if weak_mode_flag:
		components = g.connected_components(mode='weak')
		ig.plot(
    		components,
			target=ax,
			vertex_size=node_size_val,
    		vertex_color=["black" if gen == False else "red" for gen in g.vs["is_rep"]],
			vertex_frame_width=0,
			edge_width=edge_width_val,
    		edge_color="#aba9a9")
	else:
		ig.plot(
            g,
            target=ax,
            layout=layout_setting,
            vertex_size=node_size_val,
            vertex_color=["black" if gen == False else "red" for gen in g.vs["is_rep"]],
            vertex_frame_width=0,
            edge_width=edge_width_val,
            edge_color="#aba9a9")

	#plt.show()

	# Save the graph as an image file
	fig.savefig(output_plot)

	# success!
	sys.exit(0)

if __name__ == '__main__':
	granet()
