#
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#
# cython: language_level=3

import numpy as np

cimport cuvs.common.cydlpack

from cuvs.common.resources import auto_sync_resources

from cython.operator cimport dereference as deref
from libcpp cimport bool, cast
from libcpp.string cimport string

from cuvs.common cimport cydlpack
from cuvs.distance_type cimport cuvsDistanceType

from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray
from pylibraft.common.cai_wrapper import wrap_array
from pylibraft.common.interruptible import cuda_interruptible

from cuvs.common.device_tensor_view import DeviceTensorView
from cuvs.distance import DISTANCE_NAMES, DISTANCE_TYPES
from cuvs.neighbors.common import _check_input_array

from libc.stdint cimport (
    int8_t,
    int64_t,
    uint8_t,
    uint32_t,
    uint64_t,
    uintptr_t,
)

from cuvs.common.exceptions import check_cuvs


cdef class IndexParams:
    """
    Parameters to build index for IvfPq nearest neighbor search

    Parameters
    ----------
    n_lists : int, default = 1024
        The number of clusters used in the coarse quantizer.
    metric : str, default="sqeuclidean"
        String denoting the metric type.
        Valid values for metric: ["sqeuclidean", "inner_product",
        "euclidean", "cosine"],
        where:

            - sqeuclidean is the euclidean distance without the square root
              operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2,
            - euclidean is the euclidean distance
            - inner product distance is defined as
              distance(a, b) = \\sum_i a_i * b_i.
            - cosine distance is defined as
              distance(a, b) = 1 - \\sum_i a_i * b_i / ( ||a||_2 * ||b||_2).

    kmeans_n_iters : int, default = 20
        The number of iterations searching for kmeans centers during index
        building.
    kmeans_trainset_fraction : int, default = 0.5
        If kmeans_trainset_fraction is less than 1, then the dataset is
        subsampled, and only n_samples * kmeans_trainset_fraction rows
        are used for training.
    pq_bits : int, default = 8
        The bit length of the vector element after quantization.
    pq_dim : int, default = 0
        The dimensionality of a the vector after product quantization.
        When zero, an optimal value is selected using a heuristic. Note
        pq_dim * pq_bits must be a multiple of 8. Hint: a smaller 'pq_dim'
        results in a smaller index size and better search performance, but
        lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number,
        but multiple of 8 are desirable for good performance. If 'pq_bits'
        is not 8, 'pq_dim' should be a multiple of 8. For good performance,
        it is desirable that 'pq_dim' is a multiple of 32. Ideally,
        'pq_dim' should be also a divisor of the dataset dim.
    codebook_kind : string, default = "subspace"
        Valid values ["subspace", "cluster"]
    force_random_rotation : bool, default = False
        Apply a random rotation matrix on the input data and queries even
        if `dim % pq_dim == 0`. Note: if `dim` is not multiple of `pq_dim`,
        a random rotation is always applied to the input data and queries
        to transform the working space from `dim` to `rot_dim`, which may
        be slightly larger than the original space and and is a multiple
        of `pq_dim` (`rot_dim % pq_dim == 0`). However, this transform is
        not necessary when `dim` is multiple of `pq_dim` (`dim == rot_dim`,
        hence no need in adding "extra" data columns / features). By
        default, if `dim == rot_dim`, the rotation transform is
        initialized with the identity matrix. When
        `force_random_rotation == True`, a random orthogonal transform
        matrix is generated regardless of the values of `dim` and `pq_dim`.
    add_data_on_build : bool, default = True
        After training the coarse and fine quantizers, we will populate
        the index with the dataset if add_data_on_build == True, otherwise
        the index is left empty, and the extend method can be used
        to add new vectors to the index.
    conservative_memory_allocation : bool, default = True
        By default, the algorithm allocates more space than necessary for
        individual clusters (`list_data`). This allows to amortize the cost
        of memory allocation and reduce the number of data copies during
        repeated calls to `extend` (extending the database).
        To disable this behavior and use as little GPU memory for the
        database as possible, set this flat to `True`.
    max_train_points_per_pq_code : int, default = 256
        The max number of data points to use per PQ code during PQ codebook
        training. Using more data points per PQ code may increase the
        quality of PQ codebook but may also increase the build time. The
        parameter is applied to both PQ codebook generation methods, i.e.,
        PER_SUBSPACE and PER_CLUSTER. In both cases, we will use
        pq_book_size * max_train_points_per_pq_code training points to
        train each codebook.
    codes_layout : string, default = "interleaved"
        Memory layout of the IVF-PQ list data.
        Valid values ["flat", "interleaved"]

            - flat: Codes are stored contiguously, one vector's codes after
              another.
            - interleaved: Codes are interleaved for optimized search
              performance. This is the default and recommended for search
              workloads.
    """

    def __cinit__(self):
        cuvsIvfPqIndexParamsCreate(&self.params)

    def __dealloc__(self):
        if self.params != NULL:
            check_cuvs(cuvsIvfPqIndexParamsDestroy(self.params))

    def __init__(self, *,
                 n_lists=1024,
                 metric="sqeuclidean",
                 metric_arg=2.0,
                 kmeans_n_iters=20,
                 kmeans_trainset_fraction=0.5,
                 pq_bits=8,
                 pq_dim=0,
                 codebook_kind="subspace",
                 force_random_rotation=False,
                 add_data_on_build=True,
                 conservative_memory_allocation=False,
                 max_train_points_per_pq_code=256,
                 codes_layout="interleaved"):
        self.params.n_lists = n_lists
        self.params.metric = <cuvsDistanceType>DISTANCE_TYPES[metric]
        self.params.metric_arg = metric_arg
        self.params.kmeans_n_iters = kmeans_n_iters
        self.params.kmeans_trainset_fraction = kmeans_trainset_fraction
        self.params.pq_bits = pq_bits
        self.params.pq_dim = pq_dim
        if codebook_kind == "subspace":
            self.params.codebook_kind = cuvsIvfPqCodebookGen.CUVS_IVF_PQ_CODEBOOK_GEN_PER_SUBSPACE
        elif codebook_kind == "cluster":
            self.params.codebook_kind = cuvsIvfPqCodebookGen.CUVS_IVF_PQ_CODEBOOK_GEN_PER_CLUSTER
        else:
            raise ValueError("Incorrect codebook kind %s" % codebook_kind)
        self.params.force_random_rotation = force_random_rotation
        self.params.add_data_on_build = add_data_on_build
        self.params.conservative_memory_allocation = \
            conservative_memory_allocation
        self.params.max_train_points_per_pq_code = \
            max_train_points_per_pq_code
        if codes_layout == "flat":
            self.params.codes_layout = cuvsIvfPqListLayout.CUVS_IVF_PQ_LIST_LAYOUT_FLAT
        elif codes_layout == "interleaved":
            self.params.codes_layout = cuvsIvfPqListLayout.CUVS_IVF_PQ_LIST_LAYOUT_INTERLEAVED
        else:
            raise ValueError("Incorrect codes layout %s" % codes_layout)

    def get_handle(self):
        return <size_t> self.params

    @property
    def metric(self):
        return DISTANCE_NAMES[self.params.metric]

    @property
    def metric_arg(self):
        return self.params.metric_arg

    @property
    def add_data_on_build(self):
        return self.params.add_data_on_build

    @property
    def n_lists(self):
        return self.params.n_lists

    @property
    def kmeans_n_iters(self):
        return self.params.kmeans_n_iters

    @property
    def kmeans_trainset_fraction(self):
        return self.params.kmeans_trainset_fraction

    @property
    def pq_bits(self):
        return self.params.pq_bits

    @property
    def pq_dim(self):
        return self.params.pq_dim

    @property
    def codebook_kind(self):
        return self.params.codebook_kind

    @property
    def force_random_rotation(self):
        return self.params.force_random_rotation

    @property
    def add_data_on_build(self):
        return self.params.add_data_on_build

    @property
    def conservative_memory_allocation(self):
        return self.params.conservative_memory_allocation

    @property
    def max_train_points_per_pq_code(self):
        return self.params.max_train_points_per_pq_code

    @property
    def codes_layout(self):
        if self.params.codes_layout == cuvsIvfPqListLayout.CUVS_IVF_PQ_LIST_LAYOUT_FLAT:
            return "flat"
        else:
            return "interleaved"

    def get_handle(self):
        return <size_t>self.params

