/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.CustomMetricResultAggregator;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.ml.shared.ResultsReaderBase;
import com.dataiku.dip.analysis.model.ModelDetailsBase;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.CustomMetricResult;
import com.dataiku.dip.analysis.model.core.ModelCustomEvaluationMetric;
import com.dataiku.dip.analysis.model.core.ResolvedCoreParams;
import com.dataiku.dip.analysis.model.prediction.BinaryClassificationModelPerf;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.ClassificationModelIntrinsicPerf;
import com.dataiku.dip.analysis.model.prediction.MetricParams;
import com.dataiku.dip.analysis.model.prediction.MulticlassModelPerf;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelPerf;
import com.dataiku.dip.analysis.model.prediction.RegressionModelIntrinsicPerf;
import com.dataiku.dip.analysis.model.prediction.RegressionModelPerf;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedTimeseriesForecastingCoreParams;
import com.dataiku.dip.analysis.model.prediction.TabularPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.TimeseriesForecastingModelDetails;
import com.dataiku.dip.analysis.model.prediction.TimeseriesForecastingModelIntrinsicPerf;
import com.dataiku.dip.analysis.model.prediction.TimeseriesForecastingModelPerf;
import com.dataiku.dip.analysis.model.prediction.assertions.MLAssertionsParams;
import com.dataiku.dip.partitioning.StratifiedModelUtils;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.JSON;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.ObjectUtils;

public class StratifiedMetricsAggregator {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis.ml.prediction.partitioned.aggregation");

    public static List<ModelDetailsBase> retrievePerPartitionMetrics(FullModelId fmi, Set<FullModelId> partitionFmis) throws IOException {
        ResolvedPredictionCoreParams pcp = (ResolvedPredictionCoreParams)fmi.parseSessionFile("core_params.json", ResolvedCoreParams.class);
        ArrayList<ModelDetailsBase> modelDetailsList = new ArrayList<ModelDetailsBase>();
        for (FullModelId partitionFmi : partitionFmis) {
            TabularPredictionModelDetails modelDetails;
            if (!partitionFmi.getModelFolder().exists() || !partitionFmi.getModelInfoFile().exists()) continue;
            ModelTrainInfo mti = partitionFmi.parseModelFile("train_info.json", ModelTrainInfo.class);
            if (mti.state != ModelTrainInfo.ModelTrainState.DONE) continue;
            if (pcp.prediction_type == PredictionMLTask.PredictionType.TIMESERIES_FORECAST) {
                modelDetails = new TimeseriesForecastingModelDetails();
                ResultsReaderBase.readTrainInfoUserMetaAndDiagnostics(partitionFmi, modelDetails);
                modelDetails.coreParams = (ResolvedTimeseriesForecastingCoreParams)partitionFmi.getResolvedCoreParams();
                modelDetails.modeling = partitionFmi.parseModelFile("rmodeling_params.json", PreTrainPredictionModelingParams.class);
                modelDetails.perf = partitionFmi.getTimeseriesForecastingPerf();
                modelDetails.iperf = partitionFmi.getTimeseriesForecastingIPerf();
                modelDetailsList.add(modelDetails);
                continue;
            }
            modelDetails = new ClassicalPredictionModelDetails();
            ResultsReaderBase.readTrainInfoUserMetaAndDiagnostics(partitionFmi, modelDetails);
            ((ClassicalPredictionModelDetails)modelDetails).coreParams = (ResolvedClassicalPredictionCoreParams)partitionFmi.getResolvedCoreParams();
            ((ClassicalPredictionModelDetails)modelDetails).modeling = partitionFmi.parseModelFile("rmodeling_params.json", PreTrainPredictionModelingParams.class);
            PredictionResultsReader.readClassicalModelAllPerfs(partitionFmi, (ClassicalPredictionModelDetails)modelDetails, pcp.prediction_type);
            if (partitionFmi.getAssertionsFile().isFile()) {
                ((ClassicalPredictionModelDetails)modelDetails).assertionsParams = partitionFmi.getAssertionsParams();
            }
            modelDetailsList.add(modelDetails);
        }
        return modelDetailsList;
    }

    public static List<ModelDetailsBase> retrievePerPartitionMetrics(FullModelId fmi) throws IOException {
        return StratifiedMetricsAggregator.retrievePerPartitionMetrics(fmi, StratifiedModelUtils.fetchPartitionFmis(fmi));
    }

