#!/usr/bin/python3
#
# Copyright (c) 2016 Qualcomm Atheros, Inc.
# Copyright (c) 2018, The Linux Foundation. All rights reserved.
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import os
import logging
import re
import argparse
import shutil
import sys
import filecmp
import functools
import subprocess
import email

# global variables
logger = None
blacklist = False

BRANCH_DEFAULT_PRIORITY = 1000
BRANCH_PRIORITY_FILE = '.priority'
WHENCE_FILE = 'WHENCE'

FIRMWARE_BLACKLIST = [
    '999.999.0.636',
    '10.2-00082-4-2'
]


@functools.total_ordering
class Hardware():
    def get_path(self):
        return os.path.join(self.hw, self.hw_ver)

    def __eq__(self, other):
        return self.name == other.name

    def __lt__(self, other):
        return self.name < other.name

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        return 'Hardware(\'%s\'): %s %s' % (self.name, self.board_files,
                                            sorted(self.firmware_branches))

    def __init__(self, hw, hw_ver):
        # QCA6174
        self.hw = hw

        # hw3.0
        self.hw_ver = hw_ver

        self.name = '%s %s' % (hw, hw_ver)

        self.firmware_branches = []
        self.board_files = []


@functools.total_ordering
class FirmwareBranch():
    def __eq__(self, other):
        return self.priority == other.priority and self.name == other.name

    def __lt__(self, other):
        # '.' is always of the lower priority
        if self.name == '.':
            return True

        if other.name == '.':
            return False

        if self.priority != other.priority:
            if self.priority < other.priority:
                return True
            else:
                return False

        return self.name < other.name

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        return 'FirmwareBranch(\'%s\'): %s' % (self.name, sorted(self.firmwares))

    def __init__(self, name, path=None):
        self.name = name
        self.firmwares = []
        self.priority = BRANCH_DEFAULT_PRIORITY

        if path:
            priority_path = os.path.join(path, BRANCH_PRIORITY_FILE)
            if os.path.isfile(priority_path):
                try:
                    f = open(priority_path, 'r')
                    buf = f.read()
                    f.close()

                    self.priority = int(buf)
                except Exception as e:
                    print('Failed to read %s: %s' % (priority_path, e))


class BoardFile():

    @staticmethod
    def create_from_path(path):
        filename = os.path.basename(path)

        match = re.search(r'^board-(\d+).bin', filename)
        if match is None:
            match = re.search(r'^board.bin', filename)
            if match is None:
                return None

        if len(match.groups()) > 1:
            bd_api = match.group(1)
        else:
            bd_api = None

        return BoardFile(path, bd_api)

    def get_installation_name(self):
        return os.path.basename(self.path)

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        return '%s' % (self.get_installation_name())

    def __init__(self, path, bd_api):
        self.path = path
        self.bd_api = bd_api


class Firmware():

    @staticmethod
    def create_from_path(path):
        filename = os.path.basename(path)

        match = re.search(r'^firmware-(\d+).bin_(.+)', filename)
        if match is None:
            return None

        fw_api = match.group(1)
        fw_ver = match.group(2)

        return Firmware(fw_ver, path, fw_api)

    def get_installation_name(self):
        return 'firmware-%s.bin' % (self.fw_api)

    def get_notice_filename(self):
        return 'notice_ath10k_firmware-%s.txt' % (self.fw_api)

    def __eq__(self, other):
        return self.fw_ver == other.fw_ver

    def __ne__(self, other):
        return not self.__eq__(other)

    # FIXME: firmware-5.bin_10.4-3.2-00080 and
    # firmware-5.bin_10.4-3.2.1-00028 are sorted incorrectly
    def __lt__(self, other):
        s = self.fw_ver
        o = other.fw_ver

        # FIXME: An ugly hack that to make the comparison easier to
        # implement. Just to get some sort of simple sorting working
        # replace '-' with '.' in version string. But now for example
        # '10.2.4.70.2 > 10.2.4.70-2' is not compared correctly.

        s = s.replace('-', '.')
        o = o.replace('-', '.')

        s = s.split('.')
        o = o.split('.')

        s2 = s
        o2 = o

        s = []
        o = []

        for t in s2:
            try:
                k = int(t)
            except:
                k = t

            s.append(k)

        for t in o2:
            try:
                k = int(t)
            except:
                k = t

            o.append(k)

        l = min(len(s), len(o))

        for i in range(l):

            if s[i] < o[i]:
                return True
            elif s[i] > o[i]:
                return False

        if len(s) > len(o):
            return False

        return True

    def __le__(self, other):
        return self.__lt__(other) or self.__eq__(other)

    def __gt__(self, other):
        return not self.__le__(other)

    def __ge__(self, other):
        return self.__gt__(other) or self.__eq__(other)

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        return '%s' % (self.fw_ver)

    # path can be None with unittests
    def __init__(self, fw_ver, path=None, fw_api=None):
        self.path = path
        self.fw_api = fw_api
        self.fw_ver = fw_ver
        self.notice_path = None

        if path:
            s = 'notice.txt_%s' % (self.fw_ver)
            n = os.path.join(os.path.dirname(path), s)
            if os.path.isfile(n):
                self.notice_path = n


