"""Document parameters using comments."""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_docments.ipynb.

# %% ../nbs/04_docments.ipynb 2
from __future__ import annotations

import re,ast,inspect
from tokenize import tokenize,COMMENT
from ast import parse,FunctionDef,AsyncFunctionDef,AnnAssign
from io import BytesIO
from textwrap import dedent
from types import SimpleNamespace
from inspect import getsource,isfunction,ismethod,isclass,signature,Parameter
from dataclasses import dataclass, is_dataclass
from .utils import *
from .meta import delegates
from . import docscrape
from textwrap import fill
from inspect import isclass,getdoc

# %% auto 0
__all__ = ['empty', 'docstring', 'parse_docstring', 'isdataclass', 'get_dataclass_source', 'get_source', 'get_name', 'qual_name',
           'docments', 'sig2str', 'extract_docstrings', 'DocmentTbl', 'ShowDocRenderer', 'MarkdownRenderer']

# %% ../nbs/04_docments.ipynb
def docstring(sym):
    "Get docstring for `sym` for functions ad classes"
    if isinstance(sym, str): return sym
    res = getdoc(sym)
    if not res and isclass(sym): res = getdoc(sym.__init__)
    return res or ""

# %% ../nbs/04_docments.ipynb
def parse_docstring(sym):
    "Parse a numpy-style docstring in `sym`"
    return AttrDict(**docscrape.NumpyDocString(docstring(sym)))

# %% ../nbs/04_docments.ipynb
def isdataclass(s):
    "Check if `s` is a dataclass but not a dataclass' instance"
    return is_dataclass(s) and isclass(s)

# %% ../nbs/04_docments.ipynb
def get_dataclass_source(s):
    "Get source code for dataclass `s`"
    return getsource(s) if not getattr(s, "__module__") == '__main__' else ""

# %% ../nbs/04_docments.ipynb
def get_source(s):
    "Get source code for string, function object or dataclass `s`"
    if isinstance(s,str): return s
    return getsource(s) if isfunction(s) or ismethod(s) else get_dataclass_source(s) if isdataclass(s) else None

# %% ../nbs/04_docments.ipynb
def _parses(s):
    "Parse Python code in string, function object or dataclass `s`"
    return parse(dedent(get_source(s) or ''))

def _tokens(s):
    "Tokenize Python code in string or function object `s`"
    s = get_source(s)
    if not s: return []
    return tokenize(BytesIO(s.encode('utf-8')).readline)

_clean_re = re.compile(r'^\s*#(.*)\s*$')
def _clean_comment(s):
    res = _clean_re.findall(s)
    return res[0] if res else None

def _param_locs(s, returns=True, args_kwargs=False):
    "`dict` of parameter line numbers to names"
    body = _parses(s).body
    if len(body)==1:
        defn = body[0]
        if isinstance(defn, (FunctionDef, AsyncFunctionDef)):
            res = {arg.lineno:arg.arg for arg in defn.args.args}
            # Add *args if present
            if defn.args.vararg: res[defn.args.vararg.lineno] = defn.args.vararg.arg
            # Add keyword-only args
            res.update({arg.lineno:arg.arg for arg in defn.args.kwonlyargs})
            # Add **kwargs if present
            if defn.args.kwarg and args_kwargs: res[defn.args.kwarg.lineno] = defn.args.kwarg.arg
            if returns and defn.returns: res[defn.returns.lineno] = 'return'
            return res
        elif isdataclass(s):
            res = {arg.lineno:arg.target.id for arg in defn.body if isinstance(arg, AnnAssign)}
            return res
    return None

# %% ../nbs/04_docments.ipynb
empty = Parameter.empty

# %% ../nbs/04_docments.ipynb
def _get_comment(line, arg, comments, parms):
    if line in comments: return comments[line].strip()
    line -= 1
    res = []
    while line and line in comments and line not in parms:
        res.append(comments[line])
        line -= 1
    return dedent('\n'.join(reversed(res))) if res else None

def _get_full(p, docs):
    anno = p.annotation
    if anno==empty:
        if p.default!=empty: anno = type(p.default)
        elif p.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD): anno = p.kind
    return AttrDict(docment=docs.get(p.name), anno=anno, default=p.default)

