# -*- coding: utf-8 -*-
import dataiku
import json
from project_utils import encode_image, retrieve_chunks

df = dataiku.Dataset("questions").get_dataframe()
folder = dataiku.Folder("vOjkXoGz")

# Set and get the selected LLM
LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
project = dataiku.api_client().get_default_project()
llm = project.get_llm(LLM_ID)

# Define the number of chunks to retrieve
NUM_CHUNKS = 6
# Set the knowledge bank ID
KB_ID = "zdqno9RF"

# Retrieve the knowledge bank instance and convert it to a core knowledge bank object
kb = project.get_knowledge_bank(KB_ID).as_core_knowledge_bank()
# Create a retriever from the knowledge bank with specific search parameters
retriever = kb.as_langchain_retriever(search_kwargs={"k": 2*NUM_CHUNKS})

# Load and set the metadata as a dictionnary
metadata_df = dataiku.Dataset("metadata").get_dataframe()
metadata = {}
for i in metadata_df.index:
    metadata[metadata_df.at[i, "index"]] = json.loads(metadata_df.at[i, "metadata"])

# Iterate through the rows of the questions DataFrame
for i in df.index:
    question = df.at[i, "question"]
    # Retrieve chunks related to the question and convert them to JSON format
    df.at[i, "chunks"] = json.dumps(json.dumps(retrieve_chunks(retriever, question, metadata, NUM_CHUNKS)))

dataiku.Dataset("questions_augmented").write_with_schema(df)