# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import numpy as np
import pytest

from gluonts.dataset.common import Dataset, ListDataset
from gluonts.model.estimator import Estimator
from gluonts.torch.model.simple_feedforward import SimpleFeedForwardEstimator


@pytest.mark.parametrize(
    "dataset, estimator",
    [
        (
            ListDataset(
                data_iter=[
                    {
                        "item_id": "3",
                        "target": np.random.normal(
                            loc=100, scale=10, size=(100)
                        ),
                        "start": "2020-01-01 00:00:00",
                        "info": {"some_key": [1, 2, 3]},
                    },
                    {
                        "item_id": "2",
                        "target": np.random.normal(
                            loc=100, scale=10, size=(100)
                        ),
                        "start": "2020-01-01 00:00:00",
                        "info": {"some_key": [2, 3, 4]},
                    },
                    {
                        "item_id": "1",
                        "target": np.random.normal(
                            loc=100, scale=10, size=(100)
                        ),
                        "start": "2020-01-01 00:00:00",
                        "info": {"some_key": [4, 5, 6]},
                    },
                ],
                freq="5min",
            ),
            SimpleFeedForwardEstimator(
                prediction_length=4,
                context_length=20,
                trainer_kwargs={"max_epochs": 2},
            ),
        ),
    ],
)
def test_item_id_info(dataset: Dataset, estimator: Estimator):
    predictor = estimator.train(dataset)
    forecasts = predictor.predict(dataset)
    for data_entry, forecast in zip(dataset, forecasts):
        assert ("item_id" not in data_entry) or data_entry[
            "item_id"
        ] == forecast.item_id
        assert ("info" not in data_entry) or data_entry[
            "info"
        ] == forecast.info
