#!/usr/bin/env python3

import array
import argparse
import bz2
import concurrent.futures
import collections
import csv
import ctypes
import ftplib
import functools
import glob
import gzip
import hashlib
import http.client
import inspect
import io
import itertools
import json
import logging
import logging.handlers
import lzma
import math
import multiprocessing
import os
import pty
import random
import re
import shutil
import signal
import subprocess
import sys
import tarfile
import time
import tempfile
import threading
import traceback
import urllib
import urllib.error
import urllib.parse
import urllib.request
import zipfile
import zlib

LOG = None
SCRIPT_PATHNAME = None
# this is a placeholder value. The real version will
# be substituted once after calling `install_kraken.sh`.
SCRIPT_VERSION = "2.17.1"

NCBI_REST_API = "api.ncbi.nlm.nih.gov"
NCBI_SERVER = "ftp.ncbi.nlm.nih.gov"
GREENGENES_SERVER = "greengenes.microbio.me"
SILVA_SERVER = "ftp.arb-silva.de"
# GTDB_SERVER = "data.gtdb.ecogenomic.org"
GTDB_SERVER = "data.ace.uq.edu.au"
AMBIGUOUS_TAXID = 2 ** 32 - 1

WRAPPER_ARGS_TO_BIN_ARGS = {
    "block_size": "-B",
    "classified_out": "-C",
    "confidence": "-T",
    "fast_build": "-F",
    "interleaved": "-S",
    "kmer_len": "-k",
    "max_db_size": "-M",
    "memory_mapping": "-M",
    "minimizer_len": "-l",
    "minimum_bits_for_taxid": "-r",
    "minimum_base_quality": "-Q",
    "minimum_hit_groups": "-g",
    "output": "-O",
    "paired": "-P",
    "protein": "-X",
    "quick": "-q",
    "report": "-R",
    "report_minimizer_data": "-K",
    "report_zero_counts": "-z",
    "skip_counts": "-s",
    "sub_block_size": "-b",
    "threads": "-p",
    "unclassified_out": "-U",
    "use_mpa_style": "-m",
    "use_names": "-n",
    "use_daemon": "-D",
}


class FTP:
    def __init__(self, server):
        self.ftp = ftplib.FTP(server, timeout=600)
        self.ftp.login()
        self.ftp.sendcmd("TYPE I")
        self.pwd = "/"
        self.server = server

    def _progress_bar(self, f, remote_size):
        pb = ProgressBar(remote_size, f.tell())

        def inner(block):
            nonlocal f, remote_size, pb
            written = 0
            while written < len(block):
                written += f.write(block[written:])
            size_on_disk = f.tell()
            pb.progress(size_on_disk)
            LOG.debug(
                "{:s} {: >10s}\r".format(
                    pb.get_bar(), format_bytes(size_on_disk)
                )
            )

        return inner

    def download(self, remote_dir, filepaths):
        if isinstance(filepaths, str):
            filepaths = [filepaths]
        number_of_files = len(filepaths)
        self.cwd(remote_dir)
        for index, filepath in enumerate(filepaths):
            mode = "ab"
            local_size = 0
            remote_size = self.size(filepath)
            if os.path.exists(filepath):
                local_size = os.stat(filepath).st_size
            else:
                if os.path.basename(filepath) != filepath:
                    os.makedirs(os.path.dirname(filepath), exist_ok=True)
            if local_size == remote_size:
                LOG.info(
                    "Already downloaded {:s}\n".format(get_abs_path(filepath))
                )
                continue
            if local_size > remote_size:
                mode = "wb"
            url_components = urllib.parse.SplitResult(
                "ftp", self.server, os.path.join(remote_dir, filepath), "", ""
            )
            url = urllib.parse.urlunsplit(url_components)
            if number_of_files == 1:
                LOG.info("Downloading {:s}\n".format(url))
            else:
                LOG.info(
                    "[{:d}/{:d}] Downloading {:s}\n".format(
                        index + 1, number_of_files, url
                    )
                )
            with open(filepath, mode) as f:
                while True:
                    try:
                        cb = self._progress_bar(f, remote_size)
                        self.ftp.retrbinary(
                            "RETR " + filepath, cb, rest=f.tell()
                        )
                        break
                    except KeyboardInterrupt:
                        f.flush()
                        self.close()
                        sys.exit(1)
                    except ftplib.all_errors:
                        f.flush()
                        self.reconnect()
                        self.cwd(remote_dir)
                        continue
            absolute_path = get_abs_path(filepath)
            local_filename, local_dirname = os.path.basename(
                absolute_path
            ), os.path.dirname(absolute_path)
            clear_console_line()
            LOG.info(
                "Saved {:s} to {:s}\n".format(local_filename, local_dirname)
            )

    def cwd(self, remote_pathname):
        self.ftp.cwd(remote_pathname)
        self.pwd = remote_pathname

    def size(self, filepath):
        size = 0
        while True:
            try:
                size = self.ftp.size(filepath)
                break
            except ftplib.error_temp:
                self.reconnect()
                continue
        return size

    def exists(self, filepath):
        while True:
            try:
                self.size(filepath)
                break
            except ftplib.error_perm as e:
                if e.args[0].find("No such file or directory"):
                    return False
                raise
        return True

    def connect(self, server):
        self.ftp = ftplib.FTP(server)
        self.ftp.login()
        self.ftp.sendcmd("TYPE I")

    def reconnect(self):
        host = self.ftp.host
        self.ftp.close()
        self.connect(host)
        self.ftp.cwd(self.pwd)

    def host(self):
        return self.ftp.host

    def close(self):
        self.ftp.quit()


class ProgressBar:
    def __init__(self, stop, current=0, width=30):
        self.stop = stop
        self.width = width
        self.current = current
        self.bar = list("-" * self.width)
        self.step = stop / self.width
        self.last_index = self._calculate_index()
        if self.current > 0:
            self.progress()

    def progress(self, amount=0, relative=False):
        if relative:
            self.current += amount
        else:
            self.current = amount
        if self.current > self.stop:
            self.current = self.stop
        index = self._calculate_index()
        for i in range(self.last_index, index):
            if i == 0:
                self.bar[i] = ">"
            else:
                self.bar[i - 1], self.bar[i] = "=", ">"
        self.last_index = index

    def get_bar(self):
        percentage = int(self.current / self.stop * 100)
        return "{:3d}% {:s}".format(percentage, "[" + "".join(self.bar) + "]")

    def _calculate_index(self):
        return math.floor(self.current / self.step)


class NCBI_URI_Builder:
    def __init__(self, endpoint="genome", *path_components):
        path_components = list(path_components)
        for i, component in enumerate(path_components):
            if isinstance(component, list):
                component = ",".join(component)
            path_components[i] = urllib.parse.quote(component)
        self.filters = {}
        self.path = "/datasets/v2/{}/{}".format(endpoint, "/".join(path_components))

    def assembly_source(self, source=None):
        if source:
            self.filters["filters.assembly_source"] = urllib.parse.quote(source)
        return self

    def assembly_levels(self, levels):
        if levels:
            self.filters["filters.assembly_level"] = levels
        return self

    def assembly_version(self, version=None):
        if version:
            self.filters["filters.assembly_version"] = version
        return self

    def exclude_paired_reports(self, exclude_pairs=False):
        if exclude_pairs:
            self.filters["filters.exclude_paired_reports"] = "true"
        return self

    def has_annotation(self, annotated=False):
        if annotated:
            self.filters["filters.has_annotation"] = "true"
        return self

    def search_text(self, text=None):
        if text:
            self.filters["filters.search_text"] = urllib.parse.quote(text)
        return self

    def reference_only(self, reference_only=False):
        if reference_only:
            self.filters["filters.reference_only"] = reference_only
        return self

    def page_size(self, size=None):
        if size:
            self.filters["page_size"] = size
        return self

    def page_token(self, token=None):
        if token:
            self.filters["page_token"] = token
        return self

    def include_annotation_type(self, annotation_type=None):
        if annotation_type:
            self.filters["include_annotation_type"] = annotation_type
        return self

    def set_filters_from_args(self, args):
        for k, v in vars(args).items():
            if hasattr(self, k):
                self = getattr(self, k)(v)

    def build(self):
        filters = []
        for k, v in self.filters.items():
            if isinstance(v, list):
                for value in v:
                    filters.append("{}={}".format(k, value))
            else:
                filters.append("{}={}".format(k, v))
        query = "&".join(filters)
        split = urllib.parse.SplitResult(
            scheme="",
            netloc="",
            path=self.path,
            query=query,
            fragment=""
        )

        return urllib.parse.urlunsplit(split)

    def reset(self):
        self.filters.clear()


def wrap_with_globals(f, log_queue, log_level, script_pathname, *args):
    global LOG
    global SCRIPT_PATHNAME
    LOG = Logger.setup_queue_logger(log_queue, log_level)
    SCRIPT_PATHNAME = script_pathname
    return f(*args)


def clear_console_line():
    LOG.debug("\33[2K\r")


def count_lines(*filenames):
    lines = 0
    for fname in filenames:
        with open(fname, "r") as f:
            for line in f:
                lines += 1
    return lines


def dwk2():
    estimate_capacity = find_kraken2_binary("estimate_capacity")
    output = subprocess.check_output(
        [estimate_capacity, "-h"], stderr=subprocess.STDOUT
    )
    for line in output.split(b"\n"):
        if line.startswith(b"Usage:"):
            return True if line.strip().endswith(b"<options>") else False
    return False


def get_binary_options(binary_pathname):
    options = []
    proc = subprocess.Popen(binary_pathname, stderr=subprocess.PIPE)
    lines = proc.stderr.readlines()
    for line in lines:
        match = re.search(rb"\s(-\w)\s", line)
        if not match:
            continue
        options.append(match.group(1).decode())
    return options


def construct_seed_template(args):
    if int(args.minimizer_len / 4) < args.minimizer_spaces:
        LOG.error(
            "Number of minimizer spaces, {}, exceeds max for "
            "minimizer length, {}; max: {}\n".format(
                args.minimizer_spaces,
                args.minimizer_len,
                int(args.minimizer_len / 4),
            )
        )
        sys.exit(1)
    return (
        "1" * (args.minimizer_len - 2 * args.minimizer_spaces)
        + "01" * args.minimizer_spaces
    )


def copy_globals(queue, level, script_pathname):
    global LOG
    global SCRIPT_PATHNAME
    LOG = Logger.setup_queue_logger(queue, level)
    SCRIPT_PATHNAME = script_pathname


def future_raised_exception(future):
    return future.done() and future.result() is None


def url_join(netloc, scheme="https", path="", query="", fragment=""):
    split_result = urllib.parse.SplitResult(
        scheme, netloc, path, query, fragment
    )
    return urllib.parse.urlunsplit(split_result)


def execute_in_process_pool(func, num_processes, *args):
    pass


def download_and_process_blast_volumes(args):
    download_files_from_manifest(
        NCBI_SERVER, args.threads, resume=args.resume
    )
    extraction_futures = []
    tarballs_and_converted_volumes = []
    with concurrent.futures.ProcessPoolExecutor(
            max_workers=1
    ) as pool:
        with open("manifest.txt", "r") as in_file:
            tarballs = in_file.readlines()
        wrapped_func = functools.partial(
            wrap_with_globals, extract_blast_db_files,
            LOG.get_queue(), LOG.get_level(),
            SCRIPT_PATHNAME
        )
        LOG.info(
            "Extracting index (.nin), header (.nhr), and"
            " sequence files (.nsq) from tarballs\n"
        )
        for tarball in tarballs:
            tarball = os.path.abspath(tarball)
            f = pool.submit(wrapped_func, tarball.strip())
            extraction_futures.append(f)
        for future in concurrent.futures.as_completed(extraction_futures):
            result = future.result()
            tarballs_and_converted_volumes.append(result)
            LOG.info("Finished extracting files from {}\n".format(tarball))
        for tarball, volume in tarballs_and_converted_volumes:
            LOG.info("Converting BLAST volume {} to FASTA\n".format(volume))
            convert_blast_to_fasta(args, volume, tarball)
            LOG.info(
                "Finished converting BLAST volume {} to FASTA\n"
                .format(volume)
            )
    library_extension = ".faa" if args.protein else ".fna"
    library_filename = "library" + library_extension
    LOG.info("Generating {} from converted volumes\n".format(library_filename))
    with open(library_filename, "w") as lib, \
         open("prelim_map.txt", "w") as plm:
        for _, volume in sorted(tarballs_and_converted_volumes):
            if not os.path.exists(volume + library_extension):
                LOG.error(
                    "Missing volume: {}, "
                    .format(volume + library_extension)
                )
                sys.exit(1)
            with open(volume + library_extension, "r") as in_file:
                shutil.copyfileobj(in_file, lib)
            with open(volume + "_prelim_map.txt", "r") as in_file:
                shutil.copyfileobj(in_file, plm)


def create_manifest_for_blast_db(db_name, volume_numbers, protein=False):
    suffix = "-prot-metadata.json" if protein else "-nucl-metadata.json"
    json_filename = "blast/db/" + db_name + suffix
    http_download_file2(NCBI_SERVER, [json_filename])
    json_filename = os.path.abspath(json_filename)
    with open(json_filename, "r") as in_file:
        data = json.load(in_file)
        with open("manifest.txt", "w") as out_file:
            for volume in data["files"]:
                match = re.search(r"\.(\d+)\.", volume)
                if match:
                    volume_number = match.group(1)
                    if int(volume_number) not in volume_numbers:
                        continue
                path = urllib.parse.urlsplit(volume).path
                out_file.write(path[1:] + "\n")


def extract_blast_db_files(tarball_pathname):
    extract_dirname = os.path.dirname(tarball_pathname)
    volume = None
    with tarfile.open(tarball_pathname, "r:gz") as tar:
        for member in tar.getnames():
            if member.endswith(("nsq", "nin", "nhr")):
                if not volume:
                    volume = os.path.splitext(member)[0]
                tar.extract(member, extract_dirname)
    return (
        tarball_pathname,
        os.path.join(extract_dirname, volume)
    )


def convert_blast_to_fasta(args, volume, tarball):
    extension = ".faa" if args.protein else ".fna"
    volume_dirname = os.path.dirname(volume)
    volume_basename = os.path.basename(volume)
    tarball = os.path.basename(tarball)
    tmp_fasta_filename = volume + extension + ".tmp"
    fasta_filename = volume + extension
    remote_filepath = url_join(
        NCBI_SERVER, path="blast/db/" + tarball
    )
    blast_to_fasta_bin = find_kraken2_binary("blast_to_fasta")
    # blast_to_fasta_argv = ""
    proc = subprocess.Popen([blast_to_fasta_bin, "-s", "-t", volume])
    if proc.wait() != 0:
        LOG.error(
            "Encountered an error while converting BLAST format to FASTA\n"
        )
        sys.exit(1)
    with open(fasta_filename, "r") as in_file:
        prelim_map_name = os.path.join(
            volume_dirname,
            volume_basename + "_prelim_map.txt"
        )
        with open(prelim_map_name, "w") as out_file:
            scan_fasta_file(
                in_file, out_file, lenient=True,
                sequence_to_url=remote_filepath
            )
    if not args.no_masking:
        shutil.move(fasta_filename, tmp_fasta_filename)
        mask_files(
            [tmp_fasta_filename], fasta_filename,
            args.masker_threads, args.protein
        )
        os.remove(tmp_fasta_filename)


def wrapper_args_to_binary_args(opts, argv, binary_args):
    for k, v in vars(opts).items():
        if k not in WRAPPER_ARGS_TO_BIN_ARGS:
            continue
        if WRAPPER_ARGS_TO_BIN_ARGS[k] not in binary_args:
            continue
        if v is False:
            continue
        if v is None:
            continue
        if v is True:
            argv.append(WRAPPER_ARGS_TO_BIN_ARGS[k])
        else:
            argv.extend([WRAPPER_ARGS_TO_BIN_ARGS[k], str(v)])


def find_kraken2_binary(name):
    # search the OS PATH
    if "PATH" in os.environ:
        for dir in os.environ["PATH"].split(":"):
            if os.path.exists(os.path.join(dir, name)):
                return os.path.join(dir, name)
    # search for binary in the same directory as wrapper
    script_parent_directory = get_parent_directory(SCRIPT_PATHNAME)
    if os.path.exists(os.path.join(script_parent_directory, name)):
        return os.path.join(script_parent_directory, name)
    # if called from within kraken2 project root, search the src dir
    project_root = get_parent_directory(script_parent_directory)
    if "src" in os.listdir(project_root) and name in os.listdir(
        os.path.join(project_root, "src")
    ):
        return os.path.join(project_root, os.path.join("src", name))
    # not found in these likely places, exit
    LOG.error("Unable to find {:s}, exiting\n".format(name))
    sys.exit(1)


def get_parent_directory(pathname):
    if len(pathname) == 0:
        return None
    pathname = os.path.abspath(pathname)
    if len(pathname) > 1 and pathname[-1] == os.path.sep:
        return os.path.dirname(pathname[:-1])
    return os.path.dirname(pathname)


def find_database(database_name):
    database_path = None
    if not os.path.isdir(database_name):
        if "KRAKEN2_DB_PATH" in os.environ:
            for directory in os.environ["KRAKEN2_DB_PATH"].split(":"):
                if os.path.exists(os.path.join(directory, database_name)):
                    database_path = os.path.join(directory, database_name)
                    break
        else:
            if database_name in os.listdir(os.getcwd()):
                database_path = database_name
    else:
        database_path = os.path.abspath(database_name)
    if database_path:
        for db_file in ["taxo.k2d", "hash.k2d", "opts.k2d"]:
            if not os.path.exists(os.path.join(database_path, db_file)):
                return None
    return database_path


def remove_files(filepaths, forked=False):
    total_size = 0

    for fname in filepaths:
        if not os.path.exists(fname):
            continue
        elif os.path.isdir(fname):
            with os.scandir(fname) as iter:
                directories = []
                for entry in iter:
                    if entry.is_dir():
                        directories.append(entry.path)
                    else:
                        total_size += os.path.getsize(entry.path)
                        LOG.info("Removing {}\n".format(entry.path))
                        os.remove(entry.path)
                if not forked and len(directories) >= 4:
                    total_size += remove_files_parallel(directories)
                else:
                    total_size += remove_files(directories, forked)
                    for directory in directories:
                        shutil.rmtree(directory)
        else:
            LOG.info("Removing {}\n".format(fname))
            total_size += os.path.getsize(fname)
            os.remove(fname)

    return total_size


def remove_files_parallel(filepaths):
    total_size = 0

    with concurrent.futures.ProcessPoolExecutor(
            max_workers=4,
    ) as pool:
        futures = []
        f = functools.partial(
            wrap_with_globals, remove_files,
            LOG.get_queue(), LOG.get_level(),
            SCRIPT_PATHNAME
        )
        for fname in filepaths:
            if not os.path.exists(fname):
                continue
            future = pool.submit(f, [fname], True)
            futures.append(future)

        for future in concurrent.futures.as_completed(futures):
            total_size += future.result()

    return total_size


def get_taxid_from_seqid(seqid):
    taxid = None
    match = re.search(r"(?:^|\|)kraken:taxid\|(\d+)", seqid)
    if match:
        taxid = match.group(1)
    elif re.match(r"^\d+$", seqid):
        taxid = seqid
    if not taxid:
        match = re.search(r"(?:^|\|)([A-Z]+_?[A-Z0-9]+)(?:\||\b|\.)", seqid)
        if match:
            taxid = match.group(1)
    return taxid


def hash_string(string):
    md5 = hashlib.md5()
    md5.update(string.encode())
    return md5.hexdigest()


def hash_file(filename, buf_size=8192):
    LOG.info("Calculating MD5 sum for {}\n".format(filename))
    md5 = hashlib.md5()
    with open(filename, "rb") as in_file:
        while True:
            data = in_file.read(buf_size)
            if not data:
                break
            md5.update(data)
    digest = md5.hexdigest()
    LOG.info("MD5 sum of {} is {}\n".format(filename, digest))
    return digest


# This function is part of the Kraken 2 taxonomic sequence
# classification system.
#
# Reads multi-FASTA input and examines each sequence header.  Headers are
# OK if a taxonomy ID is found (as either the entire sequence ID or as part
# of a "kraken:taxid" token), or if something looking like an accession
# number is found.  Not "OK" headers will are fatal errors unless "lenient"
# is used.
#
# Each sequence header results in a line with three tab-separated values;
# the first indicating whether third column is the taxonomy ID ("TAXID") or
# an accession number ("ACCNUM") for the sequence ID listed in the second
# column.
#
def scan_fasta_file(
    in_file, out_file, lenient=False, sequence_to_url=None
):
    LOG.info("Generating prelim_map.txt for {}.\n".format(in_file.name))
    iterator = in_file
    iterator_is_dict = False
    if type(sequence_to_url) is dict:
        iterator = sequence_to_url
        iterator_is_dict = True
    for line in iterator:
        if not line.startswith(">"):
            continue
        remote_filepath = sequence_to_url
        if iterator_is_dict:
            remote_filepath = sequence_to_url[line]
        for match in re.finditer(r"(?:^>|\x01)(\S+)(?: (.*))?", line):
            seqid = match.group(1)
            taxid = get_taxid_from_seqid(seqid)
            comment = match.group(2) or ""
            if not taxid:
                if lenient:
                    continue
                else:
                    sys.exit(1)
            if re.match(r"^\d+$", taxid):
                out_file.write(
                    "TAXID\t{:s}\t{:s}\t{:s}\t{:s}\n".format(
                        seqid, taxid, comment, remote_filepath
                    )
                )
            else:
                out_file.write(
                    "ACCNUM\t{:s}\t{:s}\t{:s}\t{:s}\n".format(
                        seqid, taxid, comment, remote_filepath
                    )
                )
    LOG.info(
        "Finished generating prelim_map.txt for {}.\n".format(in_file.name)
    )


