import enum
import json
import logging
import os
import os.path as osp
from typing import TYPE_CHECKING, Dict, List, Generator, Optional, Any, Sequence
from collections.abc import Callable

import pandas as pd

import dataiku
from dataiku.base.utils import package_is_at_least_no_import
from dataiku.core import dkujson
from dataiku.langchain import DKUEmbeddings
from dataiku.langchain.dataset_loader import VectorStoreLoader
from dataiku.llm.types import RetrievableKnowledge, BaseVectorStoreQuerySettings, SearchType
from dataikuapi.dss.admin import DSSConnection, DSSConnectionInfo
from dataikuapi.dss.utils import AnyLoc

if TYPE_CHECKING:
    from langchain_core.vectorstores import VectorStoreRetriever, VectorStore
    from langchain_core.documents import Document

logger = logging.getLogger(__name__)


RECORD_MANAGER_FILENAME = "record_manager_cache.sqlite" # embed dataset recipe
DSS_RECORD_MANAGER_FILENAME = "dss_record_manager_cache.sqlite" # embed document recipe
KB_STATUS_FILENAME = "kb_status.json"


class RecordManagerCleanupMode(enum.Enum):
    FULL = "full"  # deletes docs that are not included in the docs currently being indexed
    INCREMENTAL = "incremental"  # deletes previous versions of modified docs
    NONE = None  # does not delete anything


class UpdateMethod(enum.Enum):
    SMART_OVERWRITE = "SMART_OVERWRITE"  # called Smart Sync in the UI; requires a sourceIdColumn in the KB
    SMART_APPEND = "SMART_APPEND"  # called Upsert in the UI; requires a sourceIdColumn in the KB
    OVERWRITE = "OVERWRITE"
    APPEND = "APPEND"


