# encoding: utf-8
"""
Execute a RAG embedding recipe. Must be called in a Flow environment
"""
import logging
import os.path as osp
import sys

import dataiku
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import dkujson
from dataiku.core.vector_stores.dku_vector_store import UpdateMethod, VectorStoreFactory
from dataiku.langchain.dataset_loader import DatasetLoader, TextSplitterLoader, DEFAULT_DATASET_BATCH_SIZE
from dataiku.langchain.dku_embeddings import DKUEmbeddings

logger = logging.getLogger("rag_embedding_recipe")


def embed_data(exec_folder: str, input_data: dataiku.Dataset) -> None:
    rk = dkujson.load_from_filepath(osp.join(exec_folder, "kb.json"))
    recipe_settings = dkujson.load_from_filepath(osp.join(exec_folder, "recipe_settings.json"))

    vector_store_update_method = UpdateMethod(recipe_settings.get("vectorStoreUpdateMethod", UpdateMethod.OVERWRITE))

    # TODO @rag Workaround for issues with PINECONE that do not work well with smart update modes. Can remove this when RecordManager works with PINECONE
    #           See also rag_embedding.js RAGEmbeddingRecipeEditionController $scope.updateMethodDisabledReason
    if (vector_store_update_method == UpdateMethod.SMART_OVERWRITE or vector_store_update_method == UpdateMethod.SMART_APPEND) and rk["vectorStoreType"] == "PINECONE":
        raise Exception(f"Smart update methods not supported for PINECONE")

    logger.info("Loading vector store")
    vector_store = VectorStoreFactory.get_vector_store(rk, exec_folder, VectorStoreFactory.get_connection_details_from_env)

    if recipe_settings.get("clearVectorStore"):
        logger.info("Vector store clear requested before indexing new records. Clearing the vector store.")
        vector_store.clear()

    logger.info("Loading dataset records")

    if not recipe_settings.get("knowledgeColumn"):
        raise ValueError("The Embedding column is missing please select one")

    documents_loader = DatasetLoader(
        input_data=input_data,
        content_column=recipe_settings["knowledgeColumn"],
        source_id_column=recipe_settings.get("sourceIdColumn"),
        security_tokens_column=recipe_settings.get("securityTokensColumn"),
        metadata_columns=[c["column"] for c in recipe_settings["metadataColumns"]],
        limit=recipe_settings.get("maxRecords"),
        chunk_size=recipe_settings.get("datasetBatchSize", DEFAULT_DATASET_BATCH_SIZE),
    )

    if recipe_settings["documentSplittingMode"] == "CHARACTERS_BASED":
        documents_loader = TextSplitterLoader(
            documents_loader,
            recipe_settings["chunkSizeCharacters"],
            recipe_settings["chunkOverlapCharacters"]
        )
    elif recipe_settings["documentSplittingMode"] == "NONE":
        logger.info("Not performing splitting")
    else:
        raise Exception("Illegal splitting mode %s" % recipe_settings["documentSplittingMode"])

    embeddings = DKUEmbeddings(llm_id=rk["embeddingLLMId"])

    logger.info("Performing embedding and indexing")

    if recipe_settings.get("sourceIdColumn") != rk.get("sourceIdColumn"):
        logger.warning("sourceIdColumn in the embedding recipe ({}) is out of sync with the sourceIdColumn in the Knowledge Bank ({})." +
                       " Using the one from the knowledge bank.".format(recipe_settings.get("sourceIdColumn"), rk.get("sourceIdColumn")))


    vector_store_update_method = UpdateMethod(recipe_settings.get("vectorStoreUpdateMethod", UpdateMethod.OVERWRITE))
    vector_store.load_documents(documents_loader, embeddings, vector_store_update_method)

    logger.info("Finished loading documents")

    kb_status = vector_store.get_status(vector_store_update_method)
    kb_status_path = osp.join(exec_folder, "kb_status.json")
    dkujson.dump_to_filepath(kb_status_path, kb_status)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    read_dku_env_and_set()
    run_folder = sys.argv[1]
    input_dataset_name = sys.argv[2]

    with ErrorMonitoringWrapper():
        input_dataset = dataiku.Dataset(input_dataset_name)
        embed_data(run_folder, input_dataset)
