/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports.pmml.models.classification.multiclass;

import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.scoring.exports.pmml.XML;
import com.dataiku.dip.scoring.exports.pmml.helpers.NodeRescaler;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLSegmentations;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLTransformationDictionary;
import com.dataiku.dip.scoring.exports.pmml.models.PMMLModel;
import com.dataiku.dip.scoring.exports.pmml.models.classification.PMMLClassifier;
import com.dataiku.dip.scoring.exports.pmml.models.regression.PMMLRegressionModel;
import com.dataiku.dip.scoring.exports.pmml.models.regression.PMMLTreeRegressor;
import com.dataiku.scoring.models.Classifier;
import com.dataiku.scoring.models.DecisionTreeModel;
import com.dataiku.scoring.models.DecisionTreeRegressor;
import com.dataiku.scoring.models.GradientBoostingClassifier;
import com.dataiku.scoring.pipelines.ClassificationPipeline;
import com.dataiku.scoring.pipelines.MulticlassProbabilisticPipeline;
import com.dataiku.scoring.pipelines.Pipeline;
import java.util.ArrayList;
import java.util.List;

@XML.Named(name="MiningModel")
public class PMMLGradientBoostingMultiClassClassifier
extends PMMLClassifier {
    @XML.Element
    final PMMLSegmentations.PMMLModelChainSegmentation Segmentation = new PMMLSegmentations.PMMLModelChainSegmentation();
    @XML.Attribute
    final String functionName = "classification";

    public PMMLGradientBoostingMultiClassClassifier(MulticlassProbabilisticPipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
        super((ClassificationPipeline)pipe, rppp);
        ArrayList<PMMLSegmentations.PMMLSegment> segmentations = new ArrayList<PMMLSegmentations.PMMLSegment>(2);
        PMMLSegmentations.PMMLModelSegment globalModelSegment = new PMMLSegmentations.PMMLModelSegment();
        globalModelSegment.MiningModel = new PMMLMultiClassGlobalModel(pipe, rppp);
        segmentations.add(globalModelSegment);
        PMMLSegmentations.PMMLRegressionSegment normalizationSegment = new PMMLSegmentations.PMMLRegressionSegment();
        normalizationSegment.RegressionModel = new PMMLNormalizationModel(pipe);
        segmentations.add(normalizationSegment);
        this.Segmentation.Segment = segmentations;
    }

    @Override
    public void enrichTransformationDictionaryFromModel(PMMLTransformationDictionary dictionary, String[] colNames, Classifier model) {
        dictionary.addDerivedFieldsForTrees((DecisionTreeModel<T>[][])((GradientBoostingClassifier)model).getTrees(), colNames);
    }

    public static class PMMLMultiClassGlobalModel
    extends PMMLClassifier {
        @XML.Element
        final PMMLSegmentations.PMMLModelConcatSegmentation Segmentation;
        @XML.Attribute
        final String functionName = "regression";

        public PMMLMultiClassGlobalModel(MulticlassProbabilisticPipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
            super((ClassificationPipeline)pipe, rppp);
            this.Output = PMMLModel.PMMLOutput.pseudoProbabilisticWithSegmentId(pipe.getClasses());
            PMMLSegmentations.PMMLModelConcatSegmentation segmentation = new PMMLSegmentations.PMMLModelConcatSegmentation();
            segmentation.Segment = new ArrayList();
            for (int i = 0; i < pipe.getClasses().length; ++i) {
                PMMLPerClassModel perClassModel = new PMMLPerClassModel(pipe, rppp, i, false);
                segmentation.Segment.add(new PMMLSegmentations.PMMLPerClassSegment(pipe.getClasses()[i], perClassModel));
            }
            this.Segmentation = segmentation;
        }

        @Override
        public void enrichTransformationDictionaryFromModel(PMMLTransformationDictionary dictionary, String[] colNames, Classifier model) {
        }
    }

    public static class PMMLNormalizationModel
    extends PMMLRegressionModel {
        @XML.Attribute
        final String functionName = "classification";
        @XML.Attribute
        final String normalizationMethod = "softmax";

        public PMMLNormalizationModel(MulticlassProbabilisticPipeline pipe) {
            String[] classes = pipe.getClasses();
            ArrayList<PMMLModel.PMMLMiningSchema.PMMLMiningField> fields = new ArrayList<PMMLModel.PMMLMiningSchema.PMMLMiningField>(classes.length);
            for (String clazz : classes) {
                fields.add(new PMMLModel.PMMLMiningSchema.PMMLMiningField("prediction_" + clazz));
            }
            this.MiningSchema = PMMLModel.PMMLMiningSchema.fromFields(fields);
            this.Output = PMMLModel.PMMLOutput.multiclassClassification(pipe.getClasses());
            ArrayList<PMMLRegressionModel.PMMLLogisticRegressionTable> regressionTables = new ArrayList<PMMLRegressionModel.PMMLLogisticRegressionTable>();
            for (String clazz : classes) {
                regressionTables.add(new PMMLRegressionModel.PMMLLogisticRegressionTable(clazz, 0.0, new double[]{1.0}, new String[]{"prediction_" + clazz}));
            }
            this.RegressionTable = regressionTables;
        }
    }

    public static class PMMLPerClassModel
    extends PMMLClassifier {
        @XML.Element
        final List<PMMLSegmentations.PMMLSumSegmentation> Segmentation;
        @XML.Attribute
        final String functionName = "regression";

        public PMMLPerClassModel(MulticlassProbabilisticPipeline pipe, ResolvedPredictionPreprocessingParams rppp, int classIndex, boolean withOutput) {
            super((ClassificationPipeline)pipe, rppp);
            if (withOutput) {
                this.Output = PMMLModel.PMMLOutput.outputOnlyClassOnePredictedValue(pipe.getClasses());
            }
            ArrayList<PMMLSegmentations.PMMLSumSegmentation> segmentations = new ArrayList<PMMLSegmentations.PMMLSumSegmentation>();
            PMMLSegmentations.PMMLSumSegmentation segments = new PMMLSegmentations.PMMLSumSegmentation();
            segments.Segment = new ArrayList<PMMLSegmentations.PMMLTreeSegment>();
            GradientBoostingClassifier gbt = (GradientBoostingClassifier)pipe.getModel();
            DecisionTreeRegressor[][] trees = gbt.getTrees();
            for (int j = 0; j < trees.length; ++j) {
                double baseline = j == 0 ? gbt.getBaseline()[classIndex] : 0.0;
                NodeRescaler nodeRescaler = new NodeRescaler(baseline, gbt.getShrinkage());
                DecisionTreeRegressor originalRegressor = trees[j][classIndex];
                DecisionTreeModel.Node<Double> rescaledTree = nodeRescaler.rescale((DecisionTreeModel.Node<Double>)originalRegressor.getRoot());
                DecisionTreeRegressor rescaledRegressor = new DecisionTreeRegressor(rescaledTree, originalRegressor.variant);
                PMMLTreeRegressor pmmlTreeRegressor = new PMMLTreeRegressor((Pipeline)pipe, rppp, rescaledRegressor, pipe.getClasses()[classIndex], classIndex + ":" + j);
                segments.Segment.add(new PMMLSegmentations.PMMLTreeSegment(pmmlTreeRegressor));
            }
            segmentations.add(segments);
            this.Segmentation = segmentations;
        }

        @Override
        public void enrichTransformationDictionaryFromModel(PMMLTransformationDictionary dictionary, String[] colNames, Classifier model) {
        }
    }
}

