import asyncio
import base64
import io
import logging
import os
import traceback
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import List, Awaitable, Optional

import torch
from docling.backend.docling_parse_v2_backend import DoclingParseV2DocumentBackend
from docling.backend.html_backend import HTMLDocumentBackend
from docling.backend.mspowerpoint_backend import MsPowerpointDocumentBackend
from docling.backend.msword_backend import MsWordDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions, TableFormerMode, AcceleratorOptions, AcceleratorDevice, PipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, PowerpointFormatOption, HTMLFormatOption
from docling.models.layout_model import LayoutModel
from docling.models.table_structure_model import TableStructureModel
from docling.pipeline.simple_pipeline import SimplePipeline
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
from docling_core.types import DoclingDocument
from docling_core.types.doc import DocItemLabel, ProvenanceItem
from docling_core.types.io import DocumentStream

from dataiku.base.batcher import Batcher
from dataiku.doctor.utils.gpu_execution import TorchGpuCapability

logger = logging.getLogger("docling_extraction")


@dataclass
class DoclingRequest:
    file_name: str
    document_content: str
    max_section_depth: int

@dataclass
class DoclingResponse:
    ok: bool
    resp: dict
    error: str
    def to_dict(self):
        return asdict(self)

class DocumentNode:
    """
    A tree node structure to map docling converted structure to Dataiku structured extraction API: com.dataiku.dip.docextraction.StructuredContent
    """
    node_type = None
    level: int

    def __init__(self, node_id: str, children: List, label, content: str = "", level: int = None, page_provenance: List[ProvenanceItem] = None):
        self.node_id = node_id
        self.children = children  # List of children DocumentNode
        self.label = label  # docling label see DocItemLabel, GroupLabel
        self.content = content  # Content, text or base64 for images
        self.level = level  # Level within the hierarchy
        self.page_provenance = page_provenance  # page provenance from the original document

    def to_dict(self):
        res = {
            "type": self.node_type,
            "content": [child.to_dict() for child in self.children],
        }
        if self.page_provenance:
            res["pages"] = [page.page_no for page in self.page_provenance]
        return res

class ImageNode(DocumentNode):
    node_type = "image"

    def __init__(self, node_id, children, label, content="", level=None, page_provenance=None, image=None):
        super().__init__(node_id, children, label, content, level, page_provenance)
        self.mime_type = image.mimetype if image is not None else None
        self.height = image.size.height if image is not None else None
        self.width = image.size.width if image is not None else None
        self.resolution = image.dpi if image is not None else None

    def to_dict(self):
        res = super().to_dict()
        res["mimeType"] = self.mime_type
        if self.height:
            res["height"] = self.height
        if self.width:
            res["width"] = self.width
        if self.resolution:
            res["resolution"] = self.resolution
        return res

class TextNode(DocumentNode):
    node_type = "text"

    def to_dict(self):
        res = super().to_dict()
        res["text"] = self.content
        return res

class TableNode(DocumentNode):
    node_type = "table"

    def to_dict(self):
        res = super().to_dict()
        res["text"] = self.content
        return res

class SectionNode(DocumentNode):
    node_type = "section"

    def to_dict(self):
        res = super().to_dict()
        res["title"] = self.content
        res["level"] = self.level
        return res

class GroupNode(DocumentNode):
    node_type = "group"

class RootNode(DocumentNode):
    node_type = "document"

