import re
import string

from conda.models.match_spec import MatchSpec

from .logging import logger as _l, log_prefix
from .utils import MirrorException


standard_license_families = {
    "AGPL",
    "APACHE",
    "BSD",
    "CC",
    "GPL",
    "GPL2",
    "GPL3",
    "GPL-GCC",
    "LGPL",
    "MIT",
    "MOZILLA",
    "NONE",
    "OTHER",
    "PROPRIETARY",
    "PSF",
    "PUBLICDOMAIN",
}

_matchspec_cache = {}


def _matchspec(dstr):
    """
    Return a MatchSpec, a normalized MatchSpec string,
    and the package name for a raw MatchSpec input string.
    Use caching to improve performance
    """
    drec = _matchspec_cache.get(dstr)
    if drec is None:
        dep = MatchSpec(dstr)
        drec = (dep, str(dep), dep.name)
        _matchspec_cache[dstr] = _matchspec_cache[drec[1]] = drec
    return drec


def _matchname(name):
    return _matchspec(name)[0]._match_components["name"]


GLOB = _matchname("*")


def filter_format(config, upstream_indices):
    counts = {}
    if config.format_policy == "keep-both" or not upstream_indices:
        return counts
    mode, fmt = config.format_policy.split("-", 1)
    if fmt == "tarbz2":
        good_fmt, bad_fmt = ".tar.bz2", ".conda"
    else:
        good_fmt, bad_fmt = ".conda", ".tar.bz2"
    n_bad = len(bad_fmt)
    mode, fmt = config.format_policy.split("-", 1)
    for platform, upstream_index in upstream_indices.items():
        new_pkgs = {}
        t_counts = {"redundant": 0, "dropped": 0, "transforming": 0}
        for fn, info in upstream_index.items():
            if fn.endswith(good_fmt):
                continue
            good_fn = fn[:-n_bad] + good_fmt
            if good_fn in upstream_index:
                t_counts["redundant"] += 1
                info["_skip"] = "redundant"
            elif info.get("_skip"):
                pass
            elif mode == "only":
                t_counts["redundant"] += 1
                info["_skip"] = "format"
            elif mode == "transmute":
                # For transmuting, we actually need to add the
                # destination filename to the package list. We
                # copy the source record to that new filename.
                t_counts["transforming"] += 1
                info_dst = info.copy()
                info_dst["fn"] = fn
                info["_skip"] = "transmuting"
                new_pkgs[good_fn] = info_dst
        upstream_index.update(new_pkgs)
        if any(t_counts.values()):
            counts[platform] = ", ".join(
                "%d %s" % (v, k) for k, v in t_counts.items() if v
            )
    return counts


