#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import asyncio
from unittest.mock import AsyncMock, MagicMock

import aiohttp

from snowflake.connector.aio._session_manager import SessionManager
from snowflake.connector.auth.by_plugin import DEFAULT_AUTH_CLASS_TIMEOUT
from snowflake.connector.connection import DEFAULT_BACKOFF_POLICY


def mock_async_request_with_action(next_action, sleep=None):
    async def mock_request(*args, **kwargs):
        if sleep is not None:
            await asyncio.sleep(sleep)
        if next_action == "RETRY":
            return MagicMock(
                status=503,
                close=lambda: None,
            )
        elif next_action == "ERROR":
            raise aiohttp.ClientConnectionError()

    return mock_request


def get_mock_session_manager(allow_send: bool = False):
    """Create a mock async SessionManager that prevents actual network calls in tests."""

    async def forbidden_connect(*args, **kwargs):
        raise NotImplementedError("Unit test tried to make real network connection")

    class MockSessionManager(SessionManager):
        def make_session(self, *, url: str | None = None):
            session = super().make_session(url=url)
            if not allow_send:
                # Block at connector._connect level (like sync blocks session.send)
                # This allows patches on session.request to work
                session.connector._connect = forbidden_connect
            return session

    return MockSessionManager()


def mock_connection(
    login_timeout=DEFAULT_AUTH_CLASS_TIMEOUT,
    network_timeout=None,
    socket_timeout=None,
    backoff_policy=DEFAULT_BACKOFF_POLICY,
    disable_saml_url_check=False,
    session_manager=None,
    cert_revocation_check_mode="DISABLED",
    platform_detection_timeout_seconds=0.0,
):
    return AsyncMock(
        _login_timeout=login_timeout,
        login_timeout=login_timeout,
        _network_timeout=network_timeout,
        network_timeout=network_timeout,
        _socket_timeout=socket_timeout,
        socket_timeout=socket_timeout,
        _backoff_policy=backoff_policy,
        backoff_policy=backoff_policy,
        _backoff_generator=backoff_policy(),
        _disable_saml_url_check=disable_saml_url_check,
        _session_manager=session_manager or get_mock_session_manager(),
        _update_parameters=AsyncMock(return_value=None),
        cert_revocation_check_mode=cert_revocation_check_mode,
        platform_detection_timeout_seconds=platform_detection_timeout_seconds,
    )