def scan_dir(path):
    fw_list = []
    bd_list = []

    files = os.listdir(path)
    files.sort()

    for f in files:
        f_path = os.path.join(path, f)

        if not os.path.isfile(f_path):
            continue

        firmware = Firmware.create_from_path(f_path)
        if firmware:
            if blacklist and firmware.fw_ver in FIRMWARE_BLACKLIST:
                logger.debug('\'%s\' blacklisted' % (firmware.fw_ver))
                continue

            fw_list.append(firmware)
            continue

        boardfile = BoardFile.create_from_path(f_path)
        if boardfile:
            bd_list.append(boardfile)
            continue

        # skip notice files
        if f.startswith('notice.txt'):
            continue

        if f == BRANCH_PRIORITY_FILE:
            continue

        logger.warning('Unknown file: %s' % (f_path))

    return fw_list, bd_list


# QCA988X/hw2.0
def scan_hw_ver(hw):
    path = hw.get_path()
    files = os.listdir(path)
    files.sort()

    for fw_branch in files:
        fw_branch_path = os.path.join(path, fw_branch)

        if not os.path.isdir(fw_branch_path):
            continue

        logger.debug('Found firmware branch: %s' % (fw_branch))
        fb = FirmwareBranch(fw_branch, fw_branch_path)
        hw.firmware_branches.append(fb)

        (fw, bd) = scan_dir(fw_branch_path)
        fb.firmwares += fw

        # board file should be only in "." (main) branch, not in other branches

    # this is the main branch named "."
    fb = FirmwareBranch('.')

    (fw, bd) = scan_dir(path)
    fb.firmwares += fw
    hw.board_files += bd

    hw.firmware_branches.append(fb)


# QCA98XX
def scan_hw(path):
    hws = []

    files = os.listdir(path)
    files.sort()

    for hw_ver in files:
        hw_ver_path = os.path.join(path, hw_ver)

        if not os.path.isdir(hw_ver_path):
            continue

        logger.debug('Found hw version: %s' % (hw_ver))

        hw = Hardware(path, hw_ver)
        scan_hw_ver(hw)
        hws.append(hw)

    return hws


def scan_repository(directory):
    hws = {}

    files = os.listdir(directory)
    files.sort()

    for hw_name in files:
        if not os.path.isdir(hw_name):
            continue

        # skip hidden directories
        if hw_name.startswith('.'):
            continue

        logger.debug('Found hw: %s' % (hw_name))

        hw_list = scan_hw(hw_name)

        for hw in hw_list:
            hws[hw.name] = hw

    return hws


def install_file(args, src, installdir, dest):
    if args.dry_run:
        return

    destpath = os.path.join(installdir, dest)

    destdir = os.path.dirname(destpath)
    if not os.path.isdir(destdir):
        os.makedirs(destdir)

    shutil.copyfile(src, destpath)

    return destpath


def get_firmware_version(path):
    cmd = ['ath10k-fwencoder', '--info', path]
    info = subprocess.check_output(cmd, universal_newlines=True)
    msg = email.message_from_string(info)
    return msg['FirmwareVersion']


