import asyncio
import logging
import os
import traceback
from concurrent.futures import ThreadPoolExecutor

from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.core import debugging

logger = logging.getLogger(__name__)

class DocExtractionServer:
    def __init__(self):
        self.started = False
        self.executor = ThreadPoolExecutor(1)
        self._docling_pipeline = None
        self._raw_pipeline = None
        self.settings = None

    async def handler(self, command):
        if command["type"] == "start":
            logger.info("Received start command: %s" % command)
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )
        elif command["type"] == "process-document-structured":
            logger.info("received process-document-structured")
            logger.debug("Received command: %s", command)
            yield await self.docling_process_document(command)
        elif command["type"] == "process-document-raw":
            logger.info("received process-document-raw")
            logger.debug("Received command: %s", command)
            yield await self.raw_process_document(command)
        else:
            raise ValueError("Unknown command type: %s" % command["type"])

    def start(self, start_command):
        """
        Store settings for lazy pipeline initialization on first document processing.
        Don't import docling stuff here : too slow and caused timeouts (sc-289562)
        """
        logger.info("starting doc extraction server")
        assert not self.started, "Already started"
        self.settings = start_command.get("settings", {})
        self.started = True

    async def docling_process_document(self, process_document_command):
        # Safe because this runs on the main asyncio event loop. Since there are no `await` calls during pipeline initialization, the event loop cannot switch tasks, making this block atomic.
        if self._docling_pipeline is None:
            logger.info("First document processing - initializing DoclingExtractorPipeline")
            try:
                from dataiku.llm.docextraction.docling_extraction import DoclingExtractorPipeline
            except ModuleNotFoundError as e:
                raise ModuleNotFoundError(
                    f"Unable to initialize pipeline. Check the code environment: {e}"
                ) from e

            self._docling_pipeline = DoclingExtractorPipeline(self.settings)
            logger.info("DoclingExtractorPipeline initialized successfully")

        return await self._docling_pipeline.process_document(process_document_command)

    async def raw_process_document(self, process_document_command):
        if self._raw_pipeline is None:
            logger.info("First document processing - initializing RawExtractorPipeline")
            try:
                from dataiku.llm.docextraction.raw_extraction import RawExtractorPipeline
            except ModuleNotFoundError as e:
                raise ModuleNotFoundError(
                    f"Unable to initialize pipeline. Check the code environment: {e}"
                ) from e

            self._raw_pipeline = RawExtractorPipeline(self.settings)
            logger.info("RawExtractorPipeline initialized successfully")

        return await self._raw_pipeline.process_document(process_document_command)


def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logger.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )

if __name__ == "__main__":
    LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()  # Set LOGLEVEL=DEBUG to debug
    logging.basicConfig(level=LOGLEVEL,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')

    debugging.install_handler()
    watch_stdin()

    async def start_server():
        asyncio.get_event_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = DocExtractionServer()
        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
