# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import inspect
import json
import os
import threading
import unittest
from pathlib import Path
from typing import TYPE_CHECKING, Any
from unittest import mock
from unittest.mock import MagicMock, patch

import pandas as pd
import pytest

import streamlit as st
import streamlit.components.v1 as components
from streamlit.components.lib.local_component_registry import LocalComponentRegistry
from streamlit.components.types.base_component_registry import BaseComponentRegistry
from streamlit.components.v1 import component_arrow
from streamlit.components.v1.component_registry import (
    ComponentRegistry,
    _get_module_name,
)
from streamlit.components.v1.custom_component import CustomComponent
from streamlit.errors import DuplicateWidgetID, StreamlitAPIException
from streamlit.proto.Components_pb2 import SpecialArg
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
from streamlit.runtime import Runtime, RuntimeConfig
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import add_script_run_ctx
from streamlit.type_util import to_bytes
from tests.delta_generator_test_case import DeltaGeneratorTestCase
from tests.testutil import create_mock_script_run_ctx

if TYPE_CHECKING:
    from streamlit.components.types.base_custom_component import BaseCustomComponent

URL = "http://not.a.real.url:3001"
PATH = "not/a/real/path"


def _serialize_dataframe_arg(key: str, value: Any) -> SpecialArg:
    special_arg = SpecialArg()
    special_arg.key = key
    component_arrow.marshall(special_arg.arrow_dataframe.data, value)
    return special_arg


def _serialize_bytes_arg(key: str, value: Any) -> SpecialArg:
    special_arg = SpecialArg()
    special_arg.key = key
    special_arg.bytes = to_bytes(value)
    return special_arg


class DeclareComponentTest(unittest.TestCase):
    """Test component declaration."""

    def setUp(self) -> None:
        config = RuntimeConfig(
            script_path="mock/script/path.py",
            command_line=None,
            component_registry=LocalComponentRegistry(),
            media_file_storage=MemoryMediaFileStorage("/mock/media"),
            uploaded_file_manager=MemoryUploadedFileManager("/mock/upload"),
        )
        self.runtime = Runtime(config)

        # declare_component needs a script_run_ctx to be set
        add_script_run_ctx(threading.current_thread(), create_mock_script_run_ctx())

    def tearDown(self) -> None:
        Runtime._instance = None

    def test_name(self):
        """Test component name generation"""
        # Test a component defined in a module with no package
        component = components.declare_component("foo", url=URL)
        self.assertEqual("tests.streamlit.components_test.foo", component.name)

        # Test a component defined in __init__.py
        from tests.streamlit.component_test_data import component as init_component

        self.assertEqual(
            "tests.streamlit.component_test_data.foo",
            init_component.name,
        )

        # Test a component defined in a module within a package
        from tests.streamlit.component_test_data.outer_module import (
            component as outer_module_component,
        )

        self.assertEqual(
            "tests.streamlit.component_test_data.outer_module.foo",
            outer_module_component.name,
        )

        # Test a component defined in module within a nested package
        from tests.streamlit.component_test_data.nested.inner_module import (
            component as inner_module_component,
        )

        self.assertEqual(
            "tests.streamlit.component_test_data.nested.inner_module.foo",
            inner_module_component.name,
        )

    def test_only_path_str(self):
        """Succeed when a path is provided via str."""

        def isdir(path):
            return path == PATH or path == os.path.abspath(PATH)

        with mock.patch(
            "streamlit.components.v1.component_registry.os.path.isdir",
            side_effect=isdir,
        ):
            component = components.declare_component("test", path=PATH)

        self.assertEqual(PATH, component.path)
        self.assertIsNone(component.url)

        self.assertEqual(
            ComponentRegistry.instance().get_component_path(component.name),
            component.abspath,
        )

    def test_only_path_pathlib(self):
        """Succeed when a path is provided via Path."""

        def isdir(path):
            return path == PATH or path == os.path.abspath(PATH)

        with mock.patch(
            "streamlit.components.v1.component_registry.os.path.isdir",
            side_effect=isdir,
        ):
            component = components.declare_component("test", path=Path(PATH))

        self.assertEqual(PATH, component.path)
        self.assertIsNone(component.url)

        self.assertEqual(
            ComponentRegistry.instance().get_component_path(component.name),
            component.abspath,
        )

    def test_only_url(self):
        """Succeed when a URL is provided."""
        component = components.declare_component("test", url=URL)
        self.assertEqual(URL, component.url)
        self.assertIsNone(component.path)

        self.assertEqual(
            ComponentRegistry.instance().get_component_path("components_test"),
            component.abspath,
        )

    def test_path_and_url(self):
        """Fail if path AND url are provided."""
        with pytest.raises(StreamlitAPIException) as exception_message:
            components.declare_component("test", path=PATH, url=URL)
        self.assertEqual(
            "Either 'path' or 'url' must be set, but not both.",
            str(exception_message.value),
        )

    def test_no_path_and_no_url(self):
        """Fail if neither path nor url is provided."""
        with pytest.raises(StreamlitAPIException) as exception_message:
            components.declare_component("test", path=None, url=None)
        self.assertEqual(
            "Either 'path' or 'url' must be set, but not both.",
            str(exception_message.value),
        )

    def test_module_name_not_none(self):
        caller_frame = inspect.currentframe()
        self.assertIsNotNone(caller_frame)
        module_name = _get_module_name(caller_frame=caller_frame)

        component = components.declare_component("test", url=URL)
        self.assertEqual(
            ComponentRegistry.instance().get_module_name(component.name),
            module_name,
        )

    def test_get_registered_components(self):
        component1 = components.declare_component("test1", url=URL)
        component2 = components.declare_component("test2", url=URL)
        component3 = components.declare_component("test3", url=URL)
        expected_registered_component_names = {
            component1.name,
            component2.name,
            component3.name,
        }

        registered_components = ComponentRegistry.instance().get_components()
        self.assertEqual(
            len(registered_components),
            3,
        )
        registered_component_names = {
            component.name for component in registered_components
        }
        self.assertSetEqual(
            registered_component_names, expected_registered_component_names
        )

    def test_when_registry_not_explicitly_initialized_return_defaultregistry(self):
        ComponentRegistry._instance = None
        components.declare_component("test", url=URL)
        self.assertIsInstance(ComponentRegistry.instance(), LocalComponentRegistry)


