/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml;

import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.SavedModelCodes;
import com.dataiku.dip.analysis.ml.prediction.flow.AbstractPredictionScoringRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.flow.AbstractPredictionTrainingRecipePayloadParams;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.clustering.ClusteringMLTask;
import com.dataiku.dip.analysis.model.clustering.ClustersFacts;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.core.CustomMetricResult;
import com.dataiku.dip.analysis.model.core.PostTrainModelingParams;
import com.dataiku.dip.analysis.model.core.PreTrainModelingParams;
import com.dataiku.dip.analysis.model.core.ResolvedCoreParams;
import com.dataiku.dip.analysis.model.core.ResolvedPreprocessingParams;
import com.dataiku.dip.analysis.model.prediction.DeepHubPreTrainModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.datasets.DatasetInspector;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.hive.HiveSchemaHandler;
import com.dataiku.dip.shaker.model.ScriptStep;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import java.io.File;
import java.io.IOException;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;

public class DKUMLUtils {
    private static final Logger logger = Logger.getLogger((String)"dku.ml.utils");

    public static void loadClasses() {
        DKUtils.forceInit(ScriptStep.class);
        DKUtils.forceInit(FeaturePreprocessingParams.class);
        DKUtils.forceInit(AbstractPredictionTrainingRecipePayloadParams.class);
        DKUtils.forceInit(CustomMetricResult.class);
        DKUtils.forceInit(AbstractPredictionScoringRecipePayloadParams.class);
        DKUtils.forceInit(DeepHubPreTrainModelingParams.class);
        DKUtils.forceInit(MLTask.class);
        DKUtils.forceInit(PreTrainModelingParams.JsonAdapterInit.class);
        DKUtils.forceInit(PostTrainModelingParams.JsonAdapterInit.class);
        DKUtils.forceInit(ClustersFacts.ClusterFact.class);
        DKUtils.forceInit(ResolvedCoreParams.class);
        DKUtils.forceInit(ResolvedPreprocessingParams.class);
    }

    public static void checkPredictionTaskBeforeTraining(PredictionMLTask task) throws Exception {
        DKUMLUtils.checkOnlyBeforeTraining(task);
        DKUMLUtils.checkPredictionTaskBeforeSaving(task);
    }

    private static void checkOnlyBeforeTraining(PredictionMLTask task) throws IllegalArgumentException {
        task.gpuConfig.validate(task.backendType);
    }

    public static void checkPredictionTaskBeforeSaving(PredictionMLTask task) throws Exception {
        String col;
        ErrorContext.checkNotNull((Object)task.getPreprocessingParams(), (String)"No preprocessing settings");
        if (task.splitParams == null) {
            throw ErrorContext.iae((String)"No split params");
        }
        if (task.splitParams.ttPolicy == SplitParams.TrainTestPolicy.EXPLICIT_FILTERING_TWO_DATASETS && task.splitParams.eftdTest == null) {
            throw ErrorContext.iae((String)"Test set is not specified, please change the settings.");
        }
        if (task.splitParams.ssdSelection.samplingMethod.toString().startsWith("CLASS_REBALANCE") && !task.getPreprocessingParams().per_feature.containsKey(col = task.splitParams.ssdSelection.column)) {
            throw ErrorContext.iae((String)"The column chosen for class rebalancing does not exist !");
        }
        if (task instanceof PredictionMLTask.ClassicalPredictionMLTask) {
            DKUMLUtils.checkPredictionTask((PredictionMLTask.ClassicalPredictionMLTask)task);
        } else if (task instanceof PredictionMLTask.DeepHubPredictionMLTask) {
            DKUMLUtils.checkPredictionTask((PredictionMLTask.DeepHubPredictionMLTask)task);
        } else if (task instanceof PredictionMLTask.CausalPredictionMLTask) {
            DKUMLUtils.checkPredictionTask((PredictionMLTask.CausalPredictionMLTask)task);
        } else if (task instanceof PredictionMLTask.TimeseriesForecastingMLTask) {
            DKUMLUtils.checkPredictionTask((PredictionMLTask.TimeseriesForecastingMLTask)task);
        } else {
            throw ErrorContext.iae((String)("Unsupported PredictionMLTask: " + task.getClass().getSimpleName()));
        }
    }

    public static void checkPredictionTask(PredictionMLTask.DeepHubPredictionMLTask task) throws Exception {
        ErrorContext.checkNotNull((Object)task.modeling, (String)"No modeling settings");
        if (task.backendType != MLTask.BackendType.DEEP_HUB) {
            throw ErrorContext.iae((String)("Unsupported backend for DeepHub: " + String.valueOf((Object)task.backendType)));
        }
    }

    public static void checkPredictionTask(PredictionMLTask.CausalPredictionMLTask task) throws Exception {
        DKUMLUtils.checkPerFeaturePreprocessing(task.preprocessing.per_feature, task);
        ErrorContext.checkNotNull((Object)task.modeling, (String)"No modeling settings");
        if (task.backendType != MLTask.BackendType.PY_MEMORY) {
            throw ErrorContext.iae((String)("Unsupported backend for Causal prediction: " + String.valueOf((Object)task.backendType)));
        }
    }

