import functools
import logging
import time
from functools import wraps
from typing import Callable, Optional, Any

import click
import urllib3
from lightning_cloud import env
from lightning_cloud.login import Auth
from lightning_cloud.openapi import (
    ApiClient,
    AuthServiceApi,
    CloudSpaceServiceApi,
    ClusterServiceApi,
    Configuration,
    DataConnectionServiceApi,
    LightningappInstanceServiceApi,
    LightningappV2ServiceApi,
    LightningworkServiceApi,
    ProjectsServiceApi,
    SecretServiceApi,
    SSHPublicKeyServiceApi,
    DatasetServiceApi,
    OrganizationsServiceApi,
    UserServiceApi,
    BillingServiceApi,
)
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.source_code.logs_socket_api import LightningLogsSocketAPI

logger = logging.getLogger(__name__)


def create_swagger_client(check_context=True):
    """
    Create the swagger client to use the autogenerated code

    Parameters
    ----------
    check_context: bool
        If true, check if the context is set. It's only false for APIs that
        doesn't need the context information i.e login
    """
    if check_context and not env.CONTEXT:
        raise RuntimeError(
            "Default cluster is not found. Try logging in again!")
    url = env.LIGHTNING_CLOUD_URL
    configuration = Configuration()
    configuration.host = url
    configuration.debug = env.DEBUG
    # for custom certs we need to hint urllib to use one of them if available
    # (requests package would use any of them). these two are also use
    # during artifacts to actually reverted and not use the custom certs
    # if present in /grid-cli/grid/cli/cli/artifacts.py
    configuration.ssl_ca_cert = env.SSL_CA_CERT
    api_client = ApiClient(configuration)
    api_client.default_headers["Authorization"] = Auth().authenticate()
    api_client.user_agent = f"Grid-CLI-{env.VERSION}"
    return api_client


class GridRestClient(
        LightningLogsSocketAPI,
        LightningappInstanceServiceApi,
        LightningappV2ServiceApi,
        AuthServiceApi,
        CloudSpaceServiceApi,
        ClusterServiceApi,
        ProjectsServiceApi,
        LightningworkServiceApi,
        SecretServiceApi,
        SSHPublicKeyServiceApi,
        DataConnectionServiceApi,
        DatasetServiceApi,
        OrganizationsServiceApi,
        UserServiceApi,
        BillingServiceApi,
):

    def __init__(self, api_client: Optional[ApiClient] = None):
        api_client = api_client if api_client else create_swagger_client()
        api_client.request = request_auth_warning_wrapper(api_client.request)
        super().__init__(api_client)


_DEFAULT_BACKOFF_MAX = 5 * 60  # seconds


def _get_next_backoff_time(num_retries: int,
                           backoff_value: float = 0.5) -> float:
    next_backoff_value = backoff_value * (2**(num_retries - 1))
    return min(_DEFAULT_BACKOFF_MAX, next_backoff_value)


def _retry_wrapper(self,
                   func: Callable,
                   max_tries: Optional[int] = None) -> Callable:
    """Returns the function decorated by a wrapper that retries the call several times if a connection error occurs.

    The retries follow an exponential backoff.

    """

    @wraps(func)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        consecutive_errors = 0

        while True:
            try:
                return func(self, *args, **kwargs)
            except (ApiException, urllib3.exceptions.HTTPError) as ex:
                # retry if the backend fails with all errors except 4xx but not 408 - (Request Timeout)
                if (isinstance(ex, urllib3.exceptions.HTTPError)
                        or ex.status in (408, 409)
                        or not str(ex.status).startswith("4")):
                    consecutive_errors += 1
                    backoff_time = _get_next_backoff_time(consecutive_errors)

                    msg = (f"error: {str(ex)}" if isinstance(
                        ex, urllib3.exceptions.HTTPError) else
                           f"response: {ex.status}")
                    logger.debug(
                        f"The {func.__name__} request failed to reach the server, {msg}."
                        f" Retrying after {backoff_time} seconds.")

                    if max_tries is not None and consecutive_errors == max_tries:
                        raise Exception(
                            f"The {func.__name__} request failed to reach the server, {msg}."
                        )

                    time.sleep(backoff_time)
                else:
                    raise ex

    return wrapped


class LightningClient(GridRestClient):
    """The LightningClient is a wrapper around the GridRestClient.

    It wraps all methods to monitor connection exceptions and employs a retry strategy.

    Args:
        retry: Whether API calls should follow a retry mechanism with exponential backoff.
        max_tries: Maximum number of attempts (or -1 to retry forever).

    """

    def __init__(self,
                 retry: bool = True,
                 max_tries: Optional[int] = None) -> None:
        super().__init__(api_client=create_swagger_client())
        if retry:
            for base_class in GridRestClient.__mro__:
                for name, attribute in base_class.__dict__.items():
                    if callable(
                            attribute) and attribute.__name__ != "__init__":
                        setattr(
                            self,
                            name,
                            _retry_wrapper(self,
                                           attribute,
                                           max_tries=max_tries),
                        )


def request_auth_warning_wrapper(func):

    @functools.wraps(func)
    def wrap(*args, **kwargs):
        try:
            response = func(*args, **kwargs)
            return response
        except ApiException as err:
            if err.status == 401:
                raise click.ClickException(
                    "Authentication failed. Please run `lightning login`.")
            raise err

    return wrap