# This function is part of the Kraken 2 taxonomic sequence
# classification system.
#
# Looks up accession numbers and reports associated taxonomy IDs
#
# `lookup_list_file` is 1 2-column TSV file w/ sequence IDs and
# accession numbers, and `accession_map_files` is a list of
# accession2taxid files from NCBI.  Output is tab-delimited lines,
# with sequence IDs in first column and taxonomy IDs in second.
#
def lookup_accession_numbers(
    lookup_list_filename, out_filename, *accession_map_files
):
    target_lists = {}
    with open(lookup_list_filename, "r") as f:
        for line in f:
            line = line.strip()
            seqid, acc_num = line.split("\t")
            if acc_num in target_lists:
                target_lists[acc_num].append(seqid)
            else:
                target_lists[acc_num] = [seqid]
    initial_target_count = len(target_lists)
    with open(out_filename, "a") as out_file:
        for filename in accession_map_files:
            with open(filename, "r") as in_file:
                in_file.readline()  # discard header line
                line_count = 0
                for line in in_file:
                    line_count += 1
                    line = line.strip()
                    split = line.split("\t")
                    if len(split) != 4:
                        LOG.warning(
                            "{}:{}-'{}' contains fewer than 4 fields\n"
                            .format(filename, line_count, line)
                        )
                        continue
                    accession, with_version, taxid, gi = split
                    if accession in target_lists:
                        lst = target_lists[accession]
                        del target_lists[accession]
                        for seqid in lst:
                            out_file.write(seqid + "\t" + taxid + "\n")
                        if len(target_lists) == 0:
                            break
            if len(target_lists) == 0:
                break
    if target_lists:
        LOG.warning(
            "{}/{} accession numbers remain unmapped, "
            "see unmapped_accessions.txt in {} directory\n"
            .format(len(target_lists), initial_target_count,
                    os.path.abspath(os.curdir))
        )
        with open("unmapped_accessions.txt", "w") as f:
            for k in target_lists:
                f.write(k + "\n")


def spawn_masking_subprocess(output_file, threads, protein=False):
    masking_binary = "segmasker" if protein else "k2mask"
    if "MASKER" in os.environ:
        masking_binary = os.environ["MASKER"]
    masking_binary = find_kraken2_binary(masking_binary)

    argv = masking_binary + " -outfmt fasta | sed -e '/^>/!s/[a-z]/x/g'"
    if masking_binary.find("k2mask") >= 0:
        # k2mask can run multithreaded
        argv = masking_binary + " -outfmt fasta -threads {} -r x".format(
            threads
        )

    cwd = os.path.dirname(os.path.abspath(output_file.name))
    p = subprocess.Popen(
        argv, shell=True, cwd=cwd,
        stdin=subprocess.PIPE, stdout=output_file
    )

    return p


# Mask low complexity sequences in the database
def mask_files(input_filenames, output_filename, threads, protein=False):
    with open(output_filename, "wb") as fout:
        masker = spawn_masking_subprocess(fout, threads, protein)
        # number_of_files = len(input_filenames)
        for i, input_filename in enumerate(input_filenames):
            library_name = os.path.basename(os.getcwd())
            if "blast" in output_filename:
                library_name = "blast"
            if library_name == "added":
                LOG.info(
                    "Masking low-complexity regions of added "
                    "library {}\n".format(input_filename)
                )
            elif library_name == "blast":
                LOG.info(
                    "Masking low-complexity regions for blast "
                    "volume {}\n".format(output_filename)
                )
            else:
                LOG.info(
                    "Masking low-complexity regions of downloaded "
                    "library {:s}\n".format(library_name)
                )
            with open(input_filename, "rb") as fin:
                shutil.copyfileobj(fin, masker.stdin)
                # masker(fin, i + 1 == number_of_files)
        masker.stdin.close()
        if masker.wait() != 0:
            LOG.error("Error while masking {}\n".format(input_filename))


def add_file(args, filename, hashes):
    already_added = False
    filehash = None
    if filename in hashes:
        already_added = True

    filehash = hashes.get(filename) or hash_file(filename)
    destination = os.path.basename(filename)
    ext = ".faa" if args.protein else ".fna"
    base, _ = os.path.splitext(destination)
    destination = base + "_" + filehash + ext
    if already_added:
        LOG.info(
            "Already added " + filename + " to library. "
            "Please remove the entry from added.md5 if this"
            " is not the case.\n"
        )
        return (filename, filehash, destination)

    LOG.info("Adding " + filename + " to library " + args.db + "\n")
    prelim_map_filename = "prelim_map_" + filehash + ".txt"
    with open(prelim_map_filename, mode="a") as out_file:
        with open(filename, "r") as in_file:
            scan_fasta_file(
                in_file, out_file, lenient=True, sequence_to_url=filename
            )
        shutil.copyfile(filename, destination)

    if not args.no_masking:
        mask_files(
            [destination],
            destination + ".masked",
            threads=args.masker_threads,
            protein=args.protein,
        )
        shutil.move(destination + ".masked", destination)
    LOG.info("Added " + filename + " to library " + args.db + "\n")

    return (filename, filehash, destination)


def add_to_library(args):
    if not os.path.isdir(args.db):
        LOG.error("Invalid database: {:s}\n".format(args.db))
        sys.exit(1)
    library_pathname = os.path.join(args.db, "library")
    added_pathname = os.path.join(library_pathname, "added")
    os.makedirs(added_pathname, exist_ok=True)
    args.files = [os.path.abspath(f) for f in args.files]
    os.chdir(added_pathname)
    hashes = {}
    if os.path.exists("added.md5"):
        with open("added.md5", "r") as in_file:
            hashes = dict([line.split()[:2] for line in in_file.readlines()])
    with concurrent.futures.ProcessPoolExecutor(
            max_workers=args.threads
    ) as pool:
        futures = []
        files = map(lambda f: glob.glob(f, recursive=True), args.files)
        for filename in itertools.chain(*files):
            f = functools.partial(
                wrap_with_globals, add_file,
                LOG.get_queue(), LOG.get_level(),
                SCRIPT_PATHNAME
            )
            future = pool.submit(f, args, filename, hashes)
            if future_raised_exception(future):
                LOG.error(
                    "Error while adding file to library\n"
                )
                raise future.exception()
            futures.append(future)
        with open("added.md5", "a") as out_file:
            for future in concurrent.futures.as_completed(futures):
                result = future.result()
                (filename, filehash, destination) = result
                out_file.write(
                    filename + "\t" + filehash + "\t" + destination + "\n"
                )


def make_manifest_from_assembly_summary(
        args, assembly_summary_file
):
    asm_level_regex = "|".join(args.assembly_levels).replace("_", " ")
    suffix = "_protein.faa.gz" if args.protein else "_genomic.fna.gz"
    manifest_to_taxid = {}
    for line in assembly_summary_file:
        if line.startswith("#"):
            continue
        fields = line.strip().split("\t")
        taxid, asm_level, ftp_path = fields[5], fields[11], fields[19]
        if not re.match(asm_level_regex, asm_level, re.IGNORECASE):
            continue
        if ftp_path == "na":
            continue
        remote_path = ftp_path + "/" + os.path.basename(ftp_path) + suffix
        url_components = urllib.parse.urlsplit(remote_path)
        local_path = url_components.path.replace("/", "", 1)
        manifest_to_taxid[local_path] = taxid
    with open("manifest.txt", "w") as f:
        for k in manifest_to_taxid:
            f.write(k + "\n")
    return manifest_to_taxid


def assign_taxids(args, filepath, manifest_to_taxid,
                  accession_to_taxid={}, filepath_to_url={}):
    absolute_filepath = os.path.abspath(filepath)
    sequences_added = 0
    ch_added = 0
    # taxid = manifest_to_taxid[filepath]
    out_filepath = ""
    if absolute_filepath.endswith(".gz"):
        out_filepath = os.path.splitext(absolute_filepath)[0]
    else:
        out_filepath = absolute_filepath + ".tmp"
    masker = None
    sequence_to_url = {}
    os.makedirs(os.path.dirname(absolute_filepath), exist_ok=True)
    with open(out_filepath, "w") as out_file:
        if not args.no_masking:
            masker = spawn_masking_subprocess(
                out_file, args.masker_threads, False
            )
        opener = open
        if absolute_filepath.endswith(".gz"):
            opener = gzip.open
        with opener(absolute_filepath, "rt") as in_file:
            while True:
                line = in_file.readline()
                if line == "":
                    break
                if line.startswith(">"):
                    taxid = manifest_to_taxid[filepath]
                    if not taxid:
                        match = re.search(r"GC[AF]_[0-9]{9}\.\d+", line)
                        if not match or match.group(0) not in accession_to_taxid:
                            LOG.error(
                                "Unable to assign taxid to sequence {} in file {}\n"
                                .format(line, in_file.name)
                            )
                            sys.exit(1)
                        taxid = accession_to_taxid[match.group(0)]
                    line = line.replace(">", ">kraken:taxid|" + taxid + "|", 1)
                    sequence_to_url[line] = filepath_to_url[filepath]
                    sequences_added += 1
                else:
                    ch_added += len(line) - 1
                if not masker:
                    out_file.write(line)
                else:
                    masker.stdin.write(line.encode())
                taxid = ""
        if out_filepath.endswith(".tmp"):
            shutil.move(out_filepath, absolute_filepath)
        if masker:
            masker.stdin.close()
            masker.wait()

    return (sequences_added, ch_added, sequence_to_url)


def download_dataset_by_project(args, endpoint, identifiers):
    library_pathname = os.path.join(args.db, "library")
    os.makedirs(library_pathname, exist_ok=True)
    os.chdir(library_pathname)
    oldwd = os.path.curdir

    for identifier in identifiers:
        dirname = identifier.lower().replace(" ", "_")
        os.makedirs(dirname, exist_ok=True)
        os.chdir(dirname)
        download_and_process_accessions(args, endpoint, identifier)
        os.chdir(oldwd)


def download_and_process_accessions(args, endpoint, identifier):
    api = http.client.HTTPSConnection(NCBI_REST_API)
    identifier = urllib.parse.quote(identifier)
    accession_to_taxid = {}
    builder = NCBI_URI_Builder(endpoint, identifier, "dataset_report")
    builder.set_filters_from_args(args)
    builder = builder.page_size(500)
    old_page_token = ""

    while True:
        api.request("GET", builder.build())
        response = api.getresponse()
        if response.status == 429:
            LOG.warning(
                "Connection is being rate limited by NCBI, backing off\n"
            )
            time.sleep(1)
            response.close()
            continue
        if response.status != 200:
            LOG.error(
                "Encountered an error while trying to gather "
                "accessions."
            )
            sys.exit(1)
        results = response.readlines()[0]
        results = json.loads(results)
        response.close()

        if not results:
            LOG.error(
                "Could not find any accessions matching the query: {}\n"
                .format(identifier)
            )
            sys.exit(1)

        for report in results["reports"]:
            accession_to_taxid[report["accession"]] =\
                report["organism"]["tax_id"]
        if "next_page_token" in results \
           and results["next_page_token"] != old_page_token:
            builder = builder.page_token(results["next_page_token"])
            old_page_token = results["next_page_token"]
        else:
            break

    api.close()
    LOG.info(
        "Found {} accession(s) associated with {}\n"
        .format(len(accession_to_taxid), identifier)
    )

    accessions = list(accession_to_taxid.keys())
    accession_to_url =\
        map_accessions_to_url_parallel(args, accessions)

    filepath_to_taxid = {}
    with open("manifest.txt", "w") as fout:
        for accession, url in accession_to_url.items():
            # exclude the leading /
            filepath = urllib.parse.urlparse(url).path[1:]
            fout.write(filepath + "\n")
            taxid = accession_to_taxid[accession]
            filepath_to_taxid[filepath] = str(taxid)

    download_files_from_manifest(NCBI_SERVER, args.threads, resume=args.resume)

    filepath_to_url = {}
    for filepath in filepath_to_taxid.keys():
        accession = re.search(r"GC[AF]_\d{9}\.\d+", filepath).group()
        filepath_to_url[filepath] = accession_to_url[accession]
    sequence_to_url = assign_taxid_to_sequences(
        args, filepath_to_taxid, filepath_to_url=filepath_to_url
    )

    library_filename = "library.faa" if args.protein else "library.fna"
    with open(library_filename, "r") as in_file:
        with open("prelim_map.txt", "w") as out_file:
            out_file.write("# prelim_map for " + args.library + "\n")
            scan_fasta_file(
                in_file,
                out_file,
                sequence_to_url=sequence_to_url,
            )


def download_accessions(args, accession_to_taxid):
    accessions = sorted(list(accession_to_taxid.keys()))
    filepath_to_taxid = {}
    zip_file_list = []

    zip_filename_to_accessions = {}
    extension = ".faa" if args.protein else ".fna"
    if os.path.exists("zips"):
        downloaded_accessions = {}
        for zip_filename in os.listdir("zips"):
            zip_filename = os.path.join("zips", zip_filename)
            if zip_filename.endswith(".zip"):
                with zipfile.ZipFile(zip_filename) as zip:
                    for filename in zip.namelist():
                        match = re.search(r"GC[AF]_\d{9}\.\d+", filename)
                        if match and filename.endswith(extension):
                            accession = match.group()
                            downloaded_accessions[accession] = zip_filename
                zip_file_list.append(zip_filename)

        if downloaded_accessions:
            unfetched_accessions = []
            for accession in accessions:
                if accession in downloaded_accessions:
                    LOG.info(
                        "Already downloaded accession: {}, skipping\n"
                        .format(accession)
                    )
                    zip_filename = downloaded_accessions[accession]
                    if zip_filename in zip_filename_to_accessions:
                        zip_filename_to_accessions[zip_filename].append(accession)
                    else:
                        zip_filename_to_accessions[zip_filename] = [accession]
                else:
                    unfetched_accessions.append(accession)
            accessions = unfetched_accessions

    partitions = []
    if accessions:
        number_of_partitions = math.ceil(len(accessions) / 400)
        number_of_partitions = max(args.threads, number_of_partitions)
        partitions = partition_list(accessions, number_of_partitions)

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as pool:
        download_futures = []
        for partition in partitions:
            download_futures.append(
                pool.submit(download_zip_from_ncbi, args, partition)
            )
        unzip_futures = []
        for future in concurrent.futures.as_completed(download_futures):
            unzip_futures.append(
                pool.submit(
                    extract_fastas_from_zip_file,
                    future.result(), args.protein
                )
            )
        for zip_filename, accessions in zip_filename_to_accessions.items():
            unzip_futures.append(
                pool.submit(
                    extract_fastas_from_zip_file,
                    zip_filename, args.protein,
                    accessions
                )
            )
        for future in concurrent.futures.as_completed(unzip_futures):
            accession_to_filepath = future.result()
            for accession, filepath in accession_to_filepath.items():
                filepath_to_taxid[filepath] = str(accession_to_taxid[accession])

    shutil.rmtree("ncbi_dataset", ignore_errors=True)

    return filepath_to_taxid


def extract_fastas_from_zip_file(filename, protein, accessions_to_extract=None):
    accession_to_filepath = {}

    suffix = "_protein.faa" if protein else "_genomic.fna"
    with zipfile.ZipFile(filename) as zip:
        for entry in zip.namelist():
            if entry.endswith(suffix):
                dir_components = ["genomes", "all"]
                filename = os.path.basename(entry)
                dirname = filename.replace(suffix, "")
                accession = re.search(r"GC[AF]_\d{9}\.\d+", filename).group()
                if accessions_to_extract and\
                   accession not in accessions_to_extract:
                    continue
                modified_accession = accession.split(".")[0].replace("_", "")
                dir_components.extend(partition_list(modified_accession, 4))
                dir_components.append(dirname)
                dir_components.append(filename)

                filepath = os.path.join("", *dir_components)
                entry_size = zip.getinfo(entry).file_size
                if os.path.exists(filepath) and os.stat(filepath).st_size == entry_size:
                    LOG.info(
                        "Already extracted {} from {}\n"
                        .format(entry, os.path.abspath(filename))
                    )
                else:
                    LOG.info(
                        "Extracting {} from {}\n".format(entry, os.path.abspath(filename))
                    )
                    zip.extract(entry, os.path.curdir)
                    os.makedirs(os.path.dirname(filepath), exist_ok=True)
                    LOG.debug(
                        "Moving {} to {}\n".format(
                            os.path.abspath(entry), os.path.abspath(filepath)
                        )
                    )
                    shutil.move(entry, filepath)
                accession_to_filepath[accession] = filepath

    return accession_to_filepath


def download_zip_from_ncbi(args, accessions):
    api = http.client.HTTPSConnection(NCBI_REST_API)
    accessions = ",".join(accessions)
    md5 = hash_string(accessions)
    os.makedirs("zips", exist_ok=True)
    filename = os.path.join("zips", md5 + ".zip")
    tmp_filename = filename + ".tmp"
    if os.path.exists(os.path.join("zip", filename)):
        LOG.info(
            "Already downloaded {} from NCBI which contains accessions: {}\n"
            .format(os.path.basename(filename), accessions)
        )
        return filename
    LOG.info(
        "Downloading {} from NCBI containing the following accessions: {}\n"
        .format(os.path.basename(filename), accessions)
    )
    accessions = urllib.parse.quote(accessions)
    annotation_type = "PROT_FASTA" if args.protein else "GENOME_FASTA"
    builder = NCBI_URI_Builder("accession", accessions, "download")
    builder = builder.include_annotation_type(annotation_type)

    api.request("GET", builder.build())
    res = api.getresponse()

    with open(tmp_filename, "wb") as fout:
        shutil.copyfileobj(res, fout)

    res.close()
    shutil.move(tmp_filename, filename)
    LOG.info("Saved {} to {}\n".format(
        os.path.basename(filename),
        os.path.abspath(filename)
    ))

    return filename


def partition_list(list, num_partitions):
    partitions = []
    length = len(list)
    if length == 0:
        return []
    step = math.ceil(length / num_partitions)
    if step == 0:
        return []
    for i in range(0, length, step):
        end = i + step
        if end > len(list) or num_partitions == 1:
            end = len(list)
        partitions.append(list[i:end])
        if num_partitions == 1:
            break
        num_partitions -= 1

    return partitions


def map_accessions_to_url(accessions, protein=False):
    api = http.client.HTTPSConnection(NCBI_REST_API)

    size = len(accessions)
    start = 0
    step = 100
    stop = min(size, step)
    base_uri = "/datasets/v2/genome/accession/{}/links"
    accession_to_url = {}

    while True:
        query = ",".join(accessions[start:stop])
        query = urllib.parse.quote(query)
        headers = {'accept': 'application/json'}
        api.request("GET", base_uri.format(query), headers=headers)
        res = api.getresponse()
        if res.status != 200:
            res.close()
            time.sleep(1)
            api.connect()
            continue

        results = res.readlines()[0]
        res.close()
        results = json.loads(results)

        for entry in results["assembly_links"]:
            if entry["assembly_link_type"] == "FTP_LINK":
                filepath = get_download_path(
                    entry["resource_link"], protein
                )
                accession_to_url[entry["accession"]] =\
                    url_join(NCBI_SERVER, path=filepath)
        start, stop = stop, min(size, stop + step)
        if start == stop:
            break
    api.close()

    return accession_to_url


def map_accessions_to_url_parallel(args, accessions):
    futures = []
    accession_to_url = {}
    LOG.info("Fetching download links for accessions\n")
    # We are limited to 5 requests per second, so we limit the number of
    # workers accordingly.
    workers = min(args.threads, 5)
    with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as pool:
        for partition in partition_list(accessions, args.threads):
            future = pool.submit(
                map_accessions_to_url, partition, args.protein
            )
            futures.append(future)
    for future in concurrent.futures.as_completed(futures):
        accession_to_url.update(future.result())
    # accession_to_url = map_accessions_to_url(accessions, args.protein)
    LOG.info("Finished fetching download links for accessions\n")

    return accession_to_url


def get_download_path(resource_link, protein):
    dir = os.path.basename(resource_link)
    filename = dir + "_genomic.fna.gz"
    resource_link += "/"
    if protein:
        filename = dir + "_protein.faa.gz"
    resource_link = urllib.parse.urljoin(resource_link, filename)
    path = urllib.parse.urlparse(resource_link).path
    # conn.request("HEAD", path)
    # response = conn.getresponse()
    # if response.status != 200:
    #     LOG.error("Unable to find ...")
    #     sys.exit(1)
    # response.close()

    return path


