from __future__ import annotations

import asyncio
import calendar
import functools
import json
import os
import re
from dataclasses import asdict
from datetime import datetime
from json.decoder import JSONDecodeError
from types import coroutine
from typing import Any, Coroutine, Dict, List, Optional, Tuple

import httpx
from aext_panels_server.const import (
    ALLOWED_CELL_METADATA,
    ALLOWED_NB_METADATA,
    DEFAULT_HTTPX_TIMEOUT_SEC,
    PYTHONANYWHERE_PRODUCTION_SITE,
    Environment,
)
from aext_panels_server.logger import custom_logger
from aext_panels_server.schemas import RequestError, ServiceResponse


def is_local_development_env():
    return bool(os.getenv("LOCAL_ENV"))


def is_staging_env():
    is_production_site = os.getenv("PYTHONANYWHERE_SITE") != PYTHONANYWHERE_PRODUCTION_SITE
    is_local = is_local_development_env()
    return is_production_site and not is_local


def is_anaconda_toolbox_env():
    # This environment variable are only set in Anaconda Cloud Notebooks
    return not os.getenv("PYTHONANYWHERE_SITE")


def get_environment() -> Environment:
    environment = Environment.PRODUCTION

    if is_anaconda_toolbox_env():
        environment = Environment.PRODUCTION
    elif is_staging_env():
        environment = Environment.STAGING
    elif is_local_development_env():
        environment = Environment.LOCAL

    custom_logger.info(f"****************** ENVIRONMENT {environment} ******************")
    return environment


def _update_value(dct: dict[str, Any], key: str, value: Any):
    if key in dct and dct[key] != value:
        dct[key] = value


def _clean_cell(cell):
    _update_value(cell, "outputs", [])
    _update_value(cell, "execution_count", None)
    clean_metadata = {}
    for key, value in cell.get("metadata", {}).items():
        if key in ALLOWED_CELL_METADATA:
            clean_metadata[key] = value
    _update_value(cell, "metadata", clean_metadata)


def clean_notebook(data: str) -> str:
    """
    Parses notebook as a string and strips out output and metadata
    unless it is allowed.

    Parameters
    ----------
    data: Notebook file as a string

    Returns
    -------
    cleaned_nb: Notebook after stripping output and metadata
    """
    nb = json.loads(data)
    for cell in nb["cells"]:
        _clean_cell(cell)
    metadata = {"language_info": {"name": "python", "pygments_lexer": "ipython3"}}
    for key, value in nb.get("metadata", {}).items():
        if key in ALLOWED_NB_METADATA:
            metadata[key] = value
    _update_value(nb, "metadata", metadata)
    cleaned_nb = json.dumps(nb, ensure_ascii=False)
    return cleaned_nb


def to_json(response: ServiceResponse):
    return json.dumps(asdict(response, dict_factory=lambda x: {k: v for (k, v) in x if v is not None}))


def extract_file_name(file_path, remove_extension: bool = True):
    """
    Remove the directory path, if any, from the file name

    Parameters
    ----------
    file_path: file path, it can be relative or absolute

    Returns
    -------
    file_name: file name
    """
    file_name = file_path.split("/")[-1]

    # Remove the ".ipynb" extension, if any, from the file name
    if remove_extension:
        file_name = re.sub(r"\.ipynb$", "", file_name)

    return file_name


async def call_api(
    api_call: Coroutine, always_return_response: bool = False
) -> Tuple[RequestError, Optional[httpx.Response]]:
    """
    This function receives a coroutine that should be created out of one of the methods in
    PythonAnywhere API class and performs the request call.

    First it will look for any exception related to networking and if that succeed it will then
    check the status code and determine if the request succeeded or not.

    It will return an RequestError with status True in case either an exception is caught
    or if the status code says there was an error with the request,
    otherwise it will return an RequestError object with status False along with the httpx.Reponse object
    got back from PythonAnywhere API.

    Parameters
    ----------
    api_call: Awaitable created out of one of the PythonAnywhere api methods
    always_return_response: Controls if response should be always returned

    Returns
    -------

    """

    def _get_response() -> Optional[httpx.Response]:
        return response if always_return_response else None

    # Fire the HTTP request
    try:
        response = None
        response = await api_call
    except (
        httpx.TimeoutException,
        httpx.ConnectTimeout,
        httpx.ReadTimeout,
        httpx.WriteTimeout,
        httpx.PoolTimeout,
    ) as ex:
        return (
            RequestError(status=True, message=f"{ex.__class__.__name__}"),
            _get_response(),
        )
    except httpx.NetworkError:
        return (
            RequestError(status=True, message="Request Network Error"),
            _get_response(),
        )
    except httpx.ProxyError:
        return (
            RequestError(status=True, message="Request Network Error"),
            _get_response(),
        )
    except Exception:
        return (
            RequestError(status=True, message="Unexpected error when calling api"),
            _get_response(),
        )

    # HTTP request without any networking errors. Checking the status code
    status = response.status_code
    success_range = range(httpx.codes.OK, httpx.codes.IM_USED + 1)
    try:
        if status in success_range:
            return RequestError(status=False, message=""), response
        elif status == httpx.codes.NOT_FOUND:
            return RequestError(status=True, message="Not found"), _get_response()
        elif status >= httpx.codes.BAD_REQUEST and status < httpx.codes.INTERNAL_SERVER_ERROR:
            return (
                RequestError(status=True, message=f"Client error: {response.json()}"),
                _get_response(),
            )
        else:
            return (
                RequestError(status=True, message=f"Server error: {response.json()}"),
                _get_response(),
            )
    except JSONDecodeError:
        return (
            RequestError(status=True, message="Could not decode JSON"),
            _get_response(),
        )
    except Exception:
        return (
            RequestError(status=True, message="An unexpected error occurred while making request"),
            None,
        )


