/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.docgen.extractor;

import com.dataiku.dip.analysis.docgen.extractor.ExtractToPrecision;
import com.dataiku.dip.analysis.docgen.extractor.ModelExtractor;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.model.ClusteringModelingParams;
import com.dataiku.dip.analysis.model.ModelDetailsBase;
import com.dataiku.dip.analysis.model.clustering.ClusteringModelDetails;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.MetricParams;
import com.dataiku.dip.analysis.model.prediction.PredictionModelSnippetData;
import com.dataiku.dip.analysis.model.prediction.TabularPredictionModelDetails;
import com.google.common.collect.ImmutableSet;
import com.jayway.jsonpath.DocumentContext;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;

public class TestMetricValueExtractor
implements ModelExtractor<String> {
    private static final int DEFAULT_PRECISION = 3;
    private static final Set<String> USE_PERCENTAGE_FOR_METRICS = ImmutableSet.of((Object)MetricParams.EvaluationMetric.MAPE.name());
    private static final Set<String> IGNORE_ZEROS_FOR_METRICS = ImmutableSet.of((Object)ClusteringModelingParams.EvaluationMetric.INERTIA.name(), (Object)MetricParams.EvaluationMetric.RMSLE.name());

    @Override
    public String extract(DocumentContext documentContext, ModelDetailsBase model) throws IOException {
        String metricName;
        Double value;
        if (model instanceof TabularPredictionModelDetails) {
            TabularPredictionModelDetails predictionModel = (TabularPredictionModelDetails)model;
            value = TestMetricValueExtractor.getMetricValueFromModel(predictionModel);
            metricName = predictionModel.modeling.metrics.evaluationMetric.name();
        } else if (model instanceof ClusteringModelDetails) {
            ClusteringModelDetails clusteringModel = (ClusteringModelDetails)model;
            value = TestMetricValueExtractor.getMetricValueFromModel(clusteringModel);
            metricName = clusteringModel.modeling.metrics.evaluationMetric.name();
        } else {
            throw new IOException(String.format("Extractor '%s' does not support model details of type '%s'", this.getClass().getSimpleName(), model.getClass().getSimpleName()));
        }
        int precision = metricName.startsWith("NB_") ? 0 : 3;
        return TestMetricValueExtractor.formatMetricValue(value, metricName, precision);
    }

    private static Double getMetricValueFromModel(TabularPredictionModelDetails model) {
        if (model instanceof ClassicalPredictionModelDetails) {
            ((ClassicalPredictionModelDetails)model).headTaskCMW = model.modeling.metrics.costMatrixWeights;
        }
        PredictionModelSnippetData snippetData = PredictionResultsReader.makeSnippet(model);
        if (model.modeling.metrics.evaluationMetric == null) {
            throw new IllegalStateException("No evaluation metric");
        }
        switch (model.modeling.metrics.evaluationMetric) {
            case ACCURACY: {
                return snippetData.accuracy;
            }
            case RECALL: {
                return snippetData.recall;
            }
            case PRECISION: {
                return snippetData.precision;
            }
            case F1: {
                return snippetData.f1;
            }
            case COST_MATRIX: {
                return snippetData.costMatrixGain;
            }
            case CUMULATIVE_LIFT: {
                return snippetData.lift;
            }
            case LOG_LOSS: {
                return snippetData.logLoss;
            }
            case ROC_AUC: {
                return snippetData.auc;
            }
            case AVERAGE_PRECISION: {
                return snippetData.averagePrecision;
            }
            case EVS: {
                return snippetData.evs;
            }
            case MAPE: {
                return snippetData.mape;
            }
            case MAE: {
                return snippetData.mae;
            }
            case MSE: {
                return snippetData.mse;
            }
            case RMSE: {
                return snippetData.rmse;
            }
            case RMSLE: {
                return snippetData.rmsle;
            }
            case R2: {
                return snippetData.r2;
            }
            case MASE: {
                return snippetData.mase;
            }
            case MEAN_ABSOLUTE_QUANTILE_LOSS: {
                return snippetData.meanAbsoluteQuantileLoss;
            }
            case MEAN_WEIGHTED_QUANTILE_LOSS: {
                return snippetData.meanWeightedQuantileLoss;
            }
            case MSIS: {
                return snippetData.msis;
            }
            case ND: {
                return snippetData.nd;
            }
            case SMAPE: {
                return snippetData.smape;
            }
            case CUSTOM: {
                return snippetData.customScore;
            }
        }
        throw new IllegalStateException("Unknown metric");
    }

    private static Double getMetricValueFromModel(ClusteringModelDetails model) {
        switch (model.modeling.metrics.evaluationMetric) {
            case SILHOUETTE: {
                return model.perf.metrics.silhouette;
            }
            case INERTIA: {
                return model.perf.metrics.inertia;
            }
            case NB_CLUSTERS: {
                return model.perf.metrics.nbClusters;
            }
        }
        return null;
    }

    private static String formatMetricValue(Double metricValue, String metricName, int precision) {
        if (metricValue == null || IGNORE_ZEROS_FOR_METRICS.contains(metricName) && metricValue == 0.0) {
            return "-";
        }
        boolean percent = USE_PERCENTAGE_FOR_METRICS.contains(metricName);
        String percentSymbol = "";
        if (percent) {
            percentSymbol = "%";
            metricValue = metricValue * 100.0;
            precision = Math.max(1, precision - 2);
        }
        double abs = Math.abs(metricValue);
        boolean exp = false;
        if (abs >= 10000.0 && !percent) {
            exp = true;
        } else if (abs >= 100.0) {
            precision = 0;
            exp = false;
        }
        return (exp ? TestMetricValueExtractor.toPrecision(metricValue, precision) : TestMetricValueExtractor.toFixed(metricValue, precision)) + percentSymbol;
    }

    public static String toFixed(double value, int precision) {
        String format = precision == 0 ? "0" : "0." + StringUtils.repeat((char)'0', (int)precision);
        return new DecimalFormat(format).format(value);
    }

    private static String toPrecision(double value, int precision) {
        return ExtractToPrecision.toPrecision(value, precision);
    }
}

