import logging

from transformers import T5Tokenizer, T5ForConditionalGeneration
from typing import List
from typing import Dict

from dataiku.huggingface.pipeline_text_gen import ModelPipelineTextGeneration
from dataiku.huggingface.types import ProcessSinglePromptCommand


# TODO @llms We are currently not using this model type - should be either deleted or a T5 model added in the backend
class ModelPipelineTextGenerationT5(ModelPipelineTextGeneration):

    def _initialize_pipeline(self, model_path, model_kwargs):
        self.tokenizer = T5Tokenizer.from_pretrained(model_path)
        kwargs = {} if model_kwargs is None else model_kwargs
        self.model = T5ForConditionalGeneration.from_pretrained(model_path, device_map="auto", **kwargs)

    def _get_inputs(self, requests: List[ProcessSinglePromptCommand]) -> List[str]:
        return [request["prompt"] for request in requests]

    def _run_inference(self, input_texts: List[str], tf_params: Dict) -> List:
        # TODO @llm: ensure does not return input text / ensure no issue with temp = 0
        input_ids = self.tokenizer(input_texts, return_tensors="pt").input_ids.to("cuda")
        outputs = self.model.generate(input_ids, **tf_params)  # todo should we pop out prompts & promptMessages?

        if logging.DEBUG >= logging.root.level:
            logging.debug("Response received from HF model: \n {}".format(outputs))

        return [{"generated_text": self.tokenizer.decode(output)} for output in outputs]
