/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.models;

import com.dataiku.scoring.Try;
import com.dataiku.scoring.linalg.Vector;
import com.dataiku.scoring.models.GenericMLP;
import com.dataiku.scoring.models.ProbabilisticClassifier;
import com.dataiku.scoring.util.MathUtils;

public class MLPClassifier
extends GenericMLP
implements ProbabilisticClassifier {
    private static final long serialVersionUID = 1L;
    private final int numClasses;

    public MLPClassifier(GenericMLP.Activation activation, double[][] biases, double[][][] weights) {
        super(activation, biases, weights);
        this.numClasses = Math.max(biases[biases.length - 1].length, 2);
    }

    @Override
    public Try<double[]> probabilities(Vector v) {
        double[] raw = this.decisionFunction(v).get();
        if (this.numClasses == 2) {
            double[] p = new double[2];
            p[1] = GenericMLP.Activation.LOGISTIC.function.apply(raw[0]);
            p[0] = 1.0 - p[1];
            return Try.success(p);
        }
        return Try.success(MathUtils.softmax(raw));
    }

    @Override
    public Try<double[]> decisionFunction(Vector v) {
        double[] dec = this.getRawOutputs(v);
        return Try.success(dec);
    }

    @Override
    public Try<Integer> predict(Vector v) {
        Try<double[]> t = this.probabilities(v);
        if (t.isSuccess()) {
            return Try.success(MathUtils.argmax(this.probabilities(v).get()));
        }
        return Try.failure("prediction failed");
    }

    @Override
    public int getNumClasses() {
        return this.numClasses;
    }
}

