/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports.pmml.models.classification.binary;

import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.scoring.exports.pmml.PMMLPreprocessing;
import com.dataiku.dip.scoring.exports.pmml.XML;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLTransformationDictionary;
import com.dataiku.dip.scoring.exports.pmml.models.PMMLModel;
import com.dataiku.dip.scoring.exports.pmml.models.classification.binary.PMMLBinaryClassifier;
import com.dataiku.dip.scoring.exports.pmml.models.regression.PMMLRegressionModel;
import com.dataiku.scoring.models.Classifier;
import com.dataiku.scoring.models.LogisticRegression;
import com.dataiku.scoring.pipelines.BinaryProbabilisticPipeline;
import com.dataiku.scoring.pipelines.ClassificationPipeline;
import java.util.ArrayList;
import java.util.List;

@XML.Named(name="MiningModel")
public class PMMLLogisticBinaryClassifier
extends PMMLBinaryClassifier {
    public PMMLLogisticBinaryClassifier(BinaryProbabilisticPipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
        super(pipe, rppp, false);
        this.setModel(new PMMLLogisticBinaryModel(pipe, rppp));
    }

    @Override
    public void enrichTransformationDictionaryFromModel(PMMLTransformationDictionary dictionary, String[] colNames, Classifier model) {
    }

    static List<PMMLRegressionModel.PMMLLogisticRegressionTable> binaryTables(String[] classes, double[] intercepts, double[][] coefs, String[] colnames) {
        ArrayList<PMMLRegressionModel.PMMLLogisticRegressionTable> tables = new ArrayList<PMMLRegressionModel.PMMLLogisticRegressionTable>();
        double intercept = intercepts[0] - intercepts[1];
        double[] coef = new double[coefs[0].length];
        for (int i = 0; i < coef.length; ++i) {
            coef[i] = coefs[0][i] - coefs[1][i];
        }
        tables.add(new PMMLRegressionModel.PMMLLogisticRegressionTable(classes[0], intercept, coef, colnames));
        tables.add(new PMMLRegressionModel.PMMLLogisticRegressionTable(classes[1]));
        return tables;
    }

    @XML.Named(name="RegressionModel")
    public static class PMMLLogisticBinaryModel
    extends PMMLBinaryClassifier.PMMLBinaryModel {
        @XML.Element
        final List<PMMLRegressionModel.PMMLLogisticRegressionTable> RegressionTable;
        @XML.Attribute
        final String normalizationMethod = "softmax";

        public PMMLLogisticBinaryModel(BinaryProbabilisticPipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
            super((ClassificationPipeline)pipe, rppp);
            this.Output = PMMLModel.PMMLOutput.outputOnlyClassOneProbability(pipe.getClasses());
            LogisticRegression lr = (LogisticRegression)pipe.getModel();
            this.RegressionTable = this.createPMMLRegressionTableFromScoringModel(pipe, lr, pipe.getClasses());
        }

        private List<PMMLRegressionModel.PMMLLogisticRegressionTable> createPMMLRegressionTableFromScoringModel(BinaryProbabilisticPipeline pipe, LogisticRegression lr, String[] classes) {
            return PMMLLogisticBinaryClassifier.binaryTables(classes, lr.getBaseline(), lr.getCoefficients(), PMMLPreprocessing.normalizedOutputColumns(pipe.getPreprocessing().getOutputColumns(), pipe.getPreprocessing()));
        }
    }
}

