import json
import logging

import pandas as pd
from typing import Dict, Tuple, Optional

from dataiku.llm.evaluation.exceptions import TokenCountException
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput
from dataiku.llm.evaluation.genai_eval_recipe_desc import GenAIEvalRecipeDesc
from dataiku.llm.evaluation.utils import failure_utils, prompt_recipe_utils

logger = logging.getLogger(__name__)


def compute_token_count(input_df: pd.DataFrame, input_format: str) -> Tuple[dict, pd.DataFrame]:
    try:
        # no filter on empty values. We'll have 'None' tokens for them, and carry on.
        if input_format != 'PROMPT_RECIPE':
            raise TokenCountException("Token Count are only supported for Prompt Recipe Input format")
        token_counts = _try_get_token_counts(input_df)
        avg_token_counts = {
            "inputTokensPerRow": mean_or_none(token_counts['inputTokensPerRow']),
            "outputTokensPerRow": mean_or_none(token_counts['outputTokensPerRow']),
        }

        logger.info("Token Count results : %s" % str(avg_token_counts))
        return avg_token_counts, token_counts
    except Exception as e:
        raise TokenCountException("An error happened during the computation of Token Count metrics : %s" % str(e), e)


def can_token_count(recipe_desc: GenAIEvalRecipeDesc) -> bool:
    return recipe_desc.input_format == 'PROMPT_RECIPE' and recipe_desc.output_column_name == 'llm_raw_response'


def create_empty_token_count(metric_inputs: GenAIMetricInput) -> Tuple[Dict[str, None], pd.DataFrame]:
    return failure_utils.create_empty_metrics(metric_inputs, ["inputTokensPerRow", "outputTokensPerRow"])


def _get_input_tokens(json_obj: dict):
    if json_obj is None:
        return None
    return json_obj.get('promptTokens', None)


def _get_output_tokens(json_obj: dict):
    if json_obj is None:
        return None
    return json_obj.get('completionTokens', None)


def _sum_tokens_not_none(tokens: int, reported_tokens: int):
    if tokens is None and reported_tokens is None:
        return None
    else:
        total_input_tokens = 0
        if tokens is not None:
            total_input_tokens += tokens
        if reported_tokens is not None:
            total_input_tokens += reported_tokens
    return total_input_tokens


def _read_token_count_from_json_cell(json_string: str):
    try:
        raw_json = json.loads(json_string)
        input_tokens = _get_input_tokens(raw_json)
        output_tokens = _get_output_tokens(raw_json)
        reported_json = raw_json.get('reportedUsageMetadata', None)
        reported_input_tokens = _get_input_tokens(reported_json)
        reported_output_tokens = _get_output_tokens(reported_json)
        total_input_tokens = _sum_tokens_not_none(input_tokens, reported_input_tokens)
        total_output_tokens = _sum_tokens_not_none(output_tokens, reported_output_tokens)
        return total_input_tokens, total_output_tokens
    except:
        return None, None  # No logs please


def _try_get_token_counts(input_df: pd.DataFrame) -> Optional[pd.DataFrame]:
    """
    Try to get token counts from a PromptRecipe's "raw" output
    :param input_df:
    :return: pd.Series with the token counts. Two Series, 'inputTokensPerRow' and 'outputTokensPerRow'
    """
    raw_response = input_df[prompt_recipe_utils.PROMPT_RECIPE_RAW_RESPONSE_NAME]
    if raw_response is None:
        raise TokenCountException('Can\'t find column "%s". Check that your input dataset was produced by a prompt recipe with "Raw response output mode" set to "Raw" or "Raw without traces"' % prompt_recipe_utils.PROMPT_RECIPE_RAW_RESPONSE_NAME)
    logger.info('Column "%s" is from a prompt recipe, trying to parse it for token counts' % prompt_recipe_utils.PROMPT_RECIPE_RAW_RESPONSE_NAME)
    input_and_output_tokens = raw_response.apply(_read_token_count_from_json_cell).tolist()
    if any(input_and_output_tokens):
        logger.info('Found token counts in "%s", from a prompt recipe. Parsing it.' % prompt_recipe_utils.PROMPT_RECIPE_RAW_RESPONSE_NAME)
        return pd.DataFrame(input_and_output_tokens, columns=['inputTokensPerRow', 'outputTokensPerRow'])
    else:
        raise TokenCountException('Column "%s" does not contain token counts from a prompt recipe.' % prompt_recipe_utils.PROMPT_RECIPE_RAW_RESPONSE_NAME)


def mean_or_none(pd_series):
    mean = pd_series.mean(skipna=True)
    return None if pd.isna(mean) else mean