def assign_taxid_to_sequences(args,
                              manifest_to_taxid,
                              accession_to_taxid={},
                              filepath_to_url={}):
    if args.no_masking:
        LOG.info("Assigning taxonomic IDs to sequences\n")
    else:
        LOG.info(
            "Assigning taxonomic IDs and masking sequences\n"
        )
    library_filename = "library.faa" if args.protein else "library.fna"
    sequence_to_url = {}
    projects_added = 0
    total_projects = len(manifest_to_taxid)
    sequences_added = 0
    ch_added = 0
    ch = "aa" if args.protein else "bp"

    out_line = progress_line(
        projects_added, total_projects, sequences_added, ch_added, ch
    )

    LOG.debug("{:s}\r".format(out_line))
    filepaths = sorted(manifest_to_taxid)
    with concurrent.futures.ProcessPoolExecutor(
            max_workers=args.threads
    ) as pool:
        futures = []
        max_out_line_len = 0
        for filepath in filepaths:
            f = functools.partial(
                wrap_with_globals, assign_taxids,
                LOG.get_queue(), LOG.get_level(),
                SCRIPT_PATHNAME
            )
            future = pool.submit(
                f, args, filepath,
                manifest_to_taxid, accession_to_taxid,
                filepath_to_url
            )
            if future_raised_exception(future):
                LOG.error(
                    "Error encountered while assigning tax IDs\n"
                )
                raise future.exception()
            futures.append(future)
        for future in concurrent.futures.as_completed(futures):
            result = future.result()
            sequences_added += result[0]
            ch_added += result[1]
            projects_added += 1
            sequence_to_url.update(result[2])
            out_line = progress_line(
                projects_added, total_projects, sequences_added, ch_added, ch
            )
            max_out_line_len = max(len(out_line), max_out_line_len)
            padding = " " * (max_out_line_len - len(out_line))
            LOG.debug("{:s}\r".format(out_line + padding))

    if args.no_masking:
        LOG.info("Finished assigning taxonomic IDs to sequences\n")
    else:
        LOG.info("Finished assigning taxonomic IDs and masking sequences\n")

    LOG.info("Generating {:s}\n".format(library_filename))
    with open(library_filename, "w") as out_file:
        for filepath in filepaths:
            if filepath.endswith(".gz"):
                filepath = os.path.splitext(filepath)[0]
            with open(filepath, "r") as in_file:
                shutil.copyfileobj(in_file, out_file)

    LOG.info("Finished generating {:s}\n".format(library_filename))

    return sequence_to_url


def progress_line(projects, total_projects, seqs, chars, ch):
    line = "Processed "
    if projects == total_projects:
        line += str(projects)
    else:
        line += "{:d}/{:d}".format(projects, total_projects)
    line += " project(s), {:d} sequence(s), ".format(seqs)
    prefix = None
    for p in ["k", "M", "G", "T", "P", "E"]:
        if chars >= 1024:
            prefix = p
            chars /= 1024
        else:
            break
    if prefix:
        line += "{:.2f} {:s}{:s}".format(chars, prefix, ch)
    else:
        line += "{:.2f} {:s}".format(chars, ch)
    return line


# The following three functions have dummy return values. This is
# so that we can check whether a future stopped running as a
# result of an exception, so that we can stop the main process
# early.
# def decompress_files(compressed_filenames, out_filename=None, buf_size=8192):
#     if isinstance(compressed_filenames, str):
#         compressed_filenames = [compressed_filenames]
#     if out_filename:
#         if os.path.exists(out_filename + ".tmp"):
#             os.remove(out_filename + ".tmp")
#         with open(out_filename + ".tmp", "ab") as out_file:
#             for filename in compressed_filenames:
#                 with gzip.open(filename) as gz:
#                     decompress_file(gz, out_file)
#             os.rename(out_filename + ".tmp", out_filename)
#     else:
#         for filename in compressed_filenames:
#             out_filename, ext = os.path.splitext(filename)
#             if os.path.exists(out_filename + ".tmp"):
#                 os.remove(out_filename + ".tmp")
#             with gzip.open(filename) as gz:
#                 with open(out_filename + ".tmp", "wb") as out:
#                     decompress_file(gz, out, buf_size)
#             os.rename(out_filename + ".tmp", out_filename)

#     return True


def decompress_files(compressed_filenames, out_filename=None, buf_size=8192):
    if isinstance(compressed_filenames, str):
        compressed_filenames = [compressed_filenames]
    if out_filename:
        if os.path.exists(out_filename + ".tmp"):
            os.remove(out_filename + ".tmp")
        with open(out_filename + ".tmp", "ab") as out_file:
            for filename in compressed_filenames:
                with open(filename, "rb") as gz:
                    decompress_file(gz, out_file)
            os.rename(out_filename + ".tmp", out_filename)
    else:
        for filename in compressed_filenames:
            out_filename, ext = os.path.splitext(filename)
            if os.path.exists(out_filename + ".tmp"):
                os.remove(out_filename + ".tmp")
            with open(filename, "rb") as gz:
                with open(out_filename + ".tmp", "wb") as out:
                    decompress_file(gz, out, buf_size)
            os.rename(out_filename + ".tmp", out_filename)

    return True


def decompress_file(in_file, out_file, buf_size=8129):
    LOG.info(
        "Decompressing {:s}\n".format(os.path.join(os.getcwd(), in_file.name))
    )
    inflator = zlib.decompressobj(15 + 32)
    while True:
        data = in_file.read(buf_size)
        if not data:
            break
        inflated_data = inflator.decompress(data)
        out_file.write(inflated_data)
    # shutil.copyfileobj(in_file, out_file, buf_size)
    LOG.info(
        "Finished decompressing {:s}\n".format(
            os.path.join(os.getcwd(), in_file.name)
        )
    )

    return True


def decompress_and_mask(filepath, masker_threads):
    out_filepath = os.path.splitext(filepath)[0]
    with open(out_filepath, "w") as out_file:
        masker = spawn_masking_subprocess(out_file, masker_threads)
        with open(filepath, "rb") as in_file:
            decompress_file(in_file, masker.stdin)
        masker.stdin.close()
        masker.wait()

    return True


def download_log(filename, total_size=None):
    pb = None
    current_size = 0

    def inner(block_number, read_size, size):
        nonlocal pb, current_size, total_size
        if not pb:
            pb = ProgressBar(total_size or size)
        current_size += read_size
        pb.progress(current_size)
        LOG.debug(
            "{:s} {: >10s}\r".format(pb.get_bar(), format_bytes(current_size))
        )

    return inner


def http_download_file(url, local_name=None, call_back=None):
    if not local_name:
        local_name = urllib.parse.urlparse(url).path.split("/")[-1]
    else:
        local_name = os.path.abspath(local_name)
        os.makedirs(os.path.dirname(local_name), exist_ok=True)
    with urllib.request.urlopen(url) as conn:
        remote_size = int(conn.headers["Content-Length"])
        local_size = (
            os.stat(local_name).st_size if os.path.exists(local_name) else 0
        )
        if local_size == remote_size:
            LOG.info(
                "Already downloaded {:s}\n".format(get_abs_path(local_name))
            )
            return

    LOG.info("Beginning download of {:s}\n".format(url))
    urllib.request.urlretrieve(
        url, local_name, reporthook=(call_back or download_log(local_name))
    )
    clear_console_line()
    LOG.info("Saved {:s} to {:s}\n".format(local_name, os.getcwd()))


def http_download_file2(server, urls, save_to=None, md5sums=None):
    conn = None
    if isinstance(server, str):
        conn = http.client.HTTPSConnection(server, timeout=60)
    else:
        conn = server
        server = conn.host
    md5 = md5sums if md5sums else {}
    i = 0
    num_urls = len(urls)
    skip_md5_check = False
    while i < num_urls:
        url = urls[i].strip()
        filename = os.path.basename(url)
        local_name = os.path.abspath(url)
        if save_to:
            local_name = os.path.join(save_to, filename)
        tmp_local_name = local_name + ".tmp"
        local_directory = os.path.dirname(local_name)
        os.makedirs(local_directory, exist_ok=True)

        try:
            if filename not in md5:
                checksums = []
                remote_dirname = os.path.dirname(url)
                for md5_filename in ["md5checksums.txt", filename + ".md5", "MD5SUM.txt"]:
                    # Check if file exists before trying to download.
                    # This avoids NCBI sending BadStatusLine when
                    # making the request.
                    conn.request(
                        "HEAD", "/" + remote_dirname + "/" + md5_filename
                    )
                    response = conn.getresponse()
                    if response.status == 200:
                        response.close()
                        # If we have found the file, then go ahead
                        # and download.
                        conn.request(
                            "GET", "/" + remote_dirname + "/" + md5_filename
                        )
                        response = conn.getresponse()
                        checksums = response.readlines()
                        response.close()
                        break
                    else:
                        response.close()
                if len(checksums) > 0:
                    for checksum in checksums:
                        (md5sum, remote_filename) = checksum.split()
                        remote_filename = os.path.basename(
                            remote_filename.decode()
                        )
                        md5[remote_filename] = md5sum.decode()
            if not skip_md5_check and os.path.exists(local_name)\
               and filename in md5 and md5[filename] == hash_file(local_name):
                LOG.info(
                    "Already downloaded {:s}\n".format(
                        urllib.parse.urljoin(server, url)
                    )
                )
                # Server can potentially end the connection while we
                # waiting for the md5 hash to be computed. We have
                # to reconnect to avoid failures when trying to
                # retrieve the file.
                conn.connect()
                i += 1
                continue
            LOG.info("Beginning download of {:s}\n".format(server + "/" + url))
            with open(tmp_local_name, "wb") as out_file:
                conn.request("GET", "/" + url)
                response = conn.getresponse()
                if response.status == 200:
                    shutil.copyfileobj(response, out_file, 8192)
                elif response.status == 404:
                    LOG.warning(
                        "Cannot find file: {}.\n"
                        "Please report this issue to NCBI.\n".format(url)
                    )
                else:
                    LOG.error(
                        "Error downloading file: {}.\n"
                        "Reason: {}\n".format(url, response.reason)
                    )
                    response.read()
                    sys.exit(1)
                response.close()
                shutil.move(tmp_local_name, local_name)

        except (http.client.HTTPException, http.client.RemoteDisconnected) as e:
            LOG.warning(
                "Unable to download " + url + ". Reason: {}, will try again\n"
                .format(e)
            )
            conn.close()
            time.sleep(0.1)
            conn.connect()
            continue

        # Check the MD5 sum of the downloaded file to make sure that it
        # was downloaded successfully. This prevents issues when
        # assigning tax IDs to files.
        local_md5sum = hash_file(local_name)
        if filename in md5:
            if md5[filename] != local_md5sum:
                LOG.warning(
                    "The MD5 sum of {} does not match the MD5 provided"
                    " by the server. The file will be downloaded again.\n"
                    .format(local_name)
                )
                # We have already confirmed that MD5 sum does not match
                # do not bother checking it again at the top of the loop.
                skip_md5_check = True
                conn.connect()
                continue
            else:
                LOG.info(
                    "The remote and local MD5 sum of {} match\n"
                    .format(local_name)
                )

        LOG.info(
            "Saved {:s} to {:s}\n".format(filename, local_directory)
        )
        # Reset the MD5 check for the next file.
        skip_md5_check = False
        i += 1


def make_file_filter(file_handle, regex):
    def inner(listing):
        path = listing.split()[-1]
        if path.endswith(regex):
            file_handle.write(path + "\n")

    return inner


def move(src, dst):
    src = os.path.abspath(src)
    dst = os.path.abspath(dst)
    if os.path.isfile(src) and os.path.isdir(dst):
        dst = os.path.join(dst, os.path.basename(src))
    shutil.move(src, dst)


def get_manifest_and_md5sums(server, remote_directory, regex):
    ftp = ftplib.FTP(server)
    ftp.login()

    sstream = io.StringIO()
    ftp.cwd(remote_directory)
    ftp.retrlines("LIST", callback=make_file_filter(sstream, regex))

    with open("manifest.txt", "w") as out:
        for line in sstream.getvalue().split():
            out.write(urllib.parse.urljoin(remote_directory, line) + "\n")

    sstream.truncate(0)
    sstream.seek(0)
    ftp.cwd("/refseq/release/release-catalog")
    ftp.retrlines("LIST", callback=make_file_filter(sstream, "installed"))
    install_file = sstream.getvalue().strip()

    bstream = io.BytesIO()

    ftp.retrbinary("RETR " + install_file, callback=bstream.write)
    ftp.close()

    md5sums = {}
    for line in bstream.getvalue().split(b"\n"):
        if line.find(b"plasmid") == -1:
            continue
        (md5, filename) = line.split()
        md5sums[filename.decode()] = md5.decode()
    return md5sums


def download_files_from_manifest(
    server,
    threads=1,
    manifest_filename="manifest.txt",
    filepath_to_taxid_table=None,
    md5sums=None,
    resume=False
):
    threads = min(threads, 12)
    with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool:
        with open(manifest_filename, "r") as f:
            filepaths = f.readlines()
            if resume:
                nonexistent_filepaths = []
                for filepath in filepaths:
                    abs_filepath = os.path.abspath(filepath.strip())
                    if not os.path.exists(abs_filepath):
                        nonexistent_filepaths.append(filepath)
                filepaths = nonexistent_filepaths
            # We try to reduce the risk of a single thread downloading
            # many large files.
            random.shuffle(filepaths)
            partitions = []
            futures = []
            partitions = partition_list(filepaths, threads)
            for partition in partitions:
                future = pool.submit(
                    http_download_file2, server, partition, None, md5sums
                )
                futures.append(future)
            (done, not_done) = concurrent.futures.wait(futures)
            if len(not_done) != 0:
                LOG.error("Error encountered while trying to download files\n")
                for future in not_done:
                    LOG.error(future.exception())
                sys.exit(1)
            # Make sure that all files are downloaded. This has been an issue
            # with large collections like bacteria.
            for filepath in filepaths:
                if not os.path.exists(filepath.strip()):
                    http_download_file2(server, [filepath], None, md5sums)


def download_and_decompress(filename):
    http_download_file2(
        NCBI_SERVER, [filename], save_to=os.path.abspath(os.curdir)
    )
    decompress_files(os.path.abspath(os.path.basename(filename)))


def download_taxonomy(args):
    taxonomy_path = os.path.join(args.db, "taxonomy")
    os.makedirs(taxonomy_path, exist_ok=True)
    os.chdir(taxonomy_path)
    futures = []

    with concurrent.futures.ProcessPoolExecutor(
            max_workers=2
    ) as pool:
        if not args.skip_maps:
            if not args.protein:
                for subsection in ["gb", "wgs"]:
                    filename = "pub/taxonomy/accession2taxid/"
                    filename += "nucl_" + subsection + ".accession2taxid.gz"
                    f = functools.partial(
                        wrap_with_globals, download_and_decompress,
                        LOG.get_queue(), LOG.get_level(),
                        SCRIPT_PATHNAME
                    )
                    future = pool.submit(
                        f, filename
                    )
                    if future_raised_exception(future):
                        LOG.error(
                            "Error encountered while downloading file\n"
                        )
                        raise future.exception()
                    futures.append(future)
            else:
                filename = "/pub/taxonomy/accession2taxid/"
                filename += "prot.accession2taxid.gz"
                f = functools.partial(
                    wrap_with_globals, download_and_decompress,
                    LOG.get_queue(), LOG.get_level(),
                    SCRIPT_PATHNAME
                )
                future = pool.submit(f, filename)
                if future_raised_exception(future):
                    LOG.error(
                        "Error encountered while downloading file\n"
                    )
                    raise future.exception()
                futures.append(future)

        LOG.info("Downloading taxonomy tree data\n")
        filename = "pub/taxonomy/taxdump.tar.gz"
        http_download_file2(
            NCBI_SERVER, [filename], save_to=os.path.abspath(os.curdir)
        )
        LOG.info("Untarring taxonomy tree data\n")
        with tarfile.open("taxdump.tar.gz", "r:gz") as tar:
            tar.extractall()
        LOG.info("Finished Untarring taxonomy tree data\n")
        concurrent.futures.wait(futures)


def download_gtdb_taxonomy(args, files, md5s):
    taxonomy_path = os.path.join(args.db, "taxonomy")
    os.makedirs(taxonomy_path, exist_ok=True)
    os.chdir(taxonomy_path)

    LOG.info("Dowloading GTDB taxonomy for bacteria and archaea\n")
    http_download_file2(
        GTDB_SERVER,
        files,
        save_to=os.path.abspath(os.curdir), md5sums=md5s
    )
    LOG.info("Finished downloading GTDB taxonomy for bacteria and archaea\n")


def build_gtdb_taxonomy(in_file):
    rank_codes = {
        "d": "domain",
        "p": "phylum",
        "c": "class",
        "o": "order",
        "f": "family",
        "g": "genus",
        "s": "species",
    }
    accession_map = {}
    seen_it = collections.defaultdict(int)
    child_data = collections.defaultdict(lambda: collections.defaultdict(int))
    for line in in_file:
        line = line.strip()
        accession, taxonomy_string = line.split("\t")
        start = accession.find("GCA")
        if start < 0:
            start = accession.find("GCF")
        accession = accession[start:]
        taxonomy_string = re.sub("(;[a-z]__)+$", "", taxonomy_string)
        accession_map[accession] = taxonomy_string
        seen_it[taxonomy_string] += 1
        if seen_it[taxonomy_string] > 1:
            continue
        while True:
            match = re.search("(;[a-z]__[^;]+$)", taxonomy_string)
            if not match:
                break
            level = match.group(1)
            taxonomy_string = re.sub("(;[a-z]__[^;]+$)", "", taxonomy_string)
            key = taxonomy_string + level
            child_data[taxonomy_string][key] += 1
            seen_it[taxonomy_string] += 1
            if seen_it[taxonomy_string] > 1:
                break
        if seen_it[taxonomy_string] == 1:
            child_data["root"][taxonomy_string] += 1

    id_map = {}
    next_node_id = 1
    LOG.info("Generating nodes.dmp and names.dmp\n")
    with open("names.dmp", "w") as names_file:
        with open("nodes.dmp", "w") as nodes_file:
            bfs_queue = [["root", 1]]
            while len(bfs_queue) > 0:
                node, parent_id = bfs_queue.pop()
                display_name = node
                rank = None
                match = re.search("([a-z])__([^;]+)$", node)
                if match:
                    rank = rank_codes[match.group(1)]
                    display_name = match.group(2)
                rank = rank or "no rank"
                node_id, next_node_id = next_node_id, next_node_id + 1
                id_map[node] = node_id
                names_file.write(
                    "{:d}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format(
                        node_id, display_name
                    )
                )
                nodes_file.write(
                    "{:d}\t|\t{:d}\t|\t{:s}\t|\t-\t|\n".format(
                        node_id, parent_id, rank
                    )
                )
                children = (
                    sorted([key for key in child_data[node]])
                    if node in child_data
                    else []
                )
                for node in children:
                    bfs_queue.insert(0, [node, node_id])
    with open("gtdb.accession2taxid", "w") as f:
        for accession in sorted([key for key in accession_map]):
            taxid = id_map[accession_map[accession]]
            accession_without_revision = accession.split(".")[0]
            f.write("{:s}\t{:s}\t{:d}\t-\n".format(
                accession_without_revision,
                accession, taxid
            ))


def download_gtdb_genomes(args, remote_filepath, md5s):
    for directory in ["taxonomy", "library"]:
        os.makedirs(directory, exist_ok=True)
    filename = os.path.basename(remote_filepath)
    filepath = os.path.abspath(filename)
    http_download_file2(
        GTDB_SERVER, [remote_filepath],
        save_to=os.curdir, md5sums=md5s
    )
    os.chdir("library")
    accession_to_filepath = {}
    filepaths_without_accession = []
    ext = ".faa" if args.protein else ".fna"
    library = ""
    if filename.endswith(".tar.gz"):
        with tarfile.open(filepath, "r:gz") as tar:
            library = tar.getnames()[0]
            for member in tar.getmembers():
                if member.isfile() and re.search(ext, member.name):
                    if re.search("GC[AF]", member.name):
                        filename = os.path.basename(member.name)
                        accession = re.search(
                            r"(GC[AF]_\d{9}\.\d+)", filename
                        ).group(1)
                        if os.path.exists(member.name)\
                           and member.size == os.stat(member.name).st_size:
                            LOG.info(
                                "Already extracted {}...skipping\n"
                                .format(member.name)
                            )
                        else:
                            LOG.info("Extracting {}\n".format(member.name))
                            tar.extract(member)
                            LOG.info(
                                "Finished extracting {}\n"
                                .format(member.name)
                            )
                        accession_to_filepath[accession] =\
                            os.path.abspath(member.name)
                    else:
                        # We do not check if a file has already been extracted here
                        LOG.info("Extracting {}\n".format(member.name))
                        tar.extract(member)
                        LOG.info(
                            "Finished extracting {}\n".format(member.name)
                        )
                        filepaths_without_accession.append(
                            os.path.abspath(member.name)
                        )
    else:
        library = filename.split(".")[0]
        filepaths_without_accession.append(os.path.abspath(filepath))

    return (library, accession_to_filepath, filepaths_without_accession)


def identity(object):
    return object


def remap_secondary_taxids(taxids):
    new_taxids = {}
    partitions = math.ceil(len(taxids) / 500)
    conn = http.client.HTTPSConnection(NCBI_REST_API)

    for taxid_list in partition_list(taxids, partitions):
        endpoint = NCBI_URI_Builder(
            "taxonomy", "taxon", taxid_list
        ).build()
        conn.request("GET", endpoint)
        response = conn.getresponse()
        if response.status == 200:
            result = response.readlines()[0]
            response.close()
            for entry in json.loads(
                    result, parse_float=identity, parse_int=identity
            )["taxonomy_nodes"]:
                query = entry["query"][0]
                taxid = entry["taxonomy"]["tax_id"]
                new_taxids[query] = taxid
        else:
            response.close()

    return new_taxids


