#!/opt/mambaforge/envs/bioconda/conda-bld/concoct_1758281759540/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh/bin/python
from __future__ import division
DESC="""A script that iterates over concoct results and reruns the concoct algorithm 
for clusters where the median SCG presence is at least 2."""


import vbgmm
import numpy as np 
import argparse
import pandas as p

from sklearn.decomposition import PCA

def main(args):
    clusters = p.read_csv(args.cluster_file, header=None, index_col=0)
    
    original_data = p.read_csv(args.original_data, header=0, index_col=0)
    
    original_data_matrix = original_data.values()
    
    scg_freq = p.read_csv(args.scg_file, header=0, index_col=0)
    
    scg_freq_matrix = scg_freq.as_matrix()
    
    med_scgs = np.median(scg_freq_matrix, axis=1) 
    
    clusters_matrix = clusters.as_matrix()
    
    cluster_freq = np.bincount(clusters_matrix[:,0]) 
    
    K = cluster_freq.shape[0]
    new_clusters_matrix = np.copy(clusters_matrix,order='C')
    nNewK = K - 1
    for k in range(K):
        if med_scgs[k] > 1:
            
            select = clusters_matrix == k
            
            slice_k = original_data_matrix[select[:,0],:]
            
            index_k = np.where(select[:,0])[0]
            
            pca_object = PCA(n_components=0.90).fit(slice_k)
            transform_k = pca_object.transform(slice_k)
            
            NK = med_scgs[k]*args.expansion_factor
            print("Run CONCOCT for " + str(k) + "with " + str(NK) + "clusters" + " using " + str(args.threads) + "threads")
            assigns = vbgmm.fit(np.copy(transform_k,order='C'), int(NK), args.seed, args.threads)
            kK = np.max(assigns) + 1
            
            
            for a in range(1,kK):
                index_a = index_k[assigns == a]
                new_clusters_matrix[index_a] =  nNewK + a
            
            nNewK = nNewK + kK - 1
    
    new_assign_df = p.DataFrame(new_clusters_matrix,index=original_data.index)
    new_assign_df.to_csv("clustering_refine.csv")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=DESC)

    parser.add_argument("cluster_file", help="string specifying cluster file")

    parser.add_argument("original_data", help="string original but transformed data file")

    parser.add_argument("scg_file", help="string specifying scg frequency file")

    parser.add_argument('-e','--expansion_factor',default=2, type=int,
                                        help=("number of clusters to expand by"))

    parser.add_argument('-s', '--seed' , default=11, type=int,
                        help=("The seed used for algorithm result reproducibility."))

    parser.add_argument('-t','--threads',default=1, type=int,
                        help=("number of threads to use defaults to one"))

    args = parser.parse_args()

    main(args)

