# encoding: utf-8
"""
Execute an Embed documents recipe. Must be called in a Flow environment
"""
import json
import logging
from typing import Any, Optional, Dict

import pandas as pd

from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.base.utils import get_json_friendly_error, watch_stdin
from dataiku.core.vector_stores.dku_vector_store import DkuLocalVectorStore, VectorStoreFactory
from dataiku.langchain.dataset_loader import DatasetLoader, DummyLoader
from dataiku.langchain.dku_embeddings import DKUEmbeddings
from dataiku.langchain.metadata_generator import DKU_SECURITY_TOKENS_META

logger = logging.getLogger(__name__)

def serve(port: int, secret: str, server_cert: Optional[str] = None) -> None:
    link = JavaLink(port, secret, server_cert=server_cert)
    link.connect()

    # Wait for start command
    command = link.read_json()
    order = command["type"]
    assert order == "START", "Expect START Command"
    try:
        recipe_settings = json.loads(command["recipeSettings"])
        knowledge_bank_settings = json.loads(command["knowledgeBankSettings"])
        exec_folder = command["runFolder"]
        embeddings = DKUEmbeddings(llm_id=knowledge_bank_settings["embeddingLLMId"])
        vector_store = get_vector_store(knowledge_bank_settings, exec_folder)
        logger.info("Loading vector store")
        link.send_json({"type": "SUCCESS", "message": "Successfully connected to KB"})
    except Exception as e:
        link.send_json({"type": "ERROR", "message": str(e)})
        return

    while True:
        try:
            command = link.read_json()
            order = command["type"]
            if order == "UPDATE_KB":
                input_docs = InlineDataset(command["chunksToAdd"])
                input_docs_uuids = pd.read_csv(command["chunksIdsToAdd"], names=["chunk_uuids"])["chunk_uuids"].to_list()
                to_delete_docs_uuids = pd.read_csv(command["chunksIdsToDelete"], names=["chunk_uuids"])["chunk_uuids"].to_list()
                kb_status = embed_data(vector_store, embeddings, recipe_settings, input_docs, input_docs_uuids, to_delete_docs_uuids)
                link.send_json({'type': "SUCCESS", "kbStatus": kb_status})
            elif order == "CHECK_EMBEDDING":
                dummy_text = "This is just a dummy query to check embeddingSize"
                dummy_uuid = '00000000-0000-0000-0000-000000000000'
                loader = DummyLoader([dummy_text])
                vector_store.load_documents_to_add(loader, embeddings, [dummy_uuid])
                vector_store.load_documents_to_delete(embeddings, [dummy_uuid])
                link.send_json({"type": "SUCCESS", "message": "Embedding and Knowledge Bank seem compatible."})
            elif order == "CLEAR_KB":
                vector_store.clear()
                link.send_json({"type": "SUCCESS", "message": "Knowledge Bank cleared."})
            elif order == "STOP":
                # This is needed to avoid hanging on the line (EOFException)
                break
        except Exception:
            logger.exception("Knowledge bank server failed")
            json_err = get_json_friendly_error()
            json_err["type"] = "ERROR"
            link.send_json(json_err)
            return

    link.send_json({"type": "SUCCESS", "message": "Done."})

class InlineDataset:
    """ This class mimic the DkuDataset class to be able to lazy load from the jsonl file"""

    def __init__(self, filename: str):
        self.filename = filename

    def iter_dataframes(self, chunksize: int) -> Any:
        return pd.read_json(self.filename, chunksize=chunksize, lines=True)

def get_vector_store(kb_settings: Dict[str, Any], exec_folder: str):
    return VectorStoreFactory.get_vector_store(kb_settings, exec_folder, VectorStoreFactory.get_connection_details_from_env)

def embed_data(vector_store, embeddings, recipe_settings: Dict[str, Any], input_docs: InlineDataset, to_add_docs_uuids: list[str],
               to_delete_docs_uuids: list[str]) -> Dict[str, Any]:
    logger.info("Loading dataset records")
    documents_loader = DatasetLoader(
        input_data=input_docs,
        content_column=recipe_settings["knowledgeColumn"],
        source_id_column=None,
        security_tokens_column=DKU_SECURITY_TOKENS_META if recipe_settings.get("securityTokensColumn") else None,
        metadata_columns=[c["column"] for c in recipe_settings["metadataColumns"]]
    )
    vector_store.load_documents_to_add(documents_loader, embeddings, to_add_docs_uuids)
    vector_store.load_documents_to_delete(embeddings, to_delete_docs_uuids)

    kb_status = {}
    if isinstance(vector_store, DkuLocalVectorStore):
        kb_status["fileSizeMb"] = vector_store.get_file_size()

    return kb_status

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    read_dku_env_and_set()
    serve(port, secret, server_cert=server_cert)
