import time
from typing import Any, Dict, Generator, List, Optional, Union

from answers.backend.db.user_profile import user_profile_sql_manager
from answers.backend.utils.knowledge_banks_params import KnowledgeBanksParams
from answers.solutions.chains.db.db_retrieval_chain import DBRetrievalChain
from answers.solutions.chains.kb.kb_retrieval_chain import KBRetrievalChain
from common.backend.constants import (
    CONVERSATION_DEFAULT_NAME,
    DEFAULT_TITLE_GENERATION_ERROR,
    MEDIA_CONVERSATION_START_TAG,
)
from common.backend.models.base import (
    ConversationParams,
    ConversationType,
    LLMCompletionSettings,
    LlmHistory,
    LLMStepDesc,
    MediaSummary,
    RetrievalSummaryJson,
    RetrieverMode,
)
from common.backend.services.llm_question_answering import BaseLLMQuestionAnswering
from common.backend.utils.context_utils import LLMStepName, get_main_trace, init_user_trace
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.llm_utils import (
    get_alternative_llm_completion_settings,
    get_image_generation_settings,
    get_llm_capabilities,
    get_llm_completion,
    get_main_llm_completion_settings,
    handle_response_trace,
    resolve_llm_id_from_key,
)
from common.backend.utils.sql_timing import log_execution_time
from common.llm_assist.fallback import get_fallback_completion, is_fallback_enabled
from common.llm_assist.logging import logger
from common.solutions.chains.docs.media_qa_chain import MediaQAChain
from common.solutions.chains.generic_answers_chain import GenericAnswersChain
from common.solutions.chains.image_generation.image_generation_chain import ImageGenerationChain
from common.solutions.chains.image_generation.image_generation_decision_chain import ImageGenerationDecisionChain
from common.solutions.chains.no_retrieval_chain import NoRetrievalChain
from common.solutions.prompts.conversation_title import CONVERSATION_TITLE_PROMPT, TITLE_USER_PROMPT
from dataikuapi.dss.llm import (
    DSSLLMCompletionResponse,
    DSSLLMStreamedCompletionChunk,
    DSSLLMStreamedCompletionFooter,
)


