import logging

from tabulate import tabulate

from repo_cli.utils.constant import CVE_STATUS, LICENSE, PLATFORMS
from repo_cli.utils.policy_helper import (
    add_package_rules,
    create_filter,
    parse_and_check_set,
    remove_package_rules,
    validate_cve_id,
)
from repo_cli.utils.validators import check_cve_score, check_date_format

from ..utils.format import string_to_bool
from .base import SubCommandBase

logger = logging.getLogger("repo_cli")


policy_keys = [
    "name",
    "id",
    "created_at",
    "updated_at",
    "artifact_family",
    "policy_state",
    "channel_names",
]
# TODO: Should be removed when BE starts sending channel_names for single policy (CBR-8459)
single_policy_keys = [
    "id",
    "name",
    "description",
    "created_at",
    "updated_at",
    "artifact_family",
    "policy_state",
    "schema_version",
]


class SubCommand(SubCommandBase):
    name = "policy"

    def main(self):
        user = self.api.get_current_user()
        if not "admin" in user["roles"]:
            self.log.error("You need to be an admin to manage policies")
            return

        args = self.args

        if args.create:
            if not args.name:
                self.log.error("Policy name is required")
                return

            add_rules = add_package_rules(args)
            if not add_rules:
                return

            remove_rules = remove_package_rules(args)
            if not remove_rules:
                return

            if args.create:
                self.create_policy(
                    name=args.name,
                    description=args.description,
                    add_package_rules=add_rules,
                    remove_package_rules=remove_rules,
                )
                return

        if args.edit:
            self.edit_policy(
                id=args.edit,
                name=args.name,
                description=args.description,
                args=args,
            )

        if args.delete:
            self.delete_policy(args.delete)
            return

        if args.show:
            self.show(args.show)
            return

        if args.assign:
            if not args.channel:
                self.log.error("Channel name is required")
                return
            self.assign_policy(args.assign, args.channel)
            return

        if args.unassign:
            if not args.channel:
                self.log.error("Channel name is required")
                return
            self.unassign_policy(args.unassign, args.channel)
            return

        if args.list_all:
            self.list_policies()
            return

    def create_policy(
        self,
        name,
        description,
        add_package_rules,
        remove_package_rules,
    ):
        policy = {
            "name": name,
            "description": description,
            "id": "",
            "artifact_family": "conda",
            "schema_version": "2",
            "org_id": "",
            "updated_at": None,
            "created_at": None,
            "names_of_channels_applied_to": "",
        }

        response = self.api.add_policy(policy, add_package_rules, remove_package_rules)
        self.log.info(f"Policy: {name}, ID: {response['id']} created successfully.")

    def find(self, lst, key, value):
        for i, dic in enumerate(lst):
            if dic.get(key) == value:
                return i
        return -1

    def find_mult(self, lst, keys, values):
        for i, dic in enumerate(lst):
            match = True
            for j, key in enumerate(keys):
                if dic.get(key) != values[j]:
                    match = False
                    break
            if match:
                return i
        return -1

    def edit_policy(self, id, name, description, args):
        # get the policy and edit where necessary
        policy = self.api.get_policy(id)

        if name:
            policy["name"] = name
        if description:
            policy["description"] = description

        add_package_rules = policy["filters"]["add_packages"]["subfilters"]

        # platform filter
        if args.platform is not None:
            criterion = parse_and_check_set(
                args.platform, PLATFORMS, "Invalid Platforms : "
            )
            if args.platform is not None and criterion is False:
                return False
            index_in = self.find_mult(
                add_package_rules, ["facet", "comparator"], ["platform", "in"]
            )
            index_not_in = self.find_mult(
                add_package_rules, ["facet", "comparator"], ["platform", "not in"]
            )
            if index_not_in != -1:
                add_package_rules.pop(index_not_in)
            if index_in == -1:
                add_package_rules.append(create_filter(criterion, "platform"))
            else:
                add_package_rules[index_in].update({"criterion": criterion})
        if args.platform_not is not None:
            criterion = parse_and_check_set(
                args.platform_not, PLATFORMS, "Invalid Platforms : "
            )
            if args.platform_not is not None and criterion is False:
                return False
            index_not_in = self.find_mult(
                add_package_rules, ["facet", "comparator"], ["platform", "not in"]
            )
            index_in = self.find_mult(
                add_package_rules, ["facet", "comparator"], ["platform", "in"]
            )
            if index_in != -1:
                add_package_rules.pop(index_in)
            if index_not_in == -1:
                add_package_rules.append(
                    create_filter(criterion, "platform", comparator="not in")
                )
            else:
                add_package_rules[index_not_in].update({"criterion": criterion})

        # license filter
        if args.license is not None:
            criterion = parse_and_check_set(
                args.license, LICENSE, "Invalid Licenses : "
            )
            if args.license is not None and criterion is False:
                return False
            index_in = self.find_mult(
                add_package_rules, ["facet", "comparator"], ["license_family", "in"]
            )
            index_not_in = self.find_mult(
                add_package_rules, ["facet", "comparator"], ["license_family", "not in"]
            )
            if index_not_in != -1:
                add_package_rules.pop(index_not_in)
            if index_in == -1:
                add_package_rules.append(create_filter(criterion, "license_family"))
            else:
                add_package_rules[index_in].update({"criterion": criterion})
        if args.license_not is not None:
            criterion = parse_and_check_set(
                args.license_not, LICENSE, "Invalid Licenses : "
            )
            if args.license_not is not None and criterion is False:
                return False
            index_not_in = self.find_mult(
                add_package_rules, ["facet", "comparator"], ["license_family", "not in"]
            )
            index_in = self.find_mult(
                add_package_rules, ["facet", "comparator"], ["license_family", "in"]
            )
            if index_in != -1:
                add_package_rules.pop(index_in)
            if index_not_in == -1:
                add_package_rules.append(
                    create_filter(criterion, "license_family", comparator="not in")
                )
            else:
                add_package_rules[index_not_in].update({"criterion": criterion})

        # package name filter
        if args.package_name is not None:
            criterion = (
                args.package_name.split(",") if args.package_name.strip() else []
            )
            index = self.find(add_package_rules, "facet", "conda_spec")
            if index == -1:
                add_package_rules.append(create_filter(criterion, "conda_spec"))
            else:
                add_package_rules[index].update({"criterion": criterion})

        # include dependencies
        if args.include_dependencies or args.exclude_dependencies:
            package_names_idx = self.find(add_package_rules, "facet", "conda_spec")
            if (
                package_names_idx == -1
                or not args.package_name
                or add_package_rules[package_names_idx]["criterion"] == []
            ):
                self.log.error(
                    "You need to specify package names to include dependencies."
                )
                return
            else:
                if args.include_dependencies:
                    include_dependencies = True
                else:
                    include_dependencies = False

                index = self.find(add_package_rules, "facet", "include_dependencies")
                if index == -1:
                    add_package_rules.append(
                        create_filter(
                            include_dependencies,
                            "include_dependencies",
                            comparator="==",
                        )
                    )
                else:
                    add_package_rules[index]["criterion"] = include_dependencies

        # only signed packages
        if args.only_signed_packages is not None:
            index = self.find(add_package_rules, "facet", "only_signed")
            if index == -1:
                add_package_rules.append(
                    create_filter(
                        bool(args.only_signed_packages), "only_signed", comparator="=="
                    )
                )
            else:
                add_package_rules[index].update(
                    {"criterion": bool(args.only_signed_packages)}
                )

        # legacy packages
        if args.keep_legacy_packages is not None:
            index = self.find(add_package_rules, "facet", "keep_legacy")
            if index == -1:
                add_package_rules.append(
                    create_filter(
                        bool(args.keep_legacy_packages), "keep_legacy", comparator="=="
                    )
                )
            else:
                add_package_rules[index].update(
                    {"criterion": bool(args.keep_legacy_packages)}
                )

        # package created date filter
        if args.package_created_from:
            index = self.find(add_package_rules, "facet", "package_from_date")
            if index == -1:
                add_package_rules.append(
                    create_filter(
                        args.package_created_from, "package_from_date", comparator=">="
                    )
                )
            else:
                add_package_rules[index].update(
                    {"criterion": args.package_created_from}
                )
        if args.package_created_to:
            index = self.find(add_package_rules, "facet", "package_to_date")
            if index == -1:
                add_package_rules.append(
                    create_filter(
                        args.package_created_to, "package_to_date", comparator="<="
                    )
                )
            else:
                add_package_rules[index].update({"criterion": args.package_created_to})

        # CVE filter
        cve_filters_idx = self.find(add_package_rules, "type", "compound")
        if cve_filters_idx != -1:
            cve_filters = add_package_rules[cve_filters_idx]
        else:
            cve_compound = {
                "operator": "and",
                "type": "compound",
                "subfilters": [],
            }
            add_package_rules.append(cve_compound)
            cve_filters = cve_compound
        cve_subfilters = cve_filters["subfilters"]

        if args.cve_link_status_and_score is not None:
            cve_filters["operator"] = args.cve_link_status_and_score

        if args.cve_score is not None:
            index = self.find(cve_subfilters, "facet", "cve_score")
            if index == -1:
                cve_subfilters.append(
                    create_filter(
                        args.cve_score,
                        "cve_score",
                        comparator=args.cve_score_comparator,
                    )
                )
            else:
                cve_subfilters[index].update(
                    {
                        "criterion": args.cve_score,
                        "comparator": args.cve_score_comparator,
                    }
                )

        if args.cve_status is not None:
            criterion = parse_and_check_set(
                args.cve_status, CVE_STATUS, "Invalid cve status : "
            )
            # Early exit if any CVE fails validation
            if args.cve_status is not None and criterion is False:
                return False
            index = self.find(cve_subfilters, "facet", "cve_status")
            if index == -1:
                cve_subfilters.append(create_filter(criterion, "cve_status"))
            else:
                cve_subfilters[index].update({"criterion": criterion})

        if args.cve_allowlist is not None:
            cve_allowlist = (
                args.cve_allowlist.strip().split(",") if args.cve_allowlist else []
            )
            for cve in cve_allowlist:
                if not validate_cve_id(cve):
                    logger.error(
                        f"Invalid CVE id : {cve} - should be in the format CVE-NNNN-NNNN"
                    )
                    return False
            index = self.find(cve_subfilters, "facet", "cve_allowlist")
            if index == -1:
                cve_subfilters.append(create_filter(cve_allowlist, "cve_allowlist"))
            else:
                cve_subfilters[index].update({"criterion": cve_allowlist})

        # Remove Package Rules Edit
        remove_packages_filters = policy["filters"]["remove_packages"]["subfilters"]
        package_exclusion_idx = self.find(remove_packages_filters, "type", "compound")
        package_exclusion = remove_packages_filters[package_exclusion_idx]["subfilters"]

        if args.exclude_cve_status is not None:
            criterion = parse_and_check_set(
                args.exclude_cve_status, CVE_STATUS, "Invalid CVE status : "
            )
            if args.exclude_cve_status is not None and criterion is False:
                return False
            index = self.find(remove_packages_filters, "facet", "cve_status")
            if index == -1:
                remove_packages_filters[index].append(
                    create_filter(criterion, "cve_status")
                )
            else:
                remove_packages_filters[index].update({"criterion": criterion})

        if args.exclude_package_name is not None:
            criterion = (
                args.exclude_package_name.split(",")
                if args.exclude_package_name.strip()
                else []
            )
            index = self.find(package_exclusion, "facet", "conda_spec")
            if index == -1:
                package_exclusion.append(create_filter(criterion, "conda_spec"))
            else:
                package_exclusion[index].update({"criterion": criterion})

        if args.exclude_package_name_exception is not None:
            criterion = (
                args.exclude_package_name_exception.split(",")
                if args.exclude_package_name_exception.strip()
                else []
            )
            index = self.find(package_exclusion, "facet", "conda_spec_exception")
            if index == -1:
                package_exclusion.append(
                    create_filter(criterion, "conda_spec_exception")
                )
            else:
                package_exclusion[index].update({"criterion": criterion})

        self.api.edit_policy(id, policy, add_package_rules, remove_packages_filters)
        self.log.info(f"Policy with id {id} edited")
        self.show(id)

    def delete_policy(self, policy_id):
        self.api.delete_policy(policy_id)
        self.log.info(f"Policy with id {policy_id} deleted")

    def show(self, policy_id):
        policy = self.api.get_policy(policy_id)

        add_packages = policy["filters"]["add_packages"]["subfilters"]
        cve_index = self.find(add_packages, "type", "compound")
        cve_filter = add_packages[cve_index]["subfilters"]
        cve_filter_operator = add_packages[cve_index]["operator"]
        del add_packages[cve_index]

        remove_packages = policy["filters"]["remove_packages"]["subfilters"]
        package_exclusion_idx = self.find(remove_packages, "type", "compound")
        package_exclusion_filters = remove_packages[package_exclusion_idx]["subfilters"]
        del remove_packages[package_exclusion_idx]

        del policy["filters"]

        # TODO remove policy data filtering when BE starts sending correct channel_names values. (CBR-8459)
        filtered_policy = {
            key: policy[key] for key in single_policy_keys if key in policy
        }

        self.log.info("Policy: " + str(policy["name"]))
        self.log.info(
            tabulate(
                {"keys": single_policy_keys, "values": filtered_policy.values()},
                tablefmt="grid",
            )
        )
        self.log.info("")
        self.log.info("Add Packages Rules: ")
        self.log.info(tabulate(add_packages, tablefmt="grid"))
        self.log.info("")

        self.log.info("CVE Rules: ")
        self.log.info(tabulate(cve_filter, tablefmt="grid"))
        self.log.info(f"CVE Status and Score Operator: '{cve_filter_operator}'")
        self.log.info("")

        self.log.info("Remove Packages Rules: ")
        self.log.info(tabulate(remove_packages, tablefmt="grid"))
        self.log.info(tabulate(package_exclusion_filters, tablefmt="grid"))
        self.log.info("")

    def assign_policy(self, policy_id, channel_name):
        self.api.assign_policy(policy_id, channel_name)
        self.log.info(f"Policy with id {policy_id} assigned to channel {channel_name}")

    def unassign_policy(self, policy_id, channel_name):
        self.api.unassign_policy(policy_id, channel_name)
        self.log.info(
            f"Policy with id {policy_id} unassigned from channel {channel_name}"
        )

    def list_policies(self):
        data = self.api.get_all_policies()
        policies = []
        for policy in data["items"]:
            policies.append(self.format_policy_for_list(policy))

        self.log.info(tabulate(policies, headers=policy_keys, tablefmt="grid"))

    def format_policy_for_list(self, policy):
        filtered_policy = {key: policy[key] for key in policy_keys}
        filtered_policy["created_at"] = self.format_date(filtered_policy["created_at"])

        if filtered_policy["updated_at"]:
            filtered_policy["updated_at"] = self.format_date(
                filtered_policy["updated_at"]
            )

        if filtered_policy["channel_names"]:
            channel_names = []
            for channel in filtered_policy["channel_names"]:
                if channel["parent_name"] is not None:
                    channel_names.append(f"{channel['parent_name']}/{channel['name']}")
                else:
                    channel_names.append(channel["name"])
            filtered_policy["channel_names"] = ", ".join(channel_names)

        return filtered_policy.values()

    def format_date(self, date):
        return date.split(".")[0]

    def add_parser(self, subparsers):
        self.subparser = subparsers.add_parser(
            "policy", help="Policies for filtering", description=__doc__
        )

        self.subparser.add_argument("--channel", help="channel name")

        # subcommands
        self.subparser.add_argument("--delete", help="delete a policy by id")
        self.subparser.add_argument("--assign", help="assign a policy to a channel")
        self.subparser.add_argument(
            "--unassign", help="unassign a policy from a channel"
        )
        self.subparser.add_argument(
            "--show",
            help="show a policy by id",
        )
        self.subparser.add_argument(
            "--list-all", action="store_true", help="list all policies"
        )
        self.subparser.add_argument(
            "--create", action="store_true", help="create a new policy"
        )
        self.subparser.add_argument("--edit", help="ID: Edit a policy ID")

        # arguments for creating  and edeting a policy
        policygroup = self.subparser.add_argument_group(
            "policy creation and editing arguments",
            "arguments for policy creation / editing",
        )
        policygroup.add_argument("--name", help="policy name")
        policygroup.add_argument("--channel-name", help="channel name")
        policygroup.add_argument("--description", help="policy description", default="")

        # Package rules
        policygroup.add_argument(
            "--platform",
            help=f"platforms to include separated by comma. Must be {', '.join(PLATFORMS)}",
        )
        policygroup.add_argument(
            "--platform-not",
            help=f"platform to be excluded separated by comma. Must be {', '.join(PLATFORMS)}",
        )
        policygroup.add_argument(
            "--license",
            help=f"license to include separated by comma. Must be {', '.join(LICENSE)}",
        )
        policygroup.add_argument(
            "--license-not",
            help=f"license to be excluded separated by comma. Must be {', '.join(LICENSE)}",
        )
        policygroup.add_argument(
            "--package-name", help="package names to be included separated by comma"
        )
        policygroup.add_argument(
            "--include-dependencies",
            help="include dependencies",
            action="store_true",
        )
        policygroup.add_argument(
            "--exclude-dependencies",
            help="exclude dependencies",
            action="store_true",
        )
        policygroup.add_argument(
            "--only-signed-packages",
            type=string_to_bool,
            nargs="?",
            const=True,
            default=False,
            help="Keep only signed packages",
        )
        policygroup.add_argument(
            "--keep-legacy-packages",
            type=string_to_bool,
            nargs="?",
            const=True,
            default=False,
            help="Keep only legacy packages",
        )

        policygroup.add_argument(
            "--package-created-from",
            help="package created from",
            type=check_date_format,
        )
        policygroup.add_argument(
            "--package-created-to", help="package created to", type=check_date_format
        )

        # CVE rules
        policygroup.add_argument(
            "--cve-score", help="CVE score", type=check_cve_score, default=None
        )
        policygroup.add_argument(
            "--cve-score-comparator",
            help='cve score comparator, can be "==", "<=", ">=", "<" and ">" default value is "<=" must be escaped by ""',
            default="<=",
        )
        policygroup.add_argument(
            "--cve-status",
            help=f"cve status to include separated by comma. Must be {', '.join(CVE_STATUS)}",
        )
        policygroup.add_argument(
            "--cve-link-status-and-score",
            help="Set the link between cve status and cve score, can be 'and' or 'or'. Default is 'and'",
            default="and",
        )
        policygroup.add_argument(
            "--cve-allowlist",
            help="CVE allowlist: add CVE IDs separated by comma. Must be in the format 'CVE-NNNN-NNNN'",
        )

        # remove package rules
        policygroup.add_argument(
            "--exclude-cve-status",
            help=f"CVE status to be excluded, separated by comma. Must be {', '.join(CVE_STATUS)}",
        )
        policygroup.add_argument(
            "--exclude-package-name",
            help="Package names to be excluded, separated by comma",
        )
        policygroup.add_argument(
            "--exclude-package-name-exception",
            help="Package names to be removed from the exclusion, separated by comma. Package names accept matchspec protocols, see: https://docs.conda.io/projects/conda/en/latest/dev-guide/api/conda/models/match_spec/index.html ",
        )

        self.subparser.set_defaults(main=self.main)