class DoclingExtractorPipeline:
    batcher: Batcher[DoclingRequest, List[dict]]
    pdf_pipeline_options: PdfPipelineOptions
    simple_pipeline_options: PipelineOptions

    def __init__(self):
        self.pdf_pipeline_options = PdfPipelineOptions(do_table_structure=True, do_ocr=False, do_picture_description=False)
        self.pdf_pipeline_options.table_structure_options.do_cell_matching = True
        self.pdf_pipeline_options.table_structure_options.mode = TableFormerMode.ACCURATE
        accelerator_options = AcceleratorOptions(device=self.detect_accelerator_device())
        self.pdf_pipeline_options.accelerator_options = accelerator_options
        self.simple_pipeline_options = PipelineOptions(accelerator_options=accelerator_options)
        self.executor = ThreadPoolExecutor()
        self.batcher = Batcher[DoclingRequest, List[dict]](
            batch_size=4,
            timeout=1,
            process_batch=self._process_batch_async,
            group_by=lambda request: os.path.splitext(request.file_name)[1]
        )
    def _run_batch_sync(self, requests: List[DoclingRequest]) -> List[dict]:
        logger.info("Processing a batch of %s document extraction requests" % len(requests))
        return self.docling_batch_structured_extract(requests)

    async def _process_batch_async(self, requests: List[DoclingRequest]) -> Awaitable[List[dict]]:
        return await asyncio.get_running_loop().run_in_executor(self.executor, self._run_batch_sync, requests)

    async def process_document(self, process_document_command):
        return await self.batcher.process(DoclingRequest(process_document_command["fileName"], process_document_command["documentContent"],
                                                  process_document_command["maxSectionDepth"]))

    def find_docling_pdf_models_in_resources(self) -> Optional[Path]:
        """
        Look in the resources folder whether the docling PDF models are available. Returns the path to the directory containing the models (that should be
        passed to docling) or None if one or more models is missing
        """
        if os.environ.get("DOCUMENT_EXTRACTION_MODELS") is None:
            return None
        models_path = Path(os.environ.get("DOCUMENT_EXTRACTION_MODELS")).expanduser()
        logger.info(u"Pdf extraction requires layout and tableformer model from docling. Checking if those models are already in the ressources folder.")
        layout_model_path = (models_path / LayoutModel._model_repo_folder / LayoutModel._model_path)
        logger.info(u"Searching for layout model in {}".format(layout_model_path))
        tableformer_path = (models_path / TableStructureModel._model_repo_folder / TableStructureModel._model_path / self.pdf_pipeline_options.table_structure_options.mode.value)
        logger.info(u"Searching for tableformer model in {}".format(tableformer_path))
        if layout_model_path.exists() and tableformer_path.exists():
            logger.info("Docling models for pdf extraction were found in the resources folder of the code env. Will use them for extraction")
            return models_path
        else:
            return None

    @staticmethod
    def detect_accelerator_device() -> AcceleratorDevice:
        # Detects if GPU is available else fallback on CPU or MPS (macOS).
        is_gpu_available = TorchGpuCapability.is_gpu_available()
        has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
        device = AcceleratorDevice.CPU
        if has_mps:
            device = AcceleratorDevice.MPS
        if is_gpu_available:
            device = AcceleratorDevice.CUDA

        if device == AcceleratorDevice.CPU:
            logger.info("No accelerator nor GPU was detected in the environment")
        else:
            logger.info(f"Detected the following accelerator option: {device.value}. Will be used for extraction")
        return device

    def docling_batch_structured_extract(self, requests: List[DoclingRequest]) -> List[dict]:
        document_to_process = []
        request_has_pdf = any(req.file_name.lower().endswith('.pdf') for req in requests)

        if request_has_pdf:
            # Pdf extraction with docling requires those IBM models tableformer and layout:  https://huggingface.co/ds4sd/docling-models
            # We'll check if the models are available within the resource folder. If this is the case we'll use them for extraction. If not we'll continue regular
            # execution (check hf cache check + download if not present)
            self.pdf_pipeline_options.artifacts_path = self.find_docling_pdf_models_in_resources()
        for document in requests:
            byte_data = base64.b64decode(document.document_content)
            document_to_process.append(DocumentStream(name=document.file_name, stream=io.BytesIO(byte_data)))
        doc_converter = (
            DocumentConverter(
                allowed_formats=[
                    InputFormat.PDF,
                    InputFormat.DOCX,
                    InputFormat.PPTX,
                    InputFormat.HTML
                ],
                format_options={
                    InputFormat.PDF: PdfFormatOption(
                        pipeline_options=self.pdf_pipeline_options,
                        pipeline_cls=StandardPdfPipeline, backend=DoclingParseV2DocumentBackend
                    ),
                    InputFormat.DOCX: WordFormatOption(
                        pipeline_options=self.simple_pipeline_options,
                        pipeline_cls=SimplePipeline, backend=MsWordDocumentBackend
                    ),
                    InputFormat.PPTX: PowerpointFormatOption(
                        pipeline_options=self.simple_pipeline_options,
                        pipeline_cls=SimplePipeline, backend=MsPowerpointDocumentBackend
                    ),
                    InputFormat.HTML: HTMLFormatOption(
                        pipeline_options=self.simple_pipeline_options,
                        pipeline_cls=SimplePipeline, backend=HTMLDocumentBackend
                    ),
                },
            )
        )
        res = []

        for idx, document in enumerate(document_to_process):
            try:
                conversion_result = doc_converter.convert(document)
                res.append(DoclingResponse(True, build_tree_with_outline(conversion_result.document, requests[idx].max_section_depth, True).to_dict(), None).to_dict())
            except Exception as e:
                logger.exception("An error occurred during docling processing")
                res.append(DoclingResponse(False, None, ''.join(traceback.format_exception(type(e), e, e.__traceback__))).to_dict())
        return res

