/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import java.util.Collections;
import java.util.List;

public class GluonTSMXNetDeepARMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "DeepAR - MXNet";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.GluonTSMXNetDeepARSpace space = rpmp.gluonts_deepar_timeseries_grid;
        ModelTrainInfo.PreSearchDescription description = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("Learning rate", space.learning_rate).withMVParam("Output distribution", space.distr_output).withMVParam("Nb. RNN layers", space.num_layers).withMVParam("Nb. cells per layer", space.num_cells).withMVParam("Cell type", space.cell_type).withMVParam("Dropout cell type", space.dropoutcell_type).withMVParam("Dropout rate", space.dropout_rate).withMVParam("Alpha", space.alpha).withMVParam("Beta", space.beta);
        if (!space.full_context) {
            description = description.withMVParam("Context length", space.context_length);
        }
        return description;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PredictionModelingParams.GluonTSMXNetDeepARSpace space = before.gluonts_deepar_timeseries_grid;
        PostTrainPredictionModelingParams.GluonTSMXNetDeepARParams params = after.gluonts_deepar_timeseries_params;
        ModelTrainInfo.PostSearchDescription description = new ModelTrainInfo.PostSearchDescription().withSVParam("Learning rate", params.learning_rate).withSVParam("Output distribution", params.distr_output).withSVParam("Nb. RNN layers", params.num_layers).withSVParam("Nb. cells per layer", params.num_cells).withSVParam("Cell type", params.cell_type).withSVParam("Dropout cell type", params.dropoutcell_type).withSVParam("Dropout rate", params.dropout_rate).withSVParam("Alpha", params.alpha).withSVParam("Beta", params.beta).withSVParam("Epochs", params.epochs);
        if (!space.full_context) {
            description = description.withSVParam("Context length", params.context_length);
        }
        return description;
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.GluonTSMXNetDeepARSpace space = pmp.gluonts_deepar_timeseries;
        if (space == null || !space.enabled) {
            return;
        }
        checks.checkPositive(space.distr_output.getLength(), "At least one output distribution must be selected (DeepAR)");
        checks.checkPositive(space.cell_type.getLength(), "At least one cell type must be selected (DeepAR)");
        checks.checkPositive(space.dropoutcell_type.getLength(), "At least one dropout cell type must be selected (DeepAR)");
        if (!space.full_context) {
            checks.checkNumericalDimension(space.context_length, "Context length (DeepAR)");
        }
        checks.checkNumericalDimension(space.learning_rate, "Learning rate (DeepAR)");
        checks.checkNumericalDimension(space.num_layers, "Number of RNN layers (DeepAR)");
        checks.checkNumericalDimension(space.num_cells, "Number of RNN cells for each layer (DeepAR)");
        checks.checkNumericalDimension(space.dropout_rate, "Dropout regularization parameter (DeepAR)");
        checks.checkNumericalDimension(space.alpha, "Scaling coefficient of the activation regularization (DeepAR)");
        checks.checkNumericalDimension(space.beta, "Scaling coefficient of the temporal activation regularization (DeepAR)");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.GluonTSMXNetDeepARSpace space = pmp.gluonts_deepar_timeseries;
        if (space == null || !space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.GLUONTS_DEEPAR, pmp);
        preTrainParams.gluonts_deepar_timeseries_grid = space;
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        if (preTrainParams.gridLength > 1) {
            modelingSet.estimatedTrains = preTrainParams.gridLength * gsFolds + 1;
        }
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.GluonTSMXNetDeepARParams optimizedParameters = optimized.gluonts_deepar_timeseries_params;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.GluonTSMXNetDeepARSpace space = preTrainParams.gluonts_deepar_timeseries_grid;
        if (!preTrainParams.gluonts_deepar_timeseries_grid.full_context) {
            space.context_length.setToSingleValueGrid(optimizedParameters.context_length);
        }
        space.learning_rate.setToSingleValueGrid(optimizedParameters.learning_rate);
        space.distr_output = CategoricalHyperparameterDimension.create(optimizedParameters.distr_output, "StudentTOutput", "GaussianOutput", "NegativeBinomialOutput");
        space.num_layers.setToSingleValueGrid(optimizedParameters.num_layers);
        space.num_cells.setToSingleValueGrid(optimizedParameters.num_cells);
        space.cell_type = CategoricalHyperparameterDimension.create(optimizedParameters.cell_type, "lstm", "gru");
        space.dropoutcell_type = CategoricalHyperparameterDimension.create(optimizedParameters.dropoutcell_type, "ZoneoutCell", "RNNZoneoutCell", "VariationalDropoutCell", "VariationalZoneoutCell");
        space.dropout_rate.setToSingleValueGrid(optimizedParameters.dropout_rate);
        space.alpha.setToSingleValueGrid(optimizedParameters.alpha);
        space.beta.setToSingleValueGrid(optimizedParameters.beta);
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.gluonts_deepar_timeseries = preTrainParams.gluonts_deepar_timeseries_grid;
        target.gluonts_deepar_timeseries.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.gluonts_deepar_timeseries = usedToTrain.gluonts_deepar_timeseries_grid;
        target.gluonts_deepar_timeseries.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.GluonTSMXNetDeepARSpace gluonTSMXNetDeepARSpace = (PredictionModelingParams.GluonTSMXNetDeepARSpace)space;
        int gridLength = gluonTSMXNetDeepARSpace.num_layers.getLength() * gluonTSMXNetDeepARSpace.learning_rate.getLength() * gluonTSMXNetDeepARSpace.distr_output.getLength() * gluonTSMXNetDeepARSpace.num_layers.getLength() * gluonTSMXNetDeepARSpace.num_cells.getLength() * gluonTSMXNetDeepARSpace.cell_type.getLength() * gluonTSMXNetDeepARSpace.dropoutcell_type.getLength() * gluonTSMXNetDeepARSpace.dropout_rate.getLength() * gluonTSMXNetDeepARSpace.alpha.getLength() * gluonTSMXNetDeepARSpace.beta.getLength();
        if (!gluonTSMXNetDeepARSpace.full_context) {
            gridLength *= gluonTSMXNetDeepARSpace.context_length.getLength();
        }
        return gridLength;
    }
}

