/*
 * Copyright (c) 2024, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include <cassert>
#include <cstddef>
#include <cstring>
#include <iostream>
#include <memory>
#include <optional>
#include <regex>
#include <sstream>
#include <stdexcept>
#include <string>

#include <kvikio/defaults.hpp>
#include <kvikio/error.hpp>
#include <kvikio/parallel_operation.hpp>
#include <kvikio/posix_io.hpp>
#include <kvikio/utils.hpp>

namespace kvikio {
namespace detail {

/**
 * @brief Bounce buffer in pinned host memory.
 *
 * @note Is not thread-safe.
 */
class BounceBufferH2D {
  CUstream _stream;                 // The CUDA stream to use.
  CUdeviceptr _dev;                 // The output device buffer.
  AllocRetain::Alloc _host_buffer;  // The host buffer to bounce data on.
  std::ptrdiff_t _dev_offset{0};    // Number of bytes written to `_dev`.
  std::ptrdiff_t _host_offset{0};   // Number of bytes written to `_host` (resets on flush).

 public:
  /**
   * @brief Create a bounce buffer for an output device buffer.
   *
   * @param stream The CUDA stream used throughout the lifetime of the bounce buffer.
   * @param device_buffer The output device buffer (final destination of the data).
   */
  BounceBufferH2D(CUstream stream, void* device_buffer)
    : _stream{stream},
      _dev{convert_void2deviceptr(device_buffer)},
      _host_buffer{AllocRetain::instance().get()}
  {
  }

  /**
   * @brief The bounce buffer if flushed to device on destruction.
   */
  ~BounceBufferH2D() noexcept
  {
    try {
      flush();
    } catch (CUfileException const& e) {
      std::cerr << "BounceBufferH2D error on final flush: ";
      std::cerr << e.what();
      std::cerr << std::endl;
    }
  }

 private:
  /**
   * @brief Write host memory to the output device buffer.
   *
   * @param src The host memory source.
   * @param size Number of bytes to write.
   */
  void write_to_device(void const* src, std::size_t size)
  {
    if (size > 0) {
      CUDA_DRIVER_TRY(cudaAPI::instance().MemcpyHtoDAsync(_dev + _dev_offset, src, size, _stream));
      CUDA_DRIVER_TRY(cudaAPI::instance().StreamSynchronize(_stream));
      _dev_offset += size;
    }
  }

  /**
   * @brief Flush the bounce buffer by writing everything to the output device buffer.
   */
  void flush()
  {
    write_to_device(_host_buffer.get(), _host_offset);
    _host_offset = 0;
  }

 public:
  /**
   * @brief Write host memory to the bounce buffer (also host memory).
   *
   * Only when the bounce buffer has been filled up is data copied to the output device buffer.
   *
   * @param data The host memory source.
   * @param size Number of bytes to write.
   */
  void write(char const* data, std::size_t size)
  {
    if (_host_buffer.size() - _host_offset < size) {  // Not enough space left in the bounce buffer
      flush();
      assert(_host_offset == 0);
    }
    if (_host_buffer.size() < size) {
      // If still not enough space, we just copy the data to the device. This only happens when
      // `defaults::bounce_buffer_size()` is smaller than 16kb thus no need to performance
      // optimize for this case.
      write_to_device(data, size);
    } else if (size > 0) {
      std::memcpy(_host_buffer.get(_host_offset), data, size);
      _host_offset += size;
    }
  }
};

}  // namespace detail

class CurlHandle;  // Prototype

/**
 * @brief Abstract base class for remote endpoints.
 *
 * In this context, an endpoint refers to a remote file using a specific communication protocol.
 *
 * Each communication protocol, such as HTTP or S3, needs to implement this ABC and implement
 * its own ctor that takes communication protocol specific arguments.
 */
class RemoteEndpoint {
 public:
  /**
   * @brief Set needed connection options on a curl handle.
   *
   * Subsequently, a call to `curl.perform()` should connect to the endpoint.
   *
   * @param curl The curl handle.
   */
  virtual void setopt(CurlHandle& curl) = 0;

  /**
   * @brief Get a description of this remote point instance.
   *
   * @returns A string description.
   */
  virtual std::string str() const = 0;

  virtual ~RemoteEndpoint() = default;
};

/**
 * @brief A remote endpoint using http.
 */
class HttpEndpoint : public RemoteEndpoint {
 private:
  std::string _url;

 public:
  /**
   * @brief Create an http endpoint from a url.
   *
   * @param url The full http url to the remote file.
   */
  HttpEndpoint(std::string url) : _url{std::move(url)} {}
  void setopt(CurlHandle& curl) override;
  std::string str() const override { return _url; }
  ~HttpEndpoint() override = default;
};

/**
 * @brief A remote endpoint using AWS's S3 protocol.
 */
class S3Endpoint : public RemoteEndpoint {
 private:
  std::string _url;
  std::string _aws_sigv4;
  std::string _aws_userpwd;

