/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports.pmml.helpers;

import com.dataiku.dip.scoring.exports.pmml.PMMLPreprocessing;
import com.dataiku.dip.scoring.exports.pmml.XML;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLDerivedField;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLLogitFunction;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLSubtractedDerivedFieldGenerator;
import com.dataiku.scoring.models.DecisionTreeModel;
import com.dataiku.scoring.pipelines.Normalization;
import com.dataiku.scoring.pipelines.PreprocessingPipeline;
import java.util.List;
import java.util.Map;

public class PMMLTransformationDictionary {
    @XML.Element
    public List<PMMLDerivedField> DerivedField;
    @XML.Element
    PMMLLogitFunction.PMMLDefineFunctionLogit DefineFunction;

    public static PMMLTransformationDictionary fromPreprocessing(PreprocessingPipeline pipe, Map<String, Normalization.Action> actions, boolean castToFloat) {
        PMMLTransformationDictionary dict = new PMMLTransformationDictionary();
        dict.DerivedField = PMMLPreprocessing.getDerivedFields(pipe, actions, castToFloat);
        dict.DefineFunction = new PMMLLogitFunction.PMMLDefineFunctionLogit();
        return dict;
    }

    public <T> void addDerivedFieldsForTrees(DecisionTreeModel<T>[] trees, String[] colNames) {
        for (int treeIndex = 0; treeIndex < trees.length; ++treeIndex) {
            String treeName = Integer.toString(treeIndex);
            this.addDerivedFieldsForTree(trees[treeIndex], treeName, colNames);
        }
    }

    public <T> void addDerivedFieldsForTrees(DecisionTreeModel<T>[][] trees, String[] colNames) {
        for (int treeIndex = 0; treeIndex < trees.length; ++treeIndex) {
            String treeName;
            DecisionTreeModel<T>[] treesForAllClasses = trees[treeIndex];
            if (treesForAllClasses.length == 1) {
                DecisionTreeModel<T> tree = treesForAllClasses[0];
                treeName = Integer.toString(treeIndex);
                this.addDerivedFieldsForTree(tree, treeName, colNames);
                continue;
            }
            for (int classIndex = 0; classIndex < treesForAllClasses.length; ++classIndex) {
                treeName = String.format("%d:%d", classIndex, treeIndex);
                this.addDerivedFieldsForTree(treesForAllClasses[classIndex], treeName, colNames);
            }
        }
    }

    private <T> void addDerivedFieldsForTree(DecisionTreeModel<T> tree, String treeName, String[] colNames) {
        PMMLSubtractedDerivedFieldGenerator<T> generator = new PMMLSubtractedDerivedFieldGenerator<T>(tree, treeName, colNames);
        List<PMMLDerivedField> derivedFields = generator.generateDerivedFields();
        this.DerivedField.addAll(derivedFields);
    }
}

