# encoding: utf-8
import json
import typing
from contextlib import contextmanager
from typing import Optional

from dataiku.base import remoterun
from dataiku.core import intercom, default_project_key

if typing.TYPE_CHECKING:
    from dataiku.llm.types import TrustedObject

class KnowledgeBank(object):

    """
    This is a handle to interact with a Dataiku Knowledge Bank flow object
    """
    def __init__(self, id, project_key=None, context_project_key=None):
        self.id = id
        if "." not in id:
            try:
                self.project_key = project_key or default_project_key()
                self.short_name = id
                self.name = self.project_key + "." + id
            except:
                raise Exception("Knowledge bank %s is specified with a relative name, "
                                "but no default project was found. Please use complete name" % self.id)
        else:
            # use gave a full name
            (self.project_key, self.short_name) = self.id.split(".", 1)
            if project_key is not None and self.project_key != project_key:
                raise ValueError("Project key %s incompatible with fullname Knowledge bank %s." % (project_key, id))
        self.full_name ="%s.%s" % (self.project_key, self.short_name)
        self.location_info = None
        self.cols = None
        self.rk = None

        self.set_context_project_key(context_project_key)

    def set_context_project_key(self, context_project_key):
        """
        Set the context project key to use to report calls 
        to the embedding LLM associated with this knowledge bank.

        :param context_project_key: the context project key
        :type context_project_key: str
        """
        self._context_project_key = context_project_key

    def _get_resolved_context_project_key(self):
        if self._context_project_key:
            return self._context_project_key

        if remoterun.has_env_var("DKU_CURRENT_PROJECT_KEY"):
            return remoterun.get_env_var("DKU_CURRENT_PROJECT_KEY")
        else:
            return self.project_key

    def _get(self):
        if self.rk is None:
            self.rk = intercom.backend_json_call("knowledge-bank/get", data={
                    "contextProjectKey": self._context_project_key,
                    "knowledgeBankFullId": self.full_name
            })
        return self.rk

    def get_current_version(self, trusted_object: Optional["TrustedObject"]=None):
        """
        Gets the current version for this knowledge bank.

        :param trusted_object: the optional trusted object using the kb
        :type trusted_object: Optional["TrustedObject"]
        :rtype: str
        """
        api_resp = intercom.backend_json_call("knowledge-bank/get-current-version", data={
            "knowledgeBankFullId": self.full_name,
            "contextProjectKey": self._context_project_key,
            "trustedObject": json.dumps(trusted_object) if trusted_object else None
        })

        return api_resp["version"]

    def load_into_isolated_folder(self, version=None, trusted_object: Optional["TrustedObject"]=None):
        """
        Loads the vector store files on disk.
        For local vector stores, downloads metadata files as well as data files.
        For remote vector stores, only downloads metadata files.
        By default, the latest version is loaded on disk.

        .. note::
            Copies are isolated. A new call will create a new folder.

        :param version: the knowledge bank version
        :type version: Optional[str]
        :param trusted_object: the optional trusted object using the kb
        :type trusted_object: Optional["TrustedObject"]

        :return: :class:`dataiku.core.vector_stores.lifecycle.isolated_folder.VectorStoreIsolatedFolder`
        """
        if version is None:
            version = self.get_current_version(trusted_object)

        from dataiku.core.vector_stores.lifecycle.isolated_folder import load_into_isolated_folder
        return load_into_isolated_folder(
            self.project_key, self.short_name, version,
            use_latest_settings=False,
            trusted_object=trusted_object
        )

    @contextmanager
    def get_writer(self):
        """
        Gets a writer on the latest vector store files on disk.
        For local vector stores, downloads metadata files as well as data files.
        For remote vector stores, only downloads metadata files.

        The vector store files are automatically uploaded when the context
        manager is closed.

        .. note::
            Each call creates an isolated writer which works on its own folder.

        :return: :class:`dataiku.core.vector_stores.data.writer.VectorStoreWriter`
        """
        from dataiku.core.vector_stores.lifecycle.isolated_folder import load_into_isolated_folder
        current_version = self.get_current_version()
        isolated_folder = load_into_isolated_folder(
            self.project_key, self.short_name, current_version,
            use_latest_settings=True
        )

        try:
            writer = isolated_folder.get_writer()
            yield writer
            writer.save()

        finally:
            isolated_folder.remove()

    def as_langchain_retriever(self, search_type="similarity", search_kwargs=None, vectorstore_kwargs=None, **retriever_kwargs):
        """
        Get the current version of this knowledge bank as a Langchain Retriever object.

        :rtype: :class:`langchain_core.vectorstores.VectorStoreRetriever`
        """
        search_kwargs = search_kwargs or {}
        vectorstore_kwargs = vectorstore_kwargs or {}
        langchain_vectorstore = self.as_langchain_vectorstore(**vectorstore_kwargs)
        return langchain_vectorstore.as_retriever(
            search_type=search_type, search_kwargs=search_kwargs,
            **retriever_kwargs)

    def as_langchain_vectorstore(self, **vectorstore_kwargs):
        """
        Get the current version of this knowledge bank as a Langchain Vectorstore object.

        :rtype: :class:`langchain_core.vectorstores.VectorStore`
        """
        from dataiku.core.langchain_vector_stores_cache import LANGCHAIN_VECTOR_STORES_CACHE
        vectorstore_kwargs["_context_project_key"] = vectorstore_kwargs.pop("context_project_key", self._get_resolved_context_project_key())
        return LANGCHAIN_VECTOR_STORES_CACHE.get_or_create(
            self.project_key, self.short_name,
            self.get_current_version(),
            vectorstore_kwargs
        )

    def get_multipart_context(self, docs):
        """
        Convert retrieved documents from the vector store to a multipart context.
        The multipart context contains the parts that can be added to a completion query

        :param docs: A list of retrieved documents from the langchain retriever
        :type docs: List[Document]

        :raises Exception: If the knowledge bank does not contain multimodal content

        :returns: A multipart context object composed by a list of parts containing text or images
        :rtype: :class:`MultipartContext`
        """
        from dataiku.langchain.document_handler import DocumentHandler, RetrievalSource

        self.rk = self._get()

        multimodal_column = self.rk.get("multimodalColumn")
        if not multimodal_column:
            raise Exception("Knowledge bank {id} does not contain multimodal content".format(id=self.id))

        multipart_context = MultipartContext()

        document_handler = DocumentHandler(self.rk, {}, retrieval_source=RetrievalSource.MULTIMODAL)

        for index, document in enumerate(docs):
            parts = document_handler.get_multipart_content(document, index=index)
            for part in parts:
                multipart_context.append(part)
        return multipart_context


