# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-2-Clause

import cmath
import concurrent.futures
import contextlib
import enum
import gc
import math
import unittest
import os
import io
import subprocess
import sys
import shutil
import warnings
import tempfile
import time
import types as pytypes
from functools import cached_property
import multiprocessing as mp
import traceback

import numpy as np

from numba.cuda import types
from numba.cuda.core import errors
from numba.cuda.core import config
from numba.cuda.typing import cffi_utils
from numba.cuda.memory_management.nrt import rtsys
from numba.cuda.extending import (
    typeof_impl,
    register_model,
    NativeValue,
)
from numba.cuda.core.pythonapi import unbox
from numba.cuda.datamodel.models import OpaqueModel
from numba.cuda.np import numpy_support

from numba.cuda import HAS_NUMBA

if HAS_NUMBA:
    from numba.core.extending import (
        typeof_impl as upstream_typeof_impl,
    )
    from numba.core import types as upstream_types


class EnableNRTStatsMixin(object):
    """Mixin to enable the NRT statistics counters."""

    def setUp(self):
        rtsys.memsys_enable_stats()

    def tearDown(self):
        rtsys.memsys_disable_stats()


skip_unless_cffi = unittest.skipUnless(cffi_utils.SUPPORTED, "requires cffi")

_lnx_reason = "linux only test"
linux_only = unittest.skipIf(not sys.platform.startswith("linux"), _lnx_reason)

_win_reason = "Windows only test"
windows_only = unittest.skipIf(not sys.platform.startswith("win"), _win_reason)

IS_NUMPY_2 = numpy_support.numpy_version >= (2, 0)
skip_if_numpy_2 = unittest.skipIf(IS_NUMPY_2, "Not supported on numpy 2.0+")

# Typeguard
has_typeguard = bool(os.environ.get("NUMBA_USE_TYPEGUARD", 0))

skip_unless_typeguard = unittest.skipUnless(
    has_typeguard,
    "Typeguard is not enabled",
)

skip_if_typeguard = unittest.skipIf(
    has_typeguard,
    "Broken if Typeguard is enabled",
)

_trashcan_dir = "numba-cuda-tests"

if os.name == "nt":
    # Under Windows, gettempdir() points to the user-local temp dir
    _trashcan_dir = os.path.join(tempfile.gettempdir(), _trashcan_dir)
else:
    # Mix the UID into the directory name to allow different users to
    # run the test suite without permission errors (issue #1586)
    _trashcan_dir = os.path.join(
        tempfile.gettempdir(), "%s.%s" % (_trashcan_dir, os.getuid())
    )

# Stale temporary directories are deleted after they are older than this value.
# The test suite probably won't ever take longer than this...
_trashcan_timeout = 24 * 3600  # 1 day


def _create_trashcan_dir():
    try:
        os.mkdir(_trashcan_dir)
    except FileExistsError:
        pass


def _purge_trashcan_dir():
    freshness_threshold = time.time() - _trashcan_timeout
    for fn in sorted(os.listdir(_trashcan_dir)):
        fn = os.path.join(_trashcan_dir, fn)
        try:
            st = os.stat(fn)
            if st.st_mtime < freshness_threshold:
                shutil.rmtree(fn, ignore_errors=True)
        except OSError:
            # In parallel testing, several processes can attempt to
            # remove the same entry at once, ignore.
            pass


def _create_trashcan_subdir(prefix):
    _purge_trashcan_dir()
    path = tempfile.mkdtemp(prefix=prefix + "-", dir=_trashcan_dir)
    return path


def temp_directory(prefix):
    """
    Create a temporary directory with the given *prefix* that will survive
    at least as long as this process invocation.  The temporary directory
    will be eventually deleted when it becomes stale enough.

    This is necessary because a DLL file can't be deleted while in use
    under Windows.

    An interesting side-effect is to be able to inspect the test files
    shortly after a test suite run.
    """
    _create_trashcan_dir()
    return _create_trashcan_subdir(prefix)


def import_dynamic(modname):
    """
    Import and return a module of the given name.  Care is taken to
    avoid issues due to Python's internal directory caching.
    """
    import importlib

    importlib.invalidate_caches()
    __import__(modname)
    return sys.modules[modname]


