#!/usr/bin/env python
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

import json
import logging
import stat
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from secrets import token_urlsafe
from test.randomize import random_string
from test.unit.aio.mock_utils import mock_async_request_with_action
from test.unit.mock_utils import zero_backoff
from textwrap import dedent
from unittest import mock
from unittest.mock import patch

import aiohttp
import pytest
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa

import snowflake.connector.aio
from snowflake.connector.aio._network import SnowflakeRestful
from snowflake.connector.aio.auth import (
    AuthByDefault,
    AuthByOAuth,
    AuthByOkta,
    AuthByUsrPwdMfa,
    AuthByWebBrowser,
)
from snowflake.connector.config_manager import CONFIG_MANAGER
from snowflake.connector.connection import DEFAULT_CONFIGURATION
from snowflake.connector.constants import (
    _CONNECTIVITY_ERR_MSG,
    ENV_VAR_PARTNER,
    QueryStatus,
)
from snowflake.connector.errors import (
    Error,
    HttpError,
    OperationalError,
    ProgrammingError,
)
from snowflake.connector.wif_util import AttestationProvider


@pytest.fixture(autouse=True)
def mock_detect_platforms():
    with patch(
        "snowflake.connector.auth._auth.detect_platforms", return_value=[]
    ) as mock_detect:
        yield mock_detect


def fake_connector(**kwargs) -> snowflake.connector.aio.SnowflakeConnection:
    return snowflake.connector.aio.SnowflakeConnection(
        user="user",
        account="account",
        password="testpassword",
        database="TESTDB",
        warehouse="TESTWH",
        **kwargs,
    )


def write_temp_file(file_path: Path, contents: str) -> Path:
    """Write the given string text to the given path, chmods it to be accessible, and returns the same path."""
    file_path.write_text(contents)
    file_path.chmod(stat.S_IRUSR | stat.S_IWUSR)
    return file_path


@asynccontextmanager
async def fake_db_conn(**kwargs):
    conn = fake_connector(**kwargs)
    await conn.connect()
    yield conn
    await conn.close()


@pytest.fixture
def mock_post_requests(monkeypatch):
    request_body = {}

    async def mock_post_request(request, url, headers, json_body, **kwargs):
        nonlocal request_body
        request_body.update(json.loads(json_body))
        return {
            "success": True,
            "message": None,
            "data": {
                "token": "TOKEN",
                "masterToken": "MASTER_TOKEN",
                "idToken": None,
                "parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}],
            },
        }

    monkeypatch.setattr(
        snowflake.connector.aio._network.SnowflakeRestful,
        "_post_request",
        mock_post_request,
    )

    return request_body


async def test_connect_with_service_name(mock_post_requests):
    async with fake_db_conn() as conn:
        assert conn.service_name == "FAKE_SERVICE_NAME"


