# -*- coding: utf-8 -*-
import dataiku
from sklearn import metrics
from dataikuapi.dss.modelevaluationstore import DSSModelEvaluationStore
from datetime import datetime

# Initialize DSS client and get the current project
client = dataiku.api_client()
project = client.get_default_project()

# Load the prepared dataset containing ground truth and model predictions
annotations_gt_joined_prepared = dataiku.Dataset("joined_outputs_formatted")
annotations_gt_joined_prepared_df = annotations_gt_joined_prepared.get_dataframe()

# ID of the Model Evaluation Store
EVALUATION_STORE_ID = "8qixI9fR"


def add_mes_item(df, prefix_predictions, name_model, prefix_gt=""):
    """
    Add custom model evaluation metrics to the Model Evaluation Store.

    Args:
        df (pd.DataFrame): DataFrame containing ground truth and predicted values.
        prefix_predictions (str): Prefix for the prediction columns (e.g., 'GPT_', 'qwen_').
        name_model (str): Name of the model to be evaluated.
        prefix_gt (str, optional): Prefix for ground truth columns. Defaults to "".

    Returns:
        None
    """

    # Calculate accuracy for each field: company, address, total (as string), and date
    scores = [
        DSSModelEvaluationStore.MetricDefinition(
            code="COMPANY_accuracy",
            value=metrics.accuracy_score(
                df[f"{prefix_gt}company"], df[f"{prefix_predictions}company"]
            ),
            name="COMPANY_accuracy",
            description="Accuracy for the 'company' field",
        ),
        DSSModelEvaluationStore.MetricDefinition(
            code="ADDRESS_accuracy",
            value=metrics.accuracy_score(
                df[f"{prefix_gt}address"], df[f"{prefix_predictions}address"]
            ),
            name="ADDRESS_accuracy",
            description="Accuracy for the 'address' field",
        ),
        DSSModelEvaluationStore.MetricDefinition(
            code="TOTAL_accuracy",
            value=metrics.accuracy_score(
                df[f"{prefix_gt}total"].astype(str),
                df[f"{prefix_predictions}total"].astype(str),
            ),
            name="TOTAL_accuracy",
            description="Accuracy for the 'total' field",
        ),
        DSSModelEvaluationStore.MetricDefinition(
            code="DATE_accuracy",
            value=metrics.accuracy_score(
                df[f"{prefix_gt}date"], df[f"{prefix_predictions}date"]
            ),
            name="DATE_accuracy",
            description="Accuracy for the 'date' field",
        ),
    ]

    # Calculate global accuracy as the average of all field accuracies
    global_score = DSSModelEvaluationStore.MetricDefinition(
        code="GLOBAL_accuracy",
        value=sum([x["value"] for x in scores]) / len(scores),
        name="GLOBAL_accuracy",
        description="Average accuracy across all fields",
    )

    # Append the global accuracy score to the scores list
    scores.append(global_score)

    # Add evaluation metadata: timestamp and label
    eval_timestamp = datetime.now().isoformat()
    label = DSSModelEvaluationStore.LabelDefinition("evaluation:date", eval_timestamp)

    # Retrieve the Model Evaluation Store and add custom evaluation scores
    mes = project.get_model_evaluation_store(EVALUATION_STORE_ID)
    mes.add_custom_model_evaluation(scores, name=name_model, labels=[label])


# Call the function to add evaluations for different models
add_mes_item(
    annotations_gt_joined_prepared_df,
    prefix_gt="",
    prefix_predictions="GPT_",
    name_model="GPT-4o",
)
add_mes_item(
    annotations_gt_joined_prepared_df,
    prefix_gt="",
    prefix_predictions="qwen_",
    name_model="QWEN-VL2",
)
