# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Server.py unit tests"""

from __future__ import annotations

import asyncio
import contextlib
import errno
import os
import re
import subprocess
import sys
import tempfile
import unittest
from pathlib import Path
from unittest import mock
from unittest.mock import patch

import pytest
import tornado.httpserver
import tornado.testing
import tornado.web
import tornado.websocket
from parameterized import parameterized

import streamlit.web.server.server
from streamlit import config
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime import Runtime, RuntimeState
from streamlit.web.server.server import (
    MAX_PORT_SEARCH_RETRIES,
    RetriesExceededError,
    Server,
    start_listening,
)
from tests.streamlit.message_mocks import create_dataframe_msg
from tests.streamlit.web.server.server_test_case import ServerTestCase
from tests.testutil import patch_config_options

LOGGER = get_logger(__name__)


def _create_script_finished_msg(status) -> ForwardMsg:
    msg = ForwardMsg()
    msg.script_finished = status
    return msg


class ServerTest(ServerTestCase):
    def setUp(self) -> None:
        self.original_ws_compression = config.get_option(
            "server.enableWebsocketCompression"
        )
        return super().setUp()

    def tearDown(self):
        config.set_option(
            "server.enableWebsocketCompression", self.original_ws_compression
        )
        return super().tearDown()

    @tornado.testing.gen_test
    async def test_start_stop(self):
        """Test that we can start and stop the server."""
        with self._patch_app_session():
            await self.server.start()
            assert self.server._runtime._state == RuntimeState.NO_SESSIONS_CONNECTED

            await self.ws_connect()
            assert (
                self.server._runtime._state
                == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED
            )

            self.server.stop()
            await asyncio.sleep(0)  # Wait a tick for the stop to be acknowledged
            assert self.server._runtime._state == RuntimeState.STOPPING

            await asyncio.sleep(0.1)
            assert self.server._runtime._state == RuntimeState.STOPPED

    @tornado.testing.gen_test
    async def test_websocket_connect(self):
        """Test that we can connect to the server via websocket."""
        with self._patch_app_session():
            await self.server.start()

            assert not self.server.browser_is_connected

            # Open a websocket connection
            ws_client = await self.ws_connect()
            assert self.server.browser_is_connected

            # Get this client's SessionInfo object
            assert self.server._runtime._session_mgr.num_active_sessions() == 1
            session_info = self.server._runtime._session_mgr.list_active_sessions()[0]

            # Close the connection
            ws_client.close()
            await asyncio.sleep(0.1)
            assert not self.server.browser_is_connected

            # Ensure AppSession.disconnect_file_watchers() was called, and that our
            # session exists but is no longer active.
            session_info.session.disconnect_file_watchers.assert_called_once()
            assert self.server._runtime._session_mgr.num_active_sessions() == 0
            assert self.server._runtime._session_mgr.num_sessions() == 1

    @tornado.testing.gen_test
    async def test_websocket_connect_to_nonexistent_session(self):
        with self._patch_app_session():
            await self.server.start()

            ws_client = await self.ws_connect(existing_session_id="nonexistent_session")

            session_info = self.server._runtime._session_mgr.list_active_sessions()[0]

            assert session_info.session.id != "nonexistent_session"

            ws_client.close()
            await asyncio.sleep(0.1)

    @tornado.testing.gen_test
    async def test_websocket_disconnect_and_reconnect(self):
        with self._patch_app_session():
            await self.server.start()

            ws_client = await self.ws_connect()
            original_session_info = (
                self.server._runtime._session_mgr.list_active_sessions()[0]
            )

            # Disconnect, reconnect with the same session_id, and confirm that the
            # session was reused.
            ws_client.close()
            await asyncio.sleep(0.1)

            ws_client = await self.ws_connect(
                existing_session_id=original_session_info.session.id
            )

            assert self.server._runtime._session_mgr.num_active_sessions() == 1
            new_session_info = self.server._runtime._session_mgr.list_active_sessions()[
                0
            ]
            assert new_session_info.session == original_session_info.session

            ws_client.close()
            await asyncio.sleep(0.1)

    @tornado.testing.gen_test
    async def test_multiple_connections(self):
        """Test multiple websockets can connect simultaneously."""
        with self._patch_app_session():
            await self.server.start()

            assert not self.server.browser_is_connected

            # Open a websocket connection
            ws_client1 = await self.ws_connect()
            assert self.server.browser_is_connected

            # Open another
            ws_client2 = await self.ws_connect()
            assert self.server.browser_is_connected

            # Assert that our session_infos are sane
            session_infos = self.server._runtime._session_mgr.list_active_sessions()
            assert len(session_infos) == 2
            assert session_infos[0].session.id != session_infos[1].session.id

            # Close the first
            ws_client1.close()
            await asyncio.sleep(0.1)
            assert self.server.browser_is_connected

            # Close the second
            ws_client2.close()
            await asyncio.sleep(0.1)
            assert not self.server.browser_is_connected

    @tornado.testing.gen_test
    async def test_websocket_compression(self):
        with self._patch_app_session():
            config._set_option("server.enableWebsocketCompression", True, "test")
            await self.server.start()

            # Connect to the server, and explicitly request compression.
            ws_client = await tornado.websocket.websocket_connect(
                self.get_ws_url("/_stcore/stream"), compression_options={}
            )

            # Ensure that the "permessage-deflate" extension is returned
            # from the server.
            extensions = ws_client.headers.get("Sec-Websocket-Extensions")
            assert "permessage-deflate" in extensions

    @tornado.testing.gen_test
    async def test_websocket_compression_disabled(self):
        with self._patch_app_session():
            config._set_option("server.enableWebsocketCompression", False, "test")
            await self.server.start()

            # Connect to the server, and explicitly request compression.
            ws_client = await tornado.websocket.websocket_connect(
                self.get_ws_url("/_stcore/stream"), compression_options={}
            )

            # Ensure that the "Sec-Websocket-Extensions" header is not
            # present in the response from the server.
            assert ws_client.headers.get("Sec-Websocket-Extensions") is None

    @tornado.testing.gen_test
    async def test_send_message_to_disconnected_websocket(self):
        """Sending a message to a disconnected SessionClient raises an error.
        We should gracefully handle the error by cleaning up the session.
        """
        with self._patch_app_session():
            await self.server.start()
            await self.ws_connect()

            # Get the server's socket and session for this client
            session_info = self.server._runtime._session_mgr.list_active_sessions()[0]

            with (
                patch.object(
                    session_info.session, "flush_browser_queue"
                ) as flush_browser_queue,
                patch.object(session_info.client, "write_message") as ws_write_message,
            ):
                # Patch flush_browser_queue to simulate a pending message.
                flush_browser_queue.return_value = [create_dataframe_msg([1, 2, 3])]

                # Patch the session's WebsocketHandler to raise a
                # WebSocketClosedError when we write to it.
                ws_write_message.side_effect = tornado.websocket.WebSocketClosedError()

                # Tick the server. Our session's browser_queue will be flushed,
                # and the Websocket client's write_message will be called,
                # raising our WebSocketClosedError.
                while not flush_browser_queue.called:
                    self.server._runtime._get_async_objs().need_send_data.set()
                    await asyncio.sleep(0)

                flush_browser_queue.assert_called_once()
                ws_write_message.assert_called_once()

                # Our session should have been removed from the server as
                # a result of the WebSocketClosedError.
                assert (
                    self.server._runtime._session_mgr.get_active_session_info(
                        session_info.session.id
                    )
                    is None
                )

    @tornado.testing.gen_test
    async def test_tornado_settings_applied(self):
        """Test that TORNADO_SETTINGS are properly applied to the app."""
        from streamlit.web.server.server import get_tornado_settings

        # Reset config to test default behavior
        config._set_option("server.websocketPingInterval", None, "test")

        tornado_settings = get_tornado_settings()
        assert (
            self.app_settings["websocket_ping_interval"]
            == tornado_settings["websocket_ping_interval"]
        )
        assert (
            self.app_settings["websocket_ping_timeout"]
            == tornado_settings["websocket_ping_timeout"]
        )

        # In default case, timeout should always be 30
        assert tornado_settings["websocket_ping_timeout"] == 30

    @tornado.testing.gen_test
    async def test_websocket_ping_interval_custom_config(self):
        """Test that custom websocket ping interval is respected."""
        from streamlit.web.server.server import (
            _get_websocket_ping_interval_and_timeout,
            get_tornado_settings,
        )

        # Test custom configuration that's valid for all versions
        config._set_option("server.websocketPingInterval", 45, "test")
        interval, timeout = _get_websocket_ping_interval_and_timeout()
        assert interval == 45
        assert timeout == 45
        settings = get_tornado_settings()
        assert settings["websocket_ping_interval"] == 45
        assert (
            settings["websocket_ping_timeout"] == 45
        )  # Timeout matches interval when configured

        # Test high value
        config._set_option("server.websocketPingInterval", 120, "test")
        interval, timeout = _get_websocket_ping_interval_and_timeout()
        assert interval == 120
        assert timeout == 120
        settings = get_tornado_settings()
        assert settings["websocket_ping_interval"] == 120
        assert (
            settings["websocket_ping_timeout"] == 120
        )  # Timeout matches interval when configured

        # Reset config for other tests
        config._set_option("server.websocketPingInterval", None, "test")

    @tornado.testing.gen_test
    @patch("streamlit.web.server.server.is_tornado_version_less_than")
    async def test_websocket_ping_interval_tornado_old(self, mock_version_check):
        """Test websocket ping interval with Tornado < 6.5."""
        from streamlit.web.server.server import (
            _get_websocket_ping_interval_and_timeout,
            get_tornado_settings,
        )

        # Mock old Tornado version
        mock_version_check.return_value = True

        # Test default with old Tornado
        config._set_option("server.websocketPingInterval", None, "test")
        interval, timeout = _get_websocket_ping_interval_and_timeout()
        assert interval == 1
        assert timeout == 30
        settings = get_tornado_settings()
        assert settings["websocket_ping_interval"] == 1
        assert (
            settings["websocket_ping_timeout"] == 30
        )  # Timeout still 30 in default case!

        # Test low values are accepted
        config._set_option("server.websocketPingInterval", 5, "test")
        interval, timeout = _get_websocket_ping_interval_and_timeout()
        assert interval == 5
        assert timeout == 5
        settings = get_tornado_settings()
        assert settings["websocket_ping_interval"] == 5
        assert (
            settings["websocket_ping_timeout"] == 5
        )  # Timeout matches when configured

        # Reset config
        config._set_option("server.websocketPingInterval", None, "test")

    @tornado.testing.gen_test
    @patch("streamlit.web.server.server.is_tornado_version_less_than")
    async def test_websocket_ping_interval_tornado_new(self, mock_version_check):
        """Test websocket ping interval with Tornado >= 6.5."""
        from streamlit.web.server.server import _get_websocket_ping_interval_and_timeout

        # Mock new Tornado version
        mock_version_check.return_value = False

        # Test default with new Tornado
        config._set_option("server.websocketPingInterval", None, "test")
        interval, timeout = _get_websocket_ping_interval_and_timeout()
        assert interval == 30
        assert timeout == 30

        # Test that low values are respected
        config._set_option("server.websocketPingInterval", 10, "test")
        interval, timeout = _get_websocket_ping_interval_and_timeout()
        assert interval == 10
        assert timeout == 10

        # Test that values >= 30 are kept as-is
        config._set_option("server.websocketPingInterval", 60, "test")
        interval, timeout = _get_websocket_ping_interval_and_timeout()
        assert interval == 60
        assert timeout == 60

        # Reset config
        config._set_option("server.websocketPingInterval", None, "test")


