/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.models;

import com.dataiku.scoring.Try;
import com.dataiku.scoring.linalg.Vector;
import com.dataiku.scoring.models.ProbabilisticClassifier;
import com.dataiku.scoring.util.MathUtils;
import java.util.Arrays;

public class LogisticRegression
implements ProbabilisticClassifier {
    private static final long serialVersionUID = 0L;
    private final Policy policy;
    private final double[] baseline;
    private final double[][] coefficients;
    private final int vectorSize;
    private final int nClasses;

    public LogisticRegression(Policy policy, double[] baseline, double[][] coefficients) {
        this.policy = policy;
        this.baseline = baseline;
        this.coefficients = coefficients;
        this.vectorSize = coefficients[0].length;
        this.nClasses = coefficients.length;
    }

    private double[] multinomialProbabilities(double[] dec) {
        return MathUtils.softmax(dec);
    }

    private double[] oneVersusAllProbabilities(double[] dec) {
        int i;
        double[] p = new double[this.baseline.length];
        double norm = 0.0;
        for (i = 0; i < this.baseline.length; ++i) {
            p[i] = 1.0 / (1.0 + Math.exp(-dec[i]));
            norm += p[i];
        }
        i = 0;
        while (i < this.baseline.length) {
            int n = i++;
            p[n] = p[n] / norm;
        }
        return p;
    }

    private double[] modifiedHuberProbabilities(double[] dec) {
        int i;
        double[] p = new double[this.baseline.length];
        for (int i2 = 0; i2 < this.baseline.length; ++i2) {
            p[i2] = 0.5 * (1.0 + Math.min(1.0, Math.max(-1.0, dec[i2])));
        }
        if (this.nClasses == 2) {
            p[0] = 1.0 - p[1];
        }
        double norm = 0.0;
        for (i = 0; i < this.baseline.length; ++i) {
            norm += p[i];
        }
        if (norm < 1.0E-15) {
            Arrays.fill(p, 1.0 / (double)this.baseline.length);
        } else {
            i = 0;
            while (i < this.baseline.length) {
                int n = i++;
                p[n] = p[n] / norm;
            }
        }
        return p;
    }

    @Override
    public Try<double[]> probabilities(Vector v) {
        if (v.size() != this.vectorSize) {
            return Try.failure("Vector size and coefficient length do not match.");
        }
        double[] dec = this.decisionFunction(v).get();
        switch (this.policy) {
            case ONE_VERSUS_ALL: {
                return Try.success(this.oneVersusAllProbabilities(dec));
            }
            case MULTINOMIAL: {
                return Try.success(this.multinomialProbabilities(dec));
            }
            case MODIFIED_HUBER: {
                return Try.success(this.modifiedHuberProbabilities(dec));
            }
        }
        throw new IllegalArgumentException("Policy " + String.valueOf((Object)this.policy) + " is unknown.");
    }

    @Override
    public Try<double[]> decisionFunction(Vector v) {
        double[] dec = (double[])this.baseline.clone();
        for (int i : v.activeIndices()) {
            double x = v.get(i);
            for (int k = 0; k < this.baseline.length; ++k) {
                int n = k;
                dec[n] = dec[n] + this.coefficients[k][i] * x;
            }
        }
        return Try.success(dec);
    }

    @Override
    public Try<Integer> predict(Vector v) {
        Try<double[]> p = this.probabilities(v);
        if (p.isError()) {
            return Try.failure(p.getMessage());
        }
        return Try.success(MathUtils.argmax(p.get()));
    }

    @Override
    public int getNumClasses() {
        return this.nClasses;
    }

    public double[] getBaseline() {
        return this.baseline;
    }

    public double[][] getCoefficients() {
        return this.coefficients;
    }

    public Policy getPolicy() {
        return this.policy;
    }

    public static enum Policy {
        ONE_VERSUS_ALL,
        MULTINOMIAL,
        MODIFIED_HUBER;

    }
}