    public static void checkPredictionTask(PredictionMLTask.ClassicalPredictionMLTask task) throws Exception {
        DKUMLUtils.checkPerFeaturePreprocessing(task.preprocessing.per_feature, task);
        ErrorContext.checkNotNull((Object)task.modeling, (String)"No modeling settings");
        if (task.backendType == MLTask.BackendType.VERTICA) {
            throw new CodedException((InfoMessage.MessageCode)SavedModelCodes.ERR_ML_VERTICA_NOT_SUPPORTED, "ML Task \"" + task.name + "\" (id=" + task.id + ") check failed");
        }
    }

    private static void checkPredictionTask(PredictionMLTask.TimeseriesForecastingMLTask task) throws Exception {
        ErrorContext.checkNotNull((Object)task.modeling, (String)"No modeling settings");
        if (task.backendType != MLTask.BackendType.PY_MEMORY) {
            throw ErrorContext.iae((String)("Unsupported backend for timeseries forecasting:" + String.valueOf((Object)task.backendType)));
        }
    }

    public static void checkClusteringTask(ClusteringMLTask task) throws Exception {
        if (task.preprocessing == null) {
            throw ErrorContext.iae((String)"No preprocessing settings");
        }
        DKUMLUtils.checkPerFeaturePreprocessing(task.preprocessing.per_feature, task);
        if (task.backendType == MLTask.BackendType.VERTICA) {
            throw new CodedException((InfoMessage.MessageCode)SavedModelCodes.ERR_ML_VERTICA_NOT_SUPPORTED, "ML Task \"" + task.name + "\" (id=" + task.id + ") check failed");
        }
    }

    private static void checkPerFeaturePreprocessing(Map<String, FeaturePreprocessingParams> fpps, MLTask task) throws Exception {
        try {
            boolean hasInputs = false;
            for (Map.Entry<String, FeaturePreprocessingParams> featurePreprocessing : fpps.entrySet()) {
                FeaturePreprocessingParams preprocessingParams = featurePreprocessing.getValue();
                String featureName = featurePreprocessing.getKey();
                if (preprocessingParams.role.isInput()) {
                    hasInputs = true;
                }
                preprocessingParams.check(featureName, task);
            }
            if (!hasInputs) {
                throw ErrorContext.iae((String)"No input feature is selected, please edit the settings");
            }
        }
        catch (Exception e) {
            logger.error((Object)"Invalid ML task (feature settings)");
            try {
                logger.info((Object)JSON.pretty(fpps));
            }
            catch (Throwable t) {
                logger.error((Object)"Failed to print feature settings");
            }
            throw e;
        }
    }

    public static String getHiveDb(AnalysisCoreParams acp, MLTask task, DatasetsDAO datasetsDAO, SplitParams params) throws IOException {
        switch (params.ttPolicy) {
            case EXPLICIT_FILTERING_SINGLE_DATASET: {
                return DKUMLUtils.getHiveDb(acp, task, datasetsDAO, params.efsdDatasetSmartName);
            }
            case EXPLICIT_FILTERING_TWO_DATASETS: {
                return DKUMLUtils.getHiveDb(acp, task, datasetsDAO, params.eftdTrain.datasetSmartName, params.eftdTest.datasetSmartName);
            }
            case SPLIT_SINGLE_DATASET: {
                return DKUMLUtils.getHiveDb(acp, task, datasetsDAO, params.ssdDatasetSmartName);
            }
        }
        throw new Error("Unreachable (unknown split policy)");
    }

    public static String getHiveDb(AnalysisCoreParams acp, MLTask task, DatasetsDAO datasetsDAO, ClusteringMLTask.ClusterSampling params) throws IOException {
        return DKUMLUtils.getHiveDb(acp, task, datasetsDAO, params.datasetSmartName);
    }

    public static String getHiveDb(AnalysisCoreParams acp, MLTask task, DatasetsDAO datasetsDAO, String ... smartNames) throws IOException {
        String hiveDb = null;
        if (task.sparkParams.sparkUseGlobalMetastore) {
            for (String smartName : smartNames) {
                SerializedDataset sd = (SerializedDataset)datasetsDAO.getMandatory(AnyLoc.resolveSmart(acp.projectKey, smartName));
                if (DatasetInspector.canHive(sd)) {
                    hiveDb = HiveSchemaHandler.getResolvedHiveDatabaseFromDataset(Dataset.fromSerialized(sd));
                }
                if (StringUtils.isNotBlank(hiveDb)) break;
            }
        }
        return hiveDb;
    }

    public static void dumpParamsOnDisk(AnalysisCoreParams acp, MLTaskLoc taskLoc, MLTask task, String sessionId) throws IOException {
        ResolvedCoreParams rccp = task.buildResolvedCoreParams(acp.projectKey);
        DKUMLUtils.dumpParamsOnDisk(acp, taskLoc, task, sessionId, rccp);
    }

    public static void dumpParamsOnDisk(AnalysisCoreParams acp, MLTaskLoc taskLoc, MLTask task, String sessionId, ResolvedCoreParams rccp) throws IOException {
        File sessionFolder = MLPaths.sessionFolder(taskLoc, sessionId);
        JSON.prettyToFile((Object)task, (File)new File(sessionFolder, "mltask.json"));
        JSON.prettyToFile((Object)rccp, (File)new File(sessionFolder, "core_params.json"));
        JSON.prettyToFile((Object)acp.script, (File)new File(sessionFolder, "script.json"));
    }
}

