# SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Callable
from typing import Any

from cuda.bindings import driver, runtime

from rmm.pylibrmm.stream import Stream
from rmm.statistics import Statistics

class DeviceMemoryResource:
    def allocate(self, nbytes: int, stream: Stream = ...) -> int: ...
    def deallocate(
        self, ptr: int, nbytes: int, stream: Stream = ...
    ) -> None: ...

class UpstreamResourceAdaptor(DeviceMemoryResource):
    upstream_mr: DeviceMemoryResource
    def __init__(
        self, upstream_mr: DeviceMemoryResource, *args: Any, **kwargs: Any
    ) -> None: ...
    def get_upstream(self) -> DeviceMemoryResource: ...

class CudaMemoryResource(DeviceMemoryResource):
    def __init__(self) -> None: ...

class CudaAsyncMemoryResource(DeviceMemoryResource):
    def __init__(
        self,
        initial_pool_size: int | str | None = None,
        release_threshold: int | None = None,
        enable_ipc: bool = False,
        enable_fabric: bool = False,
    ) -> None: ...

class CudaAsyncViewMemoryResource(DeviceMemoryResource):
    def __init__(
        self, pool_handle: runtime.cudaMemPool_t | driver.CUmemoryPool
    ) -> None: ...
    def pool_handle(self) -> int: ...

class ManagedMemoryResource(DeviceMemoryResource):
    def __init__(self) -> None: ...

class SystemMemoryResource(DeviceMemoryResource):
    def __init__(self) -> None: ...

class PinnedHostMemoryResource(DeviceMemoryResource):
    def __init__(self) -> None: ...

class SamHeadroomMemoryResource(DeviceMemoryResource):
    def __init__(self, headroom: int) -> None: ...

class PoolMemoryResource(UpstreamResourceAdaptor):
    def __init__(
        self,
        upstream_mr: DeviceMemoryResource,
        initial_pool_size: int | str | None = None,
        maximum_pool_size: int | str | None = None,
    ) -> None: ...
    def pool_size(self) -> int: ...

class ArenaMemoryResource(UpstreamResourceAdaptor):
    def __init__(
        self,
        upstream_mr: DeviceMemoryResource,
        arena_size: int | str | None = None,
        dump_log_on_failure: bool = False,
    ) -> None: ...

class FixedSizeMemoryResource(UpstreamResourceAdaptor):
    def __init__(
        self,
        upstream_mr: DeviceMemoryResource,
        block_size: int = 1048576,
        blocks_to_preallocate: int = 128,
    ) -> None: ...

class BinningMemoryResource(UpstreamResourceAdaptor):
    def __init__(
        self,
        upstream_mr: DeviceMemoryResource,
        min_size_exponent: int = -1,
        max_size_exponent: int = -1,
    ) -> None: ...
    def add_bin(
        self,
        allocation_size: int,
        bin_resource: DeviceMemoryResource | None = None,
    ) -> None: ...
    @property
    def bin_mrs(self) -> list[DeviceMemoryResource]: ...

class CallbackMemoryResource(DeviceMemoryResource):
    def __init__(
        self,
        allocate_func: Callable[[int, Stream], int],
        deallocate_func: Callable[[int, int, Stream], None],
    ) -> None: ...

class LimitingResourceAdaptor(UpstreamResourceAdaptor):
    def __init__(
        self, upstream_mr: DeviceMemoryResource, allocation_limit: int
    ) -> None: ...
    def get_allocated_bytes(self) -> int: ...
    def get_allocation_limit(self) -> int: ...

class LoggingResourceAdaptor(UpstreamResourceAdaptor):
    def __init__(
        self,
        upstream_mr: DeviceMemoryResource,
        log_file_name: str | None = None,
    ) -> None: ...
    def flush(self) -> None: ...
    def get_file_name(self) -> str: ...

class StatisticsResourceAdaptor(UpstreamResourceAdaptor):
    def __init__(self, upstream_mr: DeviceMemoryResource) -> None: ...
    @property
    def allocation_counts(self) -> Statistics: ...
    def pop_counters(self) -> Statistics: ...
    def push_counters(self) -> Statistics: ...

class TrackingResourceAdaptor(UpstreamResourceAdaptor):
    def __init__(
        self, upstream_mr: DeviceMemoryResource, capture_stacks: bool = False
    ) -> None: ...
    def get_allocated_bytes(self) -> int: ...
    def get_outstanding_allocations_str(self) -> str: ...
    def log_outstanding_allocations(self) -> None: ...

class FailureCallbackResourceAdaptor(UpstreamResourceAdaptor):
    def __init__(
        self,
        upstream_mr: DeviceMemoryResource,
        callback: Callable[[int], bool],
    ) -> None: ...

class PrefetchResourceAdaptor(UpstreamResourceAdaptor):
    def __init__(self, upstream_mr: DeviceMemoryResource) -> None: ...

def get_per_device_resource(device: int) -> DeviceMemoryResource: ...
def set_per_device_resource(device: int, mr: DeviceMemoryResource) -> None: ...
def set_current_device_resource(mr: DeviceMemoryResource) -> None: ...
def get_per_device_resource_type(
    device: int,
) -> type[DeviceMemoryResource]: ...
def get_current_device_resource() -> DeviceMemoryResource: ...
def get_current_device_resource_type() -> type[DeviceMemoryResource]: ...
def is_initialized() -> bool: ...
def enable_logging(log_file_name: str | None = None) -> None: ...
def disable_logging() -> None: ...
def get_log_filenames() -> dict[int, str | None]: ...
def available_device_memory() -> tuple[int, int]: ...
def _initialize(
    pool_allocator: bool = False,
    managed_memory: bool = False,
    initial_pool_size: int | str | None = None,
    maximum_pool_size: int | str | None = None,
    devices: int | list[int] = 0,
    logging: bool = False,
    log_file_name: str | None = None,
) -> None: ...
def _flush_logs() -> None: ...
