#!/usr/bin/env python
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

import os
import pathlib
from getpass import getuser
from logging import getLogger
from os import path

import pytest

try:
    from parameters import CONNECTION_PARAMETERS_ADMIN
except ImportError:
    CONNECTION_PARAMETERS_ADMIN = {}

THIS_DIR = path.dirname(path.realpath(__file__))

logger = getLogger(__name__)


@pytest.fixture()
def test_data(request, conn_cnx, db_parameters):
    def connection():
        """Abstracting away connection creation."""
        return conn_cnx()

    return create_test_data(request, db_parameters, connection)


@pytest.fixture()
def s3_test_data(request, conn_cnx, db_parameters):
    def connection():
        """Abstracting away connection creation."""
        return conn_cnx(
            user=db_parameters["user"],
            account=db_parameters["account"],
        )

    return create_test_data(request, db_parameters, connection)


async def create_test_data(request, db_parameters, connection):
    assert "AWS_ACCESS_KEY_ID" in os.environ, "AWS_ACCESS_KEY_ID is missing"
    assert "AWS_SECRET_ACCESS_KEY" in os.environ, "AWS_SECRET_ACCESS_KEY is missing"

    unique_name = db_parameters["name"]
    database_name = f"{unique_name}_db"
    warehouse_name = f"{unique_name}_wh"

    async def fin():
        async with connection() as cnx:
            async with cnx.cursor() as cur:
                await cur.execute(f"drop database {database_name}")
                await cur.execute(f"drop warehouse {warehouse_name}")

    request.addfinalizer(fin)

    class TestData:
        def __init__(self):
            self.test_data_dir = (pathlib.Path(__file__).parent / "data").absolute()
            self.AWS_ACCESS_KEY_ID = "'{}'".format(os.environ["AWS_ACCESS_KEY_ID"])
            self.AWS_SECRET_ACCESS_KEY = "'{}'".format(
                os.environ["AWS_SECRET_ACCESS_KEY"]
            )
            self.stage_name = f"{unique_name}_stage"
            self.warehouse_name = warehouse_name
            self.database_name = database_name
            self.connection = connection
            self.user_bucket = os.getenv(
                "SF_AWS_USER_BUCKET", f"sfc-eng-regression/{getuser()}/reg"
            )

    ret = TestData()

    async with connection() as cnx:
        async with cnx.cursor() as cur:
            await cur.execute("use role sysadmin")
            await cur.execute(
                """
create or replace warehouse {}
warehouse_size = 'small' warehouse_type='standard'
auto_suspend=1800
""".format(
                    warehouse_name
                )
            )
            await cur.execute(
                """
create or replace database {}
""".format(
                    database_name
                )
            )
            await cur.execute(
                """
create or replace schema pytesting_schema
"""
            )
            await cur.execute(
                """
create or replace file format VSV type = 'CSV'
field_delimiter='|' error_on_column_count_mismatch=false
    """
            )
    return ret