def get_board_crc32(path):
    cmd = ['ath10k-fwencoder', '--crc32', path]
    return subprocess.check_output(cmd, universal_newlines=True).strip()


# print indent
def pi(level, msg):
    print('%s%s' % (level * '\t', msg))


def whence_update(linux_firmware, firmware_path, version):
    f = open(os.path.join(linux_firmware, WHENCE_FILE), 'r')
    buf = f.read()
    f.close()

    pattern = r'(File: %s\nVersion: ).*\n' % (firmware_path)

    # \g<1> is same as \1 but needed to separate from the version string
    replace = r'\g<1>%s\n' % (version)

    (buf, sub_count) = re.subn(pattern, replace, buf, flags=re.MULTILINE)

    if sub_count != 1:
        print('Failed to update %s to WHENCE: %d' % (firmware_path, sub_count))
        return

    f = open(os.path.join(linux_firmware, WHENCE_FILE), 'w')
    f.write(buf)
    f.close()


def whence_add(linux_firmware, firmware_path, version, license_path=None):
    f = open(os.path.join(linux_firmware, WHENCE_FILE), 'r')
    buf = f.read()
    f.close()

    pattern = r'(Driver: ath10k.*?\n\n.*?)\n\n'

    # \g<1> is same as \1 but needed to separate from the version string
    replace = r'\g<1>\nFile: %s\n' % (firmware_path)
    replace += r'Version: %s\n' % (version)

    if license_path is not None:
        replace += r'File: %s\n' % (license_path)

    replace += r'\n'

    (buf, sub_count) = re.subn(pattern, replace, buf, flags=re.MULTILINE | re.DOTALL)

    if sub_count != 1:
        print('Failed to add %s to WHENCE: %d' % (firmware_path, sub_count))
        return

    f = open(os.path.join(linux_firmware, WHENCE_FILE), 'w')
    f.write(buf)
    f.close()


def git_commit(args, msg, repodir, files):
    if not args.commit:
        # nothing to do
        return

    cmd = ['git', '-C', repodir, 'commit', '--quiet', '--signoff', '-m', msg] + files

    logger.debug('Running: %r' % (cmd))
    subprocess.check_call(cmd)


def git_add(args, repodir, files):
    if not args.commit:
        # nothing to do
        return

    cmd = ['git', '-C', repodir, 'add'] + files

    logger.debug('Running: %r' % (cmd))
    subprocess.check_call(cmd)


def cmd_check(args):
    scan_repository('.')


def cmd_list(args):
    level = 0

    hws = scan_repository('.')
    for hw in sorted(hws.values()):
        pi(level, '%s:' % (hw.name))
        level += 1

        # print board files
        if len(hw.board_files) > 0:
            pi(level, 'board')
            level += 1

            for board_file in sorted(hw.board_files):
                pi(level, board_file)

            level -= 1

        # print firmware branches
        for branch in sorted(hw.firmware_branches):
            if len(branch.firmwares) == 0:
                # don't print empty branches
                continue

            pi(level, '%s' % (branch.name))
            level += 1

            for fw in sorted(branch.firmwares):
                pi(level, fw.fw_ver)

            level -= 1

        level -= 1


def cmd_list_lib_dir(args):
    fw_dir = args.list_lib_dir[0]
    ath10k_dir = os.path.join(fw_dir, 'ath10k')

    if not os.path.exists(ath10k_dir):
        print('directory %s does not exist, aborting' % (ath10k_dir))
        sys.exit(1)

    if not os.path.isdir(ath10k_dir):
        print('%s is not a directory, aborting' % (ath10k_dir))
        sys.exit(1)

    # sort the results based on dirpath
    for (dirpath, dirnames, filenames) in sorted(os.walk(ath10k_dir)):
        found = []
        for filename in sorted(filenames):
            path = os.path.join(dirpath, filename)

            match = re.match('firmware.*\.bin', filename)
            if match is not None:
                # this is a firmware file
                s = '%s\t%s' % (filename, get_firmware_version(path))
                found.append(s)

            match = re.match('board.*\.bin', filename)
            if match is not None:
                # this is a board file
                s = '%s\t%s' % (filename, get_board_crc32(path))
                found.append(s)

        if len(found) > 0:
            # Just show QCA1234/hw1.0 directories. I would have liked
            # to use os.path functions here but just could not find
            # anything sensible there.
            pi(0, '%s:' % ('/'.join(dirpath.split('/')[-2:])))
            for line in found:
                pi(1, line)


