/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.models;

import com.dataiku.scoring.linalg.Vector;
import java.io.Serializable;

public class DecisionTreeModel<T>
implements Serializable {
    private static final long serialVersionUID = 0L;
    private final Node<T> root;
    public final TreeVariant variant;

    public DecisionTreeModel(Node<T> root, TreeVariant variant) {
        if (!DecisionTreeModel.validateNode(root)) {
            throw new IllegalArgumentException("Nodes for decision tree model are incorrectly specified, tree is not binary or labels missing.");
        }
        this.root = root;
        this.variant = variant;
    }

    public Node<T> getRoot() {
        return this.root;
    }

    private static <T> boolean validateNode(Node<T> node) {
        if (node.isLeaf) {
            return node.label != null;
        }
        if (node.leftSon == null || node.rightSon == null || Double.isNaN(node.threshold)) {
            return false;
        }
        return DecisionTreeModel.validateNode(node.leftSon) && DecisionTreeModel.validateNode(node.rightSon);
    }

    Node<T> getTerminalNode(Vector v) {
        Node<T> current = this.root;
        while (!current.isLeaf) {
            if (this.goesLeft(current, v)) {
                current = current.leftSon;
                continue;
            }
            current = current.rightSon;
        }
        return current;
    }

    private boolean goesLeft(Node<T> node, Vector v) {
        double variable = v.get(node.variable);
        double threshold = node.threshold;
        if (v.isMissing(node.variable) && node.missingGoesLeft != null) {
            return node.missingGoesLeft;
        }
        if (!v.isActive(node.variable) && node.missingGoesLeft != null && this.variant == TreeVariant.XGBOOST) {
            return node.missingGoesLeft;
        }
        if (this.variant.expectsStrictComparison) {
            if (this.variant.expectsProcessedFeatureAsDouble) {
                if (this.variant.expectsThresholdAsDouble) {
                    return variable < threshold;
                }
                return variable < (double)((float)threshold);
            }
            if (this.variant.expectsThresholdAsDouble) {
                return (double)((float)variable) < threshold;
            }
            return (float)variable < (float)threshold;
        }
        if (this.variant.expectsProcessedFeatureAsDouble) {
            if (this.variant.expectsThresholdAsDouble) {
                return variable <= threshold;
            }
            return variable <= (double)((float)threshold);
        }
        if (this.variant.expectsThresholdAsDouble) {
            return (double)((float)variable) <= threshold;
        }
        return (float)variable <= (float)threshold;
    }

    public static class Node<T>
    implements Serializable {
        private static final long serialVersionUID = 0L;
        public final boolean isLeaf;
        public final T label;
        public final int variable;
        public final double threshold;
        public final Node<T> leftSon;
        public final Node<T> rightSon;
        public final Boolean missingGoesLeft;
        public final long id;

        private Node(long id, int variable, double threshold, Node<T> leftSon, Node<T> rightSon, Boolean missingGoesLeft, boolean isLeaf, T label) {
            this.id = id;
            this.variable = variable;
            this.threshold = threshold;
            this.leftSon = leftSon;
            this.rightSon = rightSon;
            this.missingGoesLeft = missingGoesLeft;
            this.isLeaf = isLeaf;
            this.label = label;
        }

        public static <T> Node<T> leaf(long leafId, T label) {
            return new Node<T>(leafId, -1, Double.NaN, null, null, null, true, label);
        }

        public static <T> Node<T> node(long nodeId, int feature, double threshold, Node<T> leftSon, Node<T> rightSon, Boolean missingGoesLeft) {
            return new Node<Object>(nodeId, feature, threshold, leftSon, rightSon, missingGoesLeft, false, null);
        }

        public static <T> Node<T> node(long nodeId, int feature, double threshold, Node<T> leftSon, Node<T> rightSon) {
            return new Node<Object>(nodeId, feature, threshold, leftSon, rightSon, null, false, null);
        }
    }

    public static enum TreeVariant {
        SKLEARN(false, true, false),
        XGBOOST(false, false, true),
        LIGHTGBM(true, true, false);

        public final boolean expectsProcessedFeatureAsDouble;
        public final boolean expectsThresholdAsDouble;
        public final boolean expectsStrictComparison;

        private TreeVariant(boolean expectsProcessedFeatureAsDouble, boolean expectsThresholdAsDouble, boolean expectsStrictComparison) {
            this.expectsProcessedFeatureAsDouble = expectsProcessedFeatureAsDouble;
            this.expectsThresholdAsDouble = expectsThresholdAsDouble;
            this.expectsStrictComparison = expectsStrictComparison;
        }
    }
}

