import unittest
from unittest.mock import MagicMock, patch

from common.backend.constants import CONVERSATION_DEFAULT_NAME
from common.solutions.chains.title.media_conversation_title import SummaryTitler
from dataikuapi.dss.llm import DSSLLMCompletionResponse


class TestSummaryTitler(unittest.TestCase):
    @patch('common.solutions.chains.title.media_conversation_title.is_fallback_enabled')
    @patch('common.solutions.chains.title.media_conversation_title.get_fallback_completion')
    @patch('common.solutions.chains.title.media_conversation_title.get_llm_completion')
    def test_generate_summary_title_uses_fallback_on_exception(
        self,
        mock_get_llm_completion,
        mock_to_fallback,
        mock_is_fallback_enabled,
    ):
        # Arrange
        titler = SummaryTitler()
        dummy_summaries = [MagicMock()]
        titler.list_summaries = MagicMock(return_value="FAKE_SUMMARIES")

        mock_is_fallback_enabled.return_value = True 

        fake_completion = MagicMock(name="OriginalCompletion")
        mock_get_llm_completion.return_value = fake_completion
        
        # completion execute mocks, error on first call
        # then  simulate a reply
        fake_generated_fallback_title = "The fallback title"
        first_error = Exception("simulated LLM failure")
        fallback_response = MagicMock(spec=DSSLLMCompletionResponse)
        fallback_response.text = fake_generated_fallback_title
        fallback_response._raw = {}

        fake_completion.execute.side_effect = [first_error, fallback_response]
        mock_to_fallback.return_value = fake_completion

        # Act
        result = titler.generate_summary_title(dummy_summaries)

        # Assert
        self.assertEqual(result, fake_generated_fallback_title)

    
    @patch('common.solutions.chains.title.media_conversation_title.is_fallback_enabled')
    @patch('common.solutions.chains.title.media_conversation_title.get_fallback_completion')
    @patch('common.solutions.chains.title.media_conversation_title.get_llm_completion')
    def test_generate_summary_title_fails_without_fallback(
        self,
        mock_get_llm_completion,
        mock_to_fallback,
        mock_is_fallback_enabled,
    ):
        # Arrange
        titler = SummaryTitler()
        dummy_summaries = [MagicMock()]
        titler.list_summaries = MagicMock(return_value="FAKE_SUMMARIES")

        mock_is_fallback_enabled.return_value = False 

        fake_completion = MagicMock(name="OriginalCompletion")
        mock_get_llm_completion.return_value = fake_completion

        first_error = Exception("simulated LLM failure")

        fake_completion.execute.side_effect = [first_error]
        mock_to_fallback.return_value = fake_completion

        # Act
        result = titler.generate_summary_title(dummy_summaries)

        # Assert
        self.assertEqual(result, CONVERSATION_DEFAULT_NAME)


    @patch('common.solutions.chains.title.media_conversation_title.is_fallback_enabled')
    @patch('common.solutions.chains.title.media_conversation_title.get_fallback_completion')
    @patch('common.solutions.chains.title.media_conversation_title.get_llm_completion')
    def test_generate_summary_title_fails_and_fallback_fails_too(
        self,
        mock_get_llm_completion,
        mock_to_fallback,
        mock_is_fallback_enabled,
    ):
        # Arrange
        titler = SummaryTitler()
        dummy_summaries = [MagicMock()]
        titler.list_summaries = MagicMock(return_value="FAKE_SUMMARIES")

        mock_is_fallback_enabled.return_value = True 

        fake_completion = MagicMock(name="OriginalCompletion")
        mock_get_llm_completion.return_value = fake_completion
        
        # completion execute mocks, error on first call
        # and second error on second call
        first_error = Exception("simulated LLM failure")
        second_error = Exception("fallback LLM failure")

        fake_completion.execute.side_effect = [first_error, second_error]
        mock_to_fallback.return_value = fake_completion

        # Act
        result = titler.generate_summary_title(dummy_summaries)

        # Assert
        self.assertEqual(result, CONVERSATION_DEFAULT_NAME)
