
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List

from common.backend.constants import DEFAULT_DECISIONS_GENERATION_ERROR
from common.backend.models.base import LLMCompletionSettings, LlmHistory
from common.backend.utils.json_utils import extract_json
from common.backend.utils.llm_utils import (
    get_llm_completion,
    handle_response_trace,
    parse_error_messages,
)
from common.backend.utils.sql_timing import log_execution_time
from common.llm_assist.fallback import get_fallback_completion, is_fallback_enabled
from common.llm_assist.logging import logger
from dataikuapi.dss.llm import DSSLLMCompletionQuery, DSSLLMCompletionResponse


class GenericDecisionJSONChain(ABC):
    # Abstract class for Decision Chain that returns JSON objects
    # Implement the prompt property in the child class
    @property
    @abstractmethod
    # The prompt that will be used to generate the decision
    # This should instruct the model on what decision to make
    # and what information to include in the decision
    # This should instruct the model to output JSON!!!!
    def prompt(self) -> str:
        raise NotImplementedError("Subclasses must implement prompt property")

    @property
    @abstractmethod
    def completion_settings(self) -> LLMCompletionSettings:
        raise NotImplementedError("Subclasses must implement completion_settings property")

    @property
    @abstractmethod
    def chain_purpose(self) -> str:
        raise NotImplementedError("Subclasses must implement chain_purpose property")

    @abstractmethod
    def verified_json_output(self, json_output: Dict[str, Any]) -> Dict[str, Any]:
        pass

    def get_decision_json_pattern(self) -> str:
        return r"{[\s\S]*?}"

    def add_history(self, completion: DSSLLMCompletionQuery, chat_history: List[LlmHistory]) -> DSSLLMCompletionQuery:
        for hist_item in chat_history:
            if input_ := hist_item.get("input"):
                completion.with_message(message=input_, role="user")
            if output := hist_item.get("output"):
                completion.with_message(message=output, role="assistant")
        return completion

    @log_execution_time
    def get_decision_as_json(self, user_query: str, chat_history: List[LlmHistory], first_attempt: bool = True) -> Dict[str, Any]:
        error_message = DEFAULT_DECISIONS_GENERATION_ERROR
        if self.prompt is None:
            logger.error("Prompt is not defined ({})", log_conv_id=True)
            raise Exception("Decision prompt is not defined")
        logger.info(f"{self.chain_purpose} Final Prompt: {self.prompt}", log_conv_id=True)
        try:
            completion = get_llm_completion(self.completion_settings)
            if not first_attempt:
                completion = get_fallback_completion(completion)
            logger.info(f"Calling a LLM for decision chain, llm_id : [{completion.llm.llm_id}]", log_conv_id=True)
            completion.with_message(self.prompt, role="system")
            completion = self.add_history(completion, chat_history)
            completion.with_message(user_query, role="user")
            llm_response: DSSLLMCompletionResponse = completion.execute()
            handle_response_trace(llm_response)
            response: str = llm_response.text
            logger.debug(f"{self.chain_purpose} llm response: '{response}'", log_conv_id=True)
            if not response and llm_response._raw.get("errorMessage"):
                error_message += str(llm_response.errorMessage)
                raise Exception(error_message)
            if not response:
                error_message += "No response from LLM"
                raise Exception(error_message)
            if isinstance(response, dict):
                json_response = response
            else:
                json_response = extract_json(response, json_pattern=self.get_decision_json_pattern())
            logger.debug(f"Extracted response  {json_response}", log_conv_id=True)
            verified_resp = self.verified_json_output(json_response)
            return verified_resp
        except json.JSONDecodeError as e:
            error_message += f"Error decoding JSON during decision generation"
            logger.exception(error_message, log_conv_id=True)
            raise Exception(error_message)
        except Exception as e:
            error_message += parse_error_messages(str(e)) or str(e)
            logger.exception(error_message, log_conv_id=True)
            # checking if we can use a fallback LLM
            fallback_enabled = is_fallback_enabled(self.completion_settings["llm_id"])
            if first_attempt and fallback_enabled:
                return self.get_decision_as_json(user_query, chat_history, False) # type: ignore # mypy struggles with recursion here
            raise Exception(error_message)