import base64
import logging
from urllib.error import HTTPError

import numpy as np
import pandas as pd

from langchain_core.callbacks import BaseCallbackHandler
from ragas import evaluate, RunConfig
from ragas.metrics.base import Metric
from ragas.metrics import AnswerRelevancy, Faithfulness, ContextRecall, ContextPrecision, AnswerCorrectness, AnswerSimilarity, MultiModalRelevance, MultiModalFaithfulness
from typing import List, Optional, Tuple, Collection, Dict
from datasets import Dataset
from datasets.features.features import Sequence

from dataiku import Folder
from dataiku.langchain.dku_embeddings import DKUEmbeddings, TraceableDKUEmbeddings
from dataiku.llm.evaluation.exceptions import RagasException
from dataiku.llm.evaluation.llm_metrics_input import LLMMetricInput, LLMMetricInputRole
from dataiku.llm.evaluation.utils import failure_utils
from dataiku.llm.evaluation.utils.metrics_utils import get_llm_args, CompletionTraceHandler
from dataiku.llm.evaluation.utils.ragas.ragas_compatible_llm import RagasCompatibleLLM
from dataiku.llm.tracing import SpanBuilder
from dataiku.llm.types import CompletionSettings

# This map values could also be lambdas. We can't store instances directly as we
# require new instances at each metric computation run to prevent race conditions.
RAGAS_METRICS_GENERATOR_MAP = {
    "answerRelevancy": AnswerRelevancy,
    "answerCorrectness": AnswerCorrectness,
    "answerSimilarity": AnswerSimilarity,
    "faithfulness": Faithfulness,
    "contextRecall": ContextRecall,
    "contextPrecision": ContextPrecision,
    "multimodalRelevancy": MultiModalRelevance,
    "multimodalFaithfulness": MultiModalFaithfulness
}

RAGAS_OUTPUT_METRICS_MAP = {
    "answer_relevancy": "answerRelevancy",
    "semantic_similarity": "answerSimilarity",
    "faithfulness": "faithfulness",
    "context_recall": "contextRecall",
    "context_precision": "contextPrecision",
    "answer_correctness": "answerCorrectness",
    "answer_similarity": "answerSimilarity",
    "relevance_rate": "multimodalRelevancy",
    "faithful_rate": "multimodalFaithfulness",
}

RAGAS_METRICS_WITH_GROUND_TRUTH = {"contextPrecision", "contextRecall", "answerCorrectness", "answerSimilarity"}
RAGAS_METRICS_WITH_CONTEXT = {"faithfulness", "answerRelevancy", "contextPrecision", "contextRecall"}
RAGAS_METRICS_WITH_MULTIMODAL_CONTEXT = {"multimodalFaithfulness", "multimodalRelevancy"}
RAGAS_METRICS_WITH_OUTPUT = {"faithfulness", "answerRelevancy", "answerSimilarity", "answerCorrectness", "multimodalFaithfulness", "multimodalRelevancy"}

logger = logging.getLogger(__name__)