def read_gtdb_metadata(filenames):
    metadata = {}
    for filename in filenames:
        LOG.info("Reading NCBI tax IDs from {}\n".format(filename))
        with gzip.open(filename, "rt") as in_file:
            reader = csv.DictReader(in_file, delimiter="\t")
            for row in reader:
                metadata[row["accession"]] = row["ncbi_taxid"]
        LOG.info("Finished reading NCBI tax IDs from {}\n".format(filename))

    return metadata


def find_and_remap_secondary_taxids(metadata):
    gtdb_assigned_taxids = set()
    for value in metadata.values():
        gtdb_assigned_taxids.add(value)
    ncbi_taxids = set()
    with open("nodes.dmp", "r") as in_file:
        for line in in_file:
            taxid = line.split()[0]
            ncbi_taxids.add(taxid)

    taxids = list(gtdb_assigned_taxids.difference(ncbi_taxids))
    LOG.info(
        "The following tax IDs were not found in nodes.dmp"
        " and need to be remapped\n"
    )
    remapped_taxids = remap_secondary_taxids(taxids)
    for accession, taxid in metadata.items():
        if taxid in remapped_taxids:
            LOG.info(
                "Remapping {} to {}\n".format(taxid, remapped_taxids[taxid])
            )
            metadata[accession] = remapped_taxids[taxid]

    return metadata


def get_gtdb_latest_md5sums(path_prefix):
    try:
        md5_url = url_join(
            GTDB_SERVER, path=path_prefix + "/releases/latest/MD5SUM.txt"
        )
        http_download_file(md5_url)
    except Exception:
        url = url_join(
            GTDB_SERVER, path=path_prefix + "/releases/latest/VERSION.txt"
        )
        # remove the leading 'v' from the version number
        version = urllib.request.urlopen(url).readline().decode().strip()[1:]
        md5_url = url_join(
            GTDB_SERVER,
            path=path_prefix + "/releases/release" + version + "/MD5SUM.txt"
        )
        http_download_file(md5_url)


def get_needed_files(args, path_prefix):
    files_needed = collections.defaultdict(list)
    md5s = {}

    with open("MD5SUM.txt", "r") as in_file:
        # filename_regex = r"genomes|genes|fna"
        filename_regex = "|".join(args.gtdb_files)
        pattern = re.compile(filename_regex)
        candidate_files = []
        if args.protein:
            filename_regex = r"protein|faa"
        for line in in_file:
            md5sum, filepath = line.split()
            filepath = urllib.parse.urljoin(
                path_prefix + "/releases/latest/", filepath
            )
            # remove the release tag since they do not appear in the "latest"
            # file listings
            filepath = re.sub(r"_r\d+", "", filepath)
            if filepath.find("genomic_files") != -1:
                candidate_files.append(os.path.basename(filepath))
            if re.search(r"taxonomy.*\.tsv$", filepath):
                files_needed["taxonomy"].append(filepath)
            if re.search(r".*metadata.*\.tsv.gz$", filepath):
                files_needed["metadata"].append(filepath)
            elif re.search(pattern, filepath):
                files_needed["fasta"].append(filepath)
            filepath = os.path.basename(filepath)
            md5s[filepath] = md5sum

    if len(files_needed["fasta"]) != len(args.gtdb_files):
        LOG.error("At least one of the files did not match: {}\n"
                  .format(", ".join(args.gtdb_files)))
        LOG.error("Here is a list of candidates:\n{}\n"
                  .format("\n".join(candidate_files)))
        sys.exit(1)

    return (md5s, files_needed)


def build_gtdb_database(args):
    global GTDB_SERVER
    db_pathname = os.path.abspath(args.db)
    os.makedirs(db_pathname, exist_ok=True)
    os.chdir(db_pathname)

    path_prefix = "/public/gtdb/data"
    if args.gtdb_server != GTDB_SERVER:
        GTDB_SERVER = args.gtdb_server
        path_prefix = ""

    get_gtdb_latest_md5sums(path_prefix)
    md5s, files_needed = get_needed_files(args, path_prefix)
    download_gtdb_taxonomy(args, files_needed["taxonomy"], md5s)
    os.chdir(os.path.join(db_pathname, "taxonomy"))
    LOG.info("Merging Archaea and Bacteria taxonomies\n")
    with open("merged_taxonomy.tsv", "w") as file_out:
        for tax_filename in files_needed["taxonomy"]:
            tax_filename = os.path.basename(tax_filename)
            with open(tax_filename, "r") as file_in:
                shutil.copyfileobj(file_in, file_out)
    LOG.info("Finished merging Archaea and Bacteria taxonomies\n")
    accession_to_taxid = {}
    if not args.gtdb_use_ncbi_taxonomy:
        with open("merged_taxonomy.tsv", "r") as in_file:
            build_gtdb_taxonomy(in_file)
    else:
        for metadata in files_needed["metadata"]:
            url = url_join(GTDB_SERVER, path=metadata)
            http_download_file(url)
        metadata_filenames = map(os.path.basename, files_needed["metadata"])
        metadata = read_gtdb_metadata(metadata_filenames)
        args.skip_maps = True
        download_taxonomy(args)
        metadata = find_and_remap_secondary_taxids(metadata)
        for accession, taxid in metadata.items():
            accession_to_taxid[accession[3:]] = taxid

    workers = len(files_needed["fasta"])
    futures = []
    accession_to_filepath = {}
    # These files are not tied to a single accession, but
    # instead each sequence in the FASTA has its own accession.

    # os.chdir(os.path.join(db_pathname, "taxonomy"))
    if not args.gtdb_use_ncbi_taxonomy:
        with open("gtdb.accession2taxid", "r") as in_file:
            for line in in_file:
                base_accession, accession, taxid,  gi = line.split("\t")
                accession_to_taxid[accession] = taxid

    with concurrent.futures.ProcessPoolExecutor(
            max_workers=workers
    ) as pool:
        for remote_filepath in files_needed["fasta"]:
            os.chdir(db_pathname)
            f = functools.partial(
                wrap_with_globals, download_gtdb_genomes,
                LOG.get_queue(), LOG.get_level(),
                SCRIPT_PATHNAME
            )
            futures.append(pool.submit(
                f, args, remote_filepath, md5s
            ))
        for future in concurrent.futures.as_completed(futures):
            result = future.result()
            library = result[0]
            filepaths_without_accessions = []
            filepath_to_url = {}

            for accession, filepath in result[1].items():
                accession_to_filepath[accession] = filepath
                filepath_to_url[filepath] =\
                    url_join(GTDB_SERVER, path=remote_filepath)
            for filepath in result[2]:
                filepaths_without_accessions.append(filepath)
                filepath_to_url[filepath] =\
                    url_join(GTDB_SERVER, path=remote_filepath)

            filepath_to_taxid_table = {}
            library_pathname = os.path.join(db_pathname,
                                            os.path.join("library", library))
            os.makedirs(library_pathname, exist_ok=True)
            os.chdir(library_pathname)
            for accession, filepath in accession_to_filepath.items():
                filepath_to_taxid_table[filepath] =\
                    accession_to_taxid[accession]
            for filepath in filepaths_without_accessions:
                filepath_to_taxid_table[filepath] = ""

            sequence_to_url = assign_taxid_to_sequences(
                args, filepath_to_taxid_table,
                accession_to_taxid, filepath_to_url
            )
            with open("library.fna", "r") as in_file:
                with open("prelim_map.txt", "w") as out_file:
                    out_file.write("# prelim_map for {:s}\n".format(library))
                    scan_fasta_file(
                        in_file,
                        out_file,
                        sequence_to_url=sequence_to_url,
                    )

    os.chdir(db_pathname)
    build_kraken2_db(args)



def download_genomic_library(args):
    library_filename = "library.faa" if args.protein else "library.fna"
    library_pathname = os.path.join(args.db, "library")
    LOG.info("Adding {:s} to {:s}\n".format(args.library, args.db))
    if args.library in [
            "archaea",
            "bacteria",
            "viral",
            "fungi",
            "invertebrate",
            "plant",
            "human",
            "protozoa",
            "vertebrate_mammalian",
            "vertebrate_other"
    ]:
        library_pathname = os.path.join(library_pathname, args.library)
        os.makedirs(library_pathname, exist_ok=True)
        os.chdir(library_pathname)
        try:
            os.remove("assembly_summary.txt")
        except FileNotFoundError:
            pass
        remote_dir_name = args.library
        if args.library == "human":
            remote_dir_name = "vertebrate_mammalian/Homo_sapiens"
        try:
            if args.assembly_source == "all":
                # Download and merge assembly summaries from both RefSeq and GenBank
                refseq_url = "genomes/refseq/{:s}/assembly_summary.txt".format(remote_dir_name)
                genbank_url = "genomes/genbank/{:s}/assembly_summary.txt".format(remote_dir_name)

                # Download RefSeq assembly summary
                http_download_file2(
                    NCBI_SERVER, [refseq_url], save_to=os.path.abspath(os.curdir)
                )
                os.rename("assembly_summary.txt", "assembly_summary_refseq.txt")

                # Download GenBank assembly summary
                http_download_file2(
                    NCBI_SERVER, [genbank_url], save_to=os.path.abspath(os.curdir)
                )
                os.rename("assembly_summary.txt", "assembly_summary_genbank.txt")

                # Merge the two files, keeping the header from RefSeq only
                with open("assembly_summary.txt", "w") as merged_file:
                    # Write RefSeq entries
                    with open("assembly_summary_refseq.txt", "r") as refseq_file:
                        merged_file.write(refseq_file.read())

                    # Write GenBank entries (skip header lines starting with #)
                    with open("assembly_summary_genbank.txt", "r") as genbank_file:
                        for line in genbank_file:
                            if not line.startswith("#"):
                                merged_file.write(line)

                # Clean up temporary files
                os.remove("assembly_summary_refseq.txt")
                os.remove("assembly_summary_genbank.txt")
            else:
                # Original behavior for "refseq" or "genbank"
                url = "genomes/{}/{:s}/assembly_summary.txt".format(
                    args.assembly_source, remote_dir_name
                )
                http_download_file2(
                    NCBI_SERVER, [url], save_to=os.path.abspath(os.curdir)
                )
        except urllib.error.URLError:
            LOG.error(
                "Error downloading assembly summary file for {:s}, "
                "exiting\n".format(args.library)
            )
            sys.exit(1)
        if args.library == "human":
            with open("assembly_summary.txt", "r") as f1:
                with open("grc.txt", "w") as f2:
                    for line in f1:
                        if line.find("Genome Reference Consortium"):
                            f2.write(line)
            os.rename("grc.txt", "assembly_summary.txt")
        with open("assembly_summary.txt", "r") as f:
            filepath_to_url = {}
            filepath_to_taxid_table = make_manifest_from_assembly_summary(
                args, f
            )
            for filepath in filepath_to_taxid_table:
                filepath_to_url[filepath] = url_join(NCBI_SERVER, path=filepath)
            download_files_from_manifest(
                NCBI_SERVER,
                args.threads,
                filepath_to_taxid_table=filepath_to_taxid_table,
                resume=args.resume
            )
            sequence_to_url = assign_taxid_to_sequences(
                args, filepath_to_taxid_table,
                filepath_to_url=filepath_to_url
            )
        with open(library_filename, "r") as in_file:
            with open("prelim_map.txt", "w") as out_file:
                out_file.write("# prelim_map for " + args.library + "\n")
                scan_fasta_file(
                    in_file,
                    out_file,
                    sequence_to_url=sequence_to_url,
                )
    elif args.library in ["plasmid", "plastid", "mitochondrion"]:
        library_pathname = os.path.join(args.db, "library")
        library_pathname = os.path.join(library_pathname, args.library)
        library_filename = "library.faa" if args.protein else "library.fna"
        library_filename = os.path.join(library_pathname, library_filename)
        os.makedirs(library_pathname, exist_ok=True)
        os.chdir(library_pathname)
        pat = ".faa.gz" if args.protein else ".fna.gz"
        md5 = get_manifest_and_md5sums(
            NCBI_SERVER, "genomes/refseq/{}/".format(args.library), pat
        )
        download_files_from_manifest(
            NCBI_SERVER, args.threads, md5sums=md5, resume=args.resume
        )
        sequence_to_url = {}
        filenames = []
        with open("manifest.txt", "r") as manifest:
            filenames = manifest.readlines()
        filenames = sorted(filenames)
        with concurrent.futures.ProcessPoolExecutor(
                max_workers=args.threads
        ) as pool:
            futures = []
            for filename in filenames:
                filename = filename.strip()
                filename = os.path.abspath(filename)
                if not args.no_masking:
                    f = functools.partial(
                        wrap_with_globals, decompress_and_mask,
                        LOG.get_queue(), LOG.get_level(),
                        SCRIPT_PATHNAME
                    )
                    future = pool.submit(
                        f, filename, args.masker_threads
                    )
                    if future_raised_exception(future):
                        LOG.error(
                            "Error encountered while decompressing"
                            " or masking files\n"
                        )
                        raise future.exception()
                    futures.append(future)
                else:
                    f = functools.partial(
                        wrap_with_globals, decompress_files,
                        LOG.get_queue(), LOG.get_level(),
                        SCRIPT_PATHNAME
                    )
                    future = pool.submit(f, [filename])
                    if future_raised_exception(future):
                        LOG.error(
                            "Error encountered while decompressing files\n"
                        )
                        raise future.exception()
                    futures.append(future)
            result = concurrent.futures.wait(
                futures,
                return_when=concurrent.futures.ALL_COMPLETED
            )
            if len(result.not_done) > 0:
                LOG.error(
                    "Encountered error while downloading Plasmid library\n"
                )
                sys.exit(1)
        LOG.info("Generating {}\n".format(library_filename))
        with open(library_filename, "w") as out_file:
            for filename in filenames:
                in_filename = os.path.splitext(filename)[0]
                with open(os.path.abspath(in_filename), "r") as in_file:
                    for line in in_file:
                        if line.startswith(">"):
                            sequence_to_url[line.strip()] = filename
                        out_file.write(line)

        LOG.info("Finished generating {}\n".format(library_filename))
        with open(library_filename, "r") as in_file:
            with open("prelim_map.txt", "w") as out_file:
                out_file.write("# prelim_map for " + args.library + "\n")
                scan_fasta_file(
                    in_file,
                    out_file,
                    sequence_to_url=sequence_to_url,
                )
    elif args.library in ["core_nt", "nt", "env_nt", "nt_viruses", "nt_euk"]:
        library_pathname = os.path.join(library_pathname, args.library)
        os.makedirs(library_pathname, exist_ok=True)
        os.chdir(library_pathname)
        create_manifest_for_blast_db(args.library, args.blast_volumes)
        download_and_process_blast_volumes(args)
    elif args.library in ["UniVec", "UniVec_Core"]:
        if args.protein:
            LOG.error(
                "{:s} is available for nucleotide databases only\n".format(
                    args.library
                )
            )
            sys.exit(1)
        library_pathname = os.path.join(library_pathname, args.library)
        os.makedirs(library_pathname, exist_ok=True)
        os.chdir(library_pathname)
        http_download_file2(
            NCBI_SERVER,
            ["pub/UniVec/" + args.library],
            save_to=os.path.abspath(os.curdir),
        )
        special_taxid = 28384
        LOG.info(
            "Assigning taxonomy ID of {:d} to all sequences\n".format(
                special_taxid
            )
        )
        with open(args.library, "r") as in_file:
            with open("library.fna", "w") as out_file:
                for line in in_file:
                    if line.startswith(">"):
                        line = re.sub(
                            ">",
                            ">kraken:taxid|" + str(special_taxid) + "|",
                            line,
                        )
                    out_file.write(line)
        with open("library.fna", "r") as in_file:
            with open("prelim_map.txt", "w") as out_file:
                out_file.write("# prelim_map for " + args.library + "\n")
                scan_fasta_file(
                    in_file,
                    out_file,
                    sequence_to_url="pub/UniVec/" + args.library,
                )
    else:
        if args.library.upper().startswith("GCF")\
           or args.library.upper().startswith("GCA"):
            download_dataset_by_project(args, "genome/accession",
                                        [args.library.upper()])
        elif args.library.upper().startswith("PRJ"):
            download_dataset_by_project(args, "genome/bioproject",
                                        [args.library.upper()])
        else:
            download_dataset_by_project(args, "genome/taxon", [args.library])

    if not args.no_masking\
       and args.library in ["UniVec", "UniVec_Core"]:
        mask_files(
            [library_filename],
            library_filename + ".masked",
            args.masker_threads,
            args.protein
        )
        shutil.move(library_filename + ".masked", library_filename)
    LOG.info("Added {:s} to {:s}\n".format(args.library, args.db))


def get_abs_path(filename):
    return os.path.abspath(filename)


def is_compressed(filename):
    bzip_magic = b"\x42\x5A\x68"
    gzip_magic = b"\x1F\x8B"
    xz_magic = b"\xFD\x37\x7A\x58\x5A\x00"

    nbytes = len(xz_magic)
    with open(filename, "rb") as f:
        data = f.read(nbytes)
        if data.startswith((bzip_magic, gzip_magic, xz_magic)):
            return True
        return False


def get_reader(filename):
    bzip_magic = b"\x42\x5A\x68"
    gzip_magic = b"\x1F\x8B"
    xz_magic = b"\xFD\x37\x7A\x58\x5A\x00"

    nbytes = len(xz_magic)
    with open(filename, "rb") as f:
        data = f.read(nbytes)
        if data.startswith(bzip_magic):
            return bz2.open
        elif data.startswith(gzip_magic):
            return gzip.open
        elif data.startswith(xz_magic):
            return lzma.open
        else:
            return open


def read_from_files(filename1, filename2=None):
    reader1 = get_reader(filename1)
    reader2 = None

    if filename2 is not None:
        reader2 = get_reader(filename2)

    if reader2 is None:
        with reader1(filename1, "rb") as f:
            for seq in f:
                yield seq
    else:
        with reader1(filename1, "rb") as f1, reader2(filename2, "rb") as f2:
            for seq1, seq2 in itertools.zip_longest(f1, f2):
                if seq1 is None:
                    LOG.error(
                        "{} contains more sequences than {}".format(
                            filename1, filename2
                        )
                    )
                    sys.exit(1)
                if seq2 is None:
                    LOG.error(
                        "{} contains more sequences than {}".format(
                            filename2, filename1
                        )
                    )
                    sys.exit(1)
                yield (seq1, seq2)


def write_to_fifo(filenames, fifo1=None, fifo2=None):
    if fifo2 is not None:
        with open(fifo1, "wb") as file1, open(fifo2, "wb") as file2:
            for fn1, fn2 in zip(filenames[0::2], filenames[1::2]):
                for seq1, seq2 in read_from_files(fn1, fn2):
                    file1.write(seq1)
                    file2.write(seq2)
    else:
        with open(fifo1, "wb") as file1:
            for fn in filenames:
                for seq in read_from_files(fn):
                    file1.write(seq)


def check_seqidmap():
    LOG.info(
        "Checking if there are invalid taxid in seqid2taxid.map. "
        "These taxids will be logged if found and removed from the file\n"
    )
    taxonomy_nodes = {}
    with open(os.path.join("taxonomy", "nodes.dmp"), "r") as fin:
        for entry in fin:
            taxid, parent_taxid = entry.split("\t|")[:2]
            taxonomy_nodes[taxid.strip()] = parent_taxid.strip()

    with open("seqid2taxid.map.new", "w") as fout:
        with open("seqid2taxid.map", "r") as fin:
            for line in fin:
                seqid, taxid = line.split("\t")
                taxid = taxid.strip()
                if taxid in taxonomy_nodes:
                    fout.write(line)
                else:
                    LOG.warning(
                        "There is no entry for taxid, '{}', contained in, {},"
                        "in nodes.dmp. Please contact NCBI about this\n"
                        .format(taxid, seqid)
                    )
    shutil.move("seqid2taxid.map.new", "seqid2taxid.map")


def suffix_to_multiplier(suffix):
    name_to_size = {
        'byte': 1,
        'kebibyte': 2 ** 10,
        'mebibyte': 2 ** 20,
        'gebibyte': 2 ** 30,
        'tebibyte': 2 ** 40,
        'kilobyte': 10 ** 3,
        'megabyte': 10 ** 6,
        'gigabyte': 10 ** 9,
        'terabyte': 10 ** 12

    }

    unit_to_size = {
        'B': 1,
        'KiB': 2 ** 10,
        'KB':  10 ** 3,
        'MiB': 2 ** 20,
        'MB':  10 ** 6,
        'GiB': 2 ** 30,
        'GB':  10 ** 9,
        'TiB': 2 ** 40,
        'TB':  10 ** 12,
    }

    original_suffix = suffix

    if suffix in unit_to_size:
        return unit_to_size[suffix]

    if suffix.lower().endswith("s"):
        suffix = suffix.lower()[:-1]

    if suffix in name_to_size:
        return name_to_size[suffix]

    LOG.error("Unable to convert {} to a storage unit\n".format(original_suffix))
    sys.exit(1)


