# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
import numpy as np

from cuml.internals import get_handle, reflect
from cuml.internals.array import CumlArray
from cuml.internals.input_utils import input_to_cuml_array

from libc.stdint cimport uintptr_t
from libcpp cimport bool as boolcpp
from pylibraft.common.handle cimport handle_t


cdef extern from "cuml/tsa/stationarity.h" namespace "ML" nogil:
    int cpp_kpss "ML::Stationarity::kpss_test" (
        const handle_t& handle,
        const float* d_y,
        boolcpp* results,
        int batch_size,
        int n_obs,
        int d, int D, int s,
        float pval_threshold) except +

    int cpp_kpss "ML::Stationarity::kpss_test" (
        const handle_t& handle,
        const double* d_y,
        boolcpp* results,
        int batch_size,
        int n_obs,
        int d, int D, int s,
        double pval_threshold) except +


@reflect
def kpss_test(y, d=0, D=0, s=0, pval_threshold=0.05,
              handle=None, convert_dtype=True) -> CumlArray:
    """
    Perform the KPSS stationarity test on the data differenced according
    to the given order

    Parameters
    ----------
    y : dataframe or array-like (device or host)
        The time series data, assumed to have each time series in columns.
        Acceptable formats: cuDF DataFrame, cuDF Series, NumPy ndarray,
        Numba device ndarray, cuda array interface compliant array like CuPy.
    d: integer
        Order of simple differencing
    D: integer
        Order of seasonal differencing
    s: integer
        Seasonal period if D > 0
    pval_threshold : float
        The p-value threshold above which a series is considered stationary.
    handle : cuml.Handle or None, default=None

        .. deprecated:: 26.02
            The `handle` argument was deprecated in 26.02 and will be removed
            in 26.04. There's no need to pass in a handle, cuml now manages
            this resource automatically.

    Returns
    -------
    stationarity : List[bool]
        A list of the stationarity test result for each series in the batch
    """
    d_y, n_obs, batch_size, dtype = \
        input_to_cuml_array(y,
                            convert_to_dtype=(np.float32 if convert_dtype
                                              else None),
                            check_dtype=[np.float32, np.float64])
    cdef uintptr_t d_y_ptr = d_y.ptr

    handle = get_handle(handle=handle)
    cdef handle_t* handle_ = <handle_t*><size_t>handle.getHandle()

    results = CumlArray.empty(batch_size, dtype=bool)
    cdef uintptr_t d_results = results.ptr

    # Call C++ function
    if dtype == np.float32:
        cpp_kpss(handle_[0],
                 <float*> d_y_ptr,
                 <boolcpp*> d_results,
                 <int> batch_size,
                 <int> n_obs,
                 <int> d, <int> D, <int> s,
                 <float> pval_threshold)
    elif dtype == np.float64:
        cpp_kpss(handle_[0],
                 <double*> d_y_ptr,
                 <boolcpp*> d_results,
                 <int> batch_size,
                 <int> n_obs,
                 <int> d, <int> D, <int> s,
                 <double> pval_threshold)

    return results