def build_tree_with_outline(document: DoclingDocument, max_depth: int, merge_list_items: bool) -> RootNode:
    """
    Given a docling conversion result. We build a tree that represent the structure of the documents. The nodes can either be sections, texts, images or tables.

    :param document: result of Docling conversion
    :param max_depth: Deeper sections will be considered as text
    :param merge_list_items: Whether list items should be returned as separate text items.
    :return:
    """
    root = RootNode(document.body.self_ref, [], "root", "", 0)
    # We use a stack to keep track of the current title hierarchy. If a deeper section is encountered, it is added to the stack. This will ensure deeper text
    # have the whole outline. If a higher section is encountered, it is removed from the stack because it means we entered into a new section and the current
    # outline must be updated.
    stack = [root]
    # current_text holds the current text nodes that have not been yet added to the tree. Docling creates multiple paragraphs that we want to merge.
    # If we encounter a new section, an image, or a table, we need to merge the previous paragraphs and add them to the tree.
    current_text = []
    list_items_text = ""
    for item in document.iterate_items():
        node_item = item[0]
        level = item[1]
        if level > max_depth:
            # Any title deeper than max_depth is returned as text node and merged with outer text content
            # We don't skip empty paragraph and node without any text.
            if hasattr(node_item, "text") and node_item.text:
                current_text.append(TextNode(node_item.self_ref, [], node_item.label, node_item.text, level,
                                         node_item.prov if hasattr(node_item, "prov") else None))
        else:
            if hasattr(node_item, "label"):
                if node_item.label == DocItemLabel.TITLE or node_item.label == DocItemLabel.SECTION_HEADER:
                    current_section = SectionNode(node_item.self_ref, [], node_item.label, node_item.text if hasattr(node_item, "text") else "", level,
                                                  node_item.prov if hasattr(node_item, "prov") else None)
                    if not stack:
                        stack.append(current_section)
                    else:
                        if current_text:
                            ## Add the previous paragraphs to the tree
                            stack[-1].children.append(merge_paragraphs(current_text))
                            current_text = []
                        if list_items_text:
                            stack[-1].children.append(
                                TextNode("", [], "merged_list_items", list_items_text, level, node_item.prov if hasattr(node_item, "prov") else None))
                            list_items_text = ""
                        # The stack carry the current hierarchy. If the current section is of a lower level (means higher ine the hierarchy), then
                        # we need to remove the deeper levels from the stack. Level 0 is the root, level 6 is the max depth for markdown.
                        while stack and stack[-1].level >= level:
                            stack.pop()
                        if stack:
                            stack[-1].children.append(current_section)
                    stack.append(current_section)
                elif node_item.label == DocItemLabel.LIST_ITEM and hasattr(node_item, "text"):
                    if current_text:
                        stack[-1].children.append(merge_paragraphs(current_text))
                        current_text = []
                    if merge_list_items:
                        list_items_text += node_item.text + "\n"
                    else:
                        stack[-1].children.append(
                            TextNode(node_item.self_ref, [], node_item.label, node_item.text, level, node_item.prov if hasattr(node_item, "prov") else None))
                elif node_item.label == DocItemLabel.PICTURE:
                    if stack:
                        if current_text:
                            stack[-1].children.append(merge_paragraphs(current_text))
                            current_text = []
                        stack[-1].children.append(
                            ImageNode(node_item.self_ref, [], node_item.label, "placeholder", level, node_item.prov if hasattr(node_item, "prov") else None, node_item.image if hasattr(node_item, "image") else None))
                elif node_item.label == DocItemLabel.TABLE:
                    if stack:
                        if current_text:
                            stack[-1].children.append(merge_paragraphs(current_text))
                            current_text = []
                        stack[-1].children.append(TableNode(node_item.self_ref, [], node_item.label, node_item.export_to_markdown(), level,
                                                            node_item.prov if hasattr(node_item, "prov") else None))
                else:
                    if stack:
                        # Make sure we don't add empty paragraphs
                        if hasattr(node_item, "text") and node_item.text:
                            current_text.append(TextNode(node_item.self_ref, [], node_item.label, node_item.text, level,
                                                     node_item.prov if hasattr(node_item, "prov") else None))
    # Add last current_text
    if stack and current_text:
        stack[-1].children.append(merge_paragraphs(current_text))
    return root

def merge_paragraphs(current_text: list[TextNode]) -> TextNode:
    node_id_parts, children, content_parts, provenance = [], [], [], []
    for node in current_text:
        node_id_parts.append(node.node_id)
        children.extend(node.children)
        content_parts.append(node.content)
        provenance.extend(node.page_provenance)
    node_id = "-".join(node_id_parts)
    content = "\n".join(content_parts)
    return TextNode(node_id, children, "merged_text", content, current_text[0].level if current_text else None, provenance)

def build_message_log_for_document(source_file_path, message):
    return f"[{source_file_path}] - {message}"
