/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.builders;

import com.dataiku.scoring.builders.AlgorithmBuilding;
import com.dataiku.scoring.builders.BuildUtils;
import com.dataiku.scoring.builders.DecisionTreeModelBuilder;
import com.dataiku.scoring.models.DecisionTreeClassifier;
import com.dataiku.scoring.models.DecisionTreeModel;
import com.dataiku.scoring.models.DecisionTreeRegressor;
import com.dataiku.scoring.models.ForestClassifier;
import com.dataiku.scoring.models.ForestRegressor;
import com.dataiku.scoring.models.GenericMLP;
import com.dataiku.scoring.models.GradientBoostingClassifier;
import com.dataiku.scoring.models.GradientBoostingRegressor;
import com.dataiku.scoring.models.LinearRegression;
import com.dataiku.scoring.models.LogisticRegression;
import com.dataiku.scoring.models.MLPClassifier;
import com.dataiku.scoring.models.MLPRegressor;
import com.dataiku.scoring.pipelines.AbstractCalibrator;
import com.dataiku.scoring.pipelines.DummyCalibrator;
import com.dataiku.scoring.pipelines.IsotonicCalibrator;
import com.dataiku.scoring.pipelines.SigmoidCalibrator;
import java.io.IOException;
import java.net.URL;

public abstract class AlgorithmBuilder<T> {
    public static final String MODEL_FILE = "dss_pipeline_model.gz";

    public abstract T importFrom(URL var1, AlgorithmBuilding.AlgorithmType var2, String var3) throws IOException;

    public T importFrom(URL resources, AlgorithmBuilding.AlgorithmType type) throws IOException {
        return this.importFrom(resources, type, MODEL_FILE);
    }

    private static <T> T parseModel(URL resources, Class<T> modelClass, String modelFile) throws IOException {
        return BuildUtils.parseGzippedURL(new URL(resources, modelFile), modelClass);
    }

    AbstractCalibrator parseCalibrator(URL resources) throws IOException {
        CalibratorData dat = BuildUtils.parseGzippedURL(new URL(resources, MODEL_FILE), CalibratorData.class);
        if (dat.calibrator == null || dat.calibrator.method == null) {
            return new DummyCalibrator();
        }
        return switch (dat.calibrator.method) {
            case CalibratorData.CalibrationMethod.SIGMOID -> new SigmoidCalibrator(dat.calibrator.a_array, dat.calibrator.b_array, dat.calibrator.from_proba);
            case CalibratorData.CalibrationMethod.ISOTONIC -> new IsotonicCalibrator(dat.calibrator.x_array, dat.calibrator.y_array, dat.calibrator.from_proba);
            default -> new DummyCalibrator();
        };
    }

    static class CalibratorData {
        CalibrationParams calibrator = new CalibrationParams();

        CalibratorData() {
        }

        class CalibrationParams {
            CalibrationMethod method;
            boolean from_proba;
            double[] a_array;
            double[] b_array;
            double[][] x_array;
            double[][] y_array;

            CalibrationParams() {
            }
        }

        static enum CalibrationMethod {
            SIGMOID,
            ISOTONIC;

        }
    }

    static class MLPBuilder
    extends AlgorithmBuilder<GenericMLP> {
        MLPBuilder() {
        }

        @Override
        public GenericMLP importFrom(URL resources, AlgorithmBuilding.AlgorithmType type, String modelFile) throws IOException {
            MLPData data = AlgorithmBuilder.parseModel(resources, MLPData.class, modelFile);
            if (type == AlgorithmBuilding.AlgorithmType.REGRESSION) {
                return new MLPRegressor(data.activation, data.biases, data.weights);
            }
            return new MLPClassifier(data.activation, data.biases, data.weights);
        }
    }

    static class MLPData {
        GenericMLP.Activation activation;
        double[][] biases;
        double[][][] weights;

        MLPData() {
        }
    }

    static class GradientBoostingClassifierBuilder
    extends AlgorithmBuilder<GradientBoostingClassifier> {
        GradientBoostingClassifierBuilder() {
        }

        @Override
        public GradientBoostingClassifier importFrom(URL resources, AlgorithmBuilding.AlgorithmType type, String modelFile) throws IOException {
            GradientBoostingClassifierData data = AlgorithmBuilder.parseModel(resources, GradientBoostingClassifierData.class, modelFile);
            return new GradientBoostingClassifier(data.baseline, data.shrinkage, data.buildTrees());
        }
    }

    static class GradientBoostingClassifierData {
        double[] baseline;
        double shrinkage;
        RegressionTreeData[][] trees;

        GradientBoostingClassifierData() {
        }

        DecisionTreeRegressor[][] buildTrees() {
            DecisionTreeRegressor[][] res = new DecisionTreeRegressor[this.trees.length][this.trees[0].length];
            for (int i = 0; i < this.trees.length; ++i) {
                for (int j = 0; j < this.trees[i].length; ++j) {
                    res[i][j] = this.trees[i][j].buildModel();
                }
            }
            return res;
        }
    }

    static class LogisticRegressionBuilder
    extends AlgorithmBuilder<LogisticRegression> {
        LogisticRegressionBuilder() {
        }

        @Override
        public LogisticRegression importFrom(URL resources, AlgorithmBuilding.AlgorithmType type, String modelFile) throws IOException {
            LogisticRegressionData data = AlgorithmBuilder.parseModel(resources, LogisticRegressionData.class, modelFile);
            return new LogisticRegression(data.policy, data.intercept, data.coefficients);
        }
    }

    static class LogisticRegressionData {
        LogisticRegression.Policy policy;
        double[] intercept;
        double[][] coefficients;

        LogisticRegressionData() {
        }
    }

