/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import java.util.Collections;
import java.util.List;

public class CausalForestMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Causal Forest";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.CausalForestHyperparameterSpace space = rpmp.causal_forest_grid;
        return new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("trees", space.n_estimators).withMVParam("criterion", space.criterion).withMVParam("depth", space.max_depth).withMVParam("min_samples_leaf", space.min_samples_leaf);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.CausalForestParams params = after.causal_forest_params;
        return new ModelTrainInfo.PostSearchDescription().withSVParam("trees", params.n_estimators).withSVParam("criterion", params.criterion).withSVParam("depth", params.max_depth).withSVParam("min_samples_leaf", params.min_samples_leaf);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.CausalForestHyperparameterSpace space = pmp.causal_forest;
        if (space == null || !space.enabled) {
            return;
        }
        checks.checkNumericalDimension(space.n_estimators, "Number of causal trees (Causal Forest)");
        checks.checkNumericalDimension(space.max_depth, "Maximum depth of trees (Causal Forest)");
        checks.checkNumericalDimension(space.min_samples_leaf, "Minimum number of samples per leaf (Causal Forest)");
        if (space.selection_mode == PredictionModelingParams.TreeSelectionMode.NUMBER) {
            checks.checkNumericalDimension(space.max_features, "Maximum number of features considered per split (Causal Forest)");
        }
        if (space.selection_mode == PredictionModelingParams.TreeSelectionMode.PROP) {
            checks.checkNumericalDimension(space.max_feature_prop, "Maximum ratio of features considered per split (Causal Forest)");
        }
        checks.checkPositive(space.criterion.getLength(), "At least one criterion must be selected (Causal Forest)");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.CausalForestHyperparameterSpace space = pmp.causal_forest;
        if (space == null || !space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.CAUSAL_FOREST, pmp);
        preTrainParams.causal_forest_grid = space;
        this.checkAndUpdateSearchStrategy(pmp, preTrainParams);
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        if (preTrainParams.gridLength > 1) {
            modelingSet.estimatedTrains = preTrainParams.gridLength * gsFolds + 1;
        }
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.CausalForestParams optimizedParameters = optimized.causal_forest_params;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.CausalForestHyperparameterSpace space = preTrainParams.causal_forest_grid;
        space.n_estimators.setToSingleValueGrid(Long.valueOf(optimizedParameters.n_estimators));
        space.criterion = CategoricalHyperparameterDimension.create(optimizedParameters.criterion, "mse", "het");
        space.max_depth.setToSingleValueGrid(Long.valueOf(optimizedParameters.max_depth));
        space.min_samples_leaf.setToSingleValueGrid(Long.valueOf(optimizedParameters.min_samples_leaf));
        space.max_features.setToSingleValueGrid(Long.valueOf(optimizedParameters.max_features));
        space.max_feature_prop.setToSingleValueGrid(optimizedParameters.max_feature_prop);
        space.honest = optimizedParameters.honest;
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.causal_forest = preTrainParams.causal_forest_grid;
        target.causal_forest.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.causal_forest = usedToTrain.causal_forest_grid;
        target.causal_forest.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.CausalForestHyperparameterSpace causalSpace = (PredictionModelingParams.CausalForestHyperparameterSpace)space;
        int size = causalSpace.n_estimators.getLength() * causalSpace.criterion.getLength() * causalSpace.min_samples_leaf.getLength() * causalSpace.max_depth.getLength();
        if (causalSpace.selection_mode.equals((Object)PredictionModelingParams.TreeSelectionMode.NUMBER)) {
            size *= causalSpace.max_features.getLength();
        } else if (causalSpace.selection_mode.equals((Object)PredictionModelingParams.TreeSelectionMode.PROP)) {
            size *= causalSpace.max_feature_prop.getLength();
        }
        return size;
    }

    @Override
    public PredictionModelingParams.GridSearchParams.Strategy getMaximumSupportedSearchStrategy() {
        return PredictionModelingParams.GridSearchParams.Strategy.BAYESIAN;
    }
}

