/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports;

import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.scoring.exports.Columns;
import com.dataiku.dip.scoring.exports.SQLPreprocessing;
import com.dataiku.dip.sql.SQLDialect;
import com.dataiku.dip.sql.queries.ExpressionBuilder;
import com.dataiku.dip.sql.queries.SelectQueryBuilder;
import com.dataiku.scoring.models.DecisionTreeClassifier;
import com.dataiku.scoring.models.DecisionTreeModel;
import com.dataiku.scoring.models.DecisionTreeRegressor;
import com.dataiku.scoring.models.ForestClassifier;
import com.dataiku.scoring.models.ForestRegressor;
import com.dataiku.scoring.models.GenericMLP;
import com.dataiku.scoring.models.GradientBoostingClassifier;
import com.dataiku.scoring.models.GradientBoostingRegressor;
import com.dataiku.scoring.models.LinearRegression;
import com.dataiku.scoring.models.LogisticRegression;
import com.dataiku.scoring.models.MLPClassifier;
import com.dataiku.scoring.models.MLPRegressor;
import com.dataiku.scoring.models.ProbabilisticClassifier;
import com.dataiku.scoring.models.Regressor;
import com.dataiku.scoring.pipelines.BinaryProbabilisticPipeline;
import com.dataiku.scoring.pipelines.ClassificationPipeline;
import com.dataiku.scoring.pipelines.PreprocessingPipeline;
import com.dataiku.scoring.pipelines.RegressionPipeline;
import com.dataiku.scoring.util.MathUtils;
import java.util.Map;
import org.apache.log4j.Logger;

public class SQLPrediction {
    private static final ExpressionBuilder.ExpressionBuilderFactory EBF = new ExpressionBuilder.ExpressionBuilderFactory();
    static Logger logger = Logger.getLogger((String)"dataiku.scoring.sql");

    private static ExpressionBuilder getRegressorExpression(Regressor model, PreprocessingPipeline preprocessing, Double missingValue) {
        Map columnMapping = preprocessing.getColumnMapping();
        String[] columnNames = SQLPrediction.getColumnNames(columnMapping);
        if (model instanceof LinearRegression) {
            return SQLPrediction.linearRegressionQuery((LinearRegression)model, columnMapping);
        }
        if (model instanceof DecisionTreeRegressor) {
            return SQLPrediction.regressionNode2sql((DecisionTreeRegressor)model, columnNames, missingValue);
        }
        if (model instanceof ForestRegressor) {
            return SQLPrediction.regressionForestQuery((ForestRegressor)model, columnNames, missingValue);
        }
        if (model instanceof GradientBoostingRegressor) {
            return SQLPrediction.boostingRegressorQuery((GradientBoostingRegressor)model, columnNames, missingValue);
        }
        throw new IllegalArgumentException("Model type unsupported for SQL export : " + model.getClass().getSimpleName());
    }

    public static boolean canOutputMulticlassProbas(PreTrainPredictionModelingParams.Algorithm algo) {
        switch (algo) {
            case DECISION_TREE_CLASSIFICATION: 
            case MLLIB_DECISION_TREE: {
                return false;
            }
        }
        return true;
    }

    private static void restoreBackupColumns(SelectQueryBuilder sqb, Columns columns) {
        for (String s : columns) {
            if (!s.startsWith("__dku_in_")) continue;
            sqb.select(s, s.split("__dku_in_")[1]);
        }
    }

    private static void keepBackupColumns(SelectQueryBuilder sqb, Columns columns) {
        for (String s : columns) {
            if (!s.startsWith("__dku_in_")) continue;
            sqb.select(s);
        }
    }

    public static SelectQueryBuilder regression(RegressionPipeline pipeline, SQLPreprocessing.SelectWithSchema prep, Double missingValue) {
        Regressor model = pipeline.getModel();
        PreprocessingPipeline preprocessing = pipeline.getPreprocessing();
        Map columnMapping = preprocessing.getColumnMapping();
        if (model instanceof MLPRegressor) {
            return SQLPrediction.regressionNetQuery((MLPRegressor)model, columnMapping, prep);
        }
        SelectQueryBuilder sqb = new SelectQueryBuilder();
        sqb.select(SQLPrediction.getRegressorExpression(model, preprocessing, missingValue), "prediction");
        SQLPrediction.restoreBackupColumns(sqb, prep.columns);
        sqb.from(prep.query, "preprocessed_data");
        return sqb;
    }

