/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.DKUMLUtils;
import com.dataiku.dip.analysis.ml.prediction.CausalPredictionParamsExpander;
import com.dataiku.dip.analysis.ml.prediction.ClassicalPredictionParamsExpander;
import com.dataiku.dip.analysis.ml.prediction.TimeseriesForecastingParamsExpander;
import com.dataiku.dip.analysis.ml.shared.ParamsExpander;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.TabularPredictionPreprocessingParams;
import com.dataiku.dip.utils.DKULogger;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

public abstract class TabularPredictionParamsExpander
extends ParamsExpander {
    protected final PredictionMLTask.TabularPredictionMLTask task;
    protected final Set<PredictionModelingParams.GridSearchCrossValidationMode> supportedGsKFoldModes;
    private final String sessionId;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.models");

    protected TabularPredictionParamsExpander(PredictionMLTask.TabularPredictionMLTask task, String sessionId, Set<PredictionModelingParams.GridSearchCrossValidationMode> supportedGsKFoldModes) {
        this.task = task;
        this.sessionId = sessionId;
        this.supportedGsKFoldModes = supportedGsKFoldModes;
    }

    protected int getGsFolds() {
        if (this.task.modeling.gridSearchParams != null && this.supportedGsKFoldModes.contains((Object)this.task.modeling.gridSearchParams.mode)) {
            return this.task.modeling.gridSearchParams.nFolds;
        }
        return 1;
    }

    protected abstract ResolvedPredictionPreprocessingParams expandResolvedPredictionPreprocessingParams(TabularPredictionPreprocessingParams var1);

    protected void addModelingSets(WorkSet ws) {
        for (PreTrainPredictionModelingParams.Algorithm algo : PreTrainPredictionModelingParams.Algorithm.values()) {
            List<WorkSet.ModelingSet> modelingSets = algo.meta.expandModeling(this.task.modeling, this.task, this.getGsFolds());
            if (algo.meta.oneModelPerPreprocessingSet()) {
                for (WorkSet.ModelingSet modelingSet : modelingSets) {
                    this.addModelingSets(Lists.newArrayList((Object[])new WorkSet.ModelingSet[]{modelingSet}), ws, algo.meta.autoCompleter());
                }
                continue;
            }
            this.addModelingSets(modelingSets, ws, algo.meta.autoCompleter());
        }
    }

    public WorkSet expand() throws Exception {
        assert (this.task.predictionType != null);
        WorkSet ws = new WorkSet(this.sessionId);
        PredictionParameterChecks checks = new PredictionParameterChecks(this.task.modeling, this.task.getPreprocessingParams());
        try {
            DKUMLUtils.checkPredictionTaskBeforeTraining(this.task);
        }
        catch (Exception e) {
            logger.warn((Object)"Check task failed", (Throwable)e);
            checks.addError(e.getMessage(), null);
            ws.messages = checks.getMessages();
            return ws;
        }
        this.task.getPreprocessingParams().validate(checks);
        for (PreTrainPredictionModelingParams.Algorithm algo : PreTrainPredictionModelingParams.Algorithm.values()) {
            algo.meta.validateParameters(this.task.modeling, this.task, checks);
        }
        this.addModelingSets(ws);
        for (WorkSet.PreprocessingSet ps2 : ws.preprocessingSets) {
            this.task.getAssertionsParams().ifPresent(mlAssertionsParams -> {
                ps2.assertionsParams = mlAssertionsParams;
            });
            this.task.getOverridesParams().ifPresent(mlOverridesParams -> {
                ps2.overridesParams = mlOverridesParams;
            });
            for (WorkSet.ModelingSet ms : ps2.modelingSets) {
                PreTrainPredictionModelingParams rpmp = (PreTrainPredictionModelingParams)ms.modelingParams;
                rpmp.metrics = this.task.modeling.metrics;
                rpmp.autoOptimizeThreshold = !(this.task instanceof PredictionMLTask.CausalPredictionMLTask);
            }
        }
        logger.info((Object)("At the end, have " + ws.preprocessingSets.size()));
        ws.messages = checks.getMessages();
        return ws;
    }

    @Override
    protected List<WorkSet.PreprocessingSet> expandPreprocessing() {
        ArrayList<WorkSet.PreprocessingSet> ret = new ArrayList<WorkSet.PreprocessingSet>();
        WorkSet.PreprocessingSet ps2 = new WorkSet.PreprocessingSet(this.expandResolvedPredictionPreprocessingParams(this.task.getPreprocessingParams()), "");
        ret.add(ps2);
        return ret;
    }

    public static TabularPredictionParamsExpander createFromMLTask(PredictionMLTask.TabularPredictionMLTask task, String sessionId) {
        if (task instanceof PredictionMLTask.ClassicalPredictionMLTask) {
            return new ClassicalPredictionParamsExpander((PredictionMLTask.ClassicalPredictionMLTask)task, sessionId);
        }
        if (task instanceof PredictionMLTask.TimeseriesForecastingMLTask) {
            return new TimeseriesForecastingParamsExpander((PredictionMLTask.TimeseriesForecastingMLTask)task, sessionId);
        }
        if (task instanceof PredictionMLTask.CausalPredictionMLTask) {
            return new CausalPredictionParamsExpander((PredictionMLTask.CausalPredictionMLTask)task, sessionId);
        }
        throw new IllegalArgumentException("Unsupported ML task: " + task.getClass().getSimpleName());
    }
}

