from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union

from answers.backend.services.sources.sources_builder_specific_answers import generate_sources_from_sample
from answers.backend.utils.config_utils import get_retriever_info
from answers.llm_assist.llm_tooling.tool_utils import get_all_dataset_descriptions
from answers.solutions.chains.db.select_sql_agent import SelectSQLAgent
from common.backend.constants import PROMPT_DATE_FORMAT
from common.backend.models.base import (
    ConversationParams,
    LLMCompletionSettings,
    LlmHistory,
    LLMStep,
    MediaSummary,
    RetrievalSummaryJson,
)
from common.backend.utils.context_utils import LLMStepName
from common.backend.utils.llm_utils import handle_prompt_media_explanation
from common.backend.utils.prompt_utils import append_user_profile_to_prompt
from common.llm_assist.logging import logger
from common.solutions.chains.generic_answers_chain import GenericAnswersChain
from dataiku.core.knowledge_bank import MultipartContext


class DBRetrievalChain(GenericAnswersChain):
    def __init__(self, completion_settings: LLMCompletionSettings, decision_completion_settings: LLMCompletionSettings, chat_has_media: bool = False):
        super().__init__()
        self.__completion_settings = completion_settings
        self.__decision_completion_settings = decision_completion_settings
        self.__act_like_prompt = ""
        self.__system_prompt = ""
        self.__prompt_with_media_explanation = chat_has_media
        self.__sql_query = ""
        self.__tables_used: List[str] = []
        self.__retrieval_query_agent = SelectSQLAgent()
        self.__use_db_retrieval = False

    @property
    def act_like_prompt(self) -> str:
        return self.__act_like_prompt

    @property
    def system_prompt(self) -> str:
        return self.__system_prompt

    @property
    def sql_query(self) -> str:
        return self.__sql_query

    @property
    def tables_used(self) -> List[str]:
        return self.__tables_used

    @property
    def retrieval_query_agent(self) -> SelectSQLAgent:
        return self.__retrieval_query_agent

    @property
    def use_db_retrieval(self) -> bool:
        return self.__use_db_retrieval

    @property
    def completion_settings(self) -> LLMCompletionSettings:
        return self.__completion_settings

    @property
    def decision_completion_settings(self) -> LLMCompletionSettings:
        return self.__decision_completion_settings

    @property
    def chain_purpose(self) -> str:
        return LLMStepName.DB_ANSWER.value

    def load_role_and_guidelines_prompts(self, params: ConversationParams):
        act_like_prompt = self.webapp_config.get("db_prompt", "")
        system_prompt = self.webapp_config.get(
            "db_system_prompt",
            "Given the following specific context and the conversation between a user and an assistant, please give a short answer to the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.",
        )

        user_profile = params.get("user_profile", None)
        system_prompt = append_user_profile_to_prompt(system_prompt=system_prompt, user_profile=user_profile)
        system_prompt = handle_prompt_media_explanation(
            system_prompt=system_prompt, has_media=self.__prompt_with_media_explanation
        )
        self.__act_like_prompt = act_like_prompt
        self.__system_prompt = system_prompt

    def create_db_retrieval_prompt(self) -> str:
        datetime_now = datetime.now().strftime(PROMPT_DATE_FORMAT)
        return rf"""
        Today's date and time: {datetime_now}
        {self.act_like_prompt}

        {self.system_prompt}
        - Do not attempt to write any SQL
        """

    def create_db_retrieval_prompt_on_but_not_used(self, params: ConversationParams) -> str:
        if justification := params.get("justification"):
            justification_text = f"""
            They gave the following reason
            {justification}
            """
        else:
            justification_text = ""

        datetime_now = datetime.now().strftime(PROMPT_DATE_FORMAT)
        return rf"""
        Today's date and time: {datetime_now}
        {self.act_like_prompt}

        {self.system_prompt}

        Be aware that you are part of a team that can retrieve information from external data sources.
        However, your teammate decided not to make any query for the following reason: {justification_text}. Only mention this if it relevant to the user query.
        - Do not attempt to write any SQL
        The data source description is:
        {get_all_dataset_descriptions()}
        """

    def get_computed_system_prompt(self, params: ConversationParams) -> str:
        if self.use_db_retrieval:
            return self.create_db_retrieval_prompt()
        return self.create_db_retrieval_prompt_on_but_not_used(params)

    def get_computing_prompt_step(self) -> LLMStep:
        if self.use_db_retrieval:
            return LLMStep.COMPUTING_PROMPT_WITH_DB
        else:
            return LLMStep.COMPUTING_PROMPT_WITHOUT_RETRIEVAL

    def finalize_streaming(
        self, params: ConversationParams, answer_context: Union[str, Dict[str, Any], List[str]]
    ) -> RetrievalSummaryJson:
        user_profile = params.get("user_profile", None)
        # Send sources and filters at the end of the streaming
        return self.get_as_json(
            answer_context, user_profile=user_profile, uploaded_docs=params.get("media_summaries", None)
        )

    def finalize_non_streaming(
        self,
        params: ConversationParams,
        answer_context: Union[str, Dict[str, Any], List[str]],
    ) -> RetrievalSummaryJson:
        return self.finalize_streaming(params=params, answer_context=answer_context)

    def get_querying_step(self, params: ConversationParams) -> LLMStep:
        if self.use_db_retrieval:
            step = LLMStep.QUERYING_LLM_WITH_DB
        else:
            step = LLMStep.QUERYING_LLM_WITHOUT_RETRIEVAL
        return step

    def __exception_correction_loop(
        self,
        conversation_params: ConversationParams,
        db_query: Union[Dict[str, Any], None],
        justification: str,
        user_question: str,
        attempt: int,
        chat_history: list,
        max_attempts: int = 3,
    ):
        from answers.llm_assist.llm_tooling.tools import (
            SqlRetrieverTool,  # Lazy import to prevent loading heavy langchain modules
        )
        sql_query = ""
        tables_used: List[str] = []
        records: Union[List[List[Dict[str, str]]], List[Dict[str, List[Any]]]] = []
        sql_tool = SqlRetrieverTool()
        logger.debug(f"Correction attempt no. {attempt + 1}")
        context = ""
        if attempt >= max_attempts or db_query is None:
            logger.error("Unable to read from SQL table after maximum retries")
            context = f"""
            Error while reading from SQL table after {max_attempts} attempts
            {self.retrieval_query_agent.previous_sql_errors}
            The SQL table description is:
            {get_all_dataset_descriptions()}
            The following attempts were made:
            {self.retrieval_query_agent.formatted_errors}
            """
            return context, sql_query, records, tables_used
        failed_response = {**db_query, "justification": justification}
        db_query, justification = self.retrieval_query_agent.get_retrieval_query(
            completion_settings=self.decision_completion_settings,
            conversation_params=conversation_params,
            chat_history=chat_history,
            user_input=user_question,
            failed_response=failed_response,
            )
        try:
            if db_query is not None:
                context, sql_query, records, tables_used = sql_tool._run(db_query)
            else:
                # TODO: Handle this case as the db_on_but_disabled method
                context = "No Need for database query"
            return context, sql_query, records, tables_used
        except Exception as e:
            logger.error(f"Unable to read from SQL table on first attempt {e}")
            context = f"Error when reading from SQL table {e}."
            db_query = db_query or {}
            failed_response = {**db_query, "justification": justification}
            failed_response_str = str(failed_response)
            self.retrieval_query_agent.previous_sql_errors.append(
                {"response": failed_response_str, "query": sql_query, "error": str(e)}
            )
            attempt += 1
            return self.__exception_correction_loop(
                conversation_params=conversation_params,
                db_query=db_query,
                justification=justification,
                user_question=user_question,
                attempt=attempt,
                chat_history=chat_history
            )

    def __prep_db_answer_context(self, records: List[Dict[str, List[Any]]], sql_query: str):
        columns = []
        if len(records) > 0:
            columns = [{"name": r, "label": r, "field": r, "align": "left"} for r in records[0].keys()]
        answer_context = {"rows": records, "columns": columns, "query": sql_query}
        return answer_context

    def __compute_prompt_for_db(
        self,
        params: ConversationParams,
        user_question: str,
        db_query: dict,
        justification: str,
        chat_history: List[LlmHistory]
    ) -> Tuple[str, str, Dict[Any, Any], List[Any]]:
        from answers.llm_assist.llm_tooling.tools import (
            SqlRetrieverTool,  # Lazy import to prevent loading heavy langchain modules
        )
        sql_query = ""
        records: List[List[Dict[str, str]]] = []
        answer_context = {}
        self.retrieval_query_agent.previous_sql_errors = []
        logger.info("Running DB retrieval.")
        sql_tool = SqlRetrieverTool()
        context = ""
        try:
            context, sql_query, records, tables_used = sql_tool._run(db_query)
        except Exception as e:
            # The correction loop is only triggered if the query fails but is
            # not triggered if the LLM query fails.
            # LLM failure is will result in the error message being directly
            # returned as the answer to the user and the rest of the chain will
            # not be executed.
            logger.error(f"Unable to read from SQL table on first attempt {e}")
            context = f"Error when reading from SQL table {e}."
            logger.debug("Entering exception correction loop")
            failed_response = {**db_query, "justification": justification}
            failed_response_str = str(failed_response)
            self.retrieval_query_agent.previous_sql_errors.append(
                {"response": failed_response_str, "query": sql_query, "error": str(e)}
            )
            context, sql_query, records, tables_used = self.__exception_correction_loop(
                conversation_params=params,
                db_query=db_query,
                justification=justification,
                user_question=user_question,
                attempt=0,
                chat_history=chat_history,
            )
        finally:
            answer_context = self.__prep_db_answer_context(records, sql_query)  # type: ignore
        return context, sql_query, answer_context, tables_used

    def get_retrieval_context(self, params: ConversationParams,
        ) -> Tuple[Optional[Optional[Union[MultipartContext, str]]], Dict[str, Any]]:
            user_query = params.get("user_query", "")
            db_query = params.get("db_query", {}) or {}
            chat_history = params.get("chat_history", [])

            justification = params.get("justification", "")
            answer_context: Dict[str, Any] = {}
            if self.use_db_retrieval:
                context, self.__sql_query, answer_context, self.__tables_used = (
                    self.__compute_prompt_for_db(
                        params=params,
                        user_question=user_query,
                        db_query=db_query,
                        justification=justification,
                        chat_history=chat_history,
                    )
                )
            else:
                context = f"""Your teammate decided not make any query for the following reason:
                {justification}
                """
            return context, answer_context


    def get_as_json(  # type: ignore
        self,
        generated_answer: Union[str, Dict[str, Any], List[str]],
        user_profile: Optional[Dict[str, Any]] = None,
        uploaded_docs: Optional[List[MediaSummary]] = None,
    ) -> RetrievalSummaryJson:
        llm_context: LLMContext = {}  # type: ignore

        llm_context["selected_retrieval_info"] = get_retriever_info(config=self.webapp_config)
        if len(self.sql_query) > 0:
            llm_context["dataset_context"] = {
                "sql_retrieval_table_list": self.webapp_config.get("sql_retrieval_table_list"),
                "tables_used": self.tables_used,
                "sql_query": self.sql_query.replace("\n", "").replace('\\"', '"'),
            }
        if user_profile:
            llm_context["user_profile"] = user_profile

        if isinstance(generated_answer, str):
            return {
                "answer": generated_answer,
                "sources": [],
                "filters": None,
                "knowledge_bank_selection": [],
            }

        llm_context["uploaded_docs"] = (
            [
                {
                    "original_file_name": str(uploaded_doc.get("original_file_name")),
                    "file_path": str(uploaded_doc.get("file_path")),
                    "metadata_path": str(uploaded_doc.get("metadata_path")),
                }
                for uploaded_doc in uploaded_docs
            ]
            if uploaded_docs
            else []
        )
        if isinstance(generated_answer, dict):
            answer = generated_answer.get("answer", "")
            tables_used_str = (
                ", ".join(self.tables_used) if isinstance(self.tables_used, list) else str(self.tables_used)
            )
            # sample without answer key in it.
            sample = generated_answer.copy()
            sample.pop("answer", "")

            return {
                "answer": answer,
                "sources": generate_sources_from_sample(sample, tables_used_str),
                "filters": None,
                "llm_context": llm_context,
            }
        logger.error(f"Generated answer type not supported. This should not happen. {generated_answer}")
        return {}

    # TODO: correct retrieval query part - to be handled in different loops
    def create_query_from_history_and_update_params(
        self, chat_history: List[LlmHistory], user_query: str, params: ConversationParams
    ) -> ConversationParams:
        self.retrieval_query_agent.get_retrieval_graph(
            completion_settings=self.decision_completion_settings, chat_history=chat_history, user_input=user_query, conversation_params=params
        )
        retrieval_query, justification = self.retrieval_query_agent.get_retrieval_query(
            completion_settings=self.decision_completion_settings, conversation_params=params, chat_history=chat_history, user_input=user_query
        )
        params["justification"] = justification

        params["db_query"] = retrieval_query
        logger.debug(
            f"get_retrieval_query performed for db_query with: [{user_query}], computed db_query is [{params['db_query']}]"
        )

        self.__use_db_retrieval = False if params["db_query"] is None else True
        logger.debug(f"use_db_retrieval: {self.use_db_retrieval}")
        return params