class DkuVectorStore:

    def __init__(self, kb: RetrievableKnowledge, exec_folder: str):
        """
        :param kb: The contents of kb.json (dict)
        :param exec_folder: the current KB version subfolder
        :type exec_folder: str
        """
        self.kb = kb
        self.exec_folder = exec_folder

        self.metadata_column_type = {col["name"]: col["type"] for col in kb.get("metadataColumnsSchema", [])}
        self.source_id_column = kb.get("sourceIdColumn")  # Can be None, if not in a smart update mode
        self.document_filter = None
        self.total_added_documents = None # number of chunks added in overwrite mode (embed dataset only)

    @staticmethod
    def _get_mmr_args(query_settings: BaseVectorStoreQuerySettings) -> Dict:
        use_mmr = query_settings.get("searchType") == "MMR"
        if not use_mmr:
            return {}
        return {"fetch_k": query_settings.get("mmrK", 20), "lambda_mult": query_settings.get("mmrDiversity", 0.25)}

    @staticmethod
    def _get_similarity_threshold_args(query_settings: BaseVectorStoreQuerySettings) -> Dict:
        use_similarity_threshold = query_settings.get("searchType") == "SIMILARITY_THRESHOLD"
        if not use_similarity_threshold:
            return {}
        return {"score_threshold": query_settings.get("similarityThreshold", 0.5)}

    @staticmethod
    def _get_search_kwargs(query_settings: BaseVectorStoreQuerySettings) -> Dict:
        # By default, the k param (nb of documents to return) is passed as part of the search kwargs
        search_kwargs = {
            "k": query_settings.get("maxDocuments", 4),
            **DkuVectorStore._get_mmr_args(query_settings),
            **DkuVectorStore._get_similarity_threshold_args(query_settings)
        }
        return search_kwargs

    @staticmethod
    def _get_retriever_kwargs(query_settings: BaseVectorStoreQuerySettings) -> Dict:
        # By default no retriever kwargs are needed
        return {}

    def _get_db_kwargs(self, query_settings: BaseVectorStoreQuerySettings) -> Dict:
        # By default no db kwargs are needed
        return {}

    @staticmethod
    def dss_search_type_to_langchain_search_type(search_type: SearchType):
        if search_type == "SIMILARITY":
            return "similarity"
        elif search_type == "SIMILARITY_THRESHOLD":
            return "similarity_score_threshold"
        elif search_type == "MMR":
            return "mmr"
        elif search_type == "HYBRID":
            return "hybrid"
        else:
            return "similarity"

    def include_similarity_score(self, query_settings: BaseVectorStoreQuerySettings) -> bool:
        search_type = query_settings.get("searchType")
        return query_settings.get("includeScore", False) and search_type != "MMR"

    @staticmethod
    def _get_search_type(query_settings: BaseVectorStoreQuerySettings) -> str:
        return DkuVectorStore.dss_search_type_to_langchain_search_type(query_settings.get("searchType", "SIMILARITY"))

    def as_retriever(self, embeddings: DKUEmbeddings, query_settings: BaseVectorStoreQuerySettings, additional_search_kwargs: Optional[Dict] = None) -> "VectorStoreRetriever":
        search_kwargs = self._get_search_kwargs(query_settings)
        if isinstance(additional_search_kwargs, dict):
            search_kwargs = {**search_kwargs, **additional_search_kwargs}

        db_kwargs = self._get_db_kwargs(query_settings)
        search_type = self._get_search_type(query_settings)
        retriever_kwargs = self._get_retriever_kwargs(query_settings)

        logger.info("--> Resulting search kwargs: %s" % search_kwargs)
        logger.info("--> Resulting db kwargs: %s" % db_kwargs)
        logger.info("--> Resulting retriever kwargs: %s" % retriever_kwargs)
        logger.info("--> Resulting search type: %s" % search_type)

        return self.get_db(embeddings, **db_kwargs).as_retriever(
            search_type=search_type,
            search_kwargs=search_kwargs,
            **retriever_kwargs
        )

    def search_with_scores(self, query: str, embeddings: DKUEmbeddings, query_settings: BaseVectorStoreQuerySettings, additional_search_kwargs: Optional[Dict] = None) -> List[tuple[Dict, float]]:
        search_kwargs = self._get_search_kwargs(query_settings)
        if isinstance(additional_search_kwargs, dict):
            search_kwargs = {**search_kwargs, **additional_search_kwargs}

        db_kwargs = self._get_db_kwargs(query_settings)
        retriever_kwargs = self._get_retriever_kwargs(query_settings)

        return self.get_db(embeddings, **db_kwargs).similarity_search_with_relevance_scores(query, **search_kwargs, **retriever_kwargs)

    def _prepare_and_clean_metadata(self, document):
        new_meta = {}
        for key, val in document.metadata.items():
            if pd.isna(val):
                # Removing the empty metadata field all-together to stay aligned and safe with all vector stores.
                continue
            column_type = self.metadata_column_type.get(key)
            if column_type is None:
                new_meta[key] = val
            elif column_type in ["bigint", "int", "smallint", "tinyint"]:
                # Casting to int since if a single value is missing on an int column, panda will make the whole column float.
                new_meta[key] = int(val)
            elif column_type == "string":
                # Casting to string for vector stores like azure ai that when given an int on a str column will throw an error.
                new_meta[key] = str(val)
            elif column_type in ["date", "datetimetz", "datetimenotz", "dateonly"] and isinstance(val, pd.Timestamp):
                # Converting to string to avoid issues with vector stores that don't support pandas.Timestamp and for later serialisation
                new_meta[key] = self.process_timestamp(val)
            else:
                new_meta[key] = val
        document.metadata = new_meta
        return document

    def process_timestamp(self, value: pd.Timestamp) -> str:
        return value.isoformat()

    def transform_document_before_load(self, document: "Document") -> "Document":
        """
        Hook for a vectorstore to rewrite the document at load time.
        """
        document = self._prepare_and_clean_metadata(document)
        if self.document_filter is not None:
            document = self.document_filter.add_security_tokens_to_document(document)
        return document

    def add_security_filter(self, search_kwargs: Dict, caller_groups : List) -> None:
        if not self.document_filter:
            raise Exception("Security filter not implemented on this vector store")
        self.document_filter.add_security_filter(search_kwargs, caller_groups)

    def add_filter(self, search_kwargs: Dict, filter_ : Dict) -> None:
        if not self.document_filter:
            raise Exception("Explicit filter not implemented on this vector store")
        self.document_filter.add_filter(search_kwargs, filter_)

    @staticmethod
    def _clean_filter(filter_: Dict, metadata_columns: List[str], vector_store_query_tool_ref: Optional[str]) -> Optional[Dict]:
        """
        Discard irrelevant filters (with a `toolRef` other than this vector store query tool's id).
        On relevant filters, irrelevant clauses (on unknown columns) are discarded.
        """
        # Ensure the filter has no toolRef or that toolRefs match
        filter_tool_ref = filter_.get("toolRef")
        if vector_store_query_tool_ref is not None and filter_tool_ref is not None:
            vector_store_query_tool_loc = AnyLoc.from_full(vector_store_query_tool_ref)
            filter_tool_loc = AnyLoc.from_ref(dataiku.api_client().get_default_project().project_key, filter_tool_ref)
            if filter_tool_loc != vector_store_query_tool_loc:
                logger.info(f"{filter_} in context not applied. The specified toolRef {filter_tool_ref} is different than the toolRef: {vector_store_query_tool_ref}")
                return None
        return DkuVectorStore._clean_clause(filter_.get("filter", {}), metadata_columns)

    @staticmethod
    def _clean_clause(clause: Dict, metadata_columns: List[str]) -> Optional[Dict]:
        """
        Discard clauses whose column doesn't exist in the KB.
        """
        if clause.get("operator") in ["AND", "OR"]: # Several clauses:
            cleaned_clauses = [c for c in [DkuVectorStore._clean_clause(c, metadata_columns) for c in clause.get("clauses", [])] if c is not None]
            if len(cleaned_clauses) == 0:
                return None
            elif len(cleaned_clauses) == 1:
                return cleaned_clauses[0]
            else:
                clause["clauses"] = cleaned_clauses
                return clause
        else:
            if clause.get("column") not in metadata_columns:
                logger.info(f"Knowledge Bank metadata lacks column {clause.get('column')}, discarding context clause {clause}")
                return None
        return clause

    def add_dynamic_filter(self, search_kwargs, caller_filters: List[Dict], vector_store_query_tool_id: Optional[str]):
        """Add dynamic filtering"""
        metadata_columns = [c.get("name") for c in self.kb.get("metadataColumnsSchema", [])]
        cleaned_filters = [cf for cf in [DkuVectorStore._clean_filter(f, metadata_columns, vector_store_query_tool_id) for f in caller_filters] if cf is not None]
        for cf in cleaned_filters:
            self.document_filter.add_filter(search_kwargs, cf)

    def load_documents(self, documents_loader: VectorStoreLoader, embeddings: DKUEmbeddings, update_method:UpdateMethod=UpdateMethod.OVERWRITE) -> "VectorStore":
        """
        Load the given documents into the vector store, with the provided update method.

        :type documents_loader: dataiku.langchain.VectorStoreLoader
        :type embeddings: dataiku.langchain.dku_embeddings.DKUEmbeddings
        :param update_method: One of:
            UpdateMethod.SMART_OVERWRITE: Sync the vector store with the documents, using langchain RecordManager. Requires a sourceIdColumn set on the KB.
            UpdateMethod.SMART_APPEND: Upsert the documents into the vector store, using langchain RecordManager. Requires a sourceIdColumn set on the KB.
            UpdateMethod.OVERWRITE (default): Clear the vector store and add the documents
            UpdateMethod.APPEND: Add the documents to the vector store
        :type update_method: UpdateMethod

        :return: The new version of the vector store, equivalent to calling self.get_db(embeddings)
        :rtype: langchain_community.vectorstores.VectorStore
        """
        if update_method == UpdateMethod.SMART_OVERWRITE:
            return self._managed_overwrite_documents(documents_loader, embeddings)
        elif update_method == UpdateMethod.SMART_APPEND:
            return self._managed_append_documents(documents_loader, embeddings)
        elif update_method == UpdateMethod.OVERWRITE:
            return self._overwrite_documents(documents_loader, embeddings)
        elif update_method == UpdateMethod.APPEND:
            return self._append_documents(documents_loader, embeddings)
        else:
            raise ValueError("Unknown knowledge bank update method %s" % update_method)

    def load_documents_to_add(self, documents_loader:VectorStoreLoader, embeddings: DKUEmbeddings, to_add_documents_uuids: Optional[list[str]]=None)-> "VectorStore":
        if len(to_add_documents_uuids) > 0:
            logger.info("Performing embedding : adding {} total new langchain documents in the vector store".format(len(to_add_documents_uuids)))
            vector_store_db = self.add_documents(documents_loader, embeddings, to_add_documents_uuids)
            logger.info("Finished adding langchain documents")
            return vector_store_db
        else:
            logger.info("No langchain document to add in the vector store")
            return None

    def load_documents_to_delete(self, embeddings: DKUEmbeddings, to_delete_documents_uuids:  Optional[list[str]]=None):
        logger.info("Deleting {} outdated langchain documents in the vector store".format(len(to_delete_documents_uuids)))
        vector_store_db = self.delete_documents(to_delete_documents_uuids, embeddings)
        return vector_store_db

    def _overwrite_documents(self, documents_loader: VectorStoreLoader, embeddings: DKUEmbeddings) -> "VectorStore":
        logger.info("Adding documents to the vector store, in overwrite mode")
        if self.source_id_column is not None:
            raise ValueError("Source ID column cannot be used with overwrite mode")
        self.clear()
        return self.add_documents(documents_loader, embeddings)

    def _managed_overwrite_documents(self, documents_loader: VectorStoreLoader, embeddings: DKUEmbeddings) -> "VectorStore":
        logger.info("Adding documents to the vector store, in managed overwrite mode")
        if not self.source_id_column:
            raise ValueError("Source ID column is required when using smart sync mode")
        vectorstore_db = self.get_db(embeddings, allow_creation=True)
        self._index_documents(documents_loader, vectorstore_db, cleanup_mode=RecordManagerCleanupMode.FULL, source_id_key=self.source_id_column)
        return vectorstore_db

    def _append_documents(self, documents_loader: VectorStoreLoader, embeddings: DKUEmbeddings) -> "VectorStore":
        logger.info("Adding documents to the vector store, in append mode")
        if self.source_id_column is not None:
            raise ValueError("Source ID column cannot be used with upsert mode")
        return self.add_documents(documents_loader, embeddings)

    def _managed_append_documents(self, documents_loader: VectorStoreLoader, embeddings: DKUEmbeddings) -> "VectorStore":
        logger.info("Adding documents to the vector store, in managed append mode")
        if not self.source_id_column:
            raise ValueError("Source ID column is required when using upsert mode")
        vectorstore_db = self.get_db(embeddings, allow_creation=True)
        self._index_documents(documents_loader, vectorstore_db, cleanup_mode=RecordManagerCleanupMode.INCREMENTAL, source_id_key=self.source_id_column)
        return vectorstore_db

    def delete_documents(self, documents_uuids:  list[str], embeddings: DKUEmbeddings) -> "VectorStore":
        if len(documents_uuids) > 0:
            vectorstore_db = self.get_db(embeddings=embeddings, allow_creation=False)
            vectorstore_db.delete(ids=documents_uuids)
            return vectorstore_db
        return None

    def add_documents(self, documents_loader: VectorStoreLoader, embeddings: DKUEmbeddings, documents_uuids:  Optional[list[str]]=None) -> "VectorStore":
        vectorstore_db = self.get_db(embeddings, allow_creation=True)
        total = 0
        nb_batches = 0
        for documents in documents_loader.lazy_load():
            if len(documents) > 0:
                for document in documents:
                    self.transform_document_before_load(document)

                nb_batches += 1
                if documents_uuids:
                    vectorstore_db.add_documents(documents, ids=documents_uuids[total: total + len(documents)])
                else:
                    vectorstore_db.add_documents(documents)
                total += len(documents)
                logger.info(f"Added a batch of {len(documents)} documents to the vector store ({total} added so far)")
        if total > 0:
            logger.info(f"Added {total} documents to the vector store in {nb_batches} batches")
        self.total_added_documents = total
        return vectorstore_db

    def _index_documents(self, documents_loader: VectorStoreLoader, vectorstore_db: "VectorStore", cleanup_mode: RecordManagerCleanupMode, source_id_key: str, batch_size: int=100) -> None:  # 100 is the default batch size used by RecordManager
        from langchain.indexes import index

        batch_size = self._check_batch_size_against_sqlite_version(batch_size)
        record_manager = self._get_record_manager()

        logger.info(f"Indexing documents, with cleanup_mode={cleanup_mode}, source_id_key={source_id_key}, batch_size={batch_size}")

        def iter_transformed_documents() -> Generator["Document", None, None]:
            for doc in documents_loader.iter_documents():
                ret = self.transform_document_before_load(doc)
                yield ret

        def patch_for_batched_delete() -> None:
            """
            Warning: this method breaks the query into multiple transactions,
            so if it fails mid way we might end up with an inconsistent index.

            That is already happening for RecordManagerCleanupMode.FULL anyway.
            """
            non_batched_delete = record_manager.delete_keys
            def batched_delete_keys(self, keys: Sequence[str]) -> None:
                # batch the keys based on batch_size
                for i in range(0, len(keys), batch_size):
                    batch = keys[i:i + batch_size]
                    non_batched_delete(batch)
            from types import MethodType
            record_manager.delete_keys = MethodType(batched_delete_keys, record_manager)

        if cleanup_mode == RecordManagerCleanupMode.INCREMENTAL:
            patch_for_batched_delete()

        index_info = index(
            iter_transformed_documents(),
            record_manager,
            vectorstore_db,
            cleanup=cleanup_mode.value,
            source_id_key=source_id_key,
            batch_size=batch_size,
            cleanup_batch_size=batch_size
        )
        logger.info(f"Vector store indexing done {index_info}")

    def _get_chunk_count(self, update_method: UpdateMethod) -> Optional[int]:
        """
        Gets the number of chunks (also known as documents in langchain) for an embed dataset recipe.
        Each update method has a different behavior for determining the number of chunks:
        - Overwrite: Use the number of chunks from this run
        - Append: Do not calculate number of chunks (can become inaccurate if using old KBs without status files
        - Smart sync/smart append: Calculate the number of rows in the record manager db (Note: the record manager does
        not exist for overwrite/append cases)
        :param update_method
        :type update_method: UpdateMethod
        :return: The number of chunks
        :rtype: Optional[int]
        """
        if update_method == UpdateMethod.OVERWRITE:
            return self.total_added_documents
        elif update_method == UpdateMethod.APPEND:
            return None
        else:
            from sqlalchemy import text
            record_manager = self._get_record_manager()
            try:
                with record_manager.engine.connect() as connection:
                    result = connection.execute(text("SELECT COUNT(*) FROM upsertion_record"))
                    return result.scalar()
            except Exception as e:
                logger.warning(f"Failed to get chunk count from {record_manager}: {e}")

        return None

    def _get_source_document_count(self, update_method: UpdateMethod) -> Optional[int]:
        """
        Gets the number of source documents for an embed dataset recipe.
        This is only available for smart sync and smart append update methods. For append and overwrite update methods,
        we do not use the record manager, and therefore we cannot easily access the original source documents.
        :param update_method
        :type update_method: UpdateMethod
        :return: The number of source documents
        :rtype: Optional[int]
        """
        if update_method == UpdateMethod.SMART_OVERWRITE or update_method == UpdateMethod.SMART_APPEND:
            from sqlalchemy import text
            record_manager = self._get_record_manager()
            try:
                with record_manager.engine.connect() as connection:
                    result = connection.execute(text("SELECT COUNT(DISTINCT group_id) FROM upsertion_record"))
                    return result.scalar()
            except Exception as e:
                logger.warning(f"Failed to get document count from {record_manager}: {e}")

        return None

    def get_status(self, update_method: UpdateMethod) -> Dict:
        """
        Gets the number of chunks and documents, with the provided update method.
        Note this is only called for the embed dataset recipe. See Java implementation for the embed document recipe.
        :param update_method
        :type update_method: UpdateMethod
        :return: The number of chunks and documents
        :rtype: Dict
        """
        chunk_count = self._get_chunk_count(update_method)
        document_count = self._get_source_document_count(update_method)

        return {
            "nbChunks": chunk_count,
            "nbDocuments": document_count
        }

    def _get_record_manager(self):
        from langchain.indexes import SQLRecordManager

        record_manager_file_path = os.path.join(self.exec_folder, RECORD_MANAGER_FILENAME)
        record_manager_already_exists = os.path.exists(record_manager_file_path)
        if record_manager_already_exists:
            logger.info(f"Using existing record manager: {record_manager_file_path}")
        else:
            logger.info(f"Creating new record manager: {record_manager_file_path}")
        record_manager = SQLRecordManager(namespace="record_manager_namespace", db_url=f"sqlite:///{record_manager_file_path}")
        if not record_manager_already_exists:
            record_manager.create_schema()

        return record_manager

    @staticmethod
    def _check_batch_size_against_sqlite_version(batch_size: int) -> int:
        # Older versions of sqlite have a lower SQLITE_MAX_VARIABLE_NUMBER, which causes issues for RecordManager when the batch size is too high
        try:
            import sqlite3
            safe_batch_size = 190  # RecordManager uses 5 variables per row, and SQLITE_MAX_VARIABLE_NUMBER=999 on sqlite versions before 3.32
            if batch_size > safe_batch_size and sqlite3.sqlite_version_info < (3, 32, 0):
                logger.warning(f"Reducing the RecordManager batch size to {safe_batch_size} because of an older version of sqlite (<3.32).")
                batch_size = safe_batch_size
        except:
            pass  # ignore missing sqlite in case future versions of RecordManager use a different sql engine

        return batch_size

    def get_db(self, embeddings: DKUEmbeddings, allow_creation: bool = False, **kwargs: Any) -> "VectorStore":
        """
        :type embeddings: dataiku.langchain.dku_embeddings.DKUEmbeddings
        :param allow_creation: whether creating new resources if the db doesn't exist is allowed.
        :type allow_creation: boolean
        :rtype: langchain_community.vectorstores.VectorStore
        """
        raise NotImplementedError()

    def clear(self) -> None:
        """ Clear all data in the knowledge bank - used only from the clear server for now"""
        raise NotImplementedError()

    def clear_record_manager(self, folder_path: str) -> None:
        logger.info("Clearing record manager")
        record_manager_file_path = os.path.join(folder_path, RECORD_MANAGER_FILENAME)
        if os.path.isfile(record_manager_file_path):
            os.remove(record_manager_file_path)

        record_manager_file_path = os.path.join(folder_path, DSS_RECORD_MANAGER_FILENAME)
        if os.path.isfile(record_manager_file_path):
            os.remove(record_manager_file_path)

    def clear_kb_status(self, folder_path: str) -> None:
        logger.info("Clearing KB status file")
        kb_status_path = osp.join(folder_path, KB_STATUS_FILENAME)
        dkujson.dump_to_filepath(kb_status_path, {
            "fileSizeMb": 0,
            "nbChunks": 0,
            "nbDocuments": 0
        })

    def get_vector_size(self) -> int:
        # create a dummy embedding query via the public api to retrieve the embedding size no matter which llm was selected.
        if not self.kb.get("embeddingLLMId"):
            raise ValueError("An embedding model must be selected")

        project_handle = dataiku.api_client().get_default_project()
        llm_model = project_handle.get_llm(self.kb["embeddingLLMId"])

        query = llm_model.new_embeddings()
        query.add_text("This is just a dummy query to get embeddingSize")
        model_embedding_size = len(query.execute().get_embeddings()[0])

        logger.info("Retrieved vector size for LLM id {}: {}".format(self.kb["embeddingLLMId"], model_embedding_size))
        return model_embedding_size


