import base64
import io
import os
import json
import copy
from typing import List, Dict, Any, Tuple, Optional

import dataiku
from backend.config import get_uploads_managedfolder_id
from backend.utils.logging_utils import get_logger
from PIL import Image

try:
    from dataikuapi.dss.document_extractor import DocumentExtractor, ManagedFolderDocumentRef
except ImportError:  # pragma: no cover
    DocumentExtractor = None  # type: ignore
    ManagedFolderDocumentRef = None  # type: ignore

logger = get_logger(__name__)



def _ensure_folder_prefix(folder, base_path: str) -> None:
    keep_path = f"{base_path}/.init"
    try:
        folder.put_file(keep_path, b"")
    except Exception:
        pass


def _decode_image(image_payload: Any) -> bytes:
    if isinstance(image_payload, bytes):
        return image_payload
    if isinstance(image_payload, str):
        return base64.b64decode(image_payload)
    raise TypeError(f"Unsupported screenshot payload type: {type(image_payload)}")


def resize_with_ratio(image: Image.Image) -> Image.Image:
    width, height = image.size
    # Define bounding box depending on orientation
    if width >= height:
        # landscape
        max_w, max_h = 720, 405
    else:
        # portrait
        max_w, max_h = 405, 720
    # Compute scale to fit inside that box while keeping aspect ratio
    scale_w = max_w / width
    scale_h = max_h / height
    scale = min(scale_w, scale_h, 1.0)  # don't upscale past 100%
    new_w = int(round(width * scale))
    new_h = int(round(height * scale))
    return image.resize((new_w, new_h), Image.LANCZOS)


def _flatten_using_dfs(node: Dict[str, Any], current_outline: List[str]) -> List[Dict[str, Any]]:
    """
    Flatten structured document content using depth-first search.
    
    Args:
        node: The current node in the structured content tree
        current_outline: The current outline path (list of section titles)
        
    Returns:
        List of flattened chunks with text, outline, and pages
    """
    if not node or "type" not in node:
        return []

    # Extract page information (start, end, or list of pages)
    pages = []
    if "pageRange" in node and "start" in node["pageRange"]:
        pages = list(range(node["pageRange"]["start"], node["pageRange"]["end"] + 1))

    # Handle text and table
    if node["type"] in ["text", "table"]:
        if not node.get("text"):
            return []
        return [{
            "text": node["text"],
            "outline": current_outline,
            "pages": pages
        }]

    # Handle image
    elif node["type"] == "image":
        if not node.get("description"):
            return []
        return [{
            "text": node["description"],
            "outline": current_outline,
            "pages": pages
        }]

    # Handle document or section
    elif node["type"] in ["document", "section"]:
        if "content" not in node:
            return []
        deeper_outline = copy.deepcopy(current_outline)
        if node["type"] == "section" and "title" in node:
            deeper_outline.append(node["title"])
        chunks = []
        for child in node.get("content", []):
            chunks.extend(_flatten_using_dfs(child, deeper_outline))
        return chunks

    else:
        raise ValueError("Unsupported structured content type: " + node["type"])


def _chunks_to_text(chunks: List[Dict[str, Any]]) -> str:
    """Convert structured extractor chunks into a readable text blob."""
    lines: List[str] = []
    for chunk in chunks:
        text = (chunk.get("text") or "").strip()
        if not text:
            continue
        outline = [part for part in chunk.get("outline", []) if part]
        if outline:
            lines.append(" > ".join(outline))
        lines.append(text)
        lines.append("")  # spacer line
    return "\n".join(lines).strip()