    private static ExpressionBuilder linearRegressionQuery(LinearRegression model, Map<String, Integer> columnMapping) {
        double[] coefficients = model.getCoefficients();
        ExpressionBuilder eb = EBF.cst(model.getIntercept());
        for (Map.Entry<String, Integer> column : columnMapping.entrySet()) {
            eb = eb.plus(EBF.col(column.getKey()).time(EBF.cst(coefficients[column.getValue()])));
        }
        return eb;
    }

    private static ExpressionBuilder regressionForestQuery(ForestRegressor model, String[] columnNames, Double missingValue) {
        DecisionTreeRegressor[] trees = model.getTrees();
        ExpressionBuilder eb = EBF.expr();
        for (DecisionTreeRegressor tree : trees) {
            eb = eb.plus(SQLPrediction.regressionNode2sql(tree, columnNames, missingValue));
        }
        return eb.div(trees.length);
    }

    private static ExpressionBuilder boostingRegressorQuery(GradientBoostingRegressor model, String[] columnNames, Double missingValue) {
        ExpressionBuilder eb = EBF.expr();
        for (DecisionTreeRegressor tree : model.getTrees()) {
            eb = eb.plus(SQLPrediction.regressionNode2sql(tree, columnNames, missingValue));
        }
        return eb.time(model.getShrinkage()).plus(model.getBaseline());
    }

    private static String[] getColumnNames(Map<String, Integer> columnMapping) {
        String[] columnNames = new String[columnMapping.size()];
        for (Map.Entry<String, Integer> column : columnMapping.entrySet()) {
            columnNames[column.getValue().intValue()] = column.getKey();
        }
        return columnNames;
    }

    private static ExpressionBuilder regressionNode2sql(DecisionTreeRegressor tree, String[] columns, Double missingValue) {
        return new RegressionTreeToExpressionBuilder(tree, columns, missingValue).buildQuery();
    }

    public static SelectQueryBuilder classification(ClassificationPipeline pipeline, SQLDialect dialect, SQLPreprocessing.SelectWithSchema prep, Double missingValue) {
        assert (pipeline.getModel() instanceof ProbabilisticClassifier);
        ProbabilisticClassifier model = (ProbabilisticClassifier)pipeline.getModel();
        PreprocessingPipeline preprocessing = pipeline.getPreprocessing();
        Map columnMapping = preprocessing.getColumnMapping();
        String[] classes = pipeline.getClasses();
        Double threshold = 0.5;
        if (pipeline instanceof BinaryProbabilisticPipeline) {
            threshold = ((BinaryProbabilisticPipeline)pipeline).getThreshold();
        }
        if (model instanceof LogisticRegression) {
            if (!dialect.supportsGreatest()) {
                throw new UnsupportedOperationException("This SQL dialect does not support GREATEST()");
            }
            return SQLPrediction.logitProbaQuery(columnMapping, classes, prep, (LogisticRegression)model, threshold);
        }
        if (model instanceof DecisionTreeClassifier) {
            return SQLPrediction.decisionTreePredictionQuery(columnMapping, classes, threshold, prep, (DecisionTreeClassifier)model, missingValue);
        }
        if (model instanceof ForestClassifier) {
            if (model.getNumClasses() > 2) {
                throw new UnsupportedOperationException("Forest classification in SQL does not support multiclass");
            }
            return SQLPrediction.forestPredictionQuery(columnMapping, classes, threshold, prep, (ForestClassifier)model, missingValue);
        }
        if (model instanceof GradientBoostingClassifier) {
            if (model.getNumClasses() > 2) {
                throw new UnsupportedOperationException("GBT classification in SQL does not support multiclass");
            }
            return SQLPrediction.boostingBinaryClassifierQuery(columnMapping, classes, threshold, prep, (GradientBoostingClassifier)model, missingValue);
        }
        if (model instanceof MLPClassifier) {
            if (!dialect.supportsGreatest()) {
                throw new UnsupportedOperationException("This SQL dialect does not support GREATEST()");
            }
            return SQLPrediction.classificationNetQuery((MLPClassifier)model, classes, threshold, columnMapping, prep);
        }
        throw new IllegalArgumentException("Model kind " + model.getClass().getSimpleName() + " is not supported.");
    }