class ComponentRegistryTest(unittest.TestCase):
    """Test component registration."""

    def setUp(self) -> None:
        config = RuntimeConfig(
            script_path="mock/script/path.py",
            command_line=None,
            component_registry=LocalComponentRegistry(),
            media_file_storage=MemoryMediaFileStorage("/mock/media"),
            uploaded_file_manager=MemoryUploadedFileManager("/mock/upload"),
        )
        self.runtime = Runtime(config)

    def tearDown(self) -> None:
        Runtime._instance = None

    def test_register_component_with_path(self):
        """Registering a component should associate it with its path."""
        test_path = "/a/test/component/directory"

        def isdir(path):
            return path == test_path

        registry = ComponentRegistry.instance()
        with mock.patch(
            "streamlit.components.types.base_custom_component.os.path.isdir",
            side_effect=isdir,
        ):
            registry.register_component(
                CustomComponent("test_component", path=test_path)
            )

        self.assertEqual(test_path, registry.get_component_path("test_component"))

    def test_register_component_no_path(self):
        """It's not an error to register a component without a path."""
        registry = ComponentRegistry.instance()

        # Return None when the component hasn't been registered
        self.assertIsNone(registry.get_component_path("test_component"))

        # And also return None when the component doesn't have a path
        registry.register_component(
            CustomComponent("test_component", url="http://not.a.url")
        )
        self.assertIsNone(registry.get_component_path("test_component"))

    def test_register_invalid_path(self):
        """We raise an exception if a component is registered with a
        non-existent path.
        """
        test_path = "/a/test/component/directory"

        registry = ComponentRegistry.instance()
        with self.assertRaises(StreamlitAPIException) as ctx:
            registry.register_component(CustomComponent("test_component", test_path))
        self.assertIn("No such component directory", str(ctx.exception))

    def test_register_duplicate_path(self):
        """It's not an error to re-register a component.
        (This can happen during development).
        """
        test_path_1 = "/a/test/component/directory"
        test_path_2 = "/another/test/component/directory"

        def isdir(path):
            return path in (test_path_1, test_path_2)

        registry = ComponentRegistry.instance()
        with mock.patch(
            "streamlit.components.types.base_custom_component.os.path.isdir",
            side_effect=isdir,
        ):
            registry.register_component(CustomComponent("test_component", test_path_1))
            registry.register_component(CustomComponent("test_component", test_path_1))
            self.assertEqual(test_path_1, registry.get_component_path("test_component"))

            registry.register_component(CustomComponent("test_component", test_path_2))
            self.assertEqual(test_path_2, registry.get_component_path("test_component"))


