from typing import Optional, Dict, Any, Sequence, TYPE_CHECKING, List

from dataiku.langchain.base_rag_handler import BaseRagHandler, RetrievalSource
from dataiku.langchain.content_part_types import MultipartContent, TextPart
from dataiku.langchain.metadata_generator import DKU_MULTIMODAL_CONTENT, DKU_DOCUMENT_INFO
from dataiku.langchain.multimodal_content import from_doc
from dataiku.llm.types import BaseVectorStoreQuerySettings, RetrievableKnowledge

# Importing the classes that were moved out of this file for backwards compatibility with Answers plugin <=2.4.1
from dataiku.langchain.content_part_types import ImageRefPart, ImageRetrieval, InlineImagePart  # noqa: F401

if TYPE_CHECKING:
    from langchain_core.documents import Document

AZURE_CAPTIONS_KEY = "captions"
AZURE_ANSWERS_KEY = "answers"

def _should_cleanup_for_azure(retrievable_knowledge: RetrievableKnowledge, query_settings: BaseVectorStoreQuerySettings) -> bool:
    if retrievable_knowledge.get("vectorStoreType", "") != "AZURE_AI_SEARCH":
        return False
    if not query_settings.get("searchType") == "HYBRID":
        return False
    if not query_settings.get("useAdvancedReranking", False):
        return False
    return True

class DocumentHandler(BaseRagHandler):

    def __init__(self,
                 retrievable_knowledge: RetrievableKnowledge,
                 query_settings: BaseVectorStoreQuerySettings,
                 retrieval_source: RetrievalSource = RetrievalSource.EMBEDDING,
                 retrieval_columns: Optional[List[str]] = None):
        self.source_id_column: Optional[str] = retrievable_knowledge.get("sourceIdColumn")
        self.retrieval_source: RetrievalSource = retrieval_source
        self.full_folder_id: Optional[str] = retrievable_knowledge.get("managedFolderId")
        self.metadata_cleanup_for_azure: bool = _should_cleanup_for_azure(retrievable_knowledge, query_settings)
        super().__init__(retrieval_columns)

    def get_metadata_columns(self, doc: "Document") -> Dict[str, Any]:
        metadata_columns = doc.metadata

        # Filtering out the dataiku internal columns and the ones used for retrieval or source ID
        discarded_columns = []
        if self.source_id_column is not None:
            discarded_columns.append(self.source_id_column)
        if self.metadata_cleanup_for_azure:
            discarded_columns.append(AZURE_CAPTIONS_KEY)
            discarded_columns.append(AZURE_ANSWERS_KEY)

        keep_multimodal = self.retrieval_source != RetrievalSource.MULTIMODAL
        for k in metadata_columns.keys():
            if k.startswith("DKU_"):
                # Known examples: DKU_SECURITY_TOKENS, DKU_DOCUMENT_INFO
                if (keep_multimodal and k == DKU_MULTIMODAL_CONTENT) or k == DKU_DOCUMENT_INFO:
                    continue
                discarded_columns.append(k)

        return {k: v for k, v in metadata_columns.items() if k not in discarded_columns}

    def get_multipart_content(self,
                              doc: "Document",
                              index: Optional[int] = None,
                              image_retrieval: ImageRetrieval = ImageRetrieval.IMAGE_INLINE) -> Sequence[MultipartContent]:
        """
        Retrieve the document content based on the given retrieval source and return a list of multipart contents.

        For EMBEDDING retrieval sources, this method returns a list with a single text part.
        For MULTIMODAL retrieval source, it uses the multimodal column type field to determine the parts to be returned.
        - If the type of the document is text, it returns a list with a single text part.
        - If the type of the document is images, return a list of image parts.
        - If the type of the document is captioned_image, return a list of captioned image parts.

        If image_retrieval is set to `IMAGE_INLINE`, images are downloaded from the given folder.

        For inconsistent metadata, it returns as a fallback a list with a single text part from the embedded data.

        :param doc: A retrieved document from the knowledge bank
        :param index: The index of the document in the retrieved document collection
        :param image_retrieval: The image retrieval mode to use when retrieving image parts
        :returns: A list of multimodal parts
        """
        if self.retrieval_source == RetrievalSource.EMBEDDING:
            return [TextPart(index, self.get_content(doc))]

        # We are in the MULTIMODAL retrieval source case
        multimodal_content = from_doc(doc)
        if multimodal_content is None:
            return [TextPart(index, doc.page_content)]

        return multimodal_content.get_parts(index, image_retrieval, self.full_folder_id)