@patch("snowflake.connector.aio._network.SnowflakeRestful._post_request")
async def test_connection_ignore_exception(mockSnowflakeRestfulPostRequest):
    async def mock_post_request(url, headers, json_body, **kwargs):
        global mock_cnt
        ret = None
        if mock_cnt == 0:
            # return from /v1/login-request
            ret = {
                "success": True,
                "message": None,
                "data": {
                    "token": "TOKEN",
                    "masterToken": "MASTER_TOKEN",
                    "idToken": None,
                    "parameters": [
                        {"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}
                    ],
                },
            }
        elif mock_cnt == 1:
            ret = {
                "success": False,
                "message": "Session gone",
                "data": None,
                "code": 390111,
            }
        mock_cnt += 1
        return ret

    # POST requests mock
    mockSnowflakeRestfulPostRequest.side_effect = mock_post_request

    global mock_cnt
    mock_cnt = 0

    account = "testaccount"
    user = "testuser"

    # connection
    con = snowflake.connector.aio.SnowflakeConnection(
        account=account,
        user=user,
        password="testpassword",
        database="TESTDB",
        warehouse="TESTWH",
    )
    await con.connect()
    # Test to see if closing connection works or raises an exception. If an exception is raised, test will fail.
    await con.close()


def test_is_still_running():
    """Checks that is_still_running returns expected results."""
    statuses = [
        (QueryStatus.RUNNING, True),
        (QueryStatus.ABORTING, False),
        (QueryStatus.SUCCESS, False),
        (QueryStatus.FAILED_WITH_ERROR, False),
        (QueryStatus.ABORTED, False),
        (QueryStatus.QUEUED, True),
        (QueryStatus.FAILED_WITH_INCIDENT, False),
        (QueryStatus.DISCONNECTED, False),
        (QueryStatus.RESUMING_WAREHOUSE, True),
        (QueryStatus.QUEUED_REPARING_WAREHOUSE, True),
        (QueryStatus.RESTARTED, False),
        (QueryStatus.BLOCKED, True),
        (QueryStatus.NO_DATA, True),
    ]
    for status, expected_result in statuses:
        assert (
            snowflake.connector.aio.SnowflakeConnection.is_still_running(status)
            == expected_result
        )


async def test_partner_env_var(mock_post_requests, monkeypatch):
    PARTNER_NAME = "Amanda"

    monkeypatch.setenv(ENV_VAR_PARTNER, PARTNER_NAME)
    async with fake_db_conn() as conn:
        assert conn.application == PARTNER_NAME

    assert (
        mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == PARTNER_NAME
    )


@pytest.mark.skipolddriver
@pytest.mark.parametrize(
    "sys_modules,application",
    [
        ({"streamlit": None}, "streamlit"),
        (
            {"ipykernel": None, "jupyter_core": None, "jupyter_client": None},
            "jupyter_notebook",
        ),
        ({"snowbooks": None}, "snowflake_notebook"),
    ],
)
async def test_imported_module(mock_post_requests, sys_modules, application):
    with patch.dict(sys.modules, sys_modules):
        async with fake_db_conn() as conn:
            assert conn.application == application

    assert (
        mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == application
    )


@pytest.mark.parametrize(
    "auth_class",
    (
        pytest.param(
            type("auth_class", (AuthByDefault,), {})("my_secret_password"),
            id="AuthByDefault",
        ),
        pytest.param(
            type("auth_class", (AuthByOAuth,), {})("my_token"),
            id="AuthByOAuth",
        ),
        pytest.param(
            type("auth_class", (AuthByOkta,), {})("Python connector"),
            id="AuthByOkta",
        ),
        pytest.param(
            type("auth_class", (AuthByUsrPwdMfa,), {})("password", "mfa_token"),
            id="AuthByUsrPwdMfa",
        ),
        pytest.param(
            type("auth_class", (AuthByWebBrowser,), {})(None, None),
            id="AuthByWebBrowser",
        ),
    ),
)
async def test_negative_custom_auth(auth_class):
    """Tests that non-AuthByKeyPair custom auth is not allowed."""
    with pytest.raises(
        TypeError,
        match="auth_class must be a child class of AuthByKeyPair",
    ):
        await snowflake.connector.aio.SnowflakeConnection(
            account="account",
            user="user",
            auth_class=auth_class,
        ).connect()


async def test_missing_default_connection(monkeypatch, tmp_path):
    connections_file = tmp_path / "aio_connections.toml"
    config_file = tmp_path / "aio_config.toml"
    with monkeypatch.context() as m:
        m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False)
        m.delenv("SNOWFLAKE_CONNECTIONS", raising=False)
        m.setattr(CONFIG_MANAGER, "conf_file_cache", None)
        m.setattr(CONFIG_MANAGER, "file_path", config_file)

        with pytest.raises(
            Error,
            match="Default connection with name 'default' cannot be found, known ones are \\[\\]",
        ):
            snowflake.connector.aio.SnowflakeConnection(
                connections_file_path=connections_file
            )


