# -*- coding: utf-8 -*-
from __future__ import print_function, annotations
from __future__ import absolute_import

from typing import DefaultDict, Optional, Any, Union

from past.utils import old_div

import re
import sys
import os
import logging as lg
from collections import OrderedDict, defaultdict, Counter
# import gc
import multiprocessing
from functools import partial
import itertools
import warnings
import pickle
from time import perf_counter

import numpy as np
from numpy import longdouble as xdbl
import scipy
import pandas as pd
import pysam
from scipy.sparse.csgraph import connected_components
from scipy.sparse import dok_matrix, lil_matrix

from stellarscope import StellarscopeError, AlignmentValidationError
from . import OptionsBase

from .sparse_plus import csr_matrix_plus as csr_matrix
from .sparse_plus import row_identity_matrix as rowid
from .sparse_plus import bool_inv
from .sparse_plus import divide_extp
from .statistics import FragmentInfo, AlignInfo, FitInfo, PoolInfo, ReassignInfo, UMIInfo

from .colors import c2str, D2PAL, GPAL
from .helpers import str2int, region_iter, phred

from . import alignment
from .alignment import get_tag_alignments
from .alignment import CODES as ALNCODES
from stellarscope import annotation

__author__ = 'Matthew L. Bendall, Matthew Greenig'
__copyright__ = "Copyright (C) 2022 Matthew L. Bendall, Matthew Greenig"


def process_overlap_frag(pairs, overlap_feats):
    ''' Find the best alignment for each locus '''
    assert all(pairs[0].query_name == p.query_name for p in pairs)
    ''' Organize by feature'''
    byfeature = defaultdict(list)
    for pair, feat in zip(pairs, overlap_feats):
        byfeature[feat].append(pair)

    _maps = []
    for feat, falns in byfeature.items():
        # Sort alignments by score + length
        falns.sort(key=lambda x: x.alnscore + x.alnlen,
                   reverse=True)
        # Add best alignment to mappings
        _topaln = falns[0]
        _maps.append(
            (_topaln.query_name, feat, _topaln.alnscore, _topaln.alnlen)
        )
        # Set tag for feature (ZF) and whether it is best (ZT)
        _topaln.set_tag('ZF', feat)
        _topaln.set_tag('ZT', 'PRI')
        for aln in falns[1:]:
            aln.set_tag('ZF', feat)
            aln.set_tag('ZT', 'SEC')

    # Sort mappings by score
    _maps.sort(key=lambda x: x[2], reverse=True)
    # Top feature(s), comma separated
    _topfeat = ','.join(t[1] for t in _maps if t[2] == _maps[0][2])
    # Add best feature tag (ZB) to all alignments
    for p in pairs:
        p.set_tag('ZB', _topfeat)

    return _maps