  /**
   * @brief Unwrap an optional parameter, obtaining a default from the environment.
   *
   * If not nullopt, the optional's value is returned. Otherwise, the environment
   * variable `env_var` is used. If that also doesn't have a value:
   *   - if `err_msg` is empty, the empty string is returned.
   *   - if `err_msg` is not empty, `std::invalid_argument(`err_msg`)` is thrown.
   *
   * @param value The value to unwrap.
   * @param env_var The name of the environment variable to check if `value` isn't set.
   * @param err_msg The error message to throw on error or the empty string.
   * @return The parsed AWS argument or the empty string.
   */
  static std::string unwrap_or_default(std::optional<std::string> aws_arg,
                                       std::string const& env_var,
                                       std::string const& err_msg = "")
  {
    if (aws_arg.has_value()) { return std::move(*aws_arg); }

    char const* env = std::getenv(env_var.c_str());
    if (env == nullptr) {
      if (err_msg.empty()) { return std::string(); }
      throw std::invalid_argument(err_msg);
    }
    return std::string(env);
  }

 public:
  /**
   * @brief Get url from a AWS S3 bucket and object name.
   *
   * @throws std::invalid_argument if no region is specified and no default region is
   * specified in the environment.
   *
   * @param bucket_name The name of the S3 bucket.
   * @param object_name The name of the S3 object.
   * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
   * `AWS_DEFAULT_REGION` environment variable is used.
   * @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using
   * the scheme: "<aws_endpoint_url>/<bucket_name>/<object_name>". If nullopt, the value of the
   * `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
   * url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
   */
  static std::string url_from_bucket_and_object(std::string const& bucket_name,
                                                std::string const& object_name,
                                                std::optional<std::string> const& aws_region,
                                                std::optional<std::string> aws_endpoint_url)
  {
    auto const endpoint_url = unwrap_or_default(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL");
    std::stringstream ss;
    if (endpoint_url.empty()) {
      auto const region =
        unwrap_or_default(std::move(aws_region),
                          "AWS_DEFAULT_REGION",
                          "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set.");
      // We default to the official AWS url scheme.
      ss << "https://" << bucket_name << ".s3." << region << ".amazonaws.com/" << object_name;
    } else {
      ss << endpoint_url << "/" << bucket_name << "/" << object_name;
    }
    return ss.str();
  }

  /**
   * @brief Given an url like "s3://<bucket>/<object>", return the name of the bucket and object.
   *
   * @throws std::invalid_argument if url is ill-formed or is missing the bucket or object name.
   *
   * @param s3_url S3 url.
   * @return Pair of strings: [bucket-name, object-name].
   */
  [[nodiscard]] static std::pair<std::string, std::string> parse_s3_url(std::string const& s3_url)
  {
    // Regular expression to match s3://<bucket>/<object>
    std::regex const pattern{R"(^s3://([^/]+)/(.+))", std::regex_constants::icase};
    std::smatch matches;
    if (std::regex_match(s3_url, matches, pattern)) { return {matches[1].str(), matches[2].str()}; }
    throw std::invalid_argument("Input string does not match the expected S3 URL format.");
  }

  /**
   * @brief Create a S3 endpoint from a url.
   *
   * @param url The full http url to the S3 file. NB: this should be an url starting with
   * "http://" or "https://". If you have an S3 url of the form "s3://<bucket>/<object>", please
   * use `S3Endpoint::parse_s3_url()` and `S3Endpoint::url_from_bucket_and_object() to convert it.
   * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
   * `AWS_DEFAULT_REGION` environment variable is used.
   * @param aws_access_key The AWS access key to use. If nullopt, the value of the
   * `AWS_ACCESS_KEY_ID` environment variable is used.
   * @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
   * `AWS_SECRET_ACCESS_KEY` environment variable is used.
   */
  S3Endpoint(std::string url,
             std::optional<std::string> aws_region            = std::nullopt,
             std::optional<std::string> aws_access_key        = std::nullopt,
             std::optional<std::string> aws_secret_access_key = std::nullopt)
    : _url{std::move(url)}
  {
    // Regular expression to match http[s]://
    std::regex pattern{R"(^https?://.*)", std::regex_constants::icase};
    if (!std::regex_search(_url, pattern)) {
      throw std::invalid_argument("url must start with http:// or https://");
    }

    auto const region =
      unwrap_or_default(std::move(aws_region),
                        "AWS_DEFAULT_REGION",
                        "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set.");

    auto const access_key =
      unwrap_or_default(std::move(aws_access_key),
                        "AWS_ACCESS_KEY_ID",
                        "S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set.");

    auto const secret_access_key = unwrap_or_default(
      std::move(aws_secret_access_key),
      "AWS_SECRET_ACCESS_KEY",
      "S3: must provide `aws_secret_access_key` if AWS_SECRET_ACCESS_KEY isn't set.");

    // Create the CURLOPT_AWS_SIGV4 option
    {
      std::stringstream ss;
      ss << "aws:amz:" << region << ":s3";
      _aws_sigv4 = ss.str();
    }
    // Create the CURLOPT_USERPWD option
    // Notice, curl uses `secret_access_key` to generate a AWS V4 signature. It is NOT included
    // in the http header. See
    // <https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html>
    {
      std::stringstream ss;
      ss << access_key << ":" << secret_access_key;
      _aws_userpwd = ss.str();
    }
  }

  /**
   * @brief Create a S3 endpoint from a bucket and object name.
   *
   * @param bucket_name The name of the S3 bucket.
   * @param object_name The name of the S3 object.
   * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
   * `AWS_DEFAULT_REGION` environment variable is used.
   * @param aws_access_key The AWS access key to use. If nullopt, the value of the
   * `AWS_ACCESS_KEY_ID` environment variable is used.
   * @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
   * `AWS_SECRET_ACCESS_KEY` environment variable is used.
   * @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using
   * the scheme: "<aws_endpoint_url>/<bucket_name>/<object_name>". If nullopt, the value of the
   * `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
   * url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
   */
  S3Endpoint(std::string const& bucket_name,
             std::string const& object_name,
             std::optional<std::string> aws_region            = std::nullopt,
             std::optional<std::string> aws_access_key        = std::nullopt,
             std::optional<std::string> aws_secret_access_key = std::nullopt,
             std::optional<std::string> aws_endpoint_url      = std::nullopt)
    : S3Endpoint(url_from_bucket_and_object(
                   bucket_name, object_name, aws_region, std::move(aws_endpoint_url)),
                 std::move(aws_region),
                 std::move(aws_access_key),
                 std::move(aws_secret_access_key))
  {
  }

  void setopt(CurlHandle& curl) override;
  std::string str() const override { return _url; }
  ~S3Endpoint() override = default;
};

/**
 * @brief Handle of remote file.
 */
class RemoteHandle {
 private:
  std::unique_ptr<RemoteEndpoint> _endpoint;
  std::size_t _nbytes;

