/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import java.util.ArrayList;
import java.util.List;

public class PyGradientBoostingMeta
extends PyMemoryAlgorithmMeta {
    private final boolean isClassification;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.ml.prediction.algorithm.gbt");

    public PyGradientBoostingMeta(boolean isClassification) {
        this.isClassification = isClassification;
    }

    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Gradient Boosted Trees";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.GradientBoostedTreeHyperparametersSpace gbt;
        PredictionModelingParams.GradientBoostedTreeHyperparametersSpace gradientBoostedTreeHyperparametersSpace = gbt = this.isClassification ? rpmp.gbt_classifier_grid : rpmp.gbt_regressor_grid;
        assert (gbt != null);
        return new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, gbt)).withMVParam("trees", gbt.n_estimators).withMVParam("learning_rate", gbt.learning_rate).withMVParam("max_depth", gbt.max_depth);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        return new ModelTrainInfo.PostSearchDescription().withSVParam("trees", after.gbt.n_estimators).withSVParam("learning_rate", Float.valueOf(after.gbt.learning_rate)).withSVParam("max_depth", after.gbt.max_depth);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        boolean hasLosses;
        PredictionModelingParams.GradientBoostedTreeHyperparametersSpace gbt;
        PredictionModelingParams.GradientBoostedTreeHyperparametersSpace gradientBoostedTreeHyperparametersSpace = gbt = this.isClassification ? pmp.gbt_classification : pmp.gbt_regression;
        if (gbt == null || !gbt.enabled) {
            return;
        }
        checks.checkNumericalDimension(gbt.n_estimators, "Number of estimators (Gradient Boosted Trees)");
        checks.checkNumericalDimension(gbt.learning_rate, "Learning rate (Gradient Boosted Trees)");
        checks.checkNumericalDimension(gbt.max_depth, "Max depth of trees (Gradient Boosted Trees)");
        checks.checkNumericalDimension(gbt.min_samples_leaf, "Min samples per leaf (Gradient Boosted Trees)");
        if (gbt.selection_mode == PredictionModelingParams.TreeSelectionMode.NUMBER) {
            checks.checkNumericalDimension(gbt.max_features, "Max number of features (Gradient Boosted Trees)");
        }
        if (gbt.selection_mode == PredictionModelingParams.TreeSelectionMode.PROP) {
            checks.checkNumericalDimension(gbt.max_feature_prop, "Max proportion of features (Gradient Boosted Trees)");
        }
        if (this.isClassification) {
            PredictionModelingParams.GradientBoostedTreeHyperparametersSpace gbtc = gbt;
            hasLosses = gbtc.loss.getLength() > 0;
        } else {
            PredictionModelingParams.GBTRegressionHyperparametersSpace gbtr = (PredictionModelingParams.GBTRegressionHyperparametersSpace)gbt;
            hasLosses = gbtr.loss.getLength() > 0;
        }
        ErrorContext.check((boolean)hasLosses, (String)"Gradient boosted trees require at least one loss function");
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.GradientBoostedTreeHyperparametersSpace gbt = (PredictionModelingParams.GradientBoostedTreeHyperparametersSpace)space;
        int baseNb = gbt.max_depth.getLength() * gbt.learning_rate.getLength() * gbt.n_estimators.getLength() * gbt.min_samples_leaf.getLength() * gbt.loss.getLength();
        if (gbt.selection_mode == PredictionModelingParams.TreeSelectionMode.NUMBER) {
            baseNb *= gbt.max_features.getLength();
        }
        if (gbt.selection_mode == PredictionModelingParams.TreeSelectionMode.PROP) {
            baseNb *= gbt.max_feature_prop.getLength();
        }
        return baseNb;
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.GradientBoostedTreeHyperparametersSpace gbt;
        PredictionModelingParams.GradientBoostedTreeHyperparametersSpace gradientBoostedTreeHyperparametersSpace = gbt = this.isClassification ? pmp.gbt_classification : pmp.gbt_regression;
        if (gbt == null || !gbt.enabled) {
            return new ArrayList<WorkSet.ModelingSet>(0);
        }
        if (task.predictionType == PredictionMLTask.PredictionType.MULTICLASS && gbt.loss.values.get((Object)"exponential").enabled) {
            gbt = (PredictionModelingParams.GradientBoostedTreeHyperparametersSpace)JSON.deepCopy((Object)gbt);
            gbt.loss.withValue("deviance", true).withValue("exponential", false);
            logger.warn((Object)"Disabled exponential loss as it is not available on multiclass gradient boosting classification");
        }
        ArrayList<WorkSet.ModelingSet> out = new ArrayList<WorkSet.ModelingSet>();
        PreTrainPredictionModelingParams rcmp = this.isClassification ? new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.GBT_CLASSIFICATION, pmp) : new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.GBT_REGRESSION, pmp);
        rcmp.max_ensemble_nodes_serialized = pmp.max_ensemble_nodes_serialized;
        if (!this.isClassification) {
            rcmp.gbt_regressor_grid = (PredictionModelingParams.GBTRegressionHyperparametersSpace)gbt;
        } else {
            rcmp.gbt_classifier_grid = gbt;
        }
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        rcmp.gridLength = this.getSearchSize(rcmp.grid_search_params, gbt);
        ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
        out.add(ms);
        return out;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPythonCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isSQLCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        switch (coreParams.prediction_type) {
            case REGRESSION: 
            case BINARY_CLASSIFICATION: {
                return true;
            }
        }
        return false;
    }

    @Override
    public boolean isPMMLCompatible() {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        if (this.isClassification) {
            ret.gbt_classifier_grid.n_estimators.setToSingleValueGrid(Long.valueOf(optimized.gbt.n_estimators));
            ret.gbt_classifier_grid.max_depth.setToSingleValueGrid(Long.valueOf(optimized.gbt.max_depth));
            ret.gbt_classifier_grid.min_samples_leaf.setToSingleValueGrid(Long.valueOf(optimized.gbt.min_samples_leaf));
            ret.gbt_classifier_grid.max_features.setToSingleValueGrid(Long.valueOf(optimized.gbt.max_features));
            ret.gbt_classifier_grid.learning_rate.setToSingleValueGrid(Double.valueOf(optimized.gbt.learning_rate));
            ret.gbt_classifier_grid.max_feature_prop.setToSingleValueGrid(optimized.gbt.max_feature_prop);
            ret.gbt_classifier_grid.loss = CategoricalHyperparameterDimension.create(optimized.gbt.loss, "deviance", "exponential");
        } else {
            ret.gbt_regressor_grid.n_estimators.setToSingleValueGrid(Long.valueOf(optimized.gbt.n_estimators));
            ret.gbt_regressor_grid.max_depth.setToSingleValueGrid(Long.valueOf(optimized.gbt.max_depth));
            ret.gbt_regressor_grid.min_samples_leaf.setToSingleValueGrid(Long.valueOf(optimized.gbt.min_samples_leaf));
            ret.gbt_regressor_grid.max_features.setToSingleValueGrid(Long.valueOf(optimized.gbt.max_features));
            ret.gbt_regressor_grid.learning_rate.setToSingleValueGrid(Double.valueOf(optimized.gbt.learning_rate));
            ret.gbt_regressor_grid.max_feature_prop.setToSingleValueGrid(optimized.gbt.max_feature_prop);
            ret.gbt_regressor_grid.loss = CategoricalHyperparameterDimension.create(optimized.gbt.loss, "ls", "lad", "huber");
        }
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        if (this.isClassification) {
            target.gbt_classification = preTrain.gbt_classifier_grid;
            target.gbt_classification.enabled = true;
        } else {
            target.gbt_regression = preTrain.gbt_regressor_grid;
            target.gbt_regression.enabled = true;
        }
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        if (this.isClassification) {
            target.gbt_classification = usedToTrain.gbt_classifier_grid;
            target.gbt_classification.enabled = true;
        } else {
            target.gbt_regression = usedToTrain.gbt_regressor_grid;
            target.gbt_regression.enabled = true;
        }
    }
}

