# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from contextlib import contextmanager

import pytest

from neo4j import (
    AsyncManagedTransaction,
    AsyncSession,
    AsyncTransaction,
    Auth,
    Bookmarks,
    unit_of_work,
)
from neo4j._api import TelemetryAPI
from neo4j._async.home_db_cache import AsyncHomeDbCache
from neo4j._async.io import (
    AcquisitionDatabase,
    AsyncBoltPool,
    AsyncNeo4jPool,
)
from neo4j._async_compat.util import AsyncUtil
from neo4j._auth_management import to_auth_dict
from neo4j._conf import SessionConfig
from neo4j.api import (
    AsyncBookmarkManager,
    READ_ACCESS,
    WRITE_ACCESS,
)

from ...._async_compat import mark_async_test


@contextmanager
def assert_warns_tx_func_deprecation(tx_func_name):
    if tx_func_name.endswith("_transaction"):
        mode = tx_func_name.split("_")[0]
        with pytest.warns(
            DeprecationWarning,
            match=f"^{mode}_transaction has been renamed to execute_{mode}$",
        ):
            yield
    else:
        yield


@mark_async_test
async def test_session_context_calls_close(mocker):
    s = AsyncSession(None, SessionConfig())
    mock_close = mocker.patch.object(
        s, "close", autospec=True, side_effect=s.close
    )
    async with s:
        pass
    mock_close.assert_called_once_with()


@pytest.mark.parametrize(
    "test_run_args", (("RETURN $x", {"x": 1}), ("RETURN 1",))
)
@pytest.mark.parametrize(
    ("repetitions", "consume"), ((1, False), (2, False), (2, True))
)
@mark_async_test
async def test_opens_connection_on_run(
    async_fake_pool, test_run_args, repetitions, consume
):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        assert session._connection is None
        result = await session.run(*test_run_args)
        assert session._connection is not None
        if consume:
            await result.consume()


@pytest.mark.parametrize(
    "test_run_args", (("RETURN $x", {"x": 1}), ("RETURN 1",))
)
@pytest.mark.parametrize("repetitions", range(1, 3))
@mark_async_test
async def test_closes_connection_after_consume(
    async_fake_pool, test_run_args, repetitions
):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        result = await session.run(*test_run_args)
        await result.consume()
        assert session._connection is None
    assert session._connection is None


@pytest.mark.parametrize(
    "test_run_args", (("RETURN $x", {"x": 1}), ("RETURN 1",))
)
@mark_async_test
async def test_keeps_connection_until_last_result_consumed(
    async_fake_pool, test_run_args
):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        result1 = await session.run(*test_run_args)
        result2 = await session.run(*test_run_args)
        assert session._connection is not None
        await result1.consume()
        assert session._connection is not None
        await result2.consume()
        assert session._connection is None


@mark_async_test
async def test_opens_connection_on_tx_begin(async_fake_pool):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        assert session._connection is None
        async with await session.begin_transaction() as _:
            assert session._connection is not None


@pytest.mark.parametrize(
    "test_run_args", (("RETURN $x", {"x": 1}), ("RETURN 1",))
)
@pytest.mark.parametrize("repetitions", range(1, 3))
@mark_async_test
async def test_keeps_connection_on_tx_run(
    async_fake_pool, test_run_args, repetitions
):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        async with await session.begin_transaction() as tx:
            for _ in range(repetitions):
                await tx.run(*test_run_args)
                assert session._connection is not None


@pytest.mark.parametrize(
    "test_run_args", (("RETURN $x", {"x": 1}), ("RETURN 1",))
)
@pytest.mark.parametrize("repetitions", range(1, 3))
@mark_async_test
async def test_keeps_connection_on_tx_consume(
    async_fake_pool, test_run_args, repetitions
):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        async with await session.begin_transaction() as tx:
            for _ in range(repetitions):
                result = await tx.run(*test_run_args)
                await result.consume()
                assert session._connection is not None


