/**
 * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
 * SPDX-License-Identifier: Apache-2.0
 */
#pragma once

#include <stdexcept>
#include <type_traits>

#include <nvtx3/nvtx3.hpp>

#include <rapidsmpf/utils.hpp>

/**
 * @brief Help function to convert value to 64 bit signed integer
 */
template <typename T>
    requires std::is_integral_v<T>
[[nodiscard]] std::int64_t convert_to_64bit(T value) {
    if constexpr (std::numeric_limits<T>::max()
                  > std::numeric_limits<std::int64_t>::max())
    {
        if (value > std::numeric_limits<std::int64_t>::max()) {
            throw std::overflow_error(
                "convert_to_64bit(x): x too large to fit std::int64_t"
            );
        }
    }
    return std::int64_t(value);
}

/**
 * @brief Help function to convert value to 64 bit float
 */
template <typename T>
    requires std::is_floating_point_v<T>
[[nodiscard]] double convert_to_64bit(T value) {
    return double(value);
}

/**
 * @brief Tag type for rapidsmpf's NVTX domain.
 */
struct rapidsmpf_domain {
    static constexpr char const* name{"rapidsmpf"};  ///< nvtx domain name
};

// Macro to create a static, registered string that will not have a name conflict with any
// registered string defined in the same scope.
#define RAPIDSMPF_REGISTER_STRING(msg)                                         \
    [](const char* a_msg) -> auto& {                                           \
        static nvtx3::registered_string_in<rapidsmpf_domain> a_reg_str{a_msg}; \
        return a_reg_str;                                                      \
    }(msg)

// implement the func range macro with a value
#define RAPIDSMPF_NVTX_FUNC_RANGE_IMPL_WITH_VAL(val)              \
    static_assert(                                                \
        std::is_arithmetic_v<decltype(val)>,                      \
        "Value must be integral or floating point type"           \
    );                                                            \
    nvtx3::scoped_range_in<rapidsmpf_domain> RAPIDSMPF_CONCAT(    \
        _rapidsmpf_nvtx_range, __LINE__                           \
    ) {                                                           \
        nvtx3::event_attributes {                                 \
            RAPIDSMPF_REGISTER_STRING(__func__), nvtx3::payload { \
                convert_to_64bit(val)                             \
            }                                                     \
        }                                                         \
    }

// implement the func range macro without a value
#define RAPIDSMPF_NVTX_FUNC_RANGE_IMPL_WITHOUT_VAL() NVTX3_FUNC_RANGE_IN(rapidsmpf_domain)

// Macro selector for 0 vs 1 arguments
#define RAPIDSMPF_GET_MACRO_FUNC(_0, _1, NAME, ...) NAME

// unwrap the arguments and call the appropriate macro
#define RAPIDSMPF_NVTX_FUNC_RANGE_IMPL(...)                                                                                                          \
    RAPIDSMPF_GET_MACRO_FUNC(dummy __VA_OPT__(, ) __VA_ARGS__, RAPIDSMPF_NVTX_FUNC_RANGE_IMPL_WITH_VAL, RAPIDSMPF_NVTX_FUNC_RANGE_IMPL_WITHOUT_VAL)( \
        __VA_ARGS__                                                                                                                                  \
    )

/**
 * @brief Convenience macro for generating an NVTX range in the `rapidsmpf` domain
 * from the lifetime of a function.
 *
 * The name of the immediately enclosing function returned by `__func__` is used as
 * the message.
 *
 * Usage:
 * - `RAPIDSMPF_NVTX_FUNC_RANGE()` - Annotate with function name only
 * - `RAPIDSMPF_NVTX_FUNC_RANGE(payload)` - Annotate with function name and payload
 *
 * The optional argument is the payload to annotate (integral or floating-point value).
 *
 * Example:
 * ```
 * void some_function(){
 *    RAPIDSMPF_NVTX_FUNC_RANGE();        // `some_function` is used as the message
 *    RAPIDSMPF_NVTX_FUNC_RANGE(42);      // With payload
 *    ...
 * }
 * ```
 */
#define RAPIDSMPF_NVTX_FUNC_RANGE(...) RAPIDSMPF_NVTX_FUNC_RANGE_IMPL(__VA_ARGS__)

// implement the scoped range macro with a value
#define RAPIDSMPF_NVTX_SCOPED_RANGE_IMPL_WITH_VAL(msg, val)    \
    nvtx3::scoped_range_in<rapidsmpf_domain> RAPIDSMPF_CONCAT( \
        _rapidsmpf_nvtx_range, __LINE__                        \
    ) {                                                        \
        nvtx3::event_attributes {                              \
            RAPIDSMPF_REGISTER_STRING(msg), nvtx3::payload {   \
                convert_to_64bit(val)                          \
            }                                                  \
        }                                                      \
    }