@pytest.mark.skipif(
    not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible."
)
async def test_load_s3(test_data):
    async with test_data.connection() as cnx:
        async with cnx.cursor() as cur:
            await cur.execute(f"use warehouse {test_data.warehouse_name}")
            await cur.execute(f"use schema {test_data.database_name}.pytesting_schema")
            await cur.execute(
                """
create or replace table tweets(created_at timestamp,
id number, id_str string, text string, source string,
in_reply_to_status_id number, in_reply_to_status_id_str string,
in_reply_to_user_id number, in_reply_to_user_id_str string,
in_reply_to_screen_name string, user__id number, user__id_str string,
user__name string, user__screen_name string, user__location string,
user__description string, user__url string,
user__entities__description__urls string, user__protected string,
user__followers_count number, user__friends_count number,
user__listed_count number, user__created_at timestamp,
user__favourites_count number, user__utc_offset number,
user__time_zone string, user__geo_enabled string, user__verified string,
user__statuses_count number, user__lang string,
user__contributors_enabled string, user__is_translator string,
user__profile_background_color string,
user__profile_background_image_url string,
user__profile_background_image_url_https string,
user__profile_background_tile string, user__profile_image_url string,
user__profile_image_url_https string, user__profile_link_color string,
user__profile_sidebar_border_color string,
user__profile_sidebar_fill_color string, user__profile_text_color string,
user__profile_use_background_image string, user__default_profile string,
user__default_profile_image string, user__following string,
user__follow_request_sent string, user__notifications string, geo string,
coordinates string, place string, contributors string, retweet_count number,
favorite_count number, entities__hashtags string, entities__symbols string,
entities__urls string, entities__user_mentions string, favorited string,
retweeted string, lang string)
"""
            )
            await cur.execute("ls @%tweets")
            assert cur.rowcount == 0, (
                "table newly created should not have any files in its " "staging area"
            )
            await cur.execute(
                """
copy into tweets from s3://sfc-eng-data/twitter/O1k/tweets/
credentials=(AWS_KEY_ID={aws_access_key_id}
AWS_SECRET_KEY={aws_secret_access_key})
file_format=(skip_header=1 null_if=('') field_optionally_enclosed_by='"')
""".format(
                    aws_access_key_id=test_data.AWS_ACCESS_KEY_ID,
                    aws_secret_access_key=test_data.AWS_SECRET_ACCESS_KEY,
                )
            )
            assert cur.rowcount == 1, "copy into tweets did not set rowcount to 1"
            results = await cur.fetchall()
            assert (
                results[0][0] == "s3://sfc-eng-data/twitter/O1k/tweets/1.csv.gz"
            ), "ls @%tweets failed"
            await cur.execute("drop table tweets")


@pytest.mark.skipif(
    not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible."
)
async def test_put_local_file(test_data):
    async with test_data.connection() as cnx:
        async with cnx.cursor() as cur:
            await cur.execute(
                "alter session set DISABLE_PUT_AND_GET_ON_EXTERNAL_STAGE=false"
            )
            await cur.execute(f"use warehouse {test_data.warehouse_name}")
            await cur.execute(
                f"""use schema {test_data.database_name}.pytesting_schema"""
            )
            await cur.execute(
                """
create or replace table pytest_putget_t1 (c1 STRING, c2 STRING, c3 STRING,
c4 STRING, c5 STRING, c6 STRING, c7 STRING, c8 STRING, c9 STRING)
stage_file_format = (field_delimiter = '|' error_on_column_count_mismatch=false)
stage_copy_options = (purge=false)
stage_location = (url = 's3://sfc-eng-regression/jenkins/{stage_name}'
credentials = (
AWS_KEY_ID={aws_access_key_id}
AWS_SECRET_KEY={aws_secret_access_key}))
""".format(
                    aws_access_key_id=test_data.AWS_ACCESS_KEY_ID,
                    aws_secret_access_key=test_data.AWS_SECRET_ACCESS_KEY,
                    stage_name=test_data.stage_name,
                )
            )
            await cur.execute(
                """put file://{}/ExecPlatform/Database/data/orders_10*.csv @%pytest_putget_t1""".format(
                    str(test_data.test_data_dir)
                )
            )
            await cur.execute("ls @%pytest_putget_t1")
            _ = await cur.fetchall()
            assert cur.rowcount == 2, "ls @%pytest_putget_t1 did not return 2 rows"
            await cur.execute("copy into pytest_putget_t1")
            results = await cur.fetchall()
            assert len(results) == 2, "2 files were not copied"
            assert results[0][1] == "LOADED", "file 1 was not loaded after copy"
            assert results[1][1] == "LOADED", "file 2 was not loaded after copy"

            await cur.execute("select count(*) from pytest_putget_t1")
            results = await cur.fetchall()
            assert results[0][0] == 73, "73 rows not loaded into putest_putget_t1"
            await cur.execute("rm @%pytest_putget_t1")
            results = await cur.fetchall()
            assert len(results) == 2, "two files were not removed"
            await cur.execute(
                "select STATUS from information_schema.load_history where table_name='PYTEST_PUTGET_T1'"
            )
            results = await cur.fetchall()
            assert results[0][0] == "LOADED", "history does not show file to be loaded"
            await cur.execute("drop table pytest_putget_t1")


