#!/usr/bin/env python
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

import filecmp
import logging
import os
from io import BytesIO
from logging import getLogger
from os import path
from unittest import mock

import pytest

from snowflake.connector import OperationalError

try:
    from snowflake.connector.util_text import random_string
except ImportError:
    from test.randomize import random_string

try:
    from src.snowflake.connector.compat import IS_WINDOWS
except ImportError:
    import platform

    IS_WINDOWS = platform.system() == "Windows"

from test.generate_test_files import generate_k_lines_of_n_files

THIS_DIR = path.dirname(path.realpath(__file__))

logger = getLogger(__name__)

pytestmark = pytest.mark.asyncio
CLOUD = os.getenv("cloud_provider", "dev")


async def test_utf8_filename(tmp_path, aio_connection):
    test_file = tmp_path / "utf卡豆.csv"
    test_file.write_text("1,2,3\n")
    stage_name = random_string(5, "test_utf8_filename_")
    await aio_connection.connect()
    cursor = aio_connection.cursor()
    await cursor.execute(f"create temporary stage {stage_name}")
    await (
        await cursor.execute(
            "PUT 'file://{}' @{}".format(str(test_file).replace("\\", "/"), stage_name)
        )
    ).fetchall()
    await cursor.execute(f"select $1, $2, $3 from  @{stage_name}")
    assert await cursor.fetchone() == ("1", "2", "3")


async def test_put_threshold(tmp_path, aio_connection, is_public_test):
    if is_public_test:
        pytest.xfail(
            reason="This feature hasn't been rolled out for public Snowflake deployments yet."
        )
    file_name = "test_put_get_with_aws_token.txt.gz"
    stage_name = random_string(5, "test_put_get_threshold_")
    file = tmp_path / file_name
    file.touch()
    await aio_connection.connect()
    cursor = aio_connection.cursor()
    await cursor.execute(f"create temporary stage {stage_name}")
    from snowflake.connector.file_transfer_agent import SnowflakeFileTransferAgent

    with mock.patch(
        "snowflake.connector.file_transfer_agent.SnowflakeFileTransferAgent",
        autospec=SnowflakeFileTransferAgent,
    ) as mock_agent:
        await cursor.execute(f"put file://{file} @{stage_name} threshold=156")
    assert mock_agent.call_args[1].get("multipart_threshold", -1) == 156


# Snowflake on GCP does not support multipart uploads
@pytest.mark.xfail(reason="multipart transfer is not merged yet")
# @pytest.mark.aws
# @pytest.mark.azure
@pytest.mark.parametrize("use_stream", [False, True])
async def test_multipart_put(aio_connection, tmp_path, use_stream):
    """This test does a multipart upload of a smaller file and then downloads it."""
    stage_name = random_string(5, "test_multipart_put_")
    chunk_size = 6967790
    # Generate about 12 MB
    generate_k_lines_of_n_files(100_000, 1, tmp_dir=str(tmp_path))
    get_dir = tmp_path / "get_dir"
    get_dir.mkdir()
    upload_file = tmp_path / "file0"
    await aio_connection.connect()
    cursor = aio_connection.cursor()
    await cursor.execute(f"create temporary stage {stage_name}")
    real_cmd_query = aio_connection.cmd_query

    async def fake_cmd_query(*a, **kw):
        """Create a mock function to inject some value into the returned JSON"""
        ret = await real_cmd_query(*a, **kw)
        ret["data"]["threshold"] = chunk_size
        return ret

    with mock.patch.object(aio_connection, "cmd_query", side_effect=fake_cmd_query):
        with mock.patch("snowflake.connector.constants.S3_CHUNK_SIZE", chunk_size):
            if use_stream:
                kw = {
                    "command": f"put file://file0 @{stage_name} AUTO_COMPRESS=FALSE",
                    "file_stream": BytesIO(upload_file.read_bytes()),
                }
            else:
                kw = {
                    "command": f"put file://{upload_file} @{stage_name} AUTO_COMPRESS=FALSE",
                }
            await cursor.execute(**kw)
            res = await cursor.execute(f"list @{stage_name}")
            print(await res.fetchall())
    await cursor.execute(f"get @{stage_name}/{upload_file.name} file://{get_dir}")
    downloaded_file = get_dir / upload_file.name
    assert downloaded_file.exists()
    assert filecmp.cmp(upload_file, downloaded_file)


