/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.guess;

import com.dataiku.dip.analysis.ml.prediction.guess.ClassicalPredictionGuesser;
import com.dataiku.dip.analysis.ml.shared.FeatureGuessUtils;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.preprocessing.CatFeaturePreprocessingParams;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemTable;

public class InterpretableGuesser
extends ClassicalPredictionGuesser {
    private static final int MAX_TO_DROP = 20;
    private static final int MAX_TO_DUMMIFY = 5;

    public InterpretableGuesser(PredictionMLTask.ClassicalPredictionMLTask task, MemTable table) {
        super(task, table);
    }

    @Override
    public CatFeaturePreprocessingParams guessCategorical(MemColumn column) {
        return new FeatureGuessUtils.SparsityLimitGuesser(5, 20).guess(this.table, column, this.task);
    }

    private PredictionModelingParams initAlgorithmsParams(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params;
        if (keepExistingParams) {
            params = task.modeling;
            block0 : switch (task.backendType) {
                case PY_MEMORY: {
                    switch (task.predictionType) {
                        case REGRESSION: {
                            params.decision_tree_regression = new PredictionModelingParams.DecisionTreeHyperparametersSpace();
                            params.ridge_regression = new PredictionModelingParams.RidgeRegressionHyperparametersSpace();
                            break;
                        }
                        case MULTICLASS: 
                        case BINARY_CLASSIFICATION: {
                            params.decision_tree_classification = new PredictionModelingParams.DecisionTreeHyperparametersSpace();
                            params.logistic_regression = new PredictionModelingParams.LogisticRegressionHyperparametersSpace();
                            break;
                        }
                        default: {
                            throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
                        }
                    }
                    params.lars_params = new PredictionModelingParams.LarsHyperparametersSpace();
                    break;
                }
                case KERAS: {
                    break;
                }
                case MLLIB: {
                    params.mllib_dt = new PredictionModelingParams.MLLibDecisionTreeGridParams();
                    switch (task.predictionType) {
                        case REGRESSION: {
                            params.mllib_linreg.enabled = true;
                            params.mllib_linreg.enet_param.updateValues(0.1);
                            params.mllib_linreg.reg_param.updateValues(0.1);
                            break block0;
                        }
                        case MULTICLASS: 
                        case BINARY_CLASSIFICATION: {
                            params.mllib_logit.enabled = true;
                            params.mllib_logit.enet_param.updateValues(0.1);
                            params.mllib_logit.reg_param.updateValues(0.1);
                            break block0;
                        }
                    }
                    throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
                }
                case H2O: {
                    params.glm_sparkling = new PredictionModelingParams.H2OGLMGridParams();
                }
            }
        } else {
            params = new PredictionModelingParams(task.predictionType, task.modeling);
        }
        return params;
    }

    @Override
    public PredictionModelingParams guessAlgorithms(MemTable table, PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params = this.initAlgorithmsParams(task, keepExistingParams);
        block0 : switch (task.backendType) {
            case PY_MEMORY: {
                switch (task.predictionType) {
                    case REGRESSION: {
                        params.decision_tree_regression.enabled = true;
                        params.ridge_regression.enabled = true;
                        break;
                    }
                    case MULTICLASS: 
                    case BINARY_CLASSIFICATION: {
                        params.decision_tree_classification.enabled = true;
                        params.logistic_regression.enabled = true;
                        params.logistic_regression.penalty.withValue("l1", true).withValue("l2", false);
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)task.predictionType));
                    }
                }
                params.lars_params.K = 100;
                break;
            }
            case KERAS: {
                break;
            }
            case MLLIB: {
                params.mllib_dt.enabled = true;
                switch (task.predictionType) {
                    case REGRESSION: {
                        params.mllib_linreg.enabled = true;
                        params.mllib_linreg.enet_param.updateValues(0.1);
                        params.mllib_linreg.reg_param.updateValues(0.1);
                        break block0;
                    }
                    case MULTICLASS: 
                    case BINARY_CLASSIFICATION: {
                        params.mllib_logit.enabled = true;
                        params.mllib_logit.enet_param.updateValues(0.1);
                        params.mllib_logit.reg_param.updateValues(0.1);
                        break block0;
                    }
                }
                throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)task.predictionType));
            }
            case H2O: {
                params.glm_sparkling.enabled = true;
                switch (task.predictionType) {
                    case REGRESSION: {
                        params.glm_sparkling.family = "gaussian";
                        break block0;
                    }
                    case BINARY_CLASSIFICATION: {
                        params.glm_sparkling.family = "binomial";
                        break block0;
                    }
                    case MULTICLASS: {
                        params.glm_sparkling.family = "multinomial";
                        break block0;
                    }
                }
                throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)task.predictionType));
            }
        }
        return params;
    }
}

