#
# SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#
from cuml.internals import get_handle
from cuml.metrics.cluster.utils import prepare_cluster_metric_inputs

from libc.stdint cimport uintptr_t
from pylibraft.common.handle cimport handle_t


cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics" nogil:
    double completeness_score(const handle_t & handle, const int *y,
                              const int *y_hat, const int n,
                              const int lower_class_range,
                              const int upper_class_range) except +


def cython_completeness_score(labels_true, labels_pred, handle=None) -> float:
    """
    Completeness metric of a cluster labeling given a ground truth.

    A clustering result satisfies completeness if all the data points that are
    members of a given class are elements of the same cluster.

    This metric is independent of the absolute values of the labels:
    a permutation of the class or cluster label values won’t change the score
    value in any way.

    This metric is not symmetric: switching label_true with label_pred will
    return the homogeneity_score which will be different in general.

    The labels in labels_pred and labels_true are assumed to be drawn from a
    contiguous set (Ex: drawn from {2, 3, 4}, but not from {2, 4}). If your
    set of labels looks like {2, 4}, convert them to something like {0, 1}.

    Parameters
    ----------
    labels_pred : array-like (device or host) shape = (n_samples,)
        The labels predicted by the model for the test dataset.
        Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device
        ndarray, cuda array interface compliant array like CuPy
    labels_true : array-like (device or host) shape = (n_samples,)
        The ground truth labels (ints) of the test dataset.
        Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device
        ndarray, cuda array interface compliant array like CuPy
    handle : cuml.Handle or None, default=None

        .. deprecated:: 26.02
            The `handle` argument was deprecated in 26.02 and will be removed
            in 26.04. There's no need to pass in a handle, cuml now manages
            this resource automatically.

    Returns
    -------
    float
      The completeness of the predicted labeling given the ground truth.
      Score between 0.0 and 1.0. 1.0 stands for perfectly complete labeling.
    """
    handle = get_handle(handle=handle)
    cdef handle_t *handle_ = <handle_t*> <size_t> handle.getHandle()

    (y_true, y_pred, n_rows,
     lower_class_range, upper_class_range) = prepare_cluster_metric_inputs(
        labels_true,
        labels_pred
    )

    cdef uintptr_t ground_truth_ptr = y_true.ptr
    cdef uintptr_t preds_ptr = y_pred.ptr

    com = completeness_score(handle_[0],
                             <int*> ground_truth_ptr,
                             <int*> preds_ptr,
                             <int> n_rows,
                             <int> lower_class_range,
                             <int> upper_class_range)

    return com
