import logging
from collections import Counter
from typing import Dict, List, Optional, Tuple

import pandas as pd
from dataiku.llm.evaluation.exceptions import ToolCallException
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput, GenAIMetricInputRole
from dataiku.llm.evaluation.utils import failure_utils
from dataiku.llm.evaluation.utils.metrics_utils import median_or_none

logger = logging.getLogger(__name__)


# Implements the dynamic programming solution of the LCS problem
def _compute_lcs_length(actual: List[str], reference: List[str]) -> int:
    len_actual, len_reference = len(actual), len(reference)

    # Create a 2D table to store the lengths of longest common subsequences for all subproblems, initialized with zeroes
    # lcs_table[i][j] will hold the length of LCS of actual[:i] and reference[:j]
    lcs_table = [[0] * (len_reference + 1) for _ in range(len_actual + 1)]

    # The LCS table includes an extra row and column to represent the base case of comparing with an empty list
    # This padding ensures that subproblem solutions are well-defined and prevents index errors when accessing neighboring cells
    # It allows safe access to diagonal, left, and above values, hence the loops starting from 1

    # Loop through both lists and fill the LCS table
    for i in range(1, len_actual + 1):
        for j in range(1, len_reference + 1):
            # If the tool calls match, take the diagonal value and add 1
            if actual[i - 1] == reference[j - 1]:
                lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1
            # If the tool calls do not match, take the maximum of the value from the left and above
            else:
                lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1])

    # The bottom-right value in the LCS table contains the length of the longest common subsequence
    return lcs_table[len_actual][len_reference]


def tool_call_exact_match(actual_calls: List[str], reference_calls: List[str]) -> float:
    return float(actual_calls == reference_calls)


def tool_call_partial_match(actual_calls: List[str], reference_calls: List[str]) -> float:
    if not actual_calls and not reference_calls:
        # If both lists are empty, we consider them a perfect match
        return 1.0

    lcs_length = _compute_lcs_length(actual_calls, reference_calls)
    bigger_length = max(len(actual_calls), len(reference_calls))

    return lcs_length / bigger_length


def tool_call_precision_recall_f1(actual_calls: List[str], reference_calls: List[str]) -> Tuple[float, float, float]:
    len_actual = len(actual_calls)
    len_reference = len(reference_calls)

    if len_actual == 0 or len_reference == 0:
        # If both are empty, it's a perfect match
        if len_actual == 0 and len_reference == 0:
            precision, recall = 1.0, 1.0
        else:
            # If one is empty but not the other, it's a complete mismatch
            # Precision is 0 if you called something when you shouldn't have, or vice versa
            # Recall is 0 if you missed a required call, or vice versa
            precision, recall = 0.0, 0.0

    else:
        actual_counter = Counter(actual_calls)
        reference_counter = Counter(reference_calls)

        intersection_counter = actual_counter & reference_counter
        matches = sum(intersection_counter.values())

        precision = matches / len_actual
        recall = matches / len_reference

    if precision == recall == 0.0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)

    return precision, recall, f1


TOOL_CALL_METRICS_FUNCTION_MAP = {
    "toolCallExactMatch": tool_call_exact_match,
    "toolCallPartialMatch": tool_call_partial_match,
    "toolCallPrecisionRecallF1": tool_call_precision_recall_f1,
}


def compute_tool_call_metrics(
    metric_inputs: GenAIMetricInput, metrics: List[str]
) -> Tuple[Dict[str, Optional[float]], pd.DataFrame]:
    try:
        metrics_to_compute = get_tool_call_metrics(metrics)
        metric_names_by_column_role = {
            GenAIMetricInputRole.ACTUAL_TOOL_CALLS: metrics_to_compute,
            GenAIMetricInputRole.REFERENCE_TOOL_CALLS: metrics_to_compute,
        }

        # Empty values will crash so avoid them
        metric_inputs = failure_utils.filter_null_rows(metric_inputs, metric_names_by_column_role, metrics_to_compute)

        initial_index = metric_inputs.input.index if metric_inputs.input is not None else None  # Keep index for output

        actual_calls_list = []
        for row in metric_inputs.actual_tool_calls.to_list():
            normalized_row = []
            for tool_dict in row:
                if isinstance(tool_dict, dict) and "toolName" in tool_dict:
                    normalized_row.append(tool_dict["toolName"])
                else:
                    normalized_row.append(tool_dict)
            actual_calls_list.append(normalized_row)
        reference_calls_list = metric_inputs.reference_tool_calls.to_list()

        logger.info(f"The following tool call metrics will be computed : {metrics_to_compute}")

        metrics_results, metrics_scores = batch_tool_call_metrics(
            metrics_to_compute, actual_calls_list, reference_calls_list
        )

        logger.info(f"Global tool call metrics result : {str(metrics_results)}")
        return metrics_results, pd.DataFrame(metrics_scores, initial_index)

    except Exception as e:
        raise ToolCallException(f"An error happened during the computation of tool call metrics : {str(e)}", e)


def batch_tool_call_metrics(
    metrics_to_compute: List[str], actual_calls_list: List[List[str]], reference_calls_list: List[List[str]]
) -> Tuple[Dict[str, Optional[float]], Dict[str, List[float]]]:
    if len(actual_calls_list) != len(reference_calls_list):
        raise ToolCallException(
            f"Column mismatch: actual calls ({len(actual_calls_list)}) and reference calls ({len(reference_calls_list)}) are not the same length"
        )

    metrics_results = {}
    metrics_scores = {}

    for metric in metrics_to_compute:
        function_to_call = TOOL_CALL_METRICS_FUNCTION_MAP[metric]
        if metric == "toolCallPrecisionRecallF1":
            precision_scores, recall_scores, f1_scores = [], [], []
            for actual, reference in zip(actual_calls_list, reference_calls_list):
                precision, recall, f1 = function_to_call(actual, reference)
                precision_scores.append(precision)
                recall_scores.append(recall)
                f1_scores.append(f1)
            metrics_scores["toolCallPrecision"] = precision_scores
            metrics_scores["toolCallRecall"] = recall_scores
            metrics_scores["toolCallF1"] = f1_scores
            metrics_results["toolCallPrecision"] = median_or_none(pd.Series(precision_scores))
            metrics_results["toolCallRecall"] = median_or_none(pd.Series(recall_scores))
            metrics_results["toolCallF1"] = median_or_none(pd.Series(f1_scores))
        else:
            score = [
                function_to_call(actual_calls, reference_calls)
                for actual_calls, reference_calls in zip(actual_calls_list, reference_calls_list)
            ]
            metrics_scores[metric] = score
            metrics_results[metric] = median_or_none(pd.Series(score))

    return metrics_results, metrics_scores


def get_tool_call_metrics(metrics: List[str]) -> List[str]:
    return [metric for metric in metrics if metric in TOOL_CALL_METRICS_FUNCTION_MAP.keys()]


def has_tool_call_metrics(metrics: List[str]) -> bool:
    return any(metric in metrics for metric in TOOL_CALL_METRICS_FUNCTION_MAP.keys())


def create_empty_tool_call_metrics(
    interpreted_columns: GenAIMetricInput, metrics: List[str]
) -> Tuple[Dict[str, None], pd.DataFrame]:
    tool_call_metrics = get_tool_call_metrics(metrics)
    return failure_utils.create_empty_metrics(interpreted_columns, tool_call_metrics)
