/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports.pmml.models.regression;

import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.scoring.exports.pmml.XML;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLTransformationDictionary;
import com.dataiku.dip.scoring.exports.pmml.models.PMMLModel;
import com.dataiku.dip.scoring.exports.pmml.models.regression.PMMLLinearModel;
import com.dataiku.dip.scoring.exports.pmml.models.regression.PMMLTreeEnsembleRegressor;
import com.dataiku.dip.scoring.exports.pmml.models.regression.PMMLTreeRegressor;
import com.dataiku.scoring.models.DecisionTreeRegressor;
import com.dataiku.scoring.models.ForestRegressor;
import com.dataiku.scoring.models.GradientBoostingRegressor;
import com.dataiku.scoring.models.LinearRegression;
import com.dataiku.scoring.models.Regressor;
import com.dataiku.scoring.pipelines.Pipeline;
import com.dataiku.scoring.pipelines.RegressionPipeline;
import java.util.ArrayList;
import java.util.Map;

public abstract class PMMLRegressor
extends PMMLModel {
    @XML.Attribute
    final String functionName = "regression";

    public PMMLRegressor(Pipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
        this.createMiningSchema(pipe, rppp);
        this.Output = PMMLModel.PMMLOutput.regression();
    }

    public PMMLRegressor(Pipeline pipe, ResolvedPredictionPreprocessingParams rppp, String className) {
        this.createMiningSchema(pipe, rppp);
        this.Output = PMMLModel.PMMLOutput.regression(className);
    }

    private void createMiningSchema(Pipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
        String[] cols = pipe.getPreprocessing().getOutputColumns();
        ArrayList<PMMLModel.PMMLMiningSchema.PMMLMiningField> fields = new ArrayList<PMMLModel.PMMLMiningSchema.PMMLMiningField>(cols.length);
        for (Map.Entry e : rppp.per_feature.entrySet()) {
            if (((FeaturePreprocessingParams)e.getValue()).role != FeaturePreprocessingParams.Role.INPUT) continue;
            fields.add(new PMMLModel.PMMLMiningSchema.PMMLMiningField((String)e.getKey()));
        }
        this.MiningSchema = PMMLModel.PMMLMiningSchema.fromFields(fields);
    }

    public static PMMLRegressor fromRegressionPipeline(RegressionPipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
        Regressor model = pipe.getModel();
        if (model instanceof LinearRegression) {
            return new PMMLLinearModel(pipe, rppp);
        }
        if (model instanceof DecisionTreeRegressor) {
            return new PMMLTreeRegressor(pipe, rppp);
        }
        if (model instanceof ForestRegressor) {
            return PMMLTreeEnsembleRegressor.fromForest(pipe, rppp);
        }
        if (model instanceof GradientBoostingRegressor) {
            return PMMLTreeEnsembleRegressor.fromBooster(pipe, rppp);
        }
        throw new IllegalArgumentException("Algorithm " + String.valueOf(pipe.getModel().getClass()) + " not supported for pmml export.");
    }

    public abstract void enrichTransformationDictionaryFromModel(PMMLTransformationDictionary var1, String[] var2, Regressor var3);
}

