/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.guess;

import com.dataiku.dip.analysis.ml.prediction.guess.ClassicalPredictionGuesser;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import org.apache.log4j.Logger;

public class DefaultPredictionGuesser
extends ClassicalPredictionGuesser {
    private static Logger logger = Logger.getLogger((String)"dku.analysis.guess");

    public DefaultPredictionGuesser(PredictionMLTask.ClassicalPredictionMLTask task, MemTable table) {
        super(task, table);
    }

    private PredictionModelingParams initMLLibAlgorithmsParams(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params;
        block5: {
            block4: {
                if (!keepExistingParams) break block4;
                params = task.modeling;
                switch (task.predictionType) {
                    case REGRESSION: {
                        params.mllib_linreg = new PredictionModelingParams.MLLibLinearRegressionGridParams();
                        params.mllib_rf = new PredictionModelingParams.MLLibTreesEnsembleGridParams();
                        params.mllib_gbt = new PredictionModelingParams.MLLibTreesEnsembleGridParams();
                        break block5;
                    }
                    case BINARY_CLASSIFICATION: 
                    case MULTICLASS: {
                        params.mllib_logit = new PredictionModelingParams.MLLibLogisticRegressionGridParams();
                        break block5;
                    }
                    default: {
                        throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)task.predictionType));
                    }
                }
            }
            params = new PredictionModelingParams(task.predictionType, task.modeling);
        }
        return params;
    }

    private PredictionModelingParams guessMLLibAlgorithms(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params = this.initMLLibAlgorithmsParams(task, keepExistingParams);
        switch (task.predictionType) {
            case REGRESSION: {
                params.mllib_linreg.enabled = true;
                params.mllib_rf.impurity = PredictionModelingParams.MLLibRfImpurity.variance;
                params.mllib_gbt.impurity = PredictionModelingParams.MLLibRfImpurity.variance;
                break;
            }
            case BINARY_CLASSIFICATION: 
            case MULTICLASS: {
                params.mllib_logit.enabled = true;
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
            }
        }
        return params;
    }

    private PredictionModelingParams initKerasAlgorithmsParams(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params;
        if (keepExistingParams) {
            params = task.modeling;
            params.keras = new PredictionModelingParams.KerasCodeParams();
        } else {
            params = new PredictionModelingParams(task.predictionType, task.modeling);
        }
        return params;
    }

    private PredictionModelingParams guessKerasAlgorithms(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params = this.initKerasAlgorithmsParams(task, keepExistingParams);
        params.keras.enabled = true;
        return params;
    }

    private PredictionModelingParams initH20AlgorithmsParams(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params;
        if (keepExistingParams) {
            params = task.modeling;
            params.deep_learning_sparkling = new PredictionModelingParams.H2ODeepLearningGridParams();
            params.glm_sparkling = new PredictionModelingParams.H2OGLMGridParams();
            params.gbm_sparkling = new PredictionModelingParams.H2OGBMGridParams();
        } else {
            params = new PredictionModelingParams(task.predictionType, task.modeling);
        }
        return params;
    }

    private PredictionModelingParams guessH2OAlgorithms(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params = this.initH20AlgorithmsParams(task, keepExistingParams);
        params.deep_learning_sparkling.enabled = true;
        switch (task.predictionType) {
            case REGRESSION: {
                params.glm_sparkling.family = "gaussian";
                params.gbm_sparkling.family = "gaussian";
                break;
            }
            case BINARY_CLASSIFICATION: {
                params.glm_sparkling.family = "binomial";
                params.gbm_sparkling.family = "bernoulli";
                break;
            }
            case MULTICLASS: {
                params.glm_sparkling.family = "multinomial";
                params.gbm_sparkling.family = "multinomial";
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
            }
        }
        return params;
    }

    private PredictionModelingParams initPythonAlgorithmsParams(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params;
        block5: {
            block4: {
                if (!keepExistingParams) break block4;
                params = task.modeling;
                switch (task.predictionType) {
                    case REGRESSION: {
                        params.random_forest_regression = new PredictionModelingParams.RandomForestHyperparametersSpace();
                        params.ridge_regression = new PredictionModelingParams.RidgeRegressionHyperparametersSpace();
                        params.lasso_regression = new PredictionModelingParams.LassoHyperparametersSpace();
                        params.svm_regression = new PredictionModelingParams.SVMHyperparametersSpace();
                        break block5;
                    }
                    case BINARY_CLASSIFICATION: 
                    case MULTICLASS: {
                        params.random_forest_classification = new PredictionModelingParams.RandomForestHyperparametersSpace();
                        params.random_forest_regression = new PredictionModelingParams.RandomForestHyperparametersSpace();
                        params.logistic_regression = new PredictionModelingParams.LogisticRegressionHyperparametersSpace();
                        params.svc_classifier = new PredictionModelingParams.SVMHyperparametersSpace();
                        break block5;
                    }
                    default: {
                        throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
                    }
                }
            }
            params = new PredictionModelingParams(task.predictionType, task.modeling);
        }
        return params;
    }

    private PredictionModelingParams guessPythonAlgorithms(MemTable table, PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params = this.initPythonAlgorithmsParams(task, keepExistingParams);
        int ncols = table.ncols();
        int nrows = table.nrows();
        int rfNJobs = nrows > 100000 ? 2 : 1;
        switch (task.predictionType) {
            case REGRESSION: {
                params.random_forest_regression.enabled = true;
                params.random_forest_regression.max_tree_depth.updateValues(6L, 10L + (long)Math.sqrt(ncols));
                params.random_forest_regression.min_samples_leaf.updateValues(Math.max(1L, (long)Math.sqrt(nrows) / 10L));
                params.random_forest_regression.n_jobs = rfNJobs;
                params.ridge_regression.enabled = true;
                params.lasso_regression.enabled = false;
                params.svm_regression.enabled = ncols > 100 && nrows < 10000;
                params.extra_trees.selection_mode = PredictionModelingParams.TreeSelectionMode.PROP;
                params.extra_trees.max_feature_prop.setToSingleValueGrid(1.0);
                break;
            }
            case BINARY_CLASSIFICATION: 
            case MULTICLASS: {
                params.random_forest_classification.enabled = true;
                params.random_forest_classification.max_tree_depth.updateValues(6L, 10L + (long)Math.sqrt(ncols));
                params.random_forest_classification.min_samples_leaf.updateValues(nrows > 100000 ? 10L : 1L);
                params.random_forest_regression.n_jobs = rfNJobs;
                params.logistic_regression.multi_class = task.predictionType == PredictionMLTask.PredictionType.BINARY_CLASSIFICATION ? PredictionModelingParams.LogisticRegressionClassifierMultiClass.OVR : PredictionModelingParams.LogisticRegressionClassifierMultiClass.MULTINOMIAL;
                params.logistic_regression.enabled = true;
                params.svc_classifier.enabled = ncols > 100 && nrows < 10000;
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
            }
        }
        return params;
    }

    @Override
    public PredictionModelingParams guessAlgorithms(MemTable table, PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        switch (task.backendType) {
            case MLLIB: {
                return this.guessMLLibAlgorithms(task, keepExistingParams);
            }
            case H2O: {
                return this.guessH2OAlgorithms(task, keepExistingParams);
            }
            case PY_MEMORY: {
                return this.guessPythonAlgorithms(table, task, keepExistingParams);
            }
            case KERAS: {
                return this.guessKerasAlgorithms(task, keepExistingParams);
            }
        }
        throw new IllegalArgumentException(String.valueOf((Object)task.backendType) + " not supported.");
    }
}