    public static ModelDetailsBase computeOverallMetrics(PredictionMLTask.PredictionType predictionType, List<ModelDetailsBase> modelDetailsList) {
        if (modelDetailsList.isEmpty()) {
            return null;
        }
        switch (predictionType) {
            case BINARY_CLASSIFICATION: {
                return StratifiedMetricsAggregator.aggregateBinaryClassificationMetrics(modelDetailsList);
            }
            case MULTICLASS: {
                return StratifiedMetricsAggregator.aggregateMultiClassificationMetrics(modelDetailsList);
            }
            case REGRESSION: {
                return StratifiedMetricsAggregator.aggregateRegressionMetrics(modelDetailsList);
            }
            case TIMESERIES_FORECAST: {
                return StratifiedMetricsAggregator.aggregateTimeseriesForecastingMetrics(modelDetailsList);
            }
        }
        throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)predictionType));
    }

    public static void computeAndSaveOverallMetrics(PredictionMLTask.PredictionType predictionType, List<ModelDetailsBase> modelDetailsList, FullModelId globalModelId) throws IOException {
        ModelDetailsBase pred = StratifiedMetricsAggregator.computeOverallMetrics(predictionType, modelDetailsList);
        if (pred == null) {
            logger.warn((Object)"No global performance computed.");
            DKUFileUtils.delete((File)globalModelId.getModelFile("perf.json"));
            DKUFileUtils.delete((File)globalModelId.getModelFile("iperf.json"));
        } else if (predictionType == PredictionMLTask.PredictionType.TIMESERIES_FORECAST) {
            TimeseriesForecastingModelDetails timeseriesForecastingModelDetails = (TimeseriesForecastingModelDetails)pred;
            JSON.prettyToFile((Object)timeseriesForecastingModelDetails.perf, (File)globalModelId.getModelFile("perf.json"));
            JSON.prettyToFile((Object)timeseriesForecastingModelDetails.iperf, (File)globalModelId.getModelFile("iperf.json"));
        } else {
            ClassicalPredictionModelDetails classicalPredictionModelDetails = (ClassicalPredictionModelDetails)pred;
            JSON.prettyToFile((Object)classicalPredictionModelDetails.perf, (File)globalModelId.getModelFile("perf.json"));
            JSON.prettyToFile((Object)classicalPredictionModelDetails.iperf, (File)globalModelId.getModelFile("iperf.json"));
            if (classicalPredictionModelDetails.assertionsParams == null) {
                DKUFileUtils.delete((File)globalModelId.getAssertionsFile());
            } else {
                JSON.prettyToFile((Object)classicalPredictionModelDetails.assertionsParams, (File)globalModelId.getAssertionsFile());
            }
        }
    }

    public static void setTrainInfo(ModelTrainInfo trainInfo, long startTime, long endTime) {
        trainInfo.startTime = startTime;
        trainInfo.endTime = endTime;
        trainInfo.trainingTime = endTime - startTime;
    }

    private static ClassicalPredictionModelDetails aggregateBinaryClassificationMetrics(List<ModelDetailsBase> modelDetailsList) {
        ClassicalPredictionModelDetails globalDetails = new ClassicalPredictionModelDetails();
        globalDetails.assertionsParams = StratifiedMetricsAggregator.getCommonAssertionsParams(modelDetailsList);
        BinaryClassificationModelPerf globalPerf = new BinaryClassificationModelPerf();
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        double testWeights = 0.0;
        double rocAUCs = 0.0;
        double averagePrecisions = 0.0;
        double logLosses = 0.0;
        double customScore = 0.0;
        ArrayList<PredictionModelPerf.AssertionsMetrics> assertionsMetricsToAggregate = new ArrayList<PredictionModelPerf.AssertionsMetrics>();
        HashMap<String, ModelCustomEvaluationMetric> customEvaluationMetrics = new HashMap<String, ModelCustomEvaluationMetric>();
        HashMap<String, CustomMetricResultAggregator> customPerCutEvaluationMetricsScores = new HashMap<String, CustomMetricResultAggregator>();
        HashMap<String, CustomMetricResultAggregator> customTIEvaluationMetricsScores = new HashMap<String, CustomMetricResultAggregator>();
        ClassicalPredictionModelDetails firstPartitionModel = (ClassicalPredictionModelDetails)modelDetailsList.get(0);
        for (ModelCustomEvaluationMetric modelCustomEvaluationMetric : firstPartitionModel.modeling.metrics.getCustomMetrics()) {
            customEvaluationMetrics.put(modelCustomEvaluationMetric.name, modelCustomEvaluationMetric);
        }
        BinaryClassificationModelPerf firstModelPerf = (BinaryClassificationModelPerf)firstPartitionModel.perf;
        if (firstModelPerf.perCutData != null && firstModelPerf.perCutData.customMetricsResults != null) {
            for (CustomMetricResult metricResult : firstModelPerf.perCutData.customMetricsResults) {
                customPerCutEvaluationMetricsScores.put(metricResult.metric.name, new CustomMetricResultAggregator());
            }
        }
        if (firstModelPerf.tiMetrics != null && firstModelPerf.tiMetrics.customMetricsResults != null) {
            for (CustomMetricResult metricResult : firstModelPerf.tiMetrics.customMetricsResults) {
                customTIEvaluationMetricsScores.put(metricResult.metric.name, new CustomMetricResultAggregator());
            }
        }
        for (ModelDetailsBase partitionDetails : modelDetailsList) {
            ClassicalPredictionModelDetails classicalPredictionModelDetails = (ClassicalPredictionModelDetails)partitionDetails;
            BinaryClassificationModelPerf perf = (BinaryClassificationModelPerf)classicalPredictionModelDetails.perf;
            if (perf == null) continue;
            int thresholdIndex = perf.thresholdIndex(perf.usedThreshold);
            confusionMatrix.tp += perf.perCutData.tp[thresholdIndex];
            confusionMatrix.tn += perf.perCutData.tn[thresholdIndex];
            confusionMatrix.fp += perf.perCutData.fp[thresholdIndex];
            confusionMatrix.fn += perf.perCutData.fn[thresholdIndex];
            if (classicalPredictionModelDetails.modeling.metrics.evaluationMetric == MetricParams.EvaluationMetric.CUSTOM) {
                CustomMetricResult[] modelCustomEvaluationMetric = classicalPredictionModelDetails.modeling.metrics.getCustomEvaluationMetric();
                customScore = modelCustomEvaluationMetric.needsProbability ? (customScore += perf.tiMetrics.customScore * perf.globalMetrics.testWeight) : (customScore += perf.perCutData.customScore[thresholdIndex] * perf.globalMetrics.testWeight);
            }
            if (null != perf.perCutData.customMetricsResults) {
                for (CustomMetricResult metricResult : perf.perCutData.customMetricsResults) {
                    CustomMetricResultAggregator aggregator = (CustomMetricResultAggregator)customPerCutEvaluationMetricsScores.get(metricResult.metric.name);
                    aggregator.addValue(metricResult, perf.globalMetrics.testWeight, classicalPredictionModelDetails.fullModelId, thresholdIndex);
                    customPerCutEvaluationMetricsScores.put(metricResult.metric.name, aggregator);
                }
            }
            if (null != perf.tiMetrics.customMetricsResults) {
                StratifiedMetricsAggregator.aggregatePerfCustomMetricScores(customTIEvaluationMetricsScores, perf.tiMetrics.customMetricsResults, perf.globalMetrics.testWeight, classicalPredictionModelDetails.fullModelId);
            }
            rocAUCs += perf.tiMetrics.auc * perf.globalMetrics.testWeight;
            averagePrecisions += perf.tiMetrics.averagePrecision * perf.globalMetrics.testWeight;
            logLosses += perf.tiMetrics.logLoss * perf.globalMetrics.testWeight;
            testWeights += perf.globalMetrics.testWeight;
            if (perf.perCutData.assertionsMetrics == null || perf.perCutData.assertionsMetrics.length <= thresholdIndex) continue;
            assertionsMetricsToAggregate.add(perf.perCutData.assertionsMetrics[thresholdIndex]);
        }
        ConfusionMatrix.Metrics metrics = confusionMatrix.getMetrics();
        globalPerf.perCutData = new BinaryClassificationModelPerf.CutData();
        globalPerf.perCutData.cut = new double[]{0.0};
        globalPerf.perCutData.tp = new float[]{confusionMatrix.tp};
        globalPerf.perCutData.fp = new float[]{confusionMatrix.fp};
        globalPerf.perCutData.tn = new float[]{confusionMatrix.tn};
        globalPerf.perCutData.fn = new float[]{confusionMatrix.fn};
        globalPerf.perCutData.precision = new double[]{metrics.precision};
        globalPerf.perCutData.recall = new double[]{metrics.recall};
        globalPerf.perCutData.accuracy = new double[]{metrics.accuracy};
        globalPerf.perCutData.f1 = new double[]{metrics.f1};
        globalPerf.perCutData.mcc = new double[]{metrics.mcc};
        globalPerf.perCutData.hammingLoss = new double[]{metrics.hammingLoss};
        PredictionModelPerf.AssertionsMetrics assertionsMetrics = StratifiedMetricsAggregator.aggregateAssertionsMetrics(globalDetails.assertionsParams, assertionsMetricsToAggregate);
        if (assertionsMetrics != null) {
            globalPerf.perCutData.assertionsMetrics = new PredictionModelPerf.AssertionsMetrics[]{assertionsMetrics};
        }
        ArrayList<CustomMetricResult> customTIMetricsResults = StratifiedMetricsAggregator.getWeightedCustomMetricResults(customEvaluationMetrics, customTIEvaluationMetricsScores);
        globalPerf.tiMetrics.customMetricsResults = customTIMetricsResults.toArray(new CustomMetricResult[0]);
        ArrayList<CustomMetricResult> customPerCutMetricsResults = new ArrayList<CustomMetricResult>();
        for (Map.Entry entry : customPerCutEvaluationMetricsScores.entrySet()) {
            CustomMetricResultAggregator aggregator = (CustomMetricResultAggregator)entry.getValue();
            ModelCustomEvaluationMetric sourceMetric = customEvaluationMetrics.get(entry.getKey());
            CustomMetricResult aggregatedResult = aggregator.aggregateMulticut(sourceMetric);
            customPerCutMetricsResults.add(aggregatedResult);
        }
        globalPerf.perCutData.customMetricsResults = customPerCutMetricsResults.toArray(new CustomMetricResult[0]);
        if (testWeights > 0.0) {
            ClassicalPredictionModelDetails modelDetails = (ClassicalPredictionModelDetails)modelDetailsList.get(0);
            if (modelDetails.modeling.metrics.evaluationMetric == MetricParams.EvaluationMetric.CUSTOM) {
                double meanCustomScore = customScore / testWeights;
                ModelCustomEvaluationMetric modelCustomEvaluationMetric = modelDetails.modeling.metrics.getCustomEvaluationMetric();
                if (modelCustomEvaluationMetric.needsProbability) {
                    globalPerf.tiMetrics.customScore = meanCustomScore;
                } else {
                    globalPerf.perCutData.customScore = new double[]{meanCustomScore};
                }
            }
            globalPerf.tiMetrics.logLoss = logLosses / testWeights;
            globalPerf.tiMetrics.auc = rocAUCs / testWeights;
            globalPerf.tiMetrics.averagePrecision = averagePrecisions / testWeights;
        }
        globalDetails.perf = globalPerf;
        globalDetails.iperf = new ClassificationModelIntrinsicPerf();
        ClassificationModelIntrinsicPerf iperf = (ClassificationModelIntrinsicPerf)((ClassicalPredictionModelDetails)modelDetailsList.get((int)0)).iperf;
        if (iperf != null) {
            ((ClassificationModelIntrinsicPerf)globalDetails.iperf).probaAware = iperf.probaAware;
        }
        return globalDetails;
    }

    private static ClassicalPredictionModelDetails aggregateMultiClassificationMetrics(List<ModelDetailsBase> modelDetailsList) {
        ClassicalPredictionModelDetails globalDetails = new ClassicalPredictionModelDetails();
        globalDetails.assertionsParams = StratifiedMetricsAggregator.getCommonAssertionsParams(modelDetailsList);
        MulticlassModelPerf globalPerf = StratifiedMetricsAggregator.sumPartitionModelPerfs(modelDetailsList);
        Map<String, ConfusionMatrix> classToOneVsAllConfusionMatrix = StratifiedMetricsAggregator.getOneVsAllConfusionMatrices(globalPerf);
        ArrayList<PredictionModelPerf.AssertionsMetrics> assertionsMetricsToAggregate = new ArrayList<PredictionModelPerf.AssertionsMetrics>();
        for (ModelDetailsBase modelDetails : modelDetailsList) {
            ClassicalPredictionModelDetails classicalPredictionModelDetails = (ClassicalPredictionModelDetails)modelDetails;
            if (classicalPredictionModelDetails == null || classicalPredictionModelDetails.perf == null) continue;
            MulticlassModelPerf perf = (MulticlassModelPerf)classicalPredictionModelDetails.perf;
            if (perf.metrics == null) continue;
            assertionsMetricsToAggregate.add(perf.metrics.assertionsMetrics);
        }
        globalPerf.metrics.assertionsMetrics = StratifiedMetricsAggregator.aggregateAssertionsMetrics(globalDetails.assertionsParams, assertionsMetricsToAggregate);
        double precision = 0.0;
        double recall = 0.0;
        int nClasses = classToOneVsAllConfusionMatrix.size();
        for (ConfusionMatrix classMatrix : classToOneVsAllConfusionMatrix.values()) {
            ConfusionMatrix.Metrics classMetrics = classMatrix.getMetrics();
            precision += DKUtils.defaultIfNan((Double)classMetrics.precision, (double)0.0).doubleValue();
            recall += DKUtils.defaultIfNan((Double)classMetrics.recall, (double)0.0).doubleValue();
        }
        Double f1 = null;
        if (nClasses > 0 && (precision /= (double)nClasses) + (recall /= (double)nClasses) > 0.0) {
            f1 = 2.0 * (precision * recall) / (precision + recall);
        }
        globalPerf.metrics.precision = precision;
        globalPerf.metrics.recall = recall;
        globalPerf.metrics.f1 = f1;
        ConfusionMatrix globalMatrix = ConfusionMatrix.computeSum(classToOneVsAllConfusionMatrix.values());
        globalPerf.metrics.accuracy = (double)globalMatrix.tp / (double)(globalMatrix.tp + globalMatrix.fp);
        globalPerf.metrics.hammingLoss = 1.0 - globalPerf.metrics.accuracy;
        globalDetails.perf = globalPerf;
        globalDetails.iperf = new ClassificationModelIntrinsicPerf();
        ClassificationModelIntrinsicPerf iperf = (ClassificationModelIntrinsicPerf)((ClassicalPredictionModelDetails)modelDetailsList.get((int)0)).iperf;
        if (iperf != null) {
            ((ClassificationModelIntrinsicPerf)globalDetails.iperf).probaAware = iperf.probaAware;
        }
        return globalDetails;
    }

    private static double square(double x) {
        return x * x;
    }

    private static ClassicalPredictionModelDetails aggregateRegressionMetrics(List<ModelDetailsBase> modelDetailsList) {
        ClassicalPredictionModelDetails globalDetails = new ClassicalPredictionModelDetails();
        globalDetails.assertionsParams = StratifiedMetricsAggregator.getCommonAssertionsParams(modelDetailsList);
        RegressionModelPerf globalPerf = new RegressionModelPerf();
        globalPerf.metrics = new RegressionModelPerf.RegressionMetrics();
        double totalTestWeight = 0.0;
        double totalTestWeightMSE = 0.0;
        double totalTestWeightMAE = 0.0;
        double totalTestWeightRMSE = 0.0;
        double totalTestWeightRMSLE = 0.0;
        double totalTestWeightCov = 0.0;
        double totalTestWeightCustomScore = 0.0;
        double mse = 0.0;
        double mae = 0.0;
        double mape = 0.0;
        double rmse = 0.0;
        double rmsle = 0.0;
        double r2Denominator = 0.0;
        double evsDenominator = 0.0;
        double r2Numerator = 0.0;
        double evsNumerator = 0.0;
        double targetVar = 0.0;
        double predictionVar = 0.0;
        double sumY = 0.0;
        double sumYPred = 0.0;
        double yyPred = 0.0;
        double minError = Double.MAX_VALUE;
        double maxError = -1.7976931348623157E308;
        double avgError = 0.0;
        double customScore = 0.0;
        ArrayList<PredictionModelPerf.AssertionsMetrics> assertionsMetricsToAggregate = new ArrayList<PredictionModelPerf.AssertionsMetrics>();
        HashMap<String, ModelCustomEvaluationMetric> customEvaluationMetrics = new HashMap<String, ModelCustomEvaluationMetric>();
        HashMap<String, CustomMetricResultAggregator> customEvaluationMetricsScores = new HashMap<String, CustomMetricResultAggregator>();
        for (ModelCustomEvaluationMetric customMetric : ((ClassicalPredictionModelDetails)modelDetailsList.get((int)0)).modeling.metrics.getCustomMetrics()) {
            customEvaluationMetrics.put(customMetric.name, customMetric);
            customEvaluationMetricsScores.put(customMetric.name, new CustomMetricResultAggregator());
        }
        for (ModelDetailsBase partitionDetails : modelDetailsList) {
            ClassicalPredictionModelDetails classicalPredictionModelDetails = (ClassicalPredictionModelDetails)partitionDetails;
            RegressionModelPerf perf = (RegressionModelPerf)classicalPredictionModelDetails.perf;
            if (perf == null || perf.metrics == null || perf.globalMetrics == null) continue;
            totalTestWeight += perf.globalMetrics.testWeight;
            if (null != perf.metrics.mse) {
                mse += perf.metrics.mse * perf.globalMetrics.testWeight;
                totalTestWeightMSE += perf.globalMetrics.testWeight;
            }
            if (null != perf.metrics.mae) {
                mae += perf.metrics.mae * perf.globalMetrics.testWeight;
                totalTestWeightMAE += perf.globalMetrics.testWeight;
            }
            if (null != perf.metrics.rmse) {
                rmse += StratifiedMetricsAggregator.square(perf.metrics.rmse) * perf.globalMetrics.testWeight;
                totalTestWeightRMSE += perf.globalMetrics.testWeight;
            }
            if (null != perf.metrics.rmsle) {
                rmsle += StratifiedMetricsAggregator.square(perf.metrics.rmsle) * perf.globalMetrics.testWeight;
                totalTestWeightRMSLE += perf.globalMetrics.testWeight;
            }
            if (classicalPredictionModelDetails.modeling.metrics.evaluationMetric == MetricParams.EvaluationMetric.CUSTOM && null != perf.metrics.customScore) {
                customScore += perf.metrics.customScore * perf.globalMetrics.testWeight;
                totalTestWeightCustomScore += perf.globalMetrics.testWeight;
            }
            if (null != perf.metrics.customMetricsResults) {
                StratifiedMetricsAggregator.aggregatePerfCustomMetricScores(customEvaluationMetricsScores, perf.metrics.customMetricsResults, perf.globalMetrics.testWeight, classicalPredictionModelDetails.fullModelId);
            }
            assertionsMetricsToAggregate.add(perf.metrics.assertionsMetrics);
            double targetAvg = perf.globalMetrics.targetAvg[0];
            double targetStd = perf.globalMetrics.targetStd[0];
            double predictionAvg = perf.globalMetrics.predictionAvg[0];
            double predictionStd = perf.globalMetrics.predictionStd[0];
            double targetVariance = StratifiedMetricsAggregator.square(targetStd);
            if (perf.metrics.pearson != null) {
                double covariance = targetStd * predictionStd * perf.metrics.pearson;
                yyPred += perf.globalMetrics.testWeight * (covariance + targetAvg * predictionAvg);
                totalTestWeightCov += perf.globalMetrics.testWeight;
                sumY += perf.globalMetrics.testWeight * targetAvg;
                sumYPred += perf.globalMetrics.testWeight * predictionAvg;
                targetVar += StratifiedMetricsAggregator.square(targetAvg) * perf.globalMetrics.testWeight;
                predictionVar += StratifiedMetricsAggregator.square(predictionAvg) * perf.globalMetrics.testWeight;
                mape += perf.metrics.mape * perf.globalMetrics.testWeight;
            }
            if (perf.metrics.r2 != null) {
                r2Denominator += perf.globalMetrics.testWeight * targetVariance;
                r2Numerator += perf.metrics.r2 * perf.globalMetrics.testWeight * targetVariance;
            }
            if (perf.metrics.evs != null) {
                evsDenominator += perf.globalMetrics.testWeight * targetVariance;
                evsNumerator += perf.metrics.evs * perf.globalMetrics.testWeight * targetVariance;
            }
            if (perf.regression_performance == null) continue;
            if (perf.regression_performance.min_error < minError) {
                minError = perf.regression_performance.min_error;
            }
            if (perf.regression_performance.max_error > maxError) {
                maxError = perf.regression_performance.max_error;
            }
            avgError += perf.regression_performance.average_error * perf.globalMetrics.testWeight;
        }
        ArrayList<CustomMetricResult> customMetricsResults = StratifiedMetricsAggregator.getWeightedCustomMetricResults(customEvaluationMetrics, customEvaluationMetricsScores);
        globalPerf.metrics.customMetricsResults = customMetricsResults.toArray(new CustomMetricResult[0]);
        if (totalTestWeightMSE > 0.0) {
            globalPerf.metrics.mse = mse /= totalTestWeightMSE;
        }
        if (totalTestWeightMAE > 0.0) {
            globalPerf.metrics.mae = mae /= totalTestWeightMAE;
        }
        if (totalTestWeightRMSE > 0.0) {
            rmse /= totalTestWeightRMSE;
            rmse = Math.sqrt(rmse);
            globalPerf.metrics.rmse = rmse;
        }
        if (totalTestWeightRMSLE > 0.0) {
            rmsle /= totalTestWeightRMSLE;
            rmsle = Math.sqrt(rmsle);
            globalPerf.metrics.rmsle = rmsle;
        }
        if (((ClassicalPredictionModelDetails)modelDetailsList.get((int)0)).modeling.metrics.evaluationMetric == MetricParams.EvaluationMetric.CUSTOM && totalTestWeightCustomScore > 0.0) {
            globalPerf.metrics.customScore = customScore / totalTestWeightCustomScore;
        }
        globalPerf.metrics.assertionsMetrics = StratifiedMetricsAggregator.aggregateAssertionsMetrics(globalDetails.assertionsParams, assertionsMetricsToAggregate);
        if (totalTestWeight > 0.0) {
            if (r2Denominator > 0.0) {
                globalPerf.metrics.r2 = r2Numerator / r2Denominator;
            }
            if (evsDenominator > 0.0) {
                globalPerf.metrics.evs = evsNumerator / evsDenominator;
            }
            if (totalTestWeightCov > 0.0 && targetVar > 0.0 && predictionVar > 0.0) {
                double finalCovariance = yyPred / totalTestWeightCov - sumY * sumYPred / StratifiedMetricsAggregator.square(totalTestWeightCov);
                globalPerf.metrics.pearson = finalCovariance / (Math.sqrt(targetVar /= totalTestWeightCov) * Math.sqrt(predictionVar /= totalTestWeightCov));
                globalPerf.metrics.mape = mape / totalTestWeightCov;
            }
            if (minError < Double.MAX_VALUE && maxError > -1.7976931348623157E308) {
                globalPerf.regression_performance = new RegressionModelPerf.RegressionPerformance();
                globalPerf.regression_performance.min_error = minError;
                globalPerf.regression_performance.max_error = maxError;
                globalPerf.regression_performance.average_error = avgError / totalTestWeight;
            }
            int totalScatterPlotRows = 0;
            double[] x = new double[1000 + modelDetailsList.size()];
            double[] y = new double[x.length];
            for (ModelDetailsBase partitionDetails : modelDetailsList) {
                ClassicalPredictionModelDetails classicalPredictionModelDetails = (ClassicalPredictionModelDetails)partitionDetails;
                RegressionModelPerf perf = (RegressionModelPerf)classicalPredictionModelDetails.perf;
                if (perf == null || perf.globalMetrics == null || perf.scatterPlotData == null || perf.scatterPlotData.x == null || perf.scatterPlotData.x.length == 0) continue;
                int sample = (int)(1000.0 * perf.globalMetrics.testWeight / totalTestWeight);
                if (sample == 0) {
                    sample = 1;
                }
                if (sample > perf.scatterPlotData.x.length) {
                    sample = perf.scatterPlotData.x.length;
                }
                System.arraycopy(perf.scatterPlotData.x, 0, x, totalScatterPlotRows, sample);
                System.arraycopy(perf.scatterPlotData.y, 0, y, totalScatterPlotRows, sample);
                totalScatterPlotRows += sample;
            }
            if (totalScatterPlotRows > 0) {
                globalPerf.scatterPlotData = new RegressionModelPerf.ScatterplotData();
                globalPerf.scatterPlotData.x = Arrays.copyOfRange(x, 0, totalScatterPlotRows + 1);
                globalPerf.scatterPlotData.y = Arrays.copyOfRange(y, 0, totalScatterPlotRows + 1);
            }
        }
        globalDetails.perf = globalPerf;
        globalDetails.iperf = new RegressionModelIntrinsicPerf();
        return globalDetails;
    }

    private static TimeseriesForecastingModelDetails aggregateTimeseriesForecastingMetrics(List<ModelDetailsBase> modelDetailsList) {
        TimeseriesForecastingModelDetails globalDetails = new TimeseriesForecastingModelDetails();
        TimeseriesForecastingModelPerf globalPerf = new TimeseriesForecastingModelPerf();
        globalPerf.aggregatedMetrics = new TimeseriesForecastingModelPerf.TimeseriesForecastingMetrics();
        double nbNonNullMASE = 0.0;
        double nbNonNullMAPE = 0.0;
        double nbNonNullMAQL = 0.0;
        double nbNonNullMWQL = 0.0;
        double nbNonNullMSE = 0.0;
        double nbNonNullMSIS = 0.0;
        double nbNonNullND = 0.0;
        double nbNonNullRMSE = 0.0;
        double nbNonNullSMAPE = 0.0;
        double nbNonNullMAE = 0.0;
        double mase = 0.0;
        double mape = 0.0;
        double meanAbsoluteQuantileLoss = 0.0;
        double meanWeightedQuantileLoss = 0.0;
        double mse = 0.0;
        double msis = 0.0;
        double nd = 0.0;
        double rmse = 0.0;
        double smape = 0.0;
        double mae = 0.0;
        HashMap<String, ModelCustomEvaluationMetric> customEvaluationMetrics = new HashMap<String, ModelCustomEvaluationMetric>();
        HashMap<String, CustomMetricResultAggregator> customEvaluationMetricsScores = new HashMap<String, CustomMetricResultAggregator>();
        for (ModelCustomEvaluationMetric customMetric : ((TimeseriesForecastingModelDetails)modelDetailsList.get((int)0)).modeling.metrics.getCustomMetrics()) {
            customEvaluationMetrics.put(customMetric.name, customMetric);
            customEvaluationMetricsScores.put(customMetric.name, new CustomMetricResultAggregator());
        }
        for (ModelDetailsBase partitionDetails : modelDetailsList) {
            TimeseriesForecastingModelDetails timeseriesForecastingModelDetails = (TimeseriesForecastingModelDetails)partitionDetails;
            TimeseriesForecastingModelPerf perf = timeseriesForecastingModelDetails.perf;
            if (perf == null || perf.aggregatedMetrics == null) continue;
            if (null != perf.aggregatedMetrics.mase) {
                mase += perf.aggregatedMetrics.mase.doubleValue();
                nbNonNullMASE += 1.0;
            }
            if (null != perf.aggregatedMetrics.mape) {
                mape += perf.aggregatedMetrics.mape.doubleValue();
                nbNonNullMAPE += 1.0;
            }
            if (null != perf.aggregatedMetrics.meanAbsoluteQuantileLoss) {
                meanAbsoluteQuantileLoss += perf.aggregatedMetrics.meanAbsoluteQuantileLoss.doubleValue();
                nbNonNullMAQL += 1.0;
            }
            if (null != perf.aggregatedMetrics.meanWeightedQuantileLoss) {
                meanWeightedQuantileLoss += perf.aggregatedMetrics.meanWeightedQuantileLoss.doubleValue();
                nbNonNullMWQL += 1.0;
            }
            if (null != perf.aggregatedMetrics.mse) {
                mse += perf.aggregatedMetrics.mse.doubleValue();
                nbNonNullMSE += 1.0;
            }
            if (null != perf.aggregatedMetrics.msis) {
                msis += perf.aggregatedMetrics.msis.doubleValue();
                nbNonNullMSIS += 1.0;
            }
            if (null != perf.aggregatedMetrics.nd) {
                nd += perf.aggregatedMetrics.nd.doubleValue();
                nbNonNullND += 1.0;
            }
            if (null != perf.aggregatedMetrics.rmse) {
                rmse += perf.aggregatedMetrics.rmse.doubleValue();
                nbNonNullRMSE += 1.0;
            }
            if (null != perf.aggregatedMetrics.smape) {
                smape += perf.aggregatedMetrics.smape.doubleValue();
                nbNonNullSMAPE += 1.0;
            }
            if (null != perf.aggregatedMetrics.mae) {
                mae += perf.aggregatedMetrics.mae.doubleValue();
                nbNonNullMAE += 1.0;
            }
            if (null == perf.aggregatedMetrics.customMetricsResults) continue;
            StratifiedMetricsAggregator.aggregatePerfCustomMetricScores(customEvaluationMetricsScores, perf.aggregatedMetrics.customMetricsResults, 1.0, timeseriesForecastingModelDetails.fullModelId);
        }
        if (nbNonNullMASE > 0.0) {
            globalPerf.aggregatedMetrics.mase = mase / nbNonNullMASE;
        }
        if (nbNonNullMAPE > 0.0) {
            globalPerf.aggregatedMetrics.mape = mape / nbNonNullMAPE;
        }
        if (nbNonNullMAQL > 0.0) {
            globalPerf.aggregatedMetrics.meanAbsoluteQuantileLoss = meanAbsoluteQuantileLoss / nbNonNullMAQL;
        }
        if (nbNonNullMWQL > 0.0) {
            globalPerf.aggregatedMetrics.meanWeightedQuantileLoss = meanWeightedQuantileLoss / nbNonNullMWQL;
        }
        if (nbNonNullMSE > 0.0) {
            globalPerf.aggregatedMetrics.mse = mse / nbNonNullMSE;
        }
        if (nbNonNullMSIS > 0.0) {
            globalPerf.aggregatedMetrics.msis = msis / nbNonNullMSIS;
        }
        if (nbNonNullND > 0.0) {
            globalPerf.aggregatedMetrics.nd = nd / nbNonNullND;
        }
        if (nbNonNullRMSE > 0.0) {
            globalPerf.aggregatedMetrics.rmse = rmse / nbNonNullRMSE;
        }
        if (nbNonNullSMAPE > 0.0) {
            globalPerf.aggregatedMetrics.smape = smape / nbNonNullSMAPE;
        }
        if (nbNonNullMAE > 0.0) {
            globalPerf.aggregatedMetrics.mae = mae / nbNonNullMAE;
        }
        ArrayList<CustomMetricResult> customMetricsResults = StratifiedMetricsAggregator.getWeightedCustomMetricResults(customEvaluationMetrics, customEvaluationMetricsScores);
        globalPerf.aggregatedMetrics.customMetricsResults = customMetricsResults.toArray(new CustomMetricResult[0]);
        globalDetails.perf = globalPerf;
        globalDetails.iperf = new TimeseriesForecastingModelIntrinsicPerf();
        return globalDetails;
    }

    private static void aggregatePerfCustomMetricScores(HashMap<String, CustomMetricResultAggregator> customEvaluationMetricsScores, CustomMetricResult[] customMetricsResults, Double testWeight, String partitionModelId) {
        for (CustomMetricResult metricResult : customMetricsResults) {
            CustomMetricResultAggregator aggregator = customEvaluationMetricsScores.get(metricResult.metric.name);
            aggregator.addValue(metricResult, testWeight, partitionModelId);
            customEvaluationMetricsScores.put(metricResult.metric.name, aggregator);
        }
    }

    private static ArrayList<CustomMetricResult> getWeightedCustomMetricResults(HashMap<String, ModelCustomEvaluationMetric> customEvaluationMetrics, HashMap<String, CustomMetricResultAggregator> customEvaluationMetricsScores) {
        ArrayList<CustomMetricResult> customMetricsResults = new ArrayList<CustomMetricResult>();
        for (Map.Entry<String, CustomMetricResultAggregator> entry : customEvaluationMetricsScores.entrySet()) {
            ModelCustomEvaluationMetric sourceMetric = customEvaluationMetrics.get(entry.getKey());
            CustomMetricResultAggregator aggregator = entry.getValue();
            CustomMetricResult result = aggregator.aggregate(sourceMetric);
            customMetricsResults.add(result);
        }
        return customMetricsResults;
    }

    public static MLAssertionsParams getCommonAssertionsParams(List<ModelDetailsBase> modelDetailsList) {
        ArrayList<MLAssertionsParams.MLAssertion> keptAssertions = new ArrayList<MLAssertionsParams.MLAssertion>();
        boolean first = true;
        for (ModelDetailsBase modelDetails : modelDetailsList) {
            if (!(modelDetails instanceof ClassicalPredictionModelDetails)) {
                logger.info((Object)("ML assertions not supported by: " + modelDetails.getClass().getSimpleName()));
                return null;
            }
            ClassicalPredictionModelDetails predModelDetails = (ClassicalPredictionModelDetails)modelDetails;
            if (predModelDetails.assertionsParams == null || !predModelDetails.assertionsParams.hasAssertions()) {
                logger.info((Object)"Some partition does not have any assertion defined, skipping aggregation of assertions metrics");
                return null;
            }
            if (first) {
                keptAssertions.addAll(predModelDetails.assertionsParams.assertions);
            } else {
                keptAssertions.retainAll(predModelDetails.assertionsParams.assertions);
            }
            first = false;
        }
        if (keptAssertions.isEmpty()) {
            logger.info((Object)"Mismatch in assertions parameters between partitions, skipping aggregation of assertions metrics");
            return null;
        }
        logger.info((Object)("Aggregating assertion metrics for following assertions: " + StringUtils.join(keptAssertions, (String)", ")));
        return new MLAssertionsParams(keptAssertions);
    }

    /*
     * WARNING - void declaration
     */
    static PredictionModelPerf.AssertionsMetrics aggregateAssertionsMetrics(MLAssertionsParams assertionsParams, List<PredictionModelPerf.AssertionsMetrics> assertionsMetricsToAggregate) {
        void var4_9;
        if (assertionsParams == null || !assertionsParams.hasAssertions() || assertionsMetricsToAggregate == null || assertionsMetricsToAggregate.isEmpty()) {
            return null;
        }
        ArrayList<PredictionModelPerf.AssertionMetrics> aggregatedPerAssertion = new ArrayList<PredictionModelPerf.AssertionMetrics>();
        for (MLAssertionsParams.MLAssertion mLAssertion : assertionsParams.assertions) {
            PredictionModelPerf.AssertionMetrics assertionMetrics = new PredictionModelPerf.AssertionMetrics();
            assertionMetrics.name = mLAssertion.name;
            aggregatedPerAssertion.add(assertionMetrics);
        }
        int index = 0;
        for (PredictionModelPerf.AssertionMetrics aggregatedAssertionMetrics : aggregatedPerAssertion) {
            double reconstructedNbValidRows = 0.0;
            int nbMatchingRows = 0;
            int nbDroppedRows = 0;
            for (PredictionModelPerf.AssertionsMetrics assertionsMetrics : assertionsMetricsToAggregate) {
                PredictionModelPerf.AssertionMetrics am;
                if (assertionsMetrics == null || (am = assertionsMetrics.getAssertionMetrics(aggregatedAssertionMetrics.name)) == null) continue;
                nbDroppedRows += am.nbDroppedRows;
                nbMatchingRows += am.nbMatchingRows;
                if (am.validRatio == null) continue;
                reconstructedNbValidRows += (double)(am.nbMatchingRows - am.nbDroppedRows) * am.validRatio;
            }
            aggregatedAssertionMetrics.nbMatchingRows = nbMatchingRows;
            aggregatedAssertionMetrics.nbDroppedRows = nbDroppedRows;
            Double d = aggregatedAssertionMetrics.validRatio = nbMatchingRows - nbDroppedRows == 0 ? null : Double.valueOf(reconstructedNbValidRows / (double)(nbMatchingRows - nbDroppedRows));
            aggregatedAssertionMetrics.result = aggregatedAssertionMetrics.validRatio == null ? null : Boolean.valueOf(aggregatedAssertionMetrics.validRatio >= assertionsParams.assertions.get((int)index).assertionCondition.expectedValidRatio);
            ++index;
        }
        Object var4_7 = null;
        List nonNullAssertionsResults = aggregatedPerAssertion.stream().map(assertion -> assertion.result).filter(Objects::nonNull).collect(Collectors.toList());
        if (!nonNullAssertionsResults.isEmpty()) {
            Double d = (double)nonNullAssertionsResults.stream().filter(r -> r).count() / (double)nonNullAssertionsResults.size();
        }
        PredictionModelPerf.AssertionsMetrics aggregatedAssertionsMetrics = new PredictionModelPerf.AssertionsMetrics();
        aggregatedAssertionsMetrics.passingAssertionsRatio = var4_9;
        aggregatedAssertionsMetrics.perAssertion = aggregatedPerAssertion;
        return aggregatedAssertionsMetrics;
    }

    private static Map<String, ConfusionMatrix> getOneVsAllConfusionMatrices(MulticlassModelPerf globalPerf) {
        HashMap<String, ConfusionMatrix> classToConfusionMatrix = new HashMap<String, ConfusionMatrix>();
        for (String predClass : globalPerf.classes) {
            classToConfusionMatrix.put(predClass, new ConfusionMatrix());
        }
        for (Map.Entry entry : globalPerf.confusion.perActual.entrySet()) {
            String actualClass = (String)entry.getKey();
            Map<String, Double> predictedClasses = ((MulticlassModelPerf.MulticlassActualClassConfusion)entry.getValue()).perPredicted;
            ConfusionMatrix currentMatrix = (ConfusionMatrix)classToConfusionMatrix.get(actualClass);
            double actualClassCount = ((MulticlassModelPerf.MulticlassActualClassConfusion)entry.getValue()).actualClassCount;
            double correctlyPredicted = predictedClasses.containsKey(actualClass) ? predictedClasses.get(actualClass) : 0.0;
            currentMatrix.tp = (float)((double)currentMatrix.tp + correctlyPredicted);
            currentMatrix.fn = (float)((double)currentMatrix.fn + (actualClassCount - correctlyPredicted));
            for (Map.Entry<String, Double> predictedEntry : predictedClasses.entrySet()) {
                String predictedClass = predictedEntry.getKey();
                if (predictedClass.equals(actualClass)) continue;
                ConfusionMatrix otherMatrix = (ConfusionMatrix)classToConfusionMatrix.get(predictedClass);
                otherMatrix.fp = (float)((double)otherMatrix.fp + predictedEntry.getValue());
            }
        }
        double totalWeights = globalPerf.globalMetrics.testWeight;
        for (ConfusionMatrix confusionMatrix : classToConfusionMatrix.values()) {
            confusionMatrix.tn = (float)totalWeights - confusionMatrix.tp - confusionMatrix.fp - confusionMatrix.fn;
        }
        return classToConfusionMatrix;
    }

    private static MulticlassModelPerf sumPartitionModelPerfs(List<ModelDetailsBase> modelDetailsList) {
        MulticlassModelPerf perf;
        MulticlassModelPerf globalPerf = new MulticlassModelPerf();
        HashSet<String> classes = new HashSet<String>();
        globalPerf.confusion = new MulticlassModelPerf.MulticlassConfusion();
        globalPerf.globalMetrics = new PredictionModelPerf.GlobalMetrics();
        double testWeightCustomScore = 0.0;
        double testWeightsRocAUCs = 0.0;
        double testWeightsAPs = 0.0;
        double testWeightLogLoss = 0.0;
        double rocAUCs = 0.0;
        double logLoss = 0.0;
        double averagePrecisions = 0.0;
        double customScore = 0.0;
        HashMap<String, ModelCustomEvaluationMetric> customEvaluationMetrics = new HashMap<String, ModelCustomEvaluationMetric>();
        HashMap<String, CustomMetricResultAggregator> customEvaluationMetricsScores = new HashMap<String, CustomMetricResultAggregator>();
        for (ModelDetailsBase partitionDetails : modelDetailsList) {
            perf = (MulticlassModelPerf)((ClassicalPredictionModelDetails)partitionDetails).perf;
            classes.addAll(Arrays.asList(perf.classes));
        }
        for (String predClass : classes) {
            globalPerf.confusion.perActual.put(predClass, new MulticlassModelPerf.MulticlassActualClassConfusion());
        }
        for (ModelCustomEvaluationMetric customMetric : ((ClassicalPredictionModelDetails)modelDetailsList.get((int)0)).modeling.metrics.getCustomMetrics()) {
            customEvaluationMetrics.put(customMetric.name, customMetric);
            customEvaluationMetricsScores.put(customMetric.name, new CustomMetricResultAggregator());
        }
        for (ModelDetailsBase partitionDetails : modelDetailsList) {
            perf = (MulticlassModelPerf)((ClassicalPredictionModelDetails)partitionDetails).perf;
            if (perf == null || perf.metrics == null) continue;
            globalPerf.confusion.totalRows += perf.confusion.totalRows;
            globalPerf.globalMetrics.testWeight += perf.globalMetrics.testWeight;
            if (null != perf.metrics.mrocAUC) {
                testWeightsRocAUCs += perf.globalMetrics.testWeight;
                rocAUCs += perf.metrics.mrocAUC * perf.globalMetrics.testWeight;
            }
            if (null != perf.metrics.averagePrecision) {
                testWeightsAPs += perf.globalMetrics.testWeight;
                averagePrecisions += perf.metrics.averagePrecision * perf.globalMetrics.testWeight;
            }
            if (null != perf.metrics.logLoss) {
                testWeightLogLoss += perf.globalMetrics.testWeight;
                logLoss += perf.metrics.logLoss * perf.globalMetrics.testWeight;
            }
            if (((ClassicalPredictionModelDetails)partitionDetails).modeling.metrics.evaluationMetric == MetricParams.EvaluationMetric.CUSTOM && null != perf.metrics.customScore) {
                testWeightCustomScore += perf.globalMetrics.testWeight;
                customScore += perf.metrics.customScore * perf.globalMetrics.testWeight;
            }
            if (null != perf.metrics.customMetricsResults) {
                StratifiedMetricsAggregator.aggregatePerfCustomMetricScores(customEvaluationMetricsScores, perf.metrics.customMetricsResults, perf.globalMetrics.testWeight, partitionDetails.fullModelId);
            }
            for (Map.Entry<String, MulticlassModelPerf.MulticlassActualClassConfusion> classConfusionEntry : perf.confusion.perActual.entrySet()) {
                String currentClass = classConfusionEntry.getKey();
                MulticlassModelPerf.MulticlassActualClassConfusion currentConfusion = classConfusionEntry.getValue();
                MulticlassModelPerf.MulticlassActualClassConfusion globalConfusion = globalPerf.confusion.perActual.get(currentClass);
                globalConfusion.actualClassCount += currentConfusion.actualClassCount;
                for (Map.Entry<String, Double> predictedEntry : currentConfusion.perPredicted.entrySet()) {
                    String predictedClass = predictedEntry.getKey();
                    Double predictedValue = predictedEntry.getValue();
                    Double globalValue = (Double)ObjectUtils.defaultIfNull((Object)globalConfusion.perPredicted.get(predictedClass), (Object)0.0);
                    globalValue = globalValue + predictedValue;
                    globalConfusion.perPredicted.put(predictedClass, globalValue);
                }
                globalPerf.confusion.perActual.put(currentClass, globalConfusion);
            }
        }
        if (((ClassicalPredictionModelDetails)modelDetailsList.get((int)0)).modeling.metrics.evaluationMetric == MetricParams.EvaluationMetric.CUSTOM && testWeightCustomScore > 0.0) {
            globalPerf.metrics.customScore = customScore / testWeightCustomScore;
        }
        ArrayList<CustomMetricResult> customMetricsResults = StratifiedMetricsAggregator.getWeightedCustomMetricResults(customEvaluationMetrics, customEvaluationMetricsScores);
        globalPerf.metrics.customMetricsResults = customMetricsResults.toArray(new CustomMetricResult[0]);
        if (testWeightsRocAUCs > 0.0) {
            globalPerf.metrics.mrocAUC = rocAUCs / testWeightsRocAUCs;
        }
        if (testWeightsAPs > 0.0) {
            globalPerf.metrics.averagePrecision = averagePrecisions / testWeightsAPs;
        }
        if (testWeightLogLoss > 0.0) {
            globalPerf.metrics.logLoss = logLoss / testWeightLogLoss;
        }
        globalPerf.classes = classes.toArray(new String[0]);
        return globalPerf;
    }

    public static class ConfusionMatrix {
        public float tp;
        public float tn;
        public float fp;
        public float fn;

        public static ConfusionMatrix computeSum(Collection<ConfusionMatrix> matrices) {
            ConfusionMatrix result = new ConfusionMatrix();
            for (ConfusionMatrix matrix : matrices) {
                result.tp += matrix.tp;
                result.tn += matrix.tn;
                result.fp += matrix.fp;
                result.fn += matrix.fn;
            }
            return result;
        }

        public String toString() {
            return String.format("ConfusionMatrix{tp=%s, tn=%s, fp=%s, fn=%s}", Float.valueOf(this.tp), Float.valueOf(this.tn), Float.valueOf(this.fp), Float.valueOf(this.fn));
        }

        public Metrics getMetrics() {
            Metrics metrics = new Metrics();
            metrics.precision = this.tp + this.fp == 0.0f ? 1.0f : this.tp / (this.tp + this.fp);
            metrics.recall = this.tp + this.fn == 0.0f ? 0.0 : (double)(this.tp / (this.tp + this.fn));
            metrics.accuracy = this.tp + this.tn + this.fp + this.fn == 0.0f ? 0.0 : (double)((this.tp + this.tn) / (this.tp + this.tn + this.fp + this.fn));
            metrics.f1 = metrics.precision + metrics.recall == 0.0 ? 0.0 : 2.0 * (metrics.precision * metrics.recall) / (metrics.precision + metrics.recall);
            metrics.mcc = (this.tp + this.fp) * (this.tp + this.fn) * (this.tn + this.fp) * (this.tn + this.fn) == 0.0f ? 0.0 : (double)(this.tp * this.tn - this.fp * this.fn) / Math.sqrt((this.tp + this.fp) * (this.tp + this.fn) * (this.tn + this.fp) * (this.tn + this.fn));
            metrics.hammingLoss = 1.0 - metrics.accuracy;
            return metrics;
        }

        public static class Metrics {
            public double precision;
            public double recall;
            public double accuracy;
            public double f1;
            public double mcc;
            public double hammingLoss;
        }
    }
}