# %% ../nbs/04_docments.ipynb
def _merge_doc(dm, npdoc):
    if not npdoc: return dm
    if not dm.anno or dm.anno==empty: dm.anno = npdoc.type
    if not dm.docment: dm.docment = '\n'.join(npdoc.desc)
    return dm

def _merge_docs(dms, npdocs):
    npparams = npdocs['Parameters']
    params = {nm:_merge_doc(dm,npparams.get(nm,None)) for nm,dm in dms.items()}
    if 'return' in dms: params['return'] = _merge_doc(dms['return'], npdocs['Returns'])
    return params

# %% ../nbs/04_docments.ipynb
def _get_property_name(p):
    "Get the name of property `p`"
    if hasattr(p, 'fget'):
        return p.fget.func.__qualname__ if hasattr(p.fget, 'func') else p.fget.__qualname__
    else: return next(iter(re.findall(r'\'(.*)\'', str(p)))).split('.')[-1]

# %% ../nbs/04_docments.ipynb
def get_name(obj):
    "Get the name of `obj`"
    if hasattr(obj, '__name__'):       return obj.__name__
    elif getattr(obj, '_name', False): return obj._name
    elif hasattr(obj,'__origin__'):    return str(obj.__origin__).split('.')[-1] #for types
    elif type(obj)==property:          return _get_property_name(obj)
    else:                              return str(obj).split('.')[-1]

# %% ../nbs/04_docments.ipynb
def qual_name(obj):
    "Get the qualified name of `obj`"
    if hasattr(obj,'__qualname__'): return obj.__qualname__
    if ismethod(obj):       return f"{get_name(obj.__self__)}.{get_name(fn)}"
    return get_name(obj)

# %% ../nbs/04_docments.ipynb
def _docments(s, returns=True, eval_str=False, args_kwargs=False):
    "`dict` of parameter names to 'docment-style' comments in function or string `s`"
    nps = parse_docstring(s)
    if isclass(s) and not is_dataclass(s): s = s.__init__ # Constructor for a class
    comments = {o.start[0]:_clean_comment(o.string) for o in _tokens(s) if o.type==COMMENT}
    parms = _param_locs(s, returns=returns, args_kwargs=args_kwargs) or {}
    docs = {arg:_get_comment(line, arg, comments, parms) for line,arg in parms.items()}

    sig = signature_ex(s, True)
    res = {name:_get_full(p, docs) for name,p in sig.parameters.items()}
    if returns: res['return'] = AttrDict(docment=docs.get('return'), anno=sig.return_annotation, default=empty)
    res = _merge_docs(res, nps)
    if eval_str:
        hints = type_hints(s)
        for k,v in res.items():
            if k in hints: v['anno'] = hints.get(k)
    return res

# %% ../nbs/04_docments.ipynb
@delegates(_docments)
def docments(elt, full=False, args_kwargs=False, **kwargs):
    "Generates a `docment`"
    if full: args_kwargs=True
    r = {}
    params = set(signature(elt).parameters)
    params.add('return')

    def _update_docments(f, r):
        if hasattr(f, '__delwrap__'): _update_docments(f.__delwrap__, r)
        r.update({k:v for k,v in _docments(f, **kwargs).items() if k in params
                  and (v.get('docment', None) or not nested_idx(r, k, 'docment'))})

    _update_docments(elt, r)
    if not full: r = {k:v['docment'] for k,v in r.items()}
    return AttrDict(r)

# %% ../nbs/04_docments.ipynb
def sig2str(func):
    "Generate function signature with docments as comments"
    docs = docments(func, full=True)
    params = []
    for k,v in docs.items():
        if k == 'return': continue
        anno = getattr(v['anno'], '__name__', str(v['anno'])) if v['anno'] != inspect._empty else ''
        if '|' in str(v['anno']): anno = str(v['anno'])
        p = k + (f':{anno}' if anno and anno != 'inspect._empty' else '')
        if v['default'] != inspect._empty:
            d = getattr(v['default'], '__name__', v['default']) if hasattr(v['default'], '__name__') else v['default']
            p += f'={d}' if d is not None else '=None'
        if v['docment']: p += f' # {v["docment"]}'
        params.append(p)
    
    ret = docs.get('return', {})
    ret_str = ':'
    if ret and ret.get('anno')!=inspect._empty:
        ret_str = f"->{getattr(ret['anno'], '__name__', str(ret['anno']))}" + (f': # {ret["docment"]}' if ret.get('docment') else ':')
    doc_str = f'    "{func.__doc__}"' if func.__doc__ else ''
    return f"def {func.__name__}(\n    " + ",\n    ".join(params) + f"\n){ret_str}\n{doc_str}"