def ignore_internal_warnings():
    """Use in testing within a ` warnings.catch_warnings` block to filter out
    warnings that are unrelated/internally generated by Numba.
    """
    # Filter out warnings from typeguard
    warnings.filterwarnings("ignore", module="typeguard")
    # Filter out warnings about TBB interface mismatch
    warnings.filterwarnings(
        action="ignore",
        message=r".*TBB_INTERFACE_VERSION.*",
        category=errors.NumbaWarning,
        module=r"numba\.np\.ufunc\.parallel.*",
    )


@contextlib.contextmanager
def override_config(name, value):
    """
    Return a context manager that temporarily sets Numba config variable
    *name* to *value*.  *name* must be the name of an existing variable
    in numba.config.
    """
    old_value = getattr(config, name)
    setattr(config, name, value)
    try:
        yield
    finally:
        setattr(config, name, old_value)


def run_in_subprocess(code, flags=None, env=None, timeout=30):
    """Run a snippet of Python code in a subprocess with flags, if any are
    given. 'env' is passed to subprocess.Popen(). 'timeout' is passed to
    popen.communicate().

    Returns the stdout and stderr of the subprocess after its termination.
    """
    if flags is None:
        flags = []
    cmd = (
        [
            sys.executable,
        ]
        + flags
        + ["-c", code]
    )
    popen = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
    )
    out, err = popen.communicate(timeout=timeout)
    if popen.returncode != 0:
        msg = "process failed with code %s: stderr follows\n%s\n"
        raise AssertionError(msg % (popen.returncode, err.decode()))
    return out, err


def captured_stdout():
    """Capture the output of sys.stdout:

    with captured_stdout() as stdout:
        print("hello")
    self.assertEqual(stdout.getvalue(), "hello\n")
    """
    return contextlib.redirect_stdout(io.StringIO())


def captured_stderr():
    """Capture the output of sys.stderr:

    with captured_stderr() as stderr:
        print("hello", file=sys.stderr)
    self.assertEqual(stderr.getvalue(), "hello\n")
    """
    return contextlib.redirect_stderr(io.StringIO())


