/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.guess;

import com.dataiku.dip.analysis.ml.prediction.guess.ClassicalPredictionGuesser;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import java.util.ArrayList;
import org.apache.log4j.Logger;

public class CustomPredictionGuesser
extends ClassicalPredictionGuesser {
    private static final Logger logger = Logger.getLogger((String)"dku.analysis.guess");

    public CustomPredictionGuesser(PredictionMLTask.ClassicalPredictionMLTask task, MemTable table) {
        super(task, table);
    }

    protected PredictionModelingParams.CustomPythonParams buildDefaultCustomPythonParams() {
        PredictionModelingParams.CustomPythonParams cpp = new PredictionModelingParams.CustomPythonParams();
        cpp.enabled = true;
        switch (((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType) {
            case REGRESSION: {
                cpp.code = "# This sample code uses a standard scikit-learn algorithm, the Adaboost regressor.\n\n# Your code must create a 'clf' variable. This clf must be a scikit-learn compatible\n# model, ie, it should:\n#  1. have at least fit(X,y) and predict(X) methods\n#  2. inherit sklearn.base.BaseEstimator\n#  3. handle the attributes in the __init__ function\n#     See: https://doc.dataiku.com/dss/latest/machine-learning/custom-models.html\n\nfrom sklearn.ensemble import AdaBoostRegressor\n\nclf = AdaBoostRegressor(n_estimators=20)\n";
                break;
            }
            case BINARY_CLASSIFICATION: 
            case MULTICLASS: {
                cpp.code = "# This sample code uses a standard scikit-learn algorithm, the Adaboost classifier.\n\n# Your code must create a 'clf' variable. This clf must be a scikit-learn compatible\n# classifier, ie, it should:\n#  1. have at least fit(X,y) and predict(X) methods\n#  2. inherit sklearn.base.BaseEstimator\n#  3. handle the attributes in the __init__ function\n#  4. have a classes_ attribute\n#  5. have a predict_proba method (optional)\n#     See: https://doc.dataiku.com/dss/latest/machine-learning/custom-models.html\n\nfrom sklearn.ensemble import AdaBoostClassifier\n\nclf = AdaBoostClassifier(n_estimators=20)\n";
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType));
            }
        }
        return cpp;
    }

    private PredictionModelingParams guessPythonAlgorithms(PredictionMLTask.ClassicalPredictionMLTask task, Boolean keepExistingParameters) {
        PredictionModelingParams params;
        PredictionModelingParams predictionModelingParams = params = keepExistingParameters != false ? task.modeling : new PredictionModelingParams(task.predictionType, task.modeling);
        if (keepExistingParameters.booleanValue() && !params.custom_python.isEmpty()) {
            for (PredictionModelingParams.CustomPythonParams customPythonParams : params.custom_python) {
                customPythonParams.enabled = true;
            }
        } else {
            params.custom_python = new ArrayList<PredictionModelingParams.CustomPythonParams>();
            params.custom_python.add(this.buildDefaultCustomPythonParams());
        }
        return params;
    }

    private PredictionModelingParams.MLLibCustomGridParams buildDefaultMLLibCustomGridParams() {
        PredictionModelingParams.MLLibCustomGridParams mlcgp = new PredictionModelingParams.MLLibCustomGridParams();
        mlcgp.initializationCode = "Custom MLlib model";
        mlcgp.enabled = true;
        switch (((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType) {
            case REGRESSION: {
                mlcgp.initializationCode = "// This sample code uses a standard MLlib algorithm, the RandomForestRegressor.\n\n// import the Estimator from spark.ml\nimport org.apache.spark.ml.regression.RandomForestRegressor\n\n// instantiate the Estimator\nnew RandomForestRegressor()\n   .setLabelCol(\"" + ((PredictionMLTask.ClassicalPredictionMLTask)this.task).targetVariable + "\")  // Must be the target column\n   .setFeaturesCol(\"__dku_features\")  // Must always be __dku_features\n   .setPredictionCol(\"prediction\")  // Must always be prediction\n   .setNumTrees(50)\n   .setMaxDepth(8)\n";
                break;
            }
            case BINARY_CLASSIFICATION: 
            case MULTICLASS: {
                mlcgp.initializationCode = "// This sample code uses a standard MLlib algorithm, the RandomForestClassifier.\n\n// import the Estimator from spark.ml\nimport org.apache.spark.ml.classification.RandomForestClassifier\n\n// instantiate the Estimator\nnew RandomForestClassifier()\n   .setLabelCol(\"" + ((PredictionMLTask.ClassicalPredictionMLTask)this.task).targetVariable + "\")  // Must be the target column\n   .setFeaturesCol(\"__dku_features\")  // Must always be __dku_features\n   .setPredictionCol(\"prediction\")    // Must always be prediction\n   .setNumTrees(50)\n   .setMaxDepth(8)\n";
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType));
            }
        }
        return mlcgp;
    }

    private PredictionModelingParams guessMLLibAlgorithms(PredictionMLTask.ClassicalPredictionMLTask task, Boolean keepExistingParameters) {
        PredictionModelingParams params;
        PredictionModelingParams predictionModelingParams = params = keepExistingParameters != false ? task.modeling : new PredictionModelingParams(task.predictionType, task.modeling);
        if (keepExistingParameters.booleanValue() && !params.custom_mllib.isEmpty()) {
            for (PredictionModelingParams.MLLibCustomGridParams customMLLib : params.custom_mllib) {
                customMLLib.enabled = true;
            }
        } else {
            params.custom_mllib = new ArrayList<PredictionModelingParams.MLLibCustomGridParams>();
            params.custom_mllib.add(this.buildDefaultMLLibCustomGridParams());
        }
        return params;
    }

    @Override
    public PredictionModelingParams guessAlgorithms(MemTable table, PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        switch (task.backendType) {
            case MLLIB: {
                return this.guessMLLibAlgorithms(task, keepExistingParams);
            }
            case PY_MEMORY: {
                return this.guessPythonAlgorithms(task, keepExistingParams);
            }
        }
        throw new IllegalArgumentException(String.valueOf((Object)task.backendType) + " not supported.");
    }
}