async def run_concurrently(coroutines: Dict[str, coroutine], timeout: int = None):
    """
    Receives a list of coroutines and execute them all concurrently
    returning a dictionary where each key carries the coroutine name and it's
    result as value.

    Args:
        coroutines: dict of coroutines keyed by some unique ID
        timeout: after this timeout the coroutines that have finished will be returned and the
                 remaining ones will keep in pending state

    Returns: Dict with coroutine IDs and it's results

    """
    tasks = []
    for key in coroutines:
        tasks.append(asyncio.tasks.create_task(coroutines[key], name=key))
    task_results, pending = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
    if pending:
        for task in pending:
            task.cancel()

    return {task.get_name(): task.result() for task in task_results}


def get_vdate():
    return f"{datetime.now().year}.{datetime.now().month:0>2}"


def to_snake_case(names: List):
    """Convert to snake case a list of strings

    Parameters
    ----------
    names: List of strings

    Returns
    -------
    list of string converted to snake case
    """

    normalized_names = []
    for name in names:
        # replace all non-alphanumeric/space with empty space
        step_1 = re.sub("[^A-Za-z0-9 -]+", "", name).lower()
        # replace all non-alphanumeric with underlines
        step_2 = re.sub("[^A-Za-z0-9]+", "-", step_1).lower()

        # replace all non-alphanumeric with underline
        normalized_names.append(step_2)

    return normalized_names


def get_utc_timestamp() -> float:
    """
    Get the current timestamp in UTC
    Returns: timestamp in UTC
    -------
    """

    return calendar.timegm(datetime.utcnow().utctimetuple())


async def read_mocked_response(filename: str, to_json: bool = False):
    """
    Read a mock file
    """
    dirname = os.path.dirname(__file__)
    with open(f"{os.path.join(dirname, f'mocked-responses/{filename}')}", "r") as f:
        if to_json:
            return json.loads(f.read())
        return f.read()


def get_relative_file_path(file_path: str) -> str:
    """
    Return the relative file path
    e.g.:
    /home/user/subdir/filename.ext
    /subdir/filename.ext

    or
    /home/user/filename.ext
    /filename.ext

    Parameters:
    -----------
    file_path: file path to be parsed

    Returns:
    --------
    relative file path parsed
    """

    matches = re.search(r"/([^/]+)/([^/]+)(.*)", file_path)

    if matches:
        relative_file_path = matches.group(3)

        # if there is `/` char then remove it
        if relative_file_path[0] == "/":
            return relative_file_path[1:]

        return matches.group(3)

    return file_path


def escape_filename(cmd: str) -> str:
    """
    Replace expecial chars by underline

    Parameters:
    -----------
    escaped_filename: filename to be escaped

    Returns:
    --------
    filename escaped
    """
    escaped_filename = re.sub(r"[^\w\-_\.]", "_", cmd)  # type: ignore
    return escaped_filename


def handle_network_exceptions(max_retries, default_timeout=DEFAULT_HTTPX_TIMEOUT_SEC):
    def actual_decorator(func):
        @functools.wraps(func)
        async def wrapper(*args, **kwargs):
            try:
                custom_logger.debug(f"Calling {func.__name__}")
                return await func(*args, **kwargs)
            except httpx.TimeoutException:
                timeout = default_timeout * 2
                nonlocal max_retries
                while max_retries:
                    custom_logger.warning(
                        f"Connection timed out while calling {func.__name__}. " f"Retrying with {timeout}s timeout"
                    )
                    max_retries -= 1
                    try:
                        return await func(*args, timeout_sec=timeout, **kwargs)
                    except httpx.TimeoutException:
                        continue

                custom_logger.error(
                    f"Connection timed out while calling {func.__name__}. " f"Number of retries exceeded, giving up"
                )
            except httpx.NetworkError:
                custom_logger.error(
                    f"Network error while calling {func.__name__}. " f"Number of retries exceeded, giving up"
                )

        return wrapper

    return actual_decorator


def get_username() -> str:
    return os.getenv("USERNAME", os.getenv("PYTHONANYWHERE_USER", "unknown"))
