/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.flow;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLFlowUtils;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.ScoringRecipeUtils;
import com.dataiku.dip.analysis.ml.prediction.flow.AbstractScoringRecipeRunner;
import com.dataiku.dip.analysis.ml.prediction.flow.PyPredictionScoringRecipeSubrunner;
import com.dataiku.dip.analysis.ml.prediction.flow.TabularPredictionScoringRecipePayloadParams;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.graph.FlowDataset;
import com.dataiku.dip.dataflow.graph.FlowSavedModel;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.dataflow.utils.FlowJobUtils;
import com.dataiku.dip.remoterun.RemoteRunsRegistry;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.datasets.DatasetAccessService;
import com.dataiku.dip.server.services.SingleWriteTransactionTransactionService;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.AutoDelete;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;

public class CausalPredictionScoringRecipeRunner
extends AbstractScoringRecipeRunner {
    private TabularPredictionScoringRecipePayloadParams desc;
    @Autowired
    private DatasetAccessService datasetAccessService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.recipes.causal.scoring");

    public CausalPredictionScoringRecipeRunner(JobActivity activity) {
        super(activity);
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (TabularPredictionScoringRecipePayloadParams)JSON.parse((String)payload, TabularPredictionScoringRecipePayloadParams.class);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void run() throws Exception {
        Schema inferredPreparationOSchema;
        if (this.desc.backendType != MLTask.BackendType.PY_MEMORY) {
            throw new IllegalArgumentException("Unsupported backend type: " + String.valueOf((Object)this.desc.backendType));
        }
        if (this.desc.predictionType != PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION && this.desc.predictionType != PredictionMLTask.PredictionType.CAUSAL_REGRESSION) {
            throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)this.desc.predictionType));
        }
        List<FlowDataset> inputFDS = this.activity.getSubgraph().getSourceDatasets();
        if (inputFDS.isEmpty()) {
            throw ErrorContext.iae((String)"Missing input dataset in scoring recipe");
        }
        Dataset inputDataset = this.getInputDataset("main");
        Dataset outputDataset = this.getOutputDataset("main");
        String hadMetadataColName = ScoringRecipeUtils.ModelMetadataUtils.schemaIncludesModelMetadata(inputDataset.getSchema());
        if (this.desc.outputModelMetadata && hadMetadataColName != null) {
            throw new Exception("\"" + hadMetadataColName + "\" is a reserved column name for model metadata output");
        }
        FlowSavedModel fsm = MLFlowUtils.getSMInput(this.activity);
        SavedModel sm = fsm.getSavedModel();
        MLFlowUtils.checkActiveVersion(sm);
        File activeModelFolder = MLPaths.savedModelVersionFolder(sm, sm.activeVersion);
        this.copyModelDetailsForFutureJobDiagnostic(activeModelFolder);
        SerializedShakerScript script = (SerializedShakerScript)JSON.parseFile((File)new File(activeModelFolder, "script.json"), SerializedShakerScript.class);
        script.contextProjectKey = sm.projectKey;
        ((SingleWriteTransactionTransactionService)SpringUtils.getBean(SingleWriteTransactionTransactionService.class)).stashTheSingleTransaction();
        try {
            logger.info((Object)"Adapting inferred output schema from training schema to avoid inconsistencies");
            inferredPreparationOSchema = MLFlowUtils.getInferredPreparationOutputSchema_NT(sm.projectKey, inputDataset, script, outputDataset.getType(), this.authCtx);
        }
        finally {
            ((SingleWriteTransactionTransactionService)SpringUtils.getBean(SingleWriteTransactionTransactionService.class)).unstashTheSingleTransaction();
        }
        String inputDatasetSmartName = AnyLoc.resolveFull(inputDataset.getFullName()).getSmartName(this.recipe.getProjectKey());
        String outputDatasetSmartName = AnyLoc.resolveFull(outputDataset.getFullName()).getSmartName(this.recipe.getProjectKey());
        Schema inferredPreparationOSchemaF = inferredPreparationOSchema;
        JobContext.getCurrentActivitySummary().engineType = "DSS";
        try (AutoDelete outputTmpDir = FlowJobUtils.getTmpFolder("causal-prediction-scoring-recipe", "run");){
            JSON.prettyToFile((Object)this.desc, (File)new File((File)outputTmpDir, "desc.json"));
            JSON.prettyToFile((Object)inferredPreparationOSchemaF, (File)new File((File)outputTmpDir, "preparation_output_schema.json"));
            JSON.prettyToFile((Object)script, (File)new File((File)outputTmpDir, "script.json"));
            PyPredictionScoringRecipeSubrunner.Scoring runner = this.createRunner(activeModelFolder, sm, inputDatasetSmartName, outputDatasetSmartName, outputTmpDir);
            this.startRunner(runner);
        }
    }

    private PyPredictionScoringRecipeSubrunner.Scoring createRunner(File activeModelFolder, SavedModel sm, String inputDatasetSmartName, String outputDatasetSmartName, AutoDelete outputTmpDir) throws IOException {
        FullModelId fmi = new FullModelId(sm.projectKey, sm.id, sm.activeVersion);
        JsonObject containerPayload = new JsonObject();
        containerPayload.addProperty("inputDatasetSmartName", inputDatasetSmartName);
        containerPayload.addProperty("outputDatasetSmartName", outputDatasetSmartName);
        containerPayload.addProperty("inputModel", fmi.toString());
        String codeEnvName = fmi.getResolvedCoreParams().executionParams.envName;
        logger.info((Object)"Running with a python backend");
        return new PyPredictionScoringRecipeSubrunner.Scoring(this.activity, fmi, activeModelFolder, this.authCtx, codeEnvName, this.desc, RemoteRunsRegistry.ExecutionType.RECIPE_PREDICTION_SCORE_CAUSAL, outputTmpDir, containerPayload, "dataiku.doctor.causal.score.launch_scoring_recipe", activeModelFolder.getAbsolutePath(), inputDatasetSmartName, outputDatasetSmartName, new File((File)outputTmpDir, "desc.json").getAbsolutePath(), new File((File)outputTmpDir, "script.json").getAbsolutePath(), new File((File)outputTmpDir, "preparation_output_schema.json").getAbsolutePath(), fmi.toString());
    }

    private Dataset getInputDataset(String role) throws IOException {
        return this.datasetAccessService.getMandatory(this.recipe.getModel().getSingleInput(role).getLoc(this.recipe.getProjectKey()));
    }

    private Dataset getOutputDataset(String role) throws IOException {
        return this.datasetAccessService.getMandatory(this.recipe.getModel().getSingleOutput(role).getLoc(this.recipe.getProjectKey()));
    }
}

