#!/usr/bin/env python
from __future__ import annotations

import inspect
import sys
import time
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
from unittest.mock import Mock, PropertyMock

import pytest

import snowflake.connector.errors
from snowflake.connector.compat import IS_WINDOWS
from snowflake.connector.constants import OCSPMode
from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION
from snowflake.connector.network import SnowflakeRestful

from .mock_utils import mock_connection

try:  # pragma: no cover
    from snowflake.connector.auth import (
        Auth,
        AuthByDefault,
        AuthByOAuth,
        AuthByOauthCode,
        AuthByOauthCredentials,
        AuthByPlugin,
    )
except ImportError:
    from snowflake.connector.auth import Auth
    from snowflake.connector.auth_by_plugin import AuthByPlugin
    from snowflake.connector.auth_default import AuthByDefault
    from snowflake.connector.auth_oauth import AuthByOAuth
    from snowflake.connector.auth_oauth_code import AuthByOauthCode
    from snowflake.connector.auth_oauth_credentials import AuthByOauthCredentials

from snowflake.connector.errors import DatabaseError
from snowflake.connector.network import ReauthenticationRequest


def _init_rest(application, post_requset):
    connection = mock_connection()
    connection.errorhandler = Mock(return_value=None)
    connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN)
    connection.cert_revocation_check_mode = "TEST_CRL_MODE"
    type(connection).application = PropertyMock(return_value=application)
    type(connection)._internal_application_name = PropertyMock(return_value=CLIENT_NAME)
    type(connection)._internal_application_version = PropertyMock(
        return_value=CLIENT_VERSION
    )

    rest = SnowflakeRestful(
        host="testaccount.snowflakecomputing.com", port=443, connection=connection
    )
    rest._post_request = post_requset
    return rest


def _create_mock_auth_mfs_rest_response(next_action: str):
    def _mock_auth_mfa_rest_response(url, headers, body, **kwargs):
        """Tests successful case."""
        global mock_cnt
        _ = url
        _ = headers
        _ = body
        _ = kwargs.get("dummy")
        if mock_cnt == 0:
            ret = {
                "success": True,
                "message": None,
                "data": {
                    "nextAction": next_action,
                    "inFlightCtx": "inFlightCtx",
                },
            }
        elif mock_cnt == 1:
            ret = {
                "success": True,
                "message": None,
                "data": {
                    "token": "TOKEN",
                    "masterToken": "MASTER_TOKEN",
                },
            }

        mock_cnt += 1
        return ret

    return _mock_auth_mfa_rest_response


def _mock_auth_mfa_rest_response_failure(url, headers, body, **kwargs):
    """Tests failed case."""
    global mock_cnt
    _ = url
    _ = headers
    _ = body
    _ = kwargs.get("dummy")

    if mock_cnt == 0:
        ret = {
            "success": True,
            "message": None,
            "data": {
                "nextAction": "EXT_AUTHN_DUO_ALL",
                "inFlightCtx": "inFlightCtx",
            },
        }
    elif mock_cnt == 1:
        ret = {
            "success": True,
            "message": None,
            "data": {
                "nextAction": "BAD",
                "inFlightCtx": "inFlightCtx",
            },
        }
    elif mock_cnt == 2:
        ret = {
            "success": True,
            "message": None,
            "data": None,
        }
    mock_cnt += 1
    return ret


def _mock_auth_mfa_rest_response_timeout(url, headers, body, **kwargs):
    """Tests timeout case."""
    global mock_cnt
    _ = url
    _ = headers
    _ = body
    _ = kwargs.get("dummy")
    if mock_cnt == 0:
        ret = {
            "success": True,
            "message": None,
            "data": {
                "nextAction": "EXT_AUTHN_DUO_ALL",
                "inFlightCtx": "inFlightCtx",
            },
        }
    elif mock_cnt == 1:
        time.sleep(10)  # should timeout while here
        ret = {}
    elif mock_cnt == 2:
        ret = {
            "success": True,
            "message": None,
            "data": None,
        }

    mock_cnt += 1
    return ret


