import logging
import torch

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline
from typing import List
from typing import Dict

from dataiku.huggingface.pipeline_text_gen import ModelPipelineTextGeneration
from dataiku.huggingface.chat_template import ChatTemplateRenderer
from dataiku.huggingface.torch_utils import best_supported_dtype


class ModelPipelineTextGenerationMPT(ModelPipelineTextGeneration):

    def _initialize_pipeline(self, model_path, model_kwargs, model_settings, trust_remote_code):
        self.cuda_available = torch.cuda.is_available()
        logging.info("CUDA available: {cuda_available}".format(cuda_available=self.cuda_available))
        config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)

        if self.cuda_available:
            config.init_device = 'cuda'  # For fast initialization directly on GPU!

        kwargs = {} if model_kwargs is None else model_kwargs
        model = AutoModelForCausalLM.from_pretrained(model_path, config=config, trust_remote_code=trust_remote_code, **kwargs)

        tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side='left')

        quantization_mode = model_settings.get('quantizationMode', 'NONE')
        if self.cuda_available and quantization_mode == 'NONE':
            self.task = pipeline('text-generation', model=model, tokenizer=tokenizer, device='cuda')
        else:
            self.task = pipeline('text-generation', model=model, tokenizer=tokenizer)

        # Fix an issue in transformers: "Pipeline with tokenizer without pad_token cannot do batching"
        # todo use common ModelPipelineTextGeneration._fixup_tokenizer_if_needed() once the tokenizer used have a proper value for eos_token:
        if self.task.tokenizer.pad_token_id is None:
            self.task.tokenizer.pad_token_id = 0

        self.chat_template_renderer = ChatTemplateRenderer(self.task.tokenizer, 'TEXT_GENERATION_MPT')

    def _run_inference(self, input_texts: List[str], tf_params: Dict) -> List:
        with torch.autocast('cuda', dtype=best_supported_dtype()):
            responses = self.task(input_texts, eos_token_id=0, pad_token_id=0, **tf_params)

            if logging.DEBUG >= logging.root.level:
                logging.debug("Response received from HF model: \n {}".format(responses))

            return responses