async def test_put_special_file_name(tmp_path, aio_connection):
    test_file = tmp_path / "data~%23.csv"
    test_file.write_text("1,2,3\n")
    stage_name = random_string(5, "test_special_filename_")
    await aio_connection.connect()
    cursor = aio_connection.cursor()
    await cursor.execute(f"create temporary stage {stage_name}")
    filename_in_put = str(test_file).replace("\\", "/")
    await (
        await cursor.execute(
            f"PUT 'file://{filename_in_put}' @{stage_name}",
        )
    ).fetchall()
    await cursor.execute(f"select $1, $2, $3 from  @{stage_name}")
    assert await cursor.fetchone() == ("1", "2", "3")


async def test_get_empty_file(tmp_path, aio_connection):
    test_file = tmp_path / "data.csv"
    test_file.write_text("1,2,3\n")
    stage_name = random_string(5, "test_get_empty_file_")
    await aio_connection.connect()
    cur = aio_connection.cursor()
    await cur.execute(f"create temporary stage {stage_name}")
    filename_in_put = str(test_file).replace("\\", "/")
    await cur.execute(
        f"PUT 'file://{filename_in_put}' @{stage_name}",
    )
    empty_file = tmp_path / "foo.csv"
    with pytest.raises(OperationalError, match=".*the file does not exist.*$"):
        await cur.execute(f"GET @{stage_name}/foo.csv file://{tmp_path}")
    assert not empty_file.exists()


@pytest.mark.parametrize("auto_compress", ["TRUE", "FALSE"])
@pytest.mark.skipif(IS_WINDOWS, reason="not supported on Windows")
async def test_get_file_permission(tmp_path, aio_connection, caplog, auto_compress):
    test_file = tmp_path / "data.csv"
    test_file.write_text("1,2,3\n")
    stage_name = random_string(5, "test_get_empty_file_")
    await aio_connection.connect()
    cur = aio_connection.cursor()
    await cur.execute(f"create temporary stage {stage_name}")
    filename_in_put = str(test_file).replace("\\", "/")
    await cur.execute(
        f"PUT 'file://{filename_in_put}' @{stage_name} AUTO_COMPRESS={auto_compress}",
    )
    test_file.unlink()

    with caplog.at_level(logging.ERROR):
        await cur.execute(f"GET @{stage_name}/data.csv file://{tmp_path}")
    assert "FileNotFoundError" not in caplog.text
    assert len(list(tmp_path.iterdir())) == 1
    downloaded_file = next(tmp_path.iterdir())

    # get the default mask, usually it is 0o022
    default_mask = os.umask(0)
    os.umask(default_mask)
    # files by default are given the permission 600 (Octal)
    # umask is for denial, we need to negate
    assert oct(os.stat(downloaded_file).st_mode)[-3:] == oct(0o600 & ~default_mask)[-3:]


@pytest.mark.parametrize("auto_compress", ["TRUE", "FALSE"])
@pytest.mark.skipif(IS_WINDOWS, reason="not supported on Windows")
async def test_get_unsafe_file_permission_when_flag_set(
    tmp_path, aio_connection, caplog, auto_compress
):
    test_file = tmp_path / "data.csv"
    test_file.write_text("1,2,3\n")
    stage_name = random_string(5, "test_get_empty_file_")
    await aio_connection.connect()
    aio_connection.unsafe_file_write = True
    cur = aio_connection.cursor()
    await cur.execute(f"create temporary stage {stage_name}")
    filename_in_put = str(test_file).replace("\\", "/")
    await cur.execute(
        f"PUT 'file://{filename_in_put}' @{stage_name} AUTO_COMPRESS={auto_compress}",
    )
    test_file.unlink()

    with caplog.at_level(logging.ERROR):
        await cur.execute(f"GET @{stage_name}/data.csv file://{tmp_path}")
    assert "FileNotFoundError" not in caplog.text
    assert len(list(tmp_path.iterdir())) == 1
    downloaded_file = next(tmp_path.iterdir())

    # get the default mask, usually it is 0o022
    default_mask = os.umask(0)
    os.umask(default_mask)
    # when unsafe_file_write is set, permission is 644 (Octal)
    # umask is for denial, we need to negate
    assert oct(os.stat(downloaded_file).st_mode)[-3:] == oct(0o666 & ~default_mask)[-3:]


