# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

import cudf

from cugraph.tree import minimum_spanning_tree_wrapper
from pylibcugraph import minimum_spanning_tree as pylibcugraph_minimum_spanning_tree
from pylibcugraph import ResourceHandle
from cugraph.structure.graph_classes import Graph


def _minimum_spanning_tree_subgraph(G):
    mst_subgraph = Graph()
    if G.is_directed():
        raise ValueError("input graph must be undirected")

    sources, destinations, edge_weights, _ = pylibcugraph_minimum_spanning_tree(
        resource_handle=ResourceHandle(),
        graph=G._plc_graph,
        do_expensive_check=True,
    )

    mst_df = cudf.DataFrame()
    mst_df["src"] = sources
    mst_df["dst"] = destinations
    if edge_weights is not None:
        mst_df["weight"] = edge_weights

    if G.renumbered:
        mst_df = G.unrenumber(mst_df, "src")
        mst_df = G.unrenumber(mst_df, "dst")

    mst_subgraph.from_cudf_edgelist(
        mst_df, source="src", destination="dst", edge_attr="weight"
    )
    return mst_subgraph


def _maximum_spanning_tree_subgraph(G):
    mst_subgraph = Graph()
    if G.is_directed():
        raise ValueError("input graph must be undirected")

    if not G.adjlist:
        G.view_adj_list()

    if G.adjlist.weights is not None:
        G.adjlist.weights = G.adjlist.weights.mul(-1)

    mst_df = minimum_spanning_tree_wrapper.minimum_spanning_tree(G)

    # revert to original weights
    if G.adjlist.weights is not None:
        G.adjlist.weights = G.adjlist.weights.mul(-1)
        mst_df["weight"] = mst_df["weight"].mul(-1)

    if G.renumbered:
        mst_df = G.unrenumber(mst_df, "src")
        mst_df = G.unrenumber(mst_df, "dst")

    mst_subgraph.from_cudf_edgelist(
        mst_df, source="src", destination="dst", edge_attr="weight"
    )
    return mst_subgraph


def minimum_spanning_tree(
    G: Graph, weight=None, algorithm="boruvka", ignore_nan=False
) -> Graph:
    """
    Returns a minimum spanning tree (MST) or forest (MSF) on an undirected
    graph

    Parameters
    ----------
    G : cuGraph.Graph
        cuGraph graph descriptor with connectivity information.

    weight : string
        default to the weights in the graph, if the graph edges do not have a
        weight attribute a default weight of 1 will be used.

    algorithm : string
        Default to 'boruvka'. The parallel algorithm to use when finding a
        minimum spanning tree.

    ignore_nan : bool
        Default to False

    Returns
    -------
    G_mst : cuGraph.Graph
        A graph descriptor with a minimum spanning tree or forest.

    Examples
    --------
    >>> from cugraph.datasets import netscience
    >>> G = netscience.get_graph(download=True)
    >>> G_mst = cugraph.minimum_spanning_tree(G)

    """

    return _minimum_spanning_tree_subgraph(G)


def maximum_spanning_tree(G, weight=None, algorithm="boruvka", ignore_nan=False):
    """
    Returns a maximum spanning tree (MST) or forest (MSF) on an undirected
    graph. Also computes the adjacency list if G does not have one.

    Parameters
    ----------
    G : cuGraph.Graph
        cuGraph graph descriptor with connectivity information.

    weight : string
        default to the weights in the graph, if the graph edges do not have a
        weight attribute a default weight of 1 will be used.

    algorithm : string
        Default to 'boruvka'. The parallel algorithm to use when finding a
        maximum spanning tree.

    ignore_nan : bool
        Default to False

    Returns
    -------
    G_mst : cuGraph.Graph
        A graph descriptor with a maximum spanning tree or forest.

    Examples
    --------
    >>> from cugraph.datasets import netscience
    >>> G = netscience.get_graph(download=True)
    >>> G_mst = cugraph.maximum_spanning_tree(G)

    """

    return _maximum_spanning_tree_subgraph(G)