def _extract_document_text(
    doc_extractor,
    document_ref,
    folder,
    output_base: str,
    image_handling_mode: str = "IGNORE",
    llm_id: Optional[str] = None,
    output_managed_folder: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Run structured extraction to obtain text-only content for a document.
    
    Args:
        doc_extractor: DocumentExtractor instance
        document_ref: ManagedFolderDocumentRef instance
        folder: Dataiku folder instance
        output_base: Base path for output files
        image_handling_mode: How to handle images ("IGNORE", "OCR", or "VLM_ANNOTATE")
        llm_id: LLM ID for VLM_ANNOTATE mode (optional)
        output_managed_folder: Managed folder ID for storing images in VLM_ANNOTATE mode (optional)
    """
    try:
        extract_kwargs = {
            "document": document_ref,
            "image_handling_mode": image_handling_mode,
        }
        
        if image_handling_mode == "OCR":
            extract_kwargs["ocr_engine"] = "AUTO"
        elif image_handling_mode == "VLM_ANNOTATE":
            if llm_id:
                extract_kwargs["llm_id"] = llm_id
            if output_managed_folder:
                extract_kwargs["output_managed_folder"] = output_managed_folder
        
        structured_response = doc_extractor.structured_extract(**extract_kwargs)
        structured_response._fail_unless_success()
        
        # Use _flatten_using_dfs instead of text_chunks
        if hasattr(structured_response, 'content') and structured_response.content:
            chunks = _flatten_using_dfs(structured_response.content, [])
        else:
            # Fallback to text_chunks if content is not available
            chunks = structured_response.text_chunks or []
    except Exception as err:  # pragma: no cover - rely on logging
        logger.exception("Failed to extract text for document %s",
                         document_ref.path if hasattr(document_ref, 'path') else document_ref)
        return {}
    
    # Save chunks as JSON
    text_path = f"{output_base}/extracted.json"
    try:
        chunks_json = json.dumps(chunks, ensure_ascii=False, indent=2)
        folder.put_file(text_path, io.BytesIO(chunks_json.encode("utf-8")))
    except Exception as err:  # pragma: no cover - rely on logging
        logger.exception("Failed to persist extracted chunks for %s: %s",
                         document_ref.path if hasattr(document_ref, 'path') else document_ref, err)
        text_path = None
    metadata: Dict[str, Any] = {}
    if text_path:
        metadata["text_path"] = text_path
    return metadata


def process_derived_documents(
        store,
        conv_id: str,
        user_id: str,
        attachments: List[Dict[str, Any]],
        *,
        extraction_mode: str = "pagesScreenshots",
) -> None:
    """
    Populate derived document artifacts for each attachment.
    
    Args:
        extraction_mode: "pagesScreenshots" (default) to generate screenshots, or "pagesText" for text-only extraction
    """
    if not attachments:
        return
    if DocumentExtractor is None or ManagedFolderDocumentRef is None:
        logger.warning("DocumentExtractor not available; skipping derived document processing")
        return

    folder_id = get_uploads_managedfolder_id()
    client = dataiku.api_client()
    project = client.get_default_project()
    project_key = project.project_key
    folder = project.get_managed_folder(folder_id)
    doc_extractor = DocumentExtractor(client, project_key)

    existing_docs = {
        doc.get("document_path"): doc
        for doc in store.get_derived_documents(conv_id) or []
    }

    # Build a mapping from document_path to attachment info
    attachment_map = {att.get("document_path"): att for att in attachments if att.get("document_path")}

    # Get all unique source paths from both existing docs and new attachments
    all_source_paths = set(existing_docs.keys()) | set(attachment_map.keys())

    for document_path in all_source_paths:
        attachment = attachment_map.get(document_path)
        doc_record = existing_docs.get(document_path)
        doc_name = (attachment or doc_record or {}).get("document_name")

        if not doc_name or not document_path:
            continue

        file_ext = os.path.splitext(doc_name.lower())[1].lstrip(".")
        doc_meta = doc_record.get("metadata", {}) if doc_record else {}
        status = doc_meta.get("status")
        existing_snapshots = doc_meta.get("snapshots", []) if doc_meta else []
        has_snapshots = bool(existing_snapshots)
        has_text = bool(doc_meta.get("text_path")) if doc_meta else False

        if document_path.startswith("inputs/"):
            output_base = document_path.replace("inputs/", "outputs/", 1)
        else:
            output_base = f"outputs/{document_path}"

        _ensure_folder_prefix(folder, output_base)
        document_ref = ManagedFolderDocumentRef(document_path, folder_id)

        # HANDLE SIMPLE TEXT FILES
        if file_ext in {"txt", "md", "html"}:
            try:
                with folder.get_file(document_path) as response:
                    content_bytes = response.raw.read()
                text_path = f"{output_base}/content.txt"
                with io.BytesIO(content_bytes) as bstream:
                    folder.put_file(text_path, bstream)
                metadata_payload = {
                    "status": "ready",
                    "text_path": text_path,
                    "snapshots": [],
                }
                store.upsert_derived_document(conv_id, doc_name, document_path, metadata_payload)
            except Exception as err:
                store.upsert_derived_document(
                    conv_id,
                    doc_name,
                    document_path,
                    {"status": "failed", "error": str(err), "snapshots": []},
                )
            continue

        # HANDLE IMAGES (png, jpg, jpeg)
        if file_ext in {"png", "jpg", "jpeg"}:
            try:
                with folder.get_file(document_path) as response:
                    image_bytes = response.raw.read()

                target_path = f"{output_base}/page_1.{file_ext}"
                # Open image, resize with ratio, and persist
                img = Image.open(io.BytesIO(image_bytes))
                img = resize_with_ratio(img)
                out_buf = io.BytesIO()
                save_format = "JPEG" if file_ext in {"jpg", "jpeg"} else "PNG"
                img.save(out_buf, format=save_format)
                out_buf.seek(0)
                folder.put_file(target_path, out_buf)

                metadata_payload = {
                    "status": "ready",
                    "snapshots": [{"page_number": 1, "screenshot_path": target_path}],
                }
                store.upsert_derived_document(conv_id, doc_name, document_path, metadata_payload)
            except Exception as err:
                store.upsert_derived_document(
                    conv_id,
                    doc_name,
                    document_path,
                    {"status": "failed", "error": str(err), "snapshots": []},
                )
            continue

        # DEFAULT: PDF / DOCX / PPTX
        if extraction_mode == "pagesText":
            if has_text:
                continue
            
            # Get text extraction type from config (already in API format)
            from backend.config import get_text_extraction_type, get_conversation_vision_llm,get_default_llm_id
            image_handling_mode = get_text_extraction_type()
            
            # Prepare parameters for VLM_ANNOTATE mode
            llm_id = get_conversation_vision_llm() if image_handling_mode == "VLM_ANNOTATE" else get_default_llm_id()
            text_metadata = _extract_document_text(
                doc_extractor,
                document_ref,
                folder,
                output_base,
                image_handling_mode=image_handling_mode,
                llm_id=llm_id
            )
            if not text_metadata:
                store.upsert_derived_document(
                    conv_id,
                    doc_name,
                    document_path,
                    {"snapshots": existing_snapshots, "status": "failed", "error": "text_extraction_failed"},
                )
                continue

            metadata_payload = {"snapshots": existing_snapshots, "status": "ready", **text_metadata}
            store.upsert_derived_document(conv_id, doc_name, document_path, metadata_payload)
            continue

        if status == "ready" and has_snapshots:
            continue

        metadata_entries = []
        try:
            screenshots_response = doc_extractor.generate_pages_screenshots(
                document=document_ref,
                output_managed_folder=None,
                fetch_size=50,
            )
            total_pages = screenshots_response.total_count

            for page_index in range(total_pages):
                screenshot = screenshots_response.fetch_screenshot(page_index)
                image_bytes = _decode_image(screenshot.image)
                # Resize screenshot and persist
                img = Image.open(io.BytesIO(image_bytes))
                img = resize_with_ratio(img)
                out_buf = io.BytesIO()
                img.save(out_buf, format="PNG")
                out_buf.seek(0)

                target_path = f"{output_base}/page_{page_index + 1}.png"
                folder.put_file(target_path, out_buf)
                metadata_entries.append({"page_number": page_index + 1, "screenshot_path": target_path})

            logger.info(
                "Derived document ready for conv=%s doc=%s (%d pages)",
                conv_id,
                doc_name,
                len(metadata_entries),
            )
        except Exception as err:
            logger.exception("Failed to process derived document %s for conv %s", doc_name, conv_id)
            store.upsert_derived_document(
                conv_id,
                doc_name,
                document_path,
                {"snapshots": [], "status": "failed", "error": str(err)},
            )
            continue

        metadata_payload = {"snapshots": metadata_entries, "status": "ready"}
        store.upsert_derived_document(conv_id, doc_name, document_path, metadata_payload)


def get_structured_documents(store, conv_id: str, user_id: str) -> Dict[str, Dict[str, Any]]:
    """Return derived document metadata keyed by source path."""
    derived_docs = store.get_derived_documents(conv_id)
    documents_out: Dict[str, Dict[str, Any]] = {}
    if not derived_docs:
        return documents_out

    global_snapshot_id = 1

    for doc in derived_docs:
        metadata = doc.get("metadata") or {}
        snapshots_meta = metadata.get("snapshots") or []
        doc_name = doc["document_name"]
        prepared_snapshots: list[dict] = []

        # Read text content from the folder where it's stored
        text_content = ""
        text_path = metadata.get("text_path")
        if text_path:
            try:
                folder_id = get_uploads_managedfolder_id()
                import dataiku
                folder = dataiku.Folder(folder_id)
                with folder.get_download_stream(text_path) as stream:
                    text_bytes = stream.read()
                    try:
                        text_content = text_bytes.decode("utf-8", errors="replace")
                    except Exception:
                        text_content = ""
            except Exception:
                # best-effort fallback; keep empty on failure
                text_content = ""

        # If extraction failed and there is no text content, synthesize a short message
        status = metadata.get("status")
        error_msg = metadata.get("error")
        if (not text_content) and status == "failed":
            human_doc_name = doc_name or doc.get("document_path") or "Document"
            reason = f": {error_msg}" if error_msg else ""
            text_content = (
                f"Extraction failed for \"{human_doc_name}\"{reason}. "
                "No readable content is available for this file."
            )

        for snap in snapshots_meta:
            path = snap.get("screenshot_path")
            if not path:
                continue

            prepared_snapshots.append(
                {
                    "id": global_snapshot_id,
                    "file_name": path.rsplit("/", 1)[-1],
                    "page": snap.get("page") or snap.get("page_number"),
                    "screenshot_path": path,
                }
            )
            global_snapshot_id += 1

        documents_out[doc.get("document_path")] = {
            "name": doc_name,
            "snapshots": prepared_snapshots,
            "status": status,
            "error": error_msg,
            "text": text_content,
            "document_path": doc.get("document_path"),
        }

    return documents_out


def count_conversation_images(store, conv_id: str) -> int:
    """
    Count the total number of images (snapshots) generated for all documents in a conversation.
    
    Args:
        store: The database store instance
        conv_id: The conversation ID
        
    Returns:
        Total count of image snapshots across all derived documents
    """
    derived_docs = store.get_derived_documents(conv_id)
    total_images = 0

    for doc in derived_docs:
        metadata = doc.get("metadata") or {}
        snapshots = metadata.get("snapshots") or []
        total_images += len(snapshots)

    logger.debug(f"Conversation {conv_id} has {total_images} total images across {len(derived_docs)} documents")
    return total_images


def predict_new_image_count(attachments: List[Dict[str, Any]], folder_id: str) -> int:
    """
    Predict the number of images that would be generated from new attachments.
    
    This function extracts page counts from documents without generating screenshots,
    allowing quota decisions to be made before processing.
    
    Args:
        attachments: List of attachment metadata dicts with 'path' keys
        folder_id: The managed folder ID where documents are stored
        
    Returns:
        Predicted total number of images (pages) across all attachments
    """
    if not attachments:
        return 0

    try:
        from dataikuapi.dss.document_extractor import DocumentExtractor, ManagedFolderDocumentRef
    except ImportError:
        logger.warning("DocumentExtractor not available; cannot predict image count")
        return 0

    import dataiku
    client = dataiku.api_client()
    project = client.get_default_project()
    doc_extractor = DocumentExtractor(client, project.project_key)

    total_pages = 0

    for attachment in attachments:
        path = attachment.get("document_path")
        if not path:
            continue

        try:
            document_ref = ManagedFolderDocumentRef(path, folder_id)
            # Use generate_pages_screenshots with fetch_size=0 to get metadata only
            # This is much faster than actually generating screenshots
            response = doc_extractor.generate_pages_screenshots(
                document=document_ref,
                output_managed_folder=None,
                fetch_size=1,
            )
            page_count = response.total_count
            total_pages += page_count
            logger.debug(f"Document {attachment.get('document_name', path)} has {page_count} pages")
        except Exception as err:
            logger.warning(f"Failed to extract page count for {path}: {err}")
            # Continue with other documents even if one fails
            continue

    logger.info(f"Predicted {total_pages} total images from {len(attachments)} new attachments")
    return total_pages
