import datetime
import json
import os
from multiprocessing.pool import ThreadPool
from typing import Dict, Optional, Type

import httpx
import jwt
import tornado
from jupyter_server.base.handlers import APIHandler
from tornado.routing import _RuleList
from tornado.web import HTTPError

from aext_shared.config import SHARED_CONFIG
from aext_shared.errors import UnauthorizedError
from aext_shared.handler import BackendHandler, create_rules

config = SHARED_CONFIG

# TODO: This can be removed if/when we add anaconda-cloud-auth as a required dependency
try:
    from anaconda_auth import login, logout
    from anaconda_auth.exceptions import TokenNotFoundError
    from anaconda_auth.handlers import shutdown_all_servers
    from anaconda_auth.token import TokenInfo
except ImportError:  # pragma: no cover
    TokenInfo = None
    TokenNotFoundError = None
    login = None
    logout = None

_cached_api_key = None


def _load_api_key() -> Optional[str]:
    if TokenInfo is None:
        return None
    global _cached_api_key
    if _cached_api_key:
        return _cached_api_key

    try:
        token = TokenInfo.load()
        if token.expired:
            print("[Assistant] Token expired, returning None")
            return None
        _cached_api_key = token.api_key
        return _cached_api_key
    except KeyError:
        return None
    except TokenNotFoundError:
        return None


def _clear_api_key():
    global _cached_api_key
    _cached_api_key = None


def raise_if_error(response):
    if not response.ok:
        raise HTTPError(response.status_code, reason=response.reason)


def get_expires_at(token):
    if not token:
        print("[Assistant] No token found")
        return 0
        # raise HTTPError(403, reason="missing nucleus_token")
    try:
        info = jwt.decode(token, algorithms=["RS256"], options={"verify_signature": False})
        expires = datetime.datetime.fromtimestamp(info["exp"])
        # Return milliseconds
        return int(expires.timestamp() * 1000)
    except Exception as e:
        print(f"Error occurred: {e}")
        return 0


class ApiKeyRouteHandler(APIHandler):
    @tornado.web.authenticated
    async def get(self):
        api_key = _load_api_key()

        if not api_key:
            raise HTTPError(403, reason="missing api_key")

        expires_at = get_expires_at(api_key)
        self.finish(
            {
                "access_token": api_key,
                "expires_at": expires_at,
            }
        )


class LoginRouteHandler(APIHandler):
    pool = None

    @tornado.web.authenticated
    async def get(self):
        """
        This is a workaround for the fact that the login() function in anaconda-cloud-auth
        is not async, and it's not possible to run it in a thread because it starts a tornado
        server. So we run it in a thread pool.

        The next time the user makes a request, if the previous login was still in progress,
        we close the thread pool and start a new one, ensuring that the login will be run again
        without waiting for the previous one and without blocking the main thread.
        """
        try:
            api_key = _load_api_key()
            if api_key:
                self.finish()
                return
        except Exception as e:
            # Ignore errors here, we'll try to login again
            print(f"Error occurred: {e}")

        if LoginRouteHandler.pool:
            LoginRouteHandler.pool.close()

        def login_and_load():
            shutdown_all_servers()
            login()
            _load_api_key()

        with ThreadPool(processes=1) as _pool:
            LoginRouteHandler.pool = _pool
            r = LoginRouteHandler.pool.apply_async(login_and_load)
            while not r.ready():
                await tornado.gen.sleep(0.25)
            r.get()

        self.finish()


class LogoutRouteHandler(APIHandler):
    @tornado.web.authenticated
    async def get(self):
        _clear_api_key()
        logout()
        filepath = prepare_assistant_settings()
        with open(filepath, "w", os.O_EXCL) as f:
            f.write("{}")
        self.finish()


class NucleusUserRouteHandler(BackendHandler):
    @tornado.web.authenticated
    async def get(self):
        headers = {}

        # Update headers
        cloud_flare_client = os.getenv("NUCLEUS_CLOUDFLARE_CLIENT_ID")
        cloud_flare_secret = os.getenv("NUCLEUS_CLOUDFLARE_CLIENT_SECRET")
        if cloud_flare_client and cloud_flare_secret:
            headers["CF-Access-Client-Id"] = cloud_flare_client
            headers["CF-Access-Client-Secret"] = cloud_flare_secret
        try:
            response = await self.anaconda_proxy("account", headers=headers)
        except UnauthorizedError:
            self.set_status(httpx.codes.FORBIDDEN, "Get user account info - Not authenticated")
            return self.finish()

        self.finish(response)


def prepare_assistant_settings():
    """
    Creates the ~/.anaconda/assistant.json file if it doesn't exist.
    TODO: Use Jupyterlab's settings system instead of this file
    """
    directory = os.path.join(os.path.expanduser("~"), ".anaconda")
    if not os.path.exists(directory):
        os.makedirs(directory)
    filepath = os.path.join(directory, "assistant.json")
    if not os.path.exists(filepath):
        with open(filepath, "w", os.O_EXCL) as f:
            f.write("{}")
    return filepath


class GetDiskStateRouteHandler(BackendHandler):
    async def get(self, matched_part=None, *args, **kwargs):
        """
        Reads the application state from ~/.anaconda/assistant.json
        We never parse the file, we just return it as a string.
        """
        try:
            filepath = prepare_assistant_settings()
            with open(filepath) as f:
                self.finish(f.read())
        except FileNotFoundError:
            self.finish("{}")