class TelescopeLikelihood(object):
    """

    """
    ''' Reassignment modes '''
    REASSIGN_MODES = [
        'best_exclude',
        'best_conf',
        'best_random',
        'best_average',
        'initial_unique',
        'initial_random',
        'total_hits'
    ]

    def __init__(self, score_matrix, opts):
        """
        """
        lg.debug('CALL: TelescopeLikelihood.__init__()')

        """ Store program options """
        self.opts = opts
        self.epsilon = opts.em_epsilon
        self.max_iter = opts.max_iter

        """ Raw scores """
        self.raw_scores = csr_matrix(score_matrix)
        self.raw_scores.eliminate_zeros()
        self.max_score = self.raw_scores.max()

        """ N fragments x K transcripts """
        self.N, self.K = self.raw_scores.shape

        """
        Q[i,] is the set of mapping qualities for fragment i, where Q[i,j]
        represents the evidence for fragment i being generated by fragment j.
        In this case the evidence is represented by an alignment score, which
        is greater when there are more matches and is penalized for
        mismatches
        Scale the raw alignment score by the maximum alignment score
        and multiply by a scale factor.
        """
        self.scale_factor = 100.
        self.Q = self.raw_scores.scale().multiply(self.scale_factor).expm1()

        """
        Indicator variables for ambiguous (Y_amb), unique (Y_uni), and
        unassigned (Y_0) reads.

        In the mathematical model, the ambiguity indicator is represented by
        Y[i], where Y[i]=1 if fragment i is aligned to multiple transcripts
        and Y[i]=0 otherwise.
        """
        _nnz_byrow = np.asmatrix(self.raw_scores.getnnz(1)).T
        self.Y_amb = csr_matrix(_nnz_byrow > 1, dtype=np.ubyte)
        self.Y_uni = csr_matrix(_nnz_byrow == 1, dtype=np.ubyte)
        self.Y_0 = csr_matrix(_nnz_byrow == 0, dtype=np.ubyte)


        """ Pre-computed mapping quality matrix """
        self._ambQ = self.Q.multiply(self.Y_amb)
        self._uniQ = self.Q.multiply(self.Y_uni)


        """
        z[i,] is the partial assignment weights for fragment i, where z[i,j]
        is the expected value for fragment i originating from transcript j. The
        initial estimate is the normalized mapping qualities:
        z_init[i,] = Q[i,] / sum(Q[i,])
        """
        self.z = None

        """
        pi[j] is the proportion of fragments that originate from
        transcript j. Initial value assumes that all transcripts contribute
        equal proportions of fragments
        """
        self.pi = csr_matrix(np.repeat(1. / self.K, self.K))
        self.pi_init = None

        """
        theta[j] is the proportion of non-unique fragments that need to be
        reassigned to transcript j. Initial value assumes that all
        transcripts are reassigned an equal proportion of fragments
        """
        self.theta = np.repeat(1. / self.K, self.K)
        self.theta_init = None

        """ Log-likelihood score """
        self.lnl = float('inf')

        """ Prior values """
        self.pi_prior = opts.pi_prior
        self.theta_prior = opts.theta_prior

        """ Precalculated values """
        self._weights = self.Q.max(1)  # Weight assigned to each fragment
        self._total_wt = self._weights.sum()  # Total weight
        self._ambig_wt = self._weights.multiply(self.Y_amb).sum()
        self._unique_wt = self._weights.multiply(self.Y_uni).sum()

        """ Weighted prior values """
        self._pi_prior_wt = self.pi_prior * self._weights.max()
        self._theta_prior_wt = self.theta_prior * self._weights.max()
        self._pisum0 = self._uniQ.sum(0)

        self._theta_denom = self._ambig_wt + self._theta_prior_wt * self.K
        self._pi_denom = self._total_wt + self._pi_prior_wt * self.K

        lg.debug('EXIT: TelescopeLikelihood.__init__()')

        """ Information about the fit """
        self.fitinfo = FitInfo(self)
        return

    def estep(self, pi, theta):
        """ Calculate the expected values of z
                E(z[i,j]) = ( pi[j] * theta[j]**Y[i] * Q[i,j] ) /
        """
        lg.debug('CALL: TelescopeLikelihood.estep()')
        def _estep_dp():
            _amb = self._ambQ.multiply(pi).multiply(theta)
            _uni = self._uniQ.multiply(pi)
            return (_amb + _uni).norm(1)

        def _estep_xp():
            lg.debug('estep: using extended precision')
            _xpi = pi.astype(np.float128)
            _xtheta = theta.astype(np.float128)
            _xambQ = self._ambQ.astype(np.float128)
            _xuniQ = self._uniQ.astype(np.float128)

            _amb = _xambQ.multiply(_xpi).multiply(_xtheta)
            _uni = _xuniQ.multiply(_xpi)
            return (_amb + _uni).norm(1)

        try:
            return _estep_dp()
        except FloatingPointError:
            return _estep_xp()


    def mstep(self, z):
        """ Calculate the maximum a posteriori (MAP) estimates for pi and theta

        """
        lg.debug('CALL: TelescopeLikelihood.mstep()')
        def _mstep_dp():
            # The expected values of z weighted by mapping score
            _weighted = z.multiply(self._weights)

            # Estimate theta_hat
            _thetasum = _weighted.multiply(self.Y_amb).sum(0)
            _theta_hat = (_thetasum + self._theta_prior_wt) / self._theta_denom

            # Estimate pi_hat
            _pisum = self._pisum0 + _thetasum
            _pi_hat = (_pisum + self._pi_prior_wt) / self._pi_denom

            return csr_matrix(_pi_hat), _theta_hat.A1

        def _mstep_xp():
            lg.debug('mstep: using extended precision')
            # The expected values of z weighted by mapping score
            _weighted = z.multiply(self._weights)

            # Estimate theta_hat
            _thetasum = _weighted.multiply(self.Y_amb).sum(0)
            _theta_hat = (_thetasum + self._theta_prior_wt) / self._theta_denom

            # Estimate pi_hat
            _pisum = self._pisum0 + _thetasum
            _num = xdbl(_pisum) + xdbl(self._pi_prior_wt)
            _denom = xdbl(self._pi_denom)
            _pi_hat = divide_extp(_num, _denom)
            ### sanity check
            # _pi_hat0 = csr_matrix(_num / _denom)
            # assert np.allclose(_pi_hat0.data, _pi_hat.data)
            return csr_matrix(_pi_hat), _theta_hat.A1

        try:
            return _mstep_dp()
        except (FloatingPointError, RuntimeWarning) as e:
            if 'underflow encountered in divide' not in e.args:
                raise StellarscopeError(e.args)
            return _mstep_xp()



    def calculate_lnl(self, z, pi, theta):
        """

        Parameters
        ----------
        z
        pi
        theta

        Returns
        -------

        """
        lg.debug('CALL: TelescopeLikelihood.calculate_lnl()')
        try:
            _pitheta = pi.multiply(theta)
        except FloatingPointError:
            lg.debug('using extended precision (pi*theta)')
            pi = pi.astype(np.float128)
            theta = theta.astype(np.float128)
            _pitheta = pi.multiply(theta)
            # _pitheta = pi * theta

        try:
            _amb = self._ambQ.multiply(_pitheta)
            _uni = self._uniQ.multiply(pi)
        except FloatingPointError:
            lg.debug('using extended precision (_amb and _uni)')
            _amb = self._ambQ.multiply(_pitheta)
            _uni = self._uniQ.multiply(pi)

        try:
            _inner = csr_matrix(_amb + _uni)
            _log_inner = csr_matrix(_inner)
            _log_inner.data = np.log(_inner.data)
        except FloatingPointError:
            lg.debug('using extended precision (_inner)')
            _inner = csr_matrix(_amb + _uni, dtype=np.float128)
            _log_inner = csr_matrix(_inner)
            _log_inner.data = np.log(_inner.data)
        try:
            ret = z.multiply(_log_inner).sum()
        except FloatingPointError:
            lg.debug('using extended precision (z)')
            ret = z.astype(np.float128).multiply(_log_inner).sum()

        lg.debug('EXIT: TelescopeLikelihood.calculate_lnl()')
        return ret

    def em(self, loglev=lg.DEBUG):
        inum = 0  # Iteration number
        _useL = self.opts.use_likelihood

        msgD = 'Iteration {:d}, diff={:.5g}'
        msgL = 'Iteration {:d}, lnl= {:.5e}, diff={:.5g}'
        while not (self.fitinfo.converged or self.fitinfo.reached_max):
            is_time = perf_counter()
            _z = self.estep(self.pi, self.theta)
            _pi, _theta = self.mstep(_z)
            inum += 1
            if inum == 1:
                self.pi_init, self.theta_init = _pi, _theta

            ''' Calculate absolute difference between estimates '''
            diff_est = abs(_pi - self.pi).sum()

            if _useL:
                ''' Calculate likelihood '''
                _lnl = self.calculate_lnl(_z, _pi, _theta)
                diff_lnl = abs(_lnl - self.lnl)
                lg.log(loglev, msgL.format(inum, _lnl, diff_est))
                self.fitinfo.converged = diff_lnl < self.epsilon
                self.lnl = _lnl
            else:
                lg.log(loglev, msgD.format(inum, diff_est))
                self.fitinfo.converged = diff_est < self.epsilon

            itime = perf_counter() - is_time
            lg.log(loglev, "time: {}".format(itime))
            self.fitinfo.reached_max = inum >= self.max_iter
            self.z = _z
            self.pi, self.theta = _pi, _theta

            self.fitinfo.iterations.append((inum,itime,diff_est))

        if self.fitinfo.converged:
            lg.log(loglev, f'EM converged after {inum:d} iterations.')
        elif self.fitinfo.reached_max:
            lg.log(loglev, f'EM terminated after {inum:d} iterations.')
        else:
            StellarscopeError('Not converged or terminated.')

        if not _useL:
            self.lnl = self.calculate_lnl(self.z, self.pi, self.theta)
        self.fitinfo.final_lnl = self.lnl
        lg.log(loglev, f'Final log-likelihood: {self.lnl:f}.')

        return

    def reassign(self,
                 mode: str,
                 thresh: Optional[float] = None
                 ) -> csr_matrix:
        """Reassign fragments to expected transcripts

        Model fitting using EM finds the expected fragment assignment weights
        - posterior probabilites (PP) - at the MAP estimates of pi and theta.
        This function reassigns each fragment so that the most likely
        originating transcript has a weight of 1. In practice, not all
        fragments result have exactly one best hit, even after fitting. The
        "mode" argument defines how we deal with ties.

        In the first four modes, the alignment with the highest PP is selected.
        If multiple alignments have the same highest PP, ties are broken by:
            "best_exclude"   - the read is excluded and does not contribute to
                               the final count.
            "best_conf"      - only alignments with PP exceeding a user-defined
                               threshold are included. We require the threshold
                               to be greater than 0.5, so ties do not occur.
            "best_random"    - one of the best alignments is randomly selected.
            "best_average"   - final count is evenly divided among the best
                               alignments. This results in fractional weights
                               and the final count is not a true count.

        The final three modes do not perform reassignment or model fitting
        but are included for comparison:
            "initial_unique" - only reads that align uniquely to a single
                               locus are included, multimappers are discarded.
                               EM model optimization is not considered, similar
                               to the "unique counts" approach.
            "initial_random" - alignment is randomly chosen from among the
                               set of best scoring alignments. EM model
                               optimization is not considered, similar to the
                               "best counts" approach.
            "total_hits"     - every alignment has a weight of 1. Counts the
                               number of initial alignments to each locus.

        Parameters
        ----------
        mode
        thresh

        Returns
        -------
        csr_matrix
            Sparse CSR matrix where m[i,j] == 1 iff read i is reassigned to
            transcript j.
        """
        if mode not in self.REASSIGN_MODES:
            msg = f'Argument "method" should be one of {self.REASSIGN_MODES}'
            raise ValueError(msg)

        reassign_mat = None
        rinfo = ReassignInfo(mode)

        if mode in ['best_exclude', 'best_random', 'best_average']:
            ''' Identify best PP(s), then count number equal to best '''
            bestmat = self.z.binmax(1)
            nbest = bestmat.sum(1)

            # update ResassignInfo
            rinfo.assigned = sum(nbest.A1 == 1)
            rinfo.ambiguous = sum(nbest.A1 > 1)
            rinfo.unaligned = sum(nbest.A1 == 0)
            rinfo.ambigous_dist = Counter(nbest.A1)

            # Apply exclude, random, or average
            if mode == 'best_exclude':
                reassign_mat = bestmat.multiply(nbest == 1)
            elif mode == 'best_random':
                reassign_mat = bestmat.choose_random(1, self.opts.rng)
            elif mode == 'best_average':
                reassign_mat = bestmat.norm(1)
        elif mode == 'best_conf':
            ''' Zero out all values less than threshold. Since each row must 
                sum to 1, if threshold > 0.5 then each row will have at most 1
                nonzero element.
            '''
            if thresh <= 0.5:
                raise StellarscopeError(
                    f"Confidence threshold ({thresh}) must be > 0.5"
                )

            confmat = csr_matrix(self.z > thresh, dtype=np.int8)
            rowmax = self.z.max(1)

            # update ReassignInfo
            rinfo.assigned = sum(rowmax.data > thresh)
            rinfo.ambiguous = sum(rowmax.data <= thresh)
            rinfo.unaligned = rowmax.shape[0] - rowmax.nnz

            reassign_mat = confmat
        elif mode == 'initial_unique':
            ''' Remove ambiguous rows and set nonzero values to 1 '''
            unimap = (self.Q.multiply(self.Y_uni) > 0).astype(np.int8)

            # update ReassignInfo
            rinfo.assigned = unimap.nnz
            rinfo.ambiguous = self.Y_amb.nnz
            rinfo.unaligned = self.Y_0.nnz

            reassign_mat = unimap
        elif mode == 'initial_random':
            ''' Identify best scores in initial matrix then randomly choose one
                per row
            '''
            bestraw = self.raw_scores.binmax(1)
            nbest_raw = bestraw.sum(1)

            # update ReassignInfo
            rinfo.assigned = sum(nbest_raw.A1 == 1)
            rinfo.ambiguous = sum(nbest_raw.A1 > 1)
            rinfo.unaligned = sum(nbest_raw.A1 == 0)
            rinfo.ambigous_dist = Counter(nbest_raw.A1)

            reassign_mat = bestraw.choose_random(1, self.opts.rng)
        elif mode == 'total_hits':
            ''' Return all nonzero elements in initial matrix '''
            mapped = csr_matrix(self.raw_scores > 0, dtype=np.int8)
            mapped_rowsum = mapped.sum(1)

            # update ReassignInfo
            rinfo.assigned = mapped.sum()
            rinfo.ambiguous = sum(mapped_rowsum.A1 > 1)
            rinfo.unaligned = sum(mapped_rowsum.A1 == 0)

            reassign_mat = mapped

        if reassign_mat is None:
            raise StellarscopeError('reassign_mat was not set.')

        return reassign_mat, rinfo

