import asyncio
import base64
import io
import logging
import os
import traceback
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import List, Awaitable, Optional

import torch
from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend
from docling.backend.html_backend import HTMLDocumentBackend
from docling.backend.mspowerpoint_backend import MsPowerpointDocumentBackend
from docling.backend.msword_backend import MsWordDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions, TableFormerMode, AcceleratorOptions, AcceleratorDevice, EasyOcrOptions, \
    PaginatedPipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, PowerpointFormatOption, HTMLFormatOption
from docling.models.layout_model import LayoutModel
from docling.models.table_structure_model import TableStructureModel
from docling.pipeline.simple_pipeline import SimplePipeline
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
from docling_core.types import DoclingDocument
from docling_core.types.doc import DocItemLabel, ProvenanceItem, ImageRef
from docling_core.types.io import DocumentStream
from pydantic import AnyUrl

from dataiku.base.batcher import Batcher
from dataiku.llm.docextraction.ocr import get_ocr_config, process_images_with_ocr
from dataiku.doctor.utils.gpu_execution import TorchGpuCapability

logger = logging.getLogger("docling_extraction")

@dataclass
class DoclingRequest:
    file_name: str
    document_content: str
    max_section_depth: int
    do_ocr: bool
    ocr_engine: str
    lang: str

@dataclass
class DoclingResponse:
    ok: bool
    resp: dict
    error: Optional[str]

    def to_dict(self):
        return asdict(self)

class DocumentNode:
    """
    A tree node structure to map docling converted structure to Dataiku structured extraction API: com.dataiku.dip.docextraction.StructuredContent
    """
    node_type = None
    level: int

    def __init__(self, node_id: str, children: List, label, content: str = "", level: int = None, page_provenance: List[ProvenanceItem] = None):
        self.node_id = node_id
        self.children = children  # List of children DocumentNode
        self.label = label  # docling label see DocItemLabel, GroupLabel
        self.content = content  # Content, text or base64 for images
        self.level = level  # Level within the hierarchy
        self.page_provenance = page_provenance  # page provenance from the original document

    def to_dict(self):
        res = {
            "type": self.node_type,
            "content": [child.to_dict() for child in self.children],
        }
        if self.page_provenance:
            res["pages"] = [page.page_no for page in self.page_provenance]
        return res

class ImageNode(DocumentNode):
    node_type = "image"

    def __init__(self, node_id, children, label, content="", level=None, page_provenance=None, image=None):
        super().__init__(node_id, children, label, content, level, page_provenance)
        self.mime_type = image.mimetype if image is not None else None
        self.height = image.size.height if image is not None else None
        self.width = image.size.width if image is not None else None
        self.resolution = image.dpi if image is not None else None

    def to_dict(self):
        res = super().to_dict()
        res["mimeType"] = self.mime_type
        if self.height:
            res["height"] = self.height
        if self.width:
            res["width"] = self.width
        if self.resolution:
            res["resolution"] = self.resolution
        if self.content:
            res["description"] = self.content
        return res

class TextNode(DocumentNode):
    node_type = "text"

    def to_dict(self):
        res = super().to_dict()
        res["text"] = self.content
        return res

class TableNode(DocumentNode):
    node_type = "table"

    def to_dict(self):
        res = super().to_dict()
        res["text"] = self.content
        return res

class SectionNode(DocumentNode):
    node_type = "section"

    def to_dict(self):
        res = super().to_dict()
        res["title"] = self.content
        res["level"] = self.level
        return res

class RootNode(DocumentNode):
    node_type = "document"

