
import unittest
from unittest.mock import patch

# from common.solutions.chains.docs.self_service_decision_chain import SelfServiceDecisionChain
# from dataiku.langchain.dku_llm import DKULLM
# from dataikuapi.dss.llm import DSSLLMCompletionResponse


class TestDecisionChainFallback(unittest.TestCase):
    @patch('common.solutions.chains.generic_decision_chain.handle_response_trace')
    @patch('common.solutions.chains.generic_decision_chain.is_fallback_enabled')
    @patch('common.solutions.chains.generic_decision_chain.get_fallback_completion')
    @patch('common.solutions.chains.generic_decision_chain.get_llm_completion')
    def test_decision_chain_uses_fallback_on_exception(
        self,
        mock_get_llm_completion,
        mock_to_fallback,
        mock_is_fallback_enabled,
        mock_handle_response_trace,
    ):
        pass 
        # # Arrange
        # llm = MagicMock(spec=DKULLM)
        # decision_chain = SelfServiceDecisionChain(
        #     llm=llm,
        #     user_query="query",
        #     chat_history=[],
        #     media_summaries=[]
        # )

        # mock_is_fallback_enabled.return_value = True 

        # fake_completion = MagicMock(name="OriginalCompletion")
        # mock_get_llm_completion.return_value = fake_completion

        # mock_handle_response_trace = MagicMock()

        # # completion.execute() mocks, error on first call
        # # then  simulate a reply
        # first_error = Exception("simulated LLM failure")
        # fake_json = '{ "query": "test_query", "justification": "test_justification" }'

        # fallback_response = DSSLLMCompletionResponse({"ok": "ok", "text": fake_json}, text=fake_json)
        # fake_completion.execute.side_effect = [first_error, fallback_response]
        # mock_to_fallback.return_value = fake_completion

        # # Act
        # with patch.object(decision_chain, 'get_decision_as_json', wraps=decision_chain.get_decision_as_json) as mock_get_decision_as_json:
        #     result = decision_chain.get_decision_as_json(user_query="query")

        # # Assert
        # self.assertEqual(result, {'decision': False, 'justification': 'test_justification', 'documents': []})
        # self.assertEqual(mock_get_decision_as_json.call_count, 2)


    @patch('common.solutions.chains.generic_decision_chain.handle_response_trace')
    @patch('common.solutions.chains.generic_decision_chain.is_fallback_enabled')
    @patch('common.solutions.chains.generic_decision_chain.get_fallback_completion')
    @patch('common.solutions.chains.generic_decision_chain.get_llm_completion')
    def test_decision_chain_fails_and_fallback_too(
        self,
        mock_get_llm_completion,
        mock_to_fallback,
        mock_is_fallback_enabled,
        mock_handle_response_trace,
    ):
        # Arrange
        pass
        # llm = MagicMock(spec=DKULLM)
        # decision_chain = SelfServiceDecisionChain(
        #     llm=llm,
        #     user_query="query",
        #     chat_history=[],
        #     media_summaries=[]
        # )

        # mock_is_fallback_enabled.return_value = True 

        # fake_completion = MagicMock(name="OriginalCompletion")
        # mock_get_llm_completion.return_value = fake_completion

        # mock_handle_response_trace = MagicMock()

        # # completion.execute() mocks, will keep failing.
        # llm_error = Exception("simulated LLM failure")
        # fake_completion.execute.side_effect = [llm_error, llm_error]
        # mock_to_fallback.return_value = fake_completion

        # # Act
        # with patch.object(decision_chain, 'get_decision_as_json', wraps=decision_chain.get_decision_as_json) as mock_get_decision_as_json:
        #     with self.assertRaises(Exception) as context:
        #         decision_chain.get_decision_as_json(user_query="query")

        #     # Assert
        #     self.assertEqual(mock_get_decision_as_json.call_count, 2)
        #     self.assertIn("simulated LLM failure", str(context.exception))
