/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.PredictionAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.ArrayList;
import java.util.List;

public class KerasMeta
extends PredictionAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Neural Network built with Keras";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        return new ModelTrainInfo.PreSearchDescription(rpmp);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        return new ModelTrainInfo.PostSearchDescription().withSVParam("epochs", after.keras.epochs);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.KerasCodeParams keras = pmp.keras;
        if (keras == null || !keras.enabled) {
            return;
        }
        ErrorContext.checkNotEmpty((String)keras.buildCode, (String)"Keras model is not configured (you must write the architecture code)");
        if (keras.advancedFitMode) {
            ErrorContext.checkNotEmpty((String)keras.fitCode, (String)"Keras Train code is not configured (you must write the fit code in Advanced mode)");
        } else {
            ErrorContext.check((keras.epochs > 0 ? 1 : 0) != 0, (String)"Number of epochs needs to be positive.");
            ErrorContext.check((keras.batchSize > 0 ? 1 : 0) != 0, (String)"Batch size needs to be positive.");
            if (!keras.trainOnAllData) {
                ErrorContext.check((keras.stepsPerEpoch > 0 ? 1 : 0) != 0, (String)"Steps per epoch needs to be positive.");
            }
        }
    }

    @Override
    public MLTask.BackendType backendType() {
        return MLTask.BackendType.KERAS;
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.KerasCodeParams kp = pmp.keras;
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (kp == null || !kp.enabled) {
            return ret;
        }
        PreTrainPredictionModelingParams rpmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.KERAS_CODE, pmp);
        rpmp.keras = kp;
        rpmp.gridLength = rpmp.keras.epochs;
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rpmp);
        ret.add(ms);
        return ret;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        return this.getCopyWithGridStrategy(usedToTrain);
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        target.keras = preTrain.keras;
        target.keras.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.keras = usedToTrain.keras;
        target.keras.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }
}

