import asyncio
import json
import subprocess
from typing import Coroutine, Dict, List, Optional, Type

import tornado
from aext_environments_server.consts import CONDA_CMD_TIMEOUT, ProcessResult
from aext_environments_server.metrics import metrics
from aext_environments_server.services.environments import environment_service
from aext_environments_server.services.static_files import envs_static_file_service
from httpx import HTTPError, NetworkError, TimeoutException
from tornado.routing import _RuleList

from aext_shared import logger as custom_logger
from aext_shared.auth import require_api_key
from aext_shared.config import SHARED_CONFIG
from aext_shared.consts import RequestMethods, UserSubscriptions
from aext_shared.handler import BackendHandler, WebSocketBackendHandler, create_rules
from aext_shared.request_utils import anaconda_cloud_request
from aext_shared.utils import Timer

logger = custom_logger.logger

config = SHARED_CONFIG


class EnvironmentsHandler(BackendHandler):
    @tornado.web.authenticated
    @require_api_key
    async def get(self):
        try:
            static_environments = await environment_service.list_environments()
            if static_environments:
                return self.finish(json.dumps(static_environments))
            else:
                self.set_status(502, "Failed to fetch environment specs")
            return self.finish()
        except RuntimeError:
            self.set_status(502, "Failed to fetch environments")
            return self.finish()
        except TimeoutException:
            self.set_status(504, "Request timeout while fetching environments")
            return self.finish()
        except NetworkError:
            self.set_status(502, "Network error while fetching environments")
            return self.finish()
        except Exception:
            self.set_status(500, "Unexpected error while fetching environments")
            return self.finish()


