import logging

from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Coroutine, Type

from dataiku.llm.python import BaseEmbeddingModel
from dataiku.llm.python.processing.base_processor import BaseProcessor
from dataiku.llm.python.types import EmbeddingResponse, ProcessSingleEmbeddingCommand, SimpleEmbeddingResponse
from dataikuapi.dss.llm import DSSLLMEmbeddingsResponse

logger = logging.getLogger(__name__)

class EmbeddingProcessor(BaseProcessor[BaseEmbeddingModel, ProcessSingleEmbeddingCommand, SimpleEmbeddingResponse]):
    def __init__(self, clazz: Type[BaseEmbeddingModel], executor: ThreadPoolExecutor, config: dict, pluginConfig: dict, trace_name: str):
        super().__init__(clazz=clazz, executor=executor, config=config, pluginConfig=pluginConfig, trace_name=trace_name)

    def get_inference_params(self, command: ProcessSingleEmbeddingCommand) -> dict:
        query = command.get("query", None)
        if query is None:
            raise Exception(f"'query' missing from command {command}")
        settings = command.get("settings", None)
        if settings is None:
            raise Exception(f"'settings' missing from command {command}")
        return { "query" : query, "settings" : settings }

    def get_async_inference_func(self) -> Callable[..., Coroutine]:
        return self._instance.aprocess

    def get_sync_inference_func(self) -> Callable:
        return self._instance.process

    def parse_raw_response(self, raw_response: EmbeddingResponse) -> SimpleEmbeddingResponse:
        # Try to eat anything the user can throw at us

        if isinstance(raw_response, list):
            return { "embedding": raw_response } # type: ignore - we could maybe check if the list is made of floats ?

        if isinstance(raw_response, DSSLLMEmbeddingsResponse):
            return raw_response._raw

        if isinstance(raw_response, dict):
            return raw_response

        raise Exception("Unrecognized response type (%s)" % type(raw_response))