cdef class Index:
    """
    IvfPq index object. This object stores the trained IvfPq index state
    which can be used to perform nearest neighbors searches.
    """

    cdef cuvsIvfPqIndex_t index
    cdef bool trained

    def __cinit__(self):
        self.trained = False
        check_cuvs(cuvsIvfPqIndexCreate(&self.index))

    def __dealloc__(self):
        check_cuvs(cuvsIvfPqIndexDestroy(self.index))

    @property
    def trained(self):
        return self.trained

    def __repr__(self):
        return "Index(type=IvfPq)"

    @property
    def n_lists(self):
        """ The number of inverted lists (clusters) """
        cdef int64_t n_lists
        check_cuvs(cuvsIvfPqIndexGetNLists(self.index, &n_lists))
        return n_lists

    @property
    def dim(self):
        """ dimensionality of the cluster centers """
        cdef int64_t dim
        check_cuvs(cuvsIvfPqIndexGetDim(self.index, &dim))
        return dim

    @property
    def pq_dim(self):
        """ The dimensionality of an encoded vector after compression by PQ """
        cdef int64_t pq_dim
        check_cuvs(cuvsIvfPqIndexGetPqDim(self.index, &pq_dim))
        return pq_dim

    @property
    def pq_len(self):
        """ The dimensionality of a subspace, i.e. the number of vector
        components mapped to a subspace """
        cdef int64_t pq_len
        check_cuvs(cuvsIvfPqIndexGetPqLen(self.index, &pq_len))
        return pq_len

    @property
    def pq_bits(self):
        """ The bit length of an encoded vector element after
        compression by PQ. """
        cdef int64_t pq_bits
        check_cuvs(cuvsIvfPqIndexGetPqBits(self.index, &pq_bits))
        return pq_bits

    def __len__(self):
        cdef int64_t size
        check_cuvs(cuvsIvfPqIndexGetSize(self.index, &size))
        return size

    @property
    def centers(self):
        """ Get the cluster centers corresponding to the lists in the
        original space """
        if not self.trained:
            raise ValueError("Index needs to be built before getting centers")

        output = DeviceTensorView()
        cdef cydlpack.DLManagedTensor * tensor = \
            <cydlpack.DLManagedTensor*><size_t>output.get_handle()
        check_cuvs(cuvsIvfPqIndexGetCenters(self.index, tensor))
        output.parent = self
        return output

    @property
    def centers_padded(self):
        """ Get the padded cluster centers [n_lists, dim_ext]
        where dim_ext = round_up(dim + 1, 8).
        This returns contiguous data suitable for build_precomputed. """
        if not self.trained:
            raise ValueError("Index needs to be built before getting"
                             " centers_padded")

        output = DeviceTensorView()
        cdef cydlpack.DLManagedTensor * tensor = \
            <cydlpack.DLManagedTensor*><size_t>output.get_handle()
        check_cuvs(cuvsIvfPqIndexGetCentersPadded(self.index, tensor))
        output.parent = self
        return output

    @property
    def pq_centers(self):
        """ Get the PQ cluster centers """
        if not self.trained:
            raise ValueError("Index needs to be built before getting"
                             " pq centers")

        output = DeviceTensorView()
        cdef cydlpack.DLManagedTensor * tensor = \
            <cydlpack.DLManagedTensor*><size_t>output.get_handle()
        check_cuvs(cuvsIvfPqIndexGetPqCenters(self.index, tensor))
        output.parent = self
        return output

    @property
    def centers_rot(self):
        """ Get the rotated cluster centers [n_lists, rot_dim]
        where rot_dim = pq_len * pq_dim """
        if not self.trained:
            raise ValueError("Index needs to be built before getting"
                             " centers_rot")

        output = DeviceTensorView()
        cdef cydlpack.DLManagedTensor * tensor = \
            <cydlpack.DLManagedTensor*><size_t>output.get_handle()
        check_cuvs(cuvsIvfPqIndexGetCentersRot(self.index, tensor))
        output.parent = self
        return output

    @property
    def rotation_matrix(self):
        """ Get the rotation matrix [rot_dim, dim]
        Transform matrix (original space -> rotated padded space) """
        if not self.trained:
            raise ValueError("Index needs to be built before getting"
                             " rotation_matrix")

        output = DeviceTensorView()
        cdef cydlpack.DLManagedTensor * tensor = \
            <cydlpack.DLManagedTensor*><size_t>output.get_handle()
        check_cuvs(cuvsIvfPqIndexGetRotationMatrix(self.index, tensor))
        output.parent = self
        return output

    @property
    def list_sizes(self):
        """ Get the sizes of each list """
        if not self.trained:
            raise ValueError("Index needs to be built before getting"
                             " list sizes")
        output = DeviceTensorView()
        cdef cydlpack.DLManagedTensor * tensor = \
            <cydlpack.DLManagedTensor*><size_t>output.get_handle()
        check_cuvs(cuvsIvfPqIndexGetListSizes(self.index, tensor))
        output.parent = self
        return output

    @auto_sync_resources
    def lists(self, resources=None):
        """ Iterates through the pq-encoded list data

        This function returns an iterator over each list,
        with each value being the pq-encoded data for the
        entire list

        Parameters
        ----------
        {resources_docstring}
        """
        list_sizes = self.list_sizes.copy_to_host()
        for i, list_size in enumerate(list_sizes):
            indices = self.list_indices(i, n_rows=list_size)
            list_data = self.list_data(i, n_rows=list_size,
                                       resources=resources)
            yield indices, list_data

    @auto_sync_resources
    def list_data(self, label, n_rows=0, offset=0, out_codes=None,
                  resources=None):
        """ Gets unpacked list data for a single list (cluster)

        Parameters
        ----------
        label, int:
            The cluster to get data for
        n_rows, int:
            The number of rows to return for the cluster (0 is all rows)
        offset, int:
            The row to start getting data at
        out_codes, CAI
            Optional buffer to hold memory. Will be created if None
        {resources_docstring}
        """
        if n_rows == 0:
            n_rows = self.list_sizes.copy_to_host()[label]

        n_cols = int(np.ceil(self.pq_dim * self.pq_bits / 8))

        if out_codes is None:
            out_codes = device_ndarray.empty((n_rows, n_cols), dtype="ubyte")

        out_codes_cai= wrap_array(out_codes)
        _check_input_array(out_codes_cai, [np.dtype("ubyte")],
                           exp_rows=n_rows, exp_cols=n_cols)

        cdef cydlpack.DLManagedTensor* out_codes_dlpack = \
            cydlpack.dlpack_c(out_codes_cai)

        cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

        check_cuvs(cuvsIvfPqIndexUnpackContiguousListData(res,
                                                          self.index,
                                                          out_codes_dlpack,
                                                          label,
                                                          offset))
        return out_codes

    def list_indices(self, label, n_rows=0):
        """ Gets indices for a single cluster (list)

        Parameters
        ----------
        label, int:
            The cluster to get data for
        n_rows, int, optional
            Number of rows in the list
        """
        output = DeviceTensorView()
        cdef cydlpack.DLManagedTensor * tensor = \
            <cydlpack.DLManagedTensor*><size_t>output.get_handle()
        check_cuvs(cuvsIvfPqIndexGetListIndices(self.index, label, tensor))
        output.parent = self

        # the indices tensor being returned here is larger than the number of
        # rows in the actual list, and the remaining values are padded out
        # with -1.
        # fix this by slicing down to the number of rows in the actual list
        if n_rows == 0:
            n_rows = self.list_sizes.copy_to_host()[label]
        return output.slice_rows(0, n_rows)


