
import base64
import tempfile
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Tuple

import dataiku
from common.backend.models.base import MediaSummary, UploadFileError
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.sql_timing import log_execution_time
from common.backend.utils.upload_utils import get_checked_config
from common.llm_assist.logging import logger
from dataikuapi.dss.document_extractor import DocumentExtractor, ManagedFolderDocumentRef
from PIL import Image

webapp_config: Dict[str, str] = dataiku_api.webapp_config


def resize_base64_image(b64_image: str, log:bool = False) -> str:
    """
    Resize slide image to improve performances.
    Do not resize too low due to risk of dataloss.
    """
    # original size of slide image after extraction is : (1440, 810)`
    DEFAULT_WIDTH = 720
    DEfAULT_HEIGHT = 405
    size:Tuple[int, int] = (DEFAULT_WIDTH, DEfAULT_HEIGHT) 
    ratio = dataiku_api.webapp_config.get("pptx_as_image_resize_ratio", 1)
    size = (int(size[0] * ratio), int(size[1] * ratio))

    image_data = base64.b64decode(b64_image)
    image = Image.open(BytesIO(image_data))

    if log:
        logger.debug(f"Image of size : {image.size} will be resized to {size} using a ratio of {ratio}")

    image_resized = image.resize(size, Image.Resampling.LANCZOS)
    if image_resized.mode != "RGB":
        image_resized = image_resized.convert("RGB")

    buffer = BytesIO()
    image_resized.save(buffer, format="JPEG")
    buffer.seek(0)
    return base64.b64encode(buffer.read()).decode('utf-8')


@log_execution_time
def extract_pptx_slides_as_imgs(file_path: str, file_data: bytes) -> List[str]:
    """
    Based on dss DocumentExtractor it will upload the pptx document into a tmp folder
    and it will generate screenshots for each slide
    Afterwards the tmp uploaded file will be removed
    Images are return as an array of b64 images. 
    They will be resized to the smallest size possible without compromising data
    """
    
    dataiku.Folder(webapp_config.get("upload_folder")).upload_stream(file_path, file_data)

    project_key = dataiku.default_project_key() 
    doc_extractor = DocumentExtractor(dataiku.api_client(), project_key)
    document_ref = ManagedFolderDocumentRef(file_path, webapp_config.get("upload_folder"))

    slides_as_imgs: List[str] = []
    for img_nb, image in enumerate(doc_extractor.generate_pages_screenshots(document_ref)):
        # image is of type InlineImageRef from dataikuapi.dss.document_extractor
        slides_as_imgs.append(resize_base64_image(image.as_json().get("content", ""), True if img_nb == 0 else False))

    dataiku.Folder(webapp_config.get("upload_folder")).delete_path(file_path)

    return slides_as_imgs


@log_execution_time
def extract_pptx_text(file_data: bytes, extension: str, slides_number: int) -> Tuple[str, bool]:
    # Importing doclingloader inside the function helps avoid increasing the
    # startup time of the web application. Docling has many internal dependencies
    # that slow down Flask's initial loading if imported globally (around +2s to start).
    from common.backend.utils.llm_utils import get_llm_capabilities
    from langchain_core.documents.base import (
        Document,  # Lazy import to prevent 'langchain_core' heavy modules to be loaded
    )
    from langchain_docling import DoclingLoader
    llm_caps = get_llm_capabilities()
    multimodal_enabled = bool(llm_caps.get("multi_modal"))
    allow_doc_as_image = bool(get_checked_config("allow_doc_as_image"))
    docs_per_page_as_image = int(get_checked_config("docs_per_page_as_image"))
    is_doc_as_image = False

    try:
        with tempfile.NamedTemporaryFile(delete=True) as temp_file:
            temp_file.write(file_data)
            temp_file.flush()
            logger.debug(f"extension {extension}, temp file {temp_file.name}")
            loader = DoclingLoader(file_path=temp_file.name)
            document: List[Document] = loader.load()

        # Note: len(document) will not return the number of slides.
        # LangChain loaders parse the file into semantic chunks of text (e.g., paragraphs, titles),
        # so the length represents the total number of these chunks, not the slide count.

        extracted_text = ""
        is_doc_as_image = slides_number <= docs_per_page_as_image and multimodal_enabled and allow_doc_as_image
        logger.debug(f"PPTX slides_nb / docs_per_page_as_image : {slides_number}/{docs_per_page_as_image}, is_doc_as_image : {is_doc_as_image}, multimodal_enabled : {multimodal_enabled}")

        for page in document:
            extracted_text += f"""
                {page.page_content}
            """
        return extracted_text, is_doc_as_image
    except Exception as e:
        logger.exception(f"Error in extract_pdf_text: {e}")
        raise Exception(UploadFileError.PARSING_ERROR.value)


def save_pptx_slides_next_to_original_document(slides_as_images: List[str], pptx_path: str):
    filepath = Path(pptx_path)
    slides_filepath = filepath.with_suffix('')
    upload_folder = dataiku.Folder(webapp_config.get("upload_folder"))

    stored_imgs = []

    for slide_nb, b64_image in enumerate(slides_as_images):
        image_data = base64.b64decode(b64_image)
        image_stream = BytesIO(image_data)
        slide_img_path = f"{slides_filepath}_slide_{slide_nb+1}.png"
        upload_folder.upload_stream(slide_img_path, image_stream)
        stored_imgs.append(slide_img_path)

    return stored_imgs


def load_pptx_slides_from_summary(summary: MediaSummary) -> List[str]:
    """
    Retrieve the slides path and from the metadata then load
    the images into b64 string
    """
    folder = dataiku_api.folder_handle
    metadata_path = summary.get("metadata_path", "")
    extract_summary = folder.read_json(metadata_path)
    slides_path = extract_summary.get("extracted_images_path", [])
    upload_folder = dataiku.Folder(webapp_config.get("upload_folder"))
    encoded_images = []
    
    for filepath in slides_path:
        try:
            with upload_folder.get_download_stream(filepath) as f:
                image_data = f.read()
                encoded_image = base64.b64encode(image_data).decode('utf-8')
                encoded_images.append(encoded_image)
        except Exception as e:
            logger.exception(f"Error while loading images from folder {e}")
            raise Exception(UploadFileError.PARSING_ERROR.value)
    return encoded_images