class TestCase(unittest.TestCase):
    longMessage = True

    # A random state yielding the same random numbers for any test case.
    # Use as `self.random.<method name>`
    @cached_property
    def random(self):
        return np.random.RandomState(42)

    def reset_module_warnings(self, module):
        """
        Reset the warnings registry of a module.  This can be necessary
        as the warnings module is buggy in that regard.
        See http://bugs.python.org/issue4180
        """
        if isinstance(module, str):
            module = sys.modules[module]
        try:
            del module.__warningregistry__
        except AttributeError:
            pass

    @contextlib.contextmanager
    def assertTypingError(self):
        """
        A context manager that asserts the enclosed code block fails
        compiling in nopython mode.
        """
        _accepted_errors = (
            errors.LoweringError,
            errors.TypingError,
            TypeError,
            NotImplementedError,
        )
        with self.assertRaises(_accepted_errors) as cm:
            yield cm

    @contextlib.contextmanager
    def assertRefCount(self, *objects):
        """
        A context manager that asserts the given objects have the
        same reference counts before and after executing the
        enclosed block.
        """
        old_refcounts = [sys.getrefcount(x) for x in objects]
        yield
        gc.collect()
        new_refcounts = [sys.getrefcount(x) for x in objects]
        for old, new, obj in zip(old_refcounts, new_refcounts, objects):
            if old != new:
                self.fail(
                    "Refcount changed from %d to %d for object: %r"
                    % (old, new, obj)
                )

    def assertRefCountEqual(self, *objects):
        gc.collect()
        rc = [sys.getrefcount(x) for x in objects]
        rc_0 = rc[0]
        for i in range(len(objects))[1:]:
            rc_i = rc[i]
            if rc_0 != rc_i:
                self.fail(
                    f"Refcount for objects does not match. "
                    f"#0({rc_0}) != #{i}({rc_i}) does not match."
                )

    @contextlib.contextmanager
    def assertNoNRTLeak(self):
        """
        A context manager that asserts no NRT leak was created during
        the execution of the enclosed block.
        """
        old = rtsys.get_allocation_stats()
        yield
        new = rtsys.get_allocation_stats()
        total_alloc = new.alloc - old.alloc
        total_free = new.free - old.free
        total_mi_alloc = new.mi_alloc - old.mi_alloc
        total_mi_free = new.mi_free - old.mi_free
        self.assertEqual(
            total_alloc,
            total_free,
            "number of data allocs != number of data frees",
        )
        self.assertEqual(
            total_mi_alloc,
            total_mi_free,
            "number of meminfo allocs != number of meminfo frees",
        )

    _bool_types = (bool, np.bool_)
    _exact_typesets = [
        _bool_types,
        (int,),
        (str,),
        (np.integer,),
        (bytes, np.bytes_),
    ]
    _approx_typesets = [(float,), (complex,), (np.inexact)]
    _sequence_typesets = [(tuple, list)]
    _float_types = (float, np.floating)
    _complex_types = (complex, np.complexfloating)

    def _detect_family(self, numeric_object):
        """
        This function returns a string description of the type family
        that the object in question belongs to.  Possible return values
        are: "exact", "complex", "approximate", "sequence", and "unknown"
        """
        if isinstance(numeric_object, np.ndarray):
            return "ndarray"

        if isinstance(numeric_object, enum.Enum):
            return "enum"

        for tp in self._sequence_typesets:
            if isinstance(numeric_object, tp):
                return "sequence"

        for tp in self._exact_typesets:
            if isinstance(numeric_object, tp):
                return "exact"

        for tp in self._complex_types:
            if isinstance(numeric_object, tp):
                return "complex"

        for tp in self._approx_typesets:
            if isinstance(numeric_object, tp):
                return "approximate"

        return "unknown"

    def _fix_dtype(self, dtype):
        """
        Fix the given *dtype* for comparison.
        """
        # Under 64-bit Windows, Numpy may return either int32 or int64
        # arrays depending on the function.
        if (
            sys.platform == "win32"
            and sys.maxsize > 2**32
            and dtype == np.dtype("int32")
        ):
            return np.dtype("int64")
        else:
            return dtype

    def _fix_strides(self, arr):
        """
        Return the strides of the given array, fixed for comparison.
        Strides for 0- or 1-sized dimensions are ignored.
        """
        if arr.size == 0:
            return [0] * arr.ndim
        else:
            return [
                stride / arr.itemsize
                for (stride, shape) in zip(arr.strides, arr.shape)
                if shape > 1
            ]

    def assertStridesEqual(self, first, second):
        """
        Test that two arrays have the same shape and strides.
        """
        self.assertEqual(first.shape, second.shape, "shapes differ")
        self.assertEqual(first.itemsize, second.itemsize, "itemsizes differ")
        self.assertEqual(
            self._fix_strides(first),
            self._fix_strides(second),
            "strides differ",
        )

    def assertPreciseEqual(
        self,
        first,
        second,
        prec="exact",
        ulps=1,
        msg=None,
        ignore_sign_on_zero=False,
        abs_tol=None,
    ):
        """
        Versatile equality testing function with more built-in checks than
        standard assertEqual().

        For arrays, test that layout, dtype, shape are identical, and
        recursively call assertPreciseEqual() on the contents.

        For other sequences, recursively call assertPreciseEqual() on
        the contents.

        For scalars, test that two scalars or have similar types and are
        equal up to a computed precision.
        If the scalars are instances of exact types or if *prec* is
        'exact', they are compared exactly.
        If the scalars are instances of inexact types (float, complex)
        and *prec* is not 'exact', then the number of significant bits
        is computed according to the value of *prec*: 53 bits if *prec*
        is 'double', 24 bits if *prec* is single.  This number of bits
        can be lowered by raising the *ulps* value.
        ignore_sign_on_zero can be set to True if zeros are to be considered
        equal regardless of their sign bit.
        abs_tol if this is set to a float value its value is used in the
        following. If, however, this is set to the string "eps" then machine
        precision of the type(first) is used in the following instead. This
        kwarg is used to check if the absolute difference in value between first
        and second is less than the value set, if so the numbers being compared
        are considered equal. (This is to handle small numbers typically of
        magnitude less than machine precision).

        Any value of *prec* other than 'exact', 'single' or 'double'
        will raise an error.
        """
        try:
            self._assertPreciseEqual(
                first, second, prec, ulps, msg, ignore_sign_on_zero, abs_tol
            )
        except AssertionError as exc:
            failure_msg = str(exc)
            # Fall off of the 'except' scope to avoid Python 3 exception
            # chaining.
        else:
            return
        # Decorate the failure message with more information
        self.fail("when comparing %s and %s: %s" % (first, second, failure_msg))

    def _assertPreciseEqual(
        self,
        first,
        second,
        prec="exact",
        ulps=1,
        msg=None,
        ignore_sign_on_zero=False,
        abs_tol=None,
    ):
        """Recursive workhorse for assertPreciseEqual()."""

        def _assertNumberEqual(first, second, delta=None):
            if (
                delta is None
                or first == second == 0.0
                or math.isinf(first)
                or math.isinf(second)
            ):
                self.assertEqual(first, second, msg=msg)
                # For signed zeros
                if not ignore_sign_on_zero:
                    try:
                        if math.copysign(1, first) != math.copysign(1, second):
                            self.fail(
                                self._formatMessage(
                                    msg, "%s != %s" % (first, second)
                                )
                            )
                    except TypeError:
                        pass
            else:
                self.assertAlmostEqual(first, second, delta=delta, msg=msg)

        first_family = self._detect_family(first)
        second_family = self._detect_family(second)

        assertion_message = "Type Family mismatch. (%s != %s)" % (
            first_family,
            second_family,
        )
        if msg:
            assertion_message += ": %s" % (msg,)
        self.assertEqual(first_family, second_family, msg=assertion_message)

        # We now know they are in the same comparison family
        compare_family = first_family

        # For recognized sequences, recurse
        if compare_family == "ndarray":
            dtype = self._fix_dtype(first.dtype)
            self.assertEqual(dtype, self._fix_dtype(second.dtype))
            self.assertEqual(
                first.ndim, second.ndim, "different number of dimensions"
            )
            self.assertEqual(first.shape, second.shape, "different shapes")
            self.assertEqual(
                first.flags.writeable,
                second.flags.writeable,
                "different mutability",
            )
            # itemsize is already checked by the dtype test above
            self.assertEqual(
                self._fix_strides(first),
                self._fix_strides(second),
                "different strides",
            )
            if first.dtype != dtype:
                first = first.astype(dtype)
            if second.dtype != dtype:
                second = second.astype(dtype)
            for a, b in zip(first.flat, second.flat):
                self._assertPreciseEqual(
                    a, b, prec, ulps, msg, ignore_sign_on_zero, abs_tol
                )
            return

        elif compare_family == "sequence":
            self.assertEqual(len(first), len(second), msg=msg)
            for a, b in zip(first, second):
                self._assertPreciseEqual(
                    a, b, prec, ulps, msg, ignore_sign_on_zero, abs_tol
                )
            return

        elif compare_family == "exact":
            exact_comparison = True

        elif compare_family in ["complex", "approximate"]:
            exact_comparison = False

        elif compare_family == "enum":
            self.assertIs(first.__class__, second.__class__)
            self._assertPreciseEqual(
                first.value,
                second.value,
                prec,
                ulps,
                msg,
                ignore_sign_on_zero,
                abs_tol,
            )
            return

        elif compare_family == "unknown":
            # Assume these are non-numeric types: we will fall back
            # on regular unittest comparison.
            self.assertIs(first.__class__, second.__class__)
            exact_comparison = True

        else:
            assert 0, "unexpected family"

        # If a Numpy scalar, check the dtype is exactly the same too
        # (required for datetime64 and timedelta64).
        if hasattr(first, "dtype") and hasattr(second, "dtype"):
            self.assertEqual(first.dtype, second.dtype)

        # Mixing bools and non-bools should always fail
        if isinstance(first, self._bool_types) != isinstance(
            second, self._bool_types
        ):
            assertion_message = "Mismatching return types (%s vs. %s)" % (
                first.__class__,
                second.__class__,
            )
            if msg:
                assertion_message += ": %s" % (msg,)
            self.fail(assertion_message)

        try:
            if cmath.isnan(first) and cmath.isnan(second):
                # The NaNs will compare unequal, skip regular comparison
                return
        except TypeError:
            # Not floats.
            pass

        # if absolute comparison is set, use it
        if abs_tol is not None:
            if abs_tol == "eps":
                rtol = np.finfo(type(first)).eps
            elif isinstance(abs_tol, float):
                rtol = abs_tol
            else:
                raise ValueError(
                    'abs_tol is not "eps" or a float, found %s' % abs_tol
                )
            if abs(first - second) < rtol:
                return

        exact_comparison = exact_comparison or prec == "exact"

        if not exact_comparison and prec != "exact":
            if prec == "single":
                bits = 24
            elif prec == "double":
                bits = 53
            else:
                raise ValueError("unsupported precision %r" % (prec,))
            k = 2 ** (ulps - bits - 1)
            delta = k * (abs(first) + abs(second))
        else:
            delta = None
        if isinstance(first, self._complex_types):
            _assertNumberEqual(first.real, second.real, delta)
            _assertNumberEqual(first.imag, second.imag, delta)
        elif isinstance(first, (np.timedelta64, np.datetime64)):
            # Since Np 1.16 NaT == NaT is False, so special comparison needed
            if np.isnat(first):
                self.assertEqual(np.isnat(first), np.isnat(second))
            else:
                _assertNumberEqual(first, second, delta)
        else:
            _assertNumberEqual(first, second, delta)

    def subprocess_test_runner(
        self,
        test_module,
        test_class=None,
        test_name=None,
        envvars=None,
        timeout=60,
        flags=None,
        _subproc_test_env="1",
    ):
        """
        Runs named unit test(s) as specified in the arguments as:
        test_module.test_class.test_name. test_module must always be supplied
        and if no further refinement is made with test_class and test_name then
        all tests in the module will be run. The tests will be run in a
        subprocess with environment variables specified in `envvars`.
        If given, envvars must be a map of form:
            environment variable name (str) -> value (str)
        If given, flags must be a map of form:
            flag including the `-` (str) -> value (str)
        It is most convenient to use this method in conjunction with
        @needs_subprocess as the decorator will cause the decorated test to be
        skipped unless the `SUBPROC_TEST` environment variable is set to
        the same value of ``_subproc_test_env``
        (this special environment variable is set by this method such that the
        specified test(s) will not be skipped in the subprocess).


        Following execution in the subprocess this method will check the test(s)
        executed without error. The timeout kwarg can be used to allow more time
        for longer running tests, it defaults to 60 seconds.
        """
        parts = (test_module, test_class, test_name)
        fully_qualified_test = ".".join(x for x in parts if x is not None)
        flags_args = []
        if flags is not None:
            for flag, value in flags.items():
                flags_args.append(f"{flag}")
                flags_args.append(f"{value}")
        cmd = [
            sys.executable,
            *flags_args,
            "-m",
            "numba.runtests",
            fully_qualified_test,
        ]
        env_copy = os.environ.copy()
        env_copy["SUBPROC_TEST"] = _subproc_test_env
        try:
            env_copy["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_RCFILE"]
        except KeyError:
            pass  # ignored
        envvars = pytypes.MappingProxyType({} if envvars is None else envvars)
        env_copy.update(envvars)
        status = subprocess.run(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            timeout=timeout,
            env=env_copy,
            universal_newlines=True,
        )
        streams = (
            f"\ncaptured stdout: {status.stdout}\n"
            f"captured stderr: {status.stderr}"
        )
        self.assertEqual(status.returncode, 0, streams)
        # Python 3.12.1 report
        no_tests_ran = "NO TESTS RAN"
        if no_tests_ran in status.stderr:
            self.skipTest(no_tests_ran)
        else:
            self.assertIn("OK", status.stderr)
        return status

    def run_test_in_subprocess(maybefunc=None, timeout=60, envvars=None):
        """Runs the decorated test in a subprocess via invoking numba's test
        runner. kwargs timeout and envvars are passed through to
        subprocess_test_runner."""

        def wrapper(func):
            def inner(self, *args, **kwargs):
                if os.environ.get("SUBPROC_TEST", None) != func.__name__:
                    # Not in a subprocess test env, so stage the call to run the
                    # test in a subprocess which will set the env var.
                    class_name = self.__class__.__name__
                    self.subprocess_test_runner(
                        test_module=self.__module__,
                        test_class=class_name,
                        test_name=func.__name__,
                        timeout=timeout,
                        envvars=envvars,
                        _subproc_test_env=func.__name__,
                    )
                else:
                    # env var is set, so we're in the subprocess, run the
                    # actual test.
                    func(self)

            return inner

        if isinstance(maybefunc, pytypes.FunctionType):
            return wrapper(maybefunc)
        else:
            return wrapper

    def make_dummy_type(self):
        """Use to generate a dummy type unique to this test. Returns a python
        Dummy class and a corresponding Numba type DummyType."""

        # Use test_id to make sure no collision is possible.
        test_id = self.id()
        DummyType = type("DummyTypeFor{}".format(test_id), (types.Opaque,), {})

        dummy_type = DummyType("my_dummy")
        register_model(DummyType)(OpaqueModel)

        class Dummy(object):
            pass

        @typeof_impl.register(Dummy)
        def typeof_dummy(val, c):
            return dummy_type

        # Dual registration for cross-target tests
        if HAS_NUMBA:
            UpstreamDummyType = type(
                "DummyTypeFor{}".format(test_id), (upstream_types.Opaque,), {}
            )
            upstream_dummy_type = UpstreamDummyType("my_dummy")

            @upstream_typeof_impl.register(Dummy)
            def typeof_dummy_core(val, c):
                return upstream_dummy_type

        @unbox(DummyType)
        def unbox_dummy(typ, obj, c):
            return NativeValue(c.context.get_dummy_value())

        return Dummy, DummyType