class RagasMetricsComputer(object):
    llm: RagasCompatibleLLM
    embeddings_model_id: str
    max_workers: int

    def __init__(self, completion_llm_id: str, completion_settings: CompletionSettings, embedding_llm_id: str, max_workers: int, fail_on_row_level_errors: bool,
                 can_compute_multimodal_metrics: bool = False):
        self.llm = RagasCompatibleLLM(llm_id=completion_llm_id, **get_llm_args(completion_settings))
        self.embeddings_model_id = embedding_llm_id
        self.max_workers = max_workers
        self.fail_on_row_level_errors = fail_on_row_level_errors
        self.can_compute_multimodal_metrics = can_compute_multimodal_metrics
        logger.info(f"Ragas metrics will be computed with completion LLM {completion_llm_id} and embedding LLM {embedding_llm_id}, "
                    f"with a max of {max_workers} workers")

    def compute(self, interpreted_columns: LLMMetricInput, metrics: List[str], trace: Optional[SpanBuilder] = None) -> Tuple[dict, pd.DataFrame]:
        if not self.can_compute_multimodal_metrics:
            metrics = [metric for metric in metrics if metric not in RAGAS_METRICS_WITH_MULTIMODAL_CONTEXT]
        try:
            ragas_metrics_keys_to_compute = set(metrics) & RAGAS_METRICS_GENERATOR_MAP.keys()
            with_ground_truth_metrics = ragas_metrics_keys_to_compute & RAGAS_METRICS_WITH_GROUND_TRUTH
            with_context_metrics = ragas_metrics_keys_to_compute & RAGAS_METRICS_WITH_CONTEXT
            with_multimodal_context_metrics = ragas_metrics_keys_to_compute & RAGAS_METRICS_WITH_MULTIMODAL_CONTEXT
            with_output_metrics = ragas_metrics_keys_to_compute & RAGAS_METRICS_WITH_OUTPUT

            if with_context_metrics and with_multimodal_context_metrics:
                raise RagasException("Can't compute multimodal and textual context metrics (faithfulness, relevancy, precision, recall) simultaneously. Select only the metrics suitable to your context type.")

            # empty values will crash. Avoid them.
            # Note that, in theory, we could compute some metrics even if some values are empty (e.g. answer relevancy don't need context)
            # We don't bother, and rule out the entire row instead
            ragas_metric_to_compute = get_ragas_metrics(ragas_metrics_keys_to_compute)
            metric_names = [metric.name for metric in ragas_metric_to_compute]
            metric_names_by_column_role = {LLMMetricInputRole.INPUT: metric_names}
            if with_output_metrics:
                metric_names_by_column_role[LLMMetricInputRole.OUTPUT] = [metric.name for metric in get_ragas_metrics(with_output_metrics)]
            if with_ground_truth_metrics:
                metric_names_by_column_role[LLMMetricInputRole.GROUND_TRUTH] = [metric.name for metric in get_ragas_metrics(with_ground_truth_metrics)]
            if with_context_metrics:
                metric_names_by_column_role[LLMMetricInputRole.CONTEXT] = [metric.name for metric in get_ragas_metrics(with_context_metrics)]

            filtered_columns = failure_utils.filter_null_rows(interpreted_columns, metric_names_by_column_role, metric_names)
            initial_index = filtered_columns.input.index  # Keep index for output

            ragas_input_df = pd.DataFrame()
            ragas_input_df['user_input'] = filtered_columns.input

            if with_output_metrics:
                ragas_input_df['response'] = filtered_columns.output
            if with_ground_truth_metrics:
                ground_truth_type = type(filtered_columns.ground_truth.iloc[0])
                if ground_truth_type == list or ground_truth_type == np.ndarray:
                    raise RagasException("Ragas metrics do not support multiple ground truths : the ground truth column needs to be of type string and not %s" % str(ground_truth_type))
                ragas_input_df['reference'] = filtered_columns.ground_truth
            if with_context_metrics:
                if not isinstance(filtered_columns.context.iloc[0], list):
                    if not isinstance(filtered_columns.context.iloc[0], str):
                        raise RagasException("The context column '%s' must be of type string (or array of strings), got: %s" % (interpreted_columns.context.name, filtered_columns.context[0]))
                    ragas_input_df['retrieved_contexts'] = filtered_columns.context.apply(lambda x: [x])
                else:
                    ragas_input_df['retrieved_contexts'] = filtered_columns.context
            elif with_multimodal_context_metrics:
                # No checks needed as it should be on the output of prompt recipe
                ragas_input_df['retrieved_contexts'] = filtered_columns.context.apply(read_multimodal_context_from_prompt_recipe)

            logger.info(f"The following RAGAS metric will be computed : {metric_names}")

            ragas_dataset = Dataset.from_pandas(ragas_input_df)
            ragas_column_mapping = {
                "user_input": interpreted_columns.input.name if interpreted_columns.input is not None else None,
                "response": interpreted_columns.output.name if interpreted_columns.output is not None else None,
                "reference": interpreted_columns.ground_truth.name if interpreted_columns.ground_truth is not None else None,
                "retrieved_contexts": interpreted_columns.context.name if interpreted_columns.context is not None else None
            }
            validate_column_dtypes(ragas_dataset, ragas_column_mapping, any(with_multimodal_context_metrics))

            # Langchain does not support callbacks on embedding, so we rely on our own mechanism by passing the trace to the `TraceableDKUEmbeddings` ctor.
            embeddings = DKUEmbeddings(llm_id=self.embeddings_model_id) if trace is None else TraceableDKUEmbeddings(trace, llm_id=self.embeddings_model_id)
            # For traces on completions, however, we can rely on Langchain callbacks
            callbacks: List[BaseCallbackHandler] = [] if trace is None else [CompletionTraceHandler(trace)]
            ret = evaluate(
                ragas_dataset,
                metrics=ragas_metric_to_compute,
                llm=self.llm,
                embeddings=embeddings,
                raise_exceptions=self.fail_on_row_level_errors,
                run_config=RunConfig(max_workers=self.max_workers, max_retries=3, timeout=600, max_wait=1800),
                callbacks=callbacks,
            )

            # Try to warn the user, but best effort, openAI error messages are not very stable
            if any('"Unsupported value: \'temperature\'' in e for e in self.llm.errors):
                raise RagasException('temperature is not supported with this model. Clear the temperature from the Advanced tab')
            self.llm.errors = []

            logger.info(f"Global Ragas metrics result : {str(ret)}")
            ret_dict = ret._repr_dict
            ret_dict_keys = ret_dict.keys()
            return ({RAGAS_OUTPUT_METRICS_MAP[metric]: ret_dict[metric] if not np.isnan(ret_dict[metric]) else None for metric in ret_dict_keys},
                    ret.to_pandas().set_index(initial_index)[ret_dict_keys].rename(columns=RAGAS_OUTPUT_METRICS_MAP))

        except Exception as e:
            raise RagasException("An error happened during the computation of RAGAS metrics : %s" % str(e), e)


