import json
from abc import ABC, abstractmethod
from typing import Optional

from common.backend.models.base import LLMCompletionSettings
from common.backend.utils.json_utils import extract_json
from common.backend.utils.llm_utils import (
    extract_response_trace,
    get_alternative_llm_completion_settings,
    get_llm_completion,
)
from common.backend.utils.sql_timing import log_execution_time
from common.llm_assist.fallback import get_fallback_completion, is_fallback_enabled
from common.llm_assist.logging import logger
from dataikuapi.dss.llm import DSSLLMCompletionQuery


class GenericSummaryChain(ABC):
    @property
    @abstractmethod
    def original_file_name(self) -> Optional[str]:
        raise NotImplementedError("Subclasses must implement original_file_name property")

    @property
    @abstractmethod
    def language(self) -> Optional[str]:
        raise NotImplementedError("Subclasses must implement language property")

    @abstractmethod
    def get_summary_completion(self) -> DSSLLMCompletionQuery:
        raise NotImplementedError("Subclasses must implement create_completion method")

    def init_completion(self) -> DSSLLMCompletionQuery:
        decision_completion_settings: LLMCompletionSettings = get_alternative_llm_completion_settings("json_decision_llm_id")
        return get_llm_completion(decision_completion_settings)

    @log_execution_time
    def get_summary(self, first_attempt: bool = True):
        error_message = ""
        text = ""
        try:
            completion = self.get_summary_completion()
            if not first_attempt:
                completion = get_fallback_completion(completion)
            logger.info(f"Calling a LLM for summary chain, llm_id : [{completion.llm.llm_id}]", log_conv_id=True)
            resp = completion.execute() if completion else None
            trace = extract_response_trace(resp)
            if resp and not resp.text and resp.errorMessage:
                logger.error(f"Fallback LLM error: {resp.errorMessage}.", log_conv_id=True)
                raise Exception(resp.errorMessage)
            text = str(resp.text) if resp else ""
            extracted_json = extract_json(text, None)  # type: ignore
            extracted_json["trace"] = trace
            return extracted_json
        except json.JSONDecodeError as e:
            error_message += f"Error decoding JSON during summary generation"
            logger.exception(error_message, log_conv_id=True)
            raise Exception(error_message)
        except Exception as e:
            error_message = f"Error during summary generation: {str(e)}"
            logger.exception(error_message, log_conv_id=True)
            # checking if we can use a fallback LLM
            fallback_enabled = is_fallback_enabled(completion.llm.llm_id)
            if first_attempt and fallback_enabled:
                return self.get_summary(first_attempt=False)
            raise Exception(error_message)