class PortRotateAHundredTest(unittest.TestCase):
    """Tests port rotation handles a MAX_PORT_SEARCH_RETRIES attempts then sys exits"""

    def setUp(self) -> None:
        self.original_port = config.get_option("server.port")
        return super().setUp()

    def tearDown(self) -> None:
        config.set_option("server.port", self.original_port)
        return super().tearDown()

    @staticmethod
    def get_httpserver():
        httpserver = mock.MagicMock()

        httpserver.listen = mock.Mock()
        httpserver.listen.side_effect = OSError(errno.EADDRINUSE, "test", "asd")

        return httpserver

    def test_rotates_a_hundred_ports(self):
        app = mock.MagicMock()

        RetriesExceededError = streamlit.web.server.server.RetriesExceededError
        with (
            pytest.raises(RetriesExceededError) as pytest_wrapped_e,
            patch(
                "streamlit.web.server.server.HTTPServer",
                return_value=self.get_httpserver(),
            ) as mock_server,
        ):
            start_listening(app)
            assert pytest_wrapped_e.type is SystemExit
            assert pytest_wrapped_e.value.code == errno.EADDRINUSE
            assert mock_server.listen.call_count == MAX_PORT_SEARCH_RETRIES


class PortRotateOneTest(unittest.TestCase):
    """Tests port rotates one port"""

    which_port = mock.Mock()

    @staticmethod
    def get_httpserver():
        httpserver = mock.MagicMock()

        httpserver.listen = mock.Mock()
        httpserver.listen.side_effect = OSError(errno.EADDRINUSE, "test", "asd")

        return httpserver

    @mock.patch("streamlit.web.server.server.config._set_option")
    @mock.patch("streamlit.web.server.server.server_port_is_manually_set")
    def test_rotates_one_port(
        self, patched_server_port_is_manually_set, patched__set_option
    ):
        app = mock.MagicMock()

        patched_server_port_is_manually_set.return_value = False
        with (
            pytest.raises(RetriesExceededError),
            patch(
                "streamlit.web.server.server.HTTPServer",
                return_value=self.get_httpserver(),
            ),
        ):
            start_listening(app)

            PortRotateOneTest.which_port.assert_called_with(8502)

            patched__set_option.assert_called_with(
                "server.port", 8501, config.ConfigOption.STREAMLIT_DEFINITION
            )


