import base64
from concurrent.futures import ThreadPoolExecutor
import logging
from typing import Callable, Coroutine, Type, cast

from dataiku.llm.python import BaseImageGenerationModel
from dataiku.llm.python.processing.base_processor import BaseProcessor
from dataiku.llm.python.types import ImageGenerationResponse, PluginImageGenerationResponse, ProcessSingleImageGenerationCommand
from dataikuapi.dss.llm import DSSLLMImageGenerationResponse

logger = logging.getLogger(__name__)

class ImageGenerationProcessor(BaseProcessor[BaseImageGenerationModel, ProcessSingleImageGenerationCommand, ImageGenerationResponse]):
    def __init__(self, clazz: Type[BaseImageGenerationModel], executor: ThreadPoolExecutor, config: dict, pluginConfig: dict, trace_name: str):
        super().__init__(clazz, executor, config, pluginConfig, trace_name)

    def get_inference_params(self, command: ProcessSingleImageGenerationCommand) -> dict:
        query = command.get("query", None)
        if query is None:
            raise Exception(f"'query' missing from command {command}")
        return { "query" : query }

    def get_async_inference_func(self) -> Callable[..., Coroutine]:
        return self._instance.aprocess

    def get_sync_inference_func(self) -> Callable:
        return self._instance.process

    def parse_raw_response(self, raw_response: PluginImageGenerationResponse) -> ImageGenerationResponse:
        # Try to eat anything the user can throw at us

        if isinstance(raw_response, str):
            return {"images": [ { "data" : raw_response} ]}

        if isinstance(raw_response, bytes):
            image_data = base64.b64encode(raw_response).decode("utf8")
            return {"images": [{ "data" : image_data}]}

        if isinstance(raw_response, list):
            if (len(raw_response)) == 0:
                return {"images": []}
            if (isinstance(raw_response[0], str)):
                return { "images": [ { "data" : data} for data in raw_response]}
            if (isinstance(raw_response[0], bytes)):
                return { "images": [ { "data" : base64.b64encode(bytes_data).decode("utf8")} for bytes_data in raw_response]}

        if isinstance(raw_response, DSSLLMImageGenerationResponse):
            return raw_response._raw

        if isinstance(raw_response, dict):
            return raw_response

        raise Exception("Unrecognized response type (%s)" % type(raw_response))