class SyncDiskStateRouteHandler(APIHandler):
    @tornado.web.authenticated
    async def post(self, matched_part=None, *args, **kwargs):
        """
        Writes the application state to ~/.anaconda/assistant.json
        We never parse the file, we just write the string as-is.
        We use "x" mode to ensure that the file doesn't exist.
        """
        try:
            filepath = prepare_assistant_settings()
            with open(filepath, "w", os.O_EXCL) as f:
                f.write(self.request.body.decode("utf-8"))
            self.finish()
        except Exception as e:
            print(f"Error occurred: {e}")
            raise HTTPError(500, reason=f"Server Error: {e}")


# In order to summarize files for the user, we keep track of when they have been updated
files_last_modified = {}


# Time to wait for a file to change
def get_wait_time():
    return int(os.environ.get("ASSISTANT_MONITOR_FILES_WAIT_TIME", 30))


class MonitorFileChangesRouteHandler(APIHandler):
    @tornado.web.authenticated
    async def get(self, matched_part=None, *args, **kwargs):
        """
        Lists all the files `.ipynb` files in the current directory,
        then waits until one of the files change,
        then returns an object with the following properties:
        - path: the path to the file
        - last_modified: the last time the file was modified
        """
        try:
            # Get the current directory, then list all the .ipynb files
            path = os.path.abspath(os.path.join(os.getcwd(), "."))
            files = [f for f in os.listdir(path) if f.endswith(".ipynb")]

            # If there are no files, return an empty object
            if not files:
                self.finish(json.dumps({"path": None, "last_modified": 0}))
                return

            # Wait (up to 1 minute) for a file to change
            current_time = datetime.datetime.now()
            while (datetime.datetime.now() - current_time).seconds < get_wait_time():
                for f in files:
                    last_modified = os.path.getmtime(f)
                    # If the file has not been seen before, save the last modified time and continue to the next file
                    if f not in files_last_modified:
                        files_last_modified[f] = last_modified
                        continue
                    # If the file has been modified, save the last modified time and return the file
                    if last_modified > files_last_modified.get(f, 0):
                        files_last_modified[f] = last_modified
                        self.finish(json.dumps({"path": f, "last_modified": last_modified}))
                        return
                    # In any case, update the last modified time, then continue to the next file.
                    files_last_modified[f] = last_modified
                    continue
                await tornado.gen.sleep(1)
            self.finish(json.dumps({"path": None, "last_modified": 0}))
        except Exception as e:
            print(f"Error occurred: {e}")
            raise HTTPError(500, reason=f"Server Error: {e}")


class SummarizeFileRouteHandler(BackendHandler):
    @tornado.web.authenticated
    async def post(self, matched_part=None, *args, **kwargs):
        """
        Receives a file path, then sends it to the Assistant API
        to get the summary of the file by making a request to
        /v1/pro/summaries
        """
        try:
            # Load the file from the request body
            json_body = json.loads(self.request.body)
            file_path = json_body.get("file_path")

            # Read the file contents and trim to 150,000 characters (roughly 50000 tokens. GPT-4o mini has a limit of 128k tokens)
            with open(file_path) as f:
                file_contents = f.read(150000)

            headers = {
                "Content-Type": "application/json",
                "X-Client-Source": json_body.get("x_client_source", "anaconda-local-dev"),
                "X-Client-Version": json_body.get("x_client_version", "0.4.0"),
                "X-SDK-Version": json_body.get("x_sdk_version", "0.0.1"),
            }

            assistant_api_request_body = {
                "source": {
                    "name": file_path,
                    "data": file_contents,
                    "type": "ipynb",
                },
                "session_id": json_body.get("session_id", "session_123"),
                "response_message_id": json_body.get("response_message_id", "response_123"),
                "skip_logging": json_body.get("skip_logging", False),
            }

            # Update headers with CloudFlare credentials if available
            cloud_flare_client = os.getenv("NUCLEUS_CLOUDFLARE_CLIENT_ID")
            cloud_flare_secret = os.getenv("NUCLEUS_CLOUDFLARE_CLIENT_SECRET")
            if cloud_flare_client and cloud_flare_secret:
                headers["CF-Access-Client-Id"] = cloud_flare_client
                headers["CF-Access-Client-Secret"] = cloud_flare_secret

            response = await self.anaconda_proxy(
                "assistant/v1/pro/summaries",
                method="POST",
                headers=headers,
                json=assistant_api_request_body,
            )
            self.finish(response["remote_data"])
        except Exception as e:
            print(f"Error occurred: {e}")
            raise HTTPError(500, reason=f"Server Error: {e}")


def get_routes(base_url: str) -> _RuleList:
    handlers: Dict[str, Type[BackendHandler]] = {
        "nucleus_user": NucleusUserRouteHandler,
        "api_key": ApiKeyRouteHandler,
        "login": LoginRouteHandler,
        "logout": LogoutRouteHandler,
        "get_disk_state": GetDiskStateRouteHandler,
        "sync_disk_state": SyncDiskStateRouteHandler,
        "monitor_file_changes": MonitorFileChangesRouteHandler,
        "summarize_file": SummarizeFileRouteHandler,
    }
    return create_rules(base_url, "aext_assistant_server", handlers)
