# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import json

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, TypedDict, Union, cast

import dataiku
import pandas as pd

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
AC_LOGGING_DATASET_COLUMNS = [
    "conversation_id",
    "conversation_name",
    "llm_name",
    "user",
    "message_id",
    "question",
    "answer",
    "filters",
    "sources",
    "feedback_value",
    "feedback_choice",
    "feedback_message",
    "timestamp",
    "state",
    "llm_context",
    "generated_media",
]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
class RetrieverMode(Enum):
    KB = "kb"
    DB = "db"
    NO_RETRIEVER = "no_retrieval"

class RAGImage(TypedDict, total=False):
    full_folder_id: str
    file_path: str
    index: int
    file_data: str

class Source(TypedDict, total=False):
    excerpt: Optional[str]
    metadata: Optional[Dict[str, Union[str, int, float, List[Any]]]]
    sample: Optional[Dict[str, Union[str, int, float]]]
    images: Optional[List[RAGImage]]
    type: Optional[str]
    records: Optional[Dict]
    generatedSqlQuery: Optional[str]
    usedTables: Optional[str]
    title: Optional[str]
    url: Optional[str]
    textSnippet: Optional[str]
    tool_name_used: Optional[str]

class AggregatedToolSources(TypedDict, total=False):
    toolCallDescription: Optional[str]
    items: Optional[List[Source]]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def sources_are_old_format(sources: Union[List[Dict], List[Source], List[AggregatedToolSources]]) -> bool:
    keys_to_check = ["toolCallDescription", "items"]
    for source in sources:
        if any(key in source for key in keys_to_check):
            return False
    return True


def format_sources_if_required(sources: Union[List[Dict], List[Source], List[AggregatedToolSources]]) -> List[AggregatedToolSources]:
    """
    Take a source list, checks its format and enrich it to match
    agents-connect formatting if necessary
    """
    if sources:
        if sources_are_old_format(sources):
            return format_sources(cast(List[Source], sources)) # cast for type checker, will not affect runtime
    return cast(List[AggregatedToolSources], sources) # cast for type checker, will not affect runtime


def get_retrieval_mode(source: Source) -> RetrieverMode:
    """
    Return the source type, old or new format doesn't matter.
    """

    # old format
    keys = source.keys()
    if "sample" in keys:
        return RetrieverMode.DB
    elif "excerpt" in keys:
        return RetrieverMode.KB

    # new format
    source_type = source.get("type", "")
    if source_type == "RECORDS":
        return RetrieverMode.DB
    elif source_type == "SIMPLE_DOCUMENT":
        return RetrieverMode.KB

    return RetrieverMode.NO_RETRIEVER


def format_sources(sources: List[Source]) -> List[AggregatedToolSources]:
    """
    Wraps a given list of sources into a new dataset compatible with
    agents-connect format.
    """

    if not sources:
        return []

    # tool_type and tool_name should be the same for every source in the list
    # so it doesn't really matter in which source we get them
    items = []
    tool_type: Optional[str] = ""
    tool_name: Optional[str] = ""
    retrieval_mode = None

    for source in sources :
        if source:
            if not retrieval_mode:
                retrieval_mode = get_retrieval_mode(source)

            if retrieval_mode == RetrieverMode.KB:
                if not tool_type:
                    tool_type = "knowledge bank"

                if not tool_name:
                    tool_name = source.get("tool_name_used", "")

                metadata = source.get("metadata", {}) if isinstance(source.get("metadata", {}), dict) else {} # handle the case where metadata = None
                items.append(dict(
                    type="SIMPLE_DOCUMENT",
                    metadata=metadata,
                    title=metadata.get("source_title", ""), # type: ignore
                    url=metadata.get("source_url", ""), # type: ignore
                    textSnippet=source.get("excerpt"),
                    images=source.get("images", []),
                    tool_name_used=source.get("tool_name_used", ""),
                ))
            elif retrieval_mode == RetrieverMode.DB:
                if not tool_type:
                    tool_type = "database"

                if not tool_name:
                    tool_name = source.get("tool_name_used", "")

                metadata = source.get("metadata", {})
                sample = source.get("sample", {})
                if sample or metadata :
                    items.append(dict(
                        type="RECORDS",
                        records=sample,
                        generatedSqlQuery=sample.get("query", "") if sample else "",
                        usedTables=metadata.get("source_title", "") if metadata else "",
                        tool_name_used=source.get("tool_name_used", ""),
                    ))
                else:
                    items.append(source) # type: ignore

    return ([{
        # tool_name from sources is generated in the kb and db retrieval chains
        # through a call to get_retriever_info function
        "toolCallDescription": f"Used {tool_type} {tool_name}",
        "items": items # type: ignore
    }])