class InvokeComponentTest(DeltaGeneratorTestCase):
    """Test invocation of a custom component object."""

    def setUp(self):
        super().setUp()
        self.test_component = components.declare_component("test", url=URL)

    def test_only_json_args(self):
        """Test that component with only json args is marshalled correctly."""
        self.test_component(foo="bar")
        proto = self.get_delta_from_queue().new_element.component_instance

        self.assertEqual(self.test_component.name, proto.component_name)
        self.assertJSONEqual(
            {"foo": "bar", "key": None, "default": None}, proto.json_args
        )
        self.assertEqual("[]", str(proto.special_args))

    def test_only_df_args(self):
        """Test that component with only dataframe args is marshalled correctly."""
        raw_data = {
            "First Name": ["Jason", "Molly"],
            "Last Name": ["Miller", "Jacobson"],
            "Age": [42, 52],
        }
        df = pd.DataFrame(raw_data, columns=["First Name", "Last Name", "Age"])
        self.test_component(df=df)
        proto = self.get_delta_from_queue().new_element.component_instance

        self.assertEqual(self.test_component.name, proto.component_name)
        self.assertJSONEqual({"key": None, "default": None}, proto.json_args)
        self.assertEqual(1, len(proto.special_args))
        self.assertEqual(_serialize_dataframe_arg("df", df), proto.special_args[0])

    def test_only_list_args(self):
        """Test that component with only list args is marshalled correctly."""
        self.test_component(data=["foo", "bar", "baz"])
        proto = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual(
            {"data": ["foo", "bar", "baz"], "key": None, "default": None},
            proto.json_args,
        )
        self.assertEqual("[]", str(proto.special_args))

    def test_no_args(self):
        """Test that component with no args is marshalled correctly."""
        self.test_component()
        proto = self.get_delta_from_queue().new_element.component_instance

        self.assertEqual(self.test_component.name, proto.component_name)
        self.assertJSONEqual({"key": None, "default": None}, proto.json_args)
        self.assertEqual("[]", str(proto.special_args))

    def test_bytes_args(self):
        self.test_component(foo=b"foo", bar=b"bar")
        proto = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual({"key": None, "default": None}, proto.json_args)
        self.assertEqual(2, len(proto.special_args))
        self.assertEqual(
            _serialize_bytes_arg("foo", b"foo"),
            proto.special_args[0],
        )
        self.assertEqual(
            _serialize_bytes_arg("bar", b"bar"),
            proto.special_args[1],
        )

    def test_mixed_args(self):
        """Test marshalling of a component with varied arg types."""
        df = pd.DataFrame(
            {
                "First Name": ["Jason", "Molly"],
                "Last Name": ["Miller", "Jacobson"],
                "Age": [42, 52],
            },
            columns=["First Name", "Last Name", "Age"],
        )
        self.test_component(string_arg="string", df_arg=df, bytes_arg=b"bytes")
        proto = self.get_delta_from_queue().new_element.component_instance

        self.assertEqual(self.test_component.name, proto.component_name)
        self.assertJSONEqual(
            {"string_arg": "string", "key": None, "default": None},
            proto.json_args,
        )
        self.assertEqual(2, len(proto.special_args))
        self.assertEqual(_serialize_dataframe_arg("df_arg", df), proto.special_args[0])
        self.assertEqual(
            _serialize_bytes_arg("bytes_arg", b"bytes"), proto.special_args[1]
        )

    def test_duplicate_key(self):
        """Two components with the same `key` should throw DuplicateWidgetID exception"""
        self.test_component(foo="bar", key="baz")

        with self.assertRaises(DuplicateWidgetID):
            self.test_component(key="baz")

    def test_key_sent_to_frontend(self):
        """We send the 'key' param to the frontend (even if it's None)."""
        # Test a string key
        self.test_component(key="baz")
        proto = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual({"key": "baz", "default": None}, proto.json_args)

        # Test an empty key
        self.test_component()
        proto = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual({"key": None, "default": None}, proto.json_args)

    def test_widget_id_with_key(self):
        """UNLIKE OTHER WIDGET TYPES, a component with a user-supplied `key` will have a stable widget ID
        even when the component's other parameters change.

        This is important because a component's iframe gets unmounted and remounted - wiping all its
        internal state - when the component's ID changes. We want to be able to pass new data to a
        component's frontend without causing a remount.
        """

        # Create a component instance with a key and some custom data
        self.test_component(key="key", some_data=345)
        proto1 = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual(
            {"key": "key", "default": None, "some_data": 345}, proto1.json_args
        )

        # Clear some ScriptRunCtx data so that we can re-register the same component
        # without getting a DuplicateWidgetID error
        self.script_run_ctx.widget_user_keys_this_run.clear()
        self.script_run_ctx.widget_ids_this_run.clear()

        # Create a second component instance with the same key, and different custom data
        self.test_component(key="key", some_data=678, more_data="foo")
        proto2 = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual(
            {"key": "key", "default": None, "some_data": 678, "more_data": "foo"},
            proto2.json_args,
        )

        # The two component instances should have the same ID, *despite having different
        # data passed to them.*
        self.assertEqual(proto1.id, proto2.id)

    def test_widget_id_without_key(self):
        """Like all other widget types, two component instances with different data parameters,
        and without a specified `key`, will have different widget IDs.
        """

        # Create a component instance without a key and some custom data
        self.test_component(some_data=345)
        proto1 = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual(
            {"key": None, "default": None, "some_data": 345}, proto1.json_args
        )

        # Create a second component instance with different custom data
        self.test_component(some_data=678)
        proto2 = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual(
            {"key": None, "default": None, "some_data": 678}, proto2.json_args
        )

        # The two component instances should have different IDs (just like any other widget would).
        self.assertNotEqual(proto1.id, proto2.id)

    def test_simple_default(self):
        """Test the 'default' param with a JSON value."""
        return_value = self.test_component(default="baz")
        self.assertEqual("baz", return_value)

        proto = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual({"key": None, "default": "baz"}, proto.json_args)

    def test_bytes_default(self):
        """Test the 'default' param with a bytes value."""
        return_value = self.test_component(default=b"bytes")
        self.assertEqual(b"bytes", return_value)

        proto = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual({"key": None}, proto.json_args)
        self.assertEqual(
            _serialize_bytes_arg("default", b"bytes"),
            proto.special_args[0],
        )

    def test_df_default(self):
        """Test the 'default' param with a DataFrame value."""
        df = pd.DataFrame(
            {
                "First Name": ["Jason", "Molly"],
                "Last Name": ["Miller", "Jacobson"],
                "Age": [42, 52],
            },
            columns=["First Name", "Last Name", "Age"],
        )
        return_value = self.test_component(default=df)
        self.assertTrue(df.equals(return_value), "df != return_value")

        proto = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual({"key": None}, proto.json_args)
        self.assertEqual(
            _serialize_dataframe_arg("default", df),
            proto.special_args[0],
        )

    def test_on_change_handler(self):
        """Test the 'on_change' callback param."""

        # we use a list here so that we can update it in the lambda; we cannot assign a variable there.
        callback_call_value = []
        expected_element_value = "Called with foo"

        def create_on_change_handler(some_arg: str):
            return lambda: callback_call_value.append("Called with " + some_arg)

        return_value = self.test_component(
            key="key", default="baz", on_change=create_on_change_handler("foo")
        )
        self.assertEqual("baz", return_value)

        proto = self.get_delta_from_queue().new_element.component_instance
        self.assertJSONEqual({"key": "key", "default": "baz"}, proto.json_args)
        current_widget_states = self.script_run_ctx.session_state.get_widget_states()
        new_widget_state = WidgetState()
        # copy the custom components state and update the value
        new_widget_state.CopyFrom(current_widget_states[0])
        # update the widget's value so that the rerun will execute the callback
        new_widget_state.json_value = '{"key": "key", "default": "baz2"}'
        self.script_run_ctx.session_state.on_script_will_rerun(
            WidgetStates(widgets=[new_widget_state])
        )
        self.assertEqual(callback_call_value[0], expected_element_value)

    def assertJSONEqual(self, a, b):
        """Asserts that two JSON dicts are equal. If either arg is a string,
        it will be first converted to a dict with json.loads()."""
        # Ensure both objects are dicts.
        dict_a = a if isinstance(a, dict) else json.loads(a)
        dict_b = b if isinstance(b, dict) else json.loads(b)
        self.assertEqual(dict_a, dict_b)

    def test_outside_form(self):
        """Test that form id is marshalled correctly outside of a form."""

        self.test_component()

        proto = self.get_delta_from_queue().new_element.component_instance
        self.assertEqual(proto.form_id, "")

    @patch("streamlit.runtime.Runtime.exists", MagicMock(return_value=True))
    def test_inside_form(self):
        """Test that form id is marshalled correctly inside of a form."""

        with st.form("foo"):
            self.test_component()

        # 2 elements will be created: form block, widget
        self.assertEqual(len(self.get_all_deltas_from_queue()), 2)

        form_proto = self.get_delta_from_queue(0).add_block
        component_instance_proto = self.get_delta_from_queue(
            1
        ).new_element.component_instance
        self.assertEqual(component_instance_proto.form_id, form_proto.form.form_id)


