/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.ArrayList;
import java.util.List;

public class SGDMeta
extends PyMemoryAlgorithmMeta {
    private final boolean isClassification;

    public SGDMeta(boolean isClassification) {
        this.isClassification = isClassification;
    }

    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "SGD";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.SGDHyperparametersSpace space;
        PredictionModelingParams.SGDHyperparametersSpace sGDHyperparametersSpace = space = this.isClassification ? rpmp.sgd_grid : rpmp.sgd_reg_grid;
        assert (space != null);
        return new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("penalty", space.penalty).withMVParam("loss", space.loss);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        ModelTrainInfo.PostSearchDescription ps2 = new ModelTrainInfo.PostSearchDescription().withSVParam("penalty", after.sgd.penalty).withSVParam("loss", after.sgd.loss).withSVParam("alpha", Float.valueOf(after.sgd.alpha));
        if (after.sgd.loss.equals("huber")) {
            ps2.withSVParam("epsilon", after.sgd.epsilon);
        }
        if (after.sgd.penalty.equals("elasticnet")) {
            ps2.withSVParam("l1_ratio", Float.valueOf(after.sgd.l1_ratio));
        }
        return ps2;
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        if (this.isClassification) {
            PredictionModelingParams.SGDClassificationHyperparametersSpace sgd = pmp.sgd_classifier;
            if (sgd == null || !sgd.enabled) {
                return;
            }
            ErrorContext.check((sgd.loss.getLength() > 0 ? 1 : 0) != 0, (String)"SGD requires a loss function");
            ErrorContext.check((sgd.max_iter > 0 ? 1 : 0) != 0, (String)"SGD number of iterations must be > 0");
            ErrorContext.check((sgd.penalty.getLength() > 0 ? 1 : 0) != 0, (String)"SGD requires a penalty function");
            checks.checkNumericalDimension(sgd.alpha, "Alpha regularization coefficient (SGD regression)");
            if (sgd.penalty.values.get((Object)"elasticnet").enabled) {
                ErrorContext.check((sgd.l1_ratio > 0.0f && sgd.l1_ratio < 1.0f ? 1 : 0) != 0, (String)"SGD l1 ratio must be in ]0; 1[");
            }
        } else {
            PredictionModelingParams.SGDRegressionHyperparametersSpace sgd = pmp.sgd_regression;
            if (sgd == null || !sgd.enabled) {
                return;
            }
            ErrorContext.check((sgd.loss.getLength() > 0 ? 1 : 0) != 0, (String)"SGD requires a loss function");
            ErrorContext.check((sgd.max_iter > 0 ? 1 : 0) != 0, (String)"SGD number of iterations must be > 0");
            ErrorContext.check((sgd.penalty.getLength() > 0 ? 1 : 0) != 0, (String)"SGD requires a penalty function");
            checks.checkNumericalDimension(sgd.alpha, "Alpha regularization coefficient (SGD regression)");
            if (sgd.loss.values.get((Object)"huber").enabled) {
                checks.checkNumericalDimension(sgd.epsilon, "Epsilon coefficient (SGD regression)");
            }
            if (sgd.penalty.values.get((Object)"elasticnet").enabled) {
                ErrorContext.check((sgd.l1_ratio > 0.0f && sgd.l1_ratio < 1.0f ? 1 : 0) != 0, (String)"SGD l1 ratio must be in ]0; 1[");
            }
        }
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        if (this.isClassification) {
            PredictionModelingParams.SGDClassificationHyperparametersSpace sgd = (PredictionModelingParams.SGDClassificationHyperparametersSpace)space;
            return sgd.penalty.getLength() * sgd.loss.getLength() * sgd.alpha.getLength();
        }
        PredictionModelingParams.SGDRegressionHyperparametersSpace sgd = (PredictionModelingParams.SGDRegressionHyperparametersSpace)space;
        int squaredLossCompatibleGridLength = 0;
        if (sgd.loss.values.containsKey("squared_loss") && sgd.loss.values.get((Object)"squared_loss").enabled) {
            squaredLossCompatibleGridLength = sgd.penalty.getLength() * sgd.alpha.getLength();
        }
        int huberCompatibleGridLength = 0;
        if (sgd.loss.values.containsKey("huber") && sgd.loss.values.get((Object)"huber").enabled) {
            huberCompatibleGridLength = sgd.epsilon.getLength() * sgd.penalty.getLength() * sgd.alpha.getLength();
        }
        return squaredLossCompatibleGridLength + huberCompatibleGridLength;
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PreTrainPredictionModelingParams rcmp;
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (this.isClassification) {
            sgd = pmp.sgd_classifier;
            if (sgd == null || !sgd.enabled) {
                return ret;
            }
            rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.SGD_CLASSIFICATION, pmp);
            rcmp.sgd_grid = sgd;
            rcmp.gridLength = this.getSearchSize(rcmp.grid_search_params, sgd);
        } else {
            sgd = pmp.sgd_regression;
            if (sgd == null || !((PredictionModelingParams.SGDRegressionHyperparametersSpace)sgd).enabled) {
                return ret;
            }
            rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.SGD_REGRESSION, pmp);
            rcmp.sgd_reg_grid = sgd;
            rcmp.gridLength = this.getSearchSize(rcmp.grid_search_params, sgd);
        }
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
        ret.add(ms);
        return ret;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    public boolean isSQLCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPythonCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPMMLCompatible() {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        if (this.isClassification) {
            ret.sgd_grid.alpha.setToSingleValueGrid(Double.valueOf(optimized.sgd.alpha));
            ret.sgd_grid.loss = CategoricalHyperparameterDimension.create(optimized.sgd.loss, "log", "modified_huber");
            ret.sgd_grid.penalty = CategoricalHyperparameterDimension.create(optimized.sgd.penalty, "l1", "l2", "elasticnet");
        } else {
            ret.sgd_reg_grid.alpha.setToSingleValueGrid(Double.valueOf(optimized.sgd.alpha));
            ret.sgd_reg_grid.loss = CategoricalHyperparameterDimension.create(optimized.sgd.loss, "squared_loss", "huber");
            ret.sgd_reg_grid.penalty = CategoricalHyperparameterDimension.create(optimized.sgd.penalty, "l1", "l2", "elasticnet");
        }
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        if (this.isClassification) {
            target.sgd_classifier = preTrain.sgd_grid;
            target.sgd_classifier.enabled = true;
        } else {
            target.sgd_regression = preTrain.sgd_reg_grid;
            target.sgd_regression.enabled = true;
        }
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        if (this.isClassification) {
            target.sgd_classifier = usedToTrain.sgd_grid;
            target.sgd_classifier.enabled = true;
        } else {
            target.sgd_regression = usedToTrain.sgd_reg_grid;
            target.sgd_regression.enabled = true;
        }
    }
}