class Assigner:
    def __init__(self, annotation: annotation.BaseAnnotation,
                 opts: OptionsBase) -> None:
        self.annotation = annotation
        self.no_feature_key = opts.no_feature_key
        self.overlap_mode = opts.overlap_mode
        self.overlap_threshold = opts.overlap_threshold
        self.stranded_mode = opts.stranded_mode
        return

    def assign_func(self):
        def _assign_pair_threshold(pair):
            _ref = pair.ref_name
            _blocks = pair.refblocks
            if self.stranded_mode is None:
                f = self.annotation.intersect_blocks(_ref, _blocks)
            else:
                _strand = pair.fragstrand(self.stranded_mode)
                f = self.annotation.intersect_blocks((_ref, _strand), _blocks)

            if not f:
                return self.no_feature_key

            # Calculate the percentage of fragment mapped
            fname, overlap = f.most_common()[0]
            if overlap > pair.alnlen * self.overlap_threshold:
                return fname
            else:
                return self.no_feature_key

        def _assign_pair_intersection_strict(pair):
            pass

        def _assign_pair_union(pair):
            pass

        ''' Return function depending on overlap mode '''
        if self.overlap_mode == 'threshold':
            return _assign_pair_threshold
        elif self.overlap_mode == 'intersection-strict':
            return _assign_pair_intersection_strict
        elif self.overlap_mode == 'union':
            return _assign_pair_union
        else:
            assert False


""" 
Stellarscope model
"""