def format_sources_to_old_format(aggregatedToolSources: Union[List[AggregatedToolSources], List[Source]]) -> List[Source]:
    """
    This function needs to be used only to store sources in database.
    Everywhere else in the app we want to use the new sources format.
    Also this function will be removed at some point.
    """

    if not aggregatedToolSources:
        return []

    if (sources_are_old_format(aggregatedToolSources)):
        return aggregatedToolSources # type: ignore

    retrieval_mode = None
    reverted_sources: List[Source] = []
    for aggregatedToolSource in aggregatedToolSources:
        sources: Optional[List[Source]] = aggregatedToolSource.get("items", []) # type: ignore
        if sources:
            for source in sources:
                if not retrieval_mode:
                    retrieval_mode = get_retrieval_mode(source)

                if retrieval_mode == RetrieverMode.KB:
                    reverted_sources.append({
                        "excerpt": source.get("textSnippet", ""),
                        "metadata": source.get("metadata", {}),
                        "images": source.get("images", []),
                        "tool_name_used": source.get("tool_name_used", "")
                    })
                elif retrieval_mode == RetrieverMode.DB:
                    reverted_sources.append({
                        "sample": source.get("records", {}),
                        "metadata": {
                            "source_title": source.get("usedTables", "") or ""
                        },
                        "tool_name_used": source.get("tool_name_used", "")
                    })
                else:
                    source_type = source.get("type")
                    print(f"The source type `{source_type}` is not supported")

    return reverted_sources

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def migrate_llm_context_to_agent_connect(webapp_project_key: str, webapp_id: str, webapp_name: str, llm_context: str,
                                         question: str, file_path: Optional[str]=None):
    if not llm_context:
        return {}
    answers_llm_context = json.loads(llm_context)
    agent_connect_llm_context = {
        "agents_selection": {
            "calls":
            [
                {"agent_id": f"answers:{webapp_project_key}:{webapp_id}",
                 "query": question}
            ],
            "justification": f"[Migration from Answers `{webapp_name}` ({webapp_project_key}.{webapp_id}) to `Agent Connect`]"
        },
        "user_profile": answers_llm_context.get("user_profile", {}),
        "media_qa_context": answers_llm_context.get("media_qa_context", []),
        "uploaded_docs" : answers_llm_context.get("uploaded_docs", []),
        "trace": answers_llm_context.get("trace", {}) # A default trace mut be set for Answers Webapps that did not have that field. TODO: check the best way to handle this case
    }
    if file_path:
        agent_connect_llm_context["uploaded_docs"].append(file_path)
    return json.dumps(agent_connect_llm_context)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def migrate_sources_to_agent_connect(webapp_project_key: str, webapp_id: str, webapp_name: str, sources: str, answer:str):
    DEFAULT_SOURCES = {"sources": []}
    if not sources:
        return json.dumps(DEFAULT_SOURCES)
    try:
        answers_sources = json.loads(sources)
        agent_connect_sources = deepcopy(DEFAULT_SOURCES)
        agent_connect_sources["sources"].append({
            "name": webapp_name,
            "id": f"answers:{webapp_project_key}:{webapp_id}",
            "type": "answers webapp",
            "items": format_sources_if_required(answers_sources["sources"]),
            "answer": answer
        })
        return json.dumps(agent_connect_sources)

    except Exception as e:
        print(f"Exception met when migrating the source '{sources}': {str(e)}")
        print(f"The default sources `{DEFAULT_SOURCES}` will be used")
        return json.dumps(DEFAULT_SOURCES)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def migrate_chat_logs_dataset(answers_logging_dataset_name: str,
                              answers_webapp_id: str,
                              answers_webapp_name:str,
                              agent_connect_logging_dataset_name: str,
                              chunksize: int=500):
    DEFAULT_GENERATED_MEDIA = {"images": []}
    answers_logging_dataset = dataiku.Dataset(answers_logging_dataset_name)
    agent_connect_logging_dataset = dataiku.Dataset(agent_connect_logging_dataset_name)
    print("Setting the Agent Connect logging dataset schema")
    agent_connect_logging_dataset.write_schema(columns=[
        {"name": column_name, "type": "string"} for column_name in AC_LOGGING_DATASET_COLUMNS
    ],
      drop_and_create=True)

    with agent_connect_logging_dataset.get_writer() as writer:
        start_chunk_index = 1
        for chat_logs_df in answers_logging_dataset.iter_dataframes(chunksize=chunksize, infer_with_pandas=False):
            iteration_df = pd.DataFrame(columns=AC_LOGGING_DATASET_COLUMNS)
            chat_logs_df = chat_logs_df.where(pd.notnull(chat_logs_df), None)
            iteration_state = f"records {start_chunk_index} to {start_chunk_index+chunksize}"
            print(f"Migrating the 'Answers' logging dataset '{answers_logging_dataset_name}' ({iteration_state})")
            for _, row in chat_logs_df.iterrows():
                row_df = pd.DataFrame({
                    "conversation_id": [row["conversation_id"]],
                    "conversation_name": [row["conversation_name"]],
                    "llm_name": [row["llm_name"]],
                    "user": [row["user"]],
                    "message_id": [row["message_id"]],
                    "question": [row["question"]],
                    "answer": [row["answer"]],
                    "filters": [""],
                    "sources": [migrate_sources_to_agent_connect(webapp_project_key=current_project.project_key,
                                                                webapp_id=answers_webapp_id,
                                                                webapp_name=answers_webapp_name,
                                                                sources=row["sources"], answer=row["answer"])],
                    "feedback_value": [row["feedback_value"]],
                    "feedback_choice": [row["feedback_choice"]],
                    "feedback_message": [row["feedback_message"]],
                    "timestamp": [row["timestamp"]],
                    "state": [row["state"]],
                    "llm_context": [migrate_llm_context_to_agent_connect(webapp_project_key=current_project.project_key,
                                                                        webapp_id=answers_webapp_id, webapp_name=answers_webapp_name,
                                                                        llm_context=row["llm_context"], question=row["question"],
                                                                        file_path=row["file_path"])],
                    "generated_media": [json.dumps(json.loads(row["generated_media"]))] if row["generated_media"] else [json.dumps(DEFAULT_GENERATED_MEDIA)]
                })
                iteration_df = pd.concat([iteration_df, row_df], ignore_index=True)


            writer.write_dataframe(iteration_df)

            start_chunk_index += chunksize

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
current_project = dataiku.api_client().get_project(dataiku.default_project_key())

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
INPUT_ANSWERS_LOGGING_DATASET = None  # Fill the name of the answers logging dataset to migrate
OUTPUT_AGENT_CONNECT_LOGGING_DATASET = None  # Fill the name of the agent connect logging dataset to create
ANSWERS_WEBAPP_ID = None
ANSWERS_WEBAPP_NAME = None
# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
migrate_chat_logs_dataset(answers_logging_dataset_name=INPUT_ANSWERS_LOGGING_DATASET,
                          answers_webapp_id=ANSWERS_WEBAPP_ID,
                          answers_webapp_name=ANSWERS_WEBAPP_NAME,
                          agent_connect_logging_dataset_name=OUTPUT_AGENT_CONNECT_LOGGING_DATASET,
                          chunksize=100
                         )