#!/opt/mambaforge/envs/bioconda/conda-bld/concoct_1758281759540/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeh/bin/python
from __future__ import print_function

import sys
import logging
import vbgmm
import numpy as np


from concoct.output import Output
from concoct.parser import arguments
from concoct.input import load_data
from concoct.transform import perform_pca


def main(args):
    # Initialize output handling
    Output(args.basename,args)

    composition, cov, cov_range = load_data(args)

    # If there are zero or one contig that exceed the filter, do not continue
    if len(composition) < 2:
        logging.error('Not enough contigs pass the threshold filter. Exiting!')
        sys.exit(-1)

    if cov is not None:
        joined = composition.join(cov.loc[:,cov_range[0]:cov_range[1]],how="inner")
    else:
        joined = composition

    # Fix special case in pca_components
    if args.pca_components == "All":
        args.pca_components = joined.shape[1]

    #PCA on the contigs that have kmer count greater than length_threshold
    transform_filter, pca = perform_pca(
        joined,
        args.pca_components,
        args.seed
        )

    logging.info('Performed PCA, resulted in %s dimensions' % transform_filter.shape[1])

    if not args.no_original_data:
        Output.write_original_data(
            joined,
            args.length_threshold
            )

    Output.write_pca(
        transform_filter,
        args.length_threshold,
        joined.index,
        )

    Output.write_pca_components(
        pca.components_,
        args.length_threshold
        )

    logging.info('PCA transformed data.')

    logging.info('Will call vbgmm with parameters: %s, %s, %s, %s, %s' % (Output.CONCOCT_PATH, args.clusters, args.length_threshold, args.threads,args.iterations))

    N_contigs = transform_filter.shape[0]
    assign = np.zeros(N_contigs, dtype=np.int32)

    assign = vbgmm.fit(np.copy(transform_filter,order='C'), int(args.clusters), int(args.seed), int(args.threads),int(args.iterations))


    Output.write_assign(
        assign,
        args.length_threshold,
        joined.index,
        )

    logging.info("CONCOCT Finished")


if __name__=="__main__":
    args = arguments()
    if args.total_percentage_pca == 100:
        args.pca_components = "All"
    else:
        args.pca_components = args.total_percentage_pca/100.0

    if args.threads == 1:
        logging.warning("CONCOCT is running in single threaded mode. Please, consider adjusting the --threads parameter.")
    results = main(args)

    logging.info("CONCOCT Finished, the log shows how it went.")