@pytest.mark.parametrize(
    "test_run_args", (("RETURN $x", {"x": 1}), ("RETURN 1",))
)
@mark_async_test
async def test_closes_connection_after_tx_close(
    async_fake_pool, test_run_args
):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        async with await session.begin_transaction() as tx:
            for _ in range(2):
                result = await tx.run(*test_run_args)
                await result.consume()
            await tx.close()
            assert session._connection is None
        assert session._connection is None


@pytest.mark.parametrize(
    "test_run_args", (("RETURN $x", {"x": 1}), ("RETURN 1",))
)
@mark_async_test
async def test_closes_connection_after_tx_commit(
    async_fake_pool, test_run_args
):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        async with await session.begin_transaction() as tx:
            for _ in range(2):
                result = await tx.run(*test_run_args)
                await result.consume()
            await tx.commit()
            assert session._connection is None
        assert session._connection is None


@pytest.mark.parametrize(
    "bookmark_values",
    (None, [], ["abc"], ["foo", "bar"], {"a", "b"}, ("1", "two")),
)
@mark_async_test
async def test_session_returns_bookmarks_directly(
    async_fake_pool, bookmark_values
):
    if bookmark_values is not None:
        bookmarks = Bookmarks.from_raw_values(bookmark_values)
    else:
        bookmarks = Bookmarks()
    async with AsyncSession(
        async_fake_pool, SessionConfig(bookmarks=bookmarks)
    ) as session:
        ret_bookmarks = await session.last_bookmarks()
        assert isinstance(ret_bookmarks, Bookmarks)
        ret_bookmarks = ret_bookmarks.raw_values
        if bookmark_values is None:
            assert ret_bookmarks == frozenset()
        else:
            assert ret_bookmarks == frozenset(bookmark_values)


@pytest.mark.parametrize(
    "bookmarks", (None, [], ["abc"], ["foo", "bar"], ("1", "two"))
)
@mark_async_test
async def test_session_last_bookmark_is_deprecated(async_fake_pool, bookmarks):
    if bookmarks is not None:
        with pytest.warns(DeprecationWarning):
            session = AsyncSession(
                async_fake_pool, SessionConfig(bookmarks=bookmarks)
            )
    else:
        session = AsyncSession(
            async_fake_pool, SessionConfig(bookmarks=bookmarks)
        )
    async with session:
        with pytest.warns(DeprecationWarning):
            if bookmarks:
                assert (await session.last_bookmark()) == bookmarks[-1]
            else:
                assert (await session.last_bookmark()) is None


@pytest.mark.parametrize(
    "bookmarks", (("foo",), ("foo", "bar"), (), ["foo", "bar"], {"a", "b"})
)
@mark_async_test
async def test_session_bookmarks_as_iterable_is_deprecated(
    async_fake_pool, bookmarks
):
    with pytest.warns(DeprecationWarning):
        async with AsyncSession(
            async_fake_pool, SessionConfig(bookmarks=bookmarks)
        ) as session:
            ret_bookmarks = (await session.last_bookmarks()).raw_values
            assert ret_bookmarks == frozenset(bookmarks)


@pytest.mark.parametrize(
    ("query", "error_type"),
    (
        (None, ValueError),
        (1234, TypeError),
        ({"how about": "no?"}, TypeError),
        (["I don't", "think so"], TypeError),
    ),
)
@mark_async_test
async def test_session_run_wrong_types(async_fake_pool, query, error_type):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        with pytest.raises(error_type):
            await session.run(query)


@pytest.mark.parametrize(
    "tx_type",
    (
        "write_transaction",
        "read_transaction",
        "execute_write",
        "execute_read",
    ),
)
@mark_async_test
async def test_tx_function_argument_type(async_fake_pool, tx_type):
    called = False

    async def work(tx):
        nonlocal called
        called = True
        assert isinstance(tx, AsyncManagedTransaction)

    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        with assert_warns_tx_func_deprecation(tx_type):
            await getattr(session, tx_type)(work)
        assert called