    private static SelectQueryBuilder logitProbaQuery(Map<String, Integer> columnMapping, String[] classes, SQLPreprocessing.SelectWithSchema prep, LogisticRegression logit, Double threshold) {
        int i;
        int numClasses = logit.getNumClasses();
        ExpressionBuilder classExpr = EBF.caseWhen(new Object[0]);
        SelectQueryBuilder result = new SelectQueryBuilder();
        SelectQueryBuilder probaQuery = new SelectQueryBuilder();
        double[] baseline = logit.getBaseline();
        double[][] coefficients = logit.getCoefficients();
        LogisticRegression.Policy policy = logit.getPolicy();
        SelectQueryBuilder expQuery = new SelectQueryBuilder();
        for (int c2 = 0; c2 < baseline.length; ++c2) {
            ExpressionBuilder t_c = EBF.cst(baseline[c2]);
            for (Map.Entry<String, Integer> column : columnMapping.entrySet()) {
                t_c = t_c.plus(EBF.col(column.getKey()).time(EBF.cst(coefficients[c2][column.getValue()])));
            }
            expQuery.select(switch (policy) {
                case LogisticRegression.Policy.ONE_VERSUS_ALL -> EBF.cst(1.0).div(EBF.cst(1.0).plus(EBF.cst(0.0).minus(t_c).exp()));
                case LogisticRegression.Policy.MULTINOMIAL -> EBF.expr().least(EBF.cst(200.0), t_c).exp();
                case LogisticRegression.Policy.MODIFIED_HUBER -> EBF.cst(0.5).time(EBF.cst(1.0).plus(EBF.expr().greatest(EBF.cst(-1.0), EBF.expr().least(EBF.cst(1.0), t_c))));
                default -> throw new IllegalArgumentException("unknown policy");
            }, "exp_t_" + c2);
        }
        SQLPrediction.keepBackupColumns(expQuery, prep.columns);
        expQuery.from(prep.query, "with_probas");
        SelectQueryBuilder sumQuery = new SelectQueryBuilder();
        if (baseline.length == 2 && policy == LogisticRegression.Policy.MODIFIED_HUBER) {
            sumQuery.select(EBF.cst(1.0), "norm");
            sumQuery.select(EBF.cst(1.0).minus(EBF.col("exp_t_1")), "exp_t_0");
            sumQuery.select(EBF.col("exp_t_1"));
        } else {
            ExpressionBuilder eb = EBF.expr();
            for (int c3 = 0; c3 < baseline.length; ++c3) {
                ExpressionBuilder exp = EBF.col("exp_t_" + c3);
                sumQuery.select(exp);
                eb = eb.plus(exp);
            }
            sumQuery.select(eb, "norm");
        }
        SQLPrediction.keepBackupColumns(sumQuery, prep.columns);
        sumQuery.from(expQuery, "exp_t");
        for (int c4 = 0; c4 < baseline.length; ++c4) {
            probaQuery.select(EBF.col("exp_t_" + c4).div(EBF.col("norm")), "proba_" + c4);
        }
        SQLPrediction.keepBackupColumns(probaQuery, prep.columns);
        probaQuery.from(sumQuery, "exp_with_sum");
        if (numClasses == 2) {
            classExpr = classExpr.caseWhen(EBF.col("proba_1").gt(threshold), EBF.cst(classes[1]));
            classExpr = classExpr.caseWhen(EBF.cst(classes[0]));
            result.select(classExpr, "prediction");
            result.select("proba_0", "proba_" + classes[0]);
            result.select("proba_1", "proba_" + classes[1]);
            result.from(probaQuery, "probas");
            SQLPrediction.restoreBackupColumns(result, prep.columns);
            return result;
        }
        ExpressionBuilder[] classColumns = new ExpressionBuilder[numClasses];
        for (int c5 = 0; c5 < numClasses; ++c5) {
            classColumns[c5] = EBF.col("proba_" + c5);
        }
        SelectQueryBuilder greatestQuery = new SelectQueryBuilder();
        for (i = 0; i < numClasses; ++i) {
            greatestQuery.select(EBF.col("proba_" + i));
        }
        SQLPrediction.keepBackupColumns(greatestQuery, prep.columns);
        greatestQuery.select(EBF.expr().greatest(classColumns), "proba__");
        greatestQuery.from(probaQuery, "probas");
        for (int c6 = 0; c6 < numClasses; ++c6) {
            classExpr = classExpr.caseWhen(classColumns[c6].eq(EBF.col("proba__")), EBF.cst(classes[c6]));
        }
        classExpr = classExpr.caseWhen(EBF.cst(null));
        result.select(classExpr, "prediction");
        for (i = 0; i < numClasses; ++i) {
            result.select(EBF.col("proba_" + i), "proba_" + classes[i]);
        }
        SQLPrediction.restoreBackupColumns(result, prep.columns);
        result.from(greatestQuery, "highest");
        return result;
    }