@pytest.mark.skipif(
    not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible."
)
async def test_put_load_from_user_stage(test_data):
    async with test_data.connection() as cnx:
        async with cnx.cursor() as cur:
            await cur.execute(
                "alter session set DISABLE_PUT_AND_GET_ON_EXTERNAL_STAGE=false"
            )
            await cur.execute(
                """
use warehouse {}
""".format(
                    test_data.warehouse_name
                )
            )
            await cur.execute(
                """
use schema {}.pytesting_schema
""".format(
                    test_data.database_name
                )
            )
            await cur.execute(
                """
create or replace stage {stage_name}
url='s3://{user_bucket}/{stage_name}'
credentials = (
AWS_KEY_ID={aws_access_key_id}
AWS_SECRET_KEY={aws_secret_access_key})
""".format(
                    aws_access_key_id=test_data.AWS_ACCESS_KEY_ID,
                    aws_secret_access_key=test_data.AWS_SECRET_ACCESS_KEY,
                    user_bucket=test_data.user_bucket,
                    stage_name=test_data.stage_name,
                )
            )
            await cur.execute(
                """
create or replace table pytest_putget_t2 (c1 STRING, c2 STRING, c3 STRING,
c4 STRING, c5 STRING, c6 STRING, c7 STRING, c8 STRING, c9 STRING)
"""
            )
            await cur.execute(
                """put file://{}/ExecPlatform/Database/data/orders_10*.csv @{}""".format(
                    test_data.test_data_dir, test_data.stage_name
                )
            )
            # two files should have been put in the staging are
            results = await cur.fetchall()
            assert len(results) == 2

            await cur.execute("ls @%pytest_putget_t2")
            results = await cur.fetchall()
            assert len(results) == 0, "no files should have been loaded yet"

            # copy
            await cur.execute(
                """
copy into pytest_putget_t2 from @{stage_name}
file_format = (field_delimiter = '|' error_on_column_count_mismatch=false)
purge=true
""".format(
                    stage_name=test_data.stage_name
                )
            )
            results = sorted(await cur.fetchall())
            assert len(results) == 2, "copy failed to load two files from the stage"
            assert results[0][
                0
            ] == "s3://{user_bucket}/{stage_name}/orders_100.csv.gz".format(
                user_bucket=test_data.user_bucket,
                stage_name=test_data.stage_name,
            ), "copy did not load file orders_100"

            assert results[1][
                0
            ] == "s3://{user_bucket}/{stage_name}/orders_101.csv.gz".format(
                user_bucket=test_data.user_bucket,
                stage_name=test_data.stage_name,
            ), "copy did not load file orders_101"

            # should be empty (purged)
            await cur.execute(f"ls @{test_data.stage_name}")
            results = await cur.fetchall()
            assert len(results) == 0, "copied files not purged"
            await cur.execute("drop table pytest_putget_t2")
            await cur.execute(f"drop stage {test_data.stage_name}")


