/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.ArrayList;
import java.util.List;

public class SVMMeta
extends PyMemoryAlgorithmMeta {
    private final boolean isClassification;

    public SVMMeta(boolean isClassification) {
        this.isClassification = isClassification;
    }

    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "SVM";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.SVMHyperparametersSpace svm = this.isClassification ? rpmp.svc_grid : rpmp.svr_grid;
        ModelTrainInfo.PreSearchDescription ret = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, svm)).withMVParam("kernel", svm.kernel).withMVParam("C", svm.C);
        if (svm.countGammaCompatibleKernels() > 0) {
            ret.withMVParam("Gamma", svm.gamma);
            if (svm.gamma.values.containsKey("custom") && svm.gamma.values.get((Object)"custom").enabled) {
                ret.withMVParam("Gamma values", svm.custom_gamma);
            }
        }
        return ret;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        ModelTrainInfo.PostSearchDescription postSearchDescription = new ModelTrainInfo.PostSearchDescription().withSVParam("kernel", after.svm.kernel).withSVParam("C", Float.valueOf(after.svm.C));
        if (PredictionModelingParams.SVMHyperparametersSpace.isGammaCompatibleKernel(after.svm.kernel)) {
            Object usedGamma = "custom".equals(after.svm.gamma) ? Float.valueOf(after.svm.custom_gamma) : after.svm.gamma;
            postSearchDescription.withSVParam("gamma", usedGamma);
        }
        return postSearchDescription;
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.SVMHyperparametersSpace svm;
        PredictionModelingParams.SVMHyperparametersSpace sVMHyperparametersSpace = svm = this.isClassification ? pmp.svc_classifier : pmp.svm_regression;
        if (svm == null || !svm.enabled) {
            return;
        }
        checks.addWarningSparse("SVM");
        checks.checkNumericalDimension(svm.C, "C regularization parameter (SVM)");
        ErrorContext.check((svm.kernel.getLength() > 0 ? 1 : 0) != 0, (String)"SVM requires at least one kernel");
        if (svm.countGammaCompatibleKernels() > 0) {
            ErrorContext.check((svm.gamma.getLength() > 0 ? 1 : 0) != 0, (String)"SVM requires at least one gamma");
            if (svm.gamma.values.containsKey("custom") && svm.gamma.values.get((Object)"custom").enabled) {
                checks.checkNumericalDimension(svm.custom_gamma, "Custom gamma values (SVM)");
            }
        }
        ErrorContext.check((svm.tol > 0.0f ? 1 : 0) != 0, (String)"SVM tolerance must be positive");
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.SVMHyperparametersSpace svm = (PredictionModelingParams.SVMHyperparametersSpace)space;
        int linearKernelGridLength = 0;
        if (svm.kernel.values.containsKey("linear") && svm.kernel.values.get((Object)"linear").enabled) {
            linearKernelGridLength = svm.C.getLength();
        }
        int gammaCompatibleKernelGridLength = 0;
        if (svm.countGammaCompatibleKernels() > 0) {
            int gammaLength = svm.gamma.getLength();
            if (svm.gamma.values.containsKey("custom") && svm.gamma.values.get((Object)"custom").enabled) {
                gammaLength += svm.custom_gamma.getLength() - 1;
            }
            gammaCompatibleKernelGridLength = svm.countGammaCompatibleKernels() * svm.C.getLength() * gammaLength;
        }
        return linearKernelGridLength + gammaCompatibleKernelGridLength;
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.SVMHyperparametersSpace svm;
        PreTrainPredictionModelingParams rcmp;
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (this.isClassification) {
            rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.SVC_CLASSIFICATION, pmp);
            rcmp.svc_grid = svm = pmp.svc_classifier;
            if (svm.C.getLength() == 0) {
                svm.C.updateValues(1.0);
            }
        } else {
            rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.SVM_REGRESSION, pmp);
            rcmp.svr_grid = svm = pmp.svm_regression;
            if (svm.C.getLength() == 0) {
                svm.C.updateValues(1.0);
            }
        }
        if (svm == null || !svm.enabled) {
            return ret;
        }
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        rcmp.gridLength = this.getSearchSize(rcmp.grid_search_params, svm);
        ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
        ret.add(ms);
        return ret;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        if (this.isClassification) {
            ret.svc_grid.C.setToSingleValueGrid(Double.valueOf(optimized.svm.C));
            ret.svc_grid.gamma = CategoricalHyperparameterDimension.create(optimized.svm.gamma, "custom", "scale", "auto");
            double optimizedGammaValue = "custom".equals(optimized.svm.gamma) ? (double)optimized.svm.custom_gamma : 0.001;
            ret.svc_grid.custom_gamma.setToSingleValueGrid(optimizedGammaValue);
            ret.svc_grid.kernel = CategoricalHyperparameterDimension.create(optimized.svm.kernel, "rbf", "linear", "poly", "sigmoid");
        } else {
            ret.svr_grid.C.setToSingleValueGrid(Double.valueOf(optimized.svm.C));
            ret.svr_grid.gamma = CategoricalHyperparameterDimension.create(optimized.svm.gamma, "custom", "scale", "auto");
            double optimizedGammaValue = "custom".equals(optimized.svm.gamma) ? (double)optimized.svm.custom_gamma : 0.001;
            ret.svr_grid.custom_gamma.setToSingleValueGrid(optimizedGammaValue);
            ret.svr_grid.kernel = CategoricalHyperparameterDimension.create(optimized.svm.kernel, "rbf", "linear", "poly", "sigmoid");
        }
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        if (this.isClassification) {
            target.svc_classifier = preTrain.svc_grid;
            target.svc_classifier.enabled = true;
        } else {
            target.svm_regression = preTrain.svr_grid;
            target.svm_regression.enabled = true;
        }
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        if (this.isClassification) {
            target.svc_classifier = usedToTrain.svc_grid;
            target.svc_classifier.enabled = true;
        } else {
            target.svm_regression = usedToTrain.svr_grid;
            target.svm_regression.enabled = true;
        }
    }
}

