#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
"""
Scala UDF utilities for Snowpark Connect.

This module provides utilities for creating and managing Scala User-Defined Functions (UDFs)
in Snowflake through Snowpark Connect. It handles the conversion between different type systems
(Snowpark, Scala, Snowflake, Spark protobuf) and generates the necessary SQL DDL statements
for UDF creation.

Key components:
- ScalaUdf: Reference class for Scala UDFs
- ScalaUDFDef: Definition class for Scala UDF creation
- Type mapping functions for different type systems
- UDF creation and management utilities
"""
import re
from dataclasses import dataclass
from typing import List, Union

import snowflake.snowpark.types as snowpark_type
import snowflake.snowpark_connect.includes.python.pyspark.sql.connect.proto.types_pb2 as types_proto
from snowflake.snowpark_connect.error.error_codes import ErrorCodes
from snowflake.snowpark_connect.error.error_utils import attach_custom_error_code
from snowflake.snowpark_connect.type_mapping import map_type_to_snowflake_type
from snowflake.snowpark_connect.utils.jvm_udf_utils import (
    NullHandling,
    Param,
    ReturnType,
    Signature,
    build_jvm_udxf_imports,
)
from snowflake.snowpark_connect.utils.snowpark_connect_logging import logger
from snowflake.snowpark_connect.utils.udf_utils import (
    ProcessCommonInlineUserDefinedFunction,
)

# Prefix used for internally generated Scala UDF names to avoid conflicts
CREATE_SCALA_UDF_PREFIX = "__SC_BUILD_IN_CREATE_UDF_SCALA_"


class ScalaUdf:
    """
    Reference class for Scala UDFs, providing similar properties like Python UserDefinedFunction.

    This class serves as a lightweight reference to a Scala UDF that has been created
    in Snowflake, storing the essential metadata needed for function calls.
    """

    def __init__(
        self,
        name: str,
        input_types: List[snowpark_type.DataType],
        return_type: snowpark_type.DataType,
    ) -> None:
        """
        Initialize a Scala UDF reference.

        Args:
            name: The name of the UDF in Snowflake
            input_types: List of input parameter types
            return_type: The return type of the UDF
        """
        self.name = name
        self._input_types = input_types
        self._return_type = return_type