@pytest.mark.skipif(
    IS_WINDOWS,
    reason="There are consistent race condition issues with the global mock_cnt used for this test on windows",
)
@pytest.mark.parametrize(
    "next_action", ("EXT_AUTHN_DUO_ALL", "EXT_AUTHN_DUO_PUSH_N_PASSCODE")
)
def test_auth_mfa(next_action: str):
    """Authentication by MFA."""
    global mock_cnt
    application = "testapplication"
    account = "testaccount"
    user = "testuser"
    password = "testpassword"

    # success test case
    mock_cnt = 0
    rest = _init_rest(application, _create_mock_auth_mfs_rest_response(next_action))
    auth = Auth(rest)
    auth_instance = AuthByDefault(password)
    auth.authenticate(auth_instance, account, user)
    assert not rest._connection.errorhandler.called  # not error
    assert rest.token == "TOKEN"
    assert rest.master_token == "MASTER_TOKEN"

    # failure test case
    mock_cnt = 0
    rest = _init_rest(application, _mock_auth_mfa_rest_response_failure)
    auth = Auth(rest)
    auth_instance = AuthByDefault(password)
    auth.authenticate(auth_instance, account, user)
    assert rest._connection.errorhandler.called  # error

    # timeout 1 second
    mock_cnt = 0
    rest = _init_rest(application, _mock_auth_mfa_rest_response_timeout)
    auth = Auth(rest)
    auth_instance = AuthByDefault(password)
    auth.authenticate(auth_instance, account, user, timeout=1)
    assert rest._connection.errorhandler.called  # error

    # ret["data"] is none
    with pytest.raises(snowflake.connector.errors.Error):
        mock_cnt = 2
        rest = _init_rest(application, _mock_auth_mfa_rest_response_timeout)
        auth = Auth(rest)
        auth_instance = AuthByDefault(password)
        auth.authenticate(auth_instance, account, user)


def _mock_auth_password_change_rest_response(url, headers, body, **kwargs):
    """Test successful case."""
    global mock_cnt
    _ = url
    _ = headers
    _ = body
    _ = kwargs.get("dummy")
    if mock_cnt == 0:
        ret = {
            "success": True,
            "message": None,
            "data": {
                "nextAction": "PWD_CHANGE",
                "inFlightCtx": "inFlightCtx",
            },
        }
    elif mock_cnt == 1:
        ret = {
            "success": True,
            "message": None,
            "data": {
                "token": "TOKEN",
                "masterToken": "MASTER_TOKEN",
            },
        }

    mock_cnt += 1
    return ret


def test_auth_password_change():
    """Tests password change."""
    global mock_cnt

    def _password_callback():
        return "NEW_PASSWORD"

    application = "testapplication"
    account = "testaccount"
    user = "testuser"
    password = "testpassword"

    # success test case
    mock_cnt = 0
    rest = _init_rest(application, _mock_auth_password_change_rest_response)
    auth = Auth(rest)
    auth_instance = AuthByDefault(password)
    auth.authenticate(
        auth_instance, account, user, password_callback=_password_callback
    )
    assert not rest._connection.errorhandler.called  # not error


def test_authbyplugin_abc_api():
    """This test verifies that the abstract function signatures have not changed."""
    bc = AuthByPlugin

    # Verify properties
    assert inspect.isdatadescriptor(bc.timeout)
    assert inspect.isdatadescriptor(bc.type_)
    assert inspect.isdatadescriptor(bc.assertion_content)

    # Verify method signatures
    # update_body
    if sys.version_info < (3, 12):
        assert inspect.isfunction(bc.update_body)
        assert str(inspect.signature(bc.update_body).parameters) == (
            "OrderedDict([('self', <Parameter \"self\">), "
            "('body', <Parameter \"body: 'dict[Any, Any]'\">)])"
        )

        # authenticate
        assert inspect.isfunction(bc.prepare)
        assert str(inspect.signature(bc.prepare).parameters) == (
            "OrderedDict([('self', <Parameter \"self\">), "
            "('conn', <Parameter \"conn: 'SnowflakeConnection'\">), "
            "('authenticator', <Parameter \"authenticator: 'str'\">), "
            "('service_name', <Parameter \"service_name: 'str | None'\">), "
            "('account', <Parameter \"account: 'str'\">), "
            "('user', <Parameter \"user: 'str'\">), "
            "('password', <Parameter \"password: 'str | None'\">), "
            "('kwargs', <Parameter \"**kwargs: 'Any'\">)])"
        )

        # handle_failure
        assert inspect.isfunction(bc._handle_failure)
        assert str(inspect.signature(bc._handle_failure).parameters) == (
            "OrderedDict([('self', <Parameter \"self\">), "
            "('conn', <Parameter \"conn: 'SnowflakeConnection'\">), "
            "('ret', <Parameter \"ret: 'dict[Any, Any]'\">), "
            "('kwargs', <Parameter \"**kwargs: 'Any'\">)])"
        )

        # handle_timeout
        assert inspect.isfunction(bc.handle_timeout)
        assert str(inspect.signature(bc.handle_timeout).parameters) == (
            "OrderedDict([('self', <Parameter \"self\">), "
            "('authenticator', <Parameter \"authenticator: 'str'\">), "
            "('service_name', <Parameter \"service_name: 'str | None'\">), "
            "('account', <Parameter \"account: 'str'\">), "
            "('user', <Parameter \"user: 'str'\">), "
            "('password', <Parameter \"password: 'str'\">), "
            "('kwargs', <Parameter \"**kwargs: 'Any'\">)])"
        )
    else:
        # starting from python 3.12 the repr of collections.OrderedDict is changed
        # to use regular dictionary formating instead of pairs of keys and values.
        # see https://github.com/python/cpython/issues/101446
        assert inspect.isfunction(bc.update_body)
        assert str(inspect.signature(bc.update_body).parameters) == (
            """OrderedDict({'self': <Parameter "self">, \
'body': <Parameter "body: 'dict[Any, Any]'">})"""
        )

        # authenticate
        assert inspect.isfunction(bc.prepare)
        assert str(inspect.signature(bc.prepare).parameters) == (
            """OrderedDict({'self': <Parameter "self">, \
'conn': <Parameter "conn: 'SnowflakeConnection'">, \
'authenticator': <Parameter "authenticator: 'str'">, \
'service_name': <Parameter "service_name: 'str | None'">, \
'account': <Parameter "account: 'str'">, \
'user': <Parameter "user: 'str'">, \
'password': <Parameter "password: 'str | None'">, \
'kwargs': <Parameter "**kwargs: 'Any'">})"""
        )

        # handle_failure
        assert inspect.isfunction(bc._handle_failure)
        assert str(inspect.signature(bc._handle_failure).parameters) == (
            """OrderedDict({'self': <Parameter "self">, \
'conn': <Parameter "conn: 'SnowflakeConnection'">, \
'ret': <Parameter "ret: 'dict[Any, Any]'">, \
'kwargs': <Parameter "**kwargs: 'Any'">})"""
        )

        # handle_timeout
        assert inspect.isfunction(bc.handle_timeout)
        assert str(inspect.signature(bc.handle_timeout).parameters) == (
            """OrderedDict({'self': <Parameter "self">, \
'authenticator': <Parameter "authenticator: 'str'">, \
'service_name': <Parameter "service_name: 'str | None'">, \
'account': <Parameter "account: 'str'">, \
'user': <Parameter "user: 'str'">, \
'password': <Parameter "password: 'str'">, \
'kwargs': <Parameter "**kwargs: 'Any'">})"""
        )