async def test_missing_default_connection_conf_file(monkeypatch, tmp_path):
    connection_name = random_string(5)
    connections_file = tmp_path / "aio_connections.toml"
    config_file = tmp_path / "aio_config.toml"
    config_file.write_text(
        dedent(
            f"""\
            default_connection_name = "{connection_name}"
            """
        )
    )
    config_file.chmod(stat.S_IRUSR | stat.S_IWUSR)
    with monkeypatch.context() as m:
        m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False)
        m.delenv("SNOWFLAKE_CONNECTIONS", raising=False)
        m.setattr(CONFIG_MANAGER, "conf_file_cache", None)
        m.setattr(CONFIG_MANAGER, "file_path", config_file)

        with pytest.raises(
            Error,
            match=f"Default connection with name '{connection_name}' cannot be found, known ones are \\[\\]",
        ):
            await snowflake.connector.aio.SnowflakeConnection(
                connections_file_path=connections_file
            ).connect()


async def test_missing_default_connection_conn_file(monkeypatch, tmp_path):
    connections_file = tmp_path / "aio_connections.toml"
    config_file = tmp_path / "aio_config.toml"
    connections_file.write_text(
        dedent(
            """\
            [con_a]
            user = "test user"
            account = "test account"
            password = "test password"
            """
        )
    )
    connections_file.chmod(stat.S_IRUSR | stat.S_IWUSR)
    with monkeypatch.context() as m:
        m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False)
        m.delenv("SNOWFLAKE_CONNECTIONS", raising=False)
        m.setattr(CONFIG_MANAGER, "conf_file_cache", None)
        m.setattr(CONFIG_MANAGER, "file_path", config_file)

        with pytest.raises(
            Error,
            match="Default connection with name 'default' cannot be found, known ones are \\['con_a'\\]",
        ):
            await snowflake.connector.aio.SnowflakeConnection(
                connections_file_path=connections_file
            ).connect()


async def test_missing_default_connection_conf_conn_file(monkeypatch, tmp_path):
    connection_name = random_string(5)
    connections_file = tmp_path / "aio_connections.toml"
    config_file = tmp_path / "aio_config.toml"
    config_file.write_text(
        dedent(
            f"""\
            default_connection_name = "{connection_name}"
            """
        )
    )
    config_file.chmod(stat.S_IRUSR | stat.S_IWUSR)
    connections_file.write_text(
        dedent(
            """\
            [con_a]
            user = "test user"
            account = "test account"
            password = "test password"
            """
        )
    )
    connections_file.chmod(stat.S_IRUSR | stat.S_IWUSR)
    with monkeypatch.context() as m:
        m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False)
        m.delenv("SNOWFLAKE_CONNECTIONS", raising=False)
        m.setattr(CONFIG_MANAGER, "conf_file_cache", None)
        m.setattr(CONFIG_MANAGER, "file_path", config_file)

        with pytest.raises(
            Error,
            match=f"Default connection with name '{connection_name}' cannot be found, known ones are \\['con_a'\\]",
        ):
            await snowflake.connector.aio.SnowflakeConnection(
                connections_file_path=connections_file
            ).connect()


async def test_invalid_backoff_policy():
    with pytest.raises(ProgrammingError):
        # zero_backoff() is a generator, not a generator function
        _ = await fake_connector(backoff_policy=zero_backoff()).connect()

    with pytest.raises(ProgrammingError):
        # passing a non-generator function should not work
        _ = await fake_connector(backoff_policy=lambda: None).connect()

    with pytest.raises(HttpError):
        # passing a generator function should make it pass config and error during connection
        _ = await fake_connector(backoff_policy=zero_backoff).connect()


