
import logging

import numpy as np
import pandas as pd
from typing import List, Tuple, Dict, Iterable

from dataiku.doctor.diagnostics import diagnostics
from dataiku.llm.evaluation.exceptions import LLMEvalException, MissingColumnException
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput, GenAIMetricInputRole

logger = logging.getLogger(__name__)


def filter_null_rows(columns: GenAIMetricInput, metric_names_by_input_role: Dict[GenAIMetricInputRole, List[str]], metric_names: List[str]) -> GenAIMetricInput:
    if not any(metric_names_by_input_role):
        return columns

    rows_without_null = None
    column_names = []
    for input_role in metric_names_by_input_role.keys():
        col = columns.get(input_role)
        if col is None or col.isnull().all():
            raise MissingColumnException(
                "Error computing %s: those metrics require column '%s', which is missing or empty."
                % (metric_names_by_input_role[input_role], input_role.value))
        column_names.append(col)
        rows_without_null = rows_without_null & col.notnull() if rows_without_null is not None else col.notnull()
    rows_with_null = ~rows_without_null
    null_indexes = rows_with_null[rows_with_null].index
    if null_indexes.any():
        max_print = 20
        error_message = \
            ("Warning computing %s: some rows are missing value for one of the following columns: %s. Dismissing "
             "them from computation. Faulty rows: %s."
             % (metric_names,
                metric_names_by_input_role,
                null_indexes[0:max_print].to_list() if len(null_indexes) > max_print else null_indexes.to_list()))
        if len(null_indexes) > max_print:
            error_message += " (and %s other rows)" % (len(null_indexes) - max_print)
        logger.warning(error_message)
        diagnostics.add_or_update(
            diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR,
            error_message
        )
    if rows_with_null.all():
        raise LLMEvalException(
            "Error computing %s: all rows have at least one empty value on the required columns %s. Can't compute metrics."
            % (metric_names, column_names))

    result = GenAIMetricInput.from_series(
        columns.input[rows_without_null] if columns.input is not None else None,
        columns.output[rows_without_null] if columns.output is not None else None,
        columns.ground_truth[rows_without_null] if columns.ground_truth is not None else None,
        columns.context[rows_without_null] if columns.context is not None else None,
        columns.actual_tool_calls[rows_without_null] if columns.actual_tool_calls is not None else None,
        columns.reference_tool_calls[rows_without_null] if columns.reference_tool_calls is not None else None)
    return result


def raise_or_continue(e, metric_name, fail_on_errors, input_format):
    if e.original_exception is not None and isinstance(e.original_exception, MissingColumnException):
        if input_format == 'PROMPT_RECIPE':
            e = LLMEvalException(
                e.message +
                " Make sure \"Raw query output mode\" and \"Raw response output mode\" in your Prompt recipe are not set to \"None\"."
                " If computing context-based metrics, make sure that your Prompt recipe uses a Retrieval-augmented LLM with "
                "\"Source output format\" set to \"Separated\".",
                e)
        elif input_format == 'DATAIKU_ANSWERS':
            e = LLMEvalException(
                e.message +
                " Make sure the \"Retrieval Method\" is set to 'Use knowledge bank retrieval' in Dataiku Answers' settings.",
                e)

    if fail_on_errors:
        raise e
    else:
        explicit_error = "Error computing %s metric: %s." % (metric_name, str(e))
        logger.error(explicit_error + " Stop on errors is not enabled, carrying on with other metrics.")
        diagnostics.add_or_update(
            diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR,
            explicit_error
        )


def warn(explicit_warning, raise_diagnostic=True):
    logger.warning(explicit_warning)
    if raise_diagnostic:
        diagnostics.add_or_update(
            diagnostics.DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR,
            explicit_warning
        )


def create_empty_metrics(columns: GenAIMetricInput, metrics: Iterable[str]) -> Tuple[Dict[str, None], pd.DataFrame]:
    empty_perf = {metric: None for metric in metrics}
    empty_row_by_row = pd.DataFrame({metric: pd.Series(None, columns.input.index, dtype=np.dtype(np.float64)) for metric in metrics})
    return empty_perf, empty_row_by_row