def test_auth_by_default_prepare_body_does_not_overwrite_client_environment_fields():
    password = "testpassword"
    auth_class = AuthByDefault(password)

    req_body_before = create_mock_auth_body()
    req_body_after = apply_auth_class_update_body(auth_class, req_body_before)

    assert all(
        [
            req_body_before["data"]["CLIENT_ENVIRONMENT"][k]
            == req_body_after["data"]["CLIENT_ENVIRONMENT"][k]
            for k in req_body_before["data"]["CLIENT_ENVIRONMENT"]
        ]
    )


def _mock_oauth_token_expired_rest_response(url, headers, body, **kwargs):
    """Mock rest response for OAuth access token expired error."""
    from snowflake.connector.network import OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE

    return {
        "success": False,
        "message": "OAuth access token expired",
        "code": OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE,
        "data": {},
    }


@pytest.mark.skipolddriver
@pytest.mark.parametrize(
    "auth_instance, expected_exc_type",
    [
        (AuthByOAuth("test_oauth_token"), DatabaseError),
        (
            AuthByOauthCode(
                application="testapp",
                client_id="test_client_id",
                client_secret="test_client_secret",
                authentication_url="https://auth.example.com",
                token_request_url="https://token.example.com",
                redirect_uri="http://localhost:8080",
                scope="session:role-any",
                host="testaccount.snowflakecomputing.com",
            ),
            ReauthenticationRequest,
        ),
        (
            AuthByOauthCredentials(
                application="testapp",
                client_id="test_client_id",
                client_secret="test_client_secret",
                token_request_url="https://token.example.com",
                scope="session:role-any",
            ),
            ReauthenticationRequest,
        ),
    ],
)
def test_oauth_token_expired_error_handling(auth_instance, expected_exc_type):
    """Test that OAuth authenticators handle token expiry errors differently.

    - AuthByOAuth should raise DatabaseError (falls through to general error handling)
    - AuthByOauthCode and AuthByOauthCredentials should raise ProgrammingError (via ReauthenticationRequest)
    """

    def mock_errorhandler_always_raise(connection, cursor, error_class, error_value):
        raise error_class(**error_value)

    application = "testapplication"
    account = "testaccount"
    user = "testuser"
    rest = _init_rest(application, _mock_oauth_token_expired_rest_response)
    rest._connection.errorhandler = mock_errorhandler_always_raise
    auth = Auth(rest)
    with pytest.raises(expected_exc_type):
        auth.authenticate(auth_instance, account, user)