@auto_sync_resources
def build(IndexParams index_params, dataset, resources=None):
    """
    Build the IvfPq index from the dataset for efficient search.

    The input dataset array can be either CUDA array interface compliant matrix
    or an array interface compliant matrix in host memory.

    Parameters
    ----------
    index_params : :py:class:`cuvs.neighbors.ivf_pq.IndexParams`
        Parameters on how to build the index
    dataset : Array interface compliant matrix shape (n_samples, dim)
        Supported dtype [float32, float16, int8, uint8]
    {resources_docstring}

    Returns
    -------
    index: :py:class:`cuvs.neighbors.ivf_pq.Index`

    Examples
    --------

    >>> import cupy as cp
    >>> from cuvs.neighbors import ivf_pq
    >>> n_samples = 50000
    >>> n_features = 50
    >>> n_queries = 1000
    >>> k = 10
    >>> dataset = cp.random.random_sample((n_samples, n_features),
    ...                                   dtype=cp.float32)
    >>> build_params = ivf_pq.IndexParams(metric="sqeuclidean")
    >>> index = ivf_pq.build(build_params, dataset)
    >>> distances, neighbors = ivf_pq.search(ivf_pq.SearchParams(),
    ...                                        index, dataset,
    ...                                        k)
    >>> distances = cp.asarray(distances)
    >>> neighbors = cp.asarray(neighbors)
    """

    dataset_ai = wrap_array(dataset)
    _check_input_array(dataset_ai, [np.dtype('float32'),
                                    np.dtype('float16'),
                                    np.dtype('byte'),
                                    np.dtype('ubyte')])

    cdef Index idx = Index()
    cdef cuvsError_t build_status
    cdef cydlpack.DLManagedTensor* dataset_dlpack = \
        cydlpack.dlpack_c(dataset_ai)
    cdef cuvsIvfPqIndexParams* params = index_params.params

    cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

    with cuda_interruptible():
        check_cuvs(cuvsIvfPqBuild(
            res,
            params,
            dataset_dlpack,
            idx.index
        ))
        idx.trained = True

    return idx