class IFrameTest(DeltaGeneratorTestCase):
    def test_iframe(self):
        """Test components.iframe"""
        components.iframe("http://not.a.url", width=200, scrolling=True)

        el = self.get_delta_from_queue().new_element
        self.assertEqual(el.iframe.src, "http://not.a.url")
        self.assertEqual(el.iframe.srcdoc, "")
        self.assertEqual(el.iframe.width, 200)
        self.assertTrue(el.iframe.has_width)
        self.assertTrue(el.iframe.scrolling)

    def test_html(self):
        """Test components.html"""
        html = r"<html><body>An HTML string!</body></html>"
        components.html(html, width=200, scrolling=True)

        el = self.get_delta_from_queue().new_element
        self.assertEqual(el.iframe.src, "")
        self.assertEqual(el.iframe.srcdoc, html)
        self.assertEqual(el.iframe.width, 200)
        self.assertTrue(el.iframe.has_width)
        self.assertTrue(el.iframe.scrolling)


class AlternativeComponentRegistryTest(unittest.TestCase):
    """Test alternative component registry initialization."""

    class AlternativeComponentRegistry(BaseComponentRegistry):
        def __init__(self):
            """Dummy implementation"""
            pass

        def register_component(self, component: BaseCustomComponent) -> None:
            return None

        def get_component_path(self, name: str) -> str | None:
            return None

        def get_module_name(self, name: str) -> str | None:
            return None

        def get_component(self, name: str) -> BaseCustomComponent | None:
            return None

        def get_components(self) -> list[BaseCustomComponent]:
            return []

    def setUp(self) -> None:
        super().setUp()
        registry = AlternativeComponentRegistryTest.AlternativeComponentRegistry()
        # ComponentRegistry.initialize(registry)
        self.assertEqual(ComponentRegistry.instance(), registry)
        self.assertIsInstance(
            registry, AlternativeComponentRegistryTest.AlternativeComponentRegistry
        )
