import logging

import pandas as pd
import numpy as np
from typing import List, Any, Tuple

from numpy import ndarray, dtype

from dataiku.llm.evaluation.exceptions import BertScoreException
from dataiku.llm.evaluation.llm_metrics_input import LLMMetricInputRole, LLMMetricInput
from dataiku.llm.evaluation.utils import failure_utils
from dataiku.llm.evaluation.utils.metrics_utils import mean_or_none

logger = logging.getLogger(__name__)


def compute_bert_score(interpreted_columns: LLMMetricInput, model_type='bert-base-uncased') -> Tuple[dict, pd.DataFrame]:
    from bert_score import BERTScorer
    from bert_score.utils import model2layers
    if not model2layers.get(model_type):
        logger.error("Could not find model name {}. Available models are: {}".format(model_type, model2layers.keys()))
        raise BertScoreException("Model {} not found".format(model_type))

    # empty values will crash. Avoid them.
    filtered_columns = failure_utils.filter_null_rows(interpreted_columns,
                                                      {LLMMetricInputRole.OUTPUT: ["BERT Score"], LLMMetricInputRole.GROUND_TRUTH: ["BERT Score"]},
                                                      ["BERT Score"])
    initial_index = filtered_columns.input.index  # Keep index for output

    try:
        logger.info("BERT-score model type : %s" % str(model_type))
        candidate = filtered_columns.output.to_list()
        reference = filtered_columns.ground_truth.to_list()
        scorer = BERTScorer(model_type=model_type)
        bert_score_precision_tensor, bert_score_recall_tensor, bert_score_f1_tensor = scorer.score(candidate, reference)
        bert_score_precision, bert_score_recall, bert_score_f1 = _tensor_list_to_array(bert_score_precision_tensor), _tensor_list_to_array(bert_score_recall_tensor), _tensor_list_to_array(bert_score_f1_tensor)

        bert_score_perf = {
            "bertScorePrecision": mean_or_none(pd.Series(bert_score_precision)),
            "bertScoreRecall": mean_or_none(pd.Series(bert_score_recall)),
            "bertScoreF1": mean_or_none(pd.Series(bert_score_f1)),
        }

        logger.info("BERT Score results : %s" % str(bert_score_perf))

        return bert_score_perf, pd.DataFrame({"bertScorePrecision": bert_score_precision, "bertScoreRecall": bert_score_recall,
                                              "bertScoreF1": bert_score_f1}, initial_index)
    except Exception as e:
        raise BertScoreException("An error happened during the computation of BERT Score metrics : %s" % str(e), e)


def _tensor_list_to_array(tensor_list: List[object]) -> 'ndarray[Any, dtype[Any]]':
    return np.array([tensor.item() for tensor in tensor_list])


def has_bert_score(metrics: List[str]) -> bool:
    return "bertScore" in metrics


def create_empty_bert_score(interpreted_columns: LLMMetricInput) -> Tuple[dict, pd.DataFrame]:
    return failure_utils.create_empty_metrics(interpreted_columns, ["bertScorePrecision", "bertScoreRecall", "bertScoreF1"])