@auto_sync_resources
def build_precomputed(IndexParams index_params, uint32_t dim, pq_centers, centers,
                      centers_rot, rotation_matrix, resources=None):
    """
    Build a view-type IVF-PQ index from precomputed centroids and codebook.

    This function creates a non-owning index that stores a reference to the provided device data.
    All parameters must be provided with correct extents. The caller is responsible for ensuring
    the lifetime of the input data exceeds the lifetime of the returned index.

    The index_params must be consistent with the provided matrices. Specifically:
    - index_params.codebook_kind determines the expected shape of pq_centers
    - index_params.metric will be stored in the index
    - index_params.conservative_memory_allocation will be stored in the index

    Parameters
    ----------
    index_params : :py:class:`cuvs.neighbors.ivf_pq.IndexParams`
        Parameters that must be consistent with the provided matrices
    dim : int
        Dimensionality of the input data
    pq_centers : CUDA array interface compliant tensor
        PQ codebook on device memory with required shape:
        - codebook_kind "subspace": [pq_dim, pq_len, pq_book_size]
        - codebook_kind "cluster":  [n_lists, pq_len, pq_book_size]
        Supported dtype: float32
    centers : CUDA array interface compliant matrix
        Cluster centers in the original space [n_lists, dim_ext]
        where dim_ext = round_up(dim + 1, 8).
        Supported dtype: float32
    centers_rot : CUDA array interface compliant matrix
        Rotated cluster centers [n_lists, rot_dim]
        where rot_dim = pq_len * pq_dim.
        Supported dtype: float32
    rotation_matrix : CUDA array interface compliant matrix
        Transform matrix (original space -> rotated padded space) [rot_dim, dim].
        Supported dtype: float32
    {resources_docstring}

    Returns
    -------
    index: :py:class:`cuvs.neighbors.ivf_pq.Index`

    Examples
    --------

    >>> import cupy as cp
    >>> from cuvs.neighbors import ivf_pq
    >>> n_lists = 100
    >>> dim = 128
    >>> pq_dim = 16
    >>> pq_bits = 8
    >>> pq_len = (dim + pq_dim - 1) // pq_dim  # ceil division
    >>> pq_book_size = 1 << pq_bits
    >>> rot_dim = pq_len * pq_dim
    >>> dim_ext = ((dim + 1 + 7) // 8) * 8  # round_up(dim + 1, 8)
    >>>
    >>> # Prepare precomputed matrices (example with random data)
    >>> pq_centers = cp.random.random((pq_dim, pq_len, pq_book_size),
    ...                               dtype=cp.float32)
    >>> centers = cp.random.random((n_lists, dim_ext), dtype=cp.float32)
    >>> centers_rot = cp.random.random((n_lists, rot_dim), dtype=cp.float32)
    >>> rotation_matrix = cp.random.random((rot_dim, dim), dtype=cp.float32)
    >>>
    >>> # Build index from precomputed data
    >>> build_params = ivf_pq.IndexParams(n_lists=n_lists, pq_dim=pq_dim,
    ...                                    pq_bits=pq_bits,
    ...                                    codebook_kind="subspace")
    >>> index = ivf_pq.build_precomputed(build_params, dim, pq_centers,
    ...                                   centers, centers_rot, rotation_matrix)
    """
    # Wrap and validate inputs
    pq_centers_ai = wrap_array(pq_centers)
    _check_input_array(pq_centers_ai, [np.dtype('float32')])

    centers_ai = wrap_array(centers)
    _check_input_array(centers_ai, [np.dtype('float32')])

    centers_rot_ai = wrap_array(centers_rot)
    _check_input_array(centers_rot_ai, [np.dtype('float32')])

    rotation_matrix_ai = wrap_array(rotation_matrix)
    _check_input_array(rotation_matrix_ai, [np.dtype('float32')])

    # Create index
    cdef Index idx = Index()

    # Convert to DLPack
    cdef cydlpack.DLManagedTensor* pq_centers_dlpack = \
        cydlpack.dlpack_c(pq_centers_ai)
    cdef cydlpack.DLManagedTensor* centers_dlpack = \
        cydlpack.dlpack_c(centers_ai)
    cdef cydlpack.DLManagedTensor* centers_rot_dlpack = \
        cydlpack.dlpack_c(centers_rot_ai)
    cdef cydlpack.DLManagedTensor* rotation_matrix_dlpack = \
        cydlpack.dlpack_c(rotation_matrix_ai)

    cdef cuvsIvfPqIndexParams* params = index_params.params
    cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

    with cuda_interruptible():
        check_cuvs(cuvsIvfPqBuildPrecomputed(
            res,
            params,
            dim,
            pq_centers_dlpack,
            centers_dlpack,
            centers_rot_dlpack,
            rotation_matrix_dlpack,
            idx.index
        ))
        idx.trained = True

    return idx