@pytest.mark.parametrize("next_action", ("RETRY", "ERROR"))
@patch("aiohttp.ClientSession.request")
async def test_handle_timeout(mockSessionRequest, next_action):
    mockSessionRequest.side_effect = mock_async_request_with_action(
        next_action, sleep=5
    )

    with pytest.raises(OperationalError):
        # no backoff for testing
        async with fake_db_conn(
            login_timeout=9,
            backoff_policy=zero_backoff,
        ):
            pass

    # authenticator should be the only retry mechanism for login requests
    # 9 seconds should be enough for authenticator to attempt twice
    # however, loosen restrictions to avoid thread scheduling causing failure
    assert 1 < mockSessionRequest.call_count < 4


async def test_private_key_file_reading(tmp_path: Path):
    key_file = tmp_path / "aio_key.pem"

    private_key = rsa.generate_private_key(
        backend=default_backend(), public_exponent=65537, key_size=2048
    )

    private_key_pem = private_key.private_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption(),
    )

    key_file.write_bytes(private_key_pem)

    pkb = private_key.private_bytes(
        encoding=serialization.Encoding.DER,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption(),
    )

    exc_msg = "stop execution"

    with mock.patch(
        "snowflake.connector.aio.auth.AuthByKeyPair.__init__",
        side_effect=Exception(exc_msg),
    ) as m:
        with pytest.raises(
            Exception,
            match=exc_msg,
        ):
            await snowflake.connector.aio.SnowflakeConnection(
                account="test_account",
                user="test_user",
                private_key_file=str(key_file),
            ).connect()
    assert m.call_count == 1
    assert m.call_args_list[0].kwargs["private_key"] == pkb


async def test_encrypted_private_key_file_reading(tmp_path: Path):
    key_file = tmp_path / "aio_key.pem"
    private_key_password = token_urlsafe(25)
    private_key = rsa.generate_private_key(
        backend=default_backend(), public_exponent=65537, key_size=2048
    )

    private_key_pem = private_key.private_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.BestAvailableEncryption(
            private_key_password.encode("utf-8")
        ),
    )

    key_file.write_bytes(private_key_pem)

    pkb = private_key.private_bytes(
        encoding=serialization.Encoding.DER,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption(),
    )

    exc_msg = "stop execution"

    with mock.patch(
        "snowflake.connector.aio.auth.AuthByKeyPair.__init__",
        side_effect=Exception(exc_msg),
    ) as m:
        with pytest.raises(
            Exception,
            match=exc_msg,
        ):
            await snowflake.connector.aio.SnowflakeConnection(
                account="test_account",
                user="test_user",
                private_key_file=str(key_file),
                private_key_file_pwd=private_key_password,
            ).connect()
    assert m.call_count == 1
    assert m.call_args_list[0].kwargs["private_key"] == pkb


async def test_expired_detection():
    with mock.patch(
        "snowflake.connector.aio._network.SnowflakeRestful._post_request",
        return_value={
            "data": {
                "masterToken": "some master token",
                "token": "some token",
                "validityInSeconds": 3600,
                "masterValidityInSeconds": 14400,
                "displayUserName": "TEST_USER",
                "serverVersion": "7.42.0",
            },
            "code": None,
            "message": None,
            "success": True,
        },
    ):
        conn = fake_connector()
        await conn.connect()
    assert not conn.expired
    async with conn.cursor() as cur:
        with mock.patch(
            "snowflake.connector.aio._network.SnowflakeRestful.fetch",
            return_value={
                "data": {
                    "errorCode": "390114",
                    "reAuthnMethods": ["USERNAME_PASSWORD"],
                },
                "code": "390114",
                "message": "Authentication token has expired.  The user must authenticate again.",
                "success": False,
                "headers": None,
            },
        ):
            with pytest.raises(ProgrammingError):
                await cur.execute("select 1;")
    assert conn.expired


async def test_disable_saml_url_check_config():
    with mock.patch(
        "snowflake.connector.aio._network.SnowflakeRestful._post_request",
        return_value={
            "data": {
                "serverVersion": "a.b.c",
            },
            "code": None,
            "message": None,
            "success": True,
        },
    ):
        async with fake_db_conn() as conn:
            assert (
                conn._disable_saml_url_check
                == DEFAULT_CONFIGURATION.get("disable_saml_url_check")[0]
            )


