/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.evaluation.llm;

import com.dataiku.dip.analysis.coreservices.PredictionService;
import com.dataiku.dip.analysis.ml.MLDiagnostics;
import com.dataiku.dip.analysis.ml.ModelLikeId;
import com.dataiku.dip.analysis.ml.prediction.flow.EvaluationDatasetHelper;
import com.dataiku.dip.analysis.ml.shared.EvaluationLabelsHelper;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.JobAuthCtxService;
import com.dataiku.dip.dataflow.exec.PreRunSchemaPropagationHandler;
import com.dataiku.dip.dataflow.graph.utils.GraphUtils;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.dataflow.utils.FlowJobUtils;
import com.dataiku.dip.exceptions.IllegalConfigurationException;
import com.dataiku.dip.license.LicenseRestrictionException;
import com.dataiku.dip.license.LicenseStatusService;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMRefEnricherService;
import com.dataiku.dip.mec.KernelsModelEvaluationStoresService;
import com.dataiku.dip.mec.LLMModelEvaluation;
import com.dataiku.dip.mec.ModelEvaluationStore;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.recipes.InitializableAbortableRecipeRunner;
import com.dataiku.dip.recipes.nlp.evaluation.AbstractGenAIEvaluationRecipeRunner;
import com.dataiku.dip.recipes.nlp.evaluation.llm.LLMEvaluationRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.evaluation.llm.LLMEvaluationRecipePythonRunner;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.licensing.AbstractLicenseFeaturesStatusBuilder;
import com.dataiku.dip.server.services.licensing.LicenseFeaturesStatusBuilder;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.AutoDelete;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.mlflow_project.apachecommons.io.FileUtils;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class LLMEvaluationRecipeRunner
extends AbstractGenAIEvaluationRecipeRunner {
    private static final String TYPE = "llm-evaluation-recipe";
    private LLMEvaluationRecipePayloadParams desc;
    private AuthCtx authCtx;
    private InitializableAbortableRecipeRunner abortableRecipeRunner = null;
    @Autowired
    private JobAuthCtxService authCtxService;
    @Autowired
    private KernelsModelEvaluationStoresService modelEvaluationStoresService;
    @Autowired
    private LicenseStatusService licenseStatusService;
    @Autowired
    private LLMRefEnricherService llmRefEnricherService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.recipes.nlp.evaluation.llm.runner");

    public LLMEvaluationRecipeRunner(JobActivity activity) {
        super(activity);
    }

    @Override
    public void notifyBeforeAborting() {
        if (this.abortableRecipeRunner != null) {
            this.abortableRecipeRunner.notifyBeforeAborting();
        }
    }

    @Override
    public void finalCommit() throws Exception {
    }

    @Override
    public void run() throws Exception {
        File modelEvaluationFolder;
        LLMModelEvaluation modelEvaluation;
        boolean hasModelEvaluationStore;
        LicenseStatusService.LicensingStatus ls = this.licenseStatusService.getLicensingStatus();
        AbstractLicenseFeaturesStatusBuilder.LicenseFeaturesStatus featuresStatus = LicenseFeaturesStatusBuilder.getFeaturesStatus(ls);
        if (!featuresStatus.advancedLLMMeshAllowed) {
            throw new LicenseRestrictionException("LLM Evaluation requires Advanced LLM Mesh, which is not enabled in your license.");
        }
        List<InfoMessage> customMetricsErrors = PredictionService.checkCustomMetricParams(this.desc.customMetrics);
        if (!customMetricsErrors.isEmpty()) {
            throw new IllegalArgumentException(customMetricsErrors.get(0).toString());
        }
        logger.info((Object)"LLM Evaluation recipe runner started");
        AbstractGenAIEvaluationRecipeRunner.GenAiEvaluationRecipeInputAndOutputs recipeIO = this.getRecipeInputAndOutputs();
        String inputDatasetSmartName = recipeIO.inputDataset().getSmartName(this.recipe.getProjectKey());
        Dataset inputDataset = recipeIO.inputDataset();
        AnyLoc evaluationStoreLoc = recipeIO.evaluationStoreLoc();
        boolean bl = hasModelEvaluationStore = recipeIO.evaluationStoreLoc() != null;
        if (recipeIO.outputDatasetSmartName() == null && recipeIO.metricsDatasetSmartName() == null && recipeIO.evaluationStoreLoc() == null) {
            throw new IllegalArgumentException("The recipe needs at least one output.");
        }
        int ragasMaxWorkers = LLMEvaluationRecipeRunner.getRagasMaxWorkers(this.desc.ragasMaxWorkers);
        LLMEvaluationRecipeRunner.validateCustomMetrics(this.authCtx, this.desc.customMetrics);
        this.checkNonNullAndEmptyFields(this.desc, inputDataset);
        List<Partition> inputPartitions = this.subgraph.getSourcePartitions(GraphUtils.getSingleSource(this.recipe));
        KernelsModelEvaluationStoresService.DataTypeAndParams datasetTypeAndParams = this.modelEvaluationStoresService.makeDataTypeAndParams(this.recipe.getProjectKey(), inputDataset, inputPartitions);
        if (hasModelEvaluationStore) {
            KernelsModelEvaluationStoresService.LLMEvaluationModelInfo evaluationModelInfo = this.getEvaluationModelInfo(this.desc, inputDataset.serialize());
            ModelEvaluationStore mes = this.modelEvaluationStoresService.getMandatory(evaluationStoreLoc.getProjectKey(), evaluationStoreLoc.getId());
            modelEvaluation = this.modelEvaluationStoresService.setupLLMRun(mes, datasetTypeAndParams, evaluationModelInfo);
            modelEvaluationFolder = modelEvaluation.ref.getMainFolder();
            FilesystemACLUtils.grantFSFullACLs(this.authCtx, this.recipe.getProjectKey(), true, modelEvaluationFolder);
        } else {
            modelEvaluation = null;
            modelEvaluationFolder = null;
        }
        try (AutoDelete outputTmpDir = FlowJobUtils.getTmpFolder(TYPE, "pyrun");){
            JSON.prettyToFile((Object)this.desc, (File)new File((File)outputTmpDir, "desc.json"));
            JobContext.getCurrentActivitySummary().engineType = "PYTHON";
            File additionalLogsDir = FlowJobUtils.getJobMadeDir(TYPE, "additional-logs");
            File mainLogFile = FlowJobUtils.getJobTouchedFile(TYPE, "python.log");
            LLMEvaluationRecipePythonRunner runner = new LLMEvaluationRecipePythonRunner(this.activity, outputTmpDir, inputDatasetSmartName, recipeIO.outputDatasetSmartName(), recipeIO.metricsDatasetSmartName(), modelEvaluationFolder, additionalLogsDir, mainLogFile, ragasMaxWorkers);
            SpringUtils.getInstance().autowire((Object)runner);
            runner.init();
            runner.run();
            File diagnosticsFile = DKUFileUtils.getWithin((File)outputTmpDir, (String[])new String[]{"ml_diagnostics.json"});
            if (diagnosticsFile.exists()) {
                MLDiagnostics mlDiagnostics = ModelLikeId.parseJsonFile(diagnosticsFile, MLDiagnostics.class);
                mlDiagnostics.mergeIntoWarnings(this.activity.warnContext);
                if (hasModelEvaluationStore) {
                    FileUtils.copyFile((File)diagnosticsFile, (File)modelEvaluation.ref.getEvaluationFile("ml_diagnostics.json"));
                }
            }
        }
        catch (Exception e) {
            if (modelEvaluationFolder != null) {
                logger.warn((Object)"Fail to run LLM Evaluation, cleaning up model evaluation folder");
                DKUFileUtils.forceDelete(modelEvaluationFolder);
            }
            throw e;
        }
        if (hasModelEvaluationStore) {
            this.activity.getTargetStatus((String)evaluationStoreLoc.getFullName()).evaluationId = modelEvaluation.ref.evaluationId;
            this.modelEvaluationStoresService.finaliseRun(modelEvaluation);
        }
    }

    @Nonnull
    private KernelsModelEvaluationStoresService.LLMEvaluationModelInfo getEvaluationModelInfo(LLMEvaluationRecipePayloadParams desc, SerializedDataset serializedInputDataset) throws Exception {
        KernelsModelEvaluationStoresService.LLMEvaluationModelInfo evaluationModelInfo = new KernelsModelEvaluationStoresService.LLMEvaluationModelInfo();
        evaluationModelInfo.evaluationId = desc.evaluationId;
        evaluationModelInfo.name = desc.evaluationName;
        evaluationModelInfo.limitSampling = false;
        evaluationModelInfo.evaluationDatasetType = EvaluationDatasetHelper.EvaluationDatasetType.CLASSIC;
        evaluationModelInfo.inputFormat = desc.inputFormat;
        evaluationModelInfo.taskType = desc.llmTaskType;
        evaluationModelInfo.inputColumnName = desc.inputColumnName;
        evaluationModelInfo.outputColumnName = desc.outputColumnName;
        evaluationModelInfo.groundTruthColumnName = desc.groundTruthColumnName;
        evaluationModelInfo.contextColumnName = desc.contextColumnName;
        evaluationModelInfo.bertScoreModelType = StringUtils.isNotBlank((String)desc.bertScoreModelType) ? desc.bertScoreModelType : "bert-base-uncased";
        evaluationModelInfo.bleuTokenizer = StringUtils.isNotBlank((String)desc.bleuTokenizer) ? desc.bleuTokenizer : "13a";
        evaluationModelInfo.embeddingLLMId = desc.embeddingLLMId;
        if (StringUtils.isNotBlank((String)desc.embeddingLLMId) && !desc.embeddingLLMId.equals("None")) {
            EnrichedLLMStructuredRef embeddingLLM = this.llmRefEnricherService.getEnrichedLLMRef(desc.embeddingLLMId, this.authCtx, this.recipe.getProjectKey());
            evaluationModelInfo.embeddingLLMFriendlyName = embeddingLLM.friendlyName;
        }
        evaluationModelInfo.embeddingSettings = desc.embeddingSettings;
        evaluationModelInfo.completionLLMId = desc.completionLLMId;
        if (StringUtils.isNotBlank((String)desc.completionLLMId) && !desc.embeddingLLMId.equals("None")) {
            EnrichedLLMStructuredRef completionLLM = this.llmRefEnricherService.getEnrichedLLMRef(desc.completionLLMId, this.authCtx, this.recipe.getProjectKey());
            evaluationModelInfo.completionLLMFriendlyName = completionLLM.friendlyName;
        }
        evaluationModelInfo.completionSettings = desc.completionSettings;
        evaluationModelInfo.labels = EvaluationLabelsHelper.getLLMEvaluationTimeLabels_T(this.recipe.getProjectKey(), serializedInputDataset, desc.labels, this.subgraph.getSourcePartitions(GraphUtils.getSingleSource(this.recipe)), evaluationModelInfo.embeddingLLMFriendlyName, evaluationModelInfo.completionLLMFriendlyName);
        evaluationModelInfo.customMetrics = desc.customMetrics;
        evaluationModelInfo.selection = desc.selection;
        return evaluationModelInfo;
    }

    @Override
    public void init() throws Exception {
        this.authCtx = this.authCtxService.getAuthCtx();
        new PreRunSchemaPropagationHandler(this.activity, this.recipe).propagateIfNeeded();
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (LLMEvaluationRecipePayloadParams)JSON.parse((String)payload, LLMEvaluationRecipePayloadParams.class);
    }

    protected void checkNonNullAndEmptyFields(LLMEvaluationRecipePayloadParams desc, Dataset inputDataset) {
        if (desc.inputFormat == null) {
            throw new IllegalConfigurationException("You need to select an Input Dataset Format.");
        }
        if (!(desc.inputColumnName != null && inputDataset.getSchema().hasColumn(desc.inputColumnName) || Arrays.asList("PROMPT_RECIPE", "DATAIKU_ANSWERS").contains(desc.inputFormat.name()))) {
            throw new IllegalConfigurationException("You need to check that you have selected an Input column and that it exists in the current Dataset Schema.");
        }
        if (desc.llmTaskType == null) {
            throw new IllegalConfigurationException("You need to select a Task.");
        }
        if (desc.metrics.isEmpty() && desc.customMetrics.isEmpty()) {
            throw new IllegalConfigurationException("You need to select at least one metric or define at least one custom metric to run the recipe.");
        }
    }
}