class MemoryLeak(object):
    __enable_leak_check = True

    def memory_leak_setup(self):
        # Clean up any NRT-backed objects hanging in a dead reference cycle
        gc.collect()
        self.__init_stats = rtsys.get_allocation_stats()

    def memory_leak_teardown(self):
        if self.__enable_leak_check:
            self.assert_no_memory_leak()

    def assert_no_memory_leak(self):
        old = self.__init_stats
        new = rtsys.get_allocation_stats()
        total_alloc = new.alloc - old.alloc
        total_free = new.free - old.free
        total_mi_alloc = new.mi_alloc - old.mi_alloc
        total_mi_free = new.mi_free - old.mi_free
        self.assertEqual(total_alloc, total_free)
        self.assertEqual(total_mi_alloc, total_mi_free)

    def disable_leak_check(self):
        # For per-test use when MemoryLeakMixin is injected into a TestCase
        self.__enable_leak_check = False


class MemoryLeakMixin(EnableNRTStatsMixin, MemoryLeak):
    def setUp(self):
        super(MemoryLeakMixin, self).setUp()
        self.memory_leak_setup()

    def tearDown(self):
        gc.collect()
        self.memory_leak_teardown()
        super(MemoryLeakMixin, self).tearDown()


class CheckWarningsMixin(object):
    @contextlib.contextmanager
    def check_warnings(self, messages, category=RuntimeWarning):
        with warnings.catch_warnings(record=True) as catch:
            warnings.simplefilter("always")
            yield
        found = 0
        for w in catch:
            for m in messages:
                if m in str(w.message):
                    self.assertEqual(w.category, category)
                    found += 1
        self.assertEqual(found, len(messages))


