/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.flow;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLFlowUtils;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.ml.prediction.flow.AbstractPredictionEvaluationRecipeRunner;
import com.dataiku.dip.analysis.ml.prediction.flow.PyPredictionScoringRecipeSubrunner;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.prediction.CausalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.graph.FlowDataset;
import com.dataiku.dip.dataflow.graph.FlowSavedModel;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.dataflow.utils.FlowJobUtils;
import com.dataiku.dip.remoterun.RemoteRunsRegistry;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.datasets.DatasetAccessService;
import com.dataiku.dip.server.services.SingleWriteTransactionTransactionService;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.AutoDelete;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class CausalPredictionEvaluationRecipeRunner
extends AbstractPredictionEvaluationRecipeRunner {
    @Autowired
    private DatasetAccessService datasetAccessService;

    public CausalPredictionEvaluationRecipeRunner(JobActivity activity) {
        super(activity);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void run() throws Exception {
        Schema inferredPreparedSchema;
        if (this.desc.backendType != MLTask.BackendType.PY_MEMORY) {
            throw new IllegalArgumentException("Unsupported backend type: " + String.valueOf((Object)this.desc.backendType));
        }
        if (this.desc.predictionType != PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION && this.desc.predictionType != PredictionMLTask.PredictionType.CAUSAL_REGRESSION) {
            throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)this.desc.predictionType));
        }
        List<FlowDataset> inputFDS = this.activity.getSubgraph().getSourceDatasets();
        if (inputFDS.isEmpty()) {
            throw ErrorContext.iae((String)"Missing input dataset in scoring recipe");
        }
        Dataset inputDataset = this.getInputDataset("main");
        String outputDatasetType = inputDataset.getType();
        String inputDatasetSmartName = AnyLoc.resolveFull(inputDataset.getFullName()).getSmartName(this.recipe.getProjectKey());
        String outputDatasetSmartName = this.getOutputDatasetSmartName();
        String metricsDatasetSmartName = this.getMetricsDatasetSmartName();
        FlowSavedModel fsm = MLFlowUtils.getSMInput(this.activity);
        SavedModel sm = fsm.getSavedModel();
        MLFlowUtils.checkActiveVersion(sm);
        File activeModelFolder = MLPaths.savedModelVersionFolder(sm, sm.activeVersion);
        FullModelId fmi = new FullModelId(sm.projectKey, sm.id, sm.activeVersion);
        assert (activeModelFolder != null);
        CausalPredictionModelDetails details = (CausalPredictionModelDetails)PredictionResultsReader.makeModelDetails(fmi);
        this.executionParams = details.coreParams.executionParams;
        SerializedShakerScript script = (SerializedShakerScript)JSON.parseFile((File)new File(activeModelFolder, "script.json"), SerializedShakerScript.class);
        script.contextProjectKey = sm.projectKey;
        ((SingleWriteTransactionTransactionService)SpringUtils.getBean(SingleWriteTransactionTransactionService.class)).stashTheSingleTransaction();
        try {
            inferredPreparedSchema = MLFlowUtils.getInferredPreparationOutputSchema_NT(sm.projectKey, inputDataset, script, outputDatasetType, this.authCtx);
        }
        finally {
            ((SingleWriteTransactionTransactionService)SpringUtils.getBean(SingleWriteTransactionTransactionService.class)).unstashTheSingleTransaction();
        }
        Schema preparationOSchemaToUse = MLFlowUtils.getSchemaToUseForPreparedScoringInput(details.preprocessing, details.splitDesc.schema, inferredPreparedSchema);
        try (AutoDelete outputTmpDir = FlowJobUtils.getTmpFolder("evaluation-recipe", "pyrun");){
            JSON.prettyToFile((Object)script, (File)new File((File)outputTmpDir, "script.json"));
            JSON.prettyToFile((Object)this.desc, (File)new File((File)outputTmpDir, "desc.json"));
            JSON.prettyToFile((Object)preparationOSchemaToUse, (File)new File((File)outputTmpDir, "preparation_output_schema.json"));
            JobContext.getCurrentActivitySummary().engineType = "PYTHON";
            JsonObject containerPayload = new JsonObject();
            containerPayload.addProperty("inputDatasetSmartName", inputDatasetSmartName);
            containerPayload.addProperty("outputDatasetSmartName", outputDatasetSmartName);
            containerPayload.addProperty("metricsDatasetSmartName", metricsDatasetSmartName);
            containerPayload.addProperty("inputModel", fmi.toString());
            PyPredictionScoringRecipeSubrunner.Evaluation runner = new PyPredictionScoringRecipeSubrunner.Evaluation(this.activity, fmi, activeModelFolder, null, this.authCtx, this.executionParams.envName, this.desc, RemoteRunsRegistry.ExecutionType.RECIPE_PREDICTION_EVAL_CAUSAL, outputTmpDir, containerPayload, "dataiku.doctor.causal.evaluate.launch_evaluation_recipe", activeModelFolder.getAbsolutePath(), inputDatasetSmartName, StringUtils.defaultIfBlank((String)outputDatasetSmartName, (String)""), StringUtils.defaultIfBlank((String)metricsDatasetSmartName, (String)""), new File((File)outputTmpDir, "desc.json").getAbsolutePath(), new File((File)outputTmpDir, "script.json").getAbsolutePath(), new File((File)outputTmpDir, "preparation_output_schema.json").getAbsolutePath());
            this.startRunner(runner);
        }
    }

    private Dataset getInputDataset(String role) throws IOException {
        return this.datasetAccessService.getMandatory(this.recipe.getModel().getSingleInput(role).getLoc(this.recipe.getProjectKey()));
    }

    @Override
    public void finalCommit() throws Exception {
        List<SerializedRecipe.RecipeOutput> evaluationStoreRole = this.recipe.getModel().getOutputsForRole("evaluationStore");
        if (!evaluationStoreRole.isEmpty()) {
            throw new IllegalArgumentException("Evaluation store not supported for causal predictions");
        }
    }
}