class SslServerTest(unittest.TestCase):
    """Tests SSL server"""

    @parameterized.expand(["server.sslCertFile", "server.sslKeyFile"])
    def test_requires_two_options(self, option_name):
        """
        The test checks the behavior whenever one of the two required configuration
        option is set.
        """
        with (
            patch_config_options({option_name: "/tmp/file"}),
            pytest.raises(SystemExit),
            self.assertLogs("streamlit.web.server.server") as logs,
        ):
            start_listening(mock.MagicMock())
        assert logs.output == [
            "ERROR:streamlit.web.server.server:Options 'server.sslCertFile' and "
            "'server.sslKeyFile' must be set together. Set missing options or delete "
            "existing options."
        ]

    @parameterized.expand(["server.sslCertFile", "server.sslKeyFile"])
    def test_missing_file(self, option_name):
        """
        The test checks the behavior whenever one of the two requires file is missing.
        """
        with contextlib.ExitStack() as exit_stack:
            tmp_dir = exit_stack.enter_context(tempfile.TemporaryDirectory())

            cert_file = Path(tmp_dir) / "cert.cert"
            key_file = Path(tmp_dir) / "key.key"

            new_options = {
                "server.sslCertFile": cert_file,
                "server.sslKeyFile": key_file,
            }
            exit_stack.enter_context(patch_config_options(new_options))

            # Create only one file
            Path(new_options[option_name]).write_text("TEST-CONTENT", encoding="utf-8")

            exit_stack.enter_context(pytest.raises(SystemExit))
            logs = exit_stack.enter_context(
                self.assertLogs("streamlit.web.server.server")
            )

            start_listening(mock.MagicMock())

        assert re.search(
            r"ERROR:streamlit\.web\.server\.server:(Cert|Key) file '.+' does not exist\.",
            logs.output[0],
        )

    @parameterized.expand(["server.sslCertFile", "server.sslKeyFile"])
    @unittest.skipIf("win32" in sys.platform, "Windows does not natively have openssl")
    def test_invalid_file_content(self, option_name):
        """
        The test checks the behavior whenever one of the two requires file is corrupted.
        """
        with contextlib.ExitStack() as exit_stack:
            tmp_dir = exit_stack.enter_context(tempfile.TemporaryDirectory())
            cert_file = Path(tmp_dir) / "cert.cert"
            key_file = Path(tmp_dir) / "key.key"

            subprocess.check_call(
                [
                    "openssl",
                    "req",
                    "-x509",
                    "-newkey",
                    "rsa:4096",
                    "-keyout",
                    str(key_file),
                    "-out",
                    str(cert_file),
                    "-sha256",
                    "-days",
                    "365",
                    "-nodes",
                    "-subj",
                    "/CN=localhost",
                    # sublectAltName is required by modern browsers
                    # See: https://github.com/urllib3/urllib3/issues/497
                    "-addext",
                    "subjectAltName = DNS:localhost",
                ]
            )
            new_options = {
                "server.sslCertFile": cert_file,
                "server.sslKeyFile": key_file,
            }
            exit_stack.enter_context(patch_config_options(new_options))

            # Overwrite file with invalid content
            Path(new_options[option_name]).write_text(
                "INVALID-CONTENT", encoding="utf-8"
            )

            exit_stack.enter_context(pytest.raises(SystemExit))
            logs = exit_stack.enter_context(
                self.assertLogs("streamlit.web.server.server")
            )

            start_listening(mock.MagicMock())
        assert re.search(
            r"ERROR:streamlit\.web\.server\.server:Failed to load SSL certificate\. Make "
            r"sure cert file '.+' and key file '.+' are correct\.",
            logs.output[0],
        )