def parse_db_size(input):
    if input.isdigit():
        return int(input)

    input = input.replace(" ", "")
    number = "".join(itertools.takewhile(str.isnumeric, input))
    suffix = "".join(itertools.takewhile(str.isalpha, input[len(number):]))
    number = int(number)
    multiplier = suffix_to_multiplier(suffix)

    return number * multiplier


def build_kraken2_db(args):
    if not os.path.isdir(get_abs_path(args.db)):
        LOG.error('Cannot find Kraken 2 database: "{:s}\n'.format(args.db))
        sys.exit(1)
    os.chdir(args.db)
    if not os.path.isdir("taxonomy"):
        LOG.error("Cannot find taxonomy subdirectory in database\n")
        sys.exit(1)
    if not os.path.isdir("library"):
        LOG.error("Cannot find library subdirectory in database\n")
        sys.exit(1)

    prelim_map_filepaths = []
    prelim_map_mtime = 0
    if os.path.isdir("library"):
        glob_path = os.path.join("library", "*")
        prelim_map_filepaths = glob.glob(
            os.path.join(glob_path, "prelim_map*.txt")
        )
        for prelim_map_filepath in prelim_map_filepaths:
            mtime = os.path.getmtime(prelim_map_filepath)
            if mtime > prelim_map_mtime:
                prelim_map_mtime = mtime

    if os.path.exists("seqid2taxid.map") and \
       os.path.getmtime("seqid2taxid.map") > prelim_map_mtime:
        LOG.info(
            "A seqid2taxid.map already present and newer"
            " than any of the prelim_map.txt files, skipping\n"
        )
    else:
        LOG.info("Concatenating prelim_map.txt files\n")
        with open("prelim_map.txt", "w") as out_file:
            for prelim_map_filepath in prelim_map_filepaths:
                with open(prelim_map_filepath, "r") as in_file:
                    shutil.copyfileobj(in_file, out_file)
        if os.path.getsize("prelim_map.txt") == 0:
            os.remove("prelim_map.txt")
            LOG.error(
                "No preliminary seqid/taxid mapping files found, aborting\n"
            )
            sys.exit(1)
        LOG.info("Finished concatenating prelim_map.txt files\n")
        LOG.info("Creating sequence ID to taxonomy ID map\n")
        with open("prelim_map.txt", "r") as in_file:
            with open("seqid2taxid.map.tmp", "w") as seqid2taxid_file:
                with open("accmap.tmp", "w") as accmap_file:
                    for line in in_file:
                        if line.startswith("#"):
                            continue
                        line = line.strip()
                        new_line = "\t".join(line.split("\t")[1:3]) + "\n"
                        if line.startswith("TAXID"):
                            seqid2taxid_file.write(new_line)
                        elif line.startswith("ACCNUM"):
                            accmap_file.write(new_line)
        if os.path.getsize("accmap.tmp") > 0:
            accession2taxid_filenames = glob.glob("taxonomy/*.accession2taxid")
            if accession2taxid_filenames:
                lookup_accession_numbers(
                    "accmap.tmp",
                    "seqid2taxid.map.tmp",
                    *accession2taxid_filenames
                )
            else:
                LOG.error(
                    "Accession to taxid map files are required to"
                    " build this database.\n"
                )
                LOG.error(
                    "Run k2 download-taxonomy --db {:s} again".format(args.db)
                )
                sys.exit(1)
        os.remove("accmap.tmp")
        move("seqid2taxid.map.tmp", "seqid2taxid.map")
        LOG.info("Created sequence ID to taxonomy ID map\n")
        check_seqidmap()

    estimate_capacity_binary = find_kraken2_binary("estimate_capacity")
    argv = [estimate_capacity_binary, "-S", construct_seed_template(args)]
    if args.protein:
        argv.append("-X")
    wrapper_args_to_binary_args(
        args, argv, get_binary_options(estimate_capacity_binary)
    )
    fasta_filenames = glob.glob(
        os.path.join("library", os.path.join("*", "*.f[an]a")),
        recursive=False
    )
    estimate = ""
    total_sequences = 0
    if os.path.exists("estimated_capacity"):
        estimated_capacity_mtime = \
            os.path.getmtime("estimated_capacity")
        seqid_to_taxid_map_mtime = os.path.getmtime("seqid2taxid.map")
        if estimated_capacity_mtime > seqid_to_taxid_map_mtime:
            LOG.info(
                "An estimated_capacity file exists and is newer "
                "than seqid2taxid.map , reading the estimated "
                "capacity from estimated_capacity file.\n"
            )
            with open("estimated_capacity", "r") as in_file:
                lines = in_file.readlines()
                if len(lines) == 1:
                    estimate = lines[0].strip()
                elif len(lines) == 2:
                    estimate = lines[0].strip()
                    total_sequences = int(lines[1].strip())

    mapped_sequences = {}
    with open("seqid2taxid.map", "rb") as in_file:
        for line in in_file:
            sequence_name = line.split()[0]
            mapped_sequences[sequence_name] = True

    if estimate == "":
        if not dwk2():
            argv.extend(fasta_filenames)
        LOG.info("Running: " + " ".join(argv) + "\n")
        proc = subprocess.Popen(
            argv, stdin=subprocess.PIPE, stdout=subprocess.PIPE
        )
        if dwk2():
            for filename in fasta_filenames:
                with open(filename, "rb") as in_file:
                    for line in in_file:
                        if line.startswith(b'>'):
                            sequence_name = line.split()[0]
                            if sequence_name[1:] in mapped_sequences:
                                total_sequences += 1
                        proc.stdin.write(line)
        estimate = proc.communicate()[0].decode()
        proc.stdin.close()
        with open("estimated_capacity", "w") as out_file:
            out_file.write(estimate + str(total_sequences) + "\n")

    required_capacity = (int(estimate.strip()) + 8192) / args.load_factor
    LOG.info(
        "Estimated hash table requirement: {:s}\n".format(
            format_bytes(required_capacity * 4)
        )
    )
    if args.max_db_size:
        args.max_db_size = parse_db_size(args.max_db_size)
        if args.max_db_size < required_capacity * 4:
            args.max_db_size = int(args.max_db_size / 4)
            LOG.info(
                "Maximum hash table size of {}, specified and is"
                " lower than the calculated estimated capacity of {}\n"
                .format(
                    format_bytes(args.max_db_size * 4),
                    format_bytes(required_capacity * 4))
            )
    if os.path.isfile("hash.k2d"):
        LOG.info("Hash table already present, skipping build\n")
    else:
        LOG.info("Starting database build\n")
        build_db_bin = find_kraken2_binary("build_db")
        argv = [
            build_db_bin,
            "-H",
            "hash.k2d.tmp",
            "-t",
            "taxo.k2d.tmp",
            "-o",
            "opts.k2d.tmp",
            "-n",
            "taxonomy",
            "-m",
            "seqid2taxid.map",
            "-c",
            str(required_capacity),
            "-S",
            construct_seed_template(args),
        ]
        if args.protein:
            argv.append("-X")
        wrapper_args_to_binary_args(
            args, argv, get_binary_options(build_db_bin)
        )

        LOG.info("Running: " + " ".join(argv) + "\n")
        if total_sequences > 0:
            m_err, s_err = pty.openpty()
            cat_proc = subprocess.Popen(
                ["cat"] + fasta_filenames,
                stdout=subprocess.PIPE
            )
            build_proc = subprocess.Popen(
                argv, stdin=cat_proc.stdout,
                stdout=s_err,
                stderr=s_err,
            )

            thread = threading.Thread(
                target=read_from_stderr,
                args=(m_err, total_sequences)
            )
            thread.start()
        else:
            build_proc = subprocess.Popen(
                argv, stdin=subprocess.PIPE,
            )

        build_proc.communicate()
        cat_proc.stdout.close()
        if build_proc.returncode != 0:
            LOG.error(
                "Encountered error while building database: "
                "build process died unexpectedly\n"
            )

        if total_sequences > 0:
            os.close(s_err)
            thread.join()
            if build_proc.wait() != 0 or cat_proc.wait() != 0:
                os.close(m_err)
                return

        move("hash.k2d.tmp", "hash.k2d")
        move("taxo.k2d.tmp", "taxo.k2d")
        move("opts.k2d.tmp", "opts.k2d")
        LOG.info("Finished building database\n")


def decompress_with_zlib(filename):
    inflator = zlib.decompressobj(15 + 32)
    with open(filename, "rb") as infile:
        while True:
            data = infile.read(8196)
            if not data:
                break
            inflator.decompress(data)


def read_from_stderr(fd, total_sequences):
    pb = ProgressBar(total_sequences)
    buffer = b""
    processing = True
    while processing:
        data = os.read(fd, 1024)
        if len(data) == 0:
            processing = False
        data = buffer + data
        buffer = b""
        for line in data.splitlines(True):
            fields = line.split()
            if line.startswith(b"Processed") and len(fields) > 1:
                buffer = b""
                progress = int(fields[1])
                pb.progress(progress)
                eol = '\r'
                if pb.current == total_sequences:
                    processing = False
                    eol = '\n'
                LOG.debug(
                    "Processed:" +
                    pb.get_bar() + " {}/{}{}"
                    .format(progress, total_sequences, eol)
                )
            elif line.endswith(b"\n"):
                LOG.debug(line.decode())
            else:
                buffer = line
    os.close(fd)


# Parses RDP sequence data to create Kraken taxonomy
# and sequence ID -> taxonomy ID mapping
def build_rdp_taxonomy(f):
    seqid_map = {}
    seen_it = {}
    child_data = {"root;no rank": {}}

    for line in f:
        if not line.startswith(">"):
            continue
        line = line.strip()
        seq_label, taxonomy_string = line.split("\t")
        seqid = seq_label.split(" ")[0]
        taxonomy_string = re.sub(
            "^Lineage=Root;rootrank;", "root;no rank;", taxonomy_string
        )
        taxonomy_string = re.sub(";$", ";no rank", taxonomy_string)
        seqid_map[seqid] = taxonomy_string
        seen_it.setdefault(taxonomy_string, 0)
        seen_it[taxonomy_string] += 1
        if seen_it[taxonomy_string] > 1:
            continue
        while True:
            match = re.search("(;[^;]+;[^;]+)$", taxonomy_string)
            if match is None:
                break
            level = match.group(1)
            taxonomy_string = re.sub(";[^;]+;[^;]+$", "", taxonomy_string)
            key = taxonomy_string + level
            child_data.setdefault(taxonomy_string, {})
            seen_it.setdefault(taxonomy_string, 0)
            child_data[taxonomy_string].setdefault(key, 0)
            child_data[taxonomy_string][key] += 1
            seen_it[taxonomy_string] += 1
            if seen_it[taxonomy_string] > 1:
                break
    id_map = {}
    next_node_id = 1
    with open("names.dmp", "w") as names_file:
        with open("nodes.dmp", "w") as nodes_file:
            bfs_queue = [["root;no rank", 1]]
            while len(bfs_queue) > 0:
                node, parent_id = bfs_queue.pop()
                match = re.search("([^;]+);([^;]+)$", node)
                if match is None:
                    LOG.error(
                        'BFS processing encountered formatting eror, "{:s}"\n'
                        .format(node)
                    )
                    sys.exit(1)
                display_name, rank = match.group(1), match.group(2)
                if rank == "domain":
                    rank = "superkingdom"
                node_id, next_node_id = next_node_id, next_node_id + 1
                id_map[node] = node_id
                names_file.write(
                    "{:d}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format(
                        node_id, display_name
                    )
                )
                nodes_file.write(
                    "{:d}\t|\t{:d}\t|\t{:s}\t|\t-\t|\n".format(
                        node_id, parent_id, rank
                    )
                )
                children = (
                    sorted([key for key in child_data[node]])
                    if node in child_data
                    else []
                )
                for node in children:
                    bfs_queue.insert(0, [node, node_id])
    with open("seqid2taxid.map", "w") as f:
        for seqid in sorted([key for key in seqid_map]):
            taxid = id_map[seqid_map[seqid]]
            f.write("{:s}\t{:d}\n".format(seqid, taxid))


# Build the standard Kraken database
def build_standard_database(args):
    # download_taxonomy(args)
    args.assembly_source = "refseq"
    args.assembly_levels = ["chromosome", "complete_genome"]
    args.resume = True
    for library in [
        "archaea",
        "bacteria",
        "viral",
        "plasmid",
        "human",
        "UniVec_Core",
    ]:
        if library == "UniVec_Core" and args.protein:
            continue
        args.library = library
        download_genomic_library(args)
    build_kraken2_db(args)


# Parses Silva taxonomy file to create Kraken taxonomy
def build_silva_taxonomy(in_file):
    id_map = {"root": 1}
    with open("names.dmp", "w") as names_file:
        with open("nodes.dmp", "w") as nodes_file:
            names_file.write("1\t|\troot\t|\t-\t|\tscientific name\t|\n")
            nodes_file.write("1\t|\t1\t|\tno rank\t|\t-\t|\n")
            for line in in_file:
                line = line.strip()
                taxonomy_string, node_id, rank = line.split("\t")[:3]
                id_map[taxonomy_string] = node_id
                match = re.search("^(.+;|)([^;]+);$", taxonomy_string)
                if match:
                    parent_name = match.group(1)
                    display_name = match.group(2)
                    if parent_name == "":
                        parent_name = "root"
                    parent_id = id_map[parent_name] or None
                    if not parent_id:
                        LOG.error('orphan error: "{:s}"\n'.format(line))
                        sys.exit(1)
                    if rank == "domain":
                        rank = "superkingdom"
                    names_file.write(
                        "{:s}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format(
                            node_id, display_name
                        )
                    )
                    nodes_file.write(
                        "{:s}\t|\t{:s}\t|\t{:s}\t|\t-\t|\n".format(
                            node_id, str(parent_id), rank
                        )
                    )
                else:
                    LOG.error('strange input: "{:s}"\n'.format(line))
                    sys.exit(1)


# Build a 16S database from Silva data
def build_16S_silva(args):
    args.db = os.path.abspath(args.db)
    os.makedirs(args.db, exist_ok=True)
    os.chdir(args.db)
    for directory in ["data", "taxonomy", "library"]:
        os.makedirs(directory, exist_ok=True)
    os.chdir("data")
    remote_directory = "/release_138_2/Exports"
    fasta_filename = "SILVA_138.2_SSURef_NR99_tax_silva.fasta.gz"
    taxonomy_prefix = "tax_slv_ssu_138.2"
    ftp = FTP(SILVA_SERVER)
    ftp.download(remote_directory, fasta_filename)
    ftp.download(
        remote_directory + "/taxonomy", taxonomy_prefix + ".acc_taxid.gz"
    )
    decompress_files([taxonomy_prefix + ".acc_taxid.gz"])
    ftp.download(remote_directory + "/taxonomy", taxonomy_prefix + ".txt.gz")
    with gzip.open(taxonomy_prefix + ".txt.gz", "rt") as f:
        build_silva_taxonomy(f)
    os.chdir(os.path.pardir)
    move(os.path.join("data", "names.dmp"), "taxonomy")
    move(os.path.join("data", "nodes.dmp"), "taxonomy")
    move(
        os.path.join("data", taxonomy_prefix + ".acc_taxid"), "seqid2taxid.map"
    )
    with gzip.open(os.path.join("data", fasta_filename), "rt") as in_file:
        os.chdir("library")
        os.makedirs("silva", exist_ok=True)
        os.chdir("silva")
        with open("library.fna", "w") as out_file:
            for line in in_file:
                if not line.startswith(">"):
                    line = line.replace("U", "T")
                out_file.write(line)
    if not args.no_masking:
        filename = "library.fna"
        mask_files(
            [filename], filename + ".masked", args.threads
        )
        shutil.move(filename + ".masked", filename)

    os.chdir(args.db)
    build_kraken2_db(args)


# Parses Greengenes taxonomy file to create Kraken taxonomy
# and sequence ID -> taxonomy ID mapping
# Input: gg_13_5_taxonomy.txt
def build_gg_taxonomy(in_file):
    rank_codes = {
        "k": "superkingdom",
        "p": "phylum",
        "c": "class",
        "o": "order",
        "f": "family",
        "g": "genus",
        "s": "species",
    }
    seqid_map = {}
    seen_it = {}
    child_data = {"root": {}}
    for line in in_file:
        line = line.strip()
        seqid, taxonomy_string = line.split("\t")
        taxonomy_string = re.sub("(; [a-z]__)+$", "", taxonomy_string)
        seqid_map[seqid] = taxonomy_string
        seen_it.setdefault(taxonomy_string, 0)
        seen_it[taxonomy_string] += 1
        if seen_it[taxonomy_string] > 1:
            continue
        while True:
            match = re.search("(; [a-z]__[^;]+$)", taxonomy_string)
            if not match:
                break
            level = match.group(1)
            taxonomy_string = re.sub("(; [a-z]__[^;]+$)", "", taxonomy_string)
            child_data.setdefault(taxonomy_string, {})
            key = taxonomy_string + level
            seen_it.setdefault(taxonomy_string, 0)
            child_data[taxonomy_string].setdefault(key, 0)
            child_data[taxonomy_string][key] += 1
            seen_it[taxonomy_string] += 1
            if seen_it[taxonomy_string] > 1:
                break
        if seen_it[taxonomy_string] == 1:
            child_data["root"].setdefault(taxonomy_string, 0)
            child_data["root"][taxonomy_string] += 1
    id_map = {}
    next_node_id = 1
    with open("names.dmp", "w") as names_file:
        with open("nodes.dmp", "w") as nodes_file:
            bfs_queue = [["root", 1]]
            while len(bfs_queue) > 0:
                node, parent_id = bfs_queue.pop()
                display_name = node
                rank = None
                match = re.search("g__([^;]+); s__([^;]+)$", node)
                if match:
                    genus, species = match.group(1), match.group(2)
                    rank = "species"
                    if re.search(" endosymbiont ", species):
                        display_name = species
                    else:
                        display_name = genus + " " + species
                else:
                    match = re.search("([a-z])__([^;]+)$", node)
                    if match:
                        rank = rank_codes[match.group(1)]
                        display_name = match.group(2)
                rank = rank or "no rank"
                node_id, next_node_id = next_node_id, next_node_id + 1
                id_map[node] = node_id
                names_file.write(
                    "{:d}\t|\t{:s}\t|\t-\t|\tscientific name\t|\n".format(
                        node_id, display_name
                    )
                )
                nodes_file.write(
                    "{:d}\t|\t{:d}\t|\t{:s}\t|\t-\t|\n".format(
                        node_id, parent_id, rank
                    )
                )
                children = (
                    sorted([key for key in child_data[node]])
                    if node in child_data
                    else []
                )
                for node in children:
                    bfs_queue.insert(0, [node, node_id])
    with open("seqid2taxid.map", "w") as f:
        for seqid in sorted([key for key in seqid_map], key=int):
            taxid = id_map[seqid_map[seqid]]
            f.write("{:s}\t{:d}\n".format(seqid, taxid))


# Build a 16S database from Greengenes data
def build_16S_gg(args):
    args.db = os.path.abspath(args.db)
    os.makedirs(args.db, exist_ok=True)
    gg_version = "gg_13_5"
    remote_directory = "/greengenes_release/" + gg_version
    os.chdir(args.db)
    for directory in ["data", "taxonomy", "library"]:
        os.makedirs(directory, exist_ok=True)
    os.chdir("data")
    ftp = FTP(GREENGENES_SERVER)
    ftp.download(remote_directory, gg_version + ".fasta.gz")
    decompress_files([gg_version + ".fasta.gz"])
    ftp.download(remote_directory, gg_version + "_taxonomy.txt.gz")
    decompress_files([gg_version + "_taxonomy.txt.gz"])
    with open(gg_version + "_taxonomy.txt", "r") as f:
        build_gg_taxonomy(f)
    os.chdir(os.path.abspath(os.path.pardir))
    move(os.path.join("data", "names.dmp"), "taxonomy")
    move(os.path.join("data", "nodes.dmp"), "taxonomy")
    move(os.path.join("data", "seqid2taxid.map"), os.getcwd())
    move(
        os.path.join("data", gg_version + ".fasta"),
        os.path.join("library", "library.fna"),
    )
    os.chdir("library")
    os.makedirs("greengenes", exist_ok=True)
    move("library.fna", "greengenes")
    os.chdir("greengenes")
    if not args.no_masking:
        filename = "library.fna"
        mask_files([filename], filename + ".masked", args.threads)
        move(filename + ".masked", filename)
    os.chdir(args.db)
    build_kraken2_db(args)


# Build a 16S data from RDP data
def build_16S_rdp(args):
    os.makedirs(args.db, exist_ok=True)
    os.chdir(args.db)
    for directory in ["data", "taxonomy", "library"]:
        os.makedirs(directory, exist_ok=True)
    os.chdir("data")
    http_download_file(
        "http://rdp.cme.msu.edu/download/current_Bacteria_unaligned.fa.gz"
    )
    http_download_file(
        "http://rdp.cme.msu.edu/download/current_Archaea_unaligned.fa.gz"
    )
    decompress_files(glob.glob("*gz"))
    for filename in glob.glob("current_*_unaligned.fa"):
        with open(filename, "r") as f:
            build_rdp_taxonomy(f)
    os.chdir(os.pardir)
    move(os.path.join("data", "names.dmp"), "taxonomy")
    move(os.path.join("data", "nodes.dmp"), "taxonomy")
    move(os.path.join("data", "seqid2taxid.map"), os.getcwd())
    for filename in glob.glob(os.path.join("data", "*.fa")):
        new_filename = os.path.basename(re.sub(r"\.fa$", ".fna", filename))
        shutil.move(filename, os.path.join("library", new_filename))
        if not args.no_masking:
            new_filename = os.path.join("library", new_filename)
            mask_files(
                [new_filename], new_filename + ".masked", args.threads
            )
            shutil.move(new_filename + ".masked", new_filename)

    build_kraken2_db(args)


