# SPDX-FileCopyrightText: 2009-2022 the scikit-image team
# SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause

import cupy as cp

from .._shared import utils


def _match_cumulative_cdf(source, template):
    """
    Return modified source array so that the cumulative density function of
    its values matches the cumulative density function of the template.
    """
    if source.dtype.kind == "u":
        src_lookup = source.reshape(-1)
        src_counts = cp.bincount(src_lookup)
        tmpl_counts = cp.bincount(template.reshape(-1))

        # omit values where the count was 0
        tmpl_values = cp.nonzero(tmpl_counts)[0]
        tmpl_counts = tmpl_counts[tmpl_values]
    else:
        src_values, src_lookup, src_counts = cp.unique(
            source.reshape(-1), return_inverse=True, return_counts=True
        )
        tmpl_values, tmpl_counts = cp.unique(
            template.reshape(-1), return_counts=True
        )

    # calculate normalized quantiles for each array
    src_quantiles = cp.cumsum(src_counts) / source.size
    tmpl_quantiles = cp.cumsum(tmpl_counts) / template.size

    interp_a_values = cp.interp(src_quantiles, tmpl_quantiles, tmpl_values)
    return interp_a_values[src_lookup].reshape(source.shape)


@utils.channel_as_last_axis(channel_arg_positions=(0, 1))
def match_histograms(image, reference, *, channel_axis=None):
    """Adjust an image so that its cumulative histogram matches that of another.

    The adjustment is applied separately for each channel.

    Parameters
    ----------
    image : ndarray
        Input image. Can be gray-scale or in color.
    reference : ndarray
        Image to match histogram of. Must have the same number of channels as
        image.
    channel_axis : int or None, optional
        If None, the image is assumed to be a grayscale (single channel) image.
        Otherwise, this parameter indicates which axis of the array corresponds
        to channels.

    Returns
    -------
    matched : ndarray
        Transformed input image.

    Raises
    ------
    ValueError
        Thrown when the number of channels in the input image and the reference
        differ.

    References
    ----------
    .. [1] http://paulbourke.net/miscellaneous/equalisation/

    """
    if image.ndim != reference.ndim:
        raise ValueError(
            "Image and reference must have the same number of channels."
        )

    if channel_axis is not None:
        if image.shape[channel_axis] != reference.shape[channel_axis]:
            raise ValueError(
                "Number of channels in the input image and "
                "reference image must match!"
            )

        matched = cp.empty(image.shape, dtype=image.dtype)
        for channel in range(image.shape[-1]):
            matched_channel = _match_cumulative_cdf(
                image[..., channel], reference[..., channel]
            )
            matched[..., channel] = matched_channel
    else:
        matched = _match_cumulative_cdf(image, reference)

    if matched.dtype.kind == "f":
        # output a float32 result when the input is float16 or float32
        out_dtype = utils._supported_float_type(image.dtype)
        matched = matched.astype(out_dtype, copy=False)
    return matched