@pytest.mark.parametrize(
    "tx_type",
    ("write_transaction", "read_transaction", "execute_write", "execute_read"),
)
@pytest.mark.parametrize(
    "decorator_kwargs",
    (
        {},
        {"timeout": 5},
        {"metadata": {"foo": "bar"}},
        {"timeout": 5, "metadata": {"foo": "bar"}},
    ),
)
@mark_async_test
async def test_decorated_tx_function_argument_type(
    async_fake_pool, tx_type, decorator_kwargs
):
    called = False

    @unit_of_work(**decorator_kwargs)
    async def work(tx):
        nonlocal called
        called = True
        assert isinstance(tx, AsyncManagedTransaction)

    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        with assert_warns_tx_func_deprecation(tx_type):
            await getattr(session, tx_type)(work)
        assert called
    assert len(async_fake_pool.acquired_connection_mocks) == 1
    cx = async_fake_pool.acquired_connection_mocks[0]
    cx.begin.assert_called_once()
    for key in ("timeout", "metadata"):
        value = decorator_kwargs.get(key)
        assert cx.begin.call_args[1][key] == value


@mark_async_test
async def test_session_tx_type(async_fake_pool):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        tx = await session.begin_transaction()
        assert isinstance(tx, AsyncTransaction)


@pytest.mark.parametrize(
    "parameters",
    (
        {"x": None},
        {"x": True},
        {"x": False},
        {"x": 123456789},
        {"x": 3.1415926},
        {"x": float("nan")},
        {"x": float("inf")},
        {"x": float("-inf")},
        {"x": "foo"},
        {"x": bytearray([0x00, 0x33, 0x66, 0x99, 0xCC, 0xFF])},
        {"x": b"\x00\x33\x66\x99\xcc\xff"},
        {"x": [1, 2, 3]},
        {"x": ["a", "b", "c"]},
        {"x": ["a", 2, 1.234]},
        {"x": ["a", 2, ["c"]]},
        {"x": {"one": "eins", "two": "zwei", "three": "drei"}},
        {"x": {"one": ["eins", "uno", 1], "two": ["zwei", "dos", 2]}},
    ),
)
@pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed"))
@mark_async_test
async def test_session_run_with_parameters(
    async_fake_pool, parameters, run_type
):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        if run_type == "auto":
            await session.run("RETURN $x", **parameters)
        elif run_type == "unmanaged":
            tx = await session.begin_transaction()
            await tx.run("RETURN $x", **parameters)
        elif run_type == "managed":

            async def work(tx):
                await tx.run("RETURN $x", **parameters)

            await session.execute_write(work)
        else:
            raise ValueError(run_type)

    assert len(async_fake_pool.acquired_connection_mocks) == 1
    connection_mock = async_fake_pool.acquired_connection_mocks[0]
    connection_mock.run.assert_called_once()
    call = connection_mock.run.call_args
    assert call.args[0] == "RETURN $x"
    assert call.kwargs["parameters"] == parameters


@pytest.mark.parametrize(
    ("params", "kw_params", "expected_params"),
    (
        ({"x": 1}, {}, {"x": 1}),
        ({}, {"x": 1}, {"x": 1}),
        ({"x": 1}, {"y": 2}, {"x": 1, "y": 2}),
        ({"x": 1}, {"x": 2}, {"x": 2}),
        ({"x": 1, "y": 3}, {"x": 2}, {"x": 2, "y": 3}),
        ({"x": 1}, {"x": 2, "y": 3}, {"x": 2, "y": 3}),
        # potentially internally used keyword arguments
        ({}, {"timeout": 2}, {"timeout": 2}),
        ({"timeout": 2}, {}, {"timeout": 2}),
        ({}, {"imp_user": "hans"}, {"imp_user": "hans"}),
        ({"imp_user": "hans"}, {}, {"imp_user": "hans"}),
        ({}, {"db": "neo4j"}, {"db": "neo4j"}),
        ({"db": "neo4j"}, {}, {"db": "neo4j"}),
        ({}, {"database": "neo4j"}, {"database": "neo4j"}),
        ({"database": "neo4j"}, {}, {"database": "neo4j"}),
    ),
)
@pytest.mark.parametrize("run_type", ("auto", "unmanaged", "managed"))
@mark_async_test
async def test_session_run_parameter_precedence(
    async_fake_pool, params, kw_params, expected_params, run_type
):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        if run_type == "auto":
            await session.run("RETURN $x", params, **kw_params)
        elif run_type == "unmanaged":
            tx = await session.begin_transaction()
            await tx.run("RETURN $x", params, **kw_params)
        elif run_type == "managed":

            async def work(tx):
                await tx.run("RETURN $x", params, **kw_params)

            await session.execute_write(work)
        else:
            raise ValueError(run_type)

    assert len(async_fake_pool.acquired_connection_mocks) == 1
    connection_mock = async_fake_pool.acquired_connection_mocks[0]
    connection_mock.run.assert_called_once()
    call = connection_mock.run.call_args
    assert call.args[0] == "RETURN $x"
    assert call.kwargs["parameters"] == expected_params


