# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for MemoryMediaFileStorage"""

from __future__ import annotations

import unittest
from unittest import mock
from unittest.mock import MagicMock, mock_open

import pytest
from parameterized import parameterized

from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorageError
from streamlit.runtime.memory_media_file_storage import (
    MemoryFile,
    MemoryMediaFileStorage,
    get_extension_for_mimetype,
)


class MemoryMediaFileStorageTest(unittest.TestCase):
    def setUp(self):
        super().setUp()
        self.storage = MemoryMediaFileStorage(media_endpoint="/mock/media")

    @mock.patch(
        "streamlit.runtime.memory_media_file_storage.open",
        mock_open(read_data=b"mock_bytes"),
    )
    def test_load_with_path(self):
        """Adding a file by path creates a MemoryFile instance."""
        file_id = self.storage.load_and_get_id(
            "mock/file/path",
            mimetype="video/mp4",
            kind=MediaFileKind.MEDIA,
            filename="file.mp4",
        )
        assert MemoryFile(
            content=b"mock_bytes",
            mimetype="video/mp4",
            kind=MediaFileKind.MEDIA,
            filename="file.mp4",
        ) == self.storage.get_file(file_id)

    def test_load_with_bytes(self):
        """Adding a file with bytes creates a MemoryFile instance."""
        file_id = self.storage.load_and_get_id(
            b"mock_bytes",
            mimetype="video/mp4",
            kind=MediaFileKind.MEDIA,
            filename="file.mp4",
        )
        assert MemoryFile(
            content=b"mock_bytes",
            mimetype="video/mp4",
            kind=MediaFileKind.MEDIA,
            filename="file.mp4",
        ) == self.storage.get_file(file_id)

    def test_identical_files_have_same_id(self):
        """Two files with the same content, mimetype, and filename should share an ID."""
        # Create 2 identical files. We'll just get one ID.
        file_id1 = self.storage.load_and_get_id(
            b"mock_bytes",
            mimetype="video/mp4",
            kind=MediaFileKind.MEDIA,
            filename="file.mp4",
        )
        file_id2 = self.storage.load_and_get_id(
            b"mock_bytes",
            mimetype="video/mp4",
            kind=MediaFileKind.MEDIA,
            filename="file.mp4",
        )
        assert file_id1 == file_id2

        # Change file content -> different ID
        changed_content = self.storage.load_and_get_id(
            b"mock_bytes_2",
            mimetype="video/mp4",
            kind=MediaFileKind.MEDIA,
            filename="file.mp4",
        )
        assert file_id1 != changed_content

        # Change mimetype -> different ID
        changed_mimetype = self.storage.load_and_get_id(
            b"mock_bytes",
            mimetype="image/png",
            kind=MediaFileKind.MEDIA,
            filename="file.mp4",
        )
        assert file_id1 != changed_mimetype

        # Change (or omit) filename -> different ID
        changed_filename = self.storage.load_and_get_id(
            b"mock_bytes", mimetype="video/mp4", kind=MediaFileKind.MEDIA
        )
        assert file_id1 != changed_filename

    @mock.patch(
        "streamlit.runtime.memory_media_file_storage.open",
        MagicMock(side_effect=Exception),
    )
    def test_load_with_bad_path(self):
        """Adding a file by path raises a MediaFileStorageError if the file can't be read."""
        with pytest.raises(MediaFileStorageError):
            self.storage.load_and_get_id(
                "mock/file/path",
                mimetype="video/mp4",
                kind=MediaFileKind.MEDIA,
                filename="file.mp4",
            )

    @parameterized.expand(
        [
            ("video/mp4", ".mp4"),
            ("audio/wav", ".wav"),
            ("image/png", ".png"),
            ("image/jpeg", ".jpg"),
        ]
    )
    def test_get_url(self, mimetype, extension):
        """URLs should be formatted correctly, and have the expected extension."""
        file_id = self.storage.load_and_get_id(
            b"mock_bytes", mimetype=mimetype, kind=MediaFileKind.MEDIA
        )
        url = self.storage.get_url(file_id)
        assert f"/mock/media/{file_id}{extension}" == url

    def test_get_url_invalid_fileid(self):
        """get_url raises if it gets a bad file_id."""
        with pytest.raises(MediaFileStorageError):
            self.storage.get_url("not_a_file_id")

    def test_delete_file(self):
        """delete_file removes the file with the given ID."""
        file_id1 = self.storage.load_and_get_id(
            b"mock_bytes_1",
            mimetype="video/mp4",
            kind=MediaFileKind.MEDIA,
            filename="file.mp4",
        )
        file_id2 = self.storage.load_and_get_id(
            b"mock_bytes_2",
            mimetype="video/mp4",
            kind=MediaFileKind.MEDIA,
            filename="file.mp4",
        )

        # delete file 1. It should not exist, but file2 should.
        self.storage.delete_file(file_id1)
        with pytest.raises(MediaFileStorageError):
            self.storage.get_file(file_id1)

        assert self.storage.get_file(file_id2) is not None

        # delete file 2
        self.storage.delete_file(file_id2)
        with pytest.raises(MediaFileStorageError):
            self.storage.get_file(file_id2)

    def test_delete_invalid_file_is_a_noop(self):
        """deleting a file that doesn't exist doesn't raise an error."""
        self.storage.delete_file("mock_file_id")

    def test_cache_stats(self):
        """Test our CacheStatsProvider implementation."""
        assert len(self.storage.get_stats()) == 0

        # Add several files to storage. We'll unique-ify them by filename.
        mock_data = b"some random mock binary data"
        num_files = 5
        for ii in range(num_files):
            self.storage.load_and_get_id(
                mock_data,
                mimetype="video/mp4",
                kind=MediaFileKind.MEDIA,
                filename=f"{ii}.mp4",
            )

        stats = self.storage.get_stats()
        assert len(stats) == 1
        assert stats[0].category_name == "st_memory_media_file_storage"
        assert len(mock_data) * num_files == sum(stat.byte_length for stat in stats)

        # Remove files, and ensure our cache doesn't report they still exist
        for file_id in list(self.storage._files_by_id.keys()):
            self.storage.delete_file(file_id)

        assert len(self.storage.get_stats()) == 0


class MemoryMediaFileStorageUtilTest(unittest.TestCase):
    """Unit tests for utility functions in memory_media_file_storage.py"""

    @parameterized.expand(
        [
            ("video/mp4", ".mp4"),
            ("audio/wav", ".wav"),
            ("image/png", ".png"),
            ("image/jpeg", ".jpg"),
        ]
    )
    def test_get_extension_for_mimetype(self, mimetype: str, expected_extension: str):
        result = get_extension_for_mimetype(mimetype)
        assert expected_extension == result
