/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.DKUMLUtils;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLLibMLTaskHandler;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.ModelVersioning;
import com.dataiku.dip.analysis.ml.prediction.ClassicalPredictionParamsExpander;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.ModelStateHelper;
import com.dataiku.dip.analysis.ml.shared.WorkSetPreparator;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.cluster.SparkSettings;
import com.dataiku.dip.coremodel.SimpleKeyValue;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.spark.SparkJob;
import com.dataiku.dip.spark.SparkJobHelper;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.Lists;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class MLLibClassicalPredictionMLTaskHandler
extends MLLibMLTaskHandler<PredictionMLTask.ClassicalPredictionMLTask> {
    private final List<FullModelId> fullModelIds;

    public MLLibClassicalPredictionMLTaskHandler(AnalysisCoreParams acp, MLTaskLoc taskLoc, PredictionMLTask.ClassicalPredictionMLTask task, String sessionId, List<FullModelId> fmis, AuthCtx user) {
        super(acp, taskLoc, task, sessionId, user);
        this.fullModelIds = fmis;
    }

    @Override
    public void init(WorkSetPreparator preparator) throws Exception {
        DKUMLUtils.dumpParamsOnDisk(this.acp, this.taskLoc, this.task, this.sessionId);
        this.ws = new ClassicalPredictionParamsExpander((PredictionMLTask.ClassicalPredictionMLTask)this.task, this.sessionId).expand();
        preparator.prepare(this.ws);
    }

    @Override
    protected String getAPITicketDescription() {
        return "MLLib doctor prediction";
    }

    @Override
    protected String getDSSMetricName() {
        return "dku.ml.predictionTrain.mllibTrain";
    }

    @Override
    protected String getHiveDB(SplitDesc splitDesc) throws IOException {
        return DKUMLUtils.getHiveDb(this.acp, this.task, this.datasetsDAO, splitDesc.params);
    }

    @Override
    protected void updateTrainInfo(WorkSet.PreprocessingSet pps) throws IOException {
        ModelStateHelper.updatePredictionTrainInfoAndUserMeta(((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType, pps);
        ModelVersioning.dumpTrainVersionInfo(((PredictionMLTask.ClassicalPredictionMLTask)this.task).backendType, new File(pps.modelingSets.get((int)0).run_folder));
    }

    @Override
    protected MLLibMLTaskHandler.TrainSparkDoctorJobBuilder getTrainSparkDoctorJobBuilder(final File sessionFolder, final File preprocessingFolder, File modelFolder, SplitDesc splitDesc, String hiveDb) {
        return new MLLibMLTaskHandler.TrainSparkDoctorJobBuilder(this.task, sessionFolder, preprocessingFolder, modelFolder, splitDesc, hiveDb){

            @Override
            public <T extends SparkJob> T buildSparkJob(SparkJobHelper<T> helper, File runDir, SparkSettings sparkSettings, List<SimpleKeyValue> effectiveConf) throws Exception {
                ArrayList<String> fmis = new ArrayList<String>();
                for (FullModelId f : MLLibClassicalPredictionMLTaskHandler.this.fullModelIds) {
                    fmis.add(f.toString());
                }
                return helper.makeClassJobWithNonSecretGlobalFiles("DSS (Analysis): " + ((PredictionMLTask.ClassicalPredictionMLTask)((MLLibClassicalPredictionMLTaskHandler)MLLibClassicalPredictionMLTaskHandler.this).task).name, effectiveConf, MLLibClassicalPredictionMLTaskHandler.this.gatherer.getResourceFiles(), ((PredictionMLTask.ClassicalPredictionMLTask)((MLLibClassicalPredictionMLTaskHandler)MLLibClassicalPredictionMLTaskHandler.this).task).backendType == MLTask.BackendType.H2O, "com.dataiku.dip.spark.MLLibPredictionDoctorJob", ((MLLibClassicalPredictionMLTaskHandler)MLLibClassicalPredictionMLTaskHandler.this).acp.projectKey, sessionFolder.getAbsolutePath(), preprocessingFolder.getAbsolutePath(), JSON.json(fmis));
            }

            @Override
            public List<String> getExtraRelevantProjectkeys() {
                ArrayList projectKeys = Lists.newArrayList();
                switch (this.splitDesc.params.ttPolicy) {
                    case EXPLICIT_FILTERING_SINGLE_DATASET: {
                        projectKeys.add(AnyLoc.resolveSmart(((MLLibClassicalPredictionMLTaskHandler)MLLibClassicalPredictionMLTaskHandler.this).acp.projectKey, this.splitDesc.params.efsdDatasetSmartName).getProjectKey());
                        break;
                    }
                    case EXPLICIT_FILTERING_TWO_DATASETS: {
                        projectKeys.add(AnyLoc.resolveSmart(((MLLibClassicalPredictionMLTaskHandler)MLLibClassicalPredictionMLTaskHandler.this).acp.projectKey, this.splitDesc.params.eftdTrain.datasetSmartName).getProjectKey());
                        projectKeys.add(AnyLoc.resolveSmart(((MLLibClassicalPredictionMLTaskHandler)MLLibClassicalPredictionMLTaskHandler.this).acp.projectKey, this.splitDesc.params.eftdTest.datasetSmartName).getProjectKey());
                        break;
                    }
                    case SPLIT_SINGLE_DATASET: {
                        projectKeys.add(AnyLoc.resolveSmart(((MLLibClassicalPredictionMLTaskHandler)MLLibClassicalPredictionMLTaskHandler.this).acp.projectKey, this.splitDesc.params.ssdDatasetSmartName).getProjectKey());
                    }
                }
                return projectKeys;
            }
        };
    }

    @Override
    protected SplitDesc prepareSplits() throws Exception {
        SplitDesc expandedSplitDesc = new SplitDesc();
        expandedSplitDesc.generationDate = System.currentTimeMillis();
        SplitParams splitParams = ((PredictionMLTask.ClassicalPredictionMLTask)this.task).splitParams;
        if (splitParams.ttPolicy == SplitParams.TrainTestPolicy.SPLIT_SINGLE_DATASET && splitParams.ssdDatasetSmartName == null) {
            splitParams.ssdDatasetSmartName = this.acp.inputDatasetSmartName;
        } else if (splitParams.ttPolicy == SplitParams.TrainTestPolicy.EXPLICIT_FILTERING_SINGLE_DATASET && splitParams.efsdDatasetSmartName == null) {
            splitParams.efsdDatasetSmartName = this.acp.inputDatasetSmartName;
        }
        expandedSplitDesc.params = (SplitParams)JSON.deepCopy((Object)splitParams);
        expandedSplitDesc.policyId = "none";
        expandedSplitDesc.instanceId = "unique-" + ((PredictionMLTask.ClassicalPredictionMLTask)this.task).id;
        expandedSplitDesc.schema = this.dataService.getInferredSchemaForML_NT(this.acp, this.user);
        JSON.prettyToFile((Object)expandedSplitDesc, (File)new File(this.taskLoc.getSplitsFolder(), expandedSplitDesc.instanceId + ".json"));
        return expandedSplitDesc;
    }
}

