/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.Collections;
import java.util.List;

public class DeepNeuralNetworkMeta
extends PyMemoryAlgorithmMeta {
    private final boolean isClassification;

    public DeepNeuralNetworkMeta(boolean isClassification) {
        this.isClassification = isClassification;
    }

    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Deep Neural Network";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.DeepNeuralNetworkHyperParameterSpace hpSpace = this.getHyperParameterSpace(rpmp);
        return new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, hpSpace)).withMVParam("learning_rate", hpSpace.learning_rate).withMVParam("hidden_layers", hpSpace.hidden_layers).withMVParam("units", hpSpace.units);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.DeepNeuralNetworkParams deepNeuralNetworkParams = after.deep_neural_network;
        return new ModelTrainInfo.PostSearchDescription().withSVParam("learning_rate", deepNeuralNetworkParams.learning_rate).withSVParam("hidden_layers", deepNeuralNetworkParams.hidden_layers).withSVParam("units", deepNeuralNetworkParams.units).withSVParam("epochs", deepNeuralNetworkParams.epochs);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.DeepNeuralNetworkHyperParameterSpace hpSpace = this.getHyperParameterSpace(pmp);
        if (hpSpace == null || !hpSpace.enabled) {
            return;
        }
        checks.checkNumericalDimension(hpSpace.learning_rate, "Learning rate (Deep Neural Network)");
        checks.checkNumericalDimension(hpSpace.hidden_layers, "Hidden Layers (Deep Neural Network)");
        checks.checkNumericalDimension(hpSpace.units, "Units (Deep Neural Network)");
        ErrorContext.check((hpSpace.max_epochs > 0 ? 1 : 0) != 0, (String)"Max epochs (Deep Neural Network) should be positive");
        ErrorContext.check((hpSpace.batch_size > 0 ? 1 : 0) != 0, (String)"Batch size (Deep Neural Network) should be positive");
        if (hpSpace.early_stopping_enabled) {
            ErrorContext.check((hpSpace.early_stopping_patience > 0 ? 1 : 0) != 0, (String)"Early stopping patience (Deep Neural Network) should be positive");
            ErrorContext.check((hpSpace.early_stopping_threshold > 0.0f ? 1 : 0) != 0, (String)"Early stopping threshold (Deep Neural Network) should be positive");
        }
        ErrorContext.check((hpSpace.dropout >= 0.0f && hpSpace.dropout < 1.0f ? 1 : 0) != 0, (String)"Dropout (Deep Neural Network) should be in between 0 (included) and 1 (excluded)");
        ErrorContext.check((hpSpace.reg_l2 >= 0.0f ? 1 : 0) != 0, (String)"L2 regularization (Deep Neural Network) must be >= 0");
        ErrorContext.check((hpSpace.reg_l1 >= 0.0f ? 1 : 0) != 0, (String)"L1 regularization (Deep Neural Network) must be >= 0");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.DeepNeuralNetworkHyperParameterSpace hpSpace = this.getHyperParameterSpace(pmp);
        if (!hpSpace.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams.Algorithm algorithm = this.isClassification ? PreTrainPredictionModelingParams.Algorithm.DEEP_NEURAL_NETWORK_CLASSIFICATION : PreTrainPredictionModelingParams.Algorithm.DEEP_NEURAL_NETWORK_REGRESSION;
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(algorithm, pmp);
        if (this.isClassification) {
            preTrainParams.deep_neural_network_classification_grid = hpSpace;
        } else {
            preTrainParams.deep_neural_network_regression_grid = hpSpace;
        }
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, hpSpace);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        if (preTrainParams.gridLength > 1) {
            modelingSet.estimatedTrains = preTrainParams.gridLength * gsFolds + 1;
        }
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.DeepNeuralNetworkParams optimizedParams = optimized.deep_neural_network;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.DeepNeuralNetworkHyperParameterSpace hpSpace = this.getHyperParameterSpace(preTrainParams);
        hpSpace.learning_rate.setToSingleValueGrid(optimizedParams.learning_rate);
        hpSpace.hidden_layers.setToSingleValueGrid(optimizedParams.hidden_layers);
        hpSpace.units.setToSingleValueGrid(optimizedParams.units);
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        if (this.isClassification) {
            target.deep_neural_network_classification = preTrainParams.deep_neural_network_classification_grid;
            target.deep_neural_network_classification.enabled = true;
        } else {
            target.deep_neural_network_regression = preTrainParams.deep_neural_network_regression_grid;
            target.deep_neural_network_regression.enabled = true;
        }
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        if (this.isClassification) {
            target.deep_neural_network_classification = usedToTrain.deep_neural_network_classification_grid;
            target.deep_neural_network_classification.enabled = true;
        } else {
            target.deep_neural_network_regression = usedToTrain.deep_neural_network_regression_grid;
            target.deep_neural_network_regression.enabled = true;
        }
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return this.isClassification;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.DeepNeuralNetworkHyperParameterSpace dlSpace = (PredictionModelingParams.DeepNeuralNetworkHyperParameterSpace)space;
        return dlSpace.learning_rate.getLength() * dlSpace.hidden_layers.getLength() * dlSpace.units.getLength();
    }

    private PredictionModelingParams.DeepNeuralNetworkHyperParameterSpace getHyperParameterSpace(PreTrainPredictionModelingParams modelingParams) {
        return this.isClassification ? modelingParams.deep_neural_network_classification_grid : modelingParams.deep_neural_network_regression_grid;
    }

    private PredictionModelingParams.DeepNeuralNetworkHyperParameterSpace getHyperParameterSpace(PredictionModelingParams modelingParams) {
        return this.isClassification ? modelingParams.deep_neural_network_classification : modelingParams.deep_neural_network_regression;
    }
}

