/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports.pmml.models.regression;

import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.scoring.exports.pmml.XML;
import com.dataiku.dip.scoring.exports.pmml.helpers.NodeRescaler;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLTransformationDictionary;
import com.dataiku.dip.scoring.exports.pmml.models.regression.PMMLRegressor;
import com.dataiku.dip.scoring.exports.pmml.models.regression.PMMLTreeRegressor;
import com.dataiku.scoring.models.DecisionTreeModel;
import com.dataiku.scoring.models.DecisionTreeRegressor;
import com.dataiku.scoring.models.ForestRegressor;
import com.dataiku.scoring.models.GradientBoostingRegressor;
import com.dataiku.scoring.models.Regressor;
import com.dataiku.scoring.pipelines.Pipeline;
import com.dataiku.scoring.pipelines.RegressionPipeline;
import java.util.ArrayList;
import java.util.List;

@XML.Named(name="MiningModel")
public class PMMLTreeEnsembleRegressor
extends PMMLRegressor {
    @XML.Element
    PMMLSegmentation Segmentation;

    private PMMLTreeEnsembleRegressor(RegressionPipeline pipe, ResolvedPredictionPreprocessingParams rppp, DecisionTreeRegressor[] trees, String method) {
        super((Pipeline)pipe, rppp);
        this.Segmentation = new PMMLSegmentation(method);
        int treeIndex = 0;
        for (DecisionTreeRegressor t : trees) {
            this.Segmentation.Segment.add(new PMMLSegmentation.PMMLSegment(new PMMLTreeRegressor((Pipeline)pipe, rppp, t, treeIndex)));
            ++treeIndex;
        }
    }

    public static PMMLTreeEnsembleRegressor fromForest(RegressionPipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
        return new PMMLTreeEnsembleRegressor(pipe, rppp, ((ForestRegressor)pipe.getModel()).getTrees(), "average");
    }

    public static PMMLTreeEnsembleRegressor fromBooster(RegressionPipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
        GradientBoostingRegressor gbt = (GradientBoostingRegressor)pipe.getModel();
        DecisionTreeRegressor[] trees = gbt.getTrees();
        DecisionTreeRegressor[] rescaledTrees = gbt.getTrees();
        for (int i = 0; i < trees.length; ++i) {
            double baseline = i == 0 ? gbt.getBaseline() : 0.0;
            NodeRescaler nodeRescaler = new NodeRescaler(baseline, gbt.getShrinkage());
            DecisionTreeRegressor originalRegressor = trees[i];
            DecisionTreeModel.Node<Double> rescaledTree = nodeRescaler.rescale((DecisionTreeModel.Node<Double>)originalRegressor.getRoot());
            rescaledTrees[i] = new DecisionTreeRegressor(rescaledTree, originalRegressor.variant);
        }
        return new PMMLTreeEnsembleRegressor(pipe, rppp, rescaledTrees, "sum");
    }

    @Override
    public void enrichTransformationDictionaryFromModel(PMMLTransformationDictionary dictionary, String[] colNames, Regressor model) {
        if (model instanceof ForestRegressor) {
            dictionary.addDerivedFieldsForTrees((DecisionTreeModel<T>[])((ForestRegressor)model).getTrees(), colNames);
        } else if (model instanceof GradientBoostingRegressor) {
            dictionary.addDerivedFieldsForTrees((DecisionTreeModel<T>[])((GradientBoostingRegressor)model).getTrees(), colNames);
        }
    }

    public static class PMMLSegmentation {
        @XML.Element
        List<PMMLSegment> Segment;
        @XML.Attribute
        String multipleModelMethod;

        public PMMLSegmentation(String method) {
            this.multipleModelMethod = method;
            this.Segment = new ArrayList<PMMLSegment>();
        }

        public static class PMMLSegment {
            @XML.Element
            Object True = new Object();
            @XML.Element
            PMMLTreeRegressor TreeModel;

            public PMMLSegment(PMMLTreeRegressor treeModel) {
                this.TreeModel = treeModel;
            }
        }
    }
}

