/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.pipelines.overrides;

import com.dataiku.scoring.models.overrides.MLOverridesParamsBase;
import com.dataiku.scoring.pipelines.ClassificationResult;
import com.dataiku.scoring.pipelines.OverrideInfo;
import com.dataiku.scoring.pipelines.overrides.OverridesLayerBase;
import com.dataiku.scoring.pipelines.overrides.OverridesOutcomeComputer;
import com.dataiku.scoring.util.RawObservation;
import java.util.HashMap;
import java.util.Map;

public class ProbabilisticClassificationOverridesLayer
extends OverridesLayerBase<ClassificationResult> {
    private final String[] classes;
    private final Map<String, Integer> classToIndexMap;
    public static final String UNCERTAINTY_COL = "prediction_uncertainty";
    public static final String PREDICTION_COL = "prediction";
    private static final String PROBA_COL_PREFIX = "proba_";

    public ProbabilisticClassificationOverridesLayer(OverridesOutcomeComputer<RawObservation> outcomeComputer, String[] classes) {
        super(outcomeComputer);
        this.classes = classes;
        this.classToIndexMap = this.initClassToIndexMap(classes);
    }

    private Map<String, Integer> initClassToIndexMap(String[] classes) {
        HashMap<String, Integer> map = new HashMap<String, Integer>();
        for (int i = 0; i < classes.length; ++i) {
            map.put(classes[i], i);
        }
        return map;
    }

    @Override
    void prepareRowForOverride(RawObservation originalRow, ClassificationResult rawResult) {
        originalRow.put(PREDICTION_COL, rawResult.getPrediction());
        double[] rawProbas = rawResult.getProbabilities();
        double maxProba = 0.0;
        for (int i = 0; i < this.classes.length; ++i) {
            originalRow.put(PROBA_COL_PREFIX + this.classes[i], rawProbas[i]);
            maxProba = Math.max(maxProba, rawProbas[i]);
        }
        originalRow.put(UNCERTAINTY_COL, 1.0 - maxProba);
    }

    @Override
    ClassificationResult applyOverride(OverridesOutcomeComputer.OutcomeCandidate<RawObservation> candidate, ClassificationResult rawResult) {
        ClassificationResult result;
        MLOverridesParamsBase.MLOverride.Outcome outcome = candidate.outcome;
        ClassificationResult.RawResult rawResultInfo = new ClassificationResult.RawResult(rawResult, this.classes);
        switch (outcome.type) {
            case CATEGORY: {
                Integer newPredictionIndex = this.classToIndexMap.get(outcome.category);
                if (newPredictionIndex == null) {
                    throw new IllegalArgumentException("Predicted category ('" + outcome.category + "') cannot be found in available categories");
                }
                double[] newProbas = new double[this.classes.length];
                newProbas[newPredictionIndex.intValue()] = 1.0;
                result = new ClassificationResult(outcome.category, newProbas, rawResult.getPartition());
                result.setOverrideInfo(new OverrideInfo(candidate.overrideName, !((String)rawResult.getPrediction()).equals(result.getPrediction()), rawResultInfo));
                break;
            }
            case DECLINED: {
                result = ClassificationResult.empty(rawResult.getPartition());
                result.setOverrideInfo(OverrideInfo.declined(candidate.overrideName, rawResultInfo));
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported Outcome Type (" + String.valueOf((Object)outcome.type) + "). Classification only supports category or declined override");
            }
        }
        return result;
    }
}