@pytest.mark.parametrize("db", (None, "adb"))
@pytest.mark.parametrize("routing", (True, False))
# no home db resolution when connected to Neo4j 4.3 or earlier
@pytest.mark.parametrize("home_db_gets_resolved", (True, False))
@pytest.mark.parametrize(
    "additional_session_bookmarks", (None, ["session", "bookmarks"])
)
@mark_async_test
async def test_with_bookmark_manager(
    async_fake_pool,
    db,
    routing,
    async_scripted_connection,
    home_db_gets_resolved,
    additional_session_bookmarks,
    mocker,
):
    async def update_routing_table_side_effect(
        database,
        imp_user,
        bookmarks,
        auth=None,
        acquisition_timeout=None,
        database_callback=None,
    ):
        if home_db_gets_resolved:
            database_callback("homedb")

    async def bmm_get_bookmarks():
        nonlocal get_bookmarks_count
        get_bookmarks_count += 1
        return ["all", f"bookmarks{get_bookmarks_count}"]

    get_bookmarks_count = 0

    async_scripted_connection.set_script(
        [
            ("run", {"on_success": None, "on_summary": None}),
            (
                "pull",
                {
                    "on_success": (
                        {"bookmark": "res:bm1", "has_more": False},
                    ),
                    "on_summary": None,
                    "on_records": None,
                },
            ),
        ]
    )
    async_fake_pool.buffered_connection_mocks.append(async_scripted_connection)

    bmm = mocker.Mock(spec=AsyncBookmarkManager)
    bmm.get_bookmarks.side_effect = bmm_get_bookmarks

    if routing:
        async_fake_pool.mock_add_spec(AsyncNeo4jPool)
        async_fake_pool.update_routing_table.side_effect = (
            update_routing_table_side_effect
        )
        async_fake_pool.is_direct_pool = False
    else:
        async_fake_pool.mock_add_spec(AsyncBoltPool)
        async_fake_pool.is_direct_pool = True

    config = SessionConfig()
    config.bookmark_manager = bmm
    if db is not None:
        config.database = db
    if additional_session_bookmarks:
        config.bookmarks = Bookmarks.from_raw_values(
            additional_session_bookmarks
        )
    async with AsyncSession(async_fake_pool, config) as session:
        assert not bmm.method_calls

        await session.run("RETURN 1")

        # assert called bmm accordingly
        expected_bmm_method_calls = [
            mocker.call.get_bookmarks(),
            mocker.call.get_bookmarks(),
        ]
        if routing and db is None:
            expected_bmm_method_calls = [
                # extra call for resolving the home database
                mocker.call.get_bookmarks(),
                *expected_bmm_method_calls,
            ]
        assert bmm.method_calls == expected_bmm_method_calls
        assert bmm.get_bookmarks.await_count == len(expected_bmm_method_calls)
        bmm.method_calls.clear()

    expected_bookmarks_home_db_resolution = {
        "all",
        "bookmarks1",
        *(additional_session_bookmarks or []),
    }
    expected_bookmarks_acquire = {
        "all",
        f"bookmarks{len(expected_bmm_method_calls) - 1}",
        *(additional_session_bookmarks or []),
    }
    expected_bookmarks_run = {
        "all",
        f"bookmarks{len(expected_bmm_method_calls)}",
        *(additional_session_bookmarks or []),
    }
    assert [call[0] for call in bmm.method_calls] == ["update_bookmarks"]
    assert bmm.method_calls[0].kwargs == {}
    assert len(bmm.method_calls[0].args) == 2
    assert set(bmm.method_calls[0].args[0]) == expected_bookmarks_run
    assert set(bmm.method_calls[0].args[1]) == {"res:bm1"}

    expected_pool_method_calls = ["acquire", "release"]
    if routing and db is None:
        expected_pool_method_calls = [
            "update_routing_table",
            *expected_pool_method_calls,
        ]
    assert [
        call[0] for call in async_fake_pool.method_calls
    ] == expected_pool_method_calls
    assert (
        set(async_fake_pool.acquire.call_args.kwargs["bookmarks"])
        == expected_bookmarks_acquire
    )
    if routing and db is None:
        assert (
            set(
                async_fake_pool.update_routing_table.call_args.kwargs[
                    "bookmarks"
                ]
            )
            == expected_bookmarks_home_db_resolution
        )

    assert len(async_fake_pool.acquired_connection_mocks) == 1
    connection_mock = async_fake_pool.acquired_connection_mocks[0]
    connection_mock.run.assert_called_once()
    connection_run_call_kwargs = connection_mock.run.call_args.kwargs
    assert (
        set(connection_run_call_kwargs["bookmarks"]) == expected_bookmarks_run
    )


