import tempfile
from typing import List, Optional, Union

from common.backend.constants import PROMPT_SEPARATOR_LENGTH
from common.backend.models.base import MediaSummary, UploadChainTypes, UploadFileError
from common.backend.utils.sql_timing import log_execution_time
from common.backend.utils.upload_utils import save_extracted_json
from common.llm_assist.logging import logger
from common.solutions.chains.summary.text_extraction_summary_chain import TextExtractionSummaryChain


@log_execution_time
def extract_plain_text(file_data: bytes) -> str:
    try:
        extracted_text = file_data.decode("utf-8", errors="replace")
        return extracted_text
    except Exception as e:
        logger.exception(f"Error in extract_plain_text: {e}")
        raise Exception(UploadFileError.PARSING_ERROR.value)


@log_execution_time
def extract_docx_text(file_data: bytes) -> str:
    from langchain_community.document_loaders import Docx2txtLoader
    from langchain_core.documents.base import (
        Document,  # Lazy import to prevent 'langchain_core' heavy modules to be loaded
    )
    try:
        with tempfile.NamedTemporaryFile(delete=True) as temp_file:
            temp_file.write(file_data)
            temp_file.flush()
            loader = Docx2txtLoader(temp_file.name)
            document: List[Document] = loader.load()

        extracted_text = ""
        for page in document:
            extracted_text += f"""{'-'*PROMPT_SEPARATOR_LENGTH}
        page: {page.metadata.get('page', 'Unknown')}
        {page.page_content}
        """
        return extracted_text
    except Exception as e:
        logger.exception(f"Error in extract_docx_text: {e}")
        raise Exception(UploadFileError.PARSING_ERROR.value)


@log_execution_time
def extract_text_summary(
    file_path: str, file_data: bytes, secure_name: str, extension: str, language: Optional[str], begin_time: int
) -> MediaSummary:
    extracted_text: str
    media_summary: Union[MediaSummary, None]
    if extension == "docx":
        extracted_text = extract_docx_text(file_data)
    else:
        extracted_text = extract_plain_text(file_data)
    media_summary = TextExtractionSummaryChain(extracted_text, secure_name, language).get_summary()
    if media_summary is None:
        raise Exception(UploadFileError.PARSING_ERROR.value)
    if media_summary.get("summary") is None:
        media_summary["chain_type"] = UploadChainTypes.LONG_DOCUMENT.value
    else:
        media_summary["chain_type"] = UploadChainTypes.SHORT_DOCUMENT.value
    media_summary = {
        **media_summary,
        "file_path": file_path,
        "preview": None,
        "full_extracted_text": extracted_text,
    }
    media_summary["begin_time"] = begin_time
    metadata_path: str = save_extracted_json(file_path, media_summary or {})
    media_summary["metadata_path"] = metadata_path
    del media_summary["full_extracted_text"]
    return media_summary