class DoclingExtractorPipeline:
    batcher: Batcher[DoclingRequest, List[dict]]
    image_formats: List[str] = ["png", "jpg", "jpeg"]
    supported_formats: List[str] = ["pdf", "docx", "pptx", "html"] + image_formats

    def __init__(self):
        self.executor = ThreadPoolExecutor()
        self.batcher = Batcher[DoclingRequest, List[dict]](
            batch_size=4,
            timeout=1,
            process_batch=self._process_batch_async,
            group_by=lambda request: hash((os.path.splitext(request.file_name)[1], request.do_ocr, tuple(request.lang)))
        )

    def _run_batch_sync(self, requests: List[DoclingRequest]) -> List[dict]:
        logger.info("Processing a batch of %s document extraction requests" % len(requests))
        return self.docling_batch_structured_extract(requests)

    async def _process_batch_async(self, requests: List[DoclingRequest]) -> Awaitable[List[dict]]:
        return await asyncio.get_running_loop().run_in_executor(self.executor, self._run_batch_sync, requests)

    async def process_document(self, process_document_command):
        ocr_settings = process_document_command.get("ocrSettings", {})
        return await self.batcher.process(DoclingRequest(process_document_command["fileName"], process_document_command["documentContent"],
                                                         process_document_command["maxSectionDepth"], process_document_command.get("imageHandlingMode", "IGNORE") == "OCR",
                                                         ocr_settings.get("ocrEngine", "AUTO"), ocr_settings.get("ocrLanguages", [])))

    @staticmethod
    def find_docling_pdf_models_in_resources(mode: TableFormerMode) -> Optional[Path]:
        """
        Look in the resources folder whether the docling PDF models are available. Returns the path to the directory containing the models (that should be
        passed to docling) or None if one or more models is missing
        """
        if os.environ.get("DOCUMENT_EXTRACTION_MODELS") is None:
            return None
        models_path = Path(os.environ.get("DOCUMENT_EXTRACTION_MODELS")).expanduser()
        logger.info(u"Pdf extraction requires layout and tableformer model from docling. Checking if those models are already in the ressources folder.")
        layout_model_path = (models_path / LayoutModel._model_repo_folder / LayoutModel._model_path)
        logger.info(u"Searching for layout model in {}".format(layout_model_path))
        tableformer_path = (models_path / TableStructureModel._model_repo_folder / TableStructureModel._model_path / mode.value)
        logger.info(u"Searching for tableformer model in {}".format(tableformer_path))
        if layout_model_path.exists() and tableformer_path.exists():
            logger.info("Docling models for pdf extraction were found in the resources folder of the code env. Will use them for extraction")
            return models_path
        else:
            return None

    @staticmethod
    def detect_accelerator_device() -> AcceleratorDevice:
        # Detects if GPU is available else fallback on CPU or MPS (macOS).
        is_gpu_available = TorchGpuCapability.is_gpu_available()
        has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
        device = AcceleratorDevice.CPU
        if has_mps:
            device = AcceleratorDevice.MPS
        if is_gpu_available:
            device = AcceleratorDevice.CUDA

        if device == AcceleratorDevice.CPU:
            logger.info("No accelerator nor GPU was detected in the environment")
        else:
            logger.info(f"Detected the following accelerator option: {device.value}. Will be used for extraction")
        return device

    def docling_batch_structured_extract(self, requests: List[DoclingRequest]) -> List[dict]:
        """
        requests is a list of documents with the same extension. They also share in common the same settings for OCR. The settings of each document should lead
        to the same docling.document.PipelineOptions. For PDFs, we build a specific PipelineOptions for each `requests` object (it's not long). For png/jpeg/jpg we don't use
        docling so we process them differently (docling only supports pdf for OCR). For docx, pptx, html, the PipelineOptions object is always the same so it's
        an attribute of this class that we'll reuse.
        """
        if not requests:
            return []

        accelerator_options = AcceleratorOptions(device=self.detect_accelerator_device())

        extension = os.path.splitext(requests[0].file_name)[1].lower().lstrip(".")

        needs_ocr = any(req.do_ocr for req in requests)
        apply_additional_ocr_on_generated_images = False  # For docx/pptx, docling does not apply OCR on generated images, so we need to do it manually if needed.

        if needs_ocr:
            ocr_options = get_ocr_config(requests[0].lang, requests[0].ocr_engine, accelerator_options.device != AcceleratorDevice.CPU)
        else:
            ocr_options = None

        # Handle image formats without docling:
        if extension in self.image_formats:
            res = []
            if ocr_options is None:
                raise ValueError("OCR options must be provided for image extraction")
            text_results = process_images_with_ocr(ocr_options, [(idx, request.document_content) for idx, request in enumerate(requests)])
            for idx, request in enumerate(requests):
                if idx in text_results:
                    res.append(DoclingResponse(True, ImageNode(request.file_name, [], "image", text_results[idx], 0).to_dict(), None).to_dict())
                else:
                    #something went wrong with the OCR processing
                    res.append(DoclingResponse(False, None, f"Error processing image {request.file_name} with OCR").to_dict())
            return res
        # handle PDFs
        elif extension == "pdf":
            if needs_ocr:
                if ocr_options is None:
                    logger.warning("The provided OCR engine is not available, deactivating OCR for extraction")
                else:
                    ocr_options.force_full_page_ocr = False # To apply OCR on each image if pdf is not fully scanned. This will ensure we get good results on embedded images.
            pipeline_options = PdfPipelineOptions(do_table_structure=True, do_ocr=(needs_ocr and ocr_options is not None),
                                                  ocr_options=(ocr_options if ocr_options is not None else EasyOcrOptions()))  # Can't set ocr_options to None
            # Pdf extraction with docling requires those IBM models tableformer and layout:  https://huggingface.co/ds4sd/docling-models
            # We'll check if the models are available within the resource folder. If this is the case we'll use them for extraction. If not we'll continue regular
            # execution (check hf cache check + download if not present)
            pipeline_options.table_structure_options.mode = TableFormerMode.ACCURATE
            pipeline_options.artifacts_path = self.find_docling_pdf_models_in_resources(TableFormerMode.ACCURATE)
            pipeline_options.table_structure_options.do_cell_matching = True
            pipeline_options.accelerator_options = accelerator_options
            doc_converter = DocumentConverter(format_options={
                InputFormat.PDF: PdfFormatOption(
                    pipeline_options=pipeline_options,
                    pipeline_cls=StandardPdfPipeline, backend=DoclingParseV4DocumentBackend)
            })
        # Any other documents (docx, pptx, html)
        else:
            pipeline_options = PaginatedPipelineOptions(accelerator_options=accelerator_options)
            # Docling does not support OCR on docx, pptx, html, so we need to extract the images and apply "raw" OCR on them.
            if needs_ocr:
                if ocr_options is None:
                    logger.warning("The provided OCR engine is not available, deactivating OCR for extraction")
                else:
                    pipeline_options.generate_picture_images = True
                    apply_additional_ocr_on_generated_images = True

            doc_converter = DocumentConverter(format_options={
                InputFormat.DOCX: WordFormatOption(
                    pipeline_options=pipeline_options,
                    pipeline_cls=SimplePipeline, backend=MsWordDocumentBackend
                ),
                InputFormat.PPTX: PowerpointFormatOption(
                    pipeline_options=pipeline_options,
                    pipeline_cls=SimplePipeline, backend=MsPowerpointDocumentBackend
                ),
                InputFormat.HTML: HTMLFormatOption(
                    pipeline_options=pipeline_options,
                    pipeline_cls=SimplePipeline, backend=HTMLDocumentBackend
                )
            })

        # Do the docling conversion and map output to tree structure
        document_to_process = []
        for document in requests:
            byte_data = base64.b64decode(document.document_content)
            document_to_process.append(DocumentStream(name=document.file_name, stream=io.BytesIO(byte_data)))

        res = []
        for idx, document in enumerate(document_to_process):
            try:
                conversion_result = doc_converter.convert(document)
                res.append(
                    DoclingResponse(True, build_tree_with_outline(conversion_result.document, requests[idx].max_section_depth, True, needs_ocr,
                                                                  apply_additional_ocr_on_generated_images, ocr_options,
                                                                  accelerator_options.device != AcceleratorDevice.CPU).to_dict(), None).to_dict())
                logger.info(build_message_log_for_document(document.name, "Done processing document with docling"))
            except Exception as e:
                logger.exception("An error occurred during docling processing")
                res.append(DoclingResponse(False, None, ''.join(traceback.format_exception(type(e), e, e.__traceback__))).to_dict())

        return res

