import os
import tempfile
import json

import dataiku
from dataiku.llm.python import BaseLLM
from dataiku.langchain.dku_llm import DKUChatLLM
from dataiku.langchain.dku_embeddings import DKUEmbeddings

from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableParallel
from langchain.prompts.prompt import PromptTemplate

folder = dataiku.Folder("4pbb7xKD")

LLM_ID = dataiku.get_custom_variables()["LLM_id"]
llm = DKUChatLLM(
    llm_id=LLM_ID,
    temperature=0
)

EMBEDDING_MODEL_ID = dataiku.get_custom_variables()["embedding_model_id"]
embeddings = DKUEmbeddings(llm_id=EMBEDDING_MODEL_ID)
with tempfile.TemporaryDirectory() as temp_dir:
    for f in folder.list_paths_in_partition():
        with folder.get_download_stream(f) as stream:
            with open(os.path.join(temp_dir, os.path.basename(f)), "wb") as f2:
                f2.write(stream.read())
    vector_store = FAISS.load_local(
        temp_dir,
        embeddings,
        allow_dangerous_deserialization=True
    )

# Prompt

prompt = PromptTemplate(
    input_variables=["sources", "question"],
    template="""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question}
Context: {sources}"""
)

# Retrieval-augmented generation chain
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain_from_docs = (
    RunnablePassthrough.assign(sources=(lambda x: format_docs(x["sources"])))
    | prompt
    | llm
    | StrOutputParser()
)

class MyLLM(BaseLLM):
    def __init__(self):
        pass

    def process(self, query, settings, trace):
        with trace.subspan("Performing a semantic search") as subspan:
            question = query["messages"][0]["content"]
            num_chunks = 5 if ("context" not in query or "num_chunks" not in query["context"]) else query["context"]["num_chunks"]
            subspan.attributes["num_chunks"] = num_chunks
            subspan.attributes["question"] = question
            retriever = vector_store.as_retriever(num_chunks=num_chunks)
        
        with trace.subspan("Generating the answer"):
            rag_chain_with_source = RunnableParallel(
                {"sources": retriever, "question": RunnablePassthrough()}
            ).assign(result=rag_chain_from_docs)
            response = rag_chain_with_source.invoke(question)
            response["sources"] = [
                {"unique_id": x.metadata["source"], "content": x.page_content}
                for x in response["sources"]
            ]
            del response["question"]
            
        return {"text": json.dumps(response)}