@pytest.mark.aws
@pytest.mark.skipif(
    not CONNECTION_PARAMETERS_ADMIN, reason="Snowflake admin account is not accessible."
)
async def test_unload(db_parameters, s3_test_data):
    async with s3_test_data.connection() as cnx:
        async with cnx.cursor() as cur:
            await cur.execute(f"""use warehouse {s3_test_data.warehouse_name}""")
            await cur.execute(
                f"""use schema {s3_test_data.database_name}.pytesting_schema"""
            )
            await cur.execute(
                """
create or replace stage {stage_name}
url='s3://{user_bucket}/{stage_name}/unload/'
credentials = (
AWS_KEY_ID={aws_access_key_id}
AWS_SECRET_KEY={aws_secret_access_key})
""".format(
                    aws_access_key_id=s3_test_data.AWS_ACCESS_KEY_ID,
                    aws_secret_access_key=s3_test_data.AWS_SECRET_ACCESS_KEY,
                    user_bucket=s3_test_data.user_bucket,
                    stage_name=s3_test_data.stage_name,
                )
            )

            await cur.execute(
                """
CREATE OR REPLACE TABLE pytest_t3  (c1 STRING, c2 STRING, c3 STRING,
c4 STRING, c5 STRING, c6 STRING, c7 STRING, c8 STRING, c9 STRING)
stage_file_format = (format_name = 'vsv' field_delimiter = '|'
error_on_column_count_mismatch=false)
"""
            )
            await cur.execute(
                """
alter stage {stage_name} set file_format = (format_name = 'VSV' )
""".format(
                    stage_name=s3_test_data.stage_name
                )
            )

            # make sure its clean
            await cur.execute(f"rm @{s3_test_data.stage_name}")

            # put local file
            await cur.execute(
                "put file://{}/ExecPlatform/Database/data/orders_10*.csv @%pytest_t3".format(
                    s3_test_data.test_data_dir
                )
            )

            # copy into table
            await cur.execute(
                """
copy into pytest_t3
file_format = (field_delimiter = '|' error_on_column_count_mismatch=false)
purge=true
"""
            )
            # unload from table
            await cur.execute(
                """
copy into @{stage_name}/pytest_t3/data_
from pytest_t3 file_format=(format_name='VSV' compression='gzip')
max_file_size=10000000
""".format(
                    stage_name=s3_test_data.stage_name
                )
            )

            # load the data back to another table
            await cur.execute(
                """
CREATE OR REPLACE TABLE pytest_t3_copy
(c1 STRING, c2 STRING, c3 STRING, c4 STRING, c5 STRING,
c6 STRING, c7 STRING, c8 STRING, c9 STRING)
stage_file_format = (format_name = 'VSV' )
"""
            )

            await cur.execute(
                """
copy into pytest_t3_copy
from @{stage_name}/pytest_t3/data_ return_failed_only=true
""".format(
                    stage_name=s3_test_data.stage_name
                )
            )

            # check to make sure they are equal
            await cur.execute(
                """
(select * from pytest_t3 minus select * from pytest_t3_copy)
union
(select * from pytest_t3_copy minus select * from pytest_t3)
"""
            )
            assert cur.rowcount == 0, "unloaded/reloaded data were not the same"
            # clean stage
            await cur.execute(
                "rm @{stage_name}/pytest_t3/data_".format(
                    stage_name=s3_test_data.stage_name
                )
            )
            assert cur.rowcount == 1, "only one file was expected to be removed"

            # unload with deflate
            await cur.execute(
                """
copy into @{stage_name}/pytest_t3/data_
from pytest_t3 file_format=(format_name='VSV' compression='deflate')
max_file_size=10000000
""".format(
                    stage_name=s3_test_data.stage_name
                )
            )
            results = await cur.fetchall()
            assert results[0][0] == 73, "73 rows were expected to be loaded"

            # create a table to unload data into
            await cur.execute(
                """
CREATE OR REPLACE TABLE pytest_t3_copy
(c1 STRING, c2 STRING, c3 STRING, c4 STRING, c5 STRING, c6 STRING,
c7 STRING, c8 STRING, c9 STRING)
stage_file_format = (format_name = 'VSV'
compression='deflate')
"""
            )
            results = await cur.fetchall()
            assert results[0][0] == "Table PYTEST_T3_COPY successfully created."

            await cur.execute(
                """
alter stage {stage_name} set file_format = (format_name = 'VSV'
     compression='deflate')""".format(
                    stage_name=s3_test_data.stage_name
                )
            )

            await cur.execute(
                """
copy into pytest_t3_copy from @{stage_name}/pytest_t3/data_
return_failed_only=true
""".format(
                    stage_name=s3_test_data.stage_name
                )
            )
            results = await cur.fetchall()
            assert results[0][2] == "LOADED"
            assert results[0][4] == 73
            # check to make sure they are equal
            await cur.execute(
                """
(select * from pytest_t3 minus select * from pytest_t3_copy) union
(select * from pytest_t3_copy minus select * from pytest_t3)"""
            )
            assert cur.rowcount == 0, "unloaded/reloaded data were not the same"
            await cur.execute(
                "rm @{stage_name}/pytest_t3/data_".format(
                    stage_name=s3_test_data.stage_name
                )
            )
            assert cur.rowcount == 1, "only one file was expected to be removed"

            # clean stage
            await cur.execute(
                "rm @{stage_name}/pytest_t3/data_".format(
                    stage_name=s3_test_data.stage_name
                )
            )

            await cur.execute("drop table pytest_t3_copy")
            await cur.execute(f"drop stage {s3_test_data.stage_name}")
