#
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
#
from libcpp cimport bool
from pylibraft.common.handle cimport handle_t

from cuml.internals.logger cimport level_enum


cdef extern from "cuml/matrix/kernel_params.hpp" namespace "ML::matrix" nogil:

    enum class KernelType:
        LINEAR, POLYNOMIAL, RBF, TANH, PRECOMPUTED

    cdef struct KernelParams:
        KernelType kernel
        int degree
        double gamma
        double coef0


cdef extern from "cuml/svm/svm_parameter.h" namespace "ML::SVM" nogil:

    enum SvmType:
        C_SVC,
        NU_SVC,
        EPSILON_SVR,
        NU_SVR

    cdef struct SvmParameter:
        double C
        double cache_size
        int max_outer_iter
        int max_iter
        int nochange_steps
        double tol
        level_enum verbosity
        double epsilon
        SvmType svmType


cdef extern from "cuml/svm/svm_model.h" namespace "ML::SVM" nogil:

    cdef cppclass SupportStorage[math_t]:
        int nnz
        int* indptr
        int* indices
        math_t* data

    cdef cppclass SvmModel[math_t]:
        int n_support
        int n_cols
        math_t b
        math_t *dual_coefs
        SupportStorage[math_t] support_matrix
        int *support_idx
        int n_classes
        math_t *unique_labels


cdef extern from "cuml/svm/svc.hpp" namespace "ML::SVM" nogil:

    cdef void svcPredict[math_t](
        const handle_t &handle,
        math_t* data,
        int n_rows,
        int n_cols,
        KernelParams &kernel_params,
        const SvmModel[math_t] &model,
        math_t *preds,
        math_t buffer_size,
        bool predict_class,
    ) except +

    cdef void svcPredictSparse[math_t](
        const handle_t &handle,
        int* indptr,
        int* indices,
        math_t* data,
        int n_rows,
        int n_cols,
        int nnz,
        KernelParams &kernel_params,
        const SvmModel[math_t] &model,
        math_t *preds,
        math_t buffer_size,
        bool predict_class,
    ) except +

    cdef void svmFreeBuffers[math_t](
        const handle_t &handle,
        SvmModel[math_t] &m,
    ) except +


cdef extern from "cuml/svm/svr.hpp" namespace "ML::SVM" nogil:

    cdef int svrFit[math_t](
        const handle_t &handle,
        math_t* data,
        int n_rows,
        int n_cols,
        math_t *y,
        const SvmParameter &param,
        KernelParams &kernel_params,
        SvmModel[math_t] &model,
        const math_t *sample_weight,
    ) except+

    cdef int svrFitSparse[math_t](
        const handle_t &handle,
        int* indptr,
        int* indices,
        math_t* data,
        int n_rows,
        int n_cols,
        int nnz,
        math_t *y,
        const SvmParameter &param,
        KernelParams &kernel_params,
        SvmModel[math_t] &model,
        const math_t *sample_weight,
    ) except+
