#!/usr/bin/env python3

# Copyright 2022 Enrico Lumetti
# SPDX-License-Identifier: GPL-3.0-or-later

from enum import Enum, auto
from collections import namedtuple, defaultdict
import copy
import argparse
import struct
import sys
from pathlib import Path, PurePosixPath
import subprocess
import tempfile

Header = namedtuple('Header', [
    'version_major',
    'version_minor',
    'header_size',
    'chunk_header_size',
    'block_size',
    'output_blocks',
    'input_chunks',
    'crc32_checksum',
    'extra_bytes',
])

Chunk = namedtuple('Chunk', [
    'type',
    'reserved',
    'ouput_blocks',
    'input_bytes',
    'extra_bytes',
])

class Chunk:
    def __init__(self, parent_header, offset, type, reserved, output_blocks, input_bytes, extra_bytes):
        self.header = parent_header
        self.offset = offset
        self.type = type
        self.reserved = reserved
        self.output_blocks = output_blocks
        self.input_bytes = input_bytes
        self.extra_bytes = extra_bytes

    def get_type_repr(self):
        if self.type == 0xCAC1:
            return 'RAW'
        elif self.type == 0xCAC2:
            return 'FILL'
        elif self.type == 0xCAC3:
            return 'DONTCARE'
        else:
            return 'UNKNOWN'

    def __repr__(self):
        return (f'Chunk(offset={self.offset})'
           + ', ' + f'type={self.get_type_repr()}'
           + ', ' + f'reserved={self.reserved:04x}'
           + ', ' + f'output_blocks={self.output_blocks}'
           + ', ' + f'input_bytes={self.input_bytes}'
           + ', ' + f'extra_bytes={self.extra_bytes}'
           + ')'
        )

    def is_samsung_custom_chunk(self):
        res = False

        if len(self.header.extra_bytes) == 4:
            header_extra = int.from_bytes(self.header.extra_bytes, 'little')
            # Samsung devices after 2014, Android 5
            if header_extra == 0 and len(self.extra_bytes) == 4:
                extra = int.from_bytes(self.extra_bytes, 'little')
                if extra == 0xf7722f98 and self.type == 0xCAC1:
                    res = True
                if extra == 0xf77234b6 and self.type == 0xCAC1:
                    res = True
                if extra == 0x1000 and self.type == 0xCAC3:
                    res = True

        return res

    def data_length(self):
        if self.type == 0xCAC1:
            return self.input_bytes - self.header.chunk_header_size
        elif self.type == 0xCAC2 or self.type == 0xCAC3:
            return self.output_blocks * self.header.block_size

    def fill_bytes(self, fp):
        cur = fp.tell()
        try:
            fp.seek(self.offset + self.header.chunk_header_size)
            if self.type == 0xCAC2: # FILL
                return fp.read(4)
        finally:
            fp.seek(cur)

    def read_from_file(self, fp):
        cur = fp.tell()
        try:
            fp.seek(self.offset + self.header.chunk_header_size)
            if self.type == 0xCAC1: # RAW
                return fp.read(self.input_bytes - self.header.chunk_header_size)
            elif self.type == 0xCAC2: # FILL
                fill_bytes = fp.read(4)
                return fill_bytes * (self.data_length() // 4)
            elif self.type == 0xCAC3: # DONTCARE
                return bytes(self.data_length())
        finally:
            fp.seek(cur)

        return data


class ImgFormatException(Exception):
    pass

class InvalidMagicNumberException(ImgFormatException):
    def __init__(self, wrong_value):
        self.wrong_value = wrong_value

class InvalidAlignmentException(ImgFormatException):
    def __init__(self, name, size):
        self.name = name
        self.block_size = block_size

class InvalidMajorVersion(ImgFormatException):
    def __init__(self, major_version, minor_version):
        self.major_version = major_version
        self.minor_version = minor_version

class ChunkSizeMismatchException(ImgFormatException):
    def __init__(self, chunk_size, block_size):
        self.chunk_size = chunk_size
        self.block_size = block_size

class InvalidChunkTypeException(ImgFormatException):
    def __init__(self, chunk_type):
        self.chunk_type = chunk_type


def read_header(fp):
    magic_number = int.from_bytes(fp.read(4), 'little')
    if magic_number != 0xed26ff3a:
        raise InvalidMagicNumberException(magic_number)

    version_major = int.from_bytes(fp.read(2), 'little')
    version_minor = int.from_bytes(fp.read(2), 'little')
    if version_major > 1:
        raise InvalidMajorVersionException(version_major, version_minor)

    header_size = int.from_bytes(fp.read(2), 'little')
    chunk_header_size = int.from_bytes(fp.read(2), 'little')
    block_size = int.from_bytes(fp.read(4), 'little')
    if header_size % 4 != 0:
        raise InvalidAlignmentException('header size', header_size)
    if chunk_header_size % 4 != 0:
        raise InvalidAlignmentException('chunk header size', header_size)
    if block_size % 4 != 0:
        raise InvalidAlignmentException('block size', block_size)

    output_blocks = int.from_bytes(fp.read(4), 'little')
    input_chunks = int.from_bytes(fp.read(4), 'little')
    crc32_checksum = int.from_bytes(fp.read(4), 'little')

    extra_bytes = fp.read(header_size - 28)

    return Header(
        version_major = version_major,
        version_minor = version_minor,
        header_size = header_size,
        chunk_header_size = chunk_header_size,
        block_size = block_size,
        output_blocks = output_blocks,
        input_chunks = input_chunks,
        crc32_checksum = crc32_checksum,
        extra_bytes = extra_bytes,
    )

def iterate_chunks(fp, header):
    fp.seek(header.header_size)

    chunk_num = 0
    while fp:
        offset = fp.tell()
        fmt = '<HHII'
        chunk_header = fp.read(struct.calcsize(fmt))
        fields = struct.unpack(fmt, chunk_header)
        # skip remaining header bytes
        header_extra = fp.read(header.chunk_header_size - struct.calcsize(fmt))
        # read chunk payload
        input_bytes = fields[3]
        output_blocks = fields[2]

        ch = Chunk(
            header,
            offset,
            fields[0],
            fields[1],
            output_blocks,
            input_bytes,
            header_extra
        )
        if output_blocks * header.block_size != ch.data_length():
            raise ChunkSizeMismatchException(output_blocks*header.block_size, ch.data_length())

        if input_bytes != 0:
            fp.seek(input_bytes - header.chunk_header_size, 1)
        yield ch

        chunk_num += 1
        if chunk_num == header.input_chunks:
            break

def find_first_mismatch_pos(buf1, buf2):
    for i in range(0, min(len(buf1), len(buf2))):
        if buf1[i] != buf2[i]:
            return i


    return min(len(buf1), len(buf2))

def handle_analyze(args):
    with open(args.input_file, 'rb') as fp:
        # TODO: exception handling
        header = read_header(fp)

        print('Sparse Image Header:')
        print(f'  Version: {header.version_major}.{header.version_minor}')
        print(f'  Block size: {header.block_size}')
        print(f'  Output Blocks: {header.output_blocks}')
        print(f'  Chunk Header Size: {header.chunk_header_size}')
        print(f'  Input Chunks: {header.input_chunks}')
        print(f'  CRC32: {header.crc32_checksum}')
        if len(header.extra_bytes) != 0:
            print(f'  Extra Bytes: 0x{header.extra_bytes.hex()}')

        num_chunks = 0
        num_blocks = 0
        samsung_chunks_stats = defaultdict(int)
        for chunk in iterate_chunks(fp, header):
            if chunk.is_samsung_custom_chunk():
                samsung_chunks_stats[chunk.extra_bytes] += 1
            num_chunks += 1
            num_blocks += chunk.output_blocks

        if len(samsung_chunks_stats.keys()) != 0:
            print('Found Samsung Proprietary Chunks:')
            for k, v in samsung_chunks_stats.items():
                pad = len(k)
                ex = int.from_bytes(k, 'little')
                print(f'  Extra bytes {ex:0{pad}x}: {v} occurences')

        print(f'Output Blocks: {num_blocks} ({num_blocks * header.block_size} bytes)')

        cur = fp.tell()
        fp.seek(0, 2)
        extra_bytes_len = fp.tell() - cur

        if extra_bytes_len != 0:
            print(f'Analyzed f{num_chunks} chunk')
            print(f'Got {extra_bytes_len} bytes at file end')

def handle_inspect_chunk(args):
    with open(args.input_file, 'rb') as fp:
        # TODO: exception handling
        header = read_header(fp)

        if args.chunk_number > header.input_chunks:
            print(f'Cannot read chunk #{args.chunk_number}: file has {header.input_chunks} chunks', file=sys.stderr)
            exit(1)

        num_chunks = 1
        num_blocks = 0
        for chunk in iterate_chunks(fp, header):
            if num_chunks == args.chunk_number:
                break
            num_chunks += 1
            num_blocks += chunk.output_blocks

        print('Chunk offset in image file:', chunk.offset)
        print(f'Chunk output blocks: {num_blocks}-{num_blocks+chunk.output_blocks}'
            + f' ({chunk.output_blocks} blocks,'
            + f' {chunk.output_blocks * header.block_size} bytes)')
        print('Chunk type:', chunk.get_type_repr())
        if chunk.type == 0xCAC2: # FILL
            print(f'Fill bytes: 0x{chunk.fill_bytes(fp).hex()}')
        print('Chunk payload length:', chunk.data_length())
        if chunk.is_samsung_custom_chunk():
            # TODO
            pass

def handle_print_chunk(args):
    with open(args.input_file, 'rb') as fp:
        # TODO: exception handling
        header = read_header(fp)

        if args.chunk_number > header.input_chunks:
            print(f'Cannot read chunk #{args.chunk_number}: file has {header.input_chunks} chunks')
            exit(1)

        num = 1
        for chunk in iterate_chunks(fp, header):
            if num == args.chunk_number:
                res = chunk.read_from_file(fp)
                if args.r:
                    sys.stdout.buffer.write(res)
                else:
                    print(res.hex())
                break
            num += 1

def read_local_block(path, block_size, block_offset, block_count):
    with open(path, 'rb') as fp:
        fp.seek(block_offset * block_size)
        return fp.read(block_count * block_size)

def read_remote_block(args, block_size, block_offset, block_count, chunk_num):
    # TODO: check adb devices?
    out_name = f'blk{chunk_num}.img'
    remote_tmp_path = PurePosixPath(args.remote_tmp_dir, out_name)
    ret1 = subprocess.run(['adb', 'shell', 'su', '-c', '/data/adb/magisk/busybox',
        f'"dd bs={block_size} if={args.remote_device} of={remote_tmp_path} '
         + f'skip={block_offset} count={block_count}"'
    ], capture_output=True)
    if ret1.returncode != 0:
        # TODO
        raise Exception()
    with tempfile.TemporaryDirectory() as tmpdirname:
        ret2 = subprocess.run(['adb', 'pull', remote_tmp_path, tmpdirname], capture_output=True)
        if ret2.returncode != 0:
            # TODO:
            raise Exception()
        ret3 = subprocess.run(['adb', 'shell', 'rm', remote_tmp_path], capture_output=True)
        if ret3.returncode != 0:
            # TODO: log warning
            pass

        return read_local_block(Path(tmpdirname, out_name),
            block_size,
            0,
            block_count
        )

def handle_print_blocks(args):
    # TODO: allow choice
    remote_read = True
    with open(args.input_file, 'rb') as fp:
        # TODO: exception handling
        header = read_header(fp)

        if args.chunk_number > header.input_chunks:
            print(f'Cannot read chunk #{args.chunk_number}: file has {header.input_chunks} chunks')
            exit(1)

        num = 1
        payload = b''
        block_count = 0
        for chunk in iterate_chunks(fp, header):
            if num == args.chunk_number:
                break
            block_count += chunk.output_blocks
            num += 1

        if remote_read:
            data = read_remote_block(args, header.block_size, block_count, chunk.output_blocks, num)
        else:
            data = read_local_block(args.partition_file, block_size, block_count, chunk.output_blocks)

        if args.r:
            sys.stdout.buffer.write(data)
        else:
            print(data.hex())

def handle_compare_blocks(args):
    # TODO: check adb status and busybox
    # TODO: allow choice
    remote_read = True
    with open(args.input_file, 'rb') as fp:
        # TODO: exception handling
        header = read_header(fp)

        skip_treshold = 0 if args.skip is None else args.skip
        chunk_num = 1
        block_count = 0
        for chunk in iterate_chunks(fp, header):
            if chunk_num > skip_treshold:
                chunk_data = chunk.read_from_file(fp)
                if remote_read:
                    block_data = read_remote_block(args, header.block_size, block_count, chunk.output_blocks, chunk_num)
                else:
                    block_data = read_local_block(args.partition_file, block_size, block_count, chunk.output_blocks)

                if block_data == chunk_data:
                    print(f'OK: Chunk {chunk_num} matches blocks {block_count}-{block_count+chunk.output_blocks}')
                else:
                    i = find_first_mismatch_pos(chunk_data, block_data)
                    print(f'FAIL: Chunk {chunk_num} doesn\'t match blocks {block_count}-{block_count+chunk.output_blocks}')
                    print(f'      first non-matching byte at pos {i}')
                    print(f'      chunk_data len: {len(chunk_data)}, block_data len: {len(block_data)}')

            block_count += chunk.output_blocks
            chunk_num += 1

def write_header(header, ofp):
    fmt = '<IHHHHIIII'
    ofp.write(struct.pack(fmt,
        0xed26ff3a,
        header.version_major,
        header.version_minor,
        header.header_size,
        header.chunk_header_size,
        header.block_size,
        header.output_blocks,
        header.input_chunks,
        header.crc32_checksum,
    ))
    ofp.write(header.extra_bytes)

def write_chunk(chunk, payload, ofp):
    fmt = '<HHII'
    ofp.write(struct.pack(fmt,
        chunk.type,
        chunk.reserved,
        chunk.output_blocks,
        chunk.input_bytes
    ))
    ofp.write(chunk.extra_bytes)
    ofp.write(payload)

def patch_header(header, strategy):
    if strategy == 1:
        return Header(
            version_major = header.version_major,
            version_minor = header.version_minor,
            header_size = header.header_size + 4,
            chunk_header_size = header.chunk_header_size + 4,
            block_size = header.block_size,
            output_blocks = header.output_blocks,
            input_chunks = header.input_chunks,
            crc32_checksum = header.crc32_checksum,
            extra_bytes = bytes(4),
        )

def patch_chunk(chunk, ifp, strategy):
    patched_chunk = copy.copy(chunk)
    payload = b''
    if strategy == 1:
        if chunk.type == 0xCAC1: # RAW
            patched_chunk.extra_bytes = bytes(4)
            patched_chunk.input_bytes += 4
            payload = chunk.read_from_file(ifp)
        elif chunk.type == 0xCAC2: # FILL
            fill_bytes = chunk.fill_bytes(ifp)
            if fill_bytes != bytes(4):
                # must convert to RAW
                patched_chunk.type = 0XCAC1
                patched_chunk.extra_bytes = bytes(4)
                payload = chunk.read_from_file(ifp)
                patched_chunk.input_bytes += len(payload)
            else:
                # convert to DONTCARE
                patched_chunk.type = 0xCAC3
                patched_chunk.extra_bytes = bytes([0, 0x10, 0, 0])
                # there is no need to increment input_bytes, because
                # the fill bytes have the same length as the samsung
                # extra bytes
        elif chunk.type == 0xCAC3: # DONTCARE
            patched_chunk.extra_bytes = bytes([0, 0x10, 0, 0])
            patched_chunk.input_bytes += 4
        else:
            raise InvalidChunkTypeException(chunk.type)

    return patched_chunk, payload

def handle_samsungify(args):
    with open(args.input_file, 'rb') as ifp:
        with open(args.output_file, 'wb') as ofp:
            header = read_header(ifp)
            new_header = patch_header(header, args.strategy)
            write_header(new_header, ofp)

            for chunk in iterate_chunks(ifp, header):
                patched_chunk, payload = patch_chunk(chunk, ifp, args.strategy)
                write_chunk(patched_chunk, payload, ofp)

def build_argument_parser():
    parser = argparse.ArgumentParser(
        description = 'Tool to analyze and manipulate Samsung priopretary sparse image file'
    )
    subparsers = parser.add_subparsers()

    parser_analyze = subparsers.add_parser('analyze',
        description = 'Analyze a sparse image file'
    )
    parser_analyze.add_argument('input_file')
    parser_analyze.set_defaults(func=handle_analyze)

    parser_inspect_chunk = subparsers.add_parser('inspect-chunk',
        description = 'Inspect a chunk of the given sparse image'
    )
    parser_inspect_chunk.add_argument('input_file')
    parser_inspect_chunk.add_argument('chunk_number', type=int)
    parser_inspect_chunk.set_defaults(func=handle_inspect_chunk)

    parser_read_chunk = subparsers.add_parser('print-chunk',
         description = 'Read a chunk from the image file and print to stdout. Chunk starts from 1.'
    )
    parser_read_chunk.add_argument('input_file')
    parser_read_chunk.add_argument('chunk_number', type=int)
    parser_read_chunk.add_argument('-r', action='store_true')
    parser_read_chunk.set_defaults(func=handle_print_chunk)

    parser_print_block = subparsers.add_parser('print-blocks',
         description = 'Print the blocks corresponding to a chunk. Chunk starts from 1.'
    )
    parser_print_block.add_argument('input_file')
    parser_print_block.add_argument('chunk_number', type=int)
    parser_print_block.add_argument('-r', action='store_true')
    parser_print_block.add_argument('--partition-file')
    parser_print_block.add_argument('--remote-device')
    parser_print_block.add_argument('--remote-tmp-dir')
    parser_print_block.set_defaults(func=handle_print_blocks)

    parser_compare_blocks = subparsers.add_parser('compare-blocks',
         description = 'Compare all the blocks in the input image to a remote or local partition'
    )
    parser_compare_blocks.add_argument('input_file')
    parser_compare_blocks.add_argument('--partition-file')
    parser_compare_blocks.add_argument('--remote-device')
    parser_compare_blocks.add_argument('--remote-tmp-dir')
    parser_compare_blocks.add_argument('--skip', type=int)
    parser_compare_blocks.set_defaults(func=handle_compare_blocks)

    parser_samsungify = subparsers.add_parser('samsungify',
        description = 'Convert a standard Android sparse image to Samsung\'s format'
    )
    parser_samsungify.add_argument('input_file')
    parser_samsungify.add_argument('output_file')
    parser_samsungify.add_argument('--strategy', type=int)
    parser_samsungify.set_defaults(func=handle_samsungify)
    return parser

def validate_args(args):
    if args.func == handle_samsungify:
        strategies = 1
        if args.strategy is None:
            print(f'Missing --strategy argument, check help', file=sys.stderr)
            exit(1)
        if args.strategy not in range(1, strategies+1):
            print(f'strategy {args.strategy} is invalid, check help', file=sys.stderr)
            exit(1)

    if args.func == handle_compare_blocks:
        if args.skip is not None and args.skip < 0:
            print('--skip must be greater than 0')
            exit(1)

    if args.func == handle_compare_blocks or args.func == handle_print_blocks:
        if args.partition_file is None and args.remote_device is None:
            print(f'Please specify either --partition-file or --remote-device', file=sys.stderr)
            exit(1)
        if args.remote_device is not None and args.remote_tmp_dir is None:
            print(f'Please specify --remote-tmp-dir', file=sys.stderr)
            exit(1)

    if args.func == handle_print_blocks or args.func == handle_print_chunk \
       or args.func == handle_inspect_chunk:
        if args.chunk_number <= 0:
            print(f'Invalid chunk number {args.chunk_number}: chunk numbering must starts at 1.', file=sys.stderr)
            exit(1)


if __name__ == '__main__':
    parser = build_argument_parser()
    args = parser.parse_args()
    validate_args(args)
    args.func(args)

