/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.ArrayList;
import java.util.List;

public class PyDecisionTreeMeta
extends PyMemoryAlgorithmMeta {
    private final boolean isClassification;

    public PyDecisionTreeMeta(boolean isClassification) {
        this.isClassification = isClassification;
    }

    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Decision Tree";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.DecisionTreeHyperparametersSpace dtp = rpmp.dtc_classifier_grid;
        assert (dtp != null);
        return new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, dtp)).withMVParam("max_depth", dtp.max_depth).withMVParam("min_samples", dtp.min_samples_leaf).withMVParam("criterion", dtp.criterion);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        return new ModelTrainInfo.PostSearchDescription().withSVParam("criterion", after.dt.criterion).withSVParam("max_depth", after.dt.max_depth).withSVParam("min_samples", after.dt.min_samples_leaf);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.DecisionTreeHyperparametersSpace dtp;
        PredictionModelingParams.DecisionTreeHyperparametersSpace decisionTreeHyperparametersSpace = dtp = this.isClassification ? pmp.decision_tree_classification : pmp.decision_tree_regression;
        if (dtp == null || !dtp.enabled) {
            return;
        }
        checks.addWarningSparse("Decision Tree");
        checks.checkNumericalDimension(dtp.max_depth, "Maximum depth of tree (Decision tree)");
        checks.checkNumericalDimension(dtp.min_samples_leaf, "Min samples per leaf (Decision tree)");
        ErrorContext.check((dtp.criterion.getLength() > 0 ? 1 : 0) != 0, (String)"Decision tree requires a criterion: Gini or Entropy");
        ErrorContext.check((dtp.splitter.getLength() > 0 ? 1 : 0) != 0, (String)"Decision tree requires a split strategy: best or random");
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPythonCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.DecisionTreeHyperparametersSpace dtp = (PredictionModelingParams.DecisionTreeHyperparametersSpace)space;
        return dtp.max_depth.getLength() * dtp.min_samples_leaf.getLength() * dtp.criterion.getLength() * dtp.splitter.getLength();
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.DecisionTreeHyperparametersSpace dtp = this.isClassification ? pmp.decision_tree_classification : pmp.decision_tree_regression;
        ArrayList<WorkSet.ModelingSet> out = new ArrayList<WorkSet.ModelingSet>();
        if (dtp == null || !dtp.enabled) {
            return out;
        }
        PreTrainPredictionModelingParams.Algorithm algo = this.isClassification ? PreTrainPredictionModelingParams.Algorithm.DECISION_TREE_CLASSIFICATION : PreTrainPredictionModelingParams.Algorithm.DECISION_TREE_REGRESSION;
        PreTrainPredictionModelingParams rcmp = new PreTrainPredictionModelingParams(algo, pmp);
        rcmp.dtc_classifier_grid = dtp;
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        rcmp.gridLength = this.getSearchSize(rcmp.grid_search_params, dtp);
        ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
        out.add(ms);
        return out;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    public boolean isSQLCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPMMLCompatible() {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        ret.dtc_classifier_grid.max_depth.setToSingleValueGrid(Long.valueOf(optimized.dt.max_depth));
        ret.dtc_classifier_grid.min_samples_leaf.setToSingleValueGrid(Long.valueOf(optimized.dt.min_samples_leaf));
        if (this.isClassification) {
            ret.dtc_classifier_grid.criterion = CategoricalHyperparameterDimension.create(optimized.dt.criterion, "gini", "entropy");
        }
        ret.dtc_classifier_grid.splitter = CategoricalHyperparameterDimension.create(optimized.dt.splitter, "best", "random");
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        if (this.isClassification) {
            target.decision_tree_classification = preTrain.dtc_classifier_grid;
            target.decision_tree_classification.enabled = true;
        } else {
            target.decision_tree_regression = preTrain.dtc_classifier_grid;
            target.decision_tree_regression.enabled = true;
        }
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        if (this.isClassification) {
            target.decision_tree_classification = usedToTrain.dtc_classifier_grid;
            target.decision_tree_classification.enabled = true;
        } else {
            target.decision_tree_regression = usedToTrain.dtc_classifier_grid;
            target.decision_tree_regression.enabled = true;
        }
    }
}