def test_request_guid():
    assert (
        SnowflakeRestful.add_request_guid(
            "https://test.snowflakecomputing.com"
        ).startswith("https://test.snowflakecomputing.com?request_guid=")
        and SnowflakeRestful.add_request_guid(
            "http://test.snowflakecomputing.cn?a=b"
        ).startswith("http://test.snowflakecomputing.cn?a=b&request_guid=")
        and SnowflakeRestful.add_request_guid(
            "https://test.snowflakecomputing.com.cn"
        ).startswith("https://test.snowflakecomputing.com.cn?request_guid=")
        and SnowflakeRestful.add_request_guid("https://test.abc.cn?a=b")
        == "https://test.abc.cn?a=b"
    )


async def test_ssl_error_hint(caplog):
    with mock.patch(
        "aiohttp.ClientSession.request",
        side_effect=aiohttp.ClientSSLError(mock.Mock(), OSError("SSL error")),
    ), caplog.at_level(logging.DEBUG):
        with pytest.raises(OperationalError) as exc:
            await fake_connector().connect()
    assert _CONNECTIVITY_ERR_MSG in exc.value.msg and isinstance(
        exc.value, OperationalError
    )
    assert "SSL error" in caplog.text and _CONNECTIVITY_ERR_MSG in caplog.text


async def test_otel_error_message_async(caplog, mock_post_requests):
    """This test assumes that OpenTelemetry is not installed when tests are running."""
    with mock.patch("snowflake.connector.aio._network.SnowflakeRestful._post_request"):
        with caplog.at_level(logging.DEBUG):
            async with fake_connector():
                ...
    assert caplog.records
    important_records = [
        record
        for record in caplog.records
        if "Opentelemtry otel injection failed" in record.message
    ]
    assert len(important_records) == 1
    assert important_records[0].exc_text is not None


@pytest.mark.parametrize(
    "dependent_param,value",
    [
        ("workload_identity_provider", "AWS"),
        (
            "workload_identity_entra_resource",
            "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b",
        ),
        ("workload_identity_impersonation_path", ["subject-b", "subject-c"]),
    ],
)
async def test_cannot_set_dependent_params_without_wlid_authenticator(
    mock_post_requests, dependent_param, value
):
    with pytest.raises(ProgrammingError) as excinfo:
        await snowflake.connector.aio.connect(
            user="user",
            account="account",
            password="password",
            **{dependent_param: value},
        )
    assert (
        f"{dependent_param} was set but authenticator was not set to WORKLOAD_IDENTITY"
        in str(excinfo.value)
    )


@pytest.mark.parametrize(
    "provider_param",
    [
        None,
        "",
        "INVALID",
    ],
)
async def test_workload_identity_provider_is_required_for_wif_authenticator(
    monkeypatch, provider_param
):
    with monkeypatch.context() as m:
        m.setattr(
            "snowflake.connector.aio._connection.SnowflakeConnection._authenticate",
            lambda *_: None,
        )

        with pytest.raises(ProgrammingError) as excinfo:
            await snowflake.connector.aio.connect(
                account="account",
                authenticator="WORKLOAD_IDENTITY",
                workload_identity_provider=provider_param,
            )
        expected_error_msg = (
            "workload_identity_provider must be set to one of AWS,AZURE,GCP,OIDC when authenticator is WORKLOAD_IDENTITY"
            if provider_param is None
            else f"Unknown workload_identity_provider: '{provider_param}'. Expected one of: AWS, AZURE, GCP, OIDC"
        )
        assert expected_error_msg in str(excinfo.value)


