/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml;

import com.dataiku.dip.DSSMetrics;
import com.dataiku.dip.analysis.coreservices.AnalysisDataService;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.MLTaskHandler;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.ModelStateHelper;
import com.dataiku.dip.analysis.ml.spark.SparkBasedDoctorJob;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.distributed.metrics.ContainerUsageMetrics;
import com.dataiku.dip.export.ZipUnzipDir;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.tickets.APITicketService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.shaker.resources.ResourcesGatherer;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.CollectionUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NotImplementedException;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.dip.variables.VariablesService;
import com.google.common.collect.Lists;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public abstract class MLLibMLTaskHandler<T extends MLTask>
extends MLTaskHandler<T> {
    @Autowired
    protected DatasetsDAO datasetsDAO;
    @Autowired
    protected AnalysisDataService dataService;
    @Autowired
    protected VariablesService variablesService;
    @Autowired
    protected APITicketService apiTicketService;
    @Autowired
    protected TransactionService transactionService;
    protected final ResourcesGatherer gatherer = new ResourcesGatherer();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis.train");

    protected abstract String getAPITicketDescription();

    protected abstract TrainSparkDoctorJobBuilder getTrainSparkDoctorJobBuilder(File var1, File var2, File var3, SplitDesc var4, String var5);

    protected abstract void updateTrainInfo(WorkSet.PreprocessingSet var1) throws IOException;

    protected abstract SplitDesc prepareSplits() throws Exception;

    protected abstract String getHiveDB(SplitDesc var1) throws IOException;

    protected abstract String getDSSMetricName();

    protected MLLibMLTaskHandler(AnalysisCoreParams acp, MLTaskLoc taskLoc, T task, String sessionId, AuthCtx user) {
        super(acp, taskLoc, task, sessionId, user);
        if (this.datasetsDAO == null || this.dataService == null) {
            SpringUtils.getInstance().autowire((Object)this);
        }
        SpringUtils.getInstance().autowire((Object)this.gatherer);
    }

    @Override
    public void abort() {
    }

    @Override
    public void abort(List<FullModelId> fullModelIdSet) throws IOException {
        throw new NotImplementedException("Partial abort is not implemented for this engine");
    }

    @Override
    public Map<FullModelId, ContainerUsageMetrics> getContainerUsageMetricsPerModel() {
        return Collections.emptyMap();
    }

    @Override
    public List<Integer> getKernelPids() {
        return Collections.emptyList();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void train() throws Exception {
        String hiveDb;
        DSSMetrics.registry().meter(this.getDSSMetricName()).mark();
        SplitDesc splitDesc = this.prepareSplits();
        File sessionFolder = MLPaths.sessionFolder(this.taskLoc, this.sessionId);
        JSON.prettyToFile((Object)new SplitDesc.SplitRef(splitDesc.instanceId), (File)new File(sessionFolder, "split_ref.json"));
        try (Transaction t = this.transactionService.beginRead();){
            SerializedShakerScript expandedScript = (SerializedShakerScript)JSON.deepCopy((Object)this.acp.script);
            this.gatherer.gatherAndCompute(this.user, this.acp.projectKey, this.acp.script.expandedDeepCopy((VariablesContext)this.variablesService.getContext((String)this.acp.projectKey)).steps);
            JSON.prettyToFile((Object)splitDesc, (File)new File(sessionFolder, "split_desc.json"));
            JSON.prettyToFile((Object)expandedScript, (File)new File(sessionFolder, "escript.json"));
            JSON.prettyToFile(this.gatherer.getResourceMapping(), (File)new File(sessionFolder, "resource_mapping.json"));
            hiveDb = this.getHiveDB(splitDesc);
        }
        this.dispatchWork();
        try {
            for (WorkSet.PreprocessingSet pps : this.workQueue) {
                File ppsFolder = new File(pps.run_folder);
                JSON.prettyToFile((Object)pps, (File)new File(ppsFolder, "preprocessing_set.json"));
                this.runTrain(sessionFolder, ppsFolder, new File(pps.modelingSets.get((int)0).run_folder), splitDesc, hiveDb);
                this.updateTrainInfo(pps);
            }
        }
        catch (Throwable e) {
            logger.error((Object)"Processing failed", e);
            MLLibMLTaskHandler mLLibMLTaskHandler = this;
            synchronized (mLLibMLTaskHandler) {
                for (WorkSet.PreprocessingSet pps : this.ws.preprocessingSets) {
                    ModelStateHelper.markAllNotDoneAsFailed(pps, e);
                }
            }
            throw e;
        }
    }

    private void runTrain(File sessionFolder, File preprocessingFolder, File modelFolder, SplitDesc splitDesc, String hiveDb) throws Exception {
        try (APITicketService.ExpirableTicket ticket = this.apiTicketService.createExpiringTicket(this.user, this.getAPITicketDescription(), (Object)this.task);){
            SparkBasedDoctorJob doctorJob = new SparkBasedDoctorJob(this.user, this.acp.projectKey, preprocessingFolder, this.task, ticket);
            doctorJob.runSpark(this.getTrainSparkDoctorJobBuilder(sessionFolder, preprocessingFolder, modelFolder, splitDesc, hiveDb), context -> {
                if (context.driverRunsRemotely()) {
                    ZipUnzipDir.extractFolder(new File(modelFolder, "trainedModel"), modelFolder);
                }
            });
            JSON.prettyToFile((Object)new JsonObject(), (File)new File(preprocessingFolder, "preprocessing_report.json"));
        }
    }

    protected static abstract class TrainSparkDoctorJobBuilder
    extends SparkBasedDoctorJob.SparkDoctorJobBuilder {
        protected final SplitDesc splitDesc;
        private final File sessionFolder;
        private final File preprocessingFolder;
        private final File modelFolder;
        private final String hiveDb;
        private final MLTask task;

        public TrainSparkDoctorJobBuilder(MLTask task, File sessionFolder, File preprocessingFolder, File modelFolder, SplitDesc splitDesc, String hiveDb) {
            this.sessionFolder = sessionFolder;
            this.preprocessingFolder = preprocessingFolder;
            this.modelFolder = modelFolder;
            this.splitDesc = splitDesc;
            this.hiveDb = hiveDb;
            this.task = task;
        }

        @Override
        public Map<String, String> getContextOverrideConf() {
            return CollectionUtils.appendableSSMap().put("spark.dku.ml.preparedDF.storageLevel", this.task.sparkParams.sparkPreparedDFStorageLevel).put("spark.dku.ml.repartitionNonHDFS", String.valueOf(this.task.sparkParams.sparkRepartitionNonHDFS)).put("spark.dku.ml.fittedDF.checkpoint", this.task.sparkCheckpoint.name()).put("spark.dku.checkpointDir", this.task.sparkCheckpointDir).put("spark.dku.ml.useGlobalMetastore", Boolean.toString(this.task.sparkParams.sparkUseGlobalMetastore)).put("spark.dku.ml.hiveDb", StringUtils.defaultIfBlank((String)this.hiveDb, (String)"")).get();
        }

        @Override
        public abstract List<String> getExtraRelevantProjectkeys();

        @Override
        public List<File> getExtraRecursiveFolders() {
            return Lists.newArrayList((Object[])new File[]{this.sessionFolder, this.preprocessingFolder, this.modelFolder});
        }

        @Override
        public List<String> getWritablePaths() {
            return Lists.newArrayList((Object[])new String[]{this.modelFolder.getAbsolutePath()});
        }
    }
}

