import enum
import json
from typing import Optional, List

from langchain_core.documents import Document

from dataiku.langchain.base_rag_handler import BaseRagHandler
from dataiku.langchain.content_part_types import ImageRetrieval, ImageRefPart, CaptionedImageRefPart
from dataiku.langchain.document_handler import RetrievalSource
from dataiku.langchain.metadata_generator import DKU_DOCUMENT_INFO
from dataiku.langchain.multimodal_content import from_doc
from dataiku.llm.types import FileRef, SourceItem, SourcesSettings, ImageRef

class SourcesType(enum.Enum):
    SIMPLE_DOCUMENT = "SIMPLE_DOCUMENT"
    FILE_BASED_DOCUMENT = "FILE_BASED_DOCUMENT"

SETTINGS_KEY_TO_SOURCE_KEY: dict[str, str] = {
    "titleMetadata": "title",
    "urlMetadata": "url",
    "thumbnailURLMetadata": "thumbnailURL",
}

SNIPPET_FORMAT_TO_SOURCE_KEY = {
    "TEXT": "textSnippet",
    "MARKDOWN": "markdownSnippet",
    "HTML": "htmlSnippet",
    "JSON": "jsonSnippet"
}

def _get_file_based_document_info(doc: "Document") -> Optional[FileRef]:
    if DKU_DOCUMENT_INFO in doc.metadata:
        try:
            document_info = json.loads(doc.metadata[DKU_DOCUMENT_INFO])
        except Exception as e:
            raise ValueError(f"Metadata {doc.metadata[DKU_DOCUMENT_INFO]} is not a valid json", str(e))

        source_file = document_info.get("source_file", {})
        folder_ref = source_file.get("folder_ref", source_file.get("folder_full_id"))  # folder_full_id deprecated but to handle existing Kbs
        path = source_file.get("path")

        # Add a “sourceFile” section to the output **only when both** the folder id
        # and the path are present; otherwise omit the entire section.
        if folder_ref is not None and path is not None:
            return FileRef(folderId=folder_ref, path=path, pageRange=document_info.get("page_range"), sectionOutline=document_info.get("section_outline"))
    return None

class SourcesHandler(BaseRagHandler):
    def __init__(self,
                 sources_settings: SourcesSettings,
                 full_folder_id: Optional[str] = None,
                 retrieval_source: RetrievalSource = RetrievalSource.EMBEDDING,
                 retrieval_columns: Optional[List[str]] = None):
        self.sources_settings: SourcesSettings = sources_settings
        self.retrieval_source: RetrievalSource = retrieval_source
        self.full_folder_id = full_folder_id
        super().__init__(retrieval_columns)

    def build_role_based_source_from(self, doc: "Document") -> SourceItem:
        source_item: SourceItem = SourceItem()
        selected_metadata = self.sources_settings.get("metadataInSources")
        if selected_metadata is not None:
            source_item["metadata"] = {k: v for k, v in doc.metadata.items() if k in selected_metadata}

        for setting_key, source_key in SETTINGS_KEY_TO_SOURCE_KEY.items():
            setting_value = self.sources_settings.get(setting_key)
            if setting_value and setting_value in doc.metadata:
                source_item[source_key] = doc.metadata.get(setting_value)

        file_based_document_info = _get_file_based_document_info(doc)
        if file_based_document_info is not None:
            source_item["fileRef"] = file_based_document_info
            source_item["type"] = SourcesType.FILE_BASED_DOCUMENT.value
        else:
            source_item["type"] = SourcesType.SIMPLE_DOCUMENT.value

        multimodal_content = from_doc(doc)
        snippet_metadata = self.sources_settings.get("snippetMetadata")
        snippet_format = self.sources_settings.get("snippetFormat", "TEXT")

        if snippet_metadata is not None and snippet_metadata in doc.metadata:
            # If snippet metadata is provided, we use it to get the snippet data.
            snippet_data = doc.metadata[snippet_metadata]
        elif self.retrieval_source == RetrievalSource.MULTIMODAL and multimodal_content:
            if multimodal_content.type == "multipart":
                text_snippet = "\n".join([part.text for part in multimodal_content.content])
            else:
                text_snippet = multimodal_content.text
            if text_snippet is not None and len(text_snippet) > 0:
                snippet_data = text_snippet
            else:
                snippet_data = None
        elif self.retrieval_source == RetrievalSource.EMBEDDING:
            snippet_data = self.get_content(doc)
            assert self.retrieval_columns is not None
            if len(self.retrieval_columns) > 1:
                snippet_format = "JSON"
        else:
            # Fallback to the page content of the document.
            snippet_data = doc.page_content
        source_item[SNIPPET_FORMAT_TO_SOURCE_KEY[snippet_format]] = snippet_data

        if multimodal_content:
            # In case of images we can have multiple parts all the same type
            image_refs: List[ImageRef] = []
            assert self.full_folder_id is not None
            for image_part in multimodal_content.get_parts(None, ImageRetrieval.IMAGE_REF, self.full_folder_id):
                if type(image_part) in (ImageRefPart, CaptionedImageRefPart):
                    image_refs.append({
                        "folderId": image_part.full_folder_id,
                        "path": image_part.path
                    })
            source_item["imageRefs"] = image_refs

        return source_item
