/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.models;

import com.dataiku.scoring.Try;
import com.dataiku.scoring.linalg.Vector;
import com.dataiku.scoring.models.DecisionTreeModel;
import com.dataiku.scoring.models.ProbabilisticClassifier;
import com.dataiku.scoring.util.MathUtils;
import java.util.Arrays;

public class DecisionTreeClassifier
extends DecisionTreeModel<double[]>
implements ProbabilisticClassifier {
    private static final long serialVersionUID = 0L;
    private volatile int numClasses = 0;

    public DecisionTreeClassifier(DecisionTreeModel.Node<double[]> nodes, DecisionTreeModel.TreeVariant variant) {
        super(nodes, variant);
    }

    @Override
    public int getNumClasses() {
        if (this.numClasses <= 0) {
            DecisionTreeModel.Node current = this.getRoot();
            while (!current.isLeaf) {
                current = current.leftSon;
            }
            this.numClasses = ((double[])current.label).length;
        }
        return this.numClasses;
    }

    @Override
    public boolean expectsProcessedFeaturesAsDoubles() {
        return this.variant.expectsProcessedFeatureAsDouble;
    }

    @Override
    public Try<Integer> predict(Vector v) {
        return Try.success(this.predictUnsafe(v));
    }

    int predictUnsafe(Vector v) {
        return MathUtils.argmax(this.probasUnsafe(v));
    }

    double[] probasUnsafe(Vector v) {
        double[] label = (double[])this.getTerminalNode((Vector)v).label;
        return Arrays.copyOf(label, label.length);
    }

    @Override
    public Try<double[]> probabilities(Vector v) {
        return Try.success(this.probasUnsafe(v));
    }

    @Override
    public Try<double[]> decisionFunction(Vector v) {
        return Try.failure("No decision function in Decision Tree Classifier");
    }
}

