/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.evaluation.llm;

import com.dataiku.dip.analysis.model.core.GenAiCustomEvaluationMetric;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.recipes.RecipeSchemaComputer;
import com.dataiku.dip.recipes.nlp.evaluation.AbstractGenAIEvaluationRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.evaluation.GenAIEvaluationUtils;
import com.dataiku.dip.recipes.nlp.evaluation.llm.LLMEvaluationRecipePayloadParams;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;

public class LLMEvaluationRecipeSchemaComputer
extends RecipeSchemaComputer
implements RecipeSchemaComputer.RecipeSchemaComputerWithPayload {
    @Autowired
    protected TransactionService transactionService;
    @Autowired
    protected DatasetsDAO datasetsDAO;
    private LLMEvaluationRecipePayloadParams params;
    private static final DKULogger logger = DKULogger.getLogger((String)"recipes.nlp.evaluation.llm.schema");

    public LLMEvaluationRecipeSchemaComputer(AuthCtx authCtx, JobActivity activity) {
        super(authCtx, activity);
    }

    @Override
    public void setPayload(String payload) {
        this.params = (LLMEvaluationRecipePayloadParams)JSON.parse((String)payload, LLMEvaluationRecipePayloadParams.class);
    }

    @Override
    public List<Schema> getSchemasForOutputRole_NT(String role) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            this.recipesValidationService.checkComplianceWithRecipeDesc(this.authCtx, this.recipe);
        }
        return switch (role) {
            case "main" -> {
                if (this.recipe.getOutputsForRole("main").isEmpty()) {
                    yield Collections.emptyList();
                }
                yield Lists.newArrayList((Object[])new Schema[]{this.getMainSchema_NT()});
            }
            case "metrics" -> {
                if (this.recipe.getOutputsForRole("metrics").isEmpty()) {
                    yield Collections.emptyList();
                }
                yield Lists.newArrayList((Object[])new Schema[]{this.getMetricsSchema()});
            }
            case "evaluationStore" -> Collections.emptyList();
            default -> throw new IllegalArgumentException(String.format("Output role %s is not compatible with the LLM Evaluation recipe.", role));
        };
    }

    protected Schema getMainSchema_NT() throws Exception {
        if (this.recipe.getInputsUnsafe().get((Object)"main").items.isEmpty()) {
            throw new IllegalArgumentException("The LLM Evaluation recipe requires an input dataset.");
        }
        Schema schema = this.getInputSchemaCopy();
        schema.addColumns(LLMEvaluationRecipeSchemaComputer.getMetricsColumns(this.params.metrics, this.params.inputFormat));
        for (GenAiCustomEvaluationMetric m : this.params.customMetrics) {
            schema.addColumn(new SchemaColumn(m.name, Type.DOUBLE));
        }
        if (this.params.inputFormat == AbstractGenAIEvaluationRecipePayloadParams.GenAiEvalInputFormat.PROMPT_RECIPE) {
            schema.addColumn("dkuReconstructedInput", Type.STRING);
            schema.addColumn("dkuParsedOutput", Type.STRING);
            schema.addColumn("dkuParsedContexts", Type.STRING);
            schema.addColumns(GenAIEvaluationUtils.TOKEN_COUNT_METRIC_OUTPUT_SCHEMA_COLUMNS);
        } else if (this.params.inputFormat == AbstractGenAIEvaluationRecipePayloadParams.GenAiEvalInputFormat.DATAIKU_ANSWERS) {
            schema.addColumn("dkuParsedContexts", Type.STRING);
        }
        logger.infoV("Output dataset schema : %s", new Object[]{schema});
        return schema;
    }

    protected Schema getMetricsSchema() {
        Schema schema = new Schema();
        schema.addColumn(new SchemaColumn("date", Type.DATE));
        schema.addColumn(new SchemaColumn("sampleRowCount", Type.INT));
        schema.addColumns(LLMEvaluationRecipeSchemaComputer.getMetricsColumns(this.params.metrics, this.params.inputFormat));
        if (this.params.inputFormat == AbstractGenAIEvaluationRecipePayloadParams.GenAiEvalInputFormat.PROMPT_RECIPE) {
            schema.addColumns(GenAIEvaluationUtils.TOKEN_COUNT_METRIC_METRIC_SCHEMA_COLUMNS);
        }
        for (GenAiCustomEvaluationMetric m : this.params.customMetrics) {
            schema.addColumn(new SchemaColumn(m.name, Type.DOUBLE));
        }
        logger.infoV("Metrics schema: %s", new Object[]{schema});
        return schema;
    }

    private static List<SchemaColumn> getMetricsColumns(List<String> selectedMetrics, AbstractGenAIEvaluationRecipePayloadParams.GenAiEvalInputFormat inputFormat) {
        ArrayList<SchemaColumn> columns = new ArrayList<SchemaColumn>();
        for (String s : selectedMetrics) {
            if ("multimodalFaithfulness".equals(s) || "multimodalRelevancy".equals(s)) {
                if (!AbstractGenAIEvaluationRecipePayloadParams.GenAiEvalInputFormat.PROMPT_RECIPE.equals((Object)inputFormat)) {
                    logger.infoV("The input format of the evaluation dataset is not Prompt Recipe (is %s): not considering multimodal metric %s", new Object[]{inputFormat, s});
                    continue;
                }
                columns.add(new SchemaColumn(s, Type.DOUBLE));
                continue;
            }
            if ("bertScore".equals(s)) {
                columns.addAll(GenAIEvaluationUtils.BERT_SCORE_METRIC_SCHEMA_COLUMNS);
                continue;
            }
            if ("rouge".equals(s)) {
                columns.addAll(GenAIEvaluationUtils.ROUGE_METRIC_SCHEMA_COLUMNS);
                continue;
            }
            columns.add(new SchemaColumn(s, Type.DOUBLE));
        }
        return columns;
    }

    protected Schema getInputSchemaCopy() throws Exception {
        Dataset inputDataset;
        try (Transaction t = this.transactionService.beginRead();){
            AnyLoc inputDatasetLoc = this.recipe.getSingleInput("main").getLoc(this.recipe.getProjectKey());
            inputDataset = this.datasetAccessService.getMandatoryUnsafe(inputDatasetLoc);
        }
        return inputDataset.getSchema().getCopy();
    }
}