@pytest.mark.parametrize(
    "provider_param",
    [
        # Strongly-typed values.
        AttestationProvider.AZURE,
        AttestationProvider.OIDC,
        # String values.
        "AZURE",
        "OIDC",
    ],
)
async def test_workload_identity_impersonation_path_errors_for_unsupported_providers(
    monkeypatch, provider_param
):
    async def mock_authenticate(*_):
        pass

    with monkeypatch.context() as m:
        m.setattr(
            "snowflake.connector.aio._connection.SnowflakeConnection._authenticate",
            mock_authenticate,
        )

        with pytest.raises(ProgrammingError) as excinfo:
            await snowflake.connector.aio.connect(
                account="account",
                authenticator="WORKLOAD_IDENTITY",
                workload_identity_provider=provider_param,
                workload_identity_impersonation_path=[
                    "sa2@project.iam.gserviceaccount.com"
                ],
            )
        assert (
            "workload_identity_impersonation_path is currently only supported for GCP and AWS."
            in str(excinfo.value)
        )


@pytest.mark.parametrize(
    "provider_param,impersonation_path",
    [
        (AttestationProvider.GCP, ["sa2@project.iam.gserviceaccount.com"]),
        (AttestationProvider.AWS, ["arn:aws:iam::1234567890:role/role2"]),
        ("GCP", ["sa2@project.iam.gserviceaccount.com"]),
        ("AWS", ["arn:aws:iam::1234567890:role/role2"]),
    ],
)
async def test_workload_identity_impersonation_path_populates_auth_class_for_supported_provider(
    monkeypatch, provider_param, impersonation_path
):
    async def mock_authenticate(*_):
        pass

    with monkeypatch.context() as m:
        m.setattr(
            "snowflake.connector.aio._connection.SnowflakeConnection._authenticate",
            mock_authenticate,
        )

        conn = await snowflake.connector.aio.connect(
            account="account",
            authenticator="WORKLOAD_IDENTITY",
            workload_identity_provider=provider_param,
            workload_identity_impersonation_path=impersonation_path,
        )
        assert conn.auth_class.impersonation_path == impersonation_path


@pytest.mark.parametrize(
    "provider_param, parsed_provider",
    [
        # Strongly-typed values.
        (AttestationProvider.AWS, AttestationProvider.AWS),
        (AttestationProvider.AZURE, AttestationProvider.AZURE),
        (AttestationProvider.GCP, AttestationProvider.GCP),
        (AttestationProvider.OIDC, AttestationProvider.OIDC),
        # String values.
        ("AWS", AttestationProvider.AWS),
        ("AZURE", AttestationProvider.AZURE),
        ("GCP", AttestationProvider.GCP),
        ("OIDC", AttestationProvider.OIDC),
    ],
)
async def test_connection_params_are_plumbed_into_authbyworkloadidentity(
    monkeypatch, provider_param, parsed_provider
):
    async def mock_authenticate(*_):
        pass

    with monkeypatch.context() as m:
        m.setattr(
            "snowflake.connector.aio._connection.SnowflakeConnection._authenticate",
            mock_authenticate,
        )

        conn = await snowflake.connector.aio.connect(
            account="my_account_1",
            workload_identity_provider=provider_param,
            workload_identity_entra_resource="api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b",
            token="my_token",
            authenticator="WORKLOAD_IDENTITY",
        )
        assert conn.auth_class.provider == parsed_provider
        assert (
            conn.auth_class.entra_resource
            == "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b"
        )
        assert conn.auth_class.token == "my_token"


