# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

import os

import numpy as np
import cupy
import pandas

import pylibcugraph

from pylibcugraph.comms import cugraph_comms_get_raft_handle

from cugraph_pyg.utils.imports import import_optional, MissingModule
from cugraph_pyg.tensor import DistTensor, DistMatrix
from cugraph_pyg.tensor.utils import has_nvlink_network, is_empty

from typing import Union, Optional, List, Dict, Tuple

# cudf is an optional dependency.  It is only imported here for typing.
cudf = import_optional("cudf")

# Have to use import_optional even though these are required
# dependencies in order to build properly.
torch_geometric = import_optional("torch_geometric")
torch = import_optional("torch")
TensorType = Union[
    "torch.Tensor", cupy.ndarray, np.ndarray, "cudf.Series", pandas.Series
]


class GraphStore(
    object
    if isinstance(torch_geometric, MissingModule)
    else torch_geometric.data.GraphStore
):
    """
    cuGraph-backed PyG GraphStore implementation that distributes
    the graph across workers.  This object uses lazy graph creation.
    Users can repeatedly call put_edge_index, and the tensors won't
    be converted into a cuGraph graph until one is needed
    (i.e. when creating a loader). Supports
    single-node/single-GPU, single-node/multi-GPU, and
    multi-node/multi-GPU graph storage.

    Each worker should have a slice of the graph locally, and
    call put_edge_index with its slice.
    """

    def __init__(self):
        """
        Constructs a new, empty GraphStore object.  This object
        represents one slice of a graph on particular worker.
        """
        self.__edge_indices = {}
        self.__sizes = {}

        self.__handle = None

        self.__clear_graph()

        if int(os.environ["LOCAL_WORLD_SIZE"]) == torch.distributed.get_world_size():
            self.__backend = "vmm"
        else:
            self.__backend = "vmm" if has_nvlink_network() else "nccl"

        super().__init__()

    def __clear_graph(self):
        self.__graph = None
        self.__vertex_offsets = None
        self.__weight_attr = None
        self.__etime_attr = None
        self.__numeric_edge_types = None

    def _put_edge_index(
        self,
        edge_index: "torch_geometric.typing.EdgeTensorType",
        edge_attr: "torch_geometric.data.EdgeAttr",
    ) -> bool:
        if edge_attr.layout != torch_geometric.data.graph_store.EdgeLayout.COO:
            raise ValueError("Only COO format supported")

        if isinstance(edge_index, (cupy.ndarray, cudf.Series)):
            edge_index = torch.as_tensor(edge_index, device="cuda")
        elif isinstance(edge_index, (np.ndarray)):
            edge_index = torch.as_tensor(edge_index, device="cpu")
        elif isinstance(edge_index, pandas.Series):
            edge_index = torch.as_tensor(edge_index.values, device="cpu")
        elif isinstance(edge_index, cudf.Series):
            edge_index = torch.as_tensor(edge_index.values, device="cuda")

        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

        if isinstance(edge_index, torch.Tensor) and is_empty(edge_index):
            edge_index = torch.tensor([[], []], device="cuda", dtype=torch.int64)
        else:
            if len(edge_index) != 2:
                raise ValueError("Edge index must be of length 2")

        local_size = torch.tensor(
            0 if is_empty(edge_index[1]) else edge_index[1].shape[0],
            device="cuda",
            dtype=torch.int64,
        )
        sizes = torch.empty((world_size,), device="cuda", dtype=torch.int64)
        torch.distributed.all_gather_into_tensor(sizes, local_size)
        size = int(sizes.sum())

        offset = sizes[:rank].sum() if rank > 0 else 0

        if isinstance(edge_index, DistMatrix):
            self.__edge_indices[edge_attr.edge_type] = edge_index
        else:
            self.__edge_indices[edge_attr.edge_type] = DistMatrix(
                shape=(size, size), dtype=torch.long, backend=self.__backend
            )

            if isinstance(edge_index[0], DistTensor) and isinstance(
                edge_index[1], DistTensor
            ):
                if edge_index[0].shape[0] != edge_index[1].shape[0]:
                    raise ValueError(
                        "Only COO format is supported for construction "
                        "from DistTensor tuples."
                    )
                self.__edge_indices[edge_attr.edge_type]._row = edge_index[0]
                self.__edge_indices[edge_attr.edge_type]._col = edge_index[1]
            else:
                if isinstance(edge_index, list):
                    edge_index = torch.stack(edge_index)
                self.__edge_indices[edge_attr.edge_type][
                    offset : offset + local_size
                ] = edge_index

        self.__sizes[edge_attr.edge_type] = edge_attr.size

        # invalidate the graph
        self.__clear_graph()
        return True

    def _get_edge_index(
        self, edge_attr: "torch_geometric.data.EdgeAttr"
    ) -> Optional["torch_geometric.typing.EdgeTensorType"]:
        # TODO Return WG edge index as duck-type for torch_geometric.EdgeIndex
        # (rapidsai/cugraph-gnn#188)
        local_eix = self.__edge_indices[edge_attr.edge_type].local_coo
        ei = torch_geometric.EdgeIndex(local_eix)

        if edge_attr.layout == "csr":
            return ei.sort_by("row").values.get_csr()
        elif edge_attr.layout == "csc":
            return ei.sort_by("col").values.get_csc()

        return ei

    def _remove_edge_index(self, edge_attr: "torch_geometric.data.EdgeAttr") -> bool:
        del self.__edge_indices[edge_attr.edge_type]

        # invalidate the graph
        self.__clear_graph()
        return True

    def get_all_edge_attrs(self) -> List["torch_geometric.data.EdgeAttr"]:
        attrs = []
        for et in self.__edge_indices.keys():
            attrs.append(
                torch_geometric.data.EdgeAttr(
                    edge_type=et, layout="coo", is_sorted=False, size=self.__sizes[et]
                )
            )

        return attrs

    @property
    def is_multi_gpu(self):
        return torch.distributed.get_world_size() > 1

    @property
    def _resource_handle(self):
        if self.__handle is None:
            if self.is_multi_gpu:
                self.__handle = pylibcugraph.ResourceHandle(
                    cugraph_comms_get_raft_handle().getHandle()
                )
            else:
                self.__handle = pylibcugraph.ResourceHandle()
        return self.__handle

    @property
    def _graph(self) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]:
        graph_properties = pylibcugraph.GraphProperties(
            is_multigraph=True, is_symmetric=False
        )

        if self.__graph is None:
            edgelist_dict = self.__get_edgelist()

            if self.is_multi_gpu:
                rank = torch.distributed.get_rank()
                world_size = torch.distributed.get_world_size()

                vertices_array = cupy.arange(
                    sum(self._num_vertices().values()), dtype="int64"
                )
                vertices_array = cupy.array_split(vertices_array, world_size)[rank]

                self.__graph = pylibcugraph.MGGraph(
                    self._resource_handle,
                    graph_properties,
                    [cupy.asarray(edgelist_dict["src"]).astype("int64")],
                    [cupy.asarray(edgelist_dict["dst"]).astype("int64")],
                    vertices_array=[vertices_array],
                    edge_id_array=[cupy.asarray(edgelist_dict["eid"])],
                    edge_type_array=[cupy.asarray(edgelist_dict["etp"])],
                    weight_array=[cupy.asarray(edgelist_dict["wgt"])]
                    if "wgt" in edgelist_dict
                    else None,
                    edge_start_time_array=[cupy.asarray(edgelist_dict["etime"])]
                    if "etime" in edgelist_dict
                    else None,
                )
            else:
                self.__graph = pylibcugraph.SGGraph(
                    self._resource_handle,
                    graph_properties,
                    cupy.asarray(edgelist_dict["src"]).astype("int64"),
                    cupy.asarray(edgelist_dict["dst"]).astype("int64"),
                    vertices_array=cupy.arange(
                        sum(self._num_vertices().values()), dtype="int64"
                    ),
                    edge_id_array=cupy.asarray(edgelist_dict["eid"]),
                    edge_type_array=cupy.asarray(edgelist_dict["etp"]),
                    weight_array=cupy.asarray(edgelist_dict["wgt"])
                    if "wgt" in edgelist_dict
                    else None,
                    edge_start_time_array=cupy.asarray(edgelist_dict["etime"])
                    if "etime" in edgelist_dict
                    else None,
                )

        return self.__graph

    def _num_vertices(self) -> Dict[str, int]:
        num_vertices = {}
        for edge_attr in self.get_all_edge_attrs():
            if edge_attr.size is not None:
                num_vertices[edge_attr.edge_type[0]] = (
                    max(num_vertices[edge_attr.edge_type[0]], edge_attr.size[0])
                    if edge_attr.edge_type[0] in num_vertices
                    else edge_attr.size[0]
                )
                num_vertices[edge_attr.edge_type[2]] = (
                    max(num_vertices[edge_attr.edge_type[2]], edge_attr.size[1])
                    if edge_attr.edge_type[2] in num_vertices
                    else edge_attr.size[1]
                )
            else:
                if edge_attr.edge_type[0] != edge_attr.edge_type[2]:
                    if edge_attr.edge_type[0] not in num_vertices:
                        num_vertices[edge_attr.edge_type[0]] = int(
                            self.__edge_indices[edge_attr.edge_type].local_col.max() + 1
                        )
                    if edge_attr.edge_type[2] not in num_vertices:
                        num_vertices[edge_attr.edge_type[1]] = int(
                            self.__edge_indices[edge_attr.edge_type].local_row.max() + 1
                        )
                elif edge_attr.edge_type[0] not in num_vertices:
                    num_vertices[edge_attr.edge_type[0]] = int(
                        self.__edge_indices[edge_attr.edge_type].local_coo.max() + 1
                    )

        if self.is_multi_gpu:
            vtypes = num_vertices.keys()
            for vtype in vtypes:
                sz = torch.tensor(num_vertices[vtype], device="cuda")
                torch.distributed.all_reduce(sz, op=torch.distributed.ReduceOp.MAX)
                num_vertices[vtype] = int(sz)
        return num_vertices

    @property
    def _vertex_offsets(self) -> Dict[str, int]:
        if self.__vertex_offsets is None:
            num_vertices = self._num_vertices()
            ordered_keys = sorted(list(num_vertices.keys()))
            self.__vertex_offsets = {}
            offset = 0
            for vtype in ordered_keys:
                self.__vertex_offsets[vtype] = offset
                offset += num_vertices[vtype]

        return dict(self.__vertex_offsets)

    @property
    def _vertex_offset_array(self) -> "torch.Tensor":
        off = torch.tensor(
            [self._vertex_offsets[k] for k in sorted(self._vertex_offsets.keys())],
            dtype=torch.int64,
            device="cuda",
        )

        return torch.concat(
            [
                off,
                torch.tensor(
                    list(self._num_vertices().values()),
                    device="cuda",
                    dtype=torch.int64,
                )
                .sum()
                .reshape((1,)),
            ]
        )

    @property
    def is_homogeneous(self) -> bool:
        return len(self._vertex_offsets) == 1

    def _set_etime_attr(self, attr: Tuple["torch_geometric.data.FeatureStore", str]):
        if attr != self.__etime_attr:
            weight_attr = self.__weight_attr
            self.__clear_graph()
            self.__etime_attr = attr
            self.__weight_attr = weight_attr

    def _set_weight_attr(self, attr: Tuple["torch_geometric.data.FeatureStore", str]):
        if attr != self.__weight_attr:
            etime_attr = self.__etime_attr
            self.__clear_graph()
            self.__weight_attr = attr
            self.__etime_attr = etime_attr

    def __get_etime_tensor(
        self,
        sorted_keys: List[Tuple[str, str, str]],
        start_offsets: "torch.Tensor",
        num_edges_t: "torch.Tensor",
    ):
        feature_store, attr_name = self.__etime_attr
        etimes = []
        for i, et in enumerate(sorted_keys):
            ix = torch.arange(
                start_offsets[i],
                start_offsets[i] + num_edges_t[i],
                dtype=torch.int64,
                device="cpu",
            )
            etime = feature_store[et, attr_name][ix]

            if etime is None:
                raise ValueError("Time property must be present for all edge types.")
            etimes.append(etime)

        return torch.concat(etimes)

    def __get_weight_tensor(
        self,
        sorted_keys: List[Tuple[str, str, str]],
        start_offsets: "torch.Tensor",
        num_edges_t: "torch.Tensor",
    ):
        feature_store, attr_name = self.__weight_attr

        weights = []
        for i, et in enumerate(sorted_keys):
            ix = torch.arange(
                start_offsets[i],
                start_offsets[i] + num_edges_t[i],
                dtype=torch.int64,
                device="cpu",
            )

            weights.append(feature_store[et, attr_name][ix])

        return torch.concat(weights)

    @property
    def _numeric_edge_types(
        self,
    ) -> Tuple[List[Tuple[str, str, str]], "torch.Tensor", "torch.Tensor"]:
        """
        Returns the canonical edge types in order (the 0th canonical type corresponds
        to numeric edge type 0, etc.), along with the numeric source and destination
        vertex types for each edge type.
        """

        if self.__numeric_edge_types is None:
            sorted_keys = sorted(list(self.__edge_indices.keys()))

            vtype_table = {
                k: i for i, k in enumerate(sorted(self._vertex_offsets.keys()))
            }

            srcs = []
            dsts = []

            for can_etype in sorted_keys:
                srcs.append(vtype_table[can_etype[0]])
                dsts.append(vtype_table[can_etype[2]])

            self.__numeric_edge_types = (
                sorted_keys,
                torch.tensor(srcs, device="cuda", dtype=torch.int32),
                torch.tensor(dsts, device="cuda", dtype=torch.int32),
            )

        return self.__numeric_edge_types

    def __get_edgelist(self):
        """
        Returns
        -------
        Dict[str, torch.Tensor] with the following keys:
            src: source vertices (int64)
                Note that src is the 2nd element of the PyG edge index.
            dst: destination vertices (int64)
                Note that dst is the 1st element of the PyG edge index.
            eid: edge ids for each edge (int64)
                Note that these start from 0 for each edge type.
            etp: edge types for each edge (int32)
                Note that these are in lexicographic order.
        """
        sorted_keys = sorted(list(self.__edge_indices.keys()))

        # note that this still follows the PyG convention of (dst, rel, src)
        # i.e. (author, writes, paper): [[0,1,2],[2,0,1]] is referring to a
        # cuGraph graph where (paper 2) -> (author 0), (paper 0) -> (author 1),
        # and (paper 1) -> (author 0)
        edge_index = torch.concat(
            [
                torch.stack(
                    [
                        self.__edge_indices[dst_type, rel_type, src_type].local_col
                        + self._vertex_offsets[dst_type],
                        self.__edge_indices[dst_type, rel_type, src_type].local_row
                        + self._vertex_offsets[src_type],
                    ]
                )
                for (dst_type, rel_type, src_type) in sorted_keys
            ],
            axis=1,
        ).cuda()

        edge_type_array = torch.arange(
            len(sorted_keys), dtype=torch.int32, device="cuda"
        ).repeat_interleave(
            torch.tensor(
                [self.__edge_indices[et].local_row.numel() for et in sorted_keys],
                device="cuda",
                dtype=torch.int64,
            )
        )

        num_edges_t = torch.tensor(
            [self.__edge_indices[et].local_row.numel() for et in sorted_keys],
            device="cuda",
        )

        if self.is_multi_gpu:
            rank = torch.distributed.get_rank()
            world_size = torch.distributed.get_world_size()

            num_edges_all_t = torch.empty(
                world_size, num_edges_t.numel(), dtype=torch.int64, device="cuda"
            )
            torch.distributed.all_gather_into_tensor(num_edges_all_t, num_edges_t)

            start_offsets = num_edges_all_t[:rank].T.sum(axis=1)
        else:
            rank = 0
            start_offsets = torch.zeros(
                (len(sorted_keys),), dtype=torch.int64, device="cuda"
            )
            num_edges_all_t = num_edges_t.reshape((1, num_edges_t.numel()))

        edge_id_array = torch.concat(
            [
                torch.arange(
                    start_offsets[i],
                    start_offsets[i] + num_edges_all_t[rank][i],
                    dtype=torch.int64,
                    device="cuda",
                )
                for i in range(len(sorted_keys))
            ]
        )

        d = {
            "dst": edge_index[0],
            "src": edge_index[1],
            "etp": edge_type_array,
            "eid": edge_id_array,
        }

        if self.__weight_attr is not None:
            d["wgt"] = self.__get_weight_tensor(
                sorted_keys, start_offsets.cpu(), num_edges_t.cpu()
            ).cuda()

        if self.__etime_attr is not None:
            d["etime"] = self.__get_etime_tensor(
                sorted_keys, start_offsets.cpu(), num_edges_t.cpu()
            ).cuda()

        return d