// implement the scoped range macro without a value
#define RAPIDSMPF_NVTX_SCOPED_RANGE_IMPL_WITHOUT_VAL(msg)      \
    nvtx3::scoped_range_in<rapidsmpf_domain> RAPIDSMPF_CONCAT( \
        _rapidsmpf_nvtx_range, __LINE__                        \
    ) {                                                        \
        nvtx3::event_attributes {                              \
            RAPIDSMPF_REGISTER_STRING(msg)                     \
        }                                                      \
    }

// Macro to detect number of arguments (1 or 2)
#define RAPIDSMPF_GET_MACRO(_1, _2, NAME, ...) NAME

// unwrap the arguments and call the appropriate macro
#define RAPIDSMPF_NVTX_SCOPED_RANGE_IMPL(...)        \
    RAPIDSMPF_GET_MACRO(                             \
        __VA_ARGS__,                                 \
        RAPIDSMPF_NVTX_SCOPED_RANGE_IMPL_WITH_VAL,   \
        RAPIDSMPF_NVTX_SCOPED_RANGE_IMPL_WITHOUT_VAL \
    )                                                \
    (__VA_ARGS__)

/**
 * @brief Convenience macro for generating an NVTX scoped range in the `rapidsmpf` domain
 * to annotate a time duration.
 *
 * Usage:
 * - `RAPIDSMPF_NVTX_SCOPED_RANGE(message)` - Annotate with message only
 * - `RAPIDSMPF_NVTX_SCOPED_RANGE(message, payload)` - Annotate with message and payload
 *
 * The first argument is the message to annotate (const char*).
 * The second argument (optional) is the payload to annotate (integral or floating-point
 * value).
 *
 * Example:
 * ```
 * void some_function(){
 *    RAPIDSMPF_NVTX_SCOPED_RANGE("my function");        // Without payload
 *    RAPIDSMPF_NVTX_SCOPED_RANGE("my function", 42);    // With payload
 *    ...
 * }
 * ```
 */
#define RAPIDSMPF_NVTX_SCOPED_RANGE(...) RAPIDSMPF_NVTX_SCOPED_RANGE_IMPL(__VA_ARGS__)

/**
 * @brief Convenience macro for generating an NVTX scoped range in the `rapidsmpf` domain
 * that is only active when RAPIDSMPF_VERBOSE_INFO is defined.
 *
 * This macro behaves identically to RAPIDSMPF_NVTX_SCOPED_RANGE, but only creates
 * the NVTX range when the RAPIDSMPF_VERBOSE_INFO compile-time flag is set.
 *
 * Usage:
 * - `RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE(message)` - Annotate with message only
 * - `RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE(message, payload)` - Annotate with message and
 * payload
 *
 * Example:
 * ```
 * void some_function(){
 *    RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE("detailed operation");
 *    RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE("detailed operation", count);
 *    ...
 * }
 * ```
 */
#if RAPIDSMPF_VERBOSE_INFO
#define RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE(...) RAPIDSMPF_NVTX_SCOPED_RANGE(__VA_ARGS__)
#else
#define RAPIDSMPF_NVTX_SCOPED_RANGE_VERBOSE(...)
#endif

#define RAPIDSMPF_NVTX_MARKER_IMPL(msg, val)                                  \
    nvtx3::mark_in<rapidsmpf_domain>(nvtx3::event_attributes{                 \
        RAPIDSMPF_REGISTER_STRING(msg), nvtx3::payload{convert_to_64bit(val)} \
    })

/**
 * @brief Convenience macro for generating an NVTX marker in the `rapidsmpf` domain to
 * annotate a certain time point.
 *
 * @param message The message to annotate.
 * @param payload The payload to annotate.
 *
 * Use this macro to annotate asynchronous operations.
 */
#define RAPIDSMPF_NVTX_MARKER(message, payload) \
    RAPIDSMPF_NVTX_MARKER_IMPL(message, payload)

/**
 * @brief Convenience macro for generating an NVTX marker in the `rapidsmpf` domain to
 * annotate a certain time point, that is only activate when RAPIDSMPF_VERBOSE_INFO is
 * defined.
 *
 * @param message The message to annotate.
 * @param payload The payload to annotate.
 *
 * Use this macro to annotate asynchronous operations.
 */
#if RAPIDSMPF_VERBOSE_INFO
#define RAPIDSMPF_NVTX_MARKER_VERBOSE(message, payload) \
    RAPIDSMPF_NVTX_MARKER_IMPL(message, payload)
#else
#define RAPIDSMPF_NVTX_MARKER_VERBOSE(message, payload)
#endif
