/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.evaluation;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.model.core.CustomMetricResult;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.futures.FuturePayload;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.futures.FutureResponse;
import com.dataiku.dip.futures.FutureService;
import com.dataiku.dip.license.LicenseRestrictionException;
import com.dataiku.dip.llm.evaluation.TestCustomMetricKernelPool;
import com.dataiku.dip.recipes.code.python.PythonRecipeStatusComputerBase;
import com.dataiku.dip.recipes.nlp.llm_evaluation.LLMEvaluationRecipeParams;
import com.dataiku.dip.recipes.nlp.llm_evaluation.LLMEvaluationRecipePayloadParams;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.server.services.licensing.AbstractLicenseFeaturesStatusBuilder;
import com.dataiku.dip.server.services.licensing.LicenseEnforcementService;
import com.dataiku.dip.threads.BaseProgressingWorkThread;
import com.dataiku.dip.utils.DKULogger;
import com.google.gson.reflect.TypeToken;
import java.util.concurrent.TimeoutException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class CustomMetricsService {
    @Autowired
    private FutureService futureService;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private DatasetsDAO datasetsDAO;
    @Autowired
    private TestCustomMetricKernelPool testCustomMetricKernelPool;
    @Autowired
    private LicenseEnforcementService licenseEnforcementService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.evaluation.custommetrics.service");

    public FutureResponse<CustomMetricResult> startTest(DSSAuthCtx owner, final String projectKey, final SerializedRecipe recipe, final LLMEvaluationRecipePayloadParams recipeDesc, final int metricIndex, final SerializedDataset serializedInputDataset) throws Exception {
        AbstractLicenseFeaturesStatusBuilder.LicenseFeaturesStatus featuresStatus = this.licenseEnforcementService.getFeaturesStatus();
        if (!featuresStatus.advancedLLMMeshAllowed) {
            throw new LicenseRestrictionException("LLM Evaluation requires Advanced LLM Mesh, which is not enabled in your license.");
        }
        return this.futureService.runFuture(new BaseProgressingWorkThread<CustomMetricResult>(owner){
            CustomMetricResult result;

            public FuturePayload getPayload() {
                FuturePayload payload = FuturePayload.newSimple((String)"custom_metric_test", (String)"Test Custom metric");
                payload.targets.add(new FuturePayload.FuturePayloadTarget(projectKey, recipe.getFullId() + metricIndex, "Custom metric test : " + recipeDesc.customMetrics.get((int)metricIndex).name, null));
                return payload;
            }

            public double getDangerosity() {
                return 0.0;
            }

            public CustomMetricResult getResult() {
                return this.result;
            }

            public void execute() throws Exception {
                try (FutureProgress.AutocloseableFutureProgressState state = FutureProgress.pushAutoCloseableState((String)"Custom metric test", (double)100.0, (FutureProgressState.StateUnit)FutureProgressState.StateUnit.NONE);){
                    this.percentageProgressState = state;
                    this.result = CustomMetricsService.this.testMetric_NT(this.owner, (FutureProgressState)state, recipe, recipeDesc, metricIndex, serializedInputDataset);
                }
            }
        }, 0L, new TypeToken<FutureResponse<CustomMetricResult>>(){});
    }

    private CustomMetricResult testMetric_NT(DSSAuthCtx owner, FutureProgressState percentageProgressState, SerializedRecipe recipe, LLMEvaluationRecipePayloadParams recipeDesc, int metricIndex, SerializedDataset serializedInputDataset) throws Exception {
        LLMEvaluationRecipeParams recipeParams = recipe.getParamsAs(LLMEvaluationRecipeParams.class);
        InfoMessage.InfoMessages messages = new InfoMessage.InfoMessages();
        PythonRecipeStatusComputerBase.checkPythonCompile(messages, recipeDesc.customMetrics.get((int)metricIndex).metricCode, recipeParams.getCodeEnvSelection().envName, recipe.getProjectKey());
        if (messages.anyFatal()) {
            InfoMessage fatalMessage = messages.firstFatal();
            return new CustomMetricSyntaxError(fatalMessage.details, fatalMessage.line, fatalMessage.column);
        }
        percentageProgressState.increment(10.0);
        Dataset inputDataset = Dataset.fromSerialized(serializedInputDataset);
        int timeoutInMinutes = ApplicationConfigurator.getParams().getIntParam("dku.llm.eval.testMetric.timeoutInMinutes", Integer.valueOf(5));
        try {
            return this.testCustomMetricKernelPool.testCustomMetric(owner, recipe, recipeDesc, metricIndex, inputDataset, timeoutInMinutes);
        }
        catch (TimeoutException e) {
            throw new TimeoutException("Timeout: Custom metric did not complete in less than " + timeoutInMinutes + " minutes");
        }
    }

    static class CustomMetricSyntaxError
    extends CustomMetricResult {
        public String error;
        public int line;
        public int column;

        public CustomMetricSyntaxError(String error, int line, int column) {
            this.didSucceed = false;
            this.error = error;
            this.line = line;
            this.column = column;
        }

        public CustomMetricSyntaxError() {
            this.didSucceed = false;
        }
    }
}