# %% ../nbs/04_docments.ipynb
def _get_params(node):
    params = [a.arg for a in node.args.args]
    if node.args.vararg: params.append(f"*{node.args.vararg.arg}")
    if node.args.kwarg: params.append(f"**{node.args.kwarg.arg}")
    return ", ".join(params)

# %% ../nbs/04_docments.ipynb
class _DocstringExtractor(ast.NodeVisitor):
    def __init__(self): self.docstrings,self.cls,self.cls_init = {},None,None

    def visit_FunctionDef(self, node):
        name = node.name
        if name == '__init__':
            self.cls_init = node
            return
        elif name.startswith('_'): return
        elif self.cls: name = f"{self.cls}.{node.name}"
        docs = ast.get_docstring(node)
        params = _get_params(node)
        if docs: self.docstrings[name] = (docs, params)
        self.generic_visit(node)

    def visit_ClassDef(self, node):
        self.cls,self.cls_init = node.name,None
        docs = ast.get_docstring(node)
        if docs: self.docstrings[node.name] = ()
        self.generic_visit(node)
        if not docs and self.cls_init: docs = ast.get_docstring(self.cls_init)
        params = _get_params(self.cls_init) if self.cls_init else ""
        if docs: self.docstrings[node.name] = (docs, params)
        self.cls,self.cls_init = None,None

    def visit_Module(self, node):
        module_doc = ast.get_docstring(node)
        if module_doc: self.docstrings['_module'] = (module_doc, "")
        self.generic_visit(node)

# %% ../nbs/04_docments.ipynb
def extract_docstrings(code):
    "Create a dict from function/class/method names to tuples of docstrings and param lists"
    extractor = _DocstringExtractor()
    extractor.visit(ast.parse(code))
    return extractor.docstrings

# %% ../nbs/04_docments.ipynb
def _non_empty_keys(d:dict): return L([k for k,v in d.items() if v != inspect._empty])
def _bold(s): return f'**{s}**' if s.strip() else s

# %% ../nbs/04_docments.ipynb
def _escape_markdown(s):
    for c in '|^': s = re.sub(rf'\\?\{c}', rf'\{c}', s)
    return s.replace('\n', '<br>')

# %% ../nbs/04_docments.ipynb
def _maybe_nm(o):
    if (o == inspect._empty): return ''
    else: return o.__name__ if hasattr(o, '__name__') else _escape_markdown(str(o))

# %% ../nbs/04_docments.ipynb
def _list2row(l:list): return '| '+' | '.join([_maybe_nm(o) for o in l]) + ' |'

# %% ../nbs/04_docments.ipynb
class DocmentTbl:
    # this is the column order we want these items to appear
    _map = {'anno':'Type', 'default':'Default', 'docment':'Details'}

    def __init__(self, obj, verbose=True, returns=True):
        "Compute the docment table string"
        self.verbose = verbose
        self.returns = False if isdataclass(obj) else returns
        try: self.params = L(signature_ex(obj, eval_str=True).parameters.keys())
        except (ValueError,TypeError): self.params=[]
        try: _dm = docments(obj, full=True, returns=returns)
        except: _dm = {}
        if 'self' in _dm: del _dm['self']
        for d in _dm.values(): d['docment'] = ifnone(d['docment'], inspect._empty)
        self.dm = _dm

    @property
    def _columns(self):
        "Compute the set of fields that have at least one non-empty value so we don't show tables empty columns"
        cols = set(flatten(L(self.dm.values()).filter().map(_non_empty_keys)))
        candidates = self._map if self.verbose else {'docment': 'Details'}
        return {k:v for k,v in candidates.items() if k in cols}

    @property
    def has_docment(self): return 'docment' in self._columns and self._row_list

    @property
    def has_return(self): return self.returns and bool(_non_empty_keys(self.dm.get('return', {})))

    def _row(self, nm, props):
        "unpack data for single row to correspond with column names."
        return [nm] + [props[c] for c in self._columns]

    @property
    def _row_list(self):
        "unpack data for all rows."
        ordered_params = [(p, self.dm[p]) for p in self.params if p != 'self' and p in self.dm]
        return L([self._row(nm, props) for nm,props in ordered_params])

    @property
    def _hdr_list(self): return ['  '] + [_bold(l) for l in L(self._columns.values())]

    @property
    def hdr_str(self):
        "The markdown string for the header portion of the table"
        md = _list2row(self._hdr_list)
        return md + '\n' + _list2row(['-' * len(l) for l in self._hdr_list])

    @property
    def params_str(self):
        "The markdown string for the parameters portion of the table."
        return '\n'.join(self._row_list.map(_list2row))

    @property
    def return_str(self):
        "The markdown string for the returns portion of the table."
        return _list2row(['**Returns**']+[_bold(_maybe_nm(self.dm['return'][c])) for c in self._columns])

    def _repr_markdown_(self):
        if not self.has_docment: return ''
        _tbl = [self.hdr_str, self.params_str]
        if self.has_return: _tbl.append(self.return_str)
        return '\n'.join(_tbl)

    def __eq__(self,other): return self.__str__() == str(other).strip()

    __str__ = _repr_markdown_
    __repr__ = basic_repr()