    private static SelectQueryBuilder decisionTreePredictionQuery(Map<String, Integer> columnMapping, String[] classes, Double threshold, SQLPreprocessing.SelectWithSchema prep, DecisionTreeClassifier dtc, Double missingValue) {
        String[] columnNames = SQLPrediction.getColumnNames(columnMapping);
        SelectQueryBuilder predQuery = new SelectQueryBuilder();
        if (dtc.getNumClasses() == 2) {
            SelectQueryBuilder probaQuery = new SelectQueryBuilder();
            ExpressionBuilder treeExpr = SQLPrediction.classificationNode2sql(dtc, columnNames, classes, true, missingValue);
            probaQuery.select(treeExpr, "proba_1");
            SQLPrediction.keepBackupColumns(probaQuery, prep.columns);
            probaQuery.from(prep.query, "with_probas");
            ExpressionBuilder proba_1 = EBF.col("proba_1");
            predQuery.select(EBF.cst(1.0).minus(proba_1), "proba_" + classes[0]);
            predQuery.select(proba_1, "proba_" + classes[1]);
            predQuery.select(EBF.caseWhen(new Object[0]).caseWhen(proba_1.gt(threshold), classes[1]).caseWhen(classes[0]), "prediction");
            SQLPrediction.restoreBackupColumns(predQuery, prep.columns);
            predQuery.from(probaQuery, "probas");
        } else {
            predQuery.select(SQLPrediction.classificationNode2sql(dtc, columnNames, classes, false, missingValue), "prediction");
            SQLPrediction.restoreBackupColumns(predQuery, prep.columns);
            predQuery.from(prep.query, "with_predict");
        }
        return predQuery;
    }

    private static SelectQueryBuilder forestPredictionQuery(Map<String, Integer> columnMapping, String[] classes, double threshold, SQLPreprocessing.SelectWithSchema prep, ForestClassifier fc, Double missingValue) {
        assert (fc.getNumClasses() == 2);
        String[] columnNames = SQLPrediction.getColumnNames(columnMapping);
        SelectQueryBuilder predQuery = new SelectQueryBuilder();
        ExpressionBuilder res = EBF.expr();
        boolean useProbas = fc.getPolicy() == ForestClassifier.EnsemblingPolicy.AVERAGE;
        for (DecisionTreeClassifier dtc : fc.getTrees()) {
            res = res.plus(SQLPrediction.classificationNode2sql(dtc, columnNames, classes, useProbas, missingValue));
        }
        if (!useProbas) {
            predQuery.select(EBF.caseWhen(new Object[0]).caseWhen(res.gt(fc.getTrees().length / 2), classes[1]).caseWhen(classes[0]), "prediction");
            SQLPrediction.restoreBackupColumns(predQuery, prep.columns);
            predQuery.from(prep.query, "with_predict");
        } else {
            SelectQueryBuilder probaQuery = new SelectQueryBuilder();
            probaQuery.select(res.div(EBF.cst(fc.getTrees().length)), "proba_1");
            SQLPrediction.keepBackupColumns(probaQuery, prep.columns);
            probaQuery.from(prep.query, "with_probas");
            ExpressionBuilder proba_1 = EBF.col("proba_1");
            predQuery.select(EBF.cst(1.0).minus(proba_1), "proba_" + classes[0]);
            predQuery.select(proba_1, "proba_" + classes[1]);
            predQuery.select(EBF.caseWhen(new Object[0]).caseWhen(proba_1.gt(threshold), classes[1]).caseWhen(classes[0]), "prediction");
            SQLPrediction.restoreBackupColumns(predQuery, prep.columns);
            predQuery.from(probaQuery, "probas");
        }
        return predQuery;
    }

