import itertools
import json
import uuid
from typing import Any, Dict, Generator, List, Optional

import pandas as pd
from answers.backend.db.sql.dataset_schemas import ANSWERS_DATASETS_METADATA, MESSAGE_DATASET_CONF_ID
from common.backend.constants import LOGGING_DATES_FORMAT
from common.backend.db.sql.queries import (
    DeleteQueryBuilder,
    UpdateQueryBuilder,
    WhereCondition,
    get_post_queries,
)
from common.backend.db.sql.tables_managers import GenericMessageSQL
from common.backend.models.base import (
    APIFeedback,
    APIGeneratedMedia,
    APIMessageResponse,
    APIProcessedFile,
    APIRetrieval,
    MessageInsertInfo,
    RecordState,
)
from common.backend.services.sources.sources_formatter import format_sources_if_required
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.file_utils import delete_files, get_file_format
from common.backend.utils.json_utils import load_json
from common.backend.utils.sql_timing import log_execution_time
from common.backend.utils.uploaded_files_utils import (
    extract_uploaded_docs_from_history,
)
from common.llm_assist.logging import logger
from dataiku.sql.expression import Operator
from werkzeug.exceptions import BadRequest


def stringify(data, default_value: Optional[Any]=None) -> Optional[str]:
    if data:
        return json.dumps(data, ensure_ascii=False)
    else:
        if default_value is not None:
            return json.dumps(default_value, ensure_ascii=False)
    return None


