import logging

import pandas as pd
import numpy as np
from typing import Any, Tuple

from dataiku.llm.evaluation.exceptions import AgentEvalException
from dataiku.llm.evaluation.genai_eval_recipe_desc import GenAIEvalRecipeDesc, NO_COLUMN_SELECTED
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput, GenAIMetricInputRole
from dataiku.llm.evaluation.utils import failure_utils
from dataiku.llm.evaluation.utils.metrics_utils import mean_or_none

logger = logging.getLogger(__name__)


global_metrics = [
    "averageToolExecutionsPerRow",
    "averageFailedToolExecutionsPerRow",
    "averageToolExecutionTimeSecondsPerRow",
]
row_metrics = [
    "totalToolExecutions",
    "totalFailedToolExecutions",
    "totalToolExecutionTimeSeconds",
]


def compute_tool_statistics(input_df: pd.DataFrame, metric_input: GenAIMetricInput) -> Tuple[dict, pd.DataFrame]:
    # Empty values will crash so avoid them
    metric_names_by_column_role = {GenAIMetricInputRole.ACTUAL_TOOL_CALLS: global_metrics}
    filtered_metric_input = failure_utils.filter_null_rows(metric_input, metric_names_by_column_role, global_metrics)
    try:
        total_tool_executions_per_row = filtered_metric_input.actual_tool_calls.apply(_get_total_tools)
        total_failed_tool_executions_per_row = filtered_metric_input.actual_tool_calls.apply(_get_failed_tools)
        total_tool_execution_time_per_row = filtered_metric_input.actual_tool_calls.apply(
            _get_total_tool_execution_time_seconds
        )
        tool_statistics = {
            "averageToolExecutionsPerRow": mean_or_none(total_tool_executions_per_row),
            "averageFailedToolExecutionsPerRow": mean_or_none(total_failed_tool_executions_per_row),
            "averageToolExecutionTimeSecondsPerRow": mean_or_none(total_tool_execution_time_per_row),
        }
        tool_statistics_by_row = {
            "totalToolExecutions": total_tool_executions_per_row,
            "totalFailedToolExecutions": total_failed_tool_executions_per_row,
            "totalToolExecutionTimeSeconds": total_tool_execution_time_per_row,
        }
    except Exception as e:
        raise AgentEvalException(f"An error happened during the computation of tool statistics metrics : {str(e)}", e)
    return tool_statistics, pd.DataFrame(tool_statistics_by_row)


def can_tool_statistics(recipe_desc: GenAIEvalRecipeDesc) -> bool:
    return (
        recipe_desc.input_format == "PROMPT_RECIPE"
        or recipe_desc.actual_tool_calls_column_name is not None
        and recipe_desc.actual_tool_calls_column_name != NO_COLUMN_SELECTED
    )


def _get_total_tools(tool_calls: list):
    if not isinstance(tool_calls, list):
        return None
    else:
        return len(tool_calls)


def _get_failed_tools(tool_calls: list):
    if not isinstance(tool_calls, list):
        return None

    failed_tools = 0
    dict_tool_calls = (t for t in tool_calls if isinstance(t, dict)) # user-provided list of tool call might be just strings. If so, we don't have info on failure.
    for tool_call in dict_tool_calls:
        if 'error' in tool_call: # Error only if we have an explicit error
            failed_tools += 1

    return failed_tools


def _get_total_tool_execution_time_seconds(tool_calls: list[dict[str, Any]]):
    if not isinstance(tool_calls, list):
        return None
    else:
        durations = [
            c["durationMs"] for c in tool_calls if "durationMs" in c and isinstance(c["durationMs"], (int, float))
        ]
        if not durations:
            return None
        return np.sum(durations) / 1000  # returning seconds


def create_empty_tool_statistics(metric_input: GenAIMetricInput) -> Tuple[dict, pd.DataFrame]:
    empty_perf = {metric: None for metric in global_metrics}
    empty_row_by_row = pd.DataFrame({metric: pd.Series(None, metric_input.input.index, dtype=np.dtype(np.float64)) for metric in row_metrics})
    return empty_perf, empty_row_by_row
