import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import traceback
from typing import Type

from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin, get_clazz_in_code
from dataiku.core import debugging
from dataiku.llm.python.custom.base_model import BaseModel
from dataiku.llm.python.processing.base_processor import BaseProcessor
from dataiku.llm.python.processing.completion_processor import CompletionProcessor
from dataiku.llm.python.processing.embeddings_processor import EmbeddingProcessor
from dataiku.llm.python.processing.image_generation_processor import ImageGenerationProcessor
from dataiku.llm.python import BaseLLM
from dataiku.llm.python import BaseEmbeddingModel
from dataiku.llm.python import BaseImageGenerationModel
from dataiku.llm.python.exception import RetryableException
from dataiku.llm.python.types import StartCustomLLMServerCommand

logger = logging.getLogger(__name__)


class LLMPluginServer:
    processor: BaseProcessor
    started: bool
    executor: ThreadPoolExecutor

    def __init__(self):
        self.started = False
        self.executor = ThreadPoolExecutor(32)

    async def handler(self, command):
        command_type = command["type"]
        logger.info("PythonCustomLLMServer handler received command of type: %s", command_type)
        if command_type == "start-custom-llm-server":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )
        elif command_type in [ "process-completion-query", "process-embedding-query", "process-image-generation-query" ]:
            assert self.started, "Not started"
            if command.get("stream", False):
                async for resp in self.processor.process_query_stream(command):
                    yield resp
            else:
                yield await self.processor.process_query(command)
        else:
            raise Exception(f"Command '{command['type']}' is not implemented")

    def start(self, start_command: StartCustomLLMServerCommand):
        assert not self.started, "Already started"

        if "code" in start_command and start_command["code"] is not None:
            self.processor = self._get_processor(command=start_command)
        else:
            raise Exception("Code missing.")

        self.started = True
        return { "ok" : True }

    def _get_processor(self, command: StartCustomLLMServerCommand) -> BaseProcessor:
        # we check what capability the model has so we can decide which class it's supposed to implement
        capability = command["capability"]
        processor_ctor: Type[BaseProcessor]
        parent_class: Type[BaseModel]
        trace_name: str
        if capability in [ "TEXT_COMPLETION_MULTIMODAL", "TEXT_COMPLETION" ]:
            parent_class = BaseLLM
            trace_name = "DKU_LLM_PLUGIN_COMPLETION_CALL"
            processor_ctor = CompletionProcessor

        elif capability in [ "TEXT_EMBEDDING", "TEXT_IMAGE_EMBEDDING_EXTRACTION" ]:
            parent_class = BaseEmbeddingModel
            trace_name = "DKU_LLM_PLUGIN_EMBEDDING_CALL"
            processor_ctor = EmbeddingProcessor

        elif capability == "IMAGE_GENERATION":
            parent_class = BaseImageGenerationModel
            trace_name = "DKU_LLM_PLUGIN_IMAGE_GENERATION_CALL"
            processor_ctor = ImageGenerationProcessor

        else:
            raise Exception(f"Model capability '{capability}' not supported.")

        config = command['config']
        plugin_config = command['pluginConfig']
        code = command['code']

        # We retieve the class and pass it to the processor
        clazz = get_clazz_in_code(code, parent_class) # strict_module?
        return processor_ctor(clazz=clazz, executor=self.executor, config=config, pluginConfig=plugin_config, trace_name=trace_name)

def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logger.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()
    watch_stdin()

    async def start_server():
        asyncio.get_running_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = LLMPluginServer()
        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