class MatchBase(object):
    @classmethod
    def from_matchspec(cls, ms):
        if isinstance(ms, str):
            ms, _, _ = _matchspec(ms)
        name = ms._match_components["name"]
        children = [
            MatchField(k, v) for k, v in ms._match_components.items() if k != "name"
        ]
        return name, MatchAnd.from_specs(*children)

    def __init__(self):
        self._hash = None

    def __hash__(self):
        if self._hash is None:
            items = [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")]
            self._hash = hash(tuple(sorted(items)))
        return self._hash

    def __eq__(self, other):
        return type(self) == type(other) and (  # noqa
            self is other or hash(self) == hash(other)
        )

    def match(self, rec):
        raise NotImplementedError()

    def match_callback(self, rec, if_true=None, if_false=None):
        result = self.match(rec)
        callback = if_true if result else if_false
        if callback is not None:
            callback(self, rec)
        return result


class MatchNone(MatchBase):
    def match(self, rec):
        return False

    def __init__(self):
        super().__init__()

    def __str__(self):
        return "<none>"

    def __repr__(self):
        return "MatchNone()"


class MatchAny(MatchBase):
    def match(self, rec):
        return True

    def __str__(self):
        return "<any>"

    def __repr__(self):
        return "MatchAny()"


def _normalize_license_name(s, license):
    if not s:
        s = license
    if not s:
        return "NONE"
    if "GPL" in s and "GCC-exception" in license:
        return "GPL-GCC"
    if "LGPL" in s or "LGPL" in license:
        return "LGPL"
    s = s.upper().strip()
    s = re.sub(r"GENERAL PUBLIC LICENSE", "GPL", s)
    s = re.sub(r"LESSER\s+", "L", s)
    s = re.sub(r"AFFERO\s+", "A", s)
    s = re.sub("[%s]" % re.escape(string.punctuation), "", s)
    s = re.sub(r"\s+", "", s)
    return s.strip() or "NONE"


class MatchLicense(MatchBase):
    def __init__(self, license):
        super().__init__()
        self.license = license

    def __str__(self):
        return f"license({self.license})"

    def __repr__(self):
        return f"MatchLicense({repr(self.license)})"

    def match(self, rec):
        license_fam = _normalize_license_name(
            rec.get("license_family"), rec.get("license")
        )
        if license_fam not in standard_license_families:
            for fam in standard_license_families:
                if fam in license_fam:
                    license_fam = fam
                    break
            else:
                if re.match(r"(^|[^AL])GPL([^23]*2|.*>= *)2", license_fam):
                    license_fam = "GPL2"
                elif re.match(r"(^|[^AL])GPL([^23]*3|.*>= *)3", license_fam):
                    license_fam = "GPL3"
        return self.license == license_fam


class MatchDepVer(MatchBase):
    def __init__(self, dname, dversion):
        super().__init__()
        self.dname = dname
        self.dversion = dversion
        self._prototypes = [
            {
                "name": dname,
                "version": dversion + "." + str(mv),
                "build": "test",
                "build_number": 0,
            }
            for mv in (0, 1)
        ]

    def __repr__(self):
        return f"MatchDepVer({repr(self.dname)},{repr(self.dversion)})"

    def __str__(self):
        return f"depends({self.dname}={self.dversion})"

    def match(self, rec):
        if rec["name"] == self.dname:
            return True
        for field in ("depends", "constrains"):
            for d in rec.get(field) or ():
                if d.startswith(self.dname):
                    if d == self.dname:
                        return True
                    ms, _, msname = _matchspec(d)
                    if msname != self.dname:
                        continue
                    return any(ms.match(prec) for prec in self._prototypes)
        return True


class MatchField(MatchBase):
    def __init__(self, field, matcher):
        super().__init__()
        self.field = field
        self.matcher = matcher

    def __str__(self):
        return f"{self.field}:{str(self.matcher)}"

    def __repr__(self):
        return f"MatchField({repr(self.field)},{repr(self.matcher)})"

    def match(self, rec):
        return self.field in rec and self.matcher.match(rec[self.field])


class MatchNot(MatchBase):
    @classmethod
    def from_specs(cls, arg):
        if isinstance(arg, (str, MatchSpec)):
            _, arg = MatchBase.from_matchspec(arg)
        if isinstance(arg, MatchAny):
            return MatchNone()
        if isinstance(arg, MatchNone):
            return MatchAny()
        if isinstance(arg, MatchNot):
            return arg.atom
        return MatchNot(arg)

    def __init__(self, atom):
        super().__init__()
        self.atom = atom

    def __repr__(self):
        return f"MatchNot({repr(self.atom)})"

    def __str__(self):
        return f"not({str(self.atom)})"

    def match(self, rec):
        return not self.atom.match(rec)

    def match_callback(self, rec, if_true=None, if_false=None):
        return not self.atom.match_callback(rec, if_true=if_false, if_false=if_true)


class MatchTree(MatchBase):
    oper = "<none>"

    @classmethod
    def from_specs(cls, *args, internal=False):
        no_op = MatchAny if cls is MatchAnd else MatchNone
        new_args = set()
        for arg in args:
            if arg is None or isinstance(arg, no_op):
                continue
            elif isinstance(arg, (str, MatchSpec)):
                _, arg = MatchBase.from_matchspec(arg)
            elif not isinstance(arg, MatchBase):
                raise TypeError(f"atom must not be a(n) {repr(arg)}")
            if isinstance(arg, MatchTree) and cls.oper == arg.oper:
                # associativity
                new_args.update(arg.atoms)
            elif (
                isinstance(arg, MatchNot)
                and isinstance(arg.atom, MatchTree)
                and cls.oper != arg.atom.oper
            ):
                # demorgans
                new_args.update(MatchNot(a) for a in arg.atom.atoms)
            elif isinstance(arg, MatchAny):
                if cls.oper == "or":
                    new_args = {arg}
                    break
            elif isinstance(arg, MatchNone):
                if cls.oper == "and":
                    new_args = {arg}
                    break
            else:
                new_args.add(arg)
        if internal:
            return new_args
        elif len(new_args) > 1:
            return cls(*new_args, internal=True)
        elif len(new_args) == 1:
            return next(iter(new_args))
        elif cls == MatchAnd:
            return MatchAny()
        else:
            return MatchNone()

    def __init__(self, *args, internal=False):
        super().__init__()
        if not internal:
            args = self.from_specs(*args, internal=True)
        self.atoms = args

    def __repr__(self):
        return f"{self.__class__.__name__}({','.join(map(repr,self.atoms))})"

    def __str__(self):
        return f"{self.oper}({','.join(map(str,self.atoms))})"


class MatchAnd(MatchTree):
    oper = "and"

    def match(self, rec):
        return all(a.match(rec) for a in self.atoms)

    def match_callback(self, rec, if_true=None, if_false=None):
        matched = True
        for a in self.atoms:
            if not a.match_callback(rec, if_false=if_false):
                matched = False
        if matched and if_true is not None:
            if_true(self, rec)
        return matched


class MatchOr(MatchTree):
    oper = "or"

    def match(self, rec):
        return any(a.match(rec) for a in self.atoms)

    def match_callback(self, rec, if_true=None, if_false=None):
        matched = False
        for a in self.atoms:
            if a.match_callback(rec, if_true=if_true):
                matched = True
        if not matched and if_false is not None:
            if_false(self, rec)
        return matched


def filter_spec_dict(matches):
    filters = {}
    for ms in matches:
        msname, matcher = MatchBase.from_matchspec(ms)
        filters[msname] = MatchOr.from_specs(filters.get(msname), matcher)
    return filters


def filter_version_dict(versions, package):
    constraints = {}
    vlist = [v.rstrip("*").rstrip(".") for v in versions]
    constraints[_matchname(package)] = MatchOr.from_specs(
        *(f"{package}={v}.*" for v in vlist)
    )
    dspec = MatchOr.from_specs(*(MatchDepVer(package, v) for v in vlist))
    constraints[GLOB] = MatchAnd.from_specs(constraints.get(GLOB), dspec)
    return constraints


def filter_license_dict(licenses):
    exclusions = {}
    lspecs = [MatchLicense(lic) for lic in licenses]
    lspec = MatchOr.from_specs(*lspecs)
    exclusions[GLOB] = MatchAnd.from_specs(exclusions.get(GLOB), lspec)
    return exclusions


def filter_requests(groups, requests, prune=False):
    touched_names = set()
    for msname, matcher in requests.items():
        for name in filter(msname.match, groups):
            touched_names.add(name)
            for version, sdict in groups[name].items():
                for subdir, pdict in sdict.items():
                    for fn, pkg in pdict.items():
                        if matcher.match(pkg):
                            pkg["_touched"] = True
    prune_limit = None if prune else touched_names
    prune_untouched(groups, "pkg_list", prune_limit)
    return touched_names


def filter_spec(groups, filter_dict, positive, reason):
    if not isinstance(filter_dict, dict):
        filters, filter_dict = filter_dict, {}
        for ms in filters:
            mname, ms = MatchBase.from_matchspec(ms)
            filter_dict[mname] = MatchOr.from_specs(filter_dict.get(mname), ms)
    counts = {}
    msname = ""

    def _count(matcher, pkg):
        r = f"{msname}:{matcher}"
        name = pkg["name"]
        crec = counts.setdefault(r, {})
        crec[name] = crec.get(name, 0) + 1

    for msname, matcher in filter_dict.items():
        for name in filter(msname.match, groups):
            for version, sdict in groups[name].items():
                for subdir, pdict in sdict.items():
                    for fn, pkg in pdict.items():
                        _skip = pkg.get("_skip")
                        if _skip and not _skip.startswith("filter:"):
                            continue
                        if not reason and not _skip:
                            continue
                        kwargs = {"if_true" if positive else "if_false": _count}
                        if bool(matcher.match_callback(pkg, **kwargs)) == positive:
                            if reason is None:
                                del pkg["_skip"]
                            else:
                                pkg["_skip"] = f"filter:{reason}"
    return counts


def prune_untouched(groups, reason=None, names=None):
    for name in names or list(groups):
        vdict = groups[name]
        for version, sdict in list(vdict.items()):
            for subdir, pdict in list(sdict.items()):
                for fn, pkg in list(pdict.items()):
                    if pkg.get("_touched"):
                        del pkg["_touched"]
                    elif reason is not None:
                        pkg.setdefault("_skip", reason)
                    if pkg.get("_skip"):
                        del pdict[fn]
                if not pdict:
                    del sdict[subdir]
            if not sdict:
                del vdict[version]
        if not vdict:
            del groups[name]


def prune_uninstallable(groups):
    """Remove packages from the database that are missing at least one
    valid dependency. This should in theory produce a set of packages
    that can be installed singly. But this is a conservative pruning;
    dependency cycles and noarch packages are treated in such a way
    that some uninstallable packages may remain."""

    match_cache = {}

    def _test_dep(top_sub, dstr):
        # Assume virtual packages are available
        if dstr.startswith("__"):
            return True
        dep, dstr, dname = _matchspec(dstr)
        dkey = (top_sub, dstr)
        if dkey in match_cache:
            return match_cache[dkey]
        match_cache[dkey] = True
        dversion = None if dep is None else dep.version
        for version, sdict in groups.get(dname, {}).items():
            if not (dversion is None or dversion.match(version)):
                continue
            for subdir in (top_sub, "noarch"):
                for fn, pkg in sdict.get(subdir, {}).items():
                    if dep is None or dep.match(pkg):
                        if _touch(top_sub, subdir, fn, pkg):
                            match_cache[dkey] = True
                            return True
        match_cache[dkey] = False
        return False

    pkg_touched = {}

    def _touch(top_sub, sub, fn, pkg):
        pkey = (top_sub, sub, fn)
        if pkey in pkg_touched:
            return pkg_touched[pkey]
        # This breaks potential cycles in the dep graph.
        # Packages that are uninstallable despite this
        # assumption will still be marked False below.
        pkg_touched[pkey] = True
        result = all(_test_dep(top_sub, dstr) for dstr in pkg["depends"])
        pkg_touched[pkey] = result
        if result:
            pkg["_touched"] = True
        return result

    # All of the platforms we need to search over, minus noarch
    platforms = sorted(
        set(
            subdir
            for vdict in groups.values()
            for sdict in vdict.values()
            for subdir in sdict
        )
        - {"noarch"}
    )

    # Iterate through each the package and check it for individual
    # installability. Some packages may prove installable on some
    # platforms, but not others. prune_untouched groups the platform
    # tests together so such packages will be preserved.
    for name, vdict in groups.items():
        for version, sdict in vdict.items():
            for subdir, pdict in sdict.items():
                top_subs = platforms if subdir == "noarch" else (subdir,)
                for fn, pkg in pdict.items():
                    for top_sub in top_subs:
                        _touch(top_sub, subdir, fn, pkg)

    prune_untouched(groups, "uninstallable")


def prune_untouchables(groups, requests, reason="unreachable"):
    """Find all packages that can be reached from the set of
    requested dependencies, and then prune all others."""

    # Keeps track of the packages we have already studied. This saves
    # not just the package itself, but the target subdir it is being
    # installed in ("top_sub"). This is because the dependency graph
    # for noarch packages will necessarily be platform dependent
    pkg_touched = set()

    # Keeps track of the dependency searches we have already performed
    match_touched = set()

    def _touch(subdir, fn, pkg, top_sub, dep=None):
        key = (top_sub, subdir, fn)
        if key in pkg_touched or dep is not None and not dep.match(pkg):
            return
        pkg_touched.add(key)
        pkg["_touched"] = True
        l_deps = pkg["depends"] + ["pip"] if pkg["name"] == "python" else pkg["depends"]
        for dstr in l_deps:
            dep, dstr, dname = _matchspec(dstr)
            dkey = (top_sub, dstr)
            if dkey not in match_touched:
                match_touched.add(dkey)
                dversion = dep.version
                for version, sdict in groups.get(dname, {}).items():
                    if dversion is None or dversion.match(version):
                        for subdir2 in (top_sub, "noarch"):
                            for fn2, pkg2 in sdict.get(subdir2, {}).items():
                                _touch(subdir2, fn2, pkg2, top_sub, dep)

    all_subdirs = sorted(
        set(
            subdir
            for vdict in groups.values()
            for sdict in vdict.values()
            for subdir in sdict
        )
        - {"noarch"}
    )
    if requests is None:
        requests = groups
    for name in requests:
        for version, sdict in groups.get(name, {}).items():
            for subdir, pdict in sdict.items():
                try_sub = all_subdirs if subdir == "noarch" else (subdir,)
                for fn, pkg in pdict.items():
                    for subdir2 in try_sub:
                        _touch(subdir, fn, pkg, subdir2)

    prune_untouched(groups, reason)


def count_groups(groups, limit=None):
    n_names = n_versions = n_packages = 0
    for name, vdict in groups.items():
        if limit is not None and name not in limit:
            continue
        n_names += 1
        n_versions += len(vdict)
        for sdict in vdict.values():
            for pdict in sdict.values():
                n_packages += len(pdict)
    return n_names, n_versions, n_packages


def repodata_to_groups(upstream_indices):
    groups = {}
    for subdir, data in upstream_indices.items():
        for fn, pkg in data.items():
            if pkg.get("_skip"):
                continue
            groups.setdefault(pkg["name"], {}).setdefault(
                pkg["version"], {}
            ).setdefault(subdir, {})[fn] = pkg
    return groups


def filter_all(config, upstream_indices):
    def _info(groups, why):
        why = f"{why}: " if why.strip() else why
        n_names, n_versions, n_packages2 = count_groups(groups)
        _l.info(f"{why}{n_names} names / {n_versions} versions / {n_packages2} blobs")
        return n_packages2

    _l.info("Applying filters")
    with log_prefix("  | "):
        groups = repodata_to_groups(upstream_indices)
        n_pkgs = _info(groups, "initial")

        if config.pkg_list:
            requests = filter_spec_dict(config.pkg_list)
            request_names = filter_requests(
                groups, requests, prune=not config.dependencies
            )
            t_groups = {k: v for k, v in groups.items() if k in request_names}
            n_pkgs = _info(t_groups, "pkg_list")
            if config.dependencies:
                prune_untouchables(groups, request_names, "pkg_list")
                t_groups = {k: v for k, v in groups.items() if k not in request_names}
                _info(t_groups, "dependencies")
                n_pkgs = _info(groups, "total")

        if config.constraints:
            constraints = filter_spec_dict(config.constraints)
            counts = filter_spec(groups, constraints, False, "constraints")
            _info(groups, "constraints")
            for reason, count in counts.items():
                _l.info(f"  {reason}: {count}")

        if config.python_versions:
            constraints = filter_version_dict(config.python_versions, "python")
            counts = filter_spec(groups, constraints, False, "python_version")
            _info(groups, "python versions")
            for reason, count in counts.items():
                _l.info(f"  {reason}: {sum(count.values())}")

        if config.r_versions:
            constraints = filter_version_dict(config.python_versions, "r-base")
            counts = filter_spec(groups, constraints, False, "r_version")
            _info(groups, "R versions")
            for reason, count in counts.items():
                _l.info(f"  {reason}: {sum(count.values())}")

        if config.license_exclusions:
            exclusions = filter_license_dict(config.license_exclusions)
            counts = filter_spec(groups, exclusions, True, "license")
            _info(groups, "license")
            for reason, count in counts.items():
                _l.info(f"  {reason}: {', '.join(count)}")

        if config.exclusions:
            exclusions = filter_spec_dict(config.exclusions)
            counts = filter_spec(groups, exclusions, True, "exclusions")
            _info(groups, "exclusions")
            for reason, count in counts.items():
                _l.info(f"  {reason}: {sum(count.values())}")

        if config.inclusions:
            inclusions = filter_spec_dict(config.inclusions)
            counts = filter_spec(groups, inclusions, True, None)
            _info(groups, "inclusions")
            for reason, count in counts.items():
                _l.info(f"  {reason}: {sum(count.values())}")

        counts = filter_format(config, upstream_indices)
        # This rolls up all of the packages eliminated from
        # the filters above as well
        groups = repodata_to_groups(upstream_indices)
        if counts:
            _info(groups, "format policy")
            for platform, count in counts.items():
                _l.info(f"  {platform}: {count}")

        if config.pruning:
            if not config.pkg_list:
                request_names = groups.keys()
            n_iter = 1
            while True:
                prune_uninstallable(groups)
                n_pkgs2 = _info(groups, "pruning uninstallable %d" % n_iter)
                prune_untouchables(groups, request_names)
                if n_iter > 1 and n_pkgs2 == n_pkgs:
                    break
                n_pkgs3 = _info(groups, "pruning unreachable %d" % n_iter)
                if n_pkgs3 == n_pkgs or n_iter == 1 and n_pkgs3 == n_pkgs2:
                    break
                n_pkgs = n_pkgs3
                n_iter += 1

    if config.pkg_list:
        missing = set()
        groups = repodata_to_groups(upstream_indices)
        for msname, msmatch in requests.items():
            for name in groups:
                if msname.match(name) and any(
                    msmatch.match(pkg)
                    for vdict in groups[name].values()
                    for sdict in vdict.values()
                    for pkg in sdict.values()
                ):
                    break
            else:
                missing.add(f"{msname}: {msmatch}")
        if missing:
            msg = [
                "One or more requested packages were found to be uninstallable:",
                *(" - " + r for r in sorted(missing)),
                "Please examine your filter and try again.",
            ]
            for line in msg:
                _l.error("! %s", line)
            raise MirrorException("\n".join(msg))