async def test_toml_connection_params_are_plumbed_into_authbyworkloadidentity(
    monkeypatch, tmp_path
):
    token_file = write_temp_file(tmp_path / "token.txt", contents="my_token")
    # On Windows, this path includes backslashes which will result in errors while parsing the TOML.
    # Escape the backslashes to ensure it parses correctly.
    token_file_path_escaped = str(token_file).replace("\\", "\\\\")
    connections_file = write_temp_file(
        tmp_path / "connections.toml",
        contents=dedent(
            f"""\
        [default]
        account = "my_account_1"
        authenticator = "WORKLOAD_IDENTITY"
        workload_identity_provider = "OIDC"
        workload_identity_entra_resource = "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b"
        token_file_path = "{token_file_path_escaped}"
        """
        ),
    )

    async def mock_authenticate(*_):
        pass

    with monkeypatch.context() as m:
        m.delenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME", raising=False)
        m.delenv("SNOWFLAKE_CONNECTIONS", raising=False)
        m.setattr(CONFIG_MANAGER, "conf_file_cache", None)
        m.setattr(
            "snowflake.connector.aio._connection.SnowflakeConnection._authenticate",
            mock_authenticate,
        )

        conn = await snowflake.connector.aio.connect(
            connections_file_path=connections_file
        )
        assert conn.auth_class.provider == AttestationProvider.OIDC
        assert (
            conn.auth_class.entra_resource
            == "api://0b2f151f-09a2-46eb-ad5a-39d5ebef917b"
        )
        assert conn.auth_class.token == "my_token"


@pytest.mark.parametrize("rtr_enabled", [True, False])
async def test_single_use_refresh_tokens_option_is_plumbed_into_authbyauthcode_async(
    monkeypatch, rtr_enabled: bool
):
    async def mock_authenticate(*_):
        pass

    with monkeypatch.context() as m:
        m.setattr(
            "snowflake.connector.aio._connection.SnowflakeConnection._authenticate",
            mock_authenticate,
        )

        conn = await snowflake.connector.aio.connect(
            account="my_account_1",
            user="user",
            oauth_client_id="client_id",
            oauth_client_secret="client_secret",
            authenticator="OAUTH_AUTHORIZATION_CODE",
            oauth_enable_single_use_refresh_tokens=rtr_enabled,
        )
        assert conn.auth_class._enable_single_use_refresh_tokens == rtr_enabled


@pytest.mark.skipolddriver
async def test_invalid_authenticator():
    with pytest.raises(ProgrammingError) as excinfo:
        conn = snowflake.connector.aio.SnowflakeConnection(
            account="account",
            authenticator="INVALID",
        )
        await conn.connect()
    assert "Unknown authenticator: INVALID" in str(excinfo.value)