cdef _map_dtype_np_to_cuda(dtype, supported_dtypes=None):
    if supported_dtypes is not None and dtype not in supported_dtypes:
        raise TypeError("Type %s is not supported" % str(dtype))
    return {np.float32: cudaDataType_t.CUDA_R_32F,
            np.float16: cudaDataType_t.CUDA_R_16F,
            np.uint8: cudaDataType_t.CUDA_R_8U,
            np.int8: cudaDataType_t.CUDA_R_8I}[dtype]


cdef class SearchParams:
    """
    Supplemental parameters to search IVF-Pq index

    Parameters
    ----------
    n_probes: int
        The number of clusters to search.
    lut_dtype: default = np.float32
        Data type of look up table to be created dynamically at search
        time. The use of low-precision types reduces the amount of shared
        memory required at search time, so fast shared memory kernels can
        be used even for datasets with large dimansionality. Note that
        the recall is slightly degraded when low-precision type is
        selected. Possible values [np.float32, np.float16, np.uint8]
    internal_distance_dtype: default = np.float32
        Storage data type for distance/similarity computation.
        Possible values [np.float32, np.float16]
    coarse_search_dtype: default = np.float32
        [Experimental] The data type to use as the GEMM element type when
        searching the clusters to probe.
        Possible values: [np.float32, np.float16, np.int8].
        - Legacy default: np.float32
        - Recommended for performance: np.float16 (half)
        - Experimental/low-precision: np.int8
    max_internal_batch_size: default = 4096
        Set the internal batch size to improve GPU utilization at the cost
        of larger memory footprint.
    """

    def __cinit__(self):
        cuvsIvfPqSearchParamsCreate(&self.params)

    def __dealloc__(self):
        if self.params != NULL:
            check_cuvs(cuvsIvfPqSearchParamsDestroy(self.params))

    def __init__(self, *, n_probes=20, lut_dtype=np.float32,
                 internal_distance_dtype=np.float32,
                 coarse_search_dtype=np.float32,
                 max_internal_batch_size=4096):
        self.params.n_probes = n_probes
        self.params.lut_dtype = _map_dtype_np_to_cuda(lut_dtype)
        self.params.internal_distance_dtype = \
            _map_dtype_np_to_cuda(internal_distance_dtype)
        self.params.coarse_search_dtype = \
            _map_dtype_np_to_cuda(coarse_search_dtype)
        self.params.max_internal_batch_size = max_internal_batch_size

    def get_handle(self):
        return <size_t> self.params

    @property
    def n_probes(self):
        return self.params.n_probes

    @property
    def lut_dtype(self):
        return self.params.lut_dtype

    @property
    def internal_distance_dtype(self):
        return self.params.internal_distance_dtype

    @property
    def coarse_search_dtype(self):
        return self.params.coarse_search_dtype

    @property
    def max_internal_batch_size(self):
        return self.params.max_internal_batch_size

    def get_handle(self):
        return <size_t>self.params