 public:
  /**
   * @brief Create a new remote handle from an endpoint and a file size.
   *
   * @param endpoint Remote endpoint used for subsequent IO.
   * @param nbytes The size of the remote file (in bytes).
   */
  RemoteHandle(std::unique_ptr<RemoteEndpoint> endpoint, std::size_t nbytes)
    : _endpoint{std::move(endpoint)}, _nbytes{nbytes}
  {
  }

  /**
   * @brief Create a new remote handle from an endpoint (infers the file size).
   *
   * The file size is received from the remote server using `endpoint`.
   *
   * @param endpoint Remote endpoint used for subsequently IO.
   */
  RemoteHandle(std::unique_ptr<RemoteEndpoint> endpoint);

  // A remote handle is moveable but not copyable.
  RemoteHandle(RemoteHandle&& o)               = default;
  RemoteHandle& operator=(RemoteHandle&& o)    = default;
  RemoteHandle(RemoteHandle const&)            = delete;
  RemoteHandle& operator=(RemoteHandle const&) = delete;

  /**
   * @brief Get the file size.
   *
   * Note, this is very fast, no communication needed.
   *
   * @return The number of bytes.
   */
  [[nodiscard]] std::size_t nbytes() const noexcept { return _nbytes; }

  /**
   * @brief Get a const reference to the underlying remote endpoint.
   *
   * @return The remote endpoint.
   */
  [[nodiscard]] RemoteEndpoint const& endpoint() const noexcept { return *_endpoint; }

  /**
   * @brief Read from remote source into buffer (host or device memory).
   *
   * When reading into device memory, a bounce buffer is used to avoid many small memory
   * copies to device. Use `kvikio::default::bounce_buffer_size_reset()` to set the size
   * of this bounce buffer (default 16 MiB).
   *
   * @param buf Pointer to host or device memory.
   * @param size Number of bytes to read.
   * @param file_offset File offset in bytes.
   * @return Number of bytes read, which is always `size`.
   */
  std::size_t read(void* buf, std::size_t size, std::size_t file_offset = 0);

  /**
   * @brief Read from remote source into buffer (host or device memory) in parallel.
   *
   * This API is a parallel async version of `.read()` that partitions the operation
   * into tasks of size `task_size` for execution in the default thread pool.
   *
   * @param buf Pointer to host or device memory.
   * @param size Number of bytes to read.
   * @param file_offset File offset in bytes.
   * @param task_size Size of each task in bytes.
   * @return Future that on completion returns the size of bytes read, which is always `size`.
   */
  std::future<std::size_t> pread(void* buf,
                                 std::size_t size,
                                 std::size_t file_offset = 0,
                                 std::size_t task_size   = defaults::task_size());
};

}  // namespace kvikio