def cmd_get_latest_in_branch(args):
    # As this command is mostly for scripts to parse, don't show
    # warnings etc to clutter the output, unless we are debugging of
    # course.
    if not args.debug:
        logger.setLevel(logging.ERROR)

    hws = scan_repository('.')

    args_hw = args.get_latest_in_branch[0]
    args_hwver = args.get_latest_in_branch[1]
    args_fwbranch = args.get_latest_in_branch[2]

    # TODO: hw is always in uppercase and hwver lower case, check that
    hw_name = '%s %s' % (args_hw, args_hwver)

    if hw_name not in hws:
        print('Did not find hardware: %s' % (hw_name))
        sys.exit(1)

    hw = hws[hw_name]

    fw_branch = None

    for b in hw.firmware_branches:
        if b.name == args_fwbranch:
            fw_branch = b
            break

    if fw_branch is None:
        print('Did not find firmware branch: %s' % (args_fwbranch))
        sys.exit(1)

    if len(fw_branch.firmwares) == 0:
        # no firmware images in this branch, just use return value 0 with no output
        sys.exit(0)

    print(sorted(fw_branch.firmwares)[-1].path)

    sys.exit(0)


def cmd_get_latest_in_hw(args):
    # As this command is mostly for scripts to parse, don't show
    # warnings etc to clutter the output, unless we are debugging of
    # course.
    if not args.debug:
        logger.setLevel(logging.ERROR)

    hws = scan_repository('.')

    args_hw = args.get_latest[0]
    args_hwver = args.get_latest[1]

    # TODO: hw is always in uppercase and hwver lower case, check that
    hw_name = '%s %s' % (args_hw, args_hwver)

    if hw_name not in hws:
        print('Did not find hardware: %s' % (hw_name))
        sys.exit(1)

    hw = hws[hw_name]

    fw_branch = sorted(hw.firmware_branches)[-1]

    if len(fw_branch.firmwares) == 0:
        # no firmware images in this branch, just use return value 0 with no output
        sys.exit(0)

    print(sorted(fw_branch.firmwares)[-1].path)

    sys.exit(0)


