/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.pipelines;

import com.dataiku.scoring.Try;
import com.dataiku.scoring.linalg.Vector;
import com.dataiku.scoring.models.DecisionTreeModel;
import com.dataiku.scoring.models.GradientBoostingClassifier;
import com.dataiku.scoring.models.ProbabilisticClassifier;
import com.dataiku.scoring.pipelines.AbstractCalibrator;
import com.dataiku.scoring.pipelines.AbstractPipeline;
import com.dataiku.scoring.pipelines.AbstractProbabilisticClassificationPipeline;
import com.dataiku.scoring.pipelines.BinaryProbabilisticPipeline;
import com.dataiku.scoring.pipelines.ClassificationResult;
import com.dataiku.scoring.pipelines.IsotonicCalibrator;
import com.dataiku.scoring.pipelines.PreprocessingPipeline;
import com.dataiku.scoring.pipelines.overrides.OverridesLayerBase;
import com.dataiku.scoring.util.RawObservation;
import java.util.Arrays;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;

public class BinaryProbabilisticPipelineImpl
extends AbstractProbabilisticClassificationPipeline
implements BinaryProbabilisticPipeline {
    private static final long serialVersionUID = 0L;
    private final double threshold;
    private final AbstractCalibrator calibrator;
    private final String partition;
    private final double[] probaPercentiles;
    private static final double COMPAT_THRESHOLD_EPS = 5.0E-7;
    private static final Logger logger = Logger.getLogger("dku.scoring");

    public BinaryProbabilisticPipelineImpl(PreprocessingPipeline preprocessing, ProbabilisticClassifier model, String[] classes, double threshold, double[] probaPercentiles, AbstractCalibrator calibrator, String partition, OverridesLayerBase<ClassificationResult> overridesLayer) {
        super(preprocessing, model, classes, overridesLayer);
        this.calibrator = calibrator;
        this.calibrator.expects32BitFloat = model instanceof GradientBoostingClassifier && ((GradientBoostingClassifier)model).expects32BitFloat;
        this.partition = partition;
        if (classes.length != 2) {
            throw new IllegalArgumentException("Failed to build binary probabilistic pipeline, found " + classes.length + " class names instead of 2.");
        }
        if (model.getNumClasses() != 2) {
            throw new IllegalArgumentException("Failed to build binary probabilistic pipeline, model required " + classes.length + " classes instead of 2.");
        }
        boolean clipThreshold = false;
        if ((threshold < 5.0E-7 || threshold > 0.9999995) && calibrator instanceof IsotonicCalibrator && model instanceof GradientBoostingClassifier) {
            try {
                clipThreshold = ((GradientBoostingClassifier)model).getTrees()[0][0].variant == DecisionTreeModel.TreeVariant.XGBOOST;
            }
            catch (Exception e) {
                logger.log(Level.WARNING, "Could not check first node of GBT model", e);
            }
        }
        this.threshold = clipThreshold ? (threshold < 5.0E-7 ? 5.0E-7 : 0.9999995) : threshold;
        this.probaPercentiles = probaPercentiles;
        this.columnComputers.add(new ProbaPercentileComputer());
    }

    @Override
    public double getThreshold() {
        return this.threshold;
    }

    private String predictFromProba(double d) {
        return d > this.threshold ? this.classes[1] : this.classes[0];
    }

    protected Try<double[]> probabilityNotCalibrated(RawObservation r) {
        Try<Vector> v = this.preprocessing.process(r);
        if (v.isError()) {
            return Try.failure(v.getMessage());
        }
        Try<double[]> probas = ((ProbabilisticClassifier)this.getModel()).probabilities(v.get());
        if (probas.isError()) {
            return Try.failure(probas.getMessage());
        }
        double[] p = probas.get();
        if (p.length != 2) {
            return Try.failure("Had " + p.length + " classes instead of 2");
        }
        return Try.success(p);
    }

    protected Try<Double> decisionFunctionNotCalibrated(RawObservation r) {
        Try<Vector> v = this.preprocessing.process(r);
        if (v.isError()) {
            return Try.failure(v.getMessage());
        }
        Try<double[]> dec = ((ProbabilisticClassifier)this.getModel()).decisionFunction(v.get());
        if (dec.isError()) {
            return Try.failure(dec.getMessage());
        }
        double[] d = dec.get();
        if (d.length != 2) {
            return Try.failure("Had " + d.length + " classes instead of 2");
        }
        return Try.success(d[1]);
    }

    @Override
    public Try<ClassificationResult> getPredictionResults(RawObservation r) {
        this.checkInitialized();
        RawObservation originObservation = this.copyRawObservationForPostPredictIfNeededOrNull(r);
        if (this.calibrator.isFromProba()) {
            Try<double[]> probas = this.probabilityNotCalibrated(r);
            if (probas.isError()) {
                return Try.failure(probas.getMessage());
            }
            double[] p = probas.get();
            double[] calibratedProbas = this.calibrator.getCalibratedProbabilities(p);
            return this.postPredict(originObservation, new ClassificationResult(this.predictFromProba(calibratedProbas[1]), calibratedProbas, this.partition));
        }
        Try<Double> dec = this.decisionFunctionNotCalibrated(r);
        if (dec.isError()) {
            return Try.failure(dec.getMessage());
        }
        double d = dec.get();
        double[] calibratedProbas = this.calibrator.getCalibratedProbabilities(new double[]{0.0, d});
        return this.postPredict(originObservation, new ClassificationResult(this.predictFromProba(calibratedProbas[1]), calibratedProbas, this.partition));
    }

    @Override
    public Try<Short> getProbaPercentile(ClassificationResult result) {
        if (this.probaPercentiles == null || this.probaPercentiles.length == 0) {
            return Try.failure("No proba percentiles for this model");
        }
        double p = result.getProbabilities()[1];
        int i = Arrays.binarySearch(this.probaPercentiles, p);
        if (i >= 0) {
            while (i < this.probaPercentiles.length && this.probaPercentiles[i] <= p) {
                ++i;
            }
            ++i;
        } else {
            i = -i;
        }
        return Try.success(new Integer(i).shortValue());
    }

    public class ProbaPercentileComputer
    extends AbstractPipeline.AbstractColumnComputer<Short, String, ClassificationResult> {
        public static final String COLUMN_NAME = "proba_percentile";

        public ProbaPercentileComputer() {
            super(COLUMN_NAME, Short.class);
        }

        @Override
        public Optional<Short> getOutputValue(ClassificationResult result) {
            if (result.isDeclined()) {
                return Optional.empty();
            }
            Try<Short> probaPercentile = BinaryProbabilisticPipelineImpl.this.getProbaPercentile(result);
            if (probaPercentile.isSuccess()) {
                return Optional.of(probaPercentile.get());
            }
            return Optional.empty();
        }
    }
}

