# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Packages, inputs and constants

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import dataiku
import pandas as pd
from llama_index.core.schema import TextNode
from llama_index.core.llama_dataset.generator import RagDatasetGenerator
from dataiku.langchain.dku_llm import DKUChatLLM
from llama_index.llms.langchain import LangChainLLM

df = dataiku.Dataset("chunks").get_dataframe(sampling="random", limit=20).iloc[:10]

LLM_ID = dataiku.get_custom_variables()["LLM_id"]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Creation of question/context/answer tuples

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
nodes = [
    TextNode(
        text=df.at[i, "chunk"],
        id_=str(df.at[i, "chunk_id"]),
        metadata={key: df.at[i, key] for key in ["url", "chunk_id"]},
    )
    for i in df.index
]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
llm = LangChainLLM(llm=DKUChatLLM(llm_id=LLM_ID, temperature=0))

# instantiate a DatasetGenerator
dataset_generator = RagDatasetGenerator.from_documents(
    nodes, llm=llm, num_questions_per_chunk=1, show_progress=True
)

output_df = dataset_generator.generate_dataset_from_nodes().to_pandas()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Retrieve chunk id and write back

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
output_df["chunk"] = output_df["reference_contexts"].apply(lambda x: x[0])
df = (
    pd.merge(df, output_df, how="left", on=["chunk"])
    .rename(mapper={"query": "question"}, axis=1)
    .drop(
        ["reference_answer_by", "query_by", "chunk", "reference_contexts", "url"],
        axis=1,
    )
    .drop_duplicates(subset=["chunk_id"])
    .reset_index()
    .drop("index", axis=1)
)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("synthetic_questions").write_with_schema(df)