@pytest.mark.skipolddriver
def test_connect_metadata_preservation():
    """Test that the async connect function preserves metadata from SnowflakeConnection.__init__.

    This test verifies that various inspection methods return consistent metadata,
    ensuring IDE support, type checking, and documentation generation work correctly.
    """
    import inspect

    # Use already imported snowflake.connector.aio
    connect = snowflake.connector.aio.connect
    SnowflakeConnection = snowflake.connector.aio.SnowflakeConnection

    # Test 1: Check __name__ is correct
    assert (
        connect.__name__ == "__init__"
    ), f"connect.__name__ should be '__init__', but got '{connect.__name__}'"
    assert (
        connect.__qualname__ == "SnowflakeConnection.__init__"
    ), f"connect.__qualname__ should be 'connect', but got '{connect.__qualname__}'"

    # Test 2: Check __wrapped__ points to SnowflakeConnection.__init__
    assert hasattr(connect, "__wrapped__"), "connect should have __wrapped__ attribute"
    assert (
        connect.__wrapped__ is SnowflakeConnection.__init__
    ), "connect.__wrapped__ should reference SnowflakeConnection.__init__"

    # Test 3: Check __module__ is preserved
    assert hasattr(connect, "__module__"), "connect should have __module__ attribute"
    assert connect.__module__ == SnowflakeConnection.__init__.__module__, (
        f"connect.__module__ should match SnowflakeConnection.__init__.__module__, "
        f"but got '{connect.__module__}' vs '{SnowflakeConnection.__init__.__module__}'"
    )

    # Test 4: Check __doc__ is preserved
    assert hasattr(connect, "__doc__"), "connect should have __doc__ attribute"
    assert (
        connect.__doc__ == SnowflakeConnection.__init__.__doc__
    ), "connect.__doc__ should match SnowflakeConnection.__init__.__doc__"

    # Test 5: Check __annotations__ are preserved (or at least available)
    assert hasattr(
        connect, "__annotations__"
    ), "connect should have __annotations__ attribute"
    src_annotations = getattr(SnowflakeConnection.__init__, "__annotations__", {})
    connect_annotations = getattr(connect, "__annotations__", {})
    assert connect_annotations == src_annotations, (
        f"connect.__annotations__ should match SnowflakeConnection.__init__.__annotations__, "
        f"but got {connect_annotations} vs {src_annotations}"
    )

    # Test 6: Check inspect.signature works correctly
    try:
        connect_sig = inspect.signature(connect)
        source_sig = inspect.signature(SnowflakeConnection.__init__)
        assert str(connect_sig) == str(source_sig), (
            f"inspect.signature(connect) should match inspect.signature(SnowflakeConnection.__init__), "
            f"but got '{connect_sig}' vs '{source_sig}'"
        )
    except Exception as e:
        pytest.fail(f"inspect.signature(connect) failed: {e}")

    # Test 7: Check inspect.getdoc works correctly
    connect_doc = inspect.getdoc(connect)
    source_doc = inspect.getdoc(SnowflakeConnection.__init__)
    assert (
        connect_doc == source_doc
    ), "inspect.getdoc(connect) should match inspect.getdoc(SnowflakeConnection.__init__)"

    # Test 8: Check that connect is callable
    assert callable(connect), "connect should be callable"

    # Test 9: Check type() and __class__ values (important for user introspection)
    assert (
        type(connect).__name__ == "function"
    ), f"type(connect).__name__ should be 'function', but got '{type(connect).__name__}'"
    assert (
        connect.__class__.__name__ == "function"
    ), f"connect.__class__.__name__ should be 'function', but got '{connect.__class__.__name__}'"
    assert inspect.isfunction(
        connect
    ), "connect should be recognized as a function by inspect.isfunction()"

    # Test 10: Verify the function has proper introspection capabilities
    # IDEs and type checkers should be able to resolve parameters
    sig = inspect.signature(connect)
    params = list(sig.parameters.keys())
    assert (
        len(params) > 0
    ), "connect should have parameters from SnowflakeConnection.__init__"
    # Should have parameters like account, user, password, etc.


@pytest.mark.skipolddriver
async def test_server_session_keep_alive_skips_async_check(mock_post_requests):
    """Test that server_session_keep_alive=True skips _all_async_queries_finished check."""
    conn = fake_connector(server_session_keep_alive=True)
    await conn.connect()

    # Mock the async methods we want to verify are called/not called
    conn._all_async_queries_finished = mock.AsyncMock(return_value=True)
    delete_session_mock = mock.AsyncMock()
    # rest attribute is deleted when closing the connection so accessing it in checks would fail
    conn.rest.delete_session = delete_session_mock

    # Close the connection
    await conn.close()

    # Verify _all_async_queries_finished was NOT called
    conn._all_async_queries_finished.assert_not_called()

    # Verify delete_session was NOT called (due to server_session_keep_alive=True)
    delete_session_mock.assert_not_called()


@pytest.mark.skipolddriver
async def test_server_session_keep_alive_false_calls_async_check(mock_post_requests):
    """Test that server_session_keep_alive=False calls _all_async_queries_finished check."""
    conn = fake_connector(server_session_keep_alive=False)
    await conn.connect()

    # Mock the async methods we want to verify are called
    conn._all_async_queries_finished = mock.AsyncMock(return_value=True)
    delete_session_mock = mock.AsyncMock()
    # rest attribute is deleted when closing the connection so accessing it in checks would fail
    conn.rest.delete_session = delete_session_mock

    # Close the connection
    await conn.close()

    # Verify _all_async_queries_finished WAS called
    conn._all_async_queries_finished.assert_called_once()

    # Verify delete_session WAS called (since async queries are finished and keep_alive=False)
    delete_session_mock.assert_called_once()
