import logging

import pandas as pd
from typing import Dict, List, Tuple

from dataiku.llm.evaluation.exceptions import BleuException
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput, GenAIMetricInputRole
from dataiku.llm.evaluation.utils import failure_utils
from dataiku.llm.evaluation.utils.metrics_utils import mean_or_none

logger = logging.getLogger(__name__)


def compute_bleu(metric_inputs: GenAIMetricInput, tokenizer='13a') -> Tuple[dict, pd.DataFrame]:
    from sacrebleu.metrics import BLEU
    from sacrebleu.metrics.bleu import _TOKENIZERS
    if tokenizer not in _TOKENIZERS:
        raise BleuException("Unknown BLEU tokenizer : %s. Possible tokenizers are %s." % (tokenizer, set(_TOKENIZERS.keys())))
    try:
        # empty values will crash. Avoid them.

        filtered_columns = failure_utils.filter_null_rows(
            metric_inputs,
            {GenAIMetricInputRole.OUTPUT: ["BLEU"], GenAIMetricInputRole.GROUND_TRUTH: ["BLEU"]},
            ["BLEU"]
        )
        initial_index = filtered_columns.input.index  # Keep index for output

        candidate = filtered_columns.output.to_list()
        reference = filtered_columns.ground_truth.to_list()
        multiple_references = isinstance(reference[0], list) and len(reference[0]) > 0 and all(isinstance(element, str) for element in reference[0])
        logger.info("Computing BLEU with %s reference and %s tokenizer" % (str('multiple' if multiple_references else 'single'), tokenizer))
        bleu_line_scorer = BLEU(tokenize=tokenizer, effective_order=True)
        bleu_line_scores = []
        for i in range(len(candidate)):
            if not i % 1000:
                logger.info("Computing BLEU of row {}".format(i))
            ref = reference[i] if multiple_references else [reference[i]]
            try:
                bleu_line_scores.append(bleu_line_scorer.sentence_score(candidate[i], ref).score)
            except Exception as e:
                raise BleuException("Error on line %i : %s" % (i, str(e)))
        bleu_score_perf = {
            "bleu": mean_or_none(pd.Series(bleu_line_scores))
        }

        logger.info("BLEU results : %s" % str(bleu_score_perf))
        return bleu_score_perf, pd.DataFrame({"bleu": bleu_line_scores}, initial_index)
    except Exception as e:
        raise BleuException("An error happened during the computation of BLEU metrics : %s" % str(e), e)


def has_bleu(metrics: List[str]) -> bool:
    return "bleu" in metrics


def create_empty_bleu(metric_inputs: GenAIMetricInput) -> Tuple[Dict[str, None], pd.DataFrame]:
    return failure_utils.create_empty_metrics(metric_inputs, ["bleu"])
