# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-2-Clause

import re
import cffi

import numpy as np

from io import StringIO
from numba import cuda
from numba.cuda import float32, float64, int32, intp
from numba.cuda.types import float16, CPointer
from numba.cuda import declare_device
from numba.cuda.testing import unittest, CUDATestCase
from numba.cuda.testing import (
    skip_on_cudasim,
    skip_with_nvdisasm,
    skip_without_nvdisasm,
    skip_if_nvjitlink_missing,
)


@skip_on_cudasim("Simulator does not generate code to be inspected")
class TestInspect(CUDATestCase):
    @property
    def cc(self):
        return cuda.current_context().device.compute_capability

    def test_monotyped(self):
        sig = (float32, int32)

        @cuda.jit(sig)
        def foo(x, y):
            """
            // LLVM: define void
            // LLVM-SAME: foo
            // LLVM-LABEL: entry:
            // LLVM-NEXT:         br label %"[[VAL_0:.*]]"
            // LLVM-NEXT:       [[VAL_0]]:
            // LLVM-NEXT:         ret void

            // ASM: Generated by NVIDIA NVVM Compiler
            // ASM: foo
            """
            pass

        file = StringIO()
        foo.inspect_types(file=file)
        typeanno = file.getvalue()
        # Function name in annotation
        self.assertIn("foo", typeanno)
        # Signature in annotation
        self.assertIn("(float32, int32)", typeanno)
        file.close()

        self.assertFileCheckLLVM(foo, sig)
        self.assertFileCheckAsm(foo, sig)

    def test_polytyped(self):
        @cuda.jit
        def foo(x, y):
            """
            // LLVM: define void
            // LLVM-SAME: foo
            // LLVM_INT-SAME: i64
            // LLVM_INT-SAME: i64
            // LLVM_FLOAT-SAME: double
            // LLVM_FLOAT-SAME: double

            // ASM: Generated by NVIDIA NVVM Compiler
            // ASM: .visible
            // ASM-SAME: .entry
            // ASM-SAME: foo
            """
            pass

        foo[1, 1](1, 1)
        foo[1, 1](1.2, 2.4)

        int_sig = (intp, intp)
        float_sig = (float64, float64)

        self.assertFileCheckLLVM(
            foo, int_sig, check_prefixes=["LLVM", "LLVM_INT"]
        )
        self.assertFileCheckAsm(foo, int_sig, check_prefixes=["ASM"])
        self.assertFileCheckLLVM(
            foo, float_sig, check_prefixes=["LLVM", "LLVM_FLOAT"]
        )
        self.assertFileCheckAsm(foo, float_sig, check_prefixes=["ASM"])

        file = StringIO()
        foo.inspect_types(file=file)
        typeanno = file.getvalue()
        file.close()
        # Signature in annotation
        self.assertIn("({0}, {0})".format(intp), typeanno)
        self.assertIn("(float64, float64)", typeanno)

        # Signature in LLVM dict
        llvmirs = foo.inspect_llvm()
        self.assertEqual(
            2,
            len(llvmirs),
        )
        self.assertIn((intp, intp), llvmirs)
        self.assertIn((float64, float64), llvmirs)

        asmdict = foo.inspect_asm()

        # Signature in assembly dict
        self.assertEqual(
            2,
            len(asmdict),
        )
        self.assertIn((intp, intp), asmdict)
        self.assertIn((float64, float64), asmdict)

    def _test_inspect_sass(self, kernel, name, sass):
        # Ensure function appears in output
        seen_function = False
        for line in sass.split():
            if ".text" in line and name in line:
                seen_function = True
        self.assertTrue(seen_function)

        self.assertRegex(sass, r'//## File ".*/test_inspect.py", line [0-9]')

        # Some instructions common to all supported architectures that should
        # appear in the output
        self.assertIn("S2R", sass)  # Special register to register
        self.assertIn("BRA", sass)  # Branch
        self.assertIn("EXIT", sass)  # Exit program

    @skip_on_cudasim("Simulator does not generate code to be inspected")
    @skip_if_nvjitlink_missing("nvJitLink is required for LTO")
    def test_inspect_lto_asm(self):
        ffi = cffi.FFI()

        ext = cuda.CUSource("""
            #include <cuda_fp16.h>
            extern "C"
            __device__ int add_f2_f2(__half * res, __half * a, __half *b) {
                *res = *a + *b;
                return 0;
            }
            """)

        add = declare_device(
            "add_f2_f2",
            float16(CPointer(float16), CPointer(float16)),
            link=ext,
        )

        @cuda.jit
        def k(arr):
            local_arr = cuda.local.array(shape=1, dtype=np.float16)
            local_arr2 = cuda.local.array(shape=1, dtype=np.float16)
            local_arr[0] = 1
            local_arr2[0] = 2

            ptr = ffi.from_buffer(local_arr)
            ptr2 = ffi.from_buffer(local_arr2)

            arr[0] = add(ptr, ptr2)

        arr = np.array([0], dtype=np.float16)

        k[1, 1](arr)

        allasms = k.inspect_asm()
        asm = next(iter(allasms.values()))

        regex = re.compile(r"call(.|\n)*add_f2_f2")
        self.assertRegex(asm, regex)

        all_ext_asms = k.inspect_lto_ptx()
        lto_asm = next(iter(all_ext_asms.values()))

        self.assertIn("add.f16", lto_asm)
        self.assertNotIn("call", lto_asm)

        np.testing.assert_equal(arr[0], np.float16(1) + np.float16(2))

    def skip_on_cuda_version_issues(self):
        # FIXME: This should be unskipped once the cause of certain nvdisasm
        # versions failing to dump SASS with certain driver / nvJitLink
        # versions is understood
        self.skipTest(
            "Relocation information required for analysis not preserved"
        )

    @skip_without_nvdisasm("nvdisasm needed for inspect_sass()")
    def test_inspect_sass_eager(self):
        self.skip_on_cuda_version_issues()

        sig = (float32[::1], int32[::1])

        @cuda.jit(sig, lineinfo=True)
        def add(x, y):
            i = cuda.grid(1)
            if i < len(x):
                x[i] += y[i]

        self._test_inspect_sass(add, "add", add.inspect_sass(sig))

    @skip_without_nvdisasm("nvdisasm needed for inspect_sass()")
    def test_inspect_sass_lazy(self):
        self.skip_on_cuda_version_issues()

        @cuda.jit(lineinfo=True)
        def add(x, y):
            i = cuda.grid(1)
            if i < len(x):
                x[i] += y[i]

        x = np.arange(10).astype(np.int32)
        y = np.arange(10).astype(np.float32)
        add[1, 10](x, y)

        signature = (int32[::1], float32[::1])
        self._test_inspect_sass(add, "add", add.inspect_sass(signature))

    @skip_with_nvdisasm(
        "Missing nvdisasm exception only generated when it is not present"
    )
    def test_inspect_sass_nvdisasm_missing(self):
        @cuda.jit((float32[::1],))
        def f(x):
            x[0] = 0

        with self.assertRaises(RuntimeError) as raises:
            f.inspect_sass()

        self.assertIn("nvdisasm has not been found", str(raises.exception))

    @skip_without_nvdisasm("nvdisasm needed for inspect_sass_cfg()")
    def test_inspect_sass_cfg(self):
        self.skip_on_cuda_version_issues()

        sig = (float32[::1], int32[::1])

        @cuda.jit(sig)
        def add(x, y):
            i = cuda.grid(1)
            if i < len(x):
                x[i] += y[i]

        self.assertRegex(
            add.inspect_sass_cfg(signature=sig), r"digraph\s*\w\s*{(.|\n)*\n}"
        )


if __name__ == "__main__":
    unittest.main()