@dataclass(frozen=True)
class ScalaUDFDef:
    """
    Complete definition for creating a Scala UDF in Snowflake.

    Contains all the information needed to generate the CREATE FUNCTION SQL statement
    and the Scala code body for the UDF.

    Attributes:
        name: UDF name
        signature: SQL signature (for Snowflake function definition)
        scala_signature: Scala signature (for Scala code generation)
        imports: List of JAR files to import
        null_handling: Null handling behavior (defaults to RETURNS_NULL_ON_NULL_INPUT)
    """

    name: str
    signature: Signature
    scala_signature: Signature
    scala_invocation_args: List[str]
    imports: List[str]
    null_handling: NullHandling = NullHandling.RETURNS_NULL_ON_NULL_INPUT

    # -------------------- DDL Emitter --------------------

    def _gen_body_scala(self) -> str:
        """
        Generate the Scala code body for the UDF.

        Creates a Scala object that loads the serialized function from a binary file
        and provides a run method to execute it.

        Returns:
            String containing the complete Scala code for the UDF body
        """
        # Convert Array to Seq for Scala compatibility in function signatures.
        udf_func_input_types = (
            ", ".join(p.data_type for p in self.scala_signature.params)
        ).replace("Array", "Seq")
        # Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
        joined_wrapper_arg_and_input_types_str = ", ".join(
            f"{p.name}: {p.data_type}" for p in self.scala_signature.params
        )
        # This is used in defining the input types for the wrapper function. For Maps to work correctly with Scala UDFs,
        # we need to set the Map types to Map[String, String]. These get cast to the respective original types
        # when the original UDF function is invoked.
        wrapper_arg_and_input_types_str = re.sub(
            pattern=r"Map\[\w+,\s\w+\]",
            repl="Map[String, String]",
            string=joined_wrapper_arg_and_input_types_str,
        )
        invocation_args = ", ".join(self.scala_invocation_args)

        # Cannot directly return a map from a Scala UDF due to issues with non-String values. Snowflake SQL Scala only
        # supports Map[String, String] as input types. Therefore, we convert the map to a JSON string before returning.
        # This is processed as a Variant by SQL.
        udf_func_return_type = self.scala_signature.returns.data_type
        is_map_return = udf_func_return_type.startswith("Map")
        wrapper_return_type = "String" if is_map_return else udf_func_return_type

        is_variant_input = udf_func_input_types.startswith("Variant")
        # For handling Seq type correctly, ensure that the wrapper function always uses Array as its input and
        # return types (when required) and the wrapped function uses Seq.
        udf_func_return_type = udf_func_return_type.replace("Array", "Seq")
        is_seq_return = udf_func_return_type.startswith("Seq")
        from_variant_imports = ""

        # Need to call the map to JSON string converter when a map is returned by the user's function.
        if is_map_return:
            invoke_udf_func = f"write(func({invocation_args}))"
        elif is_seq_return:
            # TODO: SNOW-2339385 Handle Array[T] return types correctly. Currently, only Seq[T] is supported.
            invoke_udf_func = f"func({invocation_args}).toArray"
        elif is_variant_input:
            udf_func_input_types = "Any"
            # When the UDF input is a Variant type (typically for Struct/complex types), we need to:
            # 1. Deserialize the UdfPacket to get the input encoder information
            # 2. Use the encoder to understand the target type structure (e.g., case class fields)
            # 3. Extract values from the Variant object and convert them to the appropriate Scala types
            from_variant_imports = """
    import com.snowflake.sas.scala.UdfPacketUtils._

    """
            invoke_udf_func = f"func(udfPacket.fromVariant({invocation_args}))"
        else:
            invoke_udf_func = f"func({invocation_args})"

        # The lines of code below are required only when a Map is returned by the UDF. This is needed to serialize the
        # map output to a JSON string.
        map_return_imports = (
            ""
            if not is_map_return
            else """
import shaded_json4s._
import shaded_json4s.native.Serialization._
import shaded_json4s.native.Serialization
"""
        )
        map_return_formatter = (
            ""
            if not is_map_return
            else """
  implicit val formats = Serialization.formats(NoTypeHints)
"""
        )

        return f"""import org.apache.spark.sql.connect.common.UdfPacket
    {map_return_imports}
    {from_variant_imports}
    import com.snowflake.sas.scala.Utils
    import com.snowflake.snowpark_java.types.Variant

    object __RecreatedSparkUdf {{
      {map_return_formatter}

      private lazy val udfPacket: UdfPacket = Utils.deserializeUdfPacket("{self.name}.bin")

      private lazy val func: ({udf_func_input_types}) => {udf_func_return_type} = udfPacket.function.asInstanceOf[({udf_func_input_types}) => {udf_func_return_type}]


      def __wrapperFunc({wrapper_arg_and_input_types_str}): {wrapper_return_type} = {{
        {invoke_udf_func}
      }}
    }}
    """

    def to_create_function_sql(self) -> str:
        """
        Generate the complete CREATE FUNCTION SQL statement for the Scala UDF.

        Creates a Snowflake CREATE OR REPLACE TEMPORARY FUNCTION statement with
        all necessary clauses including language, runtime version, packages,
        imports, and the Scala code body.

        Returns:
            Complete SQL DDL statement for creating the UDF
        """
        # self.validate()

        args = ", ".join(f"{p.name} {p.data_type}" for p in self.signature.params)
        ret_type = self.signature.returns.data_type

        def quote_single(s: str) -> str:
            """Helper function to wrap strings in single quotes for SQL."""
            return "'" + s + "'"

        # Handler and imports
        imports_sql = f"IMPORTS = ({', '.join(quote_single(x) for x in self.imports)})"

        return f"""
CREATE OR REPLACE TEMPORARY FUNCTION {self.name}({args})
RETURNS {ret_type}
LANGUAGE SCALA
{self.null_handling.value}
RUNTIME_VERSION = 2.12
PACKAGES = ('com.snowflake:snowpark:latest')
{imports_sql}
HANDLER = '__RecreatedSparkUdf.__wrapperFunc'
AS
$$
{self._gen_body_scala()}
$$;"""