class UnixSocketTest(unittest.TestCase):
    """Tests start_listening uses a unix socket when socket.address starts with
    unix://"""

    def setUp(self) -> None:
        self.original_address = config.get_option("server.address")
        return super().setUp()

    def tearDown(self) -> None:
        config.set_option("server.address", self.original_address)
        return super().tearDown()

    @staticmethod
    def get_httpserver():
        httpserver = mock.MagicMock()

        httpserver.add_socket = mock.Mock()

        return httpserver

    @unittest.skipIf("win32" in sys.platform, "Windows does not have unit sockets")
    def test_unix_socket(self):
        app = mock.MagicMock()

        config.set_option("server.address", "unix://~/fancy-test/testasd")
        some_socket = object()

        mock_server = self.get_httpserver()
        with (
            patch("streamlit.web.server.server.HTTPServer", return_value=mock_server),
            patch.object(
                tornado.netutil, "bind_unix_socket", return_value=some_socket
            ) as bind_unix_socket,
            patch.dict(os.environ, {"HOME": "/home/superfakehomedir"}),
        ):
            start_listening(app)

            bind_unix_socket.assert_called_with(
                "/home/superfakehomedir/fancy-test/testasd"
            )
            mock_server.add_socket.assert_called_with(some_socket)


class ScriptCheckEndpointExistsTest(tornado.testing.AsyncHTTPTestCase):
    async def does_script_run_without_error(self):
        return True, "test_message"

    def setUp(self):
        self._old_config = config.get_option("server.scriptHealthCheckEnabled")
        config._set_option("server.scriptHealthCheckEnabled", True, "test")
        super().setUp()

    def tearDown(self):
        config._set_option("server.scriptHealthCheckEnabled", self._old_config, "test")
        Runtime._instance = None
        super().tearDown()

    def get_app(self):
        server = Server("mock/script/path", is_hello=False)
        server._runtime.does_script_run_without_error = (
            self.does_script_run_without_error
        )
        server._runtime._eventloop = self.io_loop.asyncio_loop
        return server._create_app()

    def test_endpoint(self):
        response = self.fetch("/_stcore/script-health-check")
        assert response.code == 200
        assert response.body == b"test_message"

    def test_deprecated_endpoint(self):
        response = self.fetch("/script-health-check")
        assert response.code == 200
        assert response.body == b"test_message"
        assert (
            response.headers["link"]
            == f'<http://127.0.0.1:{self.get_http_port()}/_stcore/script-health-check>; rel="alternate"'
        )
        assert response.headers["deprecation"] == "True"


class ScriptCheckEndpointDoesNotExistTest(tornado.testing.AsyncHTTPTestCase):
    async def does_script_run_without_error(self):
        self.fail("Should not be called")

    def setUp(self):
        self._old_config = config.get_option("server.scriptHealthCheckEnabled")
        config._set_option("server.scriptHealthCheckEnabled", False, "test")
        super().setUp()

    def tearDown(self):
        config._set_option("server.scriptHealthCheckEnabled", self._old_config, "test")
        Runtime._instance = None
        super().tearDown()

    def get_app(self):
        server = Server("mock/script/path", is_hello=False)
        server._runtime.does_script_run_without_error = (
            self.does_script_run_without_error
        )
        return server._create_app()

    def test_endpoint(self):
        response = self.fetch("/script-health-check")
        assert response.code == 404