@auto_sync_resources
@auto_convert_output
def search(SearchParams search_params,
           Index index,
           queries,
           k,
           neighbors=None,
           distances=None,
           resources=None):
    """
    Find the k nearest neighbors for each query.

    Parameters
    ----------
    search_params : :py:class:`cuvs.neighbors.ivf_pq.SearchParams`
        Parameters on how to search the index
    index : :py:class:`cuvs.neighbors.ivf_pq.Index`
        Trained IvfPq index.
    queries : CUDA array interface compliant matrix shape (n_samples, dim)
        Supported dtype [float, int8, uint8]
    k : int
        The number of neighbors.
    neighbors : Optional CUDA array interface compliant matrix shape
                (n_queries, k), dtype int64_t. If supplied, neighbor
                indices will be written here in-place. (default None)
    distances : Optional CUDA array interface compliant matrix shape
                (n_queries, k) If supplied, the distances to the
                neighbors will be written here in-place. (default None)
    {resources_docstring}

    Examples
    --------
    >>> import cupy as cp
    >>> from cuvs.neighbors import ivf_pq
    >>> n_samples = 50000
    >>> n_features = 50
    >>> n_queries = 1000
    >>> dataset = cp.random.random_sample((n_samples, n_features),
    ...                                   dtype=cp.float32)
    >>> # Build the index
    >>> index = ivf_pq.build(ivf_pq.IndexParams(), dataset)
    >>>
    >>> # Search using the built index
    >>> queries = cp.random.random_sample((n_queries, n_features),
    ...                                   dtype=cp.float32)
    >>> k = 10
    >>> search_params = ivf_pq.SearchParams(n_probes=20)
    >>>
    >>> distances, neighbors = ivf_pq.search(search_params, index, queries,
    ...                                     k)
    """
    if not index.trained:
        raise ValueError("Index needs to be built before calling search.")

    queries_cai = wrap_array(queries)
    _check_input_array(queries_cai, [np.dtype('float32'),
                                     np.dtype('float16'),
                                     np.dtype('byte'),
                                     np.dtype('ubyte')])

    cdef uint32_t n_queries = queries_cai.shape[0]

    if neighbors is None:
        neighbors = device_ndarray.empty((n_queries, k), dtype='int64')

    neighbors_cai = wrap_array(neighbors)
    _check_input_array(neighbors_cai, [np.dtype('int64')],
                       exp_rows=n_queries, exp_cols=k)

    if distances is None:
        distances = device_ndarray.empty((n_queries, k), dtype='float32')

    distances_cai = wrap_array(distances)
    _check_input_array(distances_cai, [np.dtype('float32')],
                       exp_rows=n_queries, exp_cols=k)

    cdef cuvsIvfPqSearchParams* params = search_params.params
    cdef cuvsError_t search_status
    cdef cydlpack.DLManagedTensor* queries_dlpack = \
        cydlpack.dlpack_c(queries_cai)
    cdef cydlpack.DLManagedTensor* neighbors_dlpack = \
        cydlpack.dlpack_c(neighbors_cai)
    cdef cydlpack.DLManagedTensor* distances_dlpack = \
        cydlpack.dlpack_c(distances_cai)
    cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

    with cuda_interruptible():
        check_cuvs(cuvsIvfPqSearch(
            res,
            params,
            index.index,
            queries_dlpack,
            neighbors_dlpack,
            distances_dlpack
        ))

    return (distances, neighbors)


