/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.flow;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.flow.AbstractPredictionTrainingRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.flow.ClassicalPredictionTrainingRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionRecipesUtils;
import com.dataiku.dip.analysis.ml.prediction.flow.TabularPredictionTrainingRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.flow.TabularPredictionTrainingRecipeRunner;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.cluster.SparkSettings;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.coremodel.SimpleKeyValue;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.exec.AbstractSparkBasedRecipeRunner;
import com.dataiku.dip.dataflow.exec.SparkExecutionEnginesHelper;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.export.ZipUnzipDir;
import com.dataiku.dip.recipes.InitializableAbortableRecipeRunner;
import com.dataiku.dip.recipes.code.spark.SparkRecipeUtils;
import com.dataiku.dip.remoterun.RemoteRunsRegistry;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.shaker.resources.ResourcesGatherer;
import com.dataiku.dip.spark.SparkJob;
import com.dataiku.dip.spark.SparkJobHelper;
import com.dataiku.dip.spark.SparkOverrideConfig;
import com.dataiku.dip.utils.CollectionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NotImplementedException;
import com.dataiku.dip.variables.VariablesService;
import com.google.common.collect.Lists;
import java.io.File;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class ClassicalPredictionTrainingRecipeRunner
extends TabularPredictionTrainingRecipeRunner {
    @Autowired
    private VariablesService variablesSevice;
    private ClassicalPredictionTrainingRecipePayloadParams desc;
    private final ResourcesGatherer gatherer = new ResourcesGatherer();

    public ClassicalPredictionTrainingRecipeRunner(JobActivity activity) {
        super(activity);
    }

    @Override
    protected void checkBackendAndPredictionType() {
        if (this.desc.backendType == MLTask.BackendType.DEEP_HUB) {
            throw new NotImplementedException("Unsupported backend type: " + String.valueOf((Object)this.desc.backendType));
        }
    }

    private void forceKerasOperationMode(TabularPredictionTrainingRecipePayloadParams desc) {
        ClassicalPredictionTrainingRecipePayloadParams predDesc = (ClassicalPredictionTrainingRecipePayloadParams)desc;
        if (predDesc.core.backendType == MLTask.BackendType.KERAS && predDesc.operationMode != AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_ONLY) {
            logger.info((Object)("Enforce operation mode to 'TRAIN_SPLITTED_ONLY' for models built with Keras, was set to '" + String.valueOf((Object)predDesc.operationMode) + "'"));
            predDesc.operationMode = AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_ONLY;
        }
    }

    @Override
    protected void performChecksBeforeRun(TabularPredictionTrainingRecipePayloadParams desc) throws Exception {
        super.performChecksBeforeRun(desc);
        this.forceKerasOperationMode(desc);
    }

    @Override
    protected ClassicalPredictionTrainingRecipePayloadParams getDesc() {
        return this.desc;
    }

    @Override
    protected RemoteRunsRegistry.ExecutionType getRemoteExecutionType() {
        return RemoteRunsRegistry.ExecutionType.RECIPE_PREDICTION_TRAIN_PYTHON;
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (ClassicalPredictionTrainingRecipePayloadParams)JSON.parse((String)payload, ClassicalPredictionTrainingRecipePayloadParams.class);
    }

    @Override
    public void init() {
        SpringUtils.getInstance().autowire((Object)this.gatherer);
    }

    @Override
    protected String getCommand() {
        return "dataiku.doctor.prediction.reg_train_recipe";
    }

    @Override
    protected InitializableAbortableRecipeRunner createUnpartitionedRunner(FullModelId fmi, final File outModelFolder, ContainerExecRuntimeConfig containerConfig, ContainerExecSelection containerSelection) throws Exception {
        switch (this.desc.backendType) {
            case PY_MEMORY: 
            case KERAS: {
                return this.createPythonPredictionTrainingSubrunner(fmi, outModelFolder, containerConfig);
            }
            case H2O: 
            case MLLIB: {
                FilesystemACLUtils.grantFSReadACLs(this.authCtxService.getAuthCtx(), fmi.getProjectKey(), fmi.getFolderEnsuringSecurity());
                JobContext.getCurrentActivitySummary().engineType = "SPARK";
                if (containerSelection.containerMode == ContainerExecSelection.ContainerExecMode.EXPLICIT_CONTAINER) {
                    logger.warn((Object)("Ignoring container configuration " + containerSelection.containerConf + ", not compatible with Spark ML engine"));
                }
                final String hiveDb = SparkRecipeUtils.getHiveMetastoreDatabase(this.activity, this.datasetsDAO);
                return new AbstractSparkBasedRecipeRunner(this.activity){

                    @Override
                    public void run() throws Exception {
                        SerializedShakerScript expandedScript = ClassicalPredictionTrainingRecipeRunner.this.desc.script.expandedDeepCopy(ClassicalPredictionTrainingRecipeRunner.this.variablesSevice.getContext(this.projectKey));
                        ClassicalPredictionTrainingRecipeRunner.this.gatherer.gatherAndCompute(this.authCtxService.getAuthCtx(), this.projectKey, expandedScript.steps);
                        JSON.prettyToFile((Object)expandedScript, (File)new File(outModelFolder, "script.json"));
                        JSON.prettyToFile(ClassicalPredictionTrainingRecipeRunner.this.gatherer.getResourceMapping(), (File)new File(outModelFolder, "resource_mapping.json"));
                        this.runSpark("prediction", ClassicalPredictionTrainingRecipeRunner.this.desc.sparkParams.sparkExecutionEngine, new SparkExecutionEnginesHelper.SparkRecipeJobBuilder(){

                            @Override
                            public <T extends SparkJob> T buildSparkJob(SparkJobHelper<T> helper, File runDir, SparkSettings sparkSettings, List<SimpleKeyValue> effectiveConf) throws Exception {
                                return helper.makeClassJobWithNonSecretGlobalFiles("DSS (train): " + activity.id(), effectiveConf, ClassicalPredictionTrainingRecipeRunner.this.gatherer.getResourceFiles(), ClassicalPredictionTrainingRecipeRunner.this.desc.backendType == MLTask.BackendType.H2O, "com.dataiku.dip.spark.MLLibPredictionTrainingJob", recipe.getProjectKey(), outModelFolder.getAbsolutePath());
                            }

                            @Override
                            public SparkOverrideConfig getRecipeOverrideConf() {
                                return ClassicalPredictionTrainingRecipeRunner.this.desc.sparkParams.sparkConf;
                            }

                            @Override
                            public Map<String, String> getContextOverrideConf() {
                                return CollectionUtils.appendableSSMap().put("spark.dku.ml.preparedDF.storageLevel", ClassicalPredictionTrainingRecipeRunner.this.desc.sparkParams.sparkPreparedDFStorageLevel).put("spark.dku.ml.repartitionNonHDFS", String.valueOf(ClassicalPredictionTrainingRecipeRunner.this.desc.sparkParams.sparkRepartitionNonHDFS)).put("spark.dku.ml.useGlobalMetastore", Boolean.toString(ClassicalPredictionTrainingRecipeRunner.this.desc.sparkParams.sparkUseGlobalMetastore)).put("spark.dku.ml.hiveDb", StringUtils.defaultIfBlank((String)hiveDb, (String)"")).get();
                            }

                            @Override
                            public List<File> getExtraRecursiveFolders() {
                                return Lists.newArrayList((Object[])new File[]{outModelFolder});
                            }

                            @Override
                            public List<String> getWritablePaths() {
                                return Lists.newArrayList((Object[])new String[]{outModelFolder.getAbsolutePath()});
                            }
                        }, new SparkJobHelper.SparkJobPostProcessor(){

                            @Override
                            public void postProcess(SparkJobHelper.SparkJobContext context) throws Exception {
                                if (context.driverRunsRemotely()) {
                                    ZipUnzipDir.extractFolder(new File(outModelFolder, "trainedModel"), outModelFolder);
                                }
                            }
                        }, null);
                    }

                    @Override
                    public void init() throws Exception {
                    }
                };
            }
        }
        throw new NotImplementedException("Unsupported backend type: " + String.valueOf((Object)this.desc.backendType));
    }

    @Override
    protected ResolvedClassicalPredictionCoreParams resolveCoreParams(ContainerExecSelection containerSelection) {
        ResolvedClassicalPredictionCoreParams res = (ResolvedClassicalPredictionCoreParams)JSON.deepCopy((Object)this.desc.core);
        this.addEnvParameters(this.desc, res, containerSelection);
        res.managedFolderSmartId = PredictionRecipesUtils.getManagedFolderSmartNameOrNull(this.recipe);
        res.executionParams.sparkParams = this.desc.sparkParams;
        return res;
    }
}