class DkuLocalVectorStore(DkuVectorStore):

    def __init__(self, kb: RetrievableKnowledge, exec_folder: str, collection_name: str):
        super(DkuLocalVectorStore, self).__init__(kb, exec_folder)
        self.collection_name = collection_name

    def clear(self) -> None:
        """
        Delete all the KB versioned data
        """
        self.clear_files(self.exec_folder)
        self.clear_record_manager(self.exec_folder)
        self.clear_kb_status(self.exec_folder)
        logger.info("Cleared local vector store at {}".format(self.exec_folder))

    def clear_files(self, folder_path: str) -> None:
        raise NotImplementedError()

    def get_file_size(self) -> int:
        raise NotImplementedError()

    def get_status(self, update_method) -> Dict:
        status = super().get_status(update_method)
        status["fileSizeMb"] = self.get_file_size()

        return status


class DkuRemoteVectorStore(DkuVectorStore):

    def __init__(self, kb: RetrievableKnowledge, exec_folder: str, connection_info_retriever: Callable[[str], DSSConnectionInfo], bulk_size: int = 1000):
        super(DkuRemoteVectorStore, self).__init__(kb, exec_folder)
        self.index_name: Optional[str] = None

        index_name = kb.get("resolvedIndexName")
        if index_name is None:
            if kb.get("indexName"):
                raise ValueError("The index name variables could not be resolved.")
            else:
                raise ValueError("You must provide a value for the Knowledge Bank index name.")
        self.set_index_name(index_name)

        self.type = kb["vectorStoreType"]
        self.connection_info_retriever = connection_info_retriever
        self.connection_name = kb['connection']
        self.check_connection()
        self.init_connection()
        bulk_size_from_connection_params = self.get_bulk_size_from_connection_params(self.connection_info_retriever(self.connection_name).get_params())
        if bulk_size_from_connection_params:
            self.bulk_size = bulk_size_from_connection_params
        else:
            self.bulk_size = bulk_size

    def get_bulk_size_from_connection_params(self, connection_params: Dict) -> Optional[int]:
        for property in connection_params.get('dkuProperties', []):
            if property['name'] == "dku.embedding.bulkSize":
                try:
                    bulk_size = int(property['value'])
                    logger.info("Set custom bulk size: {}".format(bulk_size))
                    return bulk_size
                except ValueError:
                    logger.warning("Ignoring invalid value for bulk size connection property (expected an integer): {}".format(property['value']))
        return None

    def _load_remote_resources_references(self) -> Dict:
        """ Loads the references to pre-existing remote-resources for the selected vector type
            :return dict containing the references to remote resources (ids/names required to manage their lifecycle)
        """
        filepath = osp.join(self.exec_folder, "{}_remote_resources.json".format(self.type))  # todo double check read/write permissions on this file
        if osp.exists(filepath):
            logger.info("Loading remote resources references from {}".format(filepath))
            return dkujson.load_from_filepath(filepath)
        return {}

    def _dump_remote_resources_references(self, remote_resources_ref: Dict) -> None:
        """ Dump the references to remote resources used for this KB version
            :param remote_resources_ref: dict containing the references to remote resources (ids/names required to manage their lifecycle)
        """
        filepath = osp.join(self.exec_folder, "{}_remote_resources.json".format(self.type))
        logger.info("Dumping remote resources references to {}".format(filepath))
        dkujson.dump_to_filepath(filepath, remote_resources_ref)

    def set_index_name(self, index_name: str) -> None:
        self.index_name = index_name

    def init_connection(self):
        raise NotImplementedError()

    def clear(self) -> None:
        """
        Delete current index & any local files in exec_folder
        """
        self.clear_index()   # to improve: if index name changed between 2 kb build, this will only clean remote resources with current index name.
        self.clear_record_manager(self.exec_folder)
        self.clear_kb_status(self.exec_folder)
        logger.info("Clearing remote resources from KB version '{}'".format(self.exec_folder))

    def clear_index(self) -> None:
        """ Clear remote resources with current index name & kb version"""
        raise NotImplementedError()

    def check_connection(self):
        if self.connection_name is None:
            raise ValueError("You must provide a connection to be used with the knowledge bank.")
        connection_info = self.connection_info_retriever(self.connection_name)
        if "params" not in connection_info:
            raise Exception("You lack the permission to read the details of the connection " + self.connection_name + ".")

    def get_batches(self, documents: List["Document"]) -> Generator[List["Document"], None, None]:
        return (documents[i:min(len(documents), i + self.bulk_size)] for i in range(0, len(documents), self.bulk_size))

    def ensure_documents_are_indexed(self) -> None:
        # nothing to do for most implementations
        pass

    def add_documents(self, documents_loader: VectorStoreLoader, embeddings: DKUEmbeddings, documents_uuids:  Optional[list[str]]=None) -> "VectorStore":
        vectorstore_db = self.get_db(embeddings, allow_creation=True)
        total = 0
        nb_batches = 0
        for documents in documents_loader.lazy_load():
            if len(documents) > 0:
                for document in documents:
                    self.transform_document_before_load(document)
                for batch in self.get_batches(documents):
                    if documents_uuids:
                        vectorstore_db.add_documents(batch, ids=documents_uuids[total: total + len(batch)])
                    else:
                        vectorstore_db.add_documents(batch)
                    total += len(batch)
                    nb_batches += 1
                    logger.info(f"Added a batch of {len(batch)} documents to the vector store ({total} added so far)")
        if total > 0:
            logger.info(f"Added {total} documents to the vector store in {nb_batches} batches")
        self.ensure_documents_are_indexed()
        self.total_added_documents = total
        return vectorstore_db

    def _index_documents(self, documents_loader: VectorStoreLoader, vectorstore_db: "VectorStore", cleanup_mode: RecordManagerCleanupMode, source_id_key: str, batch_size: Optional[int] = None) -> None:
        if batch_size is None:
            batch_size = self.bulk_size
        super(DkuRemoteVectorStore, self)._index_documents(documents_loader, vectorstore_db, cleanup_mode, source_id_key, batch_size)
        self.ensure_documents_are_indexed()