@auto_sync_resources
def save(filename, Index index, bool include_dataset=True, resources=None):
    """
    Saves the index to a file.

    Saving / loading the index is experimental. The serialization format is
    subject to change.

    Parameters
    ----------
    filename : string
        Name of the file.
    index : Index
        Trained IVF-PQ index.
    {resources_docstring}

    Examples
    --------
    >>> import cupy as cp
    >>> from cuvs.neighbors import ivf_pq
    >>> n_samples = 50000
    >>> n_features = 50
    >>> dataset = cp.random.random_sample((n_samples, n_features),
    ...                                   dtype=cp.float32)
    >>> # Build index
    >>> index = ivf_pq.build(ivf_pq.IndexParams(), dataset)
    >>> # Serialize and deserialize the ivf_pq index built
    >>> ivf_pq.save("my_index.bin", index)
    >>> index_loaded = ivf_pq.load("my_index.bin")
    """
    cdef string c_filename = filename.encode('utf-8')
    cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()
    check_cuvs(cuvsIvfPqSerialize(res,
                                  c_filename.c_str(),
                                  index.index))


@auto_sync_resources
def load(filename, resources=None):
    """
    Loads index from file.

    Saving / loading the index is experimental. The serialization format is
    subject to change, therefore loading an index saved with a previous
    version of cuvs is not guaranteed to work.

    Parameters
    ----------
    filename : string
        Name of the file.
    {resources_docstring}

    Returns
    -------
    index : Index

    """
    cdef Index idx = Index()
    cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()
    cdef string c_filename = filename.encode('utf-8')

    check_cuvs(cuvsIvfPqDeserialize(
        res,
        c_filename.c_str(),
        idx.index
    ))
    idx.trained = True
    return idx


