import logging

from abc import abstractmethod
from typing import List
from typing import Dict

from dataiku.huggingface.pipeline_batching import ModelPipelineBatching
from dataiku.huggingface.types import ProcessSinglePromptCommand
from dataiku.huggingface.types import ProcessSinglePromptResponseTextFull


class ModelPipelineSummarization(ModelPipelineBatching[ProcessSinglePromptCommand, ProcessSinglePromptResponseTextFull]):
    """Base class for the summarization pipelines to implement."""

    def __init__(self, tokenizer, context_length, batch_size):
        super().__init__(batch_size=batch_size)
        self.tokenizer = tokenizer
        self.context_length = context_length

        self.special_tokens_safety_factor = None
        self.overlap_size = None
        self.max_num_split_levels = None

        self.used_engine = "transformers"
        self.model_tracking_data["task"] = "summarization"

    def _get_inputs(self, requests: List[ProcessSinglePromptCommand]) -> List[str]:
        return [request["prompt"] for request in requests]

    def _get_params(self, request: ProcessSinglePromptCommand) -> Dict:
        params = request.get("settings", {})
        kwargs = {}

        if "summarizationMinTokens" in params:
            kwargs["min_length"] = params["summarizationMinTokens"]
        if "summarizationMaxTokens" in params:
            kwargs["max_length"] = params["summarizationMaxTokens"]

        # Not passed to the model, but used in our layered summarization
        self.special_tokens_safety_factor = params.get("summarizationSpecialTokensSafetyFactor", None)
        self.overlap_size = params.get("summarizationNumOverlapTokens", None)
        self.max_num_split_levels = params.get("summarizationMaxNumSplitLevels", None)

        return kwargs

    def _run_inference(self, input_texts: List[str], tf_params: Dict):
        if self.context_length and self.tokenizer:
            max_input_length_tokens = max(list(map(len, self.tokenizer(input_texts).input_ids)))
            # TODO @llm : this is suboptimal because if just 1 one the batch item is too long we won't batch anything (besides layered_summarization splits)
            if max_input_length_tokens > self.context_length:
                logging.info("Summarization input is longer than model context length. Using layered summarization mode, skipping batching. Max input length: {input_length} tokens. Model context length: {context_length} tokens.".format(
                    input_length=max_input_length_tokens, context_length=self.context_length
                ))
                return [self._run_layered_summarization(prompt, **tf_params) for prompt in input_texts]

        return self._run_summarization(input_texts, **tf_params)

    @abstractmethod
    def _run_summarization(self, texts, **kwargs):
        """Run the summarization model without layered summarization.
        :param texts: texts to be summarized
        :type texts: list[str]
        :return: summarized texts
        :rtype list[str]
        """
        pass

    def _split_text(self, input_text):
        """Splits a text in several chunks fitting context_length
        :param input_text: text to be split
        :type input_text: str
        :return: one or several text splits along with their length
        :rtype list[(string, int)]
        """
        input_ids = self.tokenizer(input_text, add_special_tokens=False).input_ids

        if len(input_ids) <= self.context_length:
            return [(input_text, len(input_ids))]

        chunk_size = self.context_length - self.special_tokens_safety_factor

        token_splits = (input_ids[i:i+chunk_size] for i in range(0, len(input_ids), chunk_size - self.overlap_size))
        splits = [(self.tokenizer.decode(token_split), len(token_split)) for token_split in token_splits]

        logging.info("Input text too long to summarize in one go ({input_length} tokens; LLM context length {context_length} tokens)."
                     " Splitting into {num_splits} chunks (max chunk length {chunk_length} tokens, overlap {overlap} tokens).".format(
                        input_length=len(input_ids), context_length=self.context_length,
                        num_splits=len(splits), chunk_length=chunk_size, overlap=self.overlap_size))

        return splits

    def _run_layered_summarization(self, input_text, num_splits_made=0, **kwargs):
        """
        Summarizes a long text by splitting it in several chunks and summarizing these chunks recursively
        until getting a combined text fitting the context_length.
        :param input_text: text to be summarized
        :type input_text: str
        :param num_splits_made: inner recursive counter used to stop one max_num_split_levels is reached.
        :type num_splits_made: int
        :return: summarized text
        :rtype str
        """
        splits = self._split_text(input_text)
        if len(splits) == 1:
            logging.info("Running base model on combined split summaries ({num_tokens} tokens)".format(num_tokens=splits[0][1]))
            return self._run_summarization([input_text], **kwargs)[0]

        num_splits_made += 1
        if num_splits_made > self.max_num_split_levels:
            raise Exception("Unable to perform summarization. More than {num_levels} levels of splits needed. Consider reducing max summary length.".format(
                num_levels=self.max_num_split_levels))

        logging.info("Running base model on {num_splits} splits of a maximum of {max_tokens} tokens".format(
            num_splits=len(splits), max_tokens=max([split[1] for split in splits])))
        splits_responses = self._run_summarization([split[0] for split in splits], **kwargs)
        parsed_responses = self._parse_responses(splits_responses, splits)
        combined_response = " ".join(r["text"] for r in parsed_responses)

        return self._run_layered_summarization(combined_response, num_splits_made, **kwargs)

    def _parse_response(self, response, request: ProcessSinglePromptCommand) -> ProcessSinglePromptResponseTextFull:
        # HF doc:
        # Each result comes as a dictionary with the following keys:
        # summary_text (str, present when return_text=True) — The summary of the corresponding input.
        return {"text": response["summary_text"]}
