import logging

import pandas as pd
from typing import Dict, List, Tuple

from dataiku.llm.evaluation.exceptions import RougeException
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput, GenAIMetricInputRole
from dataiku.llm.evaluation.utils import failure_utils
from dataiku.llm.evaluation.utils.metrics_utils import mean_or_none

logger = logging.getLogger(__name__)

def compute_rouge(metric_inputs: GenAIMetricInput) -> Tuple[dict, pd.DataFrame]:
    from rouge_score import rouge_scorer
    try:
        # empty values will crash. Avoid them.
        filtered_columns = failure_utils.filter_null_rows(metric_inputs,
                                                          {GenAIMetricInputRole.OUTPUT: ["ROUGE"], GenAIMetricInputRole.GROUND_TRUTH: ["ROUGE"]},
                                                          ["ROUGE"])
        initial_index = filtered_columns.input.index  # Keep index for output

        candidate = filtered_columns.output.to_list()
        reference = filtered_columns.ground_truth.to_list()
        multiple_references = isinstance(reference[0], list) and len(reference[0]) > 0 and all(isinstance(element, str) for element in reference[0])
        logger.info("Computing ROUGE with %s reference" % str('multiple' if multiple_references else 'single'))
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        rouge_1_line_precision = []
        rouge_1_line_recall = []
        rouge_1_line_fmeasure = []
        rouge_2_line_precision = []
        rouge_2_line_recall = []
        rouge_2_line_fmeasure = []
        rouge_L_line_precision = []
        rouge_L_line_recall = []
        rouge_L_line_fmeasure = []
        for i in range(len(candidate)):
            if not i % 1000:
                logger.info("Computing ROUGE of row {}".format(i))
            try:
                if multiple_references:
                    scores = scorer.score_multi(reference[i], candidate[i])
                else:
                    scores = scorer.score(reference[i], candidate[i])
                rouge_1_line_precision.append(scores['rouge1'].precision)
                rouge_1_line_recall.append(scores['rouge1'].recall)
                rouge_1_line_fmeasure.append(scores['rouge1'].fmeasure)
                rouge_2_line_precision.append(scores['rouge2'].precision)
                rouge_2_line_recall.append(scores['rouge2'].recall)
                rouge_2_line_fmeasure.append(scores['rouge2'].fmeasure)
                rouge_L_line_precision.append(scores['rougeL'].precision)
                rouge_L_line_recall.append(scores['rougeL'].recall)
                rouge_L_line_fmeasure.append(scores['rougeL'].fmeasure)
            except Exception as e:
                raise RougeException("Error on line %i : %s" % (i, str(e)))



        rouge_score_perf = {
            "rouge1Precision": mean_or_none(pd.Series(rouge_1_line_precision)),
            "rouge1Recall": mean_or_none(pd.Series(rouge_1_line_recall)),
            "rouge1F1": mean_or_none(pd.Series(rouge_1_line_fmeasure)),
            "rouge2Precision": mean_or_none(pd.Series(rouge_2_line_precision)),
            "rouge2Recall": mean_or_none(pd.Series(rouge_2_line_recall)),
            "rouge2F1": mean_or_none(pd.Series(rouge_2_line_fmeasure)),
            "rougeLPrecision": mean_or_none(pd.Series(rouge_L_line_precision)),
            "rougeLRecall": mean_or_none(pd.Series(rouge_L_line_recall)),
            "rougeLF1": mean_or_none(pd.Series(rouge_L_line_fmeasure))
        }

        logger.info("ROUGE results : %s" % str(rouge_score_perf))
        return (rouge_score_perf,
                pd.DataFrame({
                        "rouge1Precision": rouge_1_line_precision,
                        "rouge1Recall": rouge_1_line_recall,
                        "rouge1F1": rouge_1_line_fmeasure,
                        "rouge2Precision": rouge_2_line_precision,
                        "rouge2Recall": rouge_2_line_recall,
                        "rouge2F1": rouge_2_line_fmeasure,
                        "rougeLPrecision": rouge_L_line_precision,
                        "rougeLRecall": rouge_L_line_recall,
                        "rougeLF1": rouge_L_line_fmeasure,
                    },
                    initial_index))
    except Exception as e:
        raise RougeException("An error happened during the computation of ROUGE metrics : %s" % str(e), e)


def has_rouge(metrics: List[str]) -> bool:
    return "rouge" in metrics


def create_empty_rouge(metric_inputs: GenAIMetricInput) -> Tuple[Dict[str, None], pd.DataFrame]:
    return failure_utils.create_empty_metrics(
        metric_inputs,
        [
            "rouge1Precision",
            "rouge1Recall",
            "rouge1F1",
            "rouge2Precision",
            "rouge2Recall",
            "rouge2F1",
            "rougeLPrecision",
            "rougeLRecall",
            "rougeLF1",
        ]
    )