    private static SelectQueryBuilder boostingBinaryClassifierQuery(Map<String, Integer> columnMapping, String[] classes, double threshold, SQLPreprocessing.SelectWithSchema prep, GradientBoostingClassifier model, Double missingValue) {
        assert (model.getNumClasses() == 2);
        ExpressionBuilder res = EBF.cst(0.0);
        String[] columnNames = SQLPrediction.getColumnNames(columnMapping);
        for (DecisionTreeRegressor[] tree : model.getTrees()) {
            res = res.plus(SQLPrediction.regressionNode2sql(tree[0], columnNames, missingValue));
        }
        res = res.time(model.getShrinkage()).plus(model.getBaseline()[0]);
        res = EBF.cst(1.0).div(EBF.cst(1.0).plus(EBF.cst(0.0).minus(res).exp()));
        SelectQueryBuilder probaQuery = new SelectQueryBuilder();
        probaQuery.select(res, "proba_1");
        SQLPrediction.keepBackupColumns(probaQuery, prep.columns);
        probaQuery.from(prep.query, "with_probas");
        ExpressionBuilder proba_1 = EBF.col("proba_1");
        SelectQueryBuilder predQuery = new SelectQueryBuilder();
        predQuery.select(EBF.cst(1.0).minus(proba_1), "proba_" + classes[0]);
        predQuery.select(proba_1, "proba_" + classes[1]);
        predQuery.select(EBF.caseWhen(new Object[0]).caseWhen(proba_1.gt(threshold), classes[1]).caseWhen(classes[0]), "prediction");
        SQLPrediction.restoreBackupColumns(predQuery, prep.columns);
        predQuery.from(probaQuery, "probas");
        return predQuery;
    }

    private static ExpressionBuilder classificationNode2sql(DecisionTreeClassifier tree, String[] columns, String[] classes, boolean proba1, Double missingValue) {
        return new ClassificationTreeToExpressionBuilder(tree, columns, classes, proba1, missingValue).buildQuery();
    }

    private static ExpressionBuilder applyActivation(GenericMLP.Activation activation, ExpressionBuilder dotExpr) {
        switch (activation) {
            case LOGISTIC: {
                return EBF.cst(1.0).div(EBF.cst(1.0).plus(EBF.cst(0.0).minus(dotExpr).exp()));
            }
            case TANH: {
                ExpressionBuilder exp = EBF.cst(2.0).time(dotExpr).exp();
                return exp.minus(EBF.cst(1.0)).div(exp.plus(EBF.cst(1.0)));
            }
            case RELU: {
                return EBF.caseWhen(dotExpr.gt(0.0), dotExpr, EBF.cst(0.0));
            }
            case IDENTITY: {
                return dotExpr;
            }
        }
        throw new IllegalArgumentException("unknown activation");
    }