def create_scala_udf(pciudf: ProcessCommonInlineUserDefinedFunction) -> ScalaUdf:
    """
    Create a Scala UDF in Snowflake from a ProcessCommonInlineUserDefinedFunction object.

    This function handles the complete process of creating a Scala UDF:
    1. Generates a unique function name if not provided
    2. Checks for existing UDFs in the session cache
    3. Creates the necessary imports list
    4. Maps types between different systems (Snowpark, Scala, Snowflake)
    5. Generates and executes the CREATE FUNCTION SQL statement

    If the UDF already exists in the session cache, it will be reused.

    Args:
        pciudf: The ProcessCommonInlineUserDefinedFunction object containing UDF details.

    Returns:
        A ScalaUdf object representing the created or cached Scala UDF.
    """
    from snowflake.snowpark_connect.resources_initializer import (
        ensure_scala_udf_jars_uploaded,
    )

    # Lazily upload Scala UDF jars on-demand when a Scala UDF is actually created.
    # This is thread-safe and will only upload once even if multiple threads call it.
    ensure_scala_udf_jars_uploaded()

    function_name = pciudf._function_name
    # If a function name is not provided, hash the binary file and use the first ten characters as the function name.
    if not function_name:
        import hashlib

        function_name = hashlib.sha256(pciudf._payload).hexdigest()[:10]
    udf_name = CREATE_SCALA_UDF_PREFIX + function_name

    # In case the Scala UDF was created with `spark.udf.register`, the Spark Scala input types (from protobuf) are
    # stored in pciudf.scala_input_types.
    # We cannot rely solely on the inputTypes field from the Scala UDF or the Snowpark input types, since:
    # - spark.udf.register arguments come from the inputTypes field
    # - UDFs created with a data type (like below) do not populate the inputTypes field. This requires the input types
    #   inferred by Snowpark. e.g.: udf((i: Long) => (i + 1).toInt, IntegerType)
    input_types = (
        pciudf._scala_input_types if pciudf._scala_input_types else pciudf._input_types
    )

    scala_input_params: List[Param] = []
    sql_input_params: List[Param] = []
    scala_invocation_args: List[str] = []  # arguments passed into the udf function
    if input_types:  # input_types can be None when no arguments are provided
        for i, input_type in enumerate(input_types):
            param_name = "arg" + str(i)
            # Create the Scala arguments and input types string: "arg0: Type0, arg1: Type1, ...".
            scala_input_params.append(
                Param(param_name, map_type_to_scala_type(input_type))
            )
            # Create the Snowflake SQL arguments and input types string: "arg0 TYPE0, arg1 TYPE1, ...".
            sql_input_params.append(
                Param(param_name, map_type_to_snowflake_type(input_type))
            )
            # In the case of Map input types, we need to cast the argument to the correct type in Scala.
            # Snowflake SQL Scala can only handle MAP[VARCHAR, VARCHAR] as input types.
            scala_invocation_args.append(
                cast_scala_map_args_from_given_type(param_name, input_type)
            )

    scala_return_type = map_type_to_scala_type(pciudf._original_return_type)
    # If the SQL return type is a MAP or STRUCT, change this to VARIANT because of issues with Scala UDFs.
    sql_return_type = map_type_to_snowflake_type(pciudf._original_return_type)
    from snowflake.snowpark_connect.utils.session import get_or_create_snowpark_session

    session = get_or_create_snowpark_session()
    imports = build_jvm_udxf_imports(
        session,
        pciudf._payload,
        udf_name,
        is_map_return=sql_return_type.startswith("MAP"),
    )
    sql_return_type = (
        "VARIANT"
        if (sql_return_type.startswith("MAP") or sql_return_type.startswith("OBJECT"))
        else sql_return_type
    )

    udf_def = ScalaUDFDef(
        name=udf_name,
        signature=Signature(
            params=sql_input_params, returns=ReturnType(sql_return_type)
        ),
        imports=imports,
        scala_signature=Signature(
            params=scala_input_params, returns=ReturnType(scala_return_type)
        ),
        scala_invocation_args=scala_invocation_args,
    )
    create_udf_sql = udf_def.to_create_function_sql()
    logger.info(f"Creating Scala UDF: {create_udf_sql}")
    session.sql(create_udf_sql).collect()
    return ScalaUdf(udf_name, pciudf._input_types, pciudf._return_type)