@contextlib.contextmanager
def override_env_config(name, value):
    """
    Return a context manager that temporarily sets an Numba config environment
    *name* to *value*.
    """
    old = os.environ.get(name)
    os.environ[name] = value
    config.reload_config()

    try:
        yield
    finally:
        if old is None:
            # If it wasn't set originally, delete the environ var
            del os.environ[name]
        else:
            # Otherwise, restore to the old value
            os.environ[name] = old
        # Always reload config
        config.reload_config()


def run_in_new_process_in_cache_dir(func, cache_dir, verbose=True):
    """Spawn a new process to run `func` with a temporary cache directory.

    The childprocess's stdout and stderr will be captured and redirected to
    the current process's stdout and stderr.

    Similar to ``run_in_new_process_caching()`` but the ``cache_dir`` is a
    directory path instead of a name prefix for the directory path.

    Returns
    -------
    ret : dict
        exitcode: 0 for success. 1 for exception-raised.
        stdout: str
        stderr: str
    """
    with override_env_config("NUMBA_CACHE_DIR", cache_dir):
        with concurrent.futures.ProcessPoolExecutor(
            mp_context=mp.get_context("spawn")
        ) as exe:
            future = exe.submit(_remote_runner, func)

        stdout, stderr, exitcode = future.result()
        if verbose:
            if stdout:
                print()
                print("STDOUT".center(80, "-"))
                print(stdout)
            if stderr:
                print(file=sys.stderr)
                print("STDERR".center(80, "-"), file=sys.stderr)
                print(stderr, file=sys.stderr)
    return {"exitcode": exitcode, "stdout": stdout, "stderr": stderr}


def _remote_runner(fn, qout):
    """Used by `run_in_new_process_caching()`"""
    with captured_stderr() as stderr, captured_stdout() as stdout:
        try:
            fn()
        except Exception:
            traceback.print_exc(file=sys.stderr)
            exitcode = 1
        else:
            exitcode = 0
        return stdout.getvalue().strip(), stderr.getvalue().strip(), exitcode
