import json
import logging
import base64
import pandas as pd
import numpy as np

from datasets import Dataset
from dataiku import Folder

from ragas import evaluate as ragas_evaluate
from ragas.metrics import MultiModalRelevance
from requests import HTTPError

logger = logging.getLogger("multimodal_relevance")

image_cache = {}

def read_multi_modal_context_from_prompt_recipe(json_string: str) -> list:
    try:
        raw_json = json.loads(json_string)
        sources = raw_json.get('sources', [])
        excerpts = [s.get('excerpt', {}) for s in sources]
        contexts = []
        for e in excerpts:
            if e.get('type') == 'TEXT':
                contexts.append(e.get('text'))
            elif e.get('type') == 'IMAGE_REF':
                for image in e.get('images'):
                    project_key, lookup = image["fullFolderId"].split(".", 1)
                    image_path = image["path"]
                    if image_path not in image_cache:
                        try:
                            folder = Folder(lookup, project_key, ignore_flow=True)
                            with folder.get_download_stream(image_path) as image_file:
                                image_cache[image_path] = base64.b64encode(image_file.read()).decode("utf8")
                        except HTTPError as e:
                            logger.warning("Error when retrieving file {image_path} in folder {folder_id} : {err_msg}".format(image_path=image_path, folder_id=image["fullFolderId"], err_msg=str(e)))
                            continue
                    contexts.append(image_cache[image_path])
        return [c for c in contexts if c]
    except Exception as e:
        logger.warning("Error when trying to parse the row {json_string} : {err_msg}".format(json_string=json_string, err_msg=str(e)))
        return []

def evaluate(input_df, recipe_params, interpreted_columns, **kwargs):
    ragas_df = pd.DataFrame()
    ragas_df["user_input"] = interpreted_columns.input
    ragas_df["response"] = interpreted_columns.output
    ragas_df["retrieved_contexts"] = input_df[recipe_params.context_column_name].apply(read_multi_modal_context_from_prompt_recipe)

    ret = ragas_evaluate(
        Dataset.from_pandas(ragas_df),
        metrics=[MultiModalRelevance()],
        llm=recipe_params.completion_llm,
        embeddings=recipe_params.embedding_llm
    )

    row_by_row_metrics = ret["relevance_rate"]
    global_metric = np.nanmean(row_by_row_metrics)

    return global_metric, row_by_row_metrics
