/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.llm_evaluation;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.coreservices.PredictionService;
import com.dataiku.dip.analysis.ml.MLDiagnostics;
import com.dataiku.dip.analysis.ml.ModelLikeId;
import com.dataiku.dip.analysis.ml.prediction.flow.EvaluationDatasetHelper;
import com.dataiku.dip.analysis.ml.shared.EvaluationLabelsHelper;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.code.CodeEnvModel;
import com.dataiku.dip.code.CodeEnvSelector;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.JobAuthCtxService;
import com.dataiku.dip.dataflow.RecipeRunnableSubgraph;
import com.dataiku.dip.dataflow.exec.AbstractPythonRecipeRunner;
import com.dataiku.dip.dataflow.exec.ContainerRecipeParams;
import com.dataiku.dip.dataflow.exec.FinalCommitable;
import com.dataiku.dip.dataflow.exec.FlowRunnable;
import com.dataiku.dip.dataflow.exec.PreRunSchemaPropagationHandler;
import com.dataiku.dip.dataflow.exec.RecipeRunnerWithPayload;
import com.dataiku.dip.dataflow.graph.FlowRecipe;
import com.dataiku.dip.dataflow.graph.utils.GraphUtils;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.dataflow.utils.FlowJobUtils;
import com.dataiku.dip.exceptions.IllegalConfigurationException;
import com.dataiku.dip.license.LicenseRestrictionException;
import com.dataiku.dip.license.LicenseStatusService;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMRefEnricherService;
import com.dataiku.dip.mec.KernelsModelEvaluationStoresService;
import com.dataiku.dip.mec.LLMModelEvaluation;
import com.dataiku.dip.mec.ModelEvaluationStore;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.recipes.InitializableAbortableRecipeRunner;
import com.dataiku.dip.recipes.RecipeRunner;
import com.dataiku.dip.recipes.consistency.RecipeCodes;
import com.dataiku.dip.recipes.nlp.llm_evaluation.LLMEvaluationRecipeParams;
import com.dataiku.dip.recipes.nlp.llm_evaluation.LLMEvaluationRecipePayloadParams;
import com.dataiku.dip.remoterun.RemoteRunsRegistry;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.licensing.AbstractLicenseFeaturesStatusBuilder;
import com.dataiku.dip.server.services.licensing.LicenseFeaturesStatusBuilder;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.AutoDelete;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.mlflow_project.apachecommons.io.FileUtils;
import com.google.gson.JsonObject;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class LLMEvaluationRecipeRunner
implements InitializableAbortableRecipeRunner,
RecipeRunner,
FlowRunnable,
RecipeRunnerWithPayload,
FinalCommitable {
    private static final String TYPE = "llm-evaluation-recipe";
    private final JobActivity activity;
    private final FlowRecipe recipe;
    private LLMEvaluationRecipePayloadParams desc;
    private AuthCtx authCtx;
    private InitializableAbortableRecipeRunner abortableRecipeRunner = null;
    private final RecipeRunnableSubgraph subgraph;
    @Autowired
    private JobAuthCtxService authCtxService;
    @Autowired
    private DatasetsDAO datasetsDAO;
    @Autowired
    private KernelsModelEvaluationStoresService modelEvaluationStoresService;
    @Autowired
    private LicenseStatusService licenseStatusService;
    @Autowired
    private LLMRefEnricherService llmRefEnricherService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.recipes.nlp.llm_evaluation");

    public LLMEvaluationRecipeRunner(JobActivity activity) {
        this.activity = activity;
        this.subgraph = (RecipeRunnableSubgraph)activity.getSubgraph();
        this.recipe = this.subgraph.getRecipe();
        this.activity.initStatus();
    }

    @Override
    public void notifyBeforeAborting() {
        if (this.abortableRecipeRunner != null) {
            this.abortableRecipeRunner.notifyBeforeAborting();
        }
    }

    @Override
    public void finalCommit() throws Exception {
    }

    @Override
    public void run() throws Exception {
        File modelEvaluationFolder;
        LLMModelEvaluation modelEvaluation;
        List wrongMinMax;
        int ragasMaxWorkers;
        boolean hasModelEvaluationStore;
        LicenseStatusService.LicensingStatus ls = this.licenseStatusService.getLicensingStatus();
        AbstractLicenseFeaturesStatusBuilder.LicenseFeaturesStatus featuresStatus = LicenseFeaturesStatusBuilder.getFeaturesStatus(ls);
        if (!featuresStatus.advancedLLMMeshAllowed) {
            throw new LicenseRestrictionException("LLM Evaluation requires Advanced LLM Mesh, which is not enabled in your license.");
        }
        List<InfoMessage> customMetricsErrors = PredictionService.checkCustomMetricParams(this.desc.customMetrics);
        if (!customMetricsErrors.isEmpty()) {
            throw new IllegalArgumentException(customMetricsErrors.get(0).toString());
        }
        logger.info((Object)"LLM Evaluation recipe runner started");
        AnyLoc evaluationStoreLoc = null;
        SerializedRecipe.RecipeInput recipeInput = this.recipe.getModel().getSingleInput("main");
        SerializedDataset serializedInputDataset = (SerializedDataset)this.datasetsDAO.getMandatory(recipeInput.getLoc(this.recipe.getProjectKey()));
        final Dataset inputDataset = Dataset.fromSerialized(serializedInputDataset);
        SerializedRecipe.RecipeOutput outputDatasetRO = this.recipe.getModel().getSingleOutputOrNull("main");
        final String outputDatasetSmartName = outputDatasetRO != null ? outputDatasetRO.getLoc(this.recipe.getProjectKey()).getSmartName(this.recipe.getProjectKey()) : null;
        SerializedRecipe.RecipeOutput metricsDatasetRO = this.recipe.getModel().getSingleOutputOrNull("metrics");
        final String metricsDatasetSmartName = metricsDatasetRO != null ? metricsDatasetRO.getLoc(this.recipe.getProjectKey()).getSmartName(this.recipe.getProjectKey()) : null;
        SerializedRecipe.RecipeOutput mesRO = this.recipe.getModel().getSingleOutputOrNull("evaluationStore");
        if (mesRO != null) {
            evaluationStoreLoc = mesRO.getLoc(this.recipe.getProjectKey());
            hasModelEvaluationStore = true;
        } else {
            hasModelEvaluationStore = false;
        }
        if (outputDatasetSmartName == null && metricsDatasetSmartName == null && evaluationStoreLoc == null) {
            throw new IllegalArgumentException("The recipe needs at least one output.");
        }
        if (this.desc.ragasMaxWorkers != null) {
            if (this.desc.ragasMaxWorkers < 1) {
                throw new IllegalArgumentException("The max number of workers for RAGAS metric computation must be equal to or greater than 1. Please review your advanced recipe configuration.");
            }
            ragasMaxWorkers = this.desc.ragasMaxWorkers;
        } else {
            ragasMaxWorkers = ApplicationConfigurator.getProperty((String)"dku.llmEvaluation.defaultRagasMaxWorkers", (int)9);
        }
        if (CollectionUtils.isNotEmpty(this.desc.customMetrics)) {
            this.authCtx.failIfNoSafeCode("You do not have the required permission to run code");
        }
        if (!(wrongMinMax = this.desc.customMetrics.stream().filter(cm -> cm.minValue != null && cm.maxValue != null && cm.minValue > cm.maxValue).collect(Collectors.toList())).isEmpty()) {
            throw new IllegalArgumentException("Some Custom Metrics have a Minimum value greater than the Maximum value.");
        }
        this.checkNonNullAndEmptyFields(this.desc, inputDataset);
        List<Partition> inputPartitions = this.subgraph.getSourcePartitions(GraphUtils.getSingleSource(this.recipe));
        KernelsModelEvaluationStoresService.DataTypeAndParams datasetTypeAndParams = this.modelEvaluationStoresService.makeDataTypeAndParams(this.recipe.getProjectKey(), inputDataset, inputPartitions);
        if (hasModelEvaluationStore) {
            KernelsModelEvaluationStoresService.LLMEvaluationModelInfo evaluationModelInfo = new KernelsModelEvaluationStoresService.LLMEvaluationModelInfo();
            evaluationModelInfo.evaluationId = this.desc.evaluationId;
            evaluationModelInfo.name = this.desc.evaluationName;
            evaluationModelInfo.limitSampling = false;
            evaluationModelInfo.evaluationDatasetType = EvaluationDatasetHelper.EvaluationDatasetType.CLASSIC;
            evaluationModelInfo.inputFormat = this.desc.inputFormat;
            evaluationModelInfo.taskType = this.desc.llmTaskType;
            evaluationModelInfo.inputColumnName = this.desc.inputColumnName;
            evaluationModelInfo.outputColumnName = this.desc.outputColumnName;
            evaluationModelInfo.groundTruthColumnName = this.desc.groundTruthColumnName;
            evaluationModelInfo.contextColumnName = this.desc.contextColumnName;
            evaluationModelInfo.bertScoreModelType = StringUtils.isNotBlank((String)this.desc.bertScoreModelType) ? this.desc.bertScoreModelType : "bert-base-uncased";
            evaluationModelInfo.bleuTokenizer = StringUtils.isNotBlank((String)this.desc.bleuTokenizer) ? this.desc.bleuTokenizer : "13a";
            evaluationModelInfo.embeddingLLMId = this.desc.embeddingLLMId;
            if (StringUtils.isNotBlank((String)this.desc.embeddingLLMId)) {
                EnrichedLLMStructuredRef embeddingLLM = this.llmRefEnricherService.getEnrichedLLMRef(this.desc.embeddingLLMId, this.authCtx, this.recipe.getProjectKey());
                evaluationModelInfo.embeddingLLMFriendlyName = embeddingLLM.friendlyName;
            }
            evaluationModelInfo.embeddingSettings = this.desc.embeddingSettings;
            evaluationModelInfo.completionLLMId = this.desc.completionLLMId;
            if (StringUtils.isNotBlank((String)this.desc.completionLLMId)) {
                EnrichedLLMStructuredRef completionLLM = this.llmRefEnricherService.getEnrichedLLMRef(this.desc.completionLLMId, this.authCtx, this.recipe.getProjectKey());
                evaluationModelInfo.completionLLMFriendlyName = completionLLM.friendlyName;
            }
            evaluationModelInfo.completionSettings = this.desc.completionSettings;
            evaluationModelInfo.labels = EvaluationLabelsHelper.getLLMEvaluationTimeLabels_T(this.recipe.getProjectKey(), serializedInputDataset, this.desc.labels, this.subgraph.getSourcePartitions(GraphUtils.getSingleSource(this.recipe)), evaluationModelInfo.embeddingLLMFriendlyName, evaluationModelInfo.completionLLMFriendlyName);
            evaluationModelInfo.customMetrics = this.desc.customMetrics;
            evaluationModelInfo.selection = this.desc.selection;
            ModelEvaluationStore mes = this.modelEvaluationStoresService.getMandatory(evaluationStoreLoc.getProjectKey(), evaluationStoreLoc.getId());
            modelEvaluation = this.modelEvaluationStoresService.setupLLMRun(mes, datasetTypeAndParams, evaluationModelInfo);
            modelEvaluationFolder = modelEvaluation.ref.getMainFolder();
            FilesystemACLUtils.grantFSFullACLs(this.authCtx, this.recipe.getProjectKey(), true, modelEvaluationFolder);
        } else {
            modelEvaluation = null;
            modelEvaluationFolder = null;
        }
        try (final AutoDelete outputTmpDir = FlowJobUtils.getTmpFolder(TYPE, "pyrun");){
            JSON.prettyToFile((Object)this.desc, (File)new File((File)outputTmpDir, "desc.json"));
            JobContext.getCurrentActivitySummary().engineType = "PYTHON";
            AbstractPythonRecipeRunner runner = new AbstractPythonRecipeRunner(this.activity){

                @Override
                public void run() throws Exception {
                    FilesystemACLUtils.grantFSFullACLs(LLMEvaluationRecipeRunner.this.authCtx, this.projectKey, new File[]{outputTmpDir});
                    File additionalLogsDir = FlowJobUtils.getJobMadeDir(LLMEvaluationRecipeRunner.TYPE, "additional-logs");
                    File mainLogFile = FlowJobUtils.getJobTouchedFile(LLMEvaluationRecipeRunner.TYPE, "python.log");
                    LLMEvaluationRecipeParams params = this.recipe.getModel().getParamsAs(LLMEvaluationRecipeParams.class);
                    String envName = new CodeEnvSelector().selectForPythonRecipe(this.recipe.getProjectKey(), params.getCodeEnvSelection());
                    CodeEnvModel.UsedCodeEnvRef codeEnvRef = new CodeEnvModel.UsedCodeEnvRef(CodeEnvModel.EnvLang.PYTHON, envName);
                    ContainerExecSelection containerSelection = this.recipe.getModel().getParamsAs(ContainerRecipeParams.class).getContainerSelection();
                    ContainerExecRuntimeConfig containerConfig = new ContainerExecConfigSelector().selectForML_autoTXN(LLMEvaluationRecipeRunner.this.authCtx, this.recipe.getProjectKey(), containerSelection, MLTask.BackendType.PY_MEMORY);
                    logger.info((Object)("Run llm evaluation in code env " + StringUtils.defaultIfBlank((String)envName, (String)"built-in") + " and container " + (containerConfig == null ? "local" : containerConfig.name)));
                    String inputDatasetSmartName = inputDataset.getSmartName(this.recipe.getProjectKey());
                    if (containerConfig == null) {
                        this.executeModule(envName, (File)outputTmpDir, "dataiku.llm.evaluation.llm_evaluation_recipe", outputTmpDir.getAbsolutePath(), hasModelEvaluationStore ? modelEvaluationFolder.getAbsolutePath() : "", inputDatasetSmartName, outputDatasetSmartName != null ? outputDatasetSmartName : "", metricsDatasetSmartName != null ? metricsDatasetSmartName : "", Integer.toString(ragasMaxWorkers));
                    } else {
                        ArrayList<String> readableAndWritablePaths = new ArrayList<String>(List.of(outputTmpDir.getAbsolutePath()));
                        if (hasModelEvaluationStore) {
                            readableAndWritablePaths.add(modelEvaluationFolder.getAbsolutePath());
                        }
                        JsonObject payload = new JsonObject();
                        payload.addProperty("inputDatasetSmartName", inputDatasetSmartName);
                        payload.addProperty("ragasMaxWorkers", (Number)ragasMaxWorkers);
                        if (outputDatasetSmartName != null) {
                            payload.addProperty("outputDatasetSmartName", outputDatasetSmartName);
                        }
                        if (metricsDatasetSmartName != null) {
                            payload.addProperty("metricsDatasetSmartName", metricsDatasetSmartName);
                        }
                        if (modelEvaluationFolder != null) {
                            payload.addProperty("modelEvaluationFolder", modelEvaluationFolder.getAbsolutePath());
                        }
                        switch (containerConfig.type) {
                            case DOCKER: {
                                this.executeDockerCodeRecipe(codeEnvRef, containerConfig, null, mainLogFile, outputTmpDir, RemoteRunsRegistry.ExecutionType.RECIPE_LLM_EVALUATION_PYTHON, payload.toString(), Collections.emptyMap(), readableAndWritablePaths, readableAndWritablePaths);
                                break;
                            }
                            case KUBERNETES: {
                                this.executeKubernetesCodeRecipe(codeEnvRef, containerConfig, null, mainLogFile, additionalLogsDir, outputTmpDir, RemoteRunsRegistry.ExecutionType.RECIPE_LLM_EVALUATION_PYTHON, payload.toString(), Collections.emptyMap(), readableAndWritablePaths, readableAndWritablePaths, () -> RecipeCodes.ERR_RECIPE_ML_EVALUATION_K8S_OOM);
                            }
                        }
                    }
                }

                @Override
                public void init() {
                }
            };
            SpringUtils.getInstance().autowire((Object)runner);
            runner.init();
            runner.run();
            File diagnosticsFile = DKUFileUtils.getWithin((File)outputTmpDir, (String[])new String[]{"ml_diagnostics.json"});
            if (diagnosticsFile.exists() && hasModelEvaluationStore) {
                MLDiagnostics mlDiagnostics = ModelLikeId.parseJsonFile(diagnosticsFile, MLDiagnostics.class);
                mlDiagnostics.mergeIntoWarnings(this.activity.warnContext);
                FileUtils.copyFile((File)diagnosticsFile, (File)modelEvaluation.ref.getEvaluationFile("ml_diagnostics.json"));
            }
        }
        catch (Exception e) {
            if (modelEvaluationFolder != null) {
                logger.warn((Object)"Fail to perform llm evaluation, cleaning up model evaluation folder");
                DKUFileUtils.forceDelete(modelEvaluationFolder);
            }
            throw e;
        }
        if (hasModelEvaluationStore) {
            this.activity.getTargetStatus((String)evaluationStoreLoc.getFullName()).evaluationId = modelEvaluation.ref.evaluationId;
            this.modelEvaluationStoresService.finaliseRun(modelEvaluation);
        }
    }

    @Override
    public void init() throws Exception {
        this.authCtx = this.authCtxService.getAuthCtx();
        new PreRunSchemaPropagationHandler(this.activity, this.recipe).propagateIfNeeded();
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (LLMEvaluationRecipePayloadParams)JSON.parse((String)payload, LLMEvaluationRecipePayloadParams.class);
    }

    private void checkNonNullAndEmptyFields(LLMEvaluationRecipePayloadParams desc, Dataset inputDataset) {
        if (desc.inputFormat == null) {
            throw new IllegalConfigurationException("You need to select an Input Dataset Format.");
        }
        if (!(desc.inputColumnName != null && inputDataset.getSchema().hasColumn(desc.inputColumnName) || Arrays.asList("PROMPT_RECIPE", "DATAIKU_ANSWERS").contains(desc.inputFormat.name()))) {
            throw new IllegalConfigurationException("You need to check that you have selected an Input column and that it exists in the current Dataset Schema.");
        }
        if (desc.llmTaskType == null) {
            throw new IllegalConfigurationException("You need to select a Task.");
        }
        if (desc.metrics.isEmpty() && desc.customMetrics.isEmpty()) {
            throw new IllegalConfigurationException("You need to select at least one metric or define at least one custom metric to run the recipe.");
        }
    }
}