    static class GradientBoostingRegressorBuilder
    extends AlgorithmBuilder<GradientBoostingRegressor> {
        GradientBoostingRegressorBuilder() {
        }

        @Override
        public GradientBoostingRegressor importFrom(URL resources, AlgorithmBuilding.AlgorithmType type, String modelFile) throws IOException {
            GradientBoostingRegressorData data = AlgorithmBuilder.parseModel(resources, GradientBoostingRegressorData.class, modelFile);
            return new GradientBoostingRegressor(data.baseline, data.shrinkage, data.buildTrees(), data.gamma_regression, data.logistic_regression);
        }
    }

    static class GradientBoostingRegressorData
    extends ForestRegressorData {
        double baseline;
        double shrinkage;
        boolean gamma_regression = false;
        boolean logistic_regression = false;

        GradientBoostingRegressorData() {
        }
    }

    static class ForestClassifierBuilder
    extends AlgorithmBuilder<ForestClassifier> {
        ForestClassifierBuilder() {
        }

        @Override
        public ForestClassifier importFrom(URL resources, AlgorithmBuilding.AlgorithmType type, String modelFile) throws IOException {
            ForestClassifierData data = AlgorithmBuilder.parseModel(resources, ForestClassifierData.class, modelFile);
            return new ForestClassifier(data.buildTrees());
        }
    }

    static class ForestClassifierData {
        ClassificationTreeData[] trees;

        ForestClassifierData() {
        }

        public DecisionTreeClassifier[] buildTrees() {
            DecisionTreeClassifier[] res = new DecisionTreeClassifier[this.trees.length];
            for (int i = 0; i < this.trees.length; ++i) {
                res[i] = this.trees[i].buildModel();
            }
            return res;
        }
    }

    static class ForestRegressorBuilder
    extends AlgorithmBuilder<ForestRegressor> {
        ForestRegressorBuilder() {
        }

        @Override
        public ForestRegressor importFrom(URL resources, AlgorithmBuilding.AlgorithmType type, String modelFile) throws IOException {
            ForestRegressorData data = AlgorithmBuilder.parseModel(resources, ForestRegressorData.class, modelFile);
            return new ForestRegressor(data.buildTrees());
        }
    }

    static class ForestRegressorData {
        RegressionTreeData[] trees;

        ForestRegressorData() {
        }

        DecisionTreeRegressor[] buildTrees() {
            DecisionTreeRegressor[] res = new DecisionTreeRegressor[this.trees.length];
            for (int i = 0; i < this.trees.length; ++i) {
                res[i] = this.trees[i].buildModel();
            }
            return res;
        }
    }

    static class DecisionTreeBuilder
    extends AlgorithmBuilder<DecisionTreeModel<?>> {
        DecisionTreeBuilder() {
        }

        @Override
        public DecisionTreeModel<?> importFrom(URL resources, AlgorithmBuilding.AlgorithmType type, String modelFile) throws IOException {
            if (type == AlgorithmBuilding.AlgorithmType.REGRESSION) {
                RegressionTreeData data = AlgorithmBuilder.parseModel(resources, RegressionTreeData.class, modelFile);
                return data.buildModel();
            }
            ClassificationTreeData data = AlgorithmBuilder.parseModel(resources, ClassificationTreeData.class, modelFile);
            return data.buildModel();
        }
    }

    static class ClassificationTreeData {
        long[] node_id;
        int[] feature;
        double[] threshold;
        String[] missing;
        long[] leaf_id;
        double[][] label;
        boolean xgboost;
        boolean lightgbm;

        ClassificationTreeData() {
        }

        DecisionTreeClassifier buildModel() {
            DecisionTreeModel.TreeVariant variant = this.xgboost ? DecisionTreeModel.TreeVariant.XGBOOST : (this.lightgbm ? DecisionTreeModel.TreeVariant.LIGHTGBM : DecisionTreeModel.TreeVariant.SKLEARN);
            DecisionTreeModelBuilder modelBuilder = new DecisionTreeModelBuilder(this.node_id, this.feature, this.threshold, this.missing, this.leaf_id, (T[])this.label);
            DecisionTreeModel.Node<double[]> treeRoot = modelBuilder.buildTree();
            return new DecisionTreeClassifier(treeRoot, variant);
        }
    }

    private static class RegressionTreeData {
        long[] node_id;
        int[] feature;
        double[] threshold;
        String[] missing;
        long[] leaf_id;
        Double[] label;
        boolean xgboost;
        boolean lightgbm;

        private RegressionTreeData() {
        }

        DecisionTreeRegressor buildModel() {
            DecisionTreeModel.TreeVariant variant = this.xgboost ? DecisionTreeModel.TreeVariant.XGBOOST : (this.lightgbm ? DecisionTreeModel.TreeVariant.LIGHTGBM : DecisionTreeModel.TreeVariant.SKLEARN);
            DecisionTreeModelBuilder<Double> modelBuilder = new DecisionTreeModelBuilder<Double>(this.node_id, this.feature, this.threshold, this.missing, this.leaf_id, this.label);
            DecisionTreeModel.Node<Double> treeRoot = modelBuilder.buildTree();
            return new DecisionTreeRegressor(treeRoot, variant);
        }
    }

    static class LinearBuilder
    extends AlgorithmBuilder<LinearRegression> {
        LinearBuilder() {
        }

        @Override
        public LinearRegression importFrom(URL resources, AlgorithmBuilding.AlgorithmType type, String modelFile) throws IOException {
            LinearRegressionData dat = AlgorithmBuilder.parseModel(resources, LinearRegressionData.class, modelFile);
            return new LinearRegression(dat.intercept, dat.coefficients);
        }

        static class LinearRegressionData {
            double intercept;
            double[] coefficients;

            LinearRegressionData() {
            }
        }
    }
}