# %% ../nbs/04_docments.ipynb
def _docstring(sym):
    npdoc = parse_docstring(sym)
    return '\n\n'.join([npdoc['Summary'], npdoc['Extended']]).strip()

# %% ../nbs/04_docments.ipynb
def _fullname(o):
    module,name = getattr(o, "__module__", None),qual_name(o)
    return name if module is None or module in ('__main__','builtins') else module + '.' + name

class ShowDocRenderer:
    def __init__(self, sym, name:str|None=None, title_level:int=3):
        "Show documentation for `sym`"
        sym = getattr(sym, '__wrapped__', sym)
        sym = getattr(sym, 'fget', None) or getattr(sym, 'fset', None) or sym
        store_attr()
        self.nm = name or qual_name(sym)
        self.isfunc = inspect.isfunction(sym)
        try: self.sig = signature_ex(sym, eval_str=True)
        except (ValueError,TypeError): self.sig = None
        self.docs = _docstring(sym)
        self.dm = DocmentTbl(sym)
        self.fn = _fullname(sym)

    __repr__ = basic_repr()

# %% ../nbs/04_docments.ipynb
def _f_name(o): return f'<function {o.__name__}>' if isinstance(o, FunctionType) else None
def _fmt_anno(o): return inspect.formatannotation(o).strip("'").replace(' ','')

def _show_param(param):
    "Like `Parameter.__str__` except removes: quotes in annos, spaces, ids in reprs"
    kind,res,anno,default = param.kind,param._name,param._annotation,param._default
    kind = '*' if kind==inspect._VAR_POSITIONAL else '**' if kind==inspect._VAR_KEYWORD else ''
    res = kind+res
    if anno is not inspect._empty: res += f':{_f_name(anno) or _fmt_anno(anno)}'
    if default is not inspect._empty: res += f'={_f_name(default) or repr(default)}'
    return res

# %% ../nbs/04_docments.ipynb
def _fmt_sig(sig):
    if sig is None: return ''
    p = {k:v for k,v in sig.parameters.items()}
    _params = [_show_param(p[k]) for k in p.keys() if k != 'self']
    return "(" + ', '.join(_params)  + ")"

def _wrap_sig(s):
    "wrap a signature to appear on multiple lines if necessary."
    pad = '> ' + ' ' * 5
    indent = pad + ' ' * (s.find('(') + 1)
    return fill(s, width=80, initial_indent=pad, subsequent_indent=indent)

def _ital_first(s:str):
    "Surround first line with * for markdown italics, preserving leading spaces"
    return re.sub(r'^(\s*)(.+)', r'\1*\2*', s, count=1)

# %% ../nbs/04_docments.ipynb
def _ext_link(url, txt, xtra=""): return f'[{txt}]({url}){{target="_blank" {xtra}}}'

class MarkdownRenderer(ShowDocRenderer):
    "Markdown renderer for `show_doc`"
    def _repr_markdown_(self):
        doc = _wrap_sig(f"{self.nm} {_fmt_sig(self.sig)}") if self.sig else ''
        if self.docs: doc += f"\n\n{_ital_first(self.docs)}"
        if self.dm.has_docment: doc += f"\n\n{self.dm}"
        return doc
    __repr__=__str__=_repr_markdown_
