/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.spark;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.SparklingAlgorithmMeta;
import java.util.ArrayList;
import java.util.List;

public class SparklingDeepLearningMeta
extends SparklingAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Deep Learning (H2O)";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        StringBuilder archi = new StringBuilder().append("[");
        PredictionModelingParams.H2ODeepLearningGridParams par = rpmp.deep_learning_sparkling_grid;
        for (int i = 0; i < par.hidden.length; ++i) {
            archi.append(par.hidden[i]);
            if (i == par.hidden.length - 1) continue;
            archi.append(", ");
        }
        archi.append("]");
        return new ModelTrainInfo.PreSearchDescription(rpmp).withSVParam("hidden_layers", archi.toString()).withSVParam("l1", Float.valueOf(par.l1)).withSVParam("l2", Float.valueOf(par.l2)).withSVParam("activation", par.activation).withSVParam("dropout", par.dropout);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        StringBuilder archi = new StringBuilder().append("[");
        PredictionModelingParams.H2ODeepLearningGridParams par = before.deep_learning_sparkling_grid;
        for (int i = 0; i < par.hidden.length; ++i) {
            archi.append(par.hidden[i]);
            if (i == par.hidden.length - 1) continue;
            archi.append(", ");
        }
        archi.append("]");
        return new ModelTrainInfo.PostSearchDescription().withSVParam("hidden_layers", archi.toString()).withSVParam("l1", Float.valueOf(par.l1)).withSVParam("l2", Float.valueOf(par.l2)).withSVParam("activation", par.activation).withSVParam("dropout", par.dropout);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        if (pmp.deep_learning_sparkling == null || !pmp.deep_learning_sparkling.enabled) {
            return;
        }
        checks.addWarningSparse("Deep learning (H2O)");
        pmp.deep_learning_sparkling.validate();
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (pmp.deep_learning_sparkling == null || !pmp.deep_learning_sparkling.enabled) {
            return ret;
        }
        if (!pmp.deep_learning_sparkling.dropout && pmp.deep_learning_sparkling.hidden.length != pmp.deep_learning_sparkling.hidden_dropout_ratios.length) {
            pmp.deep_learning_sparkling.hidden_dropout_ratios = new double[pmp.deep_learning_sparkling.hidden.length];
        }
        PreTrainPredictionModelingParams rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.SPARKLING_DEEP_LEARNING, pmp);
        rcmp.deep_learning_sparkling_grid = pmp.deep_learning_sparkling;
        ret.add(new WorkSet.ModelingSet(rcmp));
        return ret;
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return false;
    }

    @Override
    public boolean isPythonCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        return this.getCopyWithGridStrategy(usedToTrain);
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        target.deep_learning_sparkling = preTrain.deep_learning_sparkling_grid;
        target.deep_learning_sparkling.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.deep_learning_sparkling = usedToTrain.deep_learning_sparkling_grid;
        target.deep_learning_sparkling.enabled = true;
    }
}