async def test_get_multiple_files_with_same_name(tmp_path, aio_connection, caplog):
    test_file = tmp_path / "data.csv"
    test_file.write_text("1,2,3\n")
    stage_name = random_string(5, "test_get_multiple_files_with_same_name_")
    await aio_connection.connect()
    cur = aio_connection.cursor()
    await cur.execute(f"create temporary stage {stage_name}")
    filename_in_put = str(test_file).replace("\\", "/")
    await cur.execute(
        f"PUT 'file://{filename_in_put}' @{stage_name}/data/1/",
    )
    await cur.execute(
        f"PUT 'file://{filename_in_put}' @{stage_name}/data/2/",
    )

    # Verify files are uploaded before attempting GET
    import asyncio

    for _ in range(10):  # Wait up to 10 seconds for files to be available
        file_list = await (await cur.execute(f"LS @{stage_name}")).fetchall()
        if len(file_list) >= 2:  # Both files should be available
            break
        await asyncio.sleep(1)
    else:
        pytest.fail(f"Files not available in stage after 10 seconds: {file_list}")

    with caplog.at_level(logging.WARNING):
        try:
            await cur.execute(
                f"GET @{stage_name} file://{tmp_path} PATTERN='.*data.csv.gz'"
            )
        except OperationalError:
            # This can happen due to cloud storage timing issues
            pass

    # Check for the expected warning message
    assert (
        "Downloading multiple files with the same name" in caplog.text
    ), f"Expected warning not found in logs: {caplog.text}"


async def test_transfer_error_message(tmp_path, aio_connection):
    test_file = tmp_path / "data.csv"
    test_file.write_text("1,2,3\n")
    stage_name = random_string(5, "test_utf8_filename_")
    await aio_connection.connect()
    cursor = aio_connection.cursor()
    await cursor.execute(f"create temporary stage {stage_name}")
    with mock.patch(
        "snowflake.connector.aio._storage_client.SnowflakeStorageClient.finish_upload",
        side_effect=ConnectionError,
    ):
        with pytest.raises(OperationalError):
            (
                await cursor.execute(
                    "PUT 'file://{}' @{}".format(
                        str(test_file).replace("\\", "/"), stage_name
                    )
                )
            ).fetchall()


@pytest.mark.skipolddriver
async def test_put_md5(tmp_path, aio_connection):
    """This test uploads a single and a multi part file and makes sure that md5 is populated."""
    # Create files directly without subfolders for efficiency
    # Small file for single-part upload test
    small_test_file = tmp_path / "small_file.txt"
    small_test_file.write_text("test content\n")  # Minimal content

    # Big file for multi-part upload test - 200MB (well over 64MB threshold)
    big_test_file = tmp_path / "big_file.txt"
    chunk_size = 1024 * 1024  # 1MB chunks
    chunk_data = "A" * chunk_size  # 1MB of 'A' characters
    with open(big_test_file, "w") as f:
        for _ in range(200):  # Write 200MB total
            f.write(chunk_data)

    stage_name = random_string(5, "test_put_md5_")
    # Use the async connection for PUT/LS operations
    await aio_connection.connect()
    async with aio_connection.cursor() as cur:
        await cur.execute(f"create temporary stage {stage_name}")

        # Upload both files in sequence
        small_filename_in_put = str(small_test_file).replace("\\", "/")
        big_filename_in_put = str(big_test_file).replace("\\", "/")

        await cur.execute(
            f"PUT 'file://{small_filename_in_put}' @{stage_name}/small AUTO_COMPRESS = FALSE"
        )
        await cur.execute(
            f"PUT 'file://{big_filename_in_put}' @{stage_name}/big AUTO_COMPRESS = FALSE"
        )

        # Verify MD5 is populated for both files
        file_list = await (await cur.execute(f"LS @{stage_name}")).fetchall()
        assert all(
            file_info[2] is not None for file_info in file_list
        ), "MD5 should be populated for all uploaded files"
