#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright 2025 EMBL - European Bioinformatics Institute
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from collections import Counter
import csv
import logging

RANK_PREFIXES = {
    "superkingdom": "sk__",
    "kingdom": "k__",
    "phylum": "p__",
    "class": "c__",
    "order": "o__",
    "family": "f__",
    "genus": "g__",
    "species": "s__",
}

logging.basicConfig(level=logging.INFO, format="[%(asctime)s] - %(levelname)s - %(message)s")


def import_nodes(nodes_dmp):
    logging.info(f"Loading file {nodes_dmp}")
    taxid2rank = {}

    with open(nodes_dmp) as f1:
        for line in f1:
            fields = [part.strip() for part in line.split("|")]
            if len(fields) != 14:
                raise ValueError(f"Unexpected number of columns in line: {line}")
            taxid = fields[0]
            rank = fields[2]
            taxid2rank[taxid] = rank

    return taxid2rank


def import_names(names_dmp):
    logging.info(f"Loading file {names_dmp}")
    taxid2name = {}

    with open(names_dmp, newline="") as f1:
        for line in f1:
            fields = [part.strip() for part in line.split("|")]
            if len(fields) != 5:
                raise ValueError(f"Unexpected number of columns in line: {line}")
            if fields[3] == "scientific name":
                taxid = fields[0]
                name = fields[1]
                taxid2name[taxid] = name

    return taxid2name


def convert_to_official_names(lineage, taxid2rank, taxid2name):
    lineage_ranks = [taxid2rank[taxid.rstrip("*")] for taxid in lineage]
    official_names = list(RANK_PREFIXES.values())
    lowest_classification_index = -1

    for i, rank in enumerate(RANK_PREFIXES):
        if rank in lineage_ranks:
            index = lineage_ranks.index(rank)
            taxid = lineage[index].rstrip("*")
            name = taxid2name[taxid]
            official_names[i] = official_names[i] + name
            lowest_classification_index = i

    return official_names[: lowest_classification_index + 1]


def main():
    parser = argparse.ArgumentParser(
        description="Process TSV classification generated by CAT_pack contigs and write input file for Krona ktImportText"
    )
    parser.add_argument("-i", "--input", help="Path to the input TSV file from CAT_pack contigs")
    parser.add_argument("-o", "--output", help="Name of the output Krona TXT file")
    parser.add_argument("-n", "--names_dmp", help="Path to the nodes.dmp file from NCBI taxonomy")
    parser.add_argument("-r", "--nodes_dmp", help="Path to the names.dmp file from NCBI taxonomy")
    args = parser.parse_args()

    taxid2rank = import_nodes(args.nodes_dmp)
    taxid2name = import_names(args.names_dmp)

    logging.info(f"Begin parsing of CAT_pack classiffication file {args.input}")
    lineage_counter = Counter()
    with open(args.input) as infile:
        reader = csv.reader(infile, delimiter="\t")
        next(reader)  # Skip the header row
        for row in reader:
            if row[1] == "no taxid assigned":
                lineage = "unclassified"
            else:
                taxid_lineage = row[3].split(";")
                names_lineage = convert_to_official_names(taxid_lineage, taxid2rank, taxid2name)
                lineage = "\t".join(names_lineage) if names_lineage else "unclassified"
            lineage_counter[lineage] += 1

    logging.info(f"Writting output to {args.output}")
    with open(args.output, "w") as outfile:
        for lineage, count in lineage_counter.most_common():
            outfile.write(f"{count}\t{lineage}\n")

    logging.info("Done")


if __name__ == "__main__":
    main()
