
import unittest
from typing import List, cast
from unittest.mock import MagicMock, patch

from common.backend.models.source import Source
from common.backend.services.sources.sources_builder import build_augmented_llm_or_agent_sources
from common.backend.services.sources.sources_formatter import (
    format_sources,
    serialize_aggregated_sources_for_api,
    sources_are_old_format,
)
from dataikuapi.dss.llm import DSSLLM
from sources_constants import (
    AGENT_LLM_RESPONSE,
    AUGMENTED_LLM_RESPONSE,
    EXPECTED_AGENT_SOURCES,
    EXPECTED_AUGMENTED_SOURCES,
    EXPECTED_SERIALIZED_DB_SOURCES,
    EXPECTED_SERIALIZED_KB_SOURCES,
    NEW_DB_FORMAT,
    NEW_KB_FORMAT,
    OLD_DB_FORMAT,
    OLD_KB_FORMAT,
)


class TestSourcesFormatter(unittest.TestCase):

    # Note: In this test, we are using a raw list of dictionaries and casting them to the expected typed model.
    # This approach is chosen because the dictionaries are derived directly from a specific use case, ensuring they contain
    # the correct data structure. 


    def test_sources_need_formatting(self):
        assert sources_are_old_format(OLD_KB_FORMAT) is True
        assert sources_are_old_format(OLD_DB_FORMAT) is True
        assert sources_are_old_format(NEW_KB_FORMAT) is False
        assert sources_are_old_format(NEW_DB_FORMAT) is False


    def test_answers_format_conversion_KB(self):
        result = format_sources(cast(List[Source], OLD_KB_FORMAT))
        assert result == NEW_KB_FORMAT


    def test_answers_format_conversion_DB(self):
        result = format_sources(cast(List[Source], OLD_DB_FORMAT))
        assert result == NEW_DB_FORMAT
    

    def test_serialize_aggregated_sources_for_api(self):
        # API schemas  differ from data models
        # For instance, we do not require the tool_name_used information

        serialized_db_sources = serialize_aggregated_sources_for_api(NEW_DB_FORMAT)
        serialized_kb_sources = serialize_aggregated_sources_for_api(NEW_KB_FORMAT)

        self.assertEqual(EXPECTED_SERIALIZED_DB_SOURCES, serialized_db_sources)
        self.assertEqual(EXPECTED_SERIALIZED_KB_SOURCES, serialized_kb_sources)


    @patch('common.backend.services.sources.sources_builder.get_llm_friendly_name')
    def test_extract_augmented_llm_sources(self, get_llm_friendly_name):
        get_llm_friendly_name.return_value = "mocked_augmented_llm"
        llm_used = MagicMock(spec=DSSLLM(client="client", project_key="PROJECT", llm_id="augmented_llm"))
        extracted_sources = build_augmented_llm_or_agent_sources(AUGMENTED_LLM_RESPONSE, llm_used)
        self.assertEqual(EXPECTED_AUGMENTED_SOURCES, extracted_sources)


    @patch('common.backend.services.sources.sources_builder.get_llm_friendly_name')
    def test_extract_agent_sources(self, get_llm_friendly_name):
        get_llm_friendly_name.return_value = "mocked_agent_llm"
        llm_used = MagicMock(spec=DSSLLM(client="client", project_key="PROJECT", llm_id="agent_llm"))
        extracted_sources = build_augmented_llm_or_agent_sources(AGENT_LLM_RESPONSE, llm_used)
        self.assertEqual(EXPECTED_AGENT_SOURCES, extracted_sources)