# Reads multi-FASTA input and examines each sequence header.  In quiet
# mode headers are OK if a taxonomy ID is found (as either the entire
# sequence ID or as part of a "kraken:taxid" token), or if something
# looking like a GI or accession number is found.  In normal mode, the
# taxonomy ID will be looked up (if not explicitly specified in the
# sequence ID) and reported if it can be found.  Output is
# tab-delimited lines, with sequence IDs in first column and taxonomy
# IDs in second.


# Sequence IDs with a kraken:taxid token will use that to assign taxonomy
# ID, e.g.:
# >gi|32499|ref|NC_021949.2|kraken:taxid|562|
#
# Sequence IDs that are completely numeric are assumed to be the taxonomy
# ID for that sequence.
#
# Otherwise, an accession number is searched for; if not found, a GI
# number is searched for.  Failure to find any of the above is a fatal error.
# Without `quiet`, a comma-separated file list specified by -A (for both accession
# numbers and GI numbers) is examined; failure to find a
# taxonomy ID that maps to a provided accession/GI number is non-fatal and
# will emit a warning.
#
# With -q, does not print any output, and will die w/ nonzero exit instead
# of warning when unable to find a taxid, accession #, or GI #.
#
def make_seqid_to_taxid_map(
    in_file, quiet, accession_map_filenames=False, library_map_filename=None
):
    target_lists = {}
    for line in in_file:
        match = re.match(r">(\S+)", line)
        if match is None:
            continue
        seqid = match.group(1)
        output = None
        regexes = [
            r"(?:^|\|)kraken:taxid\|(\d+)",
            r"^\d+$",
            r"(?:^|\|)([A-Z]+_?[A-Z0-9]+)(?:\||\b|\.)",
            r"(?:^|\|)gi\|(\d+)",
        ]
        match = None
        index = None
        for i, regex in enumerate(regexes):
            match = re.match(regex, seqid)
            if match:
                index = i
                break
        if index == 0:
            output = seqid + "\t" + match.group(1) + "\n"
        elif index == 1:
            output = seqid + "\t" + seqid + "\n"
        elif index in [2, 3]:
            if not quiet:
                capture = match.group(1)
                target_lists.setdefault(capture, [])
                target_lists[capture].insert(0, seqid)
        else:
            LOG.error(
                "Unable to determine taxonomy ID for sequence {:s}\n".format(
                    seqid
                )
            )
            sys.exit(1)
        if output and not quiet:
            print(output)
    if quiet:
        if len(target_lists) == 0:
            LOG.error("External map required\n")
        sys.exit(0)
    if len(target_lists) == 0:
        sys.exit(0)
    if not accession_map_filenames and library_map_filename is None:
        LOG.error(
            "Found sequence ID without explicit taxonomy ID, but no map used\n"
        )
        sys.exit(1)
    # Remove targets where we've already handled the mapping
    if library_map_filename:
        with open(library_map_filename, "r") as f:
            for line in f:
                line = line.strip()
                seqid, taxid = line.split("\t")
                if seqid in target_lists:
                    print("{:s}\t{:s}\n".format(seqid, taxid))
                    del target_lists[seqid]
    if len(target_lists) == 0:
        sys.exit(0)
    for filename in accession_map_filenames:
        with open(filename, "r") as f:
            f.readline()
            for line in f:
                line = line.strip()
                accession, with_version, taxid, gi = line.split("\t")
                if accession in target_lists:
                    target_list = target_lists[accession]
                    del target_lists[accession]
                    for seqid in target_list:
                        print("{:s}\t{:s}".format(seqid, taxid))
                if gi != "na" and gi in target_lists:
                    target_list = target_lists[gi]
                    del target_lists[gi]
                    for seqid in target_list:
                        print("{:s}\t{:s}\n".format(seqid, taxid))


def wait_for_files(*fifos):
    for fifo in fifos:
        while not os.path.exists(fifo):
            continue


def cleanup_fifos():
    LOG.info("Cleaning up fifos and pid files\n")
    for filename in os.listdir("/tmp"):
        if re.match(r"^classify_(?:\d+_)?(?:stdin|stdout)$", filename):
            os.remove(os.path.join("/tmp", filename))
        elif filename == "classify.pid":
            os.remove(os.path.join("/tmp", "classify.pid"))


def copy_file_obj(dst):
    while True:
        data = sys.stdin.read(8194)
        if not data:
            break
        dst.write(data)
    dst.close()


def check_daemon():
    if not os.path.exists("/tmp/classify.pid"):
        return False
    with open("/tmp/classify.pid", "r") as in_file:
        pid = in_file.readline().strip()
        null = os.open(os.devnull, os.O_WRONLY)
        alive = subprocess.call(
            ["ps", pid], stdout=null, stderr=null
        )
        os.close(null)

    return alive == 0


def message_daemon(message):
    alive = check_daemon()

    if not alive:
        return

    fd_rd_1 = os.open("/tmp/classify_stdin", os.O_RDONLY | os.O_NONBLOCK)
    fd_wr_1 = os.open("/tmp/classify_stdin", os.O_WRONLY)
    fd_rd_2 = os.open("/tmp/classify_stdout", os.O_RDONLY | os.O_NONBLOCK)

    os.set_blocking(fd_rd_1, True)

    alive = False
    try:
        os.write(fd_wr_1, message)
        time.sleep(0.1)
        if os.read(fd_rd_2, 3) == b"OK\n":
            # put some log here that the daemon is stopped
            alive = True
    except BlockingIOError:
        alive = False

    os.close(fd_rd_1)
    os.close(fd_rd_2)
    os.close(fd_wr_1)


def classify_using_daemon(args, argv):
    alive = check_daemon()
    if not alive:
        cleanup_fifos()
        LOG.info("Starting backgroud classifier process\n")
        subprocess.call(argv)
        wait_for_files(
            "/tmp/classify.pid",
            "/tmp/classify_stdin",
            "/tmp/classify_stdout"
        )
        with open("/tmp/classify.pid", "r") as in_file:
            pid = in_file.readline().strip()
            LOG.info(
                "Started background classifier process with PID: {}\n"
                .format(pid)
            )
            LOG.info("Run k2 clean --stop-daemon to stop it.\n")
    else:
        fd = os.open("/tmp/classify_stdin", os.O_RDWR)
        with os.fdopen(fd, 'w') as out_file:
            out_file.write(" ".join(argv) + "\n")
            # the daemon will return the pid of the subprocess
            # doing the work
    with open("/tmp/classify_stdout", "r") as in_file:
        for line in in_file:
            line = line.strip()
            if line.startswith("PID"):
                pid = line.split(":")[1].strip()
                break
            else:
                print(line)
    proc_in = "/tmp/classify_{}_stdin".format(pid)
    proc_out = "/tmp/classify_{}_stdout".format(pid)
    wait_for_files(proc_in, proc_out)
    thread = None
    proc_fd = open(proc_in, "w")
    if len(args.filenames) == 0:
        thread = threading.Thread(
            target=copy_file_obj, args=(proc_fd,)
        )
        thread.start()
    if "output" not in args:
        with open(proc_out, "r") as in_file:
            for line in in_file:
                line = line.strip()
                print(line)
                # wait for the classification job to complete
    if thread:
        thread.join()

    with open("/tmp/classify_stdout", "r") as in_file:
        try:
            for line in in_file:
                line = line.strip()
                if line == "DONE":
                    break
        except KeyboardInterrupt:
            os.kill(int(pid), signal.SIGINT)
            for line in in_file:
                line = line.strip()
                if line == "DONE":
                    break


# TODO: modify to include the scientific name?
def write_raw_sequence_to_file(in_file, out_file, header, taxid=None):
    fastq = False
    lines_printed = 0
    if chr(header[0]) == '@':
        fastq = True
    if taxid:
        header = header.strip() + (" kraken:taxid|" + str(taxid) + "\n").encode()
    if out_file:
        out_file.write(header)
    lines_printed += 1
    header = ""
    while True:
        if fastq and lines_printed == 4:
            break
        if len(in_file.peek()) == 0:
            break
        if not fastq and chr(in_file.peek()[0]) == '>':
            break
        line = in_file.readline()
        if out_file:
            out_file.write(line)
        lines_printed += 1


def process_unpaired(input_filenames, classified_out_filename,
                     unclassified_out_filename, classified_headers):
    if classified_out_filename:
        classified_out_file = open(classified_out_filename, "wb")
    if unclassified_out_filename:
        unclassified_out_file = open(unclassified_out_filename, "wb")

    for filename in input_filenames:
        with open(filename, "rb") as in_file:
            for line in in_file:
                read_name = line.split(b' ', 1)[0]
                read_name = read_name[1:].strip()
                classified = read_name.decode() in classified_headers
                if classified and classified_out_filename:
                    taxid = classified_headers[read_name.decode()]
                    write_raw_sequence_to_file(
                        in_file, classified_out_file, line, taxid
                    )
                elif not classified and unclassified_out_filename:
                    write_raw_sequence_to_file(
                        in_file, unclassified_out_file, line
                    )
                else:
                    write_raw_sequence_to_file(in_file, None, line)

    if classified_out_filename:
        classified_out_file.close()
    if unclassified_out_filename:
        unclassified_out_file.close()


def process_paired(input_filenames, classified_out_filename,
                   unclassified_out_filename, classified_headers):
    if classified_out_filename:
        classified_out_filename1 =\
            classified_out_filename.replace('#', "_1")
        classified_out_file1 = open(classified_out_filename1, "wb")
        classified_out_filename2 =\
            classified_out_filename.replace('#', "_2")
        classified_out_file2 = open(classified_out_filename2, "wb")
    if unclassified_out_filename:
        unclassified_out_filename1 =\
            unclassified_out_filename.replace('#', "_1")
        unclassified_out_file1 = open(unclassified_out_filename1, "wb")
        unclassified_out_filename2 =\
            unclassified_out_filename.replace('#', "_2")
        unclassified_out_file2 = open(unclassified_out_filename2, "wb")
    for filename1, filename2 in zip(input_filenames[::1], input_filenames[::2]):
        with open(filename1, "rb") as in_file1, open(filename2, "rb") as in_file2:
            for line1, line2 in zip(in_file1, in_file2):
                read_name = line1.split(b' ', 1)[0]
                read_name = read_name[1:].strip()
                classified = read_name.decode() in classified_headers
                if classified and classified_out_filename:
                    taxid = classified_headers[read_name.decode()]
                    write_raw_sequence_to_file(
                        in_file1, classified_out_file1, line1, taxid
                    )
                    write_raw_sequence_to_file(
                        in_file2, classified_out_file2, line2, taxid
                    )
                elif not classified and unclassified_out_filename:
                    write_raw_sequence_to_file(
                        in_file1, unclassified_out_file1, line1
                    )
                    write_raw_sequence_to_file(
                        in_file2, unclassified_out_file2, line2
                    )
                else:
                    write_raw_sequence_to_file(
                        in_file1, None, line1
                    )
                    write_raw_sequence_to_file(
                        in_file2, None, line2
                    )

    if classified_out_filename:
        classified_out_file1.close()
        classified_out_file2.close()
    if unclassified_out_filename:
        unclassified_out_file1.close()
        unclassified_out_file2.close()


def process_interleaved(
        input_filenames, classified_out_filename,
        unclassified_out_filename, classified_headers):
    if classified_out_filename:
        classified_out_file = open(classified_out_filename, "wb")
    if unclassified_out_filename:
        unclassified_out_file = open(unclassified_out_filename, "wb")

    for filename in input_filenames:
        with open(filename, "rb") as in_file:
            for line in in_file:
                line2 = in_file.readline()
                read_name = line.split(b' ', 1)[0]
                read_name = read_name[1:].strip()
                classified = read_name.decode() in classified_headers
                if classified and classified_out_filename:
                    taxid = classified_headers[read_name.decode()]
                    write_raw_sequence_to_file(
                        in_file, classified_out_file, line, taxid
                    )
                    write_raw_sequence_to_file(
                        in_file, classified_out_file, line2, taxid
                    )
                elif not classified and unclassified_out_filename:
                    write_raw_sequence_to_file(
                        in_file, unclassified_out_file, line
                    )
                    write_raw_sequence_to_file(
                        in_file, unclassified_out_file, line2
                    )
                else:
                    write_raw_sequence_to_file(in_file, None, line)
                    write_raw_sequence_to_file(in_file, None, line2)

    if classified_out_filename:
        classified_out_file.close()
    if unclassified_out_filename:
        unclassified_out_file.close()


def write_fasta_sequences(
        args, input_filenames, classified_out_filename,
        unclassified_out_filename, classified_headers):

    if "paired" in args:
        process_paired(
            input_filenames, classified_out_filename,
            unclassified_out_filename, classified_headers
        )
    elif "interleaved" in args:
        process_interleaved(
            input_filenames, classified_out_filename,
            unclassified_out_filename, classified_headers
        )
    else:
        process_unpaired(
            input_filenames, classified_out_filename,
            unclassified_out_filename, classified_headers
        )


class TaxonomyStruct(ctypes.Structure):
    pass


class Taxonomy:
    def __init__(self, dll_pathname):
        self.dll = ctypes.CDLL(dll_pathname)
        self.dll.init_taxonomy.restype = ctypes.POINTER(TaxonomyStruct)
        self.dll.get_lca.restype = ctypes.c_uint64
        self.dll.get_internal_taxid.restype = ctypes.c_uint64
        self.dll.is_ancestor_of.restype = ctypes.c_bool
        self.dll.get_rank.restype = ctypes.c_char_p
        self.dll.get_child_count.restype = ctypes.c_uint64
        self.dll.taxid_to_name.restype = ctypes.c_char_p
        # self.dll.get_child_taxids.restype = ctypes.c_

    def generate_taxonomy(self, names, nodes, seqid2taxid, taxonomy_pathname):
        names = ctypes.c_char_p(names.encode())
        nodes = ctypes.c_char_p(nodes.encode())
        seqid2taxid = ctypes.c_char_p(seqid2taxid.encode())
        taxonomy_pathname = ctypes.c_char_p(taxonomy_pathname.encode())
        self.dll.generate_taxonomy(
            names, nodes, seqid2taxid, taxonomy_pathname
        )

    def load_taxonomy(self, taxonomy_pathname):
        tax_file = ctypes.c_char_p(taxonomy_pathname.encode())
        self.taxonomy = self.dll.init_taxonomy(tax_file)

    def get_lca(self, taxid1, taxid2):
        taxid1 = int(taxid1)
        taxid2 = int(taxid2)
        return self.dll.get_lca(self.taxonomy, taxid1, taxid2)

    def get_internal_taxid(self, taxid):
        return self.dll.get_internal_taxid(self.taxonomy, taxid)

    def is_ancestor_of(self, parent, child):
        return self.dll.is_ancestor_of(self.taxonomy, parent, child)

    def destroy_taxonomy(self):
        self.dll.destroy_taxonomy(self.taxonomy)

    def get_rank(self, taxid):
        return self.dll.get_rank(self.taxonomy, taxid)

    def taxid_to_name(self, taxid):
        return self.dll.taxid_to_name(self.taxonomy, taxid)

    def get_child_count(self, taxid):
        return self.dll.get_child_count(self.taxonomy, taxid)

    def get_parent_id(self, taxid):
        return self.dll.get_parent_id(self.taxonomy, taxid)

    def get_child_taxids(self, taxid):
        num_children = self.get_child_count(taxid)
        child_taxids = (ctypes.c_uint64 * num_children)(*([0] * num_children))
        self.dll.get_child_taxids(
            self.taxonomy, taxid, ctypes.byref(child_taxids), num_children
        )
        return child_taxids


class ReadCounts:
    def __init__(self):
        self.n_kmers = 0
        self.n_reads = 0

    def get_read_count(self):
        return self.n_reads

    def get_kmer_count(self):
        return self.n_kmers

    def increment_read_count(self):
        self.n_reads += 1

    def __iadd__(self, other):
        self.n_kmers += other.n_kmers
        self.n_reads += other.n_reads

        return self


class TaxonCounters:
    def __init__(self):
        self.counter = collections.defaultdict(ReadCounts)

    def __getitem__(self, taxid):
        return self.counter[taxid]

    def __setitem__(self, taxid, c):
        self.counter[taxid] = c

    def items(self):
        return self.counter.items()

    def keys(self):
        return self.counter.keys()


def get_clade_counters(taxonomy, call_counters):
    clade_counters = TaxonCounters()
    for k, v in call_counters.items():
        while k != 0:
            clade_counters[k] += v
            k = taxonomy.get_parent_id(k)

    return clade_counters


def print_kraken_style_report(out_file, report_kmer_data,
                              total_seqs, clade_counter, taxon_counter,
                              rank_string, taxid, scientific_name, depth):
    read_count = clade_counter.get_read_count()
    percentage = 100.0 * read_count / total_seqs
    out_string = "{:6.2f}".format(percentage) +\
        "\t" + str(clade_counter.get_read_count()) +\
        "\t" + str(taxon_counter.get_read_count()) +\
        "\t" + rank_string + "\t" + str(taxid) +\
        "\t" + " " * depth + scientific_name + "\n"
    out_file.write(out_string)


def kraken_report_dfs(out_file, taxid, report_zeros, report_kmer_data,
                      taxonomy, clade_counters, call_counters, total_seqs,
                      rank_code, rank_depth, depth):
    clade_counter = clade_counters[taxid]
    call_counter = call_counters[taxid]
    if not report_zeros and clade_counter.get_read_count() == 0:
        return

    rank = taxonomy.get_rank(taxid).decode()
    if rank == "superkingdom":
        rank_code = 'D'
        rank_depth = 0
    elif rank == "kingdom":
        rank_code = 'K'
        rank_depth = 0
    elif rank == "phylum":
        rank_code = 'P'
        rank_depth = 0
    elif rank == "class":
        rank_code = 'C'
        rank_depth = 0
    elif rank == "order":
        rank_code = 'O'
        rank_depth = 0
    elif rank == "family":
        rank_code = 'F'
        rank_depth = 0
    elif rank == "genus":
        rank_code = 'G'
        rank_depth = 0
    elif rank == "species":
        rank_code = 'S'
        rank_depth = 0
    else:
        rank_depth += 1

    rank_string = rank_code
    if rank_depth > 0:
        rank_string += str(rank_depth)
    scientific_name = taxonomy.taxid_to_name(taxid).decode()
    print_kraken_style_report(
        out_file, False, total_seqs, clade_counter,
        call_counter, rank_string, taxid, scientific_name,
        depth
    )
    children = sorted(
        taxonomy.get_child_taxids(taxid),
        key=lambda t: clade_counters[t].get_read_count()
    )
    for child_taxid in children:
        kraken_report_dfs(
            out_file, child_taxid, report_zeros,
            report_kmer_data, taxonomy, clade_counters,
            call_counters, total_seqs, rank_code, rank_depth,
            depth + 1
        )


def report_kraken_style(filename, report_zeros, report_kmer_data,
                        taxonomy, call_counters, total_seqs):
    clade_counters = get_clade_counters(taxonomy, call_counters)
    total_unclassified = call_counters[0].get_read_count()
    rank_code = "R"

    with open(filename, "w") as out_file:
        if total_unclassified > 0:
            print_kraken_style_report(
                out_file, False, total_seqs, call_counters[0],
                call_counters[0], "U", 0, "unclassified", 0
            )
        kraken_report_dfs(
            out_file, 1, report_zeros, False, taxonomy, clade_counters,
            call_counters, total_seqs, rank_code, 0, 0
        )


# Taken from ResolveTree function in classify.cc
def resolve_taxa_tree(hit_counts, taxonomy, total_kmers, args):
    max_taxid = 0
    max_score = 0
    required_score = math.ceil(args.confidence * total_kmers)

    for taxid in hit_counts.keys():
        score = 0
        for taxid2, counts in hit_counts.items():
            if taxonomy.is_ancestor_of(taxid2, taxid):
                score += counts
        if score > max_score:
            max_score, max_taxid = score, taxid
        elif score == max_score:
            max_taxid = taxonomy.get_lca(max_taxid, taxid)
    max_score = hit_counts[max_taxid]
    while max_taxid != 0 and max_score < required_score:
        max_score = 0
        for taxid, counts in hit_counts.items():
            if taxonomy.is_ancestor_of(max_taxid, taxid):
                max_score += counts
        if max_score >= required_score:
            return max_taxid
        else:
            max_taxid = taxonomy.get_parent_id(max_taxid)

    return max_taxid


def parse_taxid_counts(string, counts):
    index = 0
    counts_len = string.count(':') * 2
    if counts is None:
        counts = array.array('I', range(0, counts_len))
    _, counts_capacity = counts.buffer_info()
    if counts_capacity < counts_len:
        counts.extend(range(0, counts_len - counts_capacity))

    for entry in string.split():
        taxid, count = entry.split(':')
        if taxid == 'A':
            taxid, count = AMBIGUOUS_TAXID, int(count)
        # separator for paired reads, set the taxid and count
        # to an invalid value
        elif taxid == '|':
            taxid, count = 0, 0
        else:
            taxid, count = int(taxid), int(count)
        counts[index] = taxid
        index += 1
        counts[index] = count
        index += 1

    return counts, counts_len


