/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.models;

import com.dataiku.scoring.linalg.Vector;
import java.io.Serializable;

public class GenericMLP {
    private final Activation activation;
    private final double[][] biases;
    private final double[][][] weights;

    public GenericMLP(Activation activation, double[][] biases, double[][][] weights) {
        this.activation = activation;
        this.biases = biases;
        this.weights = weights;
    }

    private void applyActivation(double[] output) {
        for (int i = 0; i < output.length; ++i) {
            output[i] = this.activation.function.apply(output[i]);
        }
    }

    protected double[] getRawOutputs(Vector v) {
        int j;
        double[] outputs = new double[this.biases[0].length];
        for (int i = 0; i < this.biases[0].length; ++i) {
            outputs[i] = this.biases[0][i];
        }
        for (int i : v.activeIndices()) {
            for (j = 0; j < this.biases[0].length; ++j) {
                int n = j;
                outputs[n] = outputs[n] + this.weights[0][j][i] * v.get(i);
            }
        }
        this.applyActivation(outputs);
        for (int i = 1; i < this.biases.length; ++i) {
            double[] newOutput = new double[this.biases[i].length];
            for (j = 0; j < this.biases[i].length; ++j) {
                newOutput[j] = this.biases[i][j];
                for (int k = 0; k < this.weights[i][j].length; ++k) {
                    int n = j;
                    newOutput[n] = newOutput[n] + this.weights[i][j][k] * outputs[k];
                }
            }
            if (i != this.biases.length - 1) {
                this.applyActivation(newOutput);
            }
            outputs = newOutput;
        }
        return outputs;
    }

    public Activation getActivation() {
        return this.activation;
    }

    public double[][] getBiases() {
        return this.biases;
    }

    public double[][][] getWeights() {
        return this.weights;
    }

    public static enum Activation {
        LOGISTIC(new ActivationFunction(){
            private static final long serialVersionUID = 1L;

            @Override
            public double apply(double d) {
                return 1.0 / (1.0 + Math.exp(-d));
            }
        }),
        TANH(new ActivationFunction(){
            private static final long serialVersionUID = 1L;

            @Override
            public double apply(double d) {
                return Math.tanh(d);
            }
        }),
        RELU(new ActivationFunction(){
            private static final long serialVersionUID = 1L;

            @Override
            public double apply(double d) {
                return d > 0.0 ? d : 0.0;
            }
        }),
        IDENTITY(new ActivationFunction(){
            private static final long serialVersionUID = 1L;

            @Override
            public double apply(double d) {
                return d;
            }
        });

        final ActivationFunction function;

        private Activation(ActivationFunction function) {
            this.function = function;
        }
    }

    protected static interface ActivationFunction
    extends Serializable {
        public double apply(double var1);
    }
}

