/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.guess;

import com.dataiku.dip.analysis.ml.prediction.guess.PredictionGuesser;
import com.dataiku.dip.analysis.ml.shared.FeatureGuessUtils;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.preprocessing.CatFeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.ImageFeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.NumFeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.TextFeaturePreprocessingParams;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.utils.JSON;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

public abstract class TabularPredictionGuesser<T extends PredictionMLTask.TabularPredictionMLTask>
extends PredictionGuesser<T> {
    public static final int NB_LIMIT_TARGET_CLASS = 2;

    protected TabularPredictionGuesser(T task, MemTable table) {
        super(task, table);
    }

    protected abstract Optional<FeaturePreprocessingParams.Role> getSpecialFeatureRole(MemColumn var1);

    protected abstract PredictionModelingParams guessAlgorithms(MemTable var1, T var2, boolean var3);

    protected void setTargetRemapping(boolean isRegression, boolean throwException) {
        ((PredictionMLTask.TabularPredictionMLTask)this.task).getPreprocessingParams().target_remapping = isRegression ? new ArrayList() : this.guessClassificationTargetRemapping(this.table, ((PredictionMLTask.TabularPredictionMLTask)this.task).targetVariable, throwException);
    }

    protected FeaturePreprocessingParams guessSpecialFeature(boolean isNumerical, FeaturePreprocessingParams.Role role) {
        if (isNumerical) {
            NumFeaturePreprocessingParams preprocessingParams = new NumFeaturePreprocessingParams();
            preprocessingParams.role = role;
            return preprocessingParams;
        }
        CatFeaturePreprocessingParams preprocessingParams = new CatFeaturePreprocessingParams();
        preprocessingParams.role = role;
        return preprocessingParams;
    }

    public CatFeaturePreprocessingParams guessCategorical(MemColumn column) {
        return FeatureGuessUtils.standardNoSparseCatFeatureGuess(this.table, column, this.task);
    }

    public NumFeaturePreprocessingParams guessNumerical(MemColumn column) {
        return FeatureGuessUtils.standardNumFeatureGuess(this.table, column, ((PredictionMLTask.TabularPredictionMLTask)this.task).backendType);
    }

    public TextFeaturePreprocessingParams guessText(MemColumn column) {
        return FeatureGuessUtils.standardTextFeatureGuess(column, ((PredictionMLTask.TabularPredictionMLTask)this.task).backendType);
    }

    public ImageFeaturePreprocessingParams guessImage(MemColumn column, String embeddingModelId) {
        return FeatureGuessUtils.standardImageFeatureGuess(column, ((PredictionMLTask.TabularPredictionMLTask)this.task).backendType, embeddingModelId);
    }

    protected void setNIterRandom() {
        if (this.task == null || ((PredictionMLTask.TabularPredictionMLTask)this.task).modeling == null || ((PredictionMLTask.TabularPredictionMLTask)this.task).modeling.gridSearchParams == null || ((PredictionMLTask.TabularPredictionMLTask)this.task).backendType != MLTask.BackendType.PY_MEMORY) {
            return;
        }
        PredictionMLTask.TabularPredictionMLTask taskWithGrid = (PredictionMLTask.TabularPredictionMLTask)JSON.deepCopy((Object)((PredictionMLTask.TabularPredictionMLTask)this.task));
        taskWithGrid.modeling.gridSearchParams.strategy = PredictionModelingParams.GridSearchParams.Strategy.GRID;
        int gridLength = 1;
        for (PreTrainPredictionModelingParams.Algorithm algo : PreTrainPredictionModelingParams.Algorithm.values()) {
            List<WorkSet.ModelingSet> modelingSets = algo.meta.expandModeling(taskWithGrid.modeling, (PredictionMLTask.TabularPredictionMLTask)this.task, 1);
            for (WorkSet.ModelingSet ms : modelingSets) {
                int algoGridLength = ((PreTrainPredictionModelingParams)ms.modelingParams).gridLength;
                if (algoGridLength <= gridLength) continue;
                gridLength = algoGridLength;
            }
        }
        ((PredictionMLTask.TabularPredictionMLTask)this.task).modeling.gridSearchParams.nIterRandom = gridLength;
    }

    public void reguessAlgorithms() {
        ((PredictionMLTask.TabularPredictionMLTask)this.task).modeling.disableAll();
        ((PredictionMLTask.TabularPredictionMLTask)this.task).modeling = this.guessAlgorithms(this.table, (PredictionMLTask.TabularPredictionMLTask)this.task, true);
    }
}