    private static SQLPreprocessing.SelectWithSchema genericNetQuery(GenericMLP mlp, Map<String, Integer> columnMapping, SQLPreprocessing.SelectWithSchema prepped) {
        String[] columns = SQLPrediction.getColumnNames(columnMapping);
        GenericMLP.Activation activation = mlp.getActivation();
        double[][] biases = mlp.getBiases();
        double[][][] weights = mlp.getWeights();
        Columns toKeep = prepped.columns;
        SelectQueryBuilder prep = prepped.query;
        SelectQueryBuilder afterInput = new SelectQueryBuilder();
        for (int i = 0; i < biases[0].length; ++i) {
            ExpressionBuilder dotExpr = EBF.cst(biases[0][i]);
            for (int j = 0; j < columns.length; ++j) {
                dotExpr = dotExpr.plus(EBF.cst(weights[0][i][j]).time(EBF.col(columns[j])));
            }
            afterInput.select(SQLPrediction.applyActivation(activation, dotExpr), "__dku_output_" + i);
        }
        afterInput.from(prep, "after_input");
        SQLPrediction.keepBackupColumns(afterInput, toKeep);
        SelectQueryBuilder afterHidden = afterInput;
        for (int i = 1; i < biases.length; ++i) {
            SelectQueryBuilder newQuery = new SelectQueryBuilder();
            for (int j = 0; j < biases[i].length; ++j) {
                ExpressionBuilder dotExpr = EBF.cst(biases[i][j]);
                for (int k = 0; k < weights[i][j].length; ++k) {
                    dotExpr = dotExpr.plus(EBF.cst(weights[i][j][k]).time(EBF.col("__dku_output_" + k)));
                }
                if (i != biases.length - 1) {
                    dotExpr = SQLPrediction.applyActivation(activation, dotExpr);
                }
                newQuery.select(dotExpr, "__dku_output_" + j);
            }
            newQuery.from(afterHidden, "hidden_" + i);
            afterHidden = newQuery;
        }
        SQLPrediction.keepBackupColumns(afterHidden, prepped.columns);
        return new SQLPreprocessing.SelectWithSchema(afterHidden, toKeep);
    }

    private static SelectQueryBuilder regressionNetQuery(MLPRegressor mlp, Map<String, Integer> columnMapping, SQLPreprocessing.SelectWithSchema prepped) {
        SQLPreprocessing.SelectWithSchema gen = SQLPrediction.genericNetQuery((GenericMLP)mlp, columnMapping, prepped);
        SelectQueryBuilder res = new SelectQueryBuilder();
        res.select("__dku_output_0", "prediction");
        SQLPrediction.restoreBackupColumns(res, gen.columns);
        res.from(gen.query, "result");
        return res;
    }

    private static SelectQueryBuilder classificationNetQuery(MLPClassifier mlp, String[] classes, double threshold, Map<String, Integer> columnMapping, SQLPreprocessing.SelectWithSchema prepped) {
        SQLPreprocessing.SelectWithSchema gen = SQLPrediction.genericNetQuery((GenericMLP)mlp, columnMapping, prepped);
        if (classes.length == 2) {
            SelectQueryBuilder res = new SelectQueryBuilder();
            ExpressionBuilder positiveProba = SQLPrediction.applyActivation(GenericMLP.Activation.LOGISTIC, EBF.col("__dku_output_0"));
            ExpressionBuilder negativeProba = EBF.cst(1.0).minus(positiveProba);
            double logThreshold = Math.log(threshold / (1.0 - threshold));
            ExpressionBuilder prediction = EBF.caseWhen(EBF.col("__dku_output_0").gt(logThreshold), classes[1], classes[0]);
            res.select(prediction, "prediction");
            res.select(negativeProba, "proba_" + classes[0]);
            res.select(positiveProba, "proba_" + classes[1]);
            SQLPrediction.restoreBackupColumns(res, gen.columns);
            res.from(gen.query, "result");
            return res;
        }
        SelectQueryBuilder withExp = new SelectQueryBuilder();
        for (int i = 0; i < classes.length; ++i) {
            withExp.select(EBF.col("__dku_output_" + i).exp(), "__exp_" + i);
        }
        SQLPrediction.keepBackupColumns(withExp, gen.columns);
        withExp.from(gen.query, "with_exp");
        SelectQueryBuilder withNormAndGreatest = new SelectQueryBuilder();
        ExpressionBuilder norm = EBF.col("__exp_0");
        withNormAndGreatest.select(EBF.col("__exp_0"), "__exp_0");
        for (int i = 1; i < classes.length; ++i) {
            norm = norm.plus(EBF.col("__exp_" + i));
            withNormAndGreatest.select(EBF.col("__exp_" + i), "__exp_" + i);
        }
        Object[] expColumns = new Object[classes.length];
        for (int i = 0; i < classes.length; ++i) {
            expColumns[i] = EBF.col("__exp_" + i);
        }
        withNormAndGreatest.select(EBF.expr().greatest(expColumns), "__greatest");
        withNormAndGreatest.select(norm, "__norm");
        SQLPrediction.keepBackupColumns(withNormAndGreatest, gen.columns);
        withNormAndGreatest.from(withExp, "with_norm_greatest");
        SelectQueryBuilder withProbas = new SelectQueryBuilder();
        for (int i = 0; i < classes.length; ++i) {
            withProbas.select(EBF.col("__exp_" + i).div(EBF.col("__norm")), "proba_" + classes[i]);
        }
        ExpressionBuilder prediction = EBF.caseWhen(new Object[0]);
        for (int i = 0; i < classes.length - 1; ++i) {
            prediction = prediction.caseWhen(EBF.col("__exp_" + i).eq(EBF.col("__greatest")), classes[i]);
        }
        prediction = prediction.caseWhen(classes[classes.length - 1]);
        withProbas.select(prediction, "prediction");
        SQLPrediction.restoreBackupColumns(withProbas, gen.columns);
        withProbas.from(withNormAndGreatest, "result");
        return withProbas;
    }

