# -*- coding: utf-8 -*-
import dataiku
import json
from project_utils import encode_image, retrieve_chunks

# Load the "questions" dataset into a pandas DataFrame
df = dataiku.Dataset("questions").get_dataframe()
# Get a handle to the input folder
folder = dataiku.Folder("vOjkXoGz")

# Set and get the selected LLM
LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
project = dataiku.api_client().get_default_project()
llm = project.get_llm(LLM_ID)

# Define the number of chunks to retrieve
NUM_CHUNKS = 6
# Set the knowledge bank ID
KB_ID = "zdqno9RF"

# Retrieve the knowledge bank instance and convert it to a core knowledge bank object
kb = project.get_knowledge_bank(KB_ID).as_core_knowledge_bank()
# Create a retriever from the knowledge bank with specific search parameters
retriever = kb.as_langchain_retriever(search_kwargs={"k": 2*NUM_CHUNKS})

# Load and set the metadata as a dictionnary
metadata_df = dataiku.Dataset("metadata").get_dataframe()
metadata = {}
for i in metadata_df.index:
    metadata[metadata_df.at[i, "index"]] = json.loads(metadata_df.at[i, "metadata"])
    

def get_messages(chunks_with_metadata, question):
    """
    Build the messages send to the multimodal LLM.
    """
    # Initialize the prompt with the instruction and the question
    messages = [
        {
            "role": "system",
            "parts": [
                {
                    "type": "TEXT",
                    "text": "You are a helpful assitant. Concisely answer the question of the user based on the facts provided. If you don't know, just say you don't know."
                }
            ]
        },
        {
            "role": "user",
            "parts": [
                {
                    "type": "TEXT",
                    "text": f"Answer the following question: {question}. Use the following facts."
                }
            ]
        }
    ]
    # Add the chunks as text or image in the prompt
    for chunk in chunks_with_metadata:
        if chunk["type"] == "text":
            messages.append(
                {
                    "role": "user",
                    "parts": [
                        {
                            "type": "TEXT",
                            "text": f"Fact: {chunk['content']}"
                        }
                    ]
                }
            )
        else:
            caption = [{
                "type": "TEXT",
                "text": f"Fact: {chunk['caption']}"
            }] if "caption" in chunk else []
            messages.append(
                {
                    "role": "user",
                    "parts": caption + [
                        {
                            "type": "IMAGE_INLINE",
                            "inlineImage": encode_image(folder, chunk['image_url'])
                        }
                    ]
                }
            )
    return messages     
                         
def answer_question(question):
    """
    Answer the question using a multimodal RAG approach.
    """
    # Retrieve the similar chunks
    chunks_with_metadata = retrieve_chunks(retriever, question, metadata, NUM_CHUNKS)
    # Instantiate a new completion
    completion = llm.new_completion()
    # Create the structure of the message with the chunks
    completion.cq["messages"] = get_messages(chunks_with_metadata, question)
    # Set some parameters
    completion.settings["maxOutputTokens"] = 300
    completion.settings["temperature"] = 0
    # Query the LLM
    resp = completion.execute()
    
    return resp.text, json.dumps(chunks_with_metadata)

# Iterate through the rows of the questions DataFrame and get the answers
for i in df.index:
    answer, sources = answer_question(df.at[i, "question"])
    df.at[i, "generated_answer"] = answer
    df.at[i, "chunks"] = json.dumps(sources)

# Write the answers
dataiku.Dataset("answers").write_with_schema(df)