class MultipartContext:
    """
    A reference to a list of text or images parts that can be added to a completion query
    """

    def __init__(self):
        self.parts = []

    def append(self, part):
        """
        :param part: Part of a completion query
        :type part: :class:`MultipartContent`
        """
        self.parts.append(part)

    def add_to_completion_query(self, completion, role="user"):
        """
        Add the accumulated parts as a new multipart-message to the completion query

        :param completion: the completion query to be edited
        :type completion: :class:`DSSLLMCompletionsQuerySingleQuery`

        :param str role: The message role. Use ``system`` to set the LLM behavior, ``assistant`` to store predefined
          responses, ``user`` to provide requests or comments for the LLM to answer to. Defaults to ``user``.
        """
        multipart_message = completion.new_multipart_message(role=role)
        for part in self.parts:
            if part.type == "TEXT":
                multipart_message.with_text(part.text)
            elif part.type == "IMAGE_INLINE":
                multipart_message.with_inline_image(part.inline_image, mime_type=part.image_mime_type)
        multipart_message.add()

    def is_text_only(self):
        """
        :returns: True if all the accumulated parts are text parts, False otherwise
        :rtype: bool
        """
        for part in self.parts:
            if part.type != "TEXT":
                return False
        return True

    def to_text(self):
        """
        :returns: the concatenation of accumulated text parts (other parts are skipped)
        :rtype: str
        """
        text_content = ""
        for part in self.parts:
            if part.type == "TEXT":
                text_content += part.text
                text_content += "\n\n"
        return text_content