def cmd_install(args):
    global blacklist

    blacklist = True

    hws = scan_repository('.')

    linux_firmware = args.install[0]
    installdir = os.path.join(linux_firmware, 'ath10k')

    if not os.path.isdir(installdir):
        logger.error('%s is not a directory' % (installdir))
        sys.exit(1)

    logger.debug('Installing to directory %s' % (installdir))

    for hw in sorted(hws.values()):
        bd_list = hw.board_files

        # every Hardware() should have at least one firmware branch, the
        # main '.' branch so no need to check the length
        fw_list = sorted(sorted(hw.firmware_branches)[-1].firmwares)

        if len(fw_list) == 0:
            # no firmware images found
            continue

        destdir = hw.get_path()

        # install latest firmware
        #
        # FIXME: should we install the latest for every FW API
        # version (2, 4, 5 etc.), not just the latest?
        fw = fw_list[-1]

        dest = os.path.join(destdir, fw.get_installation_name())
        d = os.path.join(installdir, dest)
        if not os.path.exists(d) or not filecmp.cmp(fw.path, d):
            installed = []

            if os.path.exists(d):
                action = 'update'
            else:
                action = 'add'

            logger.info('Installing %s (%s)' % (dest, fw.fw_ver))
            destpath = install_file(args, fw.path, installdir, dest)
            installed.append(destpath)
            firmware_path = os.path.join('ath10k', dest)

            if fw.notice_path:
                dest = os.path.join(destdir, fw.get_notice_filename())
                logger.info('Installing %s (%s)' % (dest, fw.fw_ver))
                destpath = install_file(args, fw.notice_path, installdir, dest)
                installed.append(destpath)
                notice_path = os.path.join('ath10k', dest)
            else:
                notice_path = None

            if action == 'update':
                # updating an existing firmware file
                whence_update(linux_firmware, firmware_path, fw.fw_ver)
            else:
                # adding a new firmware file
                whence_add(linux_firmware, firmware_path, fw.fw_ver, notice_path)
                git_add(args, linux_firmware, installed)

            installed.append(WHENCE_FILE)

            # "ath10k: QCA9888 hw2.0: update firmware-5.bin to 10.4-3.5.1-00035"
            log = 'ath10k: %s: %s %s to %s' % (hw.name,
                                               action,
                                               fw.get_installation_name(),
                                               fw.fw_ver)

            git_commit(args, log, linux_firmware, installed)
        else:
            logger.debug(
                'No update needed for %s (%s)' % (dest, fw.fw_ver))

        # install board files
        for bd in bd_list:
            installed = []
            dest = os.path.join(installdir, bd.path)
            if not os.path.exists(dest) or not filecmp.cmp(bd.path, dest):
                if os.path.exists(dest):
                    action = 'update'
                else:
                    action = 'add'

                logger.info('Installing %s' % (bd.path))
                destpath = install_file(args, bd.path, installdir, bd.path)
                installed.append(destpath)

                # TODO: update WHENCE file when adding a new board
                # file to linux-firmware

                log = 'ath10k: %s: %s %s' % (hw.name,
                                             action,
                                             bd.get_installation_name())
                git_commit(args, log, linux_firmware, installed)
            else:
                logger.debug('No update needed for %s' % (dest))


def main():
    global logger

    logger = logging.getLogger('ath10k-fw-repo')

    parser = argparse.ArgumentParser(
        description='Install firmware images from the ath10k-firmware git repository. Run it from the top directory of the working tree.')

    parser.add_argument('--debug', action='store_true',
                        help='Enable debug messages.')
    parser.add_argument('--dry-run', action='store_true',
                        help='Do not run any actual commands.')

    parser.add_argument('--check', action='store_true',
                        help='Check the ath10k-firmware repository content for validity.')
    parser.add_argument('--list', action='store_true',
                        help='List all files found from the ath10k-firmware repository.')

    parser.add_argument('--list-lib-dir', action='store',
                        nargs=1, metavar='LIB_FIRMWARE_DIRECTORY',
                        help='List all files found from the specified directory, which can either be a linux-firmware repository or /lib/firmware directory.')

    parser.add_argument('--install', action='store', nargs=1, metavar='DESTINATION',
                        help='Install all ath10k firmware images to DESTINATION folder, for example /lib/firmware.')

    parser.add_argument('--commit', action='store_true',
                        help='When installing files also git commit them, for example when updating linux-firmware.git.')

    parser.add_argument('--get-latest-in-branch', action='store', nargs=3,
                        metavar=('HW', 'HWVER', 'BRANCH'),
                        help='Show latest firmware version from a firmware branch. Just outputs the version for easy parsing in scripts.')

    parser.add_argument('--get-latest', action='store', nargs=2,
                        metavar=('HW', 'HWVER'),
                        help='Show latest firmware version for hardware version. Just outputs the version for easy parsing in scripts.')

    args = parser.parse_args()

    if args.debug:
        logging.basicConfig(format='%(levelname)s: %(message)s')
        logger.setLevel(logging.DEBUG)
    else:
        logging.basicConfig(format='%(message)s')
        logger.setLevel(logging.INFO)

    # commands
    if args.check:
        cmd_check(args)
    elif args.list:
        cmd_list(args)
    elif args.list_lib_dir:
        cmd_list_lib_dir(args)
    elif args.install:
        cmd_install(args)
    elif args.get_latest_in_branch:
        cmd_get_latest_in_branch(args)
    elif args.get_latest:
        cmd_get_latest_in_hw(args)
    else:
        logger.error('No command defined')
        parser.print_usage()

if __name__ == "__main__":
    main()
