/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.pipelines;

import com.dataiku.scoring.pipelines.AbstractCalibrator;
import com.dataiku.scoring.util.MathUtils;

public class SigmoidCalibrator
extends AbstractCalibrator {
    private double[] aArray;
    private double[] bArray;
    private int nClasses;

    public SigmoidCalibrator(double[] aArray, double[] bArray, boolean fromProba) {
        super(fromProba);
        if (aArray.length != bArray.length) {
            throw new IllegalArgumentException("a and b arrays have inconsistent lengths (a: " + aArray.length + ", b: " + bArray.length + ")");
        }
        this.aArray = aArray;
        this.bArray = bArray;
        this.nClasses = aArray.length > 2 ? aArray.length : 2;
    }

    @Override
    public double[] getCalibratedProbabilities(double[] input) {
        double[] calibratedProbas = new double[this.nClasses];
        if (this.nClasses == 2) {
            double a = this.aArray[0];
            double b = this.bArray[0];
            double x = input[1];
            calibratedProbas[1] = this.expects32BitFloat ? (double)((float)MathUtils.sigmoid(-((float)a * (float)x + (float)b))) : MathUtils.sigmoid(-(a * x + b));
            calibratedProbas[0] = 1.0 - calibratedProbas[1];
        } else {
            int classId;
            double norm = 0.0;
            for (classId = 0; classId < this.nClasses; ++classId) {
                double a = this.aArray[classId];
                double b = this.bArray[classId];
                double x = input[classId];
                calibratedProbas[classId] = this.expects32BitFloat ? (double)((float)MathUtils.sigmoid(-((float)a * (float)x + (float)b))) : MathUtils.sigmoid(-(a * x + b));
                norm += calibratedProbas[classId];
            }
            classId = 0;
            while (classId < this.nClasses) {
                int n = classId++;
                calibratedProbas[n] = calibratedProbas[n] / norm;
            }
        }
        return calibratedProbas;
    }
}