def get_lca(taxonomy, cache, taxid1, taxid2):
    taxid1, taxid2 = (taxid1, taxid2) if taxid1 < taxid2 else (taxid2, taxid1)
    if (taxid1, taxid2) not in cache:
        cache[(taxid1, taxid2)] = taxonomy.get_lca(taxid1, taxid2)
    return cache[(taxid1, taxid2)]


def merge_counts(
        taxonomy, lca_cache, counts1, counts1_len, counts2, counts2_len):
    index1 = 0
    index2 = 0
    total_minimizers = 0
    counts_str = ""
    counts_map = collections.defaultdict(int)
    final_counts = []
    while True:
        taxid1, count1 = counts1[index1], counts1[index1 + 1]
        taxid2, count2 = counts2[index2], counts2[index2 + 1]
        final_count = 0
        # final_taxid = taxonomy.get_lca(taxid1, taxid2)
        if taxid1 == AMBIGUOUS_TAXID or taxid2 == AMBIGUOUS_TAXID:
            final_taxid = AMBIGUOUS_TAXID
        else:
            final_taxid = get_lca(taxonomy, lca_cache, taxid1, taxid2)
        if count1 < count2:
            counts2[index2 + 1] -= count1
            index1 += 2
            final_count = count1
            if counts2[index2 + 1] == 0:
                index2 += 2
        elif count2 < count1:
            counts1[index1 + 1] -= count2
            index2 += 2
            final_count = count2
            if counts1[index1 + 1] == 0:
                index1 += 2
        else:
            index1 += 2
            index2 += 2
            final_count = count1

        total_minimizers += final_count
        counts_map[final_taxid] += final_count
        final_counts.append((final_taxid, final_count))

        if index1 >= counts1_len:
            break

    def taxid_to_string(taxid):
        if taxid == AMBIGUOUS_TAXID:
            return "A"
        else:
            return str(taxid)

    counts_str = io.StringIO()
    previous_taxid, previous_count = final_counts[0]
    if len(final_counts) > 1:
        for taxid, count in final_counts[1:]:
            if taxid == 0 and count == 0:
                counts_str.write(taxid_to_string(previous_taxid))
                counts_str.write(":")
                counts_str.write(str(previous_count))
                counts_str.write(" ")
                counts_str.write("|")
                counts_str.write(":")
                counts_str.write("|")
                counts_str.write(" ")
                previous_taxid = taxid
                previous_count = count
            if previous_taxid == taxid:
                previous_count += count
            else:
                # counts_str += str(previous_taxid) + ":" + str(previous_count)
                # counts_str += " "
                counts_str.write(taxid_to_string(previous_taxid))
                counts_str.write(":")
                counts_str.write(str(previous_count))
                counts_str.write(" ")
                previous_taxid = taxid
                previous_count = count
    # counts_str += str(previous_taxid) + ":" + str(previous_count)
    counts_str.write(taxid_to_string(previous_taxid))
    counts_str.write(":")
    counts_str.write(str(previous_count))

    return (counts_str.getvalue(), counts_map, total_minimizers)


def merge_classification_output(
        taxonomy, in_filename1, in_filename2,
        out_filename, use_names, args, final=False):
    call_counters = TaxonCounters()
    total_seqs = 0
    counts_array1 = None
    counts_array2 = None
    lca_cache = {}
    with open(in_filename1) as file1, open(in_filename2) as file2:
        with open(out_filename, "w") as out_file:
            for (line1, line2) in zip(file1, file2):
                (status1, name1, taxid1, len1, counts1) =\
                    line1.strip().split('\t', 5)
                (status2, name2, taxid2, len2, counts2) =\
                    line2.strip().split('\t', 5)
                if status1 == "C" and status2 == "C":
                    status = "C"
                    counts_array1, counts_len1 =\
                        parse_taxid_counts(counts1, counts_array1)
                    counts_array2, counts_len2 =\
                        parse_taxid_counts(counts2, counts_array2)
                    counts, counts_map, total_minimizers =\
                        merge_counts(
                            taxonomy, lca_cache, counts_array1,
                            counts_len1, counts_array2, counts_len2
                        )
                    taxid = resolve_taxa_tree(
                        counts_map, taxonomy, total_minimizers, args
                    )
                elif status1 == "C" and status2 == "U":
                    status = "C"
                    taxid = taxid1
                    counts = counts1
                elif status1 == "U" and status2 == "C":
                    status = "C"
                    taxid = taxid2
                    counts = counts2
                else:
                    status = "U"
                    taxid = 0
                    counts = "0:0"

                if final:
                    call_counters[int(taxid)].increment_read_count()
                total_seqs += 1
                if use_names:
                    scientific_name = taxonomy.taxid_to_name(int(taxid))
                    records = "\t".join(
                        [status, name1, scientific_name.decode(),
                         "(taxid " + str(taxid) + ")", len1]
                    )

                else:
                    records = "\t".join(
                        [status, name1, str(taxid), len1]
                    )
                out_file.write(records)
                out_file.write("\t")
                out_file.write(counts)
                out_file.write("\n")
    return (call_counters, total_seqs) if final else (None, total_seqs)


def merge_classification_output2(
        taxonomy_pathname, lines, job_number, use_names, args,
        save_seq_names, final):
    taxonomy_dll_pathname = find_kraken2_binary("libtax.so")
    taxonomy = Taxonomy(taxonomy_dll_pathname)
    taxonomy.load_taxonomy(taxonomy_pathname)

    call_counters = TaxonCounters()
    total_seqs = 0
    counts_array1 = None
    classified_headers = {}
    counts_array2 = None
    lca_cache = {}
    out_filename = tempfile.mktemp(
        prefix="k2_job" + str(job_number) + "_"
    )
    with open(out_filename, "w") as out_file:
        for (left, right) in lines:
            (status1, seq_name1, taxid1, len1, counts1) =\
                left.strip().split('\t', 5)
            (status2, seq_name2, taxid2, len2, counts2) =\
                right.strip().split('\t', 5)
            if status1 == "C" and status2 == "C":
                status = "C"
                counts_array1, counts_len1 =\
                    parse_taxid_counts(counts1, counts_array1)
                counts_array2, counts_len2 =\
                    parse_taxid_counts(counts2, counts_array2)
                counts, counts_map, total_minimizers =\
                    merge_counts(
                        taxonomy, lca_cache, counts_array1,
                        counts_len1, counts_array2, counts_len2
                    )
                taxid = resolve_taxa_tree(
                    counts_map, taxonomy, total_minimizers, args
                )
            elif status1 == "C" and status2 == "U":
                status = "C"
                taxid = taxid1
                counts = counts1
            elif status1 == "U" and status2 == "C":
                status = "C"
                taxid = taxid2
                counts = counts2
            else:
                status = "U"
                taxid = 0
                counts = "0:0"

            if final:
                call_counters[int(taxid)].increment_read_count()
                if save_seq_names:
                    if status == "C":
                        classified_headers[seq_name1] = taxid
            total_seqs += 1
            if use_names:
                scientific_name = taxonomy.taxid_to_name(int(taxid))
                records = "\t".join(
                    [status, seq_name1,
                     scientific_name.decode() + " " + "(taxid " + str(taxid) + ")", len1]
                )
            else:
                records = "\t".join(
                    [status, seq_name1, str(taxid), len1]
                )
            out_file.write(records)
            out_file.write("\t")
            out_file.write(counts)
            out_file.write("\n")
    if final:
        return (
            out_filename, call_counters, total_seqs, classified_headers
        )
    else:
        return (out_filename, None, None, None)


def merge_classification_output_parallel(
        pool, taxonomy_pathname, in_filename1, in_filename2,
        out_filename, use_names, args, save_seq_names, final):
    filenames = []
    call_counters = TaxonCounters()
    total_seqs = 0
    with open(in_filename1) as file1, open(in_filename2) as file2:
        input = list(zip(file1.readlines(), file2.readlines()))
    input_len = len(input)
    partition_ranges = list(range(0, input_len, int(input_len / args.threads)))
    partition_ranges.append(input_len)
    job_number = 0
    futures = []
    wrapped_func = functools.partial(
        wrap_with_globals, merge_classification_output2,
        LOG.get_queue(), LOG.get_level(),
        SCRIPT_PATHNAME
    )
    for start, end in zip(partition_ranges, partition_ranges[1:]):
        future = pool.submit(
            wrapped_func, taxonomy_pathname,
            input[start:end], job_number, use_names, args,
            save_seq_names, final
        )
        futures.append(future)
        job_number += 1
    done, not_done = concurrent.futures.wait(futures)

    classified_headers = []
    for future in done:
        filename, counters, total, classified_set =\
            future.result()
        filenames.append(filename)
        if final:
            classified_headers.append(classified_set)
            total_seqs += total
            for key, value in counters.items():
                call_counters[key] += value

    with open(out_filename, "w") as out_file:
        for filename in sorted(filenames):
            with open(filename, "r") as in_file:
                shutil.copyfileobj(in_file, out_file)
            os.remove(filename)

    if final:
        classified_headers = collections.ChainMap(*classified_headers)
        return (call_counters, total_seqs, classified_headers)
    else:
        return (None, total_seqs, None)


def sanity_check_taxonomies(db_pathnames):
    LOG.info("Sanity checking taxonomies\n")
    nodes = {}
    for db_pathname in db_pathnames:
        nodes_pathname = os.path.join(db_pathname, "nodes.dmp")
        nodes_pathname2 = os.path.join(db_pathname, os.path.join("taxonomy", "nodes.dmp"))
        if os.path.exists(nodes_pathname):
            pass
        elif os.path.exists(nodes_pathname2):
            nodes_pathname = nodes_pathname2
        else:
            LOG.error("Cannot find nodes.dmp file in {}\n".format(db_pathname))
            sys.exit(1)
        with open(nodes_pathname, "r") as in_file:
            for line in in_file:
                line = line.strip()
                taxid, parent_taxid = line.split('|')[0:2]
                parent_taxid = parent_taxid.strip()
                taxid = taxid.strip()
                if taxid in nodes and nodes[taxid] != parent_taxid:
                    LOG.error(
                        "taxid {} does not map to the same parent "
                        "taxid in some of the nodes file\n".format(taxid)
                    )
                    sys.exit(1)
                nodes[taxid] = parent_taxid
    LOG.info("Finished sanity checking taxonomies\n")


def classify_multi_dbs(args):
    dbs = args.db
    use_names = False
    report_filename = None
    report_zeros = False
    classified_out_filename = None
    unclassified_out_filename = None
    if "use_mpa_style" in args:
        LOG.error(
            "--use-mpa-style not supported when using multiple dbs\n"
        )
        sys.exit(1)
    if "report_minimizer_data" in args:
        LOG.error(
            "--report-minimizer-data not supported when using "
            "multiple dbs\n"
        )
        sys.exit(1)
    if "report" in args:
        report_filename, args.report = (args.report, None)
    output = args.output if "output" in args else None
    if "report_zeros" in args:
        report_zeros = True
        args.report_zeros = None
    if "classified_out" in args:
        classified_out_filename = args.classified_out
        args.classified_out = None
    if "unclassified_out" in args:
        unclassified_out_filename = args.unclassified_out
        args.unclassified_out = None
    if "use_names" in args:
        use_names = True
        args.use_names = None
    tmp_filenames = []

    LOG.info("Creating merged taxonomy\n")
    seqid2taxid_maps = []
    for db in dbs:
        seqid2taxid_map = os.path.join(db, "seqid2taxid.map")
        if not os.path.exists(seqid2taxid_map):
            LOG.error(
                "Unable to find seqid2taxid.map for database {}\n".format(db)
            )
            LOG.error(
                "seqid2taxid.map files are needed to create a merged taxonomy"
                " for mulit-database classification\n"
            )
        seqid2taxid_maps.append(seqid2taxid_map)
    seqid2taxid_maps = sorted(seqid2taxid_maps)
    with tempfile.NamedTemporaryFile(
            prefix="k2_seqid2taxid", suffix=".map", delete=False) as out_file:
        seqid2taxid_map_pathname = out_file.name
        for seqid2taxid_map in seqid2taxid_maps:
            with open(seqid2taxid_map, "rb") as in_file:
                shutil.copyfileobj(in_file, out_file)
    sanity_check_taxonomies(dbs)
    if os.path.exists(os.path.join(dbs[0], "taxonomy")):
        taxonomy_pathname = os.path.join(dbs[0], "taxonomy")
        names_pathname = os.path.join(taxonomy_pathname, "names.dmp")
        nodes_pathname = os.path.join(taxonomy_pathname, "nodes.dmp")
    else:
        names_pathname = os.path.join(dbs[0], "names.dmp")
        nodes_pathname = os.path.join(dbs[0], "nodes.dmp")
    taxonomy_dll_pathname = find_kraken2_binary("libtax.so")
    taxonomy = Taxonomy(taxonomy_dll_pathname)
    taxonomy_pathname = tempfile.mktemp(prefix="k2_taxo", suffix=".k2d")
    taxonomy.generate_taxonomy(
        names_pathname, nodes_pathname, seqid2taxid_map_pathname,
        taxonomy_pathname
    )
    LOG.info("Finished creating and loading merged taxonomy\n")

    for db in dbs:
        args.db = db
        pathname = tempfile.mktemp(prefix="k2_")
        tmp_filenames.append(pathname)
        args.output = pathname

        LOG.info(
            "Running classification job for database {}\n".format(db)
        )
        classify(args)
        LOG.info(
            "Finished running classification job for database {}\n".format(db)
        )

    LOG.info("Merging output files\n")
    out_filename = tempfile.mktemp(prefix="k2_")
    tmp_filenames_copy = tmp_filenames.copy()
    tmp_filenames.append(out_filename)
    final = False
    pool = concurrent.futures.ProcessPoolExecutor(max_workers=args.threads)
    progress = ProgressBar(len(tmp_filenames) - 1, 0)
    while True:
        LOG.debug("{}\r".format(progress.get_bar()))
        if len(tmp_filenames_copy) == 2:
            final = True
        filename1, filename2 = tmp_filenames_copy.pop(), tmp_filenames_copy.pop()
        # LOG.info("merging classification output\n")
        save_seq_names = final and (classified_out_filename is not None
                                    or unclassified_out_filename is not None)
        call_counters, total_seqs, classified_headers =\
            merge_classification_output_parallel(
                pool, taxonomy_pathname, filename1, filename2, out_filename,
                use_names and final, args, save_seq_names, final
            )
        progress.progress(1, relative=True)
        # LOG.info("finished merging classification output\n")
        LOG.debug("{}\r".format(progress.get_bar()))
        if len(tmp_filenames_copy) == 0:
            break
        tmp_filenames_copy.insert(0, out_filename)
        out_filename = filename1
    progress.progress(1, relative=True)
    # LOG.info("finished merging classification output\n")
    LOG.debug("{}\r".format(progress.get_bar()))
    pool.shutdown()
    # Add space to overwrite progress bar
    LOG.info("Finished merging output files              \n")
    if output:
        shutil.move(out_filename, os.path.abspath(output))
        if out_filename in tmp_filenames:
            tmp_filenames.remove(out_filename)
    else:
        with open(out_filename, "r") as out_file:
            shutil.copyfileobj(out_file, sys.stdout)
    for tmp_filename in tmp_filenames:
        os.remove(tmp_filename)

    if classified_out_filename or unclassified_out_filename:
        LOG.info("Writing (un)classified sequences to file\n")
        write_fasta_sequences(
            args, args.filenames, classified_out_filename,
            unclassified_out_filename, classified_headers
        )
        LOG.info("Finished writing (un)classified sequences to file\n")
    if report_filename is not None:
        LOG.info("Generating report file\n")
        taxonomy.load_taxonomy(taxonomy_pathname)
        report_kraken_style(
            report_filename, report_zeros, False,
            taxonomy, call_counters, total_seqs
        )
        LOG.info("Finished generating report file\n")
    os.remove(taxonomy_pathname)
    os.remove(seqid2taxid_map_pathname)


def classify(args):
    classify_bin = find_kraken2_binary("classify")
    database_path = find_database(args.db)
    if database_path is None:
        LOG.error("{:s} is not a valid database... exiting\n".format(args.db))
        sys.exit(1)
    if "paired" in args and len(args.filenames) % 2 != 0:
        LOG.error("--paired requires an even number of file names\n")
        sys.exit(1)
    if args.confidence < 0 or args.confidence > 1:
        LOG.error(
            "--confidence, {:f}, must be between 0 and 1 inclusive\n".format(
                args.confidence
            )
        )
        sys.exit(1)
    argv = [
        classify_bin,
        "-H",
        os.path.join(database_path, "hash.k2d"),
        "-t",
        os.path.join(database_path, "taxo.k2d"),
        "-o",
        os.path.join(database_path, "opts.k2d"),
    ]
    wrapper_args_to_binary_args(args, argv, get_binary_options(classify_bin))
    if any([is_compressed(filename) for filename in args.filenames]):
        with tempfile.TemporaryDirectory() as temp_dir_name:
            fifo1_pathname = os.path.join(temp_dir_name, "fifo1")
            fifo2_pathname = None
            try:
                os.mkfifo(fifo1_pathname, 0o600)
            except OSError:
                LOG.error(
                    "Unable to create FIFO for processing compressed files\n"
                )
                sys.exit(1)
            if "-P" in argv:
                fifo2_pathname = os.path.join(temp_dir_name, "fifo2")
                try:
                    os.mkfifo(fifo2_pathname, 0o600)
                except OSError:
                    LOG.error(
                        "Unable to create FIFO for processing compressed files\n"
                    )
                    sys.exit(1)
                argv.extend([fifo1_pathname, fifo2_pathname])
            else:
                argv.append(fifo1_pathname)
            if args.use_daemon:
                thread = threading.Thread(
                    target=classify_using_daemon, args=(args, argv)
                )
            else:
                thread = threading.Thread(target=subprocess.call, args=(argv,))
            thread.start()
            if "-P" in argv:
                writer_thread1 = threading.Thread(
                    target=write_to_fifo,
                    args=(args.filenames[0::2], fifo1_pathname)
                )
                writer_thread2 = threading.Thread(
                    target=write_to_fifo,
                    args=(args.filenames[1::2], fifo2_pathname)
                )
                writer_thread1.start()
                writer_thread2.start()
                writer_thread1.join()
                writer_thread2.join()
            else:
                write_to_fifo(args.filenames, fifo1_pathname)
            thread.join()
    else:
        for i, filename in enumerate(args.filenames):
            args.filenames[i] = os.path.abspath(filename)
        argv.extend(args.filenames)
        if args.use_daemon:
            classify_using_daemon(args, argv)
        else:
            subprocess.call(argv)


def inspect_db(args):
    database_pathname = find_database(args.db)
    if not database_pathname:
        LOG.error("{:s} database does not exist\n".format(args.db))
        sys.exit(1)
    for database_file in ["taxo.k2d", "hash.k2d", "opts.k2d"]:
        if not os.path.isfile(os.path.join(database_pathname, database_file)):
            LOG.error("{:s} does not exist\n".format(database_file))
    dump_table_bin = find_kraken2_binary("dump_table")
    argv = [
        dump_table_bin,
        "-H",
        os.path.join(database_pathname, "hash.k2d"),
        "-t",
        os.path.join(database_pathname, "taxo.k2d"),
        "-o",
        os.path.join(database_pathname, "opts.k2d"),
    ]
    # dump_table does not save the table header to file.
    # This is a workaround helps enables us to capture
    # the entire output.
    output_filename, args.output = args.output, None
    wrapper_args_to_binary_args(args, argv, get_binary_options(dump_table_bin))
    process = subprocess.Popen(
        argv, stdout=subprocess.PIPE
    )

    if output_filename == "-":
        shutil.copyfileobj(process.stdout, sys.stdout.buffer)
    else:
        with open(output_filename, "wb") as fout:
            shutil.copyfileobj(process.stdout, fout)
    process.wait()


def format_bytes(size):
    current_suffix = "B"
    for suffix in ["kB", "MB", "GB", "TB", "PB", "EB"]:
        if size >= 1024:
            current_suffix = suffix
            size /= 1024
        else:
            break
    return "{:.2f}{:s}".format(size, current_suffix)


def clean_up(filenames):
    LOG.info("Removing the following files: {}\n".format(filenames))
    # walk the directory tree to get the size of the individual files
    # sum them up to get the usage stat
    space_freed = format_bytes(remove_files(filenames))
    LOG.info(
        "Cleaned up {} of space\n".format(space_freed)
    )


def range_parser(input):
    if input == "all":
        input = ".."
    volumes = []
    regex = re.compile(r"(\d+)?(?:\.{2,3}|\-|:)(\d+)?")
    for volume in input.replace(' ', '').split(','):
        if volume.isdecimal():
            volumes.append(int(volume))
        else:
            match = regex.match(volume)
            if not match:
                raise argparse.ArgumentTypeError(input)
            start, end = match.group(1), match.group(2)
            if not start:
                start = 0
            if not end:
                end = 1000
            start, end = int(start), int(end) + 1
            expanded_range = list(range(start, end))
            volumes.extend(expanded_range)
    volumes = set(volumes)
    return volumes


