/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.mec;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.ModelLikeId;
import com.dataiku.dip.analysis.ml.prediction.PredictionPostComputationHandler;
import com.dataiku.dip.analysis.ml.prediction.PythonPostComputationHandler;
import com.dataiku.dip.analysis.model.prediction.ClassificationModelIntrinsicPerf;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelIntrinsicPerf;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.mec.FullModelEvaluationId;
import com.dataiku.dip.mec.TabularModelEvaluation;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.NotImplementedException;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;

public class PythonPostModelEvaluationComputationHandler
extends PythonPostComputationHandler {
    private ResolvedClassicalPredictionCoreParams coreParams;
    private static Logger logger = Logger.getLogger((String)"dku.modelevaluation.postcomputation");

    public PythonPostModelEvaluationComputationHandler(AuthCtx authCtx, String jobId, ModelLikeId mle, PredictionPostComputationHandler.PostComputationCommand computationCommand, JsonObject computationParameters) {
        super(authCtx, jobId, mle, computationCommand, computationParameters);
    }

    @Override
    protected void abort() {
        if (this.kernel != null) {
            this.kernel.killNoWaitNoException(true);
        }
    }

    @Override
    protected Map<String, Object> prepareParams() throws IOException {
        HashMap<String, Object> params = new HashMap<String, Object>();
        TabularModelEvaluation me = ((FullModelEvaluationId)this.mle).getTabularModelEvaluation();
        PredictionModelIntrinsicPerf iperf = switch (me.predictionType) {
            case PredictionMLTask.PredictionType.BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.MULTICLASS -> ((FullModelEvaluationId)this.mle).getIPerf(ClassificationModelIntrinsicPerf.class);
            case PredictionMLTask.PredictionType.REGRESSION -> ((FullModelEvaluationId)this.mle).getIPerf(PredictionModelIntrinsicPerf.class);
            default -> throw new NotImplementedException("Unknown model type : " + String.valueOf((Object)me.modelType));
        };
        ResolvedClassicalPredictionPreprocessingParams rppp = this.mle.getResolvedPredictionPreprocessingParams();
        if (StringUtils.isEmpty((String)me.targetVariable)) {
            me.targetVariable = rppp.getTarget();
        }
        if (TabularModelEvaluation.ModelType.SAVED_MODEL == me.modelType) {
            switch (me.predictionType) {
                case BINARY_CLASSIFICATION: 
                case MULTICLASS: {
                    if (!iperf.probaAware) break;
                    me.probaColumns = new ArrayList<TabularModelEvaluation.EvaluationModelInfoProba>();
                    for (PredictionPreprocessingParams.MappingValue mv : rppp.target_remapping) {
                        me.probaColumns.add(new TabularModelEvaluation.EvaluationModelInfoProba(mv.mappedValue, "proba_" + mv.sourceValue));
                    }
                    break;
                }
                case REGRESSION: {
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Unknown prediction type: " + String.valueOf((Object)me.predictionType));
                }
            }
        }
        params.put("job_id", this.jobId);
        params.put("model_evaluation", me);
        params.put("features", this.mle.getFeatures());
        params.put("resolved_preprocessing_params", rppp);
        params.put("iperf", iperf);
        params.put("modelevaluation_folder", this.mle.getMainFolder().getAbsolutePath());
        params.put("model_folder", me.modelRef != null ? me.modelRef.getModelFolder().getAbsolutePath() : null);
        params.put("computation_parameters", this.computationParameters);
        return params;
    }

    @Override
    protected File getExecutionDirForContainerExec() {
        return this.mle.getMainFolder();
    }

    @Override
    protected FullModelId.Type getFMIType() {
        return null;
    }

    @Override
    protected String envName() {
        return this.coreParams != null ? this.coreParams.executionParams.envName : null;
    }

    @Override
    protected ContainerExecRuntimeConfig getContainerConfig() throws IOException, DKUSecurityException {
        ContainerExecSelection containerExecSelection = new ContainerExecSelection();
        containerExecSelection.containerMode = ContainerExecSelection.ContainerExecMode.INHERIT;
        return new ContainerExecConfigSelector().selectForML_autoTXN(this.authCtx, this.mle.getProjectKey(), containerExecSelection, this.mle.getBackendType());
    }

    @Override
    protected void init() throws IOException {
        if (this.mle.getUnderlyingModel() != null && this.mle.getUnderlyingModel().exists()) {
            this.coreParams = (ResolvedClassicalPredictionCoreParams)this.mle.getUnderlyingModel().getResolvedCoreParams();
        }
    }

    @Override
    protected File getContextDirForContainerExec() throws IOException {
        FullModelId underlyingModel = this.mle.getUnderlyingModel();
        if (underlyingModel != null) {
            return underlyingModel.getModelFolder();
        }
        return this.mle.getMainFolder();
    }
}