    static class RegressionTreeToExpressionBuilder
    extends TreeToExpressionBuilder<Double> {
        RegressionTreeToExpressionBuilder(DecisionTreeRegressor tree, String[] columns, Double missingValue) {
            super(tree, columns, missingValue);
        }

        @Override
        protected ExpressionBuilder getLeafNodeExpressionBuilder(DecisionTreeModel.Node<Double> node) {
            return EBF.cst(node.label);
        }
    }

    static class ClassificationTreeToExpressionBuilder
    extends TreeToExpressionBuilder<double[]> {
        private final String[] classes;
        private final boolean proba_1;

        ClassificationTreeToExpressionBuilder(DecisionTreeClassifier tree, String[] columns, String[] classes, boolean proba1, Double missingValue) {
            super(tree, columns, missingValue);
            this.classes = classes;
            this.proba_1 = proba1;
        }

        @Override
        protected ExpressionBuilder getLeafNodeExpressionBuilder(DecisionTreeModel.Node<double[]> node) {
            if (node.label == null) {
                throw new IllegalArgumentException("Cannot build SQL query on leaf without label");
            }
            return EBF.cst(this.proba_1 ? Double.valueOf(((double[])node.label)[1]) : this.classes[MathUtils.argmax((double[])((double[])node.label))]);
        }
    }

    static abstract class TreeToExpressionBuilder<T> {
        private final DecisionTreeModel<T> tree;
        private final String[] columns;
        private final Double missingValue;

        TreeToExpressionBuilder(DecisionTreeModel<T> tree, String[] columns, Double missingValue) {
            this.tree = tree;
            this.columns = columns;
            this.missingValue = missingValue;
        }

        public ExpressionBuilder buildQuery() {
            return this.buildQuery(this.tree.getRoot());
        }

        private ExpressionBuilder buildQuery(DecisionTreeModel.Node<T> node) {
            ExpressionBuilder condition;
            if (node.isLeaf) {
                return this.getLeafNodeExpressionBuilder(node);
            }
            ExpressionBuilder variable = this.tree.variant.expectsProcessedFeatureAsDouble ? EBF.col(this.columns[node.variable]) : EBF.col(this.columns[node.variable]).castToFloat();
            ExpressionBuilder threshold = this.tree.variant.expectsThresholdAsDouble ? EBF.cst(node.threshold) : EBF.cst(node.threshold).castToFloat();
            ExpressionBuilder thresholdComparison = this.tree.variant.expectsStrictComparison ? variable.lt(threshold) : variable.lte(threshold);
            if (node.missingGoesLeft == null) {
                condition = thresholdComparison;
            } else if (node.missingGoesLeft.booleanValue()) {
                condition = this.missingValue == null ? EBF.col(this.columns[node.variable]).isnull() : variable.eq(this.missingValue);
                condition = condition.or(thresholdComparison);
            } else {
                condition = this.missingValue == null ? thresholdComparison : variable.ne(this.missingValue).and(thresholdComparison);
            }
            ExpressionBuilder leftSubQuery = this.buildQuery(node.leftSon);
            ExpressionBuilder rightSubQuery = this.buildQuery(node.rightSon);
            return EBF.caseWhen(new Object[0]).caseWhen(condition, leftSubQuery).caseWhen(rightSubQuery);
        }

        protected abstract ExpressionBuilder getLeafNodeExpressionBuilder(DecisionTreeModel.Node<T> var1);
    }
}