def clean_db(args):
    if args.stop_daemon:
        message_daemon(b"STOP\n")
        LOG.info("Stopped background classifier process\n")
        cleanup_fifos()
    else:
        os.chdir(args.db)
        if args.pattern:
            clean_up(glob.glob(args.pattern, recursive=False))
        else:
            clean_up(
                [
                    "data",
                    "library",
                    "taxonomy",
                    "seqid2taxid.map",
                    "prelim_map.txt",
                ]
            )


def make_build_parser(subparsers):
    parser = subparsers.add_parser(
        "build",
        help="Build a database from library\
              (requires taxonomy which can be downloading\
              via download-taxonomy subcommand, and at least one library\
              which can be added via the download-library or\
              add-to-library subcommands).",
    )
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Pathname to database folder where building will take place.",
    )
    group = parser.add_argument_group("special")
    mutex_group = group.add_mutually_exclusive_group()
    mutex_group.add_argument(
        "--standard",
        action="store_true",
        help="Make standard database which includes: archaea,\
               bacteria, human, plasmid, UniVec_Core, and viral."
    )
    mutex_group.add_argument(
        "--special",
        type=str,
        choices=["greengenes", "rdp", "silva", "gtdb"],
        help="Build special database. RDP is currently unavailable\
              as URLs no longer work.",
    )
    group.add_argument(
        "--gtdb-files",
        type=str,
        nargs="+",
        help="A list of files or regex matching the files needed to build\
              the special database."
    )
    group.add_argument(
        "--gtdb-use-ncbi-taxonomy",
        action="store_true",
        help="Use NCBI tax IDs and taxonomy tree when building GTDB database"
    )
    group.add_argument(
        "--gtdb-server",
        type=str,
        default=GTDB_SERVER,
        help="The GTDB server to use (default: {})".format(GTDB_SERVER)
    )
    group.add_argument(
        "--no-masking",
        action="store_true",
        help="Avoid masking low-complexity sequences prior to\
              building database.",
    )
    group.add_argument(
        "--masker-threads",
        type=int,
        default=4,
        metavar="K2MASK_THREADS",
        help="Number of threads used by k2mask during masking\
              process (default: 4)"
    )
    parser.add_argument(
        "--kmer-len", type=int, metavar="INT", help="K-mer length in bp/aa"
    )
    parser.add_argument(
        "--minimizer-len",
        type=int,
        metavar="INT",
        help="Minimizer length in bp/aa",
    )
    parser.add_argument(
        "--minimizer-spaces",
        type=int,
        metavar="INT",
        help="Number of characters in minimizer that are\
              ignored in comparisons",
    )
    parser.add_argument(
        "--threads",
        type=int,
        metavar="INT",
        default=os.environ.get("KRAKEN2_NUM_THREADS") or 1,
        help="Number of threads",
    )
    parser.add_argument(
        "--load-factor",
        type=float,
        metavar="FLOAT (0,1]",
        default=0.7,
        help="Proportion of the hash table to be populated (default: 0.7)",
    )
    parser.add_argument(
        "--fast-build",
        action="store_true",
        help="Do not require database to be deterministically\
              built when using multiple threads. This is faster, but\
              does introduce variability in minimizer/LCA pairs.",
    )
    parser.add_argument(
        "--max-db-size",
        # type=int,
        metavar="SIZE",
        help="Maximum number of bytes for Kraken 2 hash table;\
              if the estimator determines more would normally be\
              needed, the reference library will be downsampled to fit",
    )
    parser.add_argument(
        "--skip-maps",
        action="store_true",
        help="Avoids downloading accession number to taxid maps",
    )
    parser.add_argument(
        "--protein",
        action="store_true",
        help="Build a protein database for translated search",
    )
    parser.add_argument(
        "--block-size",
        type=int,
        metavar="INT",
        default=16384,
        help="Read block size (default: 16384)",
    )
    parser.add_argument(
        "--sub-block-size",
        type=int,
        metavar="INT",
        default=0,
        help="Read subblock size",
    )
    parser.add_argument(
        "--minimum-bits-for-taxid",
        type=int,
        metavar="INT",
        default=0,
        help="Bit storage requested for taxid",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log file (default: stderr)",
    )


def make_download_taxonomy_parser(subparsers):
    parser = subparsers.add_parser(
        "download-taxonomy", help="Download NCBI taxonomic information"
    )
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Pathname to Kraken2 database",
    )
    # parser.add_argument(
    #     "--source",
    #     type=str,
    #     choices=[
    #         "GTDB",
    #         "NCBI"
    #     ],
    #     default="NCBI",
    #     help="From which database should the files be downloaded"
    # )
    parser.add_argument(
        "--protein",
        action="store_true",
        help="Files being added are for a protein database",
    )
    parser.add_argument(
        "--skip-maps",
        action="store_true",
        help="Avoids downloading accession number to taxid maps",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )


def make_download_library_parser(subparsers):
    parser = subparsers.add_parser(
        "download-library", aliases=["download"],
        help="Download and build a special database"
    )
    parser.register("type", "range", range_parser)
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Pathname to Kraken2 database",
    )
    parser.add_argument(
        "--library",
        "--taxid",
        "--project",
        "--accession",
        type=str,
        dest="library",
        required=True,
        # choices=[
        #     "archaea",
        #     "bacteria",
        #     "plasmid",
        #     "plastid",
        #     "viral",
        #     "human",
        #     "invertebrate",
        #     "fungi",
        #     "plant",
        #     "protozoa",
        #     "vertebrate_other"
        #     "vertebrate_mammalian",
        #     "mitochondrion",
        #     "nr",
        #     "nt",
        #     "UniVec",
        #     "UniVec_Core",
        # ],
        help="Name of library to download",
    )
    parser.add_argument(
        "--assembly-source",
        type=str,
        required=False,
        choices=["refseq", "genbank", "all"],
        default="refseq",
        help="Download RefSeq (GCF_) or GenBank (GCA_) genome assemblies\
              or both (default RefSeq)",
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Resume fetching the files needed for a library, skipping files\
        that have already been downloaded",
    )
    parser.add_argument(
        "--assembly-levels",
        type=str,
        nargs="+",
        choices=["chromosome", "complete_genome", "scaffold", "contig"],
        default=["chromosome", "complete_genome"],
        help="Only return genome assemblies that have one of the specified\
              assembly levels (default chromosome and complete genome)"
    )
    parser.add_argument(
        "--has-annotation",
        action="store_true",
        help="Return only annotated genome assemblies (default: false)"
    )
    parser.add_argument(
        "--blast-volumes",
        type="range",
        default="all",
        help="A comma separated list of the blast volume numbers to download.\
        Ranges are also accepted in the forms start..end, start-end, start:end,\
        ranges are inclusive (default: all volumes)"
    )
    parser.add_argument(
        "--protein",
        action="store_true",
        help="Files being added are for a protein database",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )
    parser.add_argument(
        "--threads",
        type=int,
        metavar="THREADS",
        default=1,
        help="The number of threads/processes k2 uses when downloading\
              and processing library files.",
    )
    masking_parser = parser.add_mutually_exclusive_group()
    masking_parser.add_argument(
        "--no-masking",
        action="store_true",
        help="Avoid asking low-complexity sequences prior to\
              building; masking requires k2mask or segmasker to be\
              installed",
    )
    masking_parser.add_argument(
        "--masker-threads",
        type=int,
        default=4,
        metavar="K2MASK_THREADS",
        help="Number of threads used by k2mask during masking\
              process (default: 4)"
    )


def make_add_to_library_parser(subparsers):
    parser = subparsers.add_parser(
        "add-to-library", help="Add file(s) to library"
    )
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Pathname to Kraken2 database",
    )
    parser.add_argument(
        "--threads",
        type=int,
        metavar="THREADS",
        default=1,
        help="The number of threads/processes k2 uses when\
              adding library files."
    )
    parser.add_argument(
        "--file",
        "--files",
        type=str,
        nargs="+",
        required=True,
        dest="files",
        help="""Pathname or patterns of file(s) to be added to library.
                Supported pattern are as follows:
              ? - A question-mark is a pattern that shall match any
                  character.
              * - An asterisk is a pattern that shall match multiple
                  characters.
              [ - The open bracket shall introduce a pattern bracket
                  expression.
             ** - will match any files and zero or more directories,
                  subdirectories and symbolic links to directories.
        """,
    )
    parser.add_argument(
        "--protein",
        action="store_true",
        help="Files being added are for a protein database",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )
    # parser.add_argument(
    #     "--skip-md5",
    #     action="store_true",
    #     help="K2 will by default perform an MD5 check to determine whether\
    #          a file has already been added. This option will allow the user\
    #          to skip this and instead simply compare filenames."
    # )
    masking_parser = parser.add_mutually_exclusive_group()
    masking_parser.add_argument(
        "--no-masking",
        action="store_true",
        help="Avoid asking low-complexity sequences prior to\
              building; masking requires k2mask or segmasker to be\
              installed",
    )
    masking_parser.add_argument(
        "--masker-threads",
        type=int,
        metavar="K2MASK_THREADS",
        default=4,
        help="Number of threads used by k2mask during masking process\
              (default: 4)"
    )


def make_classify_parser(subparsers):
    parser = subparsers.add_parser(
        "classify", help="Classify a set of sequences"
    )
    parser.add_argument(
        "--db",
        type=lambda x: x.split(","),
        metavar="PATHNAME",
        # nargs="+",
        required=True,
        help="Pathname to Kraken2 database(s).\
        Multiple databases are specified as a comma-\
        separated list with no spaces.",
    )
    parser.add_argument(
        "--threads",
        type=int,
        metavar="INT",
        default=os.environ.get("KRAKEN2_NUM_THREADS") or 1,
        help="Number of threads",
    )
    parser.add_argument(
        "--use-daemon",
        action="store_true",
        help="Spawn a background process that keeps any loaded indexes\
        in memory. Subsequent invokations of classify with this option will\
        skip the index loading process and immediately start classifying\
        reads. If a new index is specified that index will also be persisted.\
        Use k2 clean --stop-daemon to stop the background process."
    )
    parser.add_argument(
        "--quick",
        action="store_true",
        default=argparse.SUPPRESS,
        help="Quick operation (use first hit or hits)",
    )
    parser.add_argument(
        "--unclassified-out",
        type=str,
        default=argparse.SUPPRESS,
        metavar="FILENAME",
        help="Print unclassified sequences to filename",
    )
    parser.add_argument(
        "--classified-out",
        type=str,
        metavar="FILENAME",
        default=argparse.SUPPRESS,
        help="Print classified sequences to filename",
    )
    parser.add_argument(
        "--output",
        type=str,
        metavar="FILENAME",
        default=argparse.SUPPRESS,
        help='Print output to file (default: stdout) "-" will \
              suppress normal output',
    )
    parser.add_argument(
        "--confidence",
        type=float,
        default=0.0,
        help="confidence score threshold (default: 0.0); must be in [0,1]",
    )
    parser.add_argument(
        "--minimum-base-quality",
        type=int,
        metavar="INT",
        default=0,
        help="Minimum base quality used in classification",
    )
    parser.add_argument(
        "--report",
        type=str,
        default=argparse.SUPPRESS,
        help="Print a report with aggregate counts/clade to file",
    )
    parser.add_argument(
        "--use-mpa-style",
        action="store_true",
        default=argparse.SUPPRESS,
        help="With --report, format report output like Kraken 1's\
              kraken-mpa-report",
    )
    parser.add_argument(
        "--report-zero-counts",
        action="store_true",
        default=argparse.SUPPRESS,
        help="With --report, report counts for ALL taxa, even if\
              counts are zero",
    )
    parser.add_argument(
        "--report-minimizer-data",
        action="store_true",
        default=argparse.SUPPRESS,
        help="With --report, report minimizer and distinct minimizer\
              count information in addition to normal Kraken report",
    )
    parser.add_argument(
        "--memory-mapping",
        action="store_true",
        default=argparse.SUPPRESS,
        help="Avoids loading entire database into RAM",
    )
    paired_group = parser.add_mutually_exclusive_group()
    paired_group.add_argument(
        "--paired",
        action="store_true",
        default=argparse.SUPPRESS,
        help="The filenames provided have paired-end reads",
    )
    paired_group.add_argument(
        "--interleaved",
        action="store_true",
        default=argparse.SUPPRESS,
        help="The filenames provided have paired-end reads",
    )
    parser.add_argument(
        "--use-names",
        action="store_true",
        default=argparse.SUPPRESS,
        help="Print scientific names instead of just taxids",
    )
    parser.add_argument(
        "--minimum-hit-groups",
        type=int,
        metavar="INT",
        default=2,
        help="Minimum number of hit groups (overlapping k-mers\
              sharing the same minimizer) needed to make a call\
              (default 2)",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )
    parser.add_argument(
        "filenames",
        nargs="*",
        type=str,
        help="Filenames to be classified, supports bz2, gzip, and xz",
    )


def make_inspect_parser(subparsers):
    parser = subparsers.add_parser("inspect", help="Inspect Kraken 2 database")
    parser.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        required=True,
        help="Pathname to Kraken2 database",
    )
    # parser.add_argument(
    #     "--threads",
    #     type=int,
    #     default=os.environ.get("KRAKEN2_NUM_THREADS") or 1,
    #     help="Number of threads",
    # )
    parser.add_argument(
        "--skip-counts",
        action="store_true",
        help="Only print database summary statistics",
    )
    parser.add_argument(
        "--use-mpa-style",
        action="store_true",
        help="Format output like Kraken 1's kraken-mpa-report",
    )
    parser.add_argument(
        "--report-zero-counts",
        action="store_true",
        help="Report counts for ALL taxa, even if counts are zero",
    )
    parser.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )
    parser.add_argument(
        "--output", "--out",
        type=str,
        metavar="FILENAME",
        default="-",
        help="Write inspect output to FILENAME (default: stdout)"
    )
    parser.add_argument(
        "--memory-mapping",
        action="store_true",
        default=argparse.SUPPRESS,
        help="Avoids loading entire database into RAM",
    )


def make_clean_parser(subparsers):
    parser = subparsers.add_parser(
        "clean", help="Removes unwanted files from database"
    )

    actions = parser.add_argument_group(
        "required", "Arguments required by the cleaner"
    )
    mg = actions.add_mutually_exclusive_group(required=True)
    mg.add_argument(
        "--stop-daemon",
        action="store_true",
        help="Stop a running background process",
    )
    mg.add_argument(
        "--db",
        type=str,
        metavar="PATHNAME",
        # required=True,
        help="Pathname to Kraken2 database",
    )

    options = parser.add_argument_group(
        "options", "options for cleaning temporary files"
    )
    options.add_argument(
        "--log",
        type=str,
        metavar="FILENAME",
        default=None,
        help="Specify a log filename (default: stderr)",
    )
    options.add_argument(
        "--pattern",
        type=str,
        metavar="SHELL_REGEX",
        default=None,
        help="""Files that match this regular expression will be deleted.
              ? - A question-mark is a pattern that shall match any
                  character.
              * - An asterisk is a pattern that shall match multiple
                  characters.
              [ - The open bracket shall introduce a pattern bracket
                  expression.
             ** - will match any files and zero or more directories,
                  subdirectories and symbolic links to directories.
        """
    )


class HelpAction(argparse._HelpAction):
    def __call__(self, parser, namespace, values, option_string=None):
        parser.print_help()
        subparsers = None
        for action in parser._actions:
            if "choices" in dir(action) and action.choices:
                subparsers = action.choices
        if not subparsers:
            sys.exit(0)
        for action, arg_parser in subparsers.items():
            sys.stderr.write("\n\n" + action + "\n" + "-" * len(action) + "\n")
            arg_parser.print_help()
        sys.exit(0)


def make_cmdline_parser():
    parser = argparse.ArgumentParser("k2", add_help=False)
    parser.add_argument("-h", "--help", action=HelpAction)
    parser.add_argument(
        "-v", "--version", action="version", version=SCRIPT_VERSION
    )
    subparsers = parser.add_subparsers()
    make_add_to_library_parser(subparsers)
    make_download_library_parser(subparsers)
    make_download_taxonomy_parser(subparsers)
    make_build_parser(subparsers)
    make_classify_parser(subparsers)
    make_inspect_parser(subparsers)
    make_clean_parser(subparsers)
    return parser


class Logger:
    def __init__(self, filename):
        self.queue = multiprocessing.Manager().Queue(-1)
        logging.StreamHandler.terminator = ""
        self.logger = logging.getLogger("kraken2")
        if filename:
            self.logger.setLevel(logging.INFO)
            handler = logging.FileHandler(filename)
            formatter = logging.Formatter(
                "[%(levelname)s - %(asctime)s]: %(message)s"
            )
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)
        else:
            self.logger.setLevel(logging.DEBUG)
            handler = logging.StreamHandler()
            formatter = logging.Formatter(
                "[%(levelname)s - %(asctime)s]: %(message)s"
            )
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)

        self.thread = threading.Thread(
            target=Logger.process_thread, args=(self,),
            daemon=True
        )
        self.thread.start()

    def debug(self, log):
        self.logger.debug(log)

    def info(self, log):
        self.logger.info(log)

    def warning(self, log):
        self.logger.warning(log)

    def error(self, log):
        self.logger.error(log)

    def process_thread(self):
        import queue
        while True:
            try:
                record = self.queue.get()
                if record is None:
                    break
            except queue.Empty:
                continue
            except Exception:
                break
            self.logger.handle(record)

    def stop_thread(self):
        self.queue.put(None)
        # self.thread.join()

    def get_queue(self):
        return self.queue

    def get_level(self):
        return self.logger.level

    def setup_queue_logger(queue, level):
        queue_handler = logging.handlers.QueueHandler(queue)
        logger = logging.getLogger()
        logger.setLevel(level)
        logger.addHandler(queue_handler)

        return logger

    def __del__(self):
        try:
            self.stop_thread()
        except Exception:
            pass


def k2_main():
    global SCRIPT_PATHNAME
    global LOG

    SCRIPT_PATHNAME = os.path.realpath(inspect.getsourcefile(k2_main))

    parser = make_cmdline_parser()
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)
    args = parser.parse_args(sys.argv[1:])
    LOG = Logger(args.log)
    task = sys.argv[1]
    # if task not in ["classify", "inspect"]:
    #     args.db = os.path.abspath(args.db)
    if isinstance(args.db, list):
        args.db = list(map(os.path.abspath, args.db))
    elif isinstance(args.db, str):
        args.db = os.path.abspath(args.db)
    if task == "download-taxonomy":
        download_taxonomy(args)
    elif task == "classify":
        if len(args.db) > 1:
            classify_multi_dbs(args)
        else:
            args.db = args.db[0]
            classify(args)
    elif task == "download-library":
        download_genomic_library(args)
    elif task == "add-to-library":
        add_to_library(args)
    elif task == "inspect":
        inspect_db(args)
    elif task == "clean":
        clean_db(args)
    elif task == "build":
        # Protein defaults
        default_aa_minimizer_length = 12
        default_aa_kmer_length = 15
        default_aa_minimizer_spaces = 0
        # Nucleotide defaults
        default_nt_minimizer_length = 31
        default_nt_kmer_length = 35
        default_nt_minimizer_spaces = 7

        if args.sub_block_size == 0:
            args.sub_block_size = math.ceil(args.block_size / args.threads)
        if not args.kmer_len:
            args.kmer_len = (
                default_aa_kmer_length
                if args.protein
                else default_nt_kmer_length
            )
        if not args.minimizer_len:
            args.minimizer_len = (
                default_aa_minimizer_length
                if args.protein
                else default_nt_minimizer_length
            )
        if not args.minimizer_spaces:
            args.minimizer_spaces = (
                default_aa_minimizer_spaces
                if args.protein
                else default_nt_minimizer_spaces
            )
        if args.minimizer_len > args.kmer_len:
            LOG.error(
                "Minimizer length ({}) must not be greater than kmer "
                "length {}\n".format(args.minimizer_len, args.kmer_len)
            )
            sys.exit(1)
        if args.load_factor <= 0 or args.load_factor > 1:
            LOG.error(
                "Load factor must be greater than 0 but no more than 1\n"
            )
            sys.exit(1)
        if args.minimizer_len <= 0 or args.minimizer_len > 31:
            LOG.error(
                "Minimizer length must be a positive integer "
                "and cannot exceed 31.\n"
            )
            sys.exit(1)
        if args.standard:
            build_standard_database(args)
        elif args.special:
            if args.special == "greengenes":
                build_16S_gg(args)
            elif args.special == "silva":
                build_16S_silva(args)
            elif args.special == "gtdb":
                if not args.gtdb_files:
                    LOG.error("Please specify a list of files or pattern of\
                    the files needed to build a GTDB database.\n")
                    sys.exit(1)
                build_gtdb_database(args)
            else:
                # build_16S_rdp(args)
                LOG.error("RDP database no longer supported.\n")
                sys.exit(1)
        else:
            if args.no_masking:
                LOG.warning(
                    "--no-masking only affects the `--standard` and"
                    "`--special` flags. Its effect will be ignored.\n"
                )
            build_kraken2_db(args)


if __name__ == "__main__":
    try:
        k2_main()
    except KeyboardInterrupt:
        pass
    except Exception:
        LOG.stop_thread()
        LOG.error(traceback.format_exc())
        sys.exit(1)
    # else:
    #     LOG.stop_thread()
