import asyncio
import logging
import os
import traceback
from concurrent.futures import ThreadPoolExecutor


from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.core import debugging

logger = logging.getLogger("docextraction_server")

class DocExtractionServer:
    def __init__(self):
        self.started = False
        self.executor = ThreadPoolExecutor(1)
        self.pipeline = None

    async def handler(self, command):
        if command["type"] == "start":
            logger.info("Received start command: %s" % command)
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )
        elif command["type"] == "process-document":
            logger.info("received process-document")
            logger.debug("Received command: %s", command)
            yield await self.process_document(command)
        else:
            raise Exception("Unknown command type: %s" % command["type"])

    def start(self, start_command):
        """
        Docling import is very long, we take the opportunity of loading the module during the start command
        """
        logger.info("starting doc extraction server")
        try:
            from dataiku.llm.docextraction.docling_extraction import DoclingExtractorPipeline
        except ModuleNotFoundError as e:
            raise ModuleNotFoundError(
                f"Unable to start the kernel. Check the code environment to ensure all packages are correctly installed : {e}"
            ) from e

        assert not self.started, "Already started"
        self.pipeline = DoclingExtractorPipeline(start_command.get("settings", {}))
        self.started = True

    async def process_document(self, process_document_command):
        return await self.pipeline.process_document(process_document_command)

def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logger.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )

if __name__ == "__main__":
    LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()  # Set LOGLEVEL=DEBUG to debug
    logging.basicConfig(level=LOGLEVEL,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')

    debugging.install_handler()
    watch_stdin()

    async def start_server():
        asyncio.get_event_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = DocExtractionServer()
        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