@pytest.mark.parametrize("routing", (True, False))
@pytest.mark.parametrize("session_method", ("run", "get_server_info"))
@mark_async_test
async def test_last_bookmarks_does_not_leak_bookmark_managers_bookmarks(
    async_fake_pool, routing, session_method, mocker
):
    async def bmm_get_bookmarks():
        return ["bmm:bm1"]

    async_fake_pool.mock_add_spec(AsyncNeo4jPool if routing else AsyncBoltPool)

    bmm = mocker.Mock(spec=AsyncBookmarkManager)
    bmm.get_bookmarks.side_effect = bmm_get_bookmarks

    config = SessionConfig()
    config.bookmark_manager = bmm
    config.bookmarks = Bookmarks.from_raw_values(["session", "bookmarks"])
    async with AsyncSession(async_fake_pool, config) as session:
        if session_method == "run":
            await session.run("RETURN 1")
        elif session_method == "get_server_info":
            await session._get_server_info()
        else:
            raise NotImplementedError
        last_bookmarks = await session.last_bookmarks()

        assert last_bookmarks.raw_values == {"session", "bookmarks"}
    assert last_bookmarks.raw_values == {"session", "bookmarks"}


@pytest.mark.parametrize("routing", (True, False))
@mark_async_test
async def test_run_notification_min_severity(async_fake_pool, routing):
    async_fake_pool.mock_add_spec(AsyncNeo4jPool if routing else AsyncBoltPool)
    min_sev = object()
    config = SessionConfig(notifications_min_severity=min_sev)
    async with AsyncSession(async_fake_pool, config) as session:
        await session.run("RETURN 1")
        assert len(async_fake_pool.acquired_connection_mocks) == 1
        connection_mock = async_fake_pool.acquired_connection_mocks[0]
        connection_mock.run.assert_called_once()
        call_kwargs = connection_mock.run.call_args.kwargs
        assert call_kwargs["notifications_min_severity"] is min_sev


@pytest.mark.parametrize("routing", (True, False))
@mark_async_test
async def test_run_notification_disabled_classifications(
    async_fake_pool, routing
):
    async_fake_pool.mock_add_spec(AsyncNeo4jPool if routing else AsyncBoltPool)
    dis_clss = object()
    config = SessionConfig(notifications_disabled_classifications=dis_clss)
    async with AsyncSession(async_fake_pool, config) as session:
        await session.run("RETURN 1")
        assert len(async_fake_pool.acquired_connection_mocks) == 1
        connection_mock = async_fake_pool.acquired_connection_mocks[0]
        connection_mock.run.assert_called_once()
        call_kwargs = connection_mock.run.call_args.kwargs
        assert (
            call_kwargs["notifications_disabled_classifications"] is dis_clss
        )


