import json
from abc import ABC, abstractmethod

# Formatters useful for fine-tuning with the transformers library, see https://huggingface.co/docs/trl/en/sft_trainer#dataset-format-support


class BaseFormatter(ABC):
    def __call__(self, samples):
        # Keep supporting older TRL versions since this code can be called by users in a managed code env
        # Exposed in https://github.com/dataiku/dku-contrib-private/blob/master/snippets/BUILTIN/python/llm_finetuning/variations/adapter-conversational.py
        if isinstance(next(iter(samples.values())), list):
            return self._call_batched(samples)

        # ... but TRL does not use batched formatting anymore
        # See: https://github.com/huggingface/trl/pull/3147
        batched_input = {x: [y] for x, y in samples.items()}
        batched_output = self._call_batched(batched_input)
        return batched_output[0]

    @abstractmethod
    def _call_batched(self, samples):
        raise NotImplementedError


class InstructPromptFormatter(BaseFormatter):
    """ Formats a dataset used for fine-tuning an instruct model

    Input format: {
        "{prompt_column}": [promptA, promptB],
        "{completion_column}": [completionA, completionB],
    }

    Output format: [
        {"prompt": promptA, "completion": completionA},
        {"prompt": promptB, "completion": completionB},
    ]
    """
    def __init__(self, prompt_column, completion_column):
        self.prompt_column = prompt_column
        self.completion_column = completion_column

    def _call_batched(self, samples):
        examples = []
        for idx in range(len(samples[self.prompt_column])):
            examples.append(json.dumps({
                "prompt": samples[self.prompt_column][idx],
                "completion": samples[self.completion_column][idx],
            }))
        return examples


class ConversationalPromptFormatter(BaseFormatter):
    """ Formats a dataset used for fine-tuning a conversational model

    Input format: {
        "{system_column}": [systemMessageA, systemMessageB],
        "{user_column}": [userMessageA, userMessageB],
        "{assistant_column}": [assistantMessageA, assistantMessageB],
    }

    Output format: [{
        "messages": [
            {"role": "system", "content": "systemMessageA"},
            {"role": "user", "content": "userMessageA"},
            {"role": "assistant", "content": "assistantMessageA"},
        ],
    }, {
        "messages": [
            {"role": "system", "content": "systemMessageB"},
            {"role": "user", "content": "userMessageB"},
            {"role": "assistant", "content": "assistantMessageB"},
        ],
    }]
    """
    def __init__(self, chat_template_func, user_column, assistant_column, system_column=None, static_system_message=None):
        """
        :param func chat_template_func: a function called for each sample to turn it into the model's specific conversational template
        :param str user_column: name of the column hosting the user message
        :param str assistant_column: name of the column hosting the assistant message
        :param Union[str|None] system_column: name of the column hosting the system message, None by default
        :param Union[str|None] static_system_message: static string to use as system message for every sample, None by default
        """
        self.chat_template_func = chat_template_func
        self.user_column = user_column
        self.assistant_column = assistant_column
        self.system_column = system_column
        self.static_system_message = static_system_message

    def _call_batched(self, samples):
        examples = []
        for idx in range(len(samples[self.user_column])):
            example = []
            if self.system_column:
                example.append({
                    "role": "system",
                    "content": samples[self.system_column][idx],
                })
            elif self.static_system_message:
                example.append({
                    "role": "system",
                    "content": self.static_system_message,
                })
            example.append({
                "role": "user",
                "content": samples[self.user_column][idx],
            })
            example.append({
                "role": "assistant",
                "content": samples[self.assistant_column][idx],
            })
            try:
                examples.append(self.chat_template_func(example, tokenize=False))
            except ValueError as exc:
                if str(exc).startswith("Cannot use apply_chat_template() because tokenizer.chat_template is not set and no template argument was passed!"):
                    raise ValueError("The base model doesn't have a chat template set in tokenizer_config.json, fine-tuning it is not supported with transformers>=4.44.") from exc
                else:
                    raise
        return examples
