"""launch snakemake"""

from datetime import datetime
import re
import os
import sys
import subprocess
from rich.table import Table
from rich import box
from rich.syntax import Syntax
from harpy.common.file_ops import gzip_file, purge_empty_logs
from harpy.common.printing import CONSOLE, print_onerror, print_setup_error
from harpy.common.progress import harpy_progressbar, harpy_pulsebar, harpy_progresspanel

EXIT_CODE_SUCCESS = 0
EXIT_CODE_GENERIC_ERROR = 1
EXIT_CODE_CONDA_ERROR = 2
EXIT_CODE_RUNTIME_ERROR = 3
# quiet = 0 : print all things, full progressbar
# quiet = 1 : print all text, only "Total" progressbar
# quiet = 2 : print nothing, no progressbar
def iserror(text: str):
    """logical check for erroring trigger words in snakemake output"""
    return "Exception" in text or "Error" in text or "MissingOutputException" in text

def print_shellcmd(text: str, _process):
    _table = Table(
        show_header=False,
        pad_edge=False,
        show_edge=False,
        padding=(0,0),
        box=box.SIMPLE,
    )
    _table.add_column("Lpadding", justify="left")
    _table.add_column("shell", justify="left")
    _table.add_column("Rpadding", justify="left")

    _text = text
    while "(command exited" not in _text or not _text:
        _text = _process.stderr.readline()
        if not _text:
            break
        text += _text

    text = text.replace("(command exited with non-zero exit code)", "").rstrip().lstrip().replace("\t", "  ")
    text = re.sub(r' {2,}|\t+', '  ', text)
    cmd = Syntax(text, lexer = "bash", tab_size=2, word_wrap=True, padding=1, dedent=True, theme = "paraiso-dark")
    _table.add_row("  ", cmd, "  ")
    CONSOLE.print("[bold default]shell:", _table)

    # if there's a logfile
    _text = _process.stderr.readline()

    if _text.strip().startswith("Logfile"):
        merged_text = ""
        _log = _text.rstrip().split()[1]
        CONSOLE.rule(f"[bold]Log File: {_log.rstrip(':')}", style = "yellow")
        lines = 0
        while lines < 2:
            _text = _process.stderr.readline()
            if "====" in _text:
                lines += 1
                continue
            merged_text += _text
        if "====" in _text:
            CONSOLE.print("[red]" + re.sub(r'\n{3,}', '\n\n', merged_text), overflow = "ignore", crop = False)
            return _process.stderr.readline()


def highlight_params(text: str):
    """make important snakemake attributes like 'input:' highlighted in the error output"""
    text = text.removeprefix("    ").rstrip()
    test = text.lstrip()
    if test.startswith("jobid:"):
        return text.replace("jobid:", "[bold default]jobid:[/]")
    if test.startswith("input:"):
        return text.replace("input:", "[bold default]input:[/]")
    if test.startswith("output:"):
        return text.replace("output:", "[bold default]output:[/]")
    if test.startswith("log:"):
        return text.replace("log:", "[bold default]log:[/]")
    if test.startswith("conda-env:"):
        return text.replace("conda-env:", "[bold default]conda-env:[/]")
    if test.startswith("container:"):
        return text.replace("container:", "[bold default]container:[/]")
    if test.startswith("shell:"): 
        return text.replace("shell:", "[bold default]shell:[/]")
    if test.startswith("wildcards:"): 
        return text.replace("wildcards:", "[bold default]wildcards:[/]")
    if test.startswith("affected files:"): 
        return text.replace("affected files:", "[bold default]affected files:[/]")
    if test.startswith("[") and text.endswith("]"):
        return f"\n[blue]{text}[/]"
    return text