class MessageSQL(GenericMessageSQL):
    def __init__(self):
        config = dataiku_api.webapp_config
        if not config.get("enable_answers_api", False):
            return
        super().__init__(config=config,
                         columns=ANSWERS_DATASETS_METADATA[MESSAGE_DATASET_CONF_ID]["columns"],
                         dataset_conf_id=MESSAGE_DATASET_CONF_ID,
                         default_project_key=dataiku_api.default_project_key
        )

        self.permanent_delete = self.config.get("permanent_delete", True)

    @log_execution_time
    def add_message(
        self,
        user: str,
        message_info: MessageInsertInfo,
        timestamp: str
    ) -> str:
        logger.debug("Adding an API message...")
        message_id = str(uuid.uuid4())  # generate
        record_value = [
            message_id,
            user,
            message_info.get("platform"),
            message_info.get("llm_name"),
            message_info.get("query"),
            message_info.get("answer"),
            stringify(message_info.get("filters")),
            stringify(message_info.get("sources"), []),
            timestamp,
            stringify(message_info.get("llm_context")),
            stringify(message_info.get("generated_media")),
            stringify(message_info.get("feedback")),
            message_info.get("conversation_id"),
            RecordState.PRESENT.value,
            stringify(message_info.get("history")),
        ]

        self.insert_record(record_value)
        return message_id

    @log_execution_time
    def delete_messages_permanently(self, conditions: List[WhereCondition]):
        delete_query = DeleteQueryBuilder(self.dataset).add_conds(conds=conditions).build()
        try:
            self.executor.query_to_df(delete_query, post_queries=get_post_queries(self.dataset))
            logger.info(f"Conversation messages successfully deleted (conditions={conditions})")
        except Exception as err:
            logger.error(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    @log_execution_time
    def update_messages_as_deleted(self, conditions: List[WhereCondition]):
        update_query = (
            UpdateQueryBuilder(self.dataset)
            .add_set_cols([(self.col("state"), RecordState.DELETED.value)])
            .add_conds(conditions)
            .build()
        )
        try:
            self.executor.query_to_df(update_query, post_queries=get_post_queries(self.dataset))
            logger.info(f"Conversation messages successfully flagged as deleted (conditions={conditions})")
        except Exception as err:
            logger.error(err)
            raise BadRequest(f"Error when executing SQL query: {err}")

    # for now copied from conversation.py - could be moved to another file
    @staticmethod
    def get_paths_from_history(series: pd.Series, data_key: str, additional_file_name: str) -> List[str]:
        def get_paths(row):
            if not row:
                return []
            media_qa_context = json.loads(row).get(data_key, [])
            return itertools.chain(
                (media.get(additional_file_name) for media in media_qa_context if media.get(additional_file_name)),
                (media.get("file_path") for media in media_qa_context if media.get("file_path")),
            )

        series = series.map(get_paths)  # type: ignore
        unique_values = list(set(item for sublist in series for item in sublist if item))
        return unique_values
    
    @staticmethod
    def format_feedback_choice(choice: Optional[str]):
        if choice is None or choice == "":
            return []
        return [value.strip() for value in choice.split(";")]
    
    @log_execution_time
    def get_message_media_paths(self, conditions: List[WhereCondition]) -> List[str]:
        if dataiku_api.webapp_config.get("upload_folder") is None:
            return []
        media_columns = [
            self.col("llm_context"),
            self.col("generated_media"),
        ]
        media_df = self.select_columns_from_dataset(
            column_names=media_columns,
            eq_cond=conditions,
        )
        if isinstance(media_df, Generator) or media_df.empty:
            return []

        llm_context = media_df[self.col("llm_context")].dropna()
        generated_media = media_df[self.col("generated_media")].dropna()
        uploaded = MessageSQL.get_paths_from_history(
            series=llm_context, data_key="media_qa_context", additional_file_name="metadata_path"
        )
        generated = MessageSQL.get_paths_from_history(
            series=generated_media, data_key="images", additional_file_name="referred_file_path"
        )
        return uploaded + generated     

    @log_execution_time
    def get_all_conversation_messages(self, platform: str, user: str, conversation_id: str, only_present: bool= True)->List[APIMessageResponse]:
        def extract_conversation_messages(conversation_messages_df: pd.DataFrame):
            conversation_messages: List[APIMessageResponse] = []
            if not conversation_messages_df.empty:
                conversation_messages_df[self.col("created_at")] = pd.to_datetime(conversation_messages_df[self.col("created_at")], format=LOGGING_DATES_FORMAT)
                conversation_messages_df = conversation_messages_df.sort_values(by=self.col("created_at"), ascending=True)
                for index, row in conversation_messages_df.iterrows():
                    if only_present and row[self.col("state")] != RecordState.PRESENT.value:
                        continue

                    llm_context = load_json(row[self.col("llm_context")], {})
                    selected_retrieval_info: Dict[str, Any] = llm_context.get("selected_retrieval_info", {})

                    used_retrieval = APIRetrieval(
                        name=selected_retrieval_info.get("name", ""),
                        type=selected_retrieval_info.get("type", ""),
                        alias=selected_retrieval_info.get("alias", ""),
                        filters=load_json(row[self.col("filters")], {}).get("filters", {}) if row[self.col("filters")] else {},
                        sources=format_sources_if_required(load_json(row[self.col("sources")], [])),
                        generatedSqlQuery=llm_context.get("dataset_context", {}).get("sql_query", ""),
                        usedTables=llm_context.get("dataset_context", {}).get("tables_used", [])
                        )
                    
                    feedback = load_json(row[self.col("feedback")], APIFeedback(value="", choice=[], message=""))

                    # TODO: Implement the 'files' data extraction()
                    files: List[APIProcessedFile] = []
                    chat_history = load_json(row[self.col("history")], [])
                    previous_media_summaries = extract_uploaded_docs_from_history(chat_history)
                    for summary in previous_media_summaries:
                        files.append(
                            APIProcessedFile(
                                name=summary.get("original_file_name", ""),
                                format=get_file_format(summary.get("file_path", "")),
                                path=summary.get("file_path", ""),
                                thumbnail=summary.get("preview", ""),
                                chainType=summary.get("chain_type", ""),
                                jsonFilePath=summary.get("metadata_path", "")
                            )
                        )
                    for uploaded_doc in llm_context.get("uploaded_docs", []):
                        files.append(
                            APIProcessedFile(
                                name=uploaded_doc.get("original_file_name", ""),
                                format=get_file_format(uploaded_doc.get("file_path", "")),
                                path=uploaded_doc.get("file_path", ""),
                                thumbnail=uploaded_doc.get("preview", ""),
                                chainType=uploaded_doc.get("chain_type", ""),
                                jsonFilePath=uploaded_doc.get("metadata_path", "")
                                )
                        )

                    row_medias = load_json(row[self.col("generated_media")], {})
                    generated_media: List[APIGeneratedMedia] = []
                    if row_medias:
                         for media in row_medias["images"]:
                             media_format = media.get("file_format", "")
                             media_path = media.get("file_path", "")
                             generated_media.append(
                                 APIGeneratedMedia(
                                     data=media.get("file_data", ""),
                                     format=media_format,
                                     path=media_path,
                                     referredFilePath=media.get("referred_file_path", "")
                                     )
                                     )
                    
                    conversation_messages.append(
                        APIMessageResponse(
                            id=row[self.col("message_id")],
                            createdAt=row[self.col("created_at")],
                            query=row[self.col("query")],
                            answer=row[self.col("answer")],
                            usedRetrieval=used_retrieval,
                            feedback=feedback,
                            files=files,
                            generatedMedia=generated_media
                            
                        )
                    )
                    
            return conversation_messages

        conditions = [
            WhereCondition(column=self.col("platform"), value=platform, operator=Operator.EQ),
            WhereCondition(column=self.col("user"), value=user, operator=Operator.EQ),
            WhereCondition(
                column=self.col("conversation_id"),
                value=conversation_id,
                operator=Operator.EQ,
            ),
        ]
        format_ = "dataframe"
        conversation_messages_df: pd.DataFrame = self.select_columns_from_dataset(column_names=self.columns, eq_cond=conditions, format_=format_)
        conversation_messages: List[APIMessageResponse] = extract_conversation_messages(conversation_messages_df)
        return conversation_messages            
        
    @log_execution_time
    def delete_conversation_messages(self, platform: str, user: str, conversation_id: str):
        conditions = [
            WhereCondition(column=self.col("platform"), value=platform, operator=Operator.EQ),
            WhereCondition(column=self.col("user"), value=user, operator=Operator.EQ),
            WhereCondition(
                column=self.col("conversation_id"),
                value=conversation_id,
                operator=Operator.EQ,
            ),
        ]
        if self.permanent_delete:
            logger.info(f"Deleting permanently the user '{user}' messages from conversation '{conversation_id}' ...")
            self.delete_files(conditions)
            self.delete_messages_permanently(conditions)

        else:
            logger.info(f"Flagging the user '{user}' messages from  conversation '{conversation_id}' as deleted ...")
            self.update_messages_as_deleted(conditions)
    
    def delete_files(self, conditions: List[WhereCondition]):
        # TODO handle permenant delete and not permenant delete
        media_to_delete = self.get_message_media_paths(conditions)
        if media_to_delete:
            logger.debug(f"Deleting media files: {media_to_delete}")
            delete_files(media_to_delete)
    
messages_sql_manager = MessageSQL()