@auto_sync_resources
def extend(Index index, new_vectors, new_indices, resources=None):
    """
    Extend an existing index with new vectors.

    The input array can be either CUDA array interface compliant matrix or
    array interface compliant matrix in host memory.


    Parameters
    ----------
    index : ivf_pq.Index
        Trained ivf_pq object.
    new_vectors : array interface compliant matrix shape (n_samples, dim)
        Supported dtype [float, int8, uint8]
    new_indices : array interface compliant vector shape (n_samples)
        Supported dtype [int64]
    {resources_docstring}

    Returns
    -------
    index: py:class:`cuvs.neighbors.ivf_pq.Index`

    Examples
    --------

    >>> import cupy as cp
    >>> from cuvs.neighbors import ivf_pq
    >>> n_samples = 50000
    >>> n_features = 50
    >>> n_queries = 1000
    >>> dataset = cp.random.random_sample((n_samples, n_features),
    ...                                   dtype=cp.float32)
    >>> index = ivf_pq.build(ivf_pq.IndexParams(), dataset)
    >>> n_rows = 100
    >>> more_data = cp.random.random_sample((n_rows, n_features),
    ...                                     dtype=cp.float32)
    >>> indices = n_samples + cp.arange(n_rows, dtype=cp.int64)
    >>> index = ivf_pq.extend(index, more_data, indices)
    >>> # Search using the built index
    >>> queries = cp.random.random_sample((n_queries, n_features),
    ...                                   dtype=cp.float32)
    >>> distances, neighbors = ivf_pq.search(ivf_pq.SearchParams(),
    ...                                      index, queries,
    ...                                      k=10)
    """

    new_vectors_ai = wrap_array(new_vectors)
    _check_input_array(new_vectors_ai, [np.dtype('float32'),
                                        np.dtype('float16'),
                                        np.dtype('byte'),
                                        np.dtype('ubyte')])

    new_indices_ai = wrap_array(new_indices)
    _check_input_array(new_indices_ai, [np.dtype('int64')])
    cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

    cdef cydlpack.DLManagedTensor* new_vectors_dlpack = \
        cydlpack.dlpack_c(new_vectors_ai)

    cdef cydlpack.DLManagedTensor* new_indices_dlpack = \
        cydlpack.dlpack_c(new_indices_ai)

    with cuda_interruptible():
        check_cuvs(cuvsIvfPqExtend(
            res,
            new_vectors_dlpack,
            new_indices_dlpack,
            index.index
        ))

    return index


@auto_sync_resources
def transform(Index index, input_dataset, output_labels=None, output_dataset=None, resources=None):
    """
    Transform a dataset by applying pq-encoding to the vectors.


    Parameters
    ----------
    index : ivf_pq.Index
        Trained ivf_pq object.
    input_dataset : array interface compliant matrix shape (n_samples, dim)
        Supported dtype [float]
    new_indices : Optional array interface compliant vector shape (n_samples)
        Supported dtype [uint32]
    output_dataset : Optional array interface compliant matrix shape (n_samples, pq_dim)
        Supported dtype [uint8]

    {resources_docstring}

    Returns
    -------
    output_labels, output_dataset:
        The cluster that each point in the dataset belongs to, and the transformed dataset
    """

    input_dataset_ai = wrap_array(input_dataset)
    _check_input_array(input_dataset_ai, [np.dtype('float32'), np.dtype('float16'),
                                          np.dtype('int8'), np.dtype('uint8')])

    cdef uint32_t n_samples = input_dataset_ai.shape[0]
    cdef uint32_t pq_dim = index.pq_dim

    if output_labels is None:
        output_labels = device_ndarray.empty((n_samples), dtype='uint32')
    output_labels_ai = wrap_array(output_labels)
    _check_input_array(output_labels_ai, [np.dtype('uint32')],
                       exp_rows=n_samples)

    n_output_cols = int(np.ceil(index.pq_dim * index.pq_bits / 8))

    if output_dataset is None:
        output_dataset = device_ndarray.empty((n_samples, n_output_cols), dtype='uint8')
    output_dataset_ai = wrap_array(output_dataset)
    _check_input_array(output_dataset_ai, [np.dtype('uint8')],
                       exp_rows=n_samples, exp_cols=n_output_cols)

    cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()

    cdef cydlpack.DLManagedTensor* input_dataset_dlpack = \
        cydlpack.dlpack_c(input_dataset_ai)
    cdef cydlpack.DLManagedTensor* output_labels_dlpack = \
        cydlpack.dlpack_c(output_labels_ai)
    cdef cydlpack.DLManagedTensor* output_dataset_dlpack = \
        cydlpack.dlpack_c(output_dataset_ai)

    with cuda_interruptible():
        check_cuvs(cuvsIvfPqTransform(
            res,
            index.index,
            input_dataset_dlpack,
            output_labels_dlpack,
            output_dataset_dlpack
        ))

    return output_labels, output_dataset
