import json

from transformers import pipeline
from typing import List
from typing import Dict

from dataiku.huggingface.pipeline_batching import ModelPipelineBatching
from dataiku.huggingface.types import ProcessSinglePromptCommand
from dataiku.huggingface.types import ProcessSinglePromptResponseTextFull
from dataiku.huggingface.env_collector import extract_model_architecture


class ModelPipelineTextClassif(ModelPipelineBatching[ProcessSinglePromptCommand, ProcessSinglePromptResponseTextFull]):
    # ModelCacheService.java:: ModelStorageDefinition calls PartitioningUtils.encode which replaces / with _2f
    no_device_map_auto_support = [
        {
            'model_path': 'nlptown/bert-base-multilingual-uncased-sentiment',
            'model_cache_path': 'nlptown_2fbert-base-multilingual-uncased-sentiment'
        },
        {
            'model_path': 'distilbert/distilbert-base-uncased-finetuned-sst-2-english',
            'model_cache_path': 'distilbert_2fdistilbert-base-uncased-finetuned-sst-2-english'
        },
        {
            'model_path': 'unitary/toxic-bert',
            'model_cache_path': 'unitary_2ftoxic-bert'
        },
        {
            'model_path': 'citizenlab/distilbert-base-multilingual-cased-toxicity',
            'model_cache_path': 'citizenlab_2fdistilbert-base-multilingual-cased-toxicity'
        },
        {
            'model_path': 'EIStakovskii/french_toxicity_classifier_plus_v2',
            'model_cache_path': 'EIStakovskii_2ffrench_5ftoxicity_5fclassifier_5fplus_5fv2'
        },
        {
            'model_path': 'meta-llama/Prompt-Guard-86M',
            'model_cache_path': 'meta-llama_2fPrompt-Guard-86M'
        },
    ]

    def __init__(self, model_path, model_kwargs, batch_size):
        super().__init__(batch_size=batch_size)
        # bert based models do not support device_map="auto"
        supports_device_map_auto = True
        for unsupported_path in self.no_device_map_auto_support:
            if model_path == unsupported_path["model_path"] or unsupported_path["model_cache_path"] in model_path:
                supports_device_map_auto = False
                break

        if supports_device_map_auto:
            self.task = pipeline("text-classification", model=model_path, model_kwargs=model_kwargs, device_map="auto")
        else:
            self.task = pipeline("text-classification", model=model_path, model_kwargs=model_kwargs)

        self.model_tracking_data["task"] = "text-classification"
        self.model_tracking_data["used_engine"] = "transformers"
        self.model_tracking_data["model_architecture"] = extract_model_architecture(self.task.model.config)

    def _get_inputs(self, requests: List[ProcessSinglePromptCommand]) -> List[str]:
        return [request["prompt"] for request in requests]

    def _run_inference(self, input_texts: List[str], tf_params: Dict):
        return self.task(input_texts, **tf_params)

    def _get_params(self, request: ProcessSinglePromptCommand) -> Dict:
        kwargs = {}

        if request["settings"]["textClassificationOutputMode"] == "ALL":
            kwargs["top_k"] = 100

        return kwargs

    def _parse_response(self, response, request: ProcessSinglePromptCommand) -> ProcessSinglePromptResponseTextFull:
        if request["settings"]["textClassificationOutputMode"] == "ALL":
            # List of dict(label, score). Reformat to dict(label->score)
            output_dict = {}
            for label_score in response:
                output_dict[label_score["label"]] = label_score["score"]
            return {"text": json.dumps(output_dict)}
        else:
            # List of dict(label, score) with a single item, unfold it
            return {"text": json.dumps(response)}
