# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from enum import IntEnum

from rmm.pylibrmm.memory_resource import DeviceMemoryResource
from rmm.pylibrmm.stream import Stream

class AllocType(IntEnum):
    PRIMARY = ...
    FALLBACK = ...
    ALL = ...

class ScopedMemoryRecord:
    def num_total_allocs(self, alloc_type: AllocType = AllocType.ALL) -> int: ...
    def num_current_allocs(self, alloc_type: AllocType = AllocType.ALL) -> int: ...
    def current(self, alloc_type: AllocType = AllocType.ALL) -> int: ...
    def total(self, alloc_type: AllocType = AllocType.ALL) -> int: ...
    def peak(self, alloc_type: AllocType = AllocType.ALL) -> int: ...
    def record_allocation(self, alloc_type: AllocType, nbytes: int) -> None: ...
    def record_deallocation(self, alloc_type: AllocType, nbytes: int) -> None: ...

class RmmResourceAdaptor:
    def __init__(
        self,
        upstream_mr: DeviceMemoryResource,
        *,
        fallback_mr: DeviceMemoryResource | None = None,
    ): ...
    @property
    def get_upstream(self) -> DeviceMemoryResource: ...
    def allocate(self, nbytes: int, stream: Stream = ...) -> int: ...
    def deallocate(self, ptr: int, nbytes: int, stream: Stream = ...) -> None: ...
    def get_main_record(self) -> ScopedMemoryRecord: ...
    @property
    def current_allocated(self) -> int: ...