class VectorStoreFactory:

    @staticmethod
    def get_vector_store(kb: RetrievableKnowledge, exec_folder: str, connection_info_retriever: Optional[Callable[[str], DSSConnectionInfo]] = None) -> DkuVectorStore:
        """
        :param kb: The contents of kb.json
        :type kb: RetrievableKnowledge
        :param exec_folder: Vector Store location on disk
        :type exec_folder: string
        :param connection_info_retriever: handler to retrieve connection info
        :type connection_info_retriever: func|None
        :rtype: DkuVectorStore
        """
        vector_store_type = kb["vectorStoreType"]

        if connection_info_retriever is None and 'connection' in kb:
            connection_info_retriever = VectorStoreFactory.get_default_connection_retriever()

        if vector_store_type == "FAISS":
            from dataiku.core.vector_stores.faiss_vector_store import FAISSVectorStore
            return FAISSVectorStore(kb, exec_folder)

        elif vector_store_type == "PINECONE":
            return VectorStoreFactory.get_correct_pinecone_vectorstore(kb, exec_folder, connection_info_retriever)

        elif vector_store_type == "AZURE_AI_SEARCH":
            from dataiku.core.vector_stores.azureaisearch_vector_store import AzureAISearchVectorStore
            return AzureAISearchVectorStore(kb, exec_folder, connection_info_retriever)

        elif vector_store_type == "VERTEX_AI_GCS_BASED":
            from dataiku.core.vector_stores.vertexai_vector_store import VertexAiVectorStore
            return VertexAiVectorStore(kb, exec_folder, connection_info_retriever)

        elif vector_store_type == "ELASTICSEARCH":
            return VectorStoreFactory.get_correct_elasticsearch_vectorstore(kb, exec_folder, connection_info_retriever)

        elif vector_store_type == "CHROMA":
            from dataiku.core.vector_stores.chroma_vector_store import ChromaVectorStore
            return ChromaVectorStore(kb, exec_folder)

        elif vector_store_type == "QDRANT_LOCAL":
            from dataiku.core.vector_stores.qdrant_local_vector_store import QDrantLocalVectorStore
            return QDrantLocalVectorStore(kb, exec_folder)

        else:
            raise NotImplementedError("Requested vector store type invalid: {}".format(vector_store_type))

    @staticmethod
    def get_default_connection_retriever() -> Callable[[str], DSSConnectionInfo]:

        # Retrieves connection info from backend, requires "Read details" permission on the connection
        def remote_connection_retriever(connection_name: str) -> DSSConnectionInfo:
            connection = dataiku.api_client().get_connection(connection_name)
            return connection.get_info()

        return remote_connection_retriever

    @staticmethod
    def get_connection_details_from_env(connection_name: str) -> DSSConnectionInfo:
        connection_info_json = os.getenv("DKU_KB_CONNECTION_INFO")
        if connection_info_json is None:
            raise Exception("Missing connection details for connection {}".format(connection_name))
        # Connection details are already retrieved, we just build DSSConnectionInfo
        return DSSConnectionInfo(json.loads(connection_info_json))

    @staticmethod
    def needs_local_path(kb: RetrievableKnowledge) -> bool:
        vector_store_type = kb["vectorStoreType"]

        if vector_store_type in {"PINECONE", "ELASTICSEARCH", "AZURE_AI_SEARCH"} and not kb.get("sourceIdColumn"):
            return False

        return True

    @staticmethod
    def get_correct_pinecone_vectorstore(kb: RetrievableKnowledge, exec_folder: str, connection_info_retriever) -> DkuVectorStore:
        if package_is_at_least_no_import("pinecone", "6.0") or package_is_at_least_no_import("pinecone-client", "3.0"):
            from dataiku.core.vector_stores.pinecone_v3_vector_store import PineconeV3VectorStore
            return PineconeV3VectorStore(kb, exec_folder, connection_info_retriever)
        else:
            from dataiku.core.vector_stores.pinecone_v2_vector_store import PineconeV2VectorStore
            return PineconeV2VectorStore(kb, exec_folder, connection_info_retriever)

    @staticmethod
    def get_correct_elasticsearch_vectorstore(kb: RetrievableKnowledge, exec_folder: str, connection_info_retriever) -> DkuVectorStore:
        # This is an ugly hack to auto-detect the distribution of the connection
        # it will be replaced once we have settled for the proper way for the user to declare
        # and configure OpenSearch connections
        from elasticsearch import UnsupportedProductError
        try:
            from dataiku.core.vector_stores.elasticsearch_vector_store import ElasticSearchVectorStore
            vs = ElasticSearchVectorStore(kb, exec_folder, connection_info_retriever)
            logger.info("Using ElasticSearch vector store implementation.")
            return vs
        except UnsupportedProductError as e:
            distribution = e.body.get("version", {}).get("distribution", "")
            is_opensearch = distribution == "opensearch"
            # AWS OpenSearch Managed Cluster in compatibility mode doesn't report distribution
            if not is_opensearch and "opensearch" in e.body.get("tagline", "").lower():
                is_opensearch = True
            if is_opensearch:
                from dataiku.core.vector_stores.opensearch_vector_store import OpenSearchVectorStore
                logger.info("Using OpenSearch vector store implementation.")
                return OpenSearchVectorStore(kb, exec_folder, connection_info_retriever)
            else:
                raise UnsupportedProductError("You need to use an ElasticSearch version >= 7.14.0", e.meta, e.body)