class EnvironmentManagementWebSocket(WebSocketBackendHandler):
    running_processes: Dict[str, asyncio.subprocess.Process] = {}
    subscribers: Dict[str, List["EnvironmentManagementWebSocket"]] = {}

    @require_api_key
    async def open(self, action, env_name):
        """
        Called when a new WebSocket connection is opened.
        The URL provides the action ("install" or "remove")
        and the environment name.
        """
        action = action.lower()
        if action not in ("install", "remove"):
            self.write_message("Invalid action. Use /install/<env_name> or /remove/<env_name>.")
            self.close()
            return

        available_envs = [env["name"] for env in await environment_service.list_environments()]
        if env_name not in available_envs:
            self.write_message("Unknown environment")
            self.close()
            return

        # If a process is already running for this environment, add this connection to subscribers.
        if env_name in EnvironmentManagementWebSocket.running_processes:
            EnvironmentManagementWebSocket.subscribers.setdefault(env_name, []).append(self)
            self.write_message(f"Resuming stream for environment: {env_name}")
        else:
            # Otherwise, register this connection and start the process.
            EnvironmentManagementWebSocket.subscribers.setdefault(env_name, []).append(self)
            if action == "install":
                self.write_message(f"Starting installation of environment: {env_name}")
                asyncio.create_task(self.install_environment(env_name))
            elif action == "remove":
                self.write_message(f"Starting removal of environment: {env_name}")
                asyncio.create_task(self.remove_environment(env_name))

    def on_close(self):
        # Remove this connection from any subscribers list
        for env, subs in EnvironmentManagementWebSocket.subscribers.items():
            if self in subs:
                subs.remove(self)

    async def _broadcast_message(self, env_name: str, message: str):
        """
        Sends a message to all connections subscribed to a particular process
        Args:
            env_name: name of the environment
            message: content of the message that will be sent
        """
        for subscriber in EnvironmentManagementWebSocket.subscribers.get(env_name, []):
            subscriber.write_message(message)

    async def _keep_alive(self, env_name: str, process: subprocess.Popen):
        """
        Sends a keep-alive message every 10 seconds until the process terminates.
        """
        while True:
            await asyncio.sleep(45)
            if process.poll() is not None:
                break
            await self._broadcast_message(env_name, "[SRV] keep-alive")

    async def _stream_stdout(self, env_name: str, process: subprocess.Popen):
        """
        Reads Popen's stdout line-by-line asynchronously.
        Args:
            env_name: environment's name
            process: the OS process that is executing the task

        """
        loop = asyncio.get_running_loop()

        async def read_pipe(pipe, is_stderr=False):
            """Read from process stdout/stderr without blocking asyncio loop."""
            while True:
                line = await loop.run_in_executor(None, pipe.readline)
                if not line:
                    break
                msg = f"[STDERR] {line.strip()}" if is_stderr else line.strip()
                await self._broadcast_message(env_name, msg)
                await asyncio.sleep(0.01)  # Ensure some time for tornado's process to stream the message

        await asyncio.gather(
            read_pipe(process.stdout, is_stderr=False),
            read_pipe(process.stderr, is_stderr=True),
            self._keep_alive(env_name, process),
        )

    async def _disconnect_all(self, env_name: str):
        """
        Given a environment name it disconnects all listeners
        Args:
            env_name: name of the environment that is being listened
        """
        await self._broadcast_message(env_name, "Disconnecting")
        EnvironmentManagementWebSocket.running_processes.pop(env_name, None)
        for subscriber in EnvironmentManagementWebSocket.subscribers.get(env_name, []):
            subscriber.close()
        EnvironmentManagementWebSocket.subscribers.pop(env_name, None)

    async def _execute_command_with_timeout(
        self, env_name: str, environment_service_coro: Coroutine, success_message: Optional[str] = None
    ) -> bool:
        """
        Executes a subprocess command with a timeout,
        broadcasting output to all WebSocket subscribers in real-time.
        Args:
            env_name: name of the environment
            environment_service_coro: EnvironmentsService service function that will be called
            success_message: an option message that is sent if the process is executed correctly
        """
        await self._broadcast_message(env_name, "[SRV] Creating conda process...")
        process = await environment_service_coro
        if not process:
            await self._broadcast_message(env_name, "[SRV] Process creation failed.")
            return False

        await self._broadcast_message(env_name, "[SRV] Process created")
        EnvironmentManagementWebSocket.running_processes[env_name] = process
        loop = asyncio.get_running_loop()
        stream_task = asyncio.create_task(self._stream_stdout(env_name, process))
        try:
            await self._broadcast_message(env_name, "[SRV] Start streaming")
            return_code = await asyncio.wait_for(loop.run_in_executor(None, process.wait), timeout=CONDA_CMD_TIMEOUT)
        except asyncio.TimeoutError:
            process.kill()
            return_code = -1
            await self._broadcast_message(env_name, "[SRV] Operation timed out")
        finally:
            await stream_task

        await self._broadcast_message(env_name, f"[SRV] Process exited with code {return_code}")
        has_succeeded = return_code == 0
        if has_succeeded and success_message:
            await self._broadcast_message(env_name, success_message.format(status=ProcessResult.SUCCEEDED.value))
        return has_succeeded

    async def _remove_environment(self, env_name: str, success_message: Optional[str] = None) -> bool:
        """
        Execute a command to remove an environment
        Args:
            env_name: name of the environment
            success_message: option message that is sent if the removal succeed
        """
        remove_environment_coro = environment_service.remove_environment(env_name)
        return await self._execute_command_with_timeout(env_name, remove_environment_coro, success_message)

    async def _install_environment(self, env_name: str, success_message: Optional[str] = None) -> bool:
        """
        Execute a command to install a new environment
        Args:
            env_name: name of the environment
            success_message: option message that is sent if the removal succeed
        """
        install_environment_coro = environment_service.install_environment(env_name)
        cmd_response = await self._execute_command_with_timeout(env_name, install_environment_coro, success_message)
        await environment_service.clean_up_installation(env_name)
        return cmd_response

    async def _install_kernel(self, env_name: str, success_message: Optional[str] = None) -> bool:
        """
        Execute a process that installs a ipykernel
        Args:
            env_name: name of the environment that will have a kernel installed
            success_message: optional message that will be sent if the process succeed
        """
        install_kernel_coro = environment_service.install_kernel(env_name)
        return await self._execute_command_with_timeout(env_name, install_kernel_coro, success_message)

    async def _remove_kernel(self, env_name: str, success_message: Optional[str] = None):
        remove_kernel_coro = environment_service.remove_kernel(env_name)
        return await self._execute_command_with_timeout(env_name, remove_kernel_coro, success_message)

    async def install_environment(self, env_name: str):
        """
        Installs a conda environment from a YAML file with a timeout.
        The YAML file is assumed to reside in a 'conda_environments' folder relative to this file.
        Args:
            env_name: Name of the environment that will be installed
        """
        try:
            environments_index: Dict = await envs_static_file_service.get_index_file()
            required_subscriptions = environments_index[env_name]["metadata"]["installable_for"]
        except KeyError:
            await self._broadcast_message(env_name, f"[SRV] Not able to find environment {env_name}")
            await self._disconnect_all(env_name)
            return
        except Exception:
            logger.debug("Not possible to fetch environments index")
            await self._broadcast_message(env_name, f"[SRV] Not possible to fetch environments index {env_name}")
            await self._disconnect_all(env_name)
            return
        try:
            account_info_response = await anaconda_cloud_request(
                "account/notebooks", RequestMethods.GET, user_credentials=await self.get_user_access_credentials()
            )
        except Exception:
            logger.debug("Could not fetch user subscription")
            await self._broadcast_message(env_name, "[SRV] Could not fetch user subscription")
            await self._disconnect_all(env_name)
            return

        if account_info_response.status_code == 200:
            try:
                user_notebook_subscription = account_info_response.json()["notebooks_service_subscription"]
                if user_notebook_subscription not in required_subscriptions:
                    await self._broadcast_message(env_name, f"[SRV] Business subscription required for {env_name}")
                    await self._disconnect_all(env_name)
                    return
            except KeyError:
                logger.debug("Could not fetch user subscription")
                await self._broadcast_message(env_name, "[SRV] Could not fetch user subscription")
                await self._disconnect_all(env_name)
                return
        else:
            # only free environments can be installed - user is either not logged in or have a free subscription
            if UserSubscriptions.FREE_SUBSCRIPTION not in required_subscriptions:
                await self._broadcast_message(env_name, f"[SRV] Business subscription required for {env_name}")
                await self._disconnect_all(env_name)
                return

        environment_installation_message = f"Installation of '{env_name}' {{status}}"
        kernel_installation_message = f"Installation of kernel for '{env_name}' {{status}}"
        with Timer() as timer:
            try:
                env_succeeded = await self._install_environment(env_name, environment_installation_message)
            except Exception as ex:
                await self._broadcast_message(env_name, f"[SRV] Not able to install a new environment: {ex}")
                await self._broadcast_message(
                    env_name, environment_installation_message.format(status=ProcessResult.FAILED.value)
                )

        if not env_succeeded:
            await self._broadcast_message(
                env_name, environment_installation_message.format(status=ProcessResult.FAILED.value)
            )
        else:
            kernel_succeeded = await self._install_kernel(env_name, kernel_installation_message)

            await metrics.send(
                user_credentials=await self.get_user_access_credentials(),
                metric_data={
                    "event": metrics.EVENT_INSTALL_ENVIRONMENT,
                    "event_params": {
                        "environment_name": env_name,
                        "elapsed_time_sec": timer.elapsed,
                        "user_environment": config["environment"].value,
                    },
                    "service_id": "aext-cloud",
                },
            )
            if kernel_succeeded:
                await environment_service.conda_clean()
            else:
                await self._remove_environment(env_name)
                await self._broadcast_message(
                    env_name, kernel_installation_message.format(status=ProcessResult.FAILED.value)
                )

        await self._disconnect_all(env_name)

    async def remove_environment(self, env_name: str):
        """
        Removes a conda environment with a timeout.
        Args:
            env_name: name of the environment that should be removed
        """
        message = f"Removal of '{env_name}' {{status}}"
        with Timer() as timer:
            env_succeeded = await self._remove_environment(env_name, message)
        if env_succeeded:
            await self._remove_kernel(env_name)
        else:
            await self._broadcast_message(env_name, message.format(status=ProcessResult.FAILED.value))

        await metrics.send(
            user_credentials=await self.get_user_access_credentials(),
            metric_data={
                "event": metrics.EVENT_UNINSTALL_ENVIRONMENT,
                "event_params": {
                    "environment_name": env_name,
                    "elapsed_time_sec": timer.elapsed,
                    "user_environment": config["environment"].value,
                },
                "service_id": "aext-cloud",
            },
        )

        await self._disconnect_all(env_name)


class EnvironmentHealthz(BackendHandler):
    @tornado.web.authenticated
    async def get(self):
        self.finish("ok")


def get_routes(base_url: str) -> _RuleList:
    handlers: Dict[str, Type[BackendHandler]] = {
        "healthz": EnvironmentHealthz,
        "environments": EnvironmentsHandler,
        r"(install|remove)/(.*)": EnvironmentManagementWebSocket,
    }
    return create_rules(base_url, "aext_environments_server", handlers)