def build_tree_with_outline(document: DoclingDocument, max_depth: int, merge_list_items: bool, needs_ocr, apply_additional_ocr_on_generated_images=False,
                             ocr_options=None, is_gpu_available=False) -> RootNode:
    """
    Given a docling conversion result. We build a tree that represent the structure of the documents. The nodes can either be sections, texts, images or tables.

    :param document: result of Docling conversion
    :param max_depth: Deeper sections will be considered as text
    :param merge_list_items: Whether list items should be returned as separate text items.
    :param needs_ocr: Whether OCR should be applied on the document.
    :param apply_additional_ocr_on_generated_images: For docx/pptx, docling does not apply OCR on images, if needed, do it manually on the generated images.
    :param ocr_options: OCR options if additional OCR needs to be applied.
    :param is_gpu_available: Whether GPU is available for OCR processing.
    :return:
    """
    root = RootNode(document.body.self_ref, [], "root", "", 0)
    # We use a stack to keep track of the current title hierarchy. If a deeper section is encountered, it is added to the stack. This will ensure deeper text
    # have the whole outline. If a higher section is encountered, it is removed from the stack because it means we entered into a new section and the current
    # outline must be updated.
    stack = [root]
    # current_text holds the current text nodes that have not been yet added to the tree. Docling creates multiple paragraphs that we want to merge.
    # If we encounter a new section, an image, or a table, we need to merge the previous paragraphs and add them to the tree.
    current_text = []
    list_items_text = ""

    # We apply OCR on docx/pptx images because docling does not do it (only on PDFs). We do it as a batch here because easyocr.Reader is long to initialize
    images_text = {}
    if apply_additional_ocr_on_generated_images and ocr_options is not None:
        image_refs = []
        for picture in document.pictures:
            # Check that docling correctly included the base64 image in the document. If so apply OCR on it.
            if (
                    isinstance(picture.image, ImageRef)
                    and isinstance(picture.image.uri, AnyUrl)
                    and picture.image.uri.scheme == "data"
            ):
                image_refs.append((picture.self_ref, picture.image.uri.path.split(",")[1]))
        images_text = process_images_with_ocr(ocr_options, image_refs)

    for item in document.iterate_items():
        node_item = item[0]
        level = item[1]
        if level > max_depth:
            # Any title deeper than max_depth is returned as text node and merged with outer text content
            # We don't skip empty paragraph and node without any text.
            if hasattr(node_item, "text") and node_item.text:
                current_text.append(TextNode(node_item.self_ref, [], node_item.label, node_item.text, level,
                                             node_item.prov if hasattr(node_item, "prov") else None))
        else:
            if hasattr(node_item, "label"):
                if node_item.label == DocItemLabel.TITLE or node_item.label == DocItemLabel.SECTION_HEADER:
                    current_section = SectionNode(node_item.self_ref, [], node_item.label, node_item.text if hasattr(node_item, "text") else "", level,
                                                  node_item.prov if hasattr(node_item, "prov") else None)
                    if stack:
                        if current_text:
                            ## Add the previous paragraphs to the tree
                            stack[-1].children.append(merge_paragraphs(current_text))
                            current_text = []
                        if list_items_text:
                            stack[-1].children.append(
                                TextNode("", [], "merged_list_items", list_items_text, level, node_item.prov if hasattr(node_item, "prov") else None))
                            list_items_text = ""
                        # The stack carry the current hierarchy. If the current section is of a lower level (means higher ine the hierarchy), then
                        # we need to remove the deeper levels from the stack. Level 0 is the root, level 6 is the max depth for markdown.
                        while stack and stack[-1].level >= level:
                            stack.pop()
                        if stack:
                            stack[-1].children.append(current_section)
                    stack.append(current_section)
                elif node_item.label == DocItemLabel.LIST_ITEM and hasattr(node_item, "text"):
                    if current_text:
                        stack[-1].children.append(merge_paragraphs(current_text))
                        current_text = []
                    if merge_list_items:
                        list_items_text += node_item.text + "\n"
                    else:
                        stack[-1].children.append(
                            TextNode(node_item.self_ref, [], node_item.label, node_item.text, level, node_item.prov if hasattr(node_item, "prov") else None))
                elif node_item.label == DocItemLabel.PICTURE:
                    if stack:
                        if current_text:
                            stack[-1].children.append(merge_paragraphs(current_text))
                            current_text = []
                        content = ""
                        if needs_ocr:
                            if not apply_additional_ocr_on_generated_images:
                                # OCR was applied on the document, so we can look for text nodes inside the image.
                                if node_item.children:
                                    for child in node_item.children:
                                        # Resolve the children and append them to a single text node
                                        try:
                                            child_node = child.resolve(document)
                                            if child_node.label == DocItemLabel.TEXT:
                                                content += child_node.text + " "
                                        except Exception as e:
                                            logger.error(f"Error resolving child node {child.cref} in image {node_item.self_ref}: {e}")
                                            continue
                            else:
                                if node_item.self_ref in images_text:
                                    # Retrieve the OCR text for the image that we computed at the beginning of the function
                                    content = images_text[node_item.self_ref]
                        stack[-1].children.append(
                            ImageNode(node_item.self_ref, [], node_item.label, content, level, node_item.prov if hasattr(node_item, "prov") else None,
                                      node_item.image if hasattr(node_item, "image") else None))
                elif node_item.label == DocItemLabel.TABLE:
                    if stack:
                        if current_text:
                            stack[-1].children.append(merge_paragraphs(current_text))
                            current_text = []
                        stack[-1].children.append(TableNode(node_item.self_ref, [], node_item.label, node_item.export_to_markdown(document), level,
                                                            node_item.prov if hasattr(node_item, "prov") else None))
                else:
                    if stack:
                        # Make sure we don't add empty paragraphs
                        if hasattr(node_item, "text") and node_item.text:
                            current_text.append(TextNode(node_item.self_ref, [], node_item.label, node_item.text, level,
                                                         node_item.prov if hasattr(node_item, "prov") else None))
    # Add last current_text
    if stack and current_text:
        stack[-1].children.append(merge_paragraphs(current_text))
    return root

def merge_paragraphs(current_text: list[TextNode]) -> TextNode:
    node_id_parts, children, content_parts, provenance = [], [], [], []
    for node in current_text:
        node_id_parts.append(node.node_id)
        children.extend(node.children)
        content_parts.append(node.content)
        provenance.extend(node.page_provenance)
    node_id = "-".join(node_id_parts)
    content = "\n".join(content_parts)
    return TextNode(node_id, children, "merged_text", content, current_text[0].level if current_text else None, provenance)

def build_message_log_for_document(source_file_path, message):
    return f"[{source_file_path}] - {message}"