@mark_async_test
async def test_session_run_api_telemetry(async_fake_pool):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        await session.run("RETURN 1")
        assert len(async_fake_pool.acquired_connection_mocks) == 1
        connection_mock = async_fake_pool.acquired_connection_mocks[0]
        connection_mock.telemetry.assert_called_once()
        call_args = connection_mock.telemetry.call_args.args
        assert call_args[0] == TelemetryAPI.AUTO_COMMIT


@mark_async_test
async def test_session_unmanaged_transaction_api_telemetry(async_fake_pool):
    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        await session.begin_transaction()
        assert len(async_fake_pool.acquired_connection_mocks) == 1
        connection_mock = async_fake_pool.acquired_connection_mocks[0]
        connection_mock.telemetry.assert_called_once()
        call_args = connection_mock.telemetry.call_args.args
        assert call_args[0] == TelemetryAPI.TX


@pytest.mark.parametrize(
    "tx_type",
    (
        "write_transaction",
        "read_transaction",
        "execute_write",
        "execute_read",
    ),
)
@mark_async_test
async def test_session_managed_transaction_api_telemetry(
    async_fake_pool, tx_type
):
    async def work(_):
        pass

    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        with assert_warns_tx_func_deprecation(tx_type):
            await getattr(session, tx_type)(work)
        assert len(async_fake_pool.acquired_connection_mocks) == 1
        connection_mock = async_fake_pool.acquired_connection_mocks[0]
        connection_mock.telemetry.assert_called_once()
        call_args = connection_mock.telemetry.call_args.args
        assert call_args[0] == TelemetryAPI.TX_FUNC


@pytest.mark.parametrize("mode", (WRITE_ACCESS, READ_ACCESS))
@mark_async_test
async def test_session_custom_api_telemetry(async_fake_pool, mode):
    async def work(_):
        pass

    async with AsyncSession(async_fake_pool, SessionConfig()) as session:
        await session._run_transaction(mode, TelemetryAPI.DRIVER, work, (), {})
        assert len(async_fake_pool.acquired_connection_mocks) == 1
        connection_mock = async_fake_pool.acquired_connection_mocks[0]
        connection_mock.telemetry.assert_called_once()
        call_args = connection_mock.telemetry.call_args.args
        assert call_args[0] == TelemetryAPI.DRIVER


@pytest.mark.parametrize(
    ("db", "pool_ssr", "pool_routing", "expect_cache_usage"),
    (
        (db, ssr, routing, ssr and routing and not db)
        for ssr in (True, False)
        for routing in (True, False)
        for db in (None, "mydb")
    ),
)
@pytest.mark.parametrize("imp_user", (None, "imp_user"))
@pytest.mark.parametrize(
    "auth",
    (
        None,
        Auth(scheme="magic-auth", principal=None, credentials="tada"),
    ),
)
@mark_async_test
async def test_uses_home_db_cache_when_expected(
    async_fake_pool,
    mocker,
    db,
    pool_ssr,
    pool_routing,
    expect_cache_usage,
    imp_user,
    auth,
):
    async_fake_pool.ssr_enabled = pool_ssr
    if pool_routing:
        async_fake_pool.is_direct_pool = False
        async_fake_pool.mock_add_spec(AsyncNeo4jPool)
    cache_spy = mocker.Mock(spec=AsyncHomeDbCache, wraps=AsyncHomeDbCache())
    cached_db = "nice_cached_home_db"
    key = object()
    cache_spy.compute_key.return_value = key
    cache_spy.get.return_value = cached_db
    async_fake_pool.home_db_cache = cache_spy

    config = SessionConfig()
    config.impersonated_user = imp_user
    config.auth = auth
    config.database = db

    async with AsyncSession(async_fake_pool, config) as session:
        await session.run("RETURN 1")

        if expect_cache_usage:
            # assert using cache
            assert cache_spy.mock_calls == [
                mocker.call.compute_key(
                    imp_user, to_auth_dict(auth) if auth else None
                ),
                mocker.call.get(key),
            ]
            # assert passing cache result as a guess to the pool
            async_fake_pool.acquire.assert_awaited_once_with(
                access_mode=mocker.ANY,
                timeout=mocker.ANY,
                database=AcquisitionDatabase(cached_db, guessed=True),
                bookmarks=mocker.ANY,
                auth=mocker.ANY,
                liveness_check_timeout=mocker.ANY,
                database_callback=mocker.ANY,
            )
        else:
            # assert not using cache
            cache_spy.get.assert_not_called()
            # assert passing a non-guess to the pool
            async_fake_pool.acquire.assert_awaited_once_with(
                access_mode=mocker.ANY,
                timeout=mocker.ANY,
                database=AcquisitionDatabase(db, guessed=False),
                bookmarks=mocker.ANY,
                auth=mocker.ANY,
                liveness_check_timeout=mocker.ANY,
                database_callback=mocker.ANY,
            )