def map_type_to_scala_type(
    t: Union[snowpark_type.DataType, types_proto.DataType]
) -> str:
    """Maps a Snowpark or Spark protobuf type to a Scala type string."""
    if not t:
        return "String"
    is_snowpark_type = isinstance(t, snowpark_type.DataType)
    condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
    match condition:
        case snowpark_type.ArrayType | "array":
            return (
                f"Array[{map_type_to_scala_type(t.element_type)}]"
                if is_snowpark_type
                else f"Array[{map_type_to_scala_type(t.array.element_type)}]"
            )
        case snowpark_type.BinaryType | "binary":
            return "Array[Byte]"
        case snowpark_type.BooleanType | "boolean":
            return "Boolean"
        case snowpark_type.ByteType | "byte":
            return "Byte"
        case snowpark_type.DateType | "date":
            return "java.sql.Date"
        case snowpark_type.DecimalType | "decimal":
            return "java.math.BigDecimal"
        case snowpark_type.DoubleType | "double":
            return "Double"
        case snowpark_type.FloatType | "float":
            return "Float"
        case snowpark_type.GeographyType:
            return "Geography"
        case snowpark_type.IntegerType | "integer":
            return "Int"
        case snowpark_type.LongType | "long":
            return "Long"
        case snowpark_type.MapType | "map":  # can also map to OBJECT in Snowflake
            key_type = (
                map_type_to_scala_type(t.key_type)
                if is_snowpark_type
                else map_type_to_scala_type(t.map.key_type)
            )
            value_type = (
                map_type_to_scala_type(t.value_type)
                if is_snowpark_type
                else map_type_to_scala_type(t.map.value_type)
            )
            return f"Map[{key_type}, {value_type}]"
        case snowpark_type.NullType | "null":
            return "String"  # cannot set the return type to Null in Snowpark Scala UDFs
        case snowpark_type.ShortType | "short":
            return "Short"
        case snowpark_type.StringType | "string" | "char" | "varchar":
            return "String"
        case snowpark_type.StructType | "struct":
            return "Variant"
        case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
            return "java.sql.Timestamp"
        case snowpark_type.VariantType:
            return "Variant"
        case _:
            exception = ValueError(f"Unsupported Snowpark type: {t}")
            attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
            raise exception


def cast_scala_map_args_from_given_type(
    arg_name: str, input_type: Union[snowpark_type.DataType, types_proto.DataType]
) -> str:
    """If the input_type is a Map or Struct, cast the argument arg_name to the correct type in Scala."""
    is_snowpark_type = isinstance(input_type, snowpark_type.DataType)

    def convert_from_string_to_type(
        arg_name: str, t: Union[snowpark_type.DataType, types_proto.DataType]
    ) -> str:
        """Convert the string argument arg_name to the specified type t in Scala."""
        condition = type(t) if is_snowpark_type else t.WhichOneof("kind")
        match condition:
            case snowpark_type.BinaryType | "binary":
                return arg_name + ".getBytes()"
            case snowpark_type.BooleanType | "boolean":
                return arg_name + ".toBoolean"
            case snowpark_type.ByteType | "byte":
                return arg_name + ".getBytes().head"  # TODO: verify if this is correct
            case snowpark_type.DateType | "date":
                return f"java.sql.Date.valueOf({arg_name})"
            case snowpark_type.DecimalType | "decimal":
                return f"new BigDecimal({arg_name})"
            case snowpark_type.DoubleType | "double":
                return arg_name + ".toDouble"
            case snowpark_type.FloatType | "float":
                return arg_name + ".toFloat"
            case snowpark_type.IntegerType | "integer":
                return arg_name + ".toInt"
            case snowpark_type.LongType | "long":
                return arg_name + ".toLong"
            case snowpark_type.ShortType | "short":
                return arg_name + ".toShort"
            case snowpark_type.StringType | "string" | "char" | "varchar":
                return arg_name
            case snowpark_type.TimestampType | "timestamp" | "timestamp_ntz":
                return f"java.sql.Timestamp.valueOf({arg_name})"  # todo add test
            case _:
                exception = ValueError(f"Unsupported Snowpark type: {t}")
                attach_custom_error_code(exception, ErrorCodes.UNSUPPORTED_TYPE)
                raise exception

    if (is_snowpark_type and isinstance(input_type, snowpark_type.MapType)) or (
        not is_snowpark_type and input_type.WhichOneof("kind") == "map"
    ):
        key_type = input_type.key_type if is_snowpark_type else input_type.map.key_type
        value_type = (
            input_type.value_type if is_snowpark_type else input_type.map.value_type
        )
        return f"{arg_name}.map {{ case (k, v) => ({convert_from_string_to_type('k', key_type)}, {convert_from_string_to_type('v', value_type)})}}"
    else:
        return arg_name