def launch_snakemake(sm_args, workflow, outdir, sm_logfile, quiet, CONSOLE = CONSOLE):
    """launch snakemake with the given commands"""
    exitcode = None
    sm_start = datetime.now()
    try:
        # Start snakemake as a subprocess
        process = subprocess.Popen(sm_args.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, text = True)
        deps = False
        # read up to the job summary, but break early if dependency text appears
        while not exitcode:
            output = process.stderr.readline()
            # check for syntax errors at the very beginning
            if process.poll() or iserror(output):
                exitcode = EXIT_CODE_SUCCESS if process.poll() == 0 else EXIT_CODE_GENERIC_ERROR
                exitcode = EXIT_CODE_CONDA_ERROR if "Conda" in output else exitcode
                break
            if quiet < 2:
                with CONSOLE.status("[dim]Preparing workflow", spinner = "point", spinner_style="yellow"):
                    while output.startswith("Building DAG of jobs...") or output.startswith("Assuming"):
                        output = process.stderr.readline()
                if "Nothing to be" in output:
                    exitcode = EXIT_CODE_SUCCESS
                    break
            else:
                while output.startswith("Building DAG of jobs...") or output.startswith("Assuming"):
                    output = process.stderr.readline()
            if process.poll() or iserror(output):
                exitcode = EXIT_CODE_SUCCESS if process.poll() == 0 else EXIT_CODE_GENERIC_ERROR
                break
            while not output.startswith("Job stats:") and not exitcode:
                # print dependency text only once
                if "Downloading and installing remote packages" in output or "Running post-deploy" in output:
                    deps = True
                    deploy_text = "[dim]Installing workflow software"
                    break
                if "Pulling singularity image" in output:
                    deps = True
                    deploy_text = "[dim]Building software container"
                    break
                if "Nothing to be" in output:
                    exitcode = EXIT_CODE_SUCCESS
                    break
                if "MissingInput" in output:
                    exitcode = EXIT_CODE_GENERIC_ERROR
                    break
                if "AmbiguousRuleException" in output or "Error" in output or "Exception" in output:
                    exitcode = EXIT_CODE_RUNTIME_ERROR
                    break
                output = process.stderr.readline()
            # if dependency text present, print pulsing progress bar
            if deps:
                progress = harpy_pulsebar(quiet)
                with harpy_progresspanel(progress, quiet=quiet, title = deploy_text):
                    progress.add_task("[dim]Working...", total = None)
                    while not output.startswith("Job stats:"):
                        output = process.stderr.readline()
                        if process.poll() or iserror(output):
                            exitcode = EXIT_CODE_SUCCESS if process.poll() == 0 else 2
                            break
                    progress.stop()
            if process.poll() or exitcode:
                break
            if "Nothing to be" in output:
                exitcode = EXIT_CODE_SUCCESS
                break
            progress = harpy_progressbar(quiet)
            with harpy_progresspanel(progress, quiet = quiet):
                # process the job summary
                job_inventory = {}
                while True:
                    output = process.stderr.readline()
                    # stop parsing on "total" since it's always the last entry
                    if output.startswith("Select jobs to execute"):
                        break
                    try:
                        rule,count = output.split()
                        rule_desc = rule.replace("_", " ")
                        # rule : display_name, count_total, set of job_id's
                        job_inventory[rule] = [rule_desc, int(count), set()]
                    except ValueError:
                        pass
                # checkpoint
                    if process.poll() or iserror(output):
                        exitcode = EXIT_CODE_SUCCESS if process.poll() == 0 else EXIT_CODE_GENERIC_ERROR
                        break
                if process.poll() or exitcode:
                    break
                total_text = "[bold blue]Total" if quiet == 0 else "[bold blue]Progress"
                task_ids = {"total_progress" : progress.add_task(total_text, total=job_inventory["total"][1])}

                while output:
                    output = process.stderr.readline()
                    if iserror(output) or process.poll() == 1:
                        progress.stop()
                        exitcode = EXIT_CODE_RUNTIME_ERROR
                        break
                    if process.poll() == 0 or output.startswith("Complete log") or output.startswith("Nothing to be"):
                        progress.stop()
                        exitcode = EXIT_CODE_SUCCESS if process.poll() == 0 else EXIT_CODE_RUNTIME_ERROR
                        break
                    # add new progress bar track if the rule doesn't have one yet
                    if output.lstrip().startswith("rule ") or output.lstrip().startswith("localrule "):
                        # catch the 2nd word and remove the colon
                        rule = output.split()[-1].replace(":", "")
                        # add progress bar if it doesn't exist
                        if rule not in task_ids:
                            task_ids[rule] = progress.add_task(job_inventory[rule][0], total=job_inventory[rule][1], visible = quiet != 1)
                        # parse the rest of the rule block to get the job ID and add it to the inventory
                        while True:
                            output = process.stderr.readline()
                            if "jobid: " in output:
                                job_id = int(output.strip().split()[-1])
                                job_inventory[rule][2].add(job_id)
                                break
                        # store the job id in the inventory so we can later look up which rule it's associated with
                    # check which rule the job is associated with and update the corresponding progress bar
                    if output.startswith("Finished jobid: "):
                        completed = int(re.search(r"\d+", output).group())
                        for job,details in job_inventory.items():
                            if completed in details[2]:
                                progress.advance(task_ids[job])
                                progress.advance(task_ids["total_progress"])
                                if progress.tasks[task_ids[job]].completed == progress.tasks[task_ids[job]].total:
                                    progress.update(task_ids[job], description=f"[dim]{details[0]}")
                                    #progress.update(task_ids[job], visible = False)
                                # remove the job to save memory. wont be seen again
                                details[2].discard(completed)
                                break
        process.wait()
        if process.returncode < 1:
            return
        else:
            if exitcode in (1,2):
                print_setup_error(exitcode)
            elif exitcode == 3:
                print_onerror(os.path.join(os.path.basename(outdir), sm_logfile), datetime.now() - sm_start)

            CONSOLE.tab_size = 4
            CONSOLE._highlight = False
            while output and not output.endswith("]") and not output.startswith("Shutting down"):                   
                if "Exception" in output or "Error" in output:
                    if output.startswith("CalledProcessError in file"):
                        CONSOLE.print("[yellow bold]\t" + output.rstrip().rstrip(":"), overflow = "ignore", crop = False)
                        # skip the Command source part
                        while not output.strip().startswith("["):
                            output = process.stderr.readline()
                        CONSOLE.print("\n[blue]" + output.strip(), overflow = "ignore", crop = False)
                        output = process.stderr.readline()
                        continue
                    else:
                        CONSOLE.print("[yellow bold]" + output.strip(), overflow = "ignore", crop = False)
                    output = process.stderr.readline()
                    continue
                if output.rstrip() == "Traceback (most recent call last):" :
                    while output.rstrip() != "snakemake.exceptions.SpawnedJobError":
                        output = process.stderr.readline()
                    output = process.stderr.readline()
                if output.strip().startswith("Logfile"):
                    _log = output.rstrip().split()[1]
                    CONSOLE.rule(f"[bold]Log File: {_log.rstrip(':')}", style = "yellow")
                    lines = 0
                    while lines < 2:
                        output = process.stderr.readline()
                        if "====" in output:
                            lines += 1
                            continue
                        CONSOLE.print("[red]" + output, overflow = "ignore", crop = False)
                    if "====" in output:
                        output = process.stderr.readline()
                if output.rstrip():
                    if output.lstrip().startswith("message:") or output.startswith("Finished jobid:") or output.rstrip().endswith(") done") or output.startswith("Removing output files of failed job"):
                        output = process.stderr.readline().strip()
                        continue
                    # handle errors in the python run blocks
                    if output.startswith("Exiting because a job execution failed. Look below for"):
                        while not output.startswith("[") and not output.rstrip().endswith("]"):
                            output = process.stderr.readline().lstrip()
                        CONSOLE.print("[blue]" + output.rstrip(), overflow = "ignore", crop = False)
                        output = process.stderr.readline().lstrip()
                        continue
                    if not output.startswith("Complete log"):
                        if output.startswith("[") and output.rstrip().endswith("]"):
                            output = process.stderr.readline().lstrip()
                            continue
                        if output.strip().startswith("processing file: ") and output.strip().endswith(".qmd"):
                            # make quarto error logs a little nicer by skipping progress
                            while not output.strip().startswith("Error"):
                                output = process.stderr.readline()
                        if output.strip().startswith("Trying to restart job"):
                            break
                        if output.strip().startswith("shell:"):
                            output = print_shellcmd(process.stderr.readline(), process)
                        if "(command exited with non-zero" in output or output.startswith("Removing temporary output"):
                            output = process.stderr.readline()
                            continue
                        else:
                            CONSOLE.print("[red]" + highlight_params(output), overflow = "ignore", crop = False)
                    if output.startswith("Removing output files of failed job"):
                        break
                elif output.startswith("At least one job did not"):
                    break
                else:
                    break
                output = process.stderr.readline()
            sys.exit(process.returncode)
    except KeyboardInterrupt:
        CONSOLE.print("")
        CONSOLE.rule("[bold]Terminating Harpy", style = "yellow")
        process.terminate()
        process.wait()
        gzip_file(os.path.join(outdir,sm_logfile))
        purge_empty_logs(outdir)
        sys.exit(1)
