import asyncio
import copy
import logging
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import cast, Type

from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin, get_clazz_in_code
from dataiku.core import debugging
from dataiku.llm.python import BaseLLM
from dataiku.llm.python.types import StartAgentServerCommand
from dataiku.llm.python.processing.completion_processor import CompletionProcessor

logger = logging.getLogger(__name__)


def plugin_get_redacted_command(command):
    assert "config" in command
    redacted_command = copy.deepcopy(command)
    config = redacted_command["config"]

    # plugin params with type PASSWORD
    if password_params := command.get("pluginConfig", {}).get("dkuPasswordParams"):
        for key in password_params:
            if key in config:
                config[key] = "**redacted**"
        del redacted_command["pluginConfig"]["dkuPasswordParams"]

    return redacted_command


def get_redacted_command(command):
    if "config" not in command:
        return command
    return plugin_get_redacted_command(command)


class PythonLLMServer:
    processor: CompletionProcessor
    started = False
    executor = ThreadPoolExecutor(32)
    run_counter = 1

    async def handler(self, command: StartAgentServerCommand):
        logger.info("PythonLLMServer handler received command: %s", get_redacted_command(command))

        if command["type"] == "start-agent-server":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )

        elif command["type"] == "process-completion-query":
            logger.info("\n===============  Start completion query - run %s ===============", self.run_counter)
            try:
                if command["stream"]:
                    async for resp in self.processor.process_query_stream(command):
                        yield resp
                else:
                    yield await self.processor.process_query(command)
                logger.info("\n=============== End completion query - run %s ===============", self.run_counter)
            finally:
                self.run_counter+=1
        else:
            raise Exception("Unknown command type: %s" % command["type"])

    def start(self, start_command: StartAgentServerCommand):
        assert not self.started, "Already started"
        if "code" in start_command and start_command["code"] is not None:
            code = start_command.get("code", "")
            clazz = cast(Type[BaseLLM], get_clazz_in_code(code, BaseLLM, strict_module=True))
        else:

            py_clazz = start_command["pyClazz"]
            if py_clazz == "dataiku.llm.python.tools_using.ToolUsingAgent":
                from .tools_using import ToolsUsingAgent
                clazz = ToolsUsingAgent
            else:
                raise Exception("Missing BaseLLM implementation")

        self.processor = CompletionProcessor(clazz, self.executor, start_command.get("config", {}), start_command.get("pluginConfig", {}), trace_name="DKU_AGENT_CALL")
        self.started = True
        return { "ok" : True }

def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logger.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()
    watch_stdin()

    async def start_server():
        asyncio.get_running_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = PythonLLMServer()
        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