class LLM_Question_Answering(BaseLLMQuestionAnswering):
    """LLM_Question_Answering: A class to facilitate the question-answering process using LLM model"""

    def __init__(self):
        # Constructor parameters:
        self.project = dataiku_api.default_project
        self.llm_completion_settings: LLMCompletionSettings = get_main_llm_completion_settings()
        self.decision_completion_settings: LLMCompletionSettings = get_alternative_llm_completion_settings("json_decision_llm_id")
        self.webapp_config = dataiku_api.webapp_config
        # Parameters:
        self.retrieval_mode = self.webapp_config.get("retrieval_mode")
        if self.retrieval_mode == "":
            raise Exception("retrieval_mode must be provided")
        self.include_user_profile_in_prompt = bool(self.webapp_config.get("include_user_profile_in_KB_prompt", False))
        self.chat_has_media = False
        # Initialize knowledge banks parameters once for all
        self.kbs_params = KnowledgeBanksParams()

    def get_answer_and_sources(  # noqa: PLR0917 too many positional arguments
        self,
        query: str,
        conversation_type: ConversationType,
        chat_history: List[LlmHistory] = [],
        filters: Optional[Dict[str, List[Any]]] = None,
        chain_type: Optional[str] = None,
        media_summaries: Optional[List[MediaSummary]] = None,
        previous_media_summaries: Optional[List[MediaSummary]] = None,
        retrieval_enabled: bool = False,
        user_profile: Optional[Dict[str, Any]] = None,
        force_non_streaming: bool = False,
    ) -> Generator[
        Union[
            LLMStepDesc,
            RetrievalSummaryJson,
            DSSLLMStreamedCompletionChunk,
            DSSLLMStreamedCompletionFooter,
        ],
        Any,
        None,
    ]:
        """Extracts the answer and its corresponding sources for a given prompt.

        Args:
            query (str): The user query.
            query_index: int: The index of the query within the conversation.
            chat_history (List[LlmHistory]): A list of prior interactions (optional). Each interaction
                is a dictionary with the question and response.
            filters (Optional[Dict[str, List[Any]]]): A dictionary of filters to apply to the knowledge bank.
            file_path Optional[str]: The file path uploaded into a managed folder
            chain_type Optional[str]: The file type of an uploaded file. This indicates how it should be handled
            metadata_path Optional[str]: The json file path of uploaded documents containing any extracted text


        Returns:
            Dict[str, Union[str, Dict[str, Any]]]: A dictionary containing:
                - 'answer' (str): The generated answer to the prompt.
                - 'sources' (List[AggregatedToolSources]): A list of AggregatedToolSources: 
                    - toolCallDescription (str): the retrieval tool name
                    - items (List[Source]): see format_sources() fore more details
                        - 'type' (str): the type of the source.
                        - 'textSnippet' (str): The content of the source.
                        - 'metadata' (Dict[str, str]): Additional metadata about the source.
        """

        logger.debug("Time ===>: starting tracking time: Generating response")
        if query != MEDIA_CONVERSATION_START_TAG:
            init_user_trace(LLMStepName.DKU_ANSWERS_QUERY.name)
            if not (main_trace := get_main_trace()):
                raise Exception("Main trace is not initialized correctly.")
            main_trace.attributes["query"] = query
        conversation_params: ConversationParams = {
            "user_query": query,
            "chat_history": chat_history,
            "chain_type": chain_type,
            "media_summaries": media_summaries,
            "previous_media_summaries": previous_media_summaries,
            "retrieval_enabled": retrieval_enabled,
            "knowledge_bank_selection": [],
            "use_db_retrieval": False,
            "global_start_time": time.time(),  # Capture start time
            "justification": "",
            "user_profile": user_profile,
            "self_service_decision": None,
            "chain_purpose": LLMStepName.UNKNOWN.value,
        }
        logger.info(f"""retrieval_mode is set to : {self.retrieval_mode}
        Conversation type is set to: {conversation_type}
        Chain type is set to: {chain_type}""")
        if conversation_type is not None:
            logger.debug(f"conversation_type is set to : {conversation_type}")

        self.chat_has_media = any([item.get("output","") == "generated_media_by_ai" for item in conversation_params.get("chat_history", [])])
        llm_capabilities = get_llm_capabilities()

        if llm_capabilities.get("image_generation", False) and query != MEDIA_CONVERSATION_START_TAG:
            logger.info("Image generation is enabled")
            img_system_prompt = self.webapp_config.get("image_generation_system_prompt")
            img_gen_decision_chain = ImageGenerationDecisionChain(
                completion_settings=self.decision_completion_settings, system_prompt=img_system_prompt
            )
            decision_output = img_gen_decision_chain.get_decision_as_json(user_query=query, chat_history=chat_history)
            logger.info(f"Image generation decision chain output: {decision_output}")
            generate_image = decision_output.get("decision")
            generated_query = decision_output.get("query")
            referred_image_path = decision_output.get("referred_image")
            if generate_image:
                if image_generation_llm_id :=dataiku_api.webapp_config.get("image_generation_llm_id"):
                    if not (main_trace := get_main_trace()):
                        raise Exception("Main trace is not initialized correctly.")
                    sub_trace = main_trace.subspan(LLMStepName.IMAGE_GENERATION.name)
                    image_generation_settings = get_image_generation_settings(model_id=image_generation_llm_id, referred_image_path=referred_image_path, user_profile=user_profile)
                    max_images_to_generate = int(dataiku_api.webapp_config.get("max_images_per_user_per_week", 0))
                    return ImageGenerationChain(
                        completion_settings=self.llm_completion_settings,
                        image_generation_settings=image_generation_settings,
                        user_query=generated_query,
                        user_profile_sql_manager=user_profile_sql_manager,
                        user_profile=user_profile,                    
                        trace=sub_trace,
                        include_user_profile_in_prompt=self.include_user_profile_in_prompt
                    ).run_image_generation_query(max_images_to_generate)
                else:
                    logger.error("Image generation LLM ID is not set")
                    raise ValueError("The 'image_generation_llm_id' has not been set.")
            # If image generation is not required, continue with the normal flow and ignore media in the user profile
            if user_profile and user_profile.get("media"):
                user_profile.pop("media", None)
            if user_profile and user_profile.get("generated_media_info"):
                user_profile.pop("generated_media_info", None)

        qa_chain: GenericAnswersChain
        if query == MEDIA_CONVERSATION_START_TAG:
            logger.info("Starting media QA conversation")
            return MediaQAChain(media_summaries).start_media_qa_chain(user_profile, False if chat_history else True)
        if retrieval_enabled and self.retrieval_mode == RetrieverMode.KB.value:
            logger.info("KB retrieval enabled")
            qa_chain = KBRetrievalChain(self.llm_completion_settings, decision_completion_settings=self.decision_completion_settings, query=query, input_filters=filters, chat_has_media=self.chat_has_media)# TODO: Refactor llm->llm_id & decision_llm --> decision_completion_settings
        elif retrieval_enabled and self.retrieval_mode == RetrieverMode.DB.value:
            logger.info("DB retrieval enabled")
            qa_chain = DBRetrievalChain(self.llm_completion_settings, decision_completion_settings=self.decision_completion_settings, chat_has_media=self.chat_has_media)# TODO: Refactor llm->llm_id & decision_llm --> decision_completion_settings
        else:
            logger.info("No retrieval enabled")
            qa_chain = NoRetrievalChain(self.llm_completion_settings, chat_has_media=self.chat_has_media, include_user_profile_in_prompt=self.include_user_profile_in_prompt)# TODO: Refactor llm->llm_id

        conversation_params = qa_chain.create_query_from_history_and_update_params(
            chat_history, query, conversation_params
        )
        if (
            retrieval_enabled
            # and conversation_type == ConversationType.GENERAL
            and self.retrieval_mode == RetrieverMode.KB.value
            and not conversation_params["knowledge_bank_selection"]
        ):
            qa_chain = NoRetrievalChain(self.llm_completion_settings, include_user_profile_in_prompt=self.include_user_profile_in_prompt)# TODO: Refactor llm->llm_id
        conversation_params["chain_purpose"] = qa_chain.chain_purpose
        if force_non_streaming:
            qa_chain.forced_non_streaming = True
        return qa_chain.run_completion_query(conversation_params)


    @staticmethod
    @log_execution_time
    def get_conversation_title(query: str, answer: str, user_profile: Optional[Dict[str, Any]] = None, first_attempt: bool = True) -> str:
        title_llm_id = resolve_llm_id_from_key("title_llm_id")
        title_completion_settings = get_alternative_llm_completion_settings("title_llm_id")
        completion = get_llm_completion(title_completion_settings)
        system_prompt = CONVERSATION_TITLE_PROMPT.format(query=query, user_profile= user_profile)
        user_prompt = TITLE_USER_PROMPT.format(generated_content=answer)
        conversation_title = CONVERSATION_DEFAULT_NAME
        if not first_attempt:
            completion = get_fallback_completion(completion)
        completion.with_message(system_prompt, role="system")
        completion.with_message(user_prompt, role="user")
        try:
            resp: DSSLLMCompletionResponse = completion.execute()
            handle_response_trace(resp)
            conversation_title = str(resp.text) if resp.text else CONVERSATION_DEFAULT_NAME
            if error_message := resp._raw.get("errorMessage"):
                logger.error(error_message)
                raise Exception(error_message)
            return conversation_title
        except Exception as e:
            logger.exception(f"{DEFAULT_TITLE_GENERATION_ERROR} {e}.")
            # checking if we can use a fallback LLM
            fallback_enabled = is_fallback_enabled(title_llm_id) 
            if first_attempt and fallback_enabled:
                conversation_title = LLM_Question_Answering.get_conversation_title(query, answer, user_profile, False)
            return conversation_title

    def get_llm_name(self) -> str:
        # TODO should we return the name of all other models used like title gen, decisions image gen..?
        return self.llm_completion_settings["llm_id"]  # type: ignore