import logging
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple

import numpy as np
import pandas as pd
from dataiku.llm.evaluation.genai_eval_recipe_desc import NO_COLUMN_SELECTED
from dataiku.llm.evaluation.genai_eval_recipe_desc import GenAIEvalRecipeDesc
from dataiku.llm.evaluation.utils import failure_utils
from dataiku.llm.evaluation.utils import metrics_utils
from dataiku.llm.evaluation.utils.common import PROMPT_RECIPE_RAW_RESPONSE_NAME
from dataiku.llm.evaluation.utils.common import _get_trajectory

logger = logging.getLogger(__name__)


def can_compute_total_execution_time(recipe_desc: GenAIEvalRecipeDesc) -> bool:
    return (
        recipe_desc.input_format == "PROMPT_RECIPE"
        and recipe_desc.actual_tool_calls_column_name is not None
        and recipe_desc.actual_tool_calls_column_name != NO_COLUMN_SELECTED
    )


def compute_total_execution_time(input_df: pd.DataFrame) -> Tuple[Dict[str, Optional[float]], pd.DataFrame]:
    """
    Get the total execution times from the trajectory in a PromptRecipe's "raw" output (Raw Response), in the format
    consumed by AgentEvaluationRecipe._update_outputs.
    """
    logger.info(
        f'Column "{PROMPT_RECIPE_RAW_RESPONSE_NAME}" is from a prompt recipe, trying to parse it for total elapsed times'
    )
    total_execution_time_per_row = input_df[PROMPT_RECIPE_RAW_RESPONSE_NAME].apply(_get_duration)
    if any(total_execution_time_per_row):
        logger.info(f'Found total execution times in "{PROMPT_RECIPE_RAW_RESPONSE_NAME}", from a prompt recipe.')
    else:
        failure_utils.warn(
            f'Column "{PROMPT_RECIPE_RAW_RESPONSE_NAME}" does not contain parsable total execution times from a prompt recipe (it might be empty).'
        )

    p95_execution_time = metrics_utils.p95_or_none(total_execution_time_per_row)
    return (
        {"p95TotalAgentCallExecutionTimeSecondsPerRow": p95_execution_time},
        pd.DataFrame({"totalAgentCallExecutionTimeSeconds": total_execution_time_per_row}),
    )


def create_empty_p95_total_execution_time() -> Tuple[Dict[str, Any], pd.DataFrame]:
    return (
        {"p95TotalAgentCallExecutionTimeSecondsPerRow": None},
        pd.DataFrame({"totalAgentCallExecutionTimeSeconds": pd.Series(dtype=np.dtype(np.float64))}),
    )


def _get_duration(prompt_recipe_raw_response: str) -> Optional[float]:
    """
    Get the total elapsed time from the trajectory. Some overlap with _read_tool_calls_from_json_cell (as we re-extract
    the trajectory), but we'd rather that one remain specific to tool calls.
    """
    trajectory = _get_trajectory(prompt_recipe_raw_response)
    if not trajectory:
        return None

    duration = trajectory.get("durationMs")
    if duration is None:
        return None

    return duration / 1000
