# encoding: utf-8
"""
Execute an Embed documents recipe. Must be called in a Flow environment
"""
import logging
import os.path as osp
import sys
from typing import Any

from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import dkujson
from dataiku.core.vector_stores.dku_vector_store import DkuLocalVectorStore, VectorStoreFactory
from dataiku.langchain.dataset_loader import DatasetLoader
from dataiku.langchain.dku_embeddings import DKUEmbeddings
from dataiku.langchain.metadata_generator import DKU_SECURITY_TOKENS_META

logger = logging.getLogger("embed_documents_recipe")
import pandas as pd

class InlineDataset:
    """ This class mimic the DkuDataset class to be able to lazy load from the jsonl file"""
    def __init__(self, filename: str):
        self.filename = filename

    def iter_dataframes(self, chunksize: int) -> Any:
        return pd.read_json(self.filename, chunksize=chunksize, lines=True)

def embed_data(exec_folder: str, input_docs: InlineDataset, input_docs_uuids: list[str], to_delete_docs_uuids: list[str]) -> None:
    rk = dkujson.load_from_filepath(osp.join(exec_folder, "kb.json"))
    recipe_settings = dkujson.load_from_filepath(osp.join(exec_folder, "recipe_settings.json"))
    embeddings = DKUEmbeddings(llm_id=rk["embeddingLLMId"])


    logger.info("Loading vector store")
    vector_store = VectorStoreFactory.get_vector_store(rk, exec_folder)

    if recipe_settings.get("clearVectorStore"):
        logger.info("Vector store clear requested before adding/removing new records. Clearing the vector store.")
        vector_store.clear()

    if len(input_docs_uuids) > 0:
        logger.info("Loading dataset records")
        documents_loader = DatasetLoader(
            input_data=input_docs,
            content_column=recipe_settings["knowledgeColumn"],
            source_id_column=None,
            security_tokens_column=DKU_SECURITY_TOKENS_META if recipe_settings.get("securityTokensColumn") else None,
            metadata_columns=[c["column"] for c in recipe_settings["metadataColumns"]]
        )

        logger.info("Performing embedding : adding {} total new langchain documents in the vector store".format(len(input_docs_uuids)))
        vector_store.add_documents(documents_loader, embeddings, input_docs_uuids)
        logger.info("Finished adding langchain documents")
    else:
        logger.info("No langchain document to add in the vector store")

    logger.info("Deleting {} outdated langchain documents in the vector store".format(len(to_delete_docs_uuids)))
    vector_store.delete_documents(to_delete_docs_uuids, embeddings)

    kb_status = {}
    if isinstance(vector_store, DkuLocalVectorStore):
        kb_status["fileSizeMb"] = vector_store.get_file_size()

    kb_status_path = osp.join(exec_folder, "kb_status.json")
    dkujson.dump_to_filepath(kb_status_path, kb_status)

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    read_dku_env_and_set()
    run_folder = sys.argv[1]
    with ErrorMonitoringWrapper():
        input_docs = InlineDataset("to_add_chunks_content.jsonl")
        input_docs_uuids = pd.read_csv("to_add_chunks_uuids.csv", names=["chunk_uuids"])["chunk_uuids"].to_list()
        to_delete_docs_uuids = pd.read_csv("to_delete_chunks_uuids.csv", names=["chunk_uuids"])["chunk_uuids"].to_list()
        embed_data(run_folder, input_docs, input_docs_uuids, to_delete_docs_uuids)