def select_umi_representatives(
        umi_feat_scores: list[tuple[str, dict[int, int]]],
        best_score: bool = False,
        weighted: bool = False,
) -> (list[str], list[bool]):
    """ Select best representative(s) among reads with same BC+UMI

    Parameters
    ----------
    umi_feat_scores
    best_score
    weighted

    Returns
    -------

    """
    def read_stats(vec: dict[int, int]) -> tuple[float, int, int ,int]:
        """ Stats used to select best representative in duplicate set

        Reads connected in a duplicate set are sorted according to these
        statistics in descending order; the first read is selected.

        Parameters
        ----------
        vec : dict[int, int]
            Dictionary mapping feature indexes to alignment scores

        Returns
        -------
        tuple[float, int, int, int]
            Tuple with summary statistics

        """
        _scores = list(vec.values())
        return (
            max(_scores) / sum(_scores),  # maximum normalized score
            -len(_scores),               # number of features (ambiguity)
            max(_scores),                # maximum score
            sum(_scores),                # total score
        )

    ''' Unpack the list of tuples '''
    _labels, _score_vecs = map(list, zip(*umi_feat_scores))

    ''' Subset each vector for top scores (optional) '''
    if best_score:
        def subset_topscore(d):
            """ Return copy of dictionary with only max values """
            return {ft: sc for ft, sc in d.items() if sc == max(d.values())}

        _subset_vecs = [subset_topscore(vec) for vec in _score_vecs]
    else:
        _subset_vecs = _score_vecs

    n = len(_subset_vecs)

    ''' Skip building adjacency matrix if all reads have mappings to the same
        feature(s). Approach: Check if intersection of all feature lists is
        greater than 0. (This should be faster than building adjacency matrix
        and finding connected components)  
    '''
    shortcut = len(set.intersection(*map(set, _subset_vecs))) > 0
    if shortcut:
        _ranks = sorted(
            ((*read_stats(v), r) for r,v in enumerate(_subset_vecs)),
            reverse=True
        )
        is_excluded = [r != _ranks[0][-1] for r in range(n)]
        return np.zeros(n, dtype=np.int32).tolist(), is_excluded

    ''' Calculate adjacency matrix '''
    graph = scipy.sparse.dok_matrix((n, n), dtype=np.uint64)
    for i, j in itertools.combinations(range(n), 2):
        _inter = _subset_vecs[i].keys() & _subset_vecs[j].keys()
        if weighted:
            graph[i, j] = sum(_subset_vecs[i][x] for x in _inter)
            graph[j, i] = sum(_subset_vecs[j][x] for x in _inter)
        else:
            graph[i, j] = len(_inter)
            graph[j, i] = len(_inter)
    ncomp, comps = connected_components(graph, directed=False)

    ''' Choose best read for each component.
        For each component:
            - select reads belonging to component
            - calculate ranking statistics for reads (_ranks)
            - sort so that representative (best) read is first
    '''
    # component_rep = {}  # component number -> representative row index
    # reps = []           # list of representative read names
    is_excluded = [True] * n
    for c_i in range(ncomp):
        _ranks = []
        for row in np.where(comps == c_i)[0].tolist():
            _ranks.append((*read_stats(_subset_vecs[row]), row))
        _ranks.sort(reverse=True)
        rep_index = _ranks[0][-1]
        # component_rep[c_i] = _labels[rep_index]
        # reps.append(_labels[rep_index])
        is_excluded[rep_index] = False

    # sanity check:
    # number excluded is the nreads - ncomp (1 rep for each component)
    # if sum(is_excluded) != n - ncomp:
    #     raise StellarscopeError("incorrect number excluded")
    
    return comps.tolist(), is_excluded


def _em_wrapper(tl: TelescopeLikelihood):
    tl.em()
    return tl.z, tl.fitinfo

def _fit_pooling_model(
        st: Stellarscope,
        opts: 'StellarscopeAssignOptions',
        processes: int = 1,
        progress: int = 100

) -> TelescopeLikelihood:
    """ Fit model using different pooling modes

    Parameters
    ----------
    st : Stellarscope
        Stellarscope object
    opts : StellarscopeAssignOptions
        Stellarscope run options
    processes : int
        Number of processes to run
    progress : int
        Frequency of progress output. Set to 0 to disable.

    Returns
    -------
    TelescopeLikelihood
        TelescopeLikelihood object containing the fitted posterior probability
        matrix (`TelescopeLikelihood.z`).
    """
    def _tl_generator_celltype():
        """ Generate score matrices subset by celltype 
        
        Yields
        -------
        TelescopeLikelihood
            A TelescopeLikelihood object with subset score_matrix

        """
        for _ctype in st.celltypes:
            ''' Get read indexes for each barcode in cell '''
            _rows = []
            for _bcode in st.ctype_bcode_map[_ctype]:
                if _bcode in st.bcode_ridx_map:
                    _rows.extend(st.bcode_ridx_map[_bcode])

            ''' No reads for this celltype '''
            if not _rows:
                lg.info(f'        ...not fitting "{_ctype}" -> no reads found')
                continue

            _I = rowid(_rows, _fullmat.shape[0])
            yield TelescopeLikelihood(_fullmat.multiply(_I), opts)

    def _tl_generator_individual():
        """ Generate score matrices by cell barcode
        
        Yields
        -------
        TelescopeLikelihood
            A TelescopeLikelihood object with subset score matrix

        """
        for _bcode, _rowset in st.bcode_ridx_map.items():
            _rows = sorted(_rowset)

            ''' No reads for this cell barcode '''
            if not _rows:
                lg.info(f'        ...not fitting "{_bcode}" -> no reads found')
                continue

            _I = rowid(_rows, _fullmat.shape[0])
            yield TelescopeLikelihood(_fullmat.multiply(_I), opts)

    """ Fit pooling model """
    poolinfo = PoolInfo(opts.pooling_mode)

    """ Select UMI corrected or raw score matrix """
    if opts.ignore_umi:
        if st.corrected is not None:
            lg.warning('Ignoring UMI corrected matrix')
        _fullmat = st.raw_scores
    else:
        if st.corrected is None:
            raise StellarscopeError("UMI corrected matrix not found")
        if st.corrected.shape != st.shape:
            raise StellarscopeError("UMI corrected matrix shape mismatch")
        _fullmat = st.corrected

    ret_model = TelescopeLikelihood(_fullmat, opts)

    if opts.pooling_mode == 'pseudobulk':
        lg.info(f'Models to fit: 1')
        poolinfo.nmodels = 1
        ret_model.em()
        poolinfo.models_info['pseudobulk'] = ret_model.fitinfo
        return ret_model, poolinfo

    """ Initialize z for return model """
    ret_model.z = csr_matrix(_fullmat.shape, dtype=np.float64)
    if opts.pooling_mode == 'individual':
        poolinfo.nmodels = len(st.bcode_ridx_map)
        lg.info(f'Models to fit: {poolinfo.nmodels}')
        if progress:
            progress = 100 if np.log10(poolinfo.nmodels) > 2 else 10
        tl_generator = _tl_generator_individual()
    elif opts.pooling_mode == 'celltype':
        poolinfo.nmodels = len(st.celltypes)
        lg.info(f'Models to fit: {poolinfo.nmodels}')
        if progress:
            progress = 100 if np.log10(poolinfo.nmodels) > 2 else 10
        tl_generator = _tl_generator_celltype()

    processes = opts.nproc
    if processes == 1:
        for i, tl in enumerate(tl_generator):
            _z, _fitinfo = _em_wrapper(tl)
            if progress and (i + 1) % progress == 0:
                lg.info(f'        ...{i + 1} models fitted')
            ret_model.z += _z
            poolinfo.models_info[i] = tl.fitinfo
    else:
        with multiprocessing.Pool(processes) as pool:
            lg.info(f'  (Using pool of {processes} workers)')
            # _func = partial(_em_wrapper, use_lnl=opts.use_likelihood)
            imap_it = pool.imap(_em_wrapper, tl_generator, 10)
            for i, (_z, _fitinfo) in enumerate(imap_it):
                if progress and (i + 1) % progress == 0:
                    lg.info(f'        ...{i + 1} models fitted')
                ret_model.z += _z
                poolinfo.models_info[i] = _fitinfo

    ret_model.lnl = poolinfo.total_lnl
    return ret_model, poolinfo


