/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.pipelines;

import com.dataiku.scoring.Try;
import com.dataiku.scoring.linalg.Vector;
import com.dataiku.scoring.models.GradientBoostingClassifier;
import com.dataiku.scoring.models.ProbabilisticClassifier;
import com.dataiku.scoring.pipelines.AbstractCalibrator;
import com.dataiku.scoring.pipelines.AbstractProbabilisticClassificationPipeline;
import com.dataiku.scoring.pipelines.ClassificationResult;
import com.dataiku.scoring.pipelines.MulticlassProbabilisticPipeline;
import com.dataiku.scoring.pipelines.PreprocessingPipeline;
import com.dataiku.scoring.pipelines.overrides.OverridesLayerBase;
import com.dataiku.scoring.util.MathUtils;
import com.dataiku.scoring.util.RawObservation;

public class MulticlassProbabilisticPipelineImpl
extends AbstractProbabilisticClassificationPipeline
implements MulticlassProbabilisticPipeline {
    private static final long serialVersionUID = 0L;
    protected final AbstractCalibrator calibrator;
    protected final String partition;

    public MulticlassProbabilisticPipelineImpl(PreprocessingPipeline preprocessing, ProbabilisticClassifier model, String[] classes, AbstractCalibrator calibrator, String partition, OverridesLayerBase<ClassificationResult> overridesLayer) {
        super(preprocessing, model, classes, overridesLayer);
        this.calibrator = calibrator;
        this.calibrator.expects32BitFloat = model instanceof GradientBoostingClassifier && ((GradientBoostingClassifier)model).expects32BitFloat;
        this.partition = partition;
    }

    private Try<double[]> getProbabilities(RawObservation r) {
        double[] X;
        Try<Vector> v = this.preprocessing.process(r);
        if (v.isError()) {
            return Try.failure(v.getMessage());
        }
        if (this.calibrator.isFromProba()) {
            Try<double[]> probas = ((ProbabilisticClassifier)this.model).probabilities(v.get());
            if (v.isError()) {
                return Try.failure(probas.getMessage());
            }
            X = probas.get();
            if (X.length != this.classes.length) {
                return Try.failure("Probability vector had length different from number of classes.");
            }
        } else {
            Try<double[]> dec = ((ProbabilisticClassifier)this.model).decisionFunction(v.get());
            X = dec.get();
            if (X.length != this.classes.length) {
                return Try.failure("Probability vector had length different from number of classes.");
            }
        }
        double[] calibratedProbas = this.calibrator.getCalibratedProbabilities(X);
        return Try.success(calibratedProbas);
    }

    private int predictFromProba(double[] probas) {
        return MathUtils.argmax(probas);
    }

    @Override
    public Try<ClassificationResult> getPredictionResults(RawObservation r) {
        this.checkInitialized();
        RawObservation originObservation = this.copyRawObservationForPostPredictIfNeededOrNull(r);
        Try<double[]> v = this.getProbabilities(r);
        if (v.isError()) {
            return Try.failure(v.getMessage());
        }
        double[] calibratedProbas = v.get();
        int predIndex = this.predictFromProba(calibratedProbas);
        if (predIndex >= this.classes.length || predIndex < 0) {
            return Try.failure("Classification yielded an invalid (null or out of bounds) index.");
        }
        return this.postPredict(originObservation, new ClassificationResult(this.classes[predIndex], calibratedProbas, this.partition));
    }
}

