/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.ArrayList;
import java.util.List;

public class KNNMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        assert (rpmp.knn_grid != null);
        if (this.getSearchSize(rpmp.grid_search_params, rpmp.knn_grid) > 1) {
            return "K Nearest Neighbors (grid)";
        }
        return "K Nearest Neighbors (k=" + ((Long[])rpmp.knn_grid.k.values)[0] + ")";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.KNNHyperparametersSpace knn = rpmp.knn_grid;
        ModelTrainInfo.PreSearchDescription ret = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, knn)).withMVParam("k", knn.k).withSVParam("p", knn.p);
        if (knn.distance_weighting) {
            ret.withSVParam("distance_weighting", "yes");
        }
        return ret;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        ModelTrainInfo.PostSearchDescription ps2 = new ModelTrainInfo.PostSearchDescription().withSVParam("k", after.knn.k);
        if (after.knn.distance_weighting) {
            ps2.withSVParam("distance_weighting", "yes");
        }
        return ps2.withSVParam("p", after.knn.p);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.KNNHyperparametersSpace knn = pmp.knn;
        if (knn == null || !knn.enabled) {
            return;
        }
        checks.checkNumericalDimension(knn.k, "Number of neighbors (KNN)");
        if (knn.algorithm == PredictionModelingParams.KNNAlgorithm.BALL_TREE || knn.algorithm == PredictionModelingParams.KNNAlgorithm.KD_TREE) {
            ErrorContext.check((knn.leaf_size > 0 ? 1 : 0) != 0, (String)"Leaf size be > 0");
        }
        ErrorContext.check((knn.p > 0 ? 1 : 0) != 0, (String)"Exponent of the minkowski metric must be > 0");
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace knn) {
        return ((PredictionModelingParams.KNNHyperparametersSpace)knn).k.getLength();
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.KNNHyperparametersSpace knn = pmp.knn;
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (knn == null || !knn.enabled) {
            return ret;
        }
        PreTrainPredictionModelingParams rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.KNN, pmp);
        rcmp.knn_grid = knn;
        rcmp.max_ensemble_nodes_serialized = pmp.max_ensemble_nodes_serialized;
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        rcmp.gridLength = this.getSearchSize(rcmp.grid_search_params, knn);
        ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
        ret.add(ms);
        return ret;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        ret.knn_grid.k.setToSingleValueGrid(Long.valueOf(optimized.knn.k));
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        target.knn = preTrain.knn_grid;
        target.knn.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.knn = usedToTrain.knn_grid;
        target.knn.enabled = true;
    }
}