class Stellarscope(object):
    """

    """
    opts: "StellarscopeAssignOptions"
    run_info: dict
    feature_length: Optional[dict]
    read_index: dict[str, int]
    feat_index: dict[str, int]
    shape: Optional[tuple[int, int]]
    raw_scores: Optional[csr_matrix]
    other_bam: Union[str, bytes, os.PathLike]
    tmp_bam: Union[str, bytes, os.PathLike]
    has_index: bool
    ref_names: list[str]
    ref_lengths: list[int]
    read_bcode_map: dict[str, str]
    read_umi_map: dict[str, str]
    bcode_ridx_map: DefaultDict[str, set[int]]
    bcode_umi_map: DefaultDict[str, set[int]]
    filtlist: dict[str, int]
    bcode_ctype_map: dict[str, str]
    ctype_bcode_map: DefaultDict[set[str]]
    celltypes: list[str]

    corrected: Optional[csr_matrix]
    umi_dups: Optional[csr_matrix]

    reassignments: dict[str, csr_matrix]

    def __init__(self, opts):
        """

        Parameters
        ----------
        opts
        """
        self.opts = opts               # Command line options
        self.run_info = dict()         # Information about the run
        self.feature_length = None     # Lengths of features
        self.read_index = {}           # {"fragment name": row_index}
        self.feat_index = {}           # {"feature_name": column_index}
        self.shape = None              # Fragments x Features
        self.raw_scores = None         # Initial alignment scores

        # BAM with non overlapping fragments (or unmapped)
        self.other_bam = opts.outfile_path('other.bam')
        # BAM with overlapping fragments
        self.tmp_bam = opts.outfile_path('tmp_tele.bam')

        # Set the version
        self.run_info['version'] = self.opts.version

        # about the SAM/BAM input file
        _pysam_verbosity = pysam.set_verbosity(0)
        with pysam.AlignmentFile(self.opts.samfile, check_sq=False) as sf:
            pysam.set_verbosity(_pysam_verbosity)
            self.has_index = sf.has_index()
            if self.has_index:
                self.run_info['nmap_idx'] = sf.mapped
                self.run_info['nunmap_idx'] = sf.unmapped

            self.ref_names = sf.references
            self.ref_lengths = sf.lengths

        # read, BC, UMI tracking
        self.read_bcode_map = {}                # {read_id (str): barcode (str)}
        self.read_umi_map = {}                  # {read_id (str): umi (str)}
        self.bcode_ridx_map = defaultdict(set)  # {barcode (str): read_indexes (:obj:`set` of int)}
        self.bcode_umi_map = defaultdict(list)  # {barcode (str): umis (:obj:`set` of str)}

        self.filtlist = {}                     # {barcode (str): index (int)}

        ''' Instance variables for pooling mode = "celltype" '''
        self.bcode_ctype_map = {}
        self.ctype_bcode_map = defaultdict(set)
        self.celltypes = []

        self.corrected = None
        self.umi_dups = None

        self.reassignments = {}

        return

    def load_filtlist(self):
        if self.filtlist:
            lg.warn(f'Filter BC list exists ({len(self.filtlist)} barcodes).')
            lg.warn(f'Provided list ({self.opts.filtered_bc}) not loaded')
            return
        with open(self.opts.filtered_bc, 'r') as fh:
            _bc_gen = (l.split('\t')[0].strip() for l in fh)
            _nskip = self.opts.filtered_bc_skip
            _headers = [next(_bc_gen) for _ in range(_nskip)]
            for _bc in _bc_gen:
                _ = self.filtlist.setdefault(_bc, len(self.filtlist))
        return


    def load_celltype_file(self):
        """ Load celltype assignments into Stellarscope object


        Sets values for instance variables:
             `self.bcode_ctype_map`
             `self.ctype_bcode_map`
             `self.celltypes`
             `self.barcode_celltypes`

        Returns
        -------

        """
        #if self.bcode_ctype_map:
        #if self.ctype_bcode_map:
        with open(self.opts.celltype_tsv, 'r') as fh:
            _gen = (tuple(map(str.strip, l.split('\t')[:2])) for l in fh)
            # Check first line is valid barcode and not column header
            _bc, _ct = next(_gen)
            if re.match('^[ACGTacgt]+$', _bc):
                _ = self.bcode_ctype_map.setdefault(_bc, _ct)
                self.ctype_bcode_map[_ct].add(_bc)
            # Add the rest without checking
            for _bc, _ct in _gen:
                _ = self.bcode_ctype_map.setdefault(_bc, _ct)
                assert _ == _ct, f'Mismatch for {_bc}, "{_}" != "{_ct}"'
                self.ctype_bcode_map[_ct].add(_bc)

        self.celltypes = sorted(set(self.bcode_ctype_map.values()))
        # self.barcode_celltypes = pd.DataFrame({
        #     'barcode': self.bcode_ctype_map.keys(),
        #     'celltype': self.bcode_ctype_map.values()
        # })
        return


    def load_alignment(self, annotation: annotation.BaseAnnotation) -> AlignInfo:
        """

        Parameters
        ----------
        annotation

        Returns
        -------

        """

        def mapping_to_matrix(mappings):
            self.shape = (len(self.read_index), len(self.feat_index))

            # rescale function to positive integers > 0
            rescale = lambda s: (s - alninfo.minAS + 1)

            _m_dok = scipy.sparse.dok_matrix(self.shape, dtype=np.uint16)

            for code, query_name, feat_name, alnscore, alnlen in mappings:
                i = self.read_index[query_name]
                j = self.feat_index[feat_name]
                _m_dok[i, j] = max(_m_dok[i, j], rescale(alnscore) + alnlen)

            ''' Check that all rows have nonzero values in feature columns'''
            assert self.feat_index[self.opts.no_feature_key] == 0
            _nz = scipy.sparse.csc_matrix(_m_dok)[:, 1:].sum(1).nonzero()[0]
            assert len(_nz) == self.shape[0]

            ''' Set raw score matrix'''
            self.raw_scores = csr_matrix(_m_dok)
            return

        ''' Add feature information to object '''
        self.run_info['annotated_features'] = len(annotation.loci)
        self.feature_length = annotation.feature_length().copy()

        ''' Initialize feature index with features '''
        self.feat_index = {self.opts.no_feature_key: 0, }
        for locus in annotation.loci.keys():
            _ = self.feat_index.setdefault(locus, len(self.feat_index))

        ''' Load alignment sequentially using 1 CPU '''
        maps, alninfo = self._load_sequential(annotation)

        ''' Convert alignment to sparse matrix '''
        mapping_to_matrix(maps)
        return alninfo

    def _load_sequential(self, annotation):
        """ Load queryname sorted BAM sequentially

        Args:
            annotation:

        Returns:

        """

        def skip_fragment(reason: Optional[str] = None):
            if reason:
                _finfo.error = reason
            alninfo.update(_finfo)
            if self.opts.updated_sam:
                [p.write(bam_u) for p in alns]

        def process_fragment(
                alns: list['AlignedPair'],
                overlap_feats: list[str]
        ):
            """ Find the best alignment for each locus

            Parameters
            ----------
            alns
            overlap_feats

            Returns
            -------

            """
            return process_overlap_frag(alns, overlap_feats)

        def store_read_info(query_name, barcode, umi):
            """ Adds read query name, barcode and UMI to Stellarscope indexes

            Parameters
            ----------
            query_name
            barcode
            umi

            Returns
            -------
            None

            """
            # Add read ID, barcode, and UMI to internal data
            row = self.read_index.setdefault(query_name, len(self.read_index))

            _prev = self.read_bcode_map.setdefault(query_name, barcode)
            if _prev != barcode:
                raise AlignmentValidationError(
                    f'Barcode error ({query_name}): {_prev} != {barcode}'
                )

            self.bcode_ridx_map[barcode].add(row)

            if not self.opts.ignore_umi:
                _prev = self.read_umi_map.setdefault(query_name, umi)
                if _prev != umi:
                    raise AlignmentValidationError(
                        f'UMI error ({query_name}): {_prev} != {umi}'
                    )
                self.bcode_umi_map[barcode].append(umi)
            return

        ''' Load sequential '''
        _nfkey = self.opts.no_feature_key
        _mappings = []
        assign_func = Assigner(annotation, self.opts).assign_func()

        # Initialize variables for function
        alninfo = AlignInfo(self.opts.progress)

        _pysam_verbosity = pysam.set_verbosity(0)
        with pysam.AlignmentFile(self.opts.samfile, check_sq=False) as sf:
            pysam.set_verbosity(_pysam_verbosity)

            # Create output temporary files
            if self.opts.updated_sam:
                bam_u = pysam.AlignmentFile(self.other_bam, 'wb', template=sf)
                bam_t = pysam.AlignmentFile(self.tmp_bam, 'wb', template=sf)

            # Iterate over fragments
            for ci, alns in alignment.fetch_fragments_seq(sf, until_eof=True):
                _finfo = FragmentInfo()

                ''' Count code '''
                _finfo.add_code(ALNCODES[ci][0])

                ''' Check whether fragment is mapped '''
                if not _finfo.mapped:
                    skip_fragment()
                    continue

                ''' Get alignment barcode and UMI '''
                _cur_qname = alns[0].query_name
                _cur_bcode = get_tag_alignments(alns, self.opts.barcode_tag)
                _cur_umi = get_tag_alignments(alns, self.opts.umi_tag)

                ''' Validate barcode and UMI '''
                if _cur_bcode is None:
                    skip_fragment("Missing CB")
                    continue

                if self.filtlist:
                    if _cur_bcode not in self.filtlist:
                        skip_fragment("CB not in filtered BC list")
                        continue

                if not self.opts.ignore_umi and _cur_umi is None:
                    skip_fragment("Missing UB")
                    continue

                ''' Fragment is ambiguous if multiple mappings'''
                _mapped = [a for a in alns if not a.is_unmapped]
                _finfo.ambig = len(_mapped) > 1

                ''' Update min and max scores '''
                _finfo.scores = [a.alnscore for a in _mapped]

                ''' Check whether fragment overlaps annotation '''
                overlap_feats = list(map(assign_func, _mapped))
                _finfo.overlap = any(f != _nfkey for f in overlap_feats)

                ''' Fragment has no overlap, skip '''
                if not _finfo.overlap:
                    skip_fragment()
                    continue

                ''' Add cell tags to barcode/UMI trackers '''
                store_read_info(_cur_qname, _cur_bcode, _cur_umi)

                ''' Fragment overlaps with annotation '''
                alninfo.update(_finfo)

                ''' Find the best alignment for each locus '''
                for m in process_fragment(_mapped, overlap_feats):
                    _mappings.append((ci, m[0], m[1], m[2], m[3]))

                if self.opts.updated_sam:
                    [p.write(bam_t) for p in alns]

        ''' Loading complete '''
        if self.opts.updated_sam:
            bam_u.close()
            bam_t.close()

        return _mappings, alninfo

    def dedup_umi(self, output_report=True, summary=True):
        """

        Returns
        -------

        """
        exclude_qnames: dict[str, int] = {}  # reads to be excluded

        if output_report:
            umiFH = open(self.opts.outfile_path('umi_tracking.txt'), 'w')

        ''' Index read names by barcode+umi '''
        bcumi_read = defaultdict(dict)
        for qname in self.read_index:
            key = (self.read_bcode_map[qname], self.read_umi_map[qname])
            _ = bcumi_read[key].setdefault(qname, None)

        """ Update umiinfo """
        sumlvl = lg.INFO if summary else lg.DEBUG
        umiinfo = UMIInfo(bcumi_read)
        # umiinfo.init_rpu(bcumi_read)
        umiinfo.loginit(sumlvl)

        ''' Loop over all bc+umi pairs'''
        for (bc, umi), qnames in bcumi_read.items():
            ''' Unique barcode+umi '''
            if len(qnames) == 1:
                continue

            ''' Duplicated barcode+umi '''
            umi_feat_scores = []
            ''' Construct a list of 2-tuples with (read name, alignment vector)
                where alignment vector is a dictionary mapping feature index to
                the alignment score of the read to the feature
            '''
            for qname in qnames.keys():
                row_m = self.raw_scores[self.read_index[qname],]
                vec = {ft: sc for ft, sc in zip(row_m.indices, row_m.data)}
                _ = vec.pop(0, None) # remove "no_feature", index = 0
                umi_feat_scores.append((qname, vec))

            ''' Select representative read(s) '''
            comps, is_excluded = select_umi_representatives(umi_feat_scores)
            ''' Update set to exclude '''
            for ex, (qname, vec) in zip(is_excluded, umi_feat_scores):
                if ex:
                    _ = exclude_qnames.setdefault(qname, len(exclude_qnames))

            umiinfo.ncomps_umi[len(set(comps))] += 1
            umiinfo.nexclude = umiinfo.nexclude + sum(is_excluded)

            if output_report:
                print(f'{bc}\t{umi}', file=umiFH)
                _iter = zip(comps, is_excluded, umi_feat_scores)
                for comp, ex, (qname, vec) in _iter:
                    exstr = 'EX' if ex else 'REP'
                    print(f'\t{qname}\t{comp}\t{exstr}\t{str(vec)}',
                          file=umiFH)

        if output_report:
            umiFH.close()


        exclude_rows = [self.read_index[qname] for qname in exclude_qnames]
        # exclude_mat = rowid(exclude_rows, self.shape[0])
        self.umi_dups = rowid(exclude_rows, self.shape[0])
        # self.corrected = (self.raw_scores - self.raw_scores.multiply(self.umi_duplicates))
        self.corrected = self.raw_scores.multiply(bool_inv(self.umi_dups))

        #umiinfo.postlog(sumlvl)
        return umiinfo

    """
        # Sanity check: check excluded rows are set to zero in self.corrected
        # and included rows are the same
        for r in range(self.shape[0]):
            if r in exclude_rows:
                assert self.corrected[r,].nnz == 0
            else:
                assert self.raw_scores[r,].check_equal(self.corrected[r,])
    """

    def fit_pooling_model(self, ):
        return _fit_pooling_model(
            self,
            self.opts,
            processes=self.opts.nproc,
            progress=100
        )

    def reassign(self, tl: TelescopeLikelihood) -> dict[str, ReassignInfo]:
        """

        Returns
        -------

        """
        if not hasattr(self, "reassignments"):
            self.reassignments = {}

        rmode_info = {}

        for _rmode in self.opts.reassign_mode:
            _thresh = self.opts.conf_prob if _rmode == 'best_conf' else None
            _rmat,_rinfo = tl.reassign(_rmode, _thresh)
            self.reassignments[_rmode] = _rmat
            rmode_info[_rmode] = _rinfo
        return rmode_info


    def save(self, filename: Union[str, bytes, os.PathLike]):
        """ Save current state of Stellarscope object

        Parameters
        ----------
        filename

        Returns
        -------
        bool
            True is save is successful, False otherwise
        """
        with open(filename, 'wb') as fh:
            pickle.dump(self, fh)

        return True

    @classmethod
    def load(cls, filename: Union[str, bytes, os.PathLike]):
        """ Load Stellarscope object from file

        Parameters
        ----------
        filename

        Returns
        -------

        """
        with open(filename, 'rb') as fh:
            return pickle.load(fh)

    def output_report(self, tl: 'TelescopeLikelihood'):
        """

        Parameters
        ----------
        tl

        Returns
        -------

        """
        def agg_bc(remat: csr_matrix, idx_bc: dict[int, str], proc: int=1):
            if np.issubdtype(remat.dtype, np.integer):
                _dtype = np.int32
            elif np.issubdtype(remat.dtype, np.floating):
                _dtype = np.float64
            else:
                raise StellarscopeError("reassigned is not int or float")

            ''' Group by barcode and calculate colsums '''
            _cts0, _ridx0 = remat.groupby_sum(idx_bc, dtype=_dtype, proc=proc)
            _ridx0 = {v:i for i,v in enumerate(_ridx0)}

            ''' Order barcodes by `bc_list` and add in empty barcodes '''
            _empty_cell = csr_matrix((1, numft), dtype=_dtype)
            bcsum_list = []
            for i, _bc in enumerate(bc_list):
                if _bc in _ridx0:
                    bcsum_list.append(_cts0[_ridx0[_bc], ])
                else:
                    bcsum_list.append(_empty_cell)

            return scipy.sparse.vstack(bcsum_list, dtype=_dtype).transpose()

        def agg_bc_umi(remat: csr_matrix, idx_bc: dict[int,str], idx_umi: dict[int,str], proc: int=1):
            # raise StellarscopeError('not implemented')
            if np.issubdtype(remat.dtype, np.integer):
                _dtype = np.int32
            elif np.issubdtype(remat.dtype, np.floating):
                _dtype = np.float64
            else:
                raise StellarscopeError("reassigned is not int or float")

            ''' Group by BC+UMI and calculate colsums '''
            idx_bcU = {i:f'{idx_bc[i]}_{idx_umi[i]}' for i in idx_bc}
            _ctsU, _ridxU = remat.groupby_sum(idx_bcU, dtype=_dtype, proc=proc)
            _ridxU = {v:i for i,v in enumerate(_ridxU)}

            ''' Adjust so each BC+UMI contributes 1 count '''
            remat0 = _ctsU.copy()
            remat0.data = np.ones(len(remat0.data))

            ''' Group by barcode and calculate colsums '''
            idx_bc0 = {i: bcumi.split('_')[0] for bcumi, i in _ridxU.items()}
            return agg_bc(remat0, idx_bc0, proc)

        def mtx_meta(reassign_mode):
            """ Create metadata for mtx header """
            _meta = OrderedDict({
                'PN': 'stellarscope',
                'VN': self.opts.version,
            })
            if hasattr(self.opts, 'samfile'):
                _meta['samfile'] = self.opts.samfile
            if hasattr(self.opts, 'checkpoint'):
                _meta['checkpoint'] = self.opts.checkpoint
            if hasattr(self.opts, 'gtffile'):
                _meta['gtffile'] = self.opts.gtffile
            if hasattr(self.opts, 'filtered_bc'):
                _meta['filtered_bc'] = self.opts.filtered_bc
            _meta['pooling_mode'] = self.opts.pooling_mode
            _meta['reassign_mode'] = reassign_mode
            if self.opts.ignore_umi:
                _meta['ignore_umi'] = 'Yes'
                _meta['umi_counts'] = 'Yes' if self.opts.umi_counts else 'No'
            _meta['command'] = ' '.join(sys.argv)

            return '\n '.join(f'{k}: {v}' for k, v in _meta.items())

        ''' Final order for barcodes and features'''
        if self.filtlist:
            bc_list = sorted(self.filtlist, key=self.filtlist.get)
        else:
            bc_list = sorted(self.bcode_ridx_map.keys())
        numbc = len(bc_list)

        ft_list = sorted(self.feat_index, key=self.feat_index.get)
        numft = len(ft_list)

        ''' Write barcodes.tsv and features.tsv '''
        with open(self.opts.outfile_path('barcodes.tsv'), 'w') as outh:
            print('\n'.join(bc_list), file=outh)

        with open(self.opts.outfile_path('features.tsv'), 'w') as outh:
            print('\n'.join(ft_list), file=outh)

        ''' Generate count matrix '''
        nproc = self.opts.nproc
        _qnames = sorted(self.read_index, key=self.read_index.get)
        ridx_bc = dict(enumerate([self.read_bcode_map[rn] for rn in _qnames]))
        # assert ridx_bc == {i:self.read_bcode_map[qn] for i,qn in enumerate(_qnames)}

        if self.opts.umi_counts:
            ridx_umi = {i:self.read_umi_map[qn] for i,qn in enumerate(_qnames)}
            #umi_order = [self.read_umi_map[qname] for qname, ridx in sorted(self.read_index, key=self.read_index.get)]


        for rmode_i, _rmode in enumerate(self.opts.reassign_mode):
            _remat = self.reassignments[_rmode]
            if self.opts.umi_counts:
                lg.info(f'Aggregating UMI counts for {_rmode} (nproc={nproc})')
                _counts = agg_bc_umi(_remat, ridx_bc, ridx_umi, nproc)
            else:
                lg.info(f'Aggregating counts for {_rmode} (nproc={nproc})')
                _counts = agg_bc(_remat, ridx_bc, nproc)

            if not _counts.shape == (numft, numbc):
                raise StellarscopeError(
                    f'Incompatible shapes: {_counts.shape} != {(numft, numbc)}'
                )

            ''' Write output MTX '''
            if rmode_i == 0:
                out_mtx = self.opts.outfile_path('TE_counts.mtx')
            else:
                out_mtx = self.opts.outfile_path(f'TE_counts.{_rmode}.mtx')

            scipy.io.mmwrite(out_mtx, _counts, comment=mtx_meta(_rmode))

        return

    def output_report_old(self, tl, stats_filename, counts_filename,
                          barcodes_filename, features_filename):
        """
        .. deprecated:: be33986
            `model.output_report_old() is replaced by `model.output_report()`
            which was implemented in be33986.
        """
        raise DeprecationWarning('''
            deprecated:: be33986
                `model.output_report_old() is replaced by `model.output_report()`
                which was implemented in be33986.
        ''')

    def update_sam(self, tl, filename, all_alns=True):
        _rmode = self.opts.reassign_mode[0]
        _fnames = sorted(self.feat_index, key=self.feat_index.get)

        mat = self.reassignments[_rmode]
        # best_feats = {i: _fnames for i, j in zip(*mat.nonzero())}

        _pysam_verbosity = pysam.set_verbosity(0)
        with pysam.AlignmentFile(self.tmp_bam, check_sq=False) as sf:
            pysam.set_verbosity(_pysam_verbosity)
            header = sf.header
            header['PG'].append({
                'PN': 'telescope', 'ID': 'telescope',
                'VN': self.run_info['version'],
                'CL': ' '.join(sys.argv),
            })
            outsam = pysam.AlignmentFile(filename, 'wb', header=header)
            for code, pairs in alignment.fetch_fragments_seq(sf,
                                                             until_eof=True):
                if len(pairs) == 0: continue
                ridx = self.read_index[pairs[0].query_name]
                for aln in pairs:
                    if aln.is_unmapped:
                        if all_alns:
                            aln.write(outsam)
                        continue
                    if not aln.r1.has_tag('ZT'):
                        raise StellarscopeError('Missing ZT tag')
                    if aln.r1.get_tag('ZT') == 'SEC':
                        aln.set_flag(pysam.FSECONDARY)
                        aln.set_tag('YC', c2str((248, 248, 248)))
                        aln.set_mapq(0)
                    else:
                        fidx = self.feat_index[aln.r1.get_tag('ZF')]
                        prob = tl.z[ridx, fidx]
                        aln.set_mapq(phred(prob))
                        # aln.set_tag('XP', int(round(prob * 100)))
                        # float
                        aln.set_tag('XP', prob, 'f')
                        if mat[ridx, fidx] > 0:
                            aln.unset_flag(pysam.FSECONDARY)
                            aln.set_tag('YC', c2str(D2PAL['vermilion']))
                        else:
                            aln.set_flag(pysam.FSECONDARY)
                            if prob >= 0.2:
                                aln.set_tag('YC', c2str(D2PAL['yellow']))
                            else:
                                aln.set_tag('YC', c2str(GPAL[2]))
                    aln.write(outsam)
            outsam.close()

    def print_summary(self, loglev=lg.WARNING):
        raise DeprecationWarning('''
            deprecated:: 9097515
            Replaced by statistics.AlignInfo.log()
        ''')

    def check_equal(self, other: Stellarscope, explain: bool = False):
        """ Check whether two Stellarscope objects are equal

        Parameters
        ----------
        other: Stellarscope
            Stellarscope object to compare with
        explain: bool, default=False
            Whether to return an explanation of which attributes are not equal.

        Returns
        -------
        bool
            True if `Stellarscope` objects are equivalent, False otherwise.

        """

        def check_attr_equal(v1: Any, v2: Any):
            """ Check whether two attributes are equal

            Parameters
            ----------
            v1: Any
                First value to compare
            v2: Any
                Second value to compare
            Returns
            -------
            bool
                True if values are equal, False otherwise

            """
            if v1 is None or v2 is None:
                if v1 is None and v2 is None:
                    return True, f'both are None'
                elif v2 is not None:
                    return False, f'v1 is None but v2 is {type(v2)})'
                elif v1 is not None:
                    return False, f'v2 is None but v1 is {type(v1)}'
                raise StellarscopeError('unreachable')

            if type(v1) != type(v2):
                return False, f'v1 is {type(v1)}but v2 is {type(v2)}'
            if isinstance(v1, (bool, str, int, tuple)):
                if v1 == v2:
                    return True, f'{v1} == {v2}'
                else:
                    return False, f'{v1} != {v2}'
            if isinstance(v1, list):
                if v1 == v2:
                    return True, f'lists are equal'
                else:
                    return False, f'lists are not equal'
            if isinstance(v1, dict):
                if v1 == v2:
                    return True, f'dicts are equal'
                else:
                    return False, f'dicts are not equal'
            if isinstance(v1, csr_matrix):
                if v1.check_equal(v2):
                    return True, f'sparse matrixes are equal'
                else:
                    return False, f'sparse matrixes are not equal'
            if isinstance(v1, OptionsBase):
                return True, 'no checking for options yet'
            raise StellarscopeError(f'unknown type {type(v1)}')

        ''' Check both are Stellarscope objects '''
        if not isinstance(other, self.__class__):
            if not explain:
                return False
            else:
                return False, f'other is type {type(other)}'

        ''' Check both object have the same attributes '''
        same_attrs = self.__dict__.keys() == other.__dict__.keys()
        if not same_attrs:
            if not explain:
                return False
            reason = ''
            d1 = self.__dict__.keys() - other.__dict__.keys()
            if d1:
                reason += f'self has attrs "{",".join(d1)}" not in other\n'
            d2 = other.__dict__.keys() - self.__dict__.keys()
            if d2:
                reason += f'other has attrs "{",".join(d2)}" not in self\n'
            return False, reason

        ''' Check that attributes are equal '''
        is_equal = True
        reason = ''
        for a in self.__dict__.keys():
            v1 = getattr(self, a)
            v2 = getattr(other, a)
            attr_equal, msg = check_attr_equal(v1, v2)
            is_equal &= attr_equal
            if not attr_equal:
                if not explain:
                    return is_equal
                reason += f'Difference found in "Stellarscope.{a}": {msg}\n'

        if not explain:
            return is_equal
        return is_equal, reason

    def __str__(self):
        clsn = self.__class__.__name__
        if hasattr(self.opts, 'samfile'):
            return f'<{clsn} samfile={self.opts.samfile}, gtffile={self.opts.gtffile}>'
        elif hasattr(self.opts, 'checkpoint'):
            return f'<{clsn} checkpoint={self.opts.checkpoint}>'
        else:
            return f'<{clsn}>'
