# -*- coding: utf-8 -*-

"""Utilities for easier test parameterization."""

from __future__ import annotations

__all__ = ['cases']

import os
import traceback
import typing

import pytest
import yaml


FuncT = typing.TypeVar('FuncT', bound=typing.Callable[..., typing.Any])


class Case(typing.TypedDict):
    """Common structure for a test case."""

    id: str  # pylint: disable=invalid-name
    values: typing.Mapping[str, typing.Any]


def cases(name: str, *, root: typing.Optional[str] = None) -> typing.Callable[[FuncT], FuncT]:
    """
    Fetch cases from a file.

    Works just like :code:`@pytest.mark.parametrize`, but fetches all cases from a file.

    File might be in YAML or JSON format. Its content should be in next format:

    .. code-block:: yaml

        - id: optional_id_for_first_case
          values:
            - field_1: value
            - field_2: value

        - values:
            - field_2: value
            - field_3: value

    The only required parameter is `name` of the file with test cases. Extension might be omitted.

    By default - it checks for cases located in a `cases` folder alongside a file with decorated test case. If test
    cases located in different place - a custom `root` for such files might be provided.
    """
    if root is None:
        root = os.path.join(os.path.dirname(traceback.extract_stack(limit=2)[0].filename), 'cases')
    elif os.path.isfile(root):
        root = os.path.dirname(root)

    path: str = os.path.join(root, name)
    if not os.path.isfile(path):
        name += '.'
        name, = (item for item in os.listdir(root) if item.startswith(name))
        path = os.path.join(root, name)

    stream: typing.TextIO
    with open(path, 'rt', encoding='utf-8') as stream:
        content: typing.Sequence['Case'] = yaml.safe_load(stream)

    if not content:
        raise ValueError('no cases found')

    fields: typing.Sequence[str] = sorted({key for case in content for key in case['values']})

    return pytest.mark.parametrize(','.join(fields), [
        pytest.param(*(case['values'].get(field, None) for field in fields), id=case.get('id', f'{index:04d}'))
        for index, case in enumerate(content, start=1)
    ])