@pytest.mark.parametrize(
    ("db", "pool_ssr", "pool_routing", "expect_cache_usage"),
    (
        (db, ssr, routing, ssr and routing and not db)
        for ssr in (True, False)
        for routing in (True, False)
        for db in (None, "mydb")
    ),
)
@pytest.mark.parametrize("resolution_at", ("route", "run", "begin"))
@mark_async_test
async def test_pinns_session_db_with_cache(
    async_fake_pool,
    mocker,
    db,
    pool_ssr,
    pool_routing,
    expect_cache_usage,
    resolution_at,
):
    async def resolve_db():
        if resolution_at == "route":
            database_callback = async_fake_pool.acquire.call_args.kwargs[
                "database_callback"
            ]
            await AsyncUtil.callback(database_callback, resolved_db)
        elif resolution_at == "run":
            database_callback = res_mock.call_args.args[-1]
            await AsyncUtil.callback(database_callback, resolved_db)
        elif resolution_at == "begin":
            database_callback = tx_mock.call_args.args[-1]
            await AsyncUtil.callback(database_callback, resolved_db)
        else:
            raise ValueError(f"Unknown resolution_at: {resolution_at}")

    if resolution_at == "run":
        res_mock = mocker.patch(
            "neo4j._async.work.session.AsyncResult", autospec=True
        )
    elif resolution_at == "begin":
        tx_mock = mocker.patch(
            "neo4j._async.work.session.AsyncTransaction", autospec=True
        )

    resolved_db = "resolved_db"
    async_fake_pool.ssr_enabled = pool_ssr
    if pool_routing:
        async_fake_pool.is_direct_pool = False
        async_fake_pool.mock_add_spec(AsyncNeo4jPool)
    cache_spy = mocker.Mock(spec=AsyncHomeDbCache, wraps=AsyncHomeDbCache())
    key = object()
    cache_spy.compute_key.return_value = key
    async_fake_pool.home_db_cache = cache_spy

    config = SessionConfig()
    config.database = db

    async with AsyncSession(async_fake_pool, config) as session:
        if resolution_at == "begin":
            async with await session.begin_transaction() as tx:
                await tx.run("RETURN 1")
        else:
            await session.run("RETURN 1")

        if expect_cache_usage:
            # assert never using cache to pin a database
            assert not session._pinned_database
            assert config.database == db

            await resolve_db()

            assert session._pinned_database
            assert config.database == resolved_db
            cache_spy.set.assert_called_once_with(key, resolved_db)
        else:
            if not pool_routing or db:
                assert session._pinned_database
                assert config.database == db

            await resolve_db()

            if not pool_routing or db:
                assert session._pinned_database
                assert config.database == db
                cache_spy.set.assert_not_called()
            else:
                cache_spy.set.assert_called_once_with(key, resolved_db)
                assert session._pinned_database
                assert config.database == resolved_db