def check_use_ragas_metrics(metrics: List[str], has_ground_truth: bool, has_context: bool, completion_llm_id: str, embedding_llm_id: str,
                            can_compute_multimodal_metrics: bool) -> None:
    """
    Some ragas metrics require the ground truth, context etc.
    This method asserts that the chosen metrics by the user are coherent with its dataset.
    """
    ground_truth_based_selected_metrics = list(set(metrics) & RAGAS_METRICS_WITH_GROUND_TRUTH)
    if ground_truth_based_selected_metrics and not has_ground_truth:
        raise RagasException("The following metrics require a ground truth column : %s" % str(ground_truth_based_selected_metrics))

    context_based_selected_metrics = list(set(metrics) & RAGAS_METRICS_WITH_CONTEXT)
    if context_based_selected_metrics and not has_context:
        raise RagasException("The following metrics require a context column : %s" % str(context_based_selected_metrics))

    if can_compute_multimodal_metrics:
        multimodal_context_based_selected_metrics = list(set(metrics) & RAGAS_METRICS_WITH_MULTIMODAL_CONTEXT)
        if multimodal_context_based_selected_metrics and not has_context:
            raise RagasException("The following metrics require a multimodal context column : %s" % str(multimodal_context_based_selected_metrics))


    if not completion_llm_id or not embedding_llm_id:
        raise RagasException("You need to select both an Embedding LLM and a Completion LLM to compute : %s. Please verify your recipe configuration." % list(set(metrics) & RAGAS_METRICS_GENERATOR_MAP.keys()))

def has_context_based_metrics(metrics: List[str]) -> bool:
    return len(set(metrics) & RAGAS_METRICS_WITH_CONTEXT) > 0

def has_ragas_metrics(metrics: List[str]) -> bool:
    """
    Checks if there is at least one ragas metric in the given metric list
    """
    return len(set(metrics) & set(RAGAS_METRICS_GENERATOR_MAP.keys())) > 0


def get_ragas_metrics(metrics: Collection[str]) -> List[Metric]:
    ret = []
    for metric in metrics:
        if metric in RAGAS_METRICS_GENERATOR_MAP.keys():
            metric_instance = RAGAS_METRICS_GENERATOR_MAP[metric]()
            ret.append(metric_instance)
    return ret


def create_empty_ragas_metrics(interpreted_columns: LLMMetricInput, metrics: List[str]):
    metric_names = [metric.name for metric in get_ragas_metrics(metrics)]
    return failure_utils.create_empty_metrics(interpreted_columns, metric_names)


def validate_column_dtypes(ds: Dataset, column_mapping: dict, expects_multimodal_context: bool):
    for column_names in ["user_input", "response", "reference"]:
        if column_names in ds.features:
            column_dtype = ds.features[column_names].dtype
            if column_dtype != "string":
                raise ValueError(
                    f'Dataset feature "{column_mapping[column_names]}" should be of type string, got {column_dtype}'
                )

    if "retrieved_contexts" in ds.features:
        if expects_multimodal_context:
            if not any(ds["retrieved_contexts"]):
                raise ValueError(
                    f"Can't get multimodal context: unable to read images from Dataset feature \"{column_mapping['retrieved_contexts']}\". It should contain paths to images from a managed folder.")
        else:
            if not (
                    isinstance(ds.features["retrieved_contexts"], Sequence)
                    and hasattr(ds.features["retrieved_contexts"], "feature")
                    and ds.features["retrieved_contexts"].feature.dtype == "string"
            ):
                raise ValueError(
                    f"Can't get textual context from Dataset feature \"{column_mapping['retrieved_contexts']}\", it should be of type string (or array of strings).")


def read_multimodal_context_from_prompt_recipe(sources: List[Dict]) -> List[str]:
    image_cache = {}
    try:
        excerpts = [s.get('excerpt', {}) for s in sources]
        contexts = []
        for e in excerpts:
            if e.get('type') == 'TEXT':
                contexts.append(e.get('text'))
            elif e.get('type') == 'IMAGE_REF':
                for image in e.get('images'):
                    project_key, lookup = image["fullFolderId"].split(".", 1)
                    image_path = image["path"]
                    if image_path not in image_cache:
                        try:
                            folder = Folder(lookup, project_key, ignore_flow=True)
                            with folder.get_download_stream(image_path) as image_file:
                                image_cache[image_path] = base64.b64encode(image_file.read()).decode("utf8")
                        except HTTPError as e:
                            logger.warning("Error when retrieving file {image_path} in folder {folder_id} : {err_msg}".format(image_path=image_path, folder_id=image["fullFolderId"], err_msg=str(e)))
                            continue
                    contexts.append(image_cache[image_path])
        return [c for c in contexts if c]
    except Exception as e:
        logger.warning("Error when trying to parse multimodal sources on the row {sources} : {err_msg}".format(sources=sources, err_msg=str(e)))
        return []
