/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.spark.ml.prediction;

import com.dataiku.dip.analysis.model.prediction.MetricParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.spark.ml.prediction.BinaryClassificationPerf;
import com.dataiku.dip.spark.ml.prediction.BinaryClassificationPerf$;
import com.dataiku.dip.spark.ml.prediction.Metrics;
import com.dataiku.dip.spark.ml.prediction.ModelScorer;
import com.dataiku.dip.spark.ml.prediction.ModelScorer$;
import java.io.Serializable;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import scala.Array$;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.TraversableOnce;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.IndexedSeqView$;
import scala.math.Ordering;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

public final class Metrics$ {
    public static Metrics$ MODULE$;

    static {
        new Metrics$();
    }

    public Metrics.Scorer getScorer(MetricParams metricParams, ResolvedClassicalPredictionCoreParams coreParams) {
        boolean bl;
        String target = coreParams.target_variable;
        PredictionMLTask.PredictionType predictionType = coreParams.prediction_type;
        MetricParams.EvaluationMetric metric = metricParams.evaluationMetric;
        MetricParams.EvaluationMetric evaluationMetric = metric;
        if (MetricParams.EvaluationMetric.ACCURACY.equals(evaluationMetric)) {
            bl = true;
        } else if (MetricParams.EvaluationMetric.RECALL.equals(evaluationMetric)) {
            bl = true;
        } else if (MetricParams.EvaluationMetric.PRECISION.equals(evaluationMetric)) {
            bl = true;
        } else if (MetricParams.EvaluationMetric.F1.equals(evaluationMetric)) {
            bl = true;
        } else if (MetricParams.EvaluationMetric.COST_MATRIX.equals(evaluationMetric)) {
            bl = true;
        } else if (MetricParams.EvaluationMetric.CUMULATIVE_LIFT.equals(evaluationMetric)) {
            bl = true;
        } else if (MetricParams.EvaluationMetric.LOG_LOSS.equals(evaluationMetric)) {
            bl = false;
        } else if (MetricParams.EvaluationMetric.ROC_AUC.equals(evaluationMetric)) {
            bl = true;
        } else if (MetricParams.EvaluationMetric.AVERAGE_PRECISION.equals(evaluationMetric)) {
            bl = true;
        } else if (MetricParams.EvaluationMetric.EVS.equals(evaluationMetric)) {
            bl = true;
        } else if (MetricParams.EvaluationMetric.MAPE.equals(evaluationMetric)) {
            bl = false;
        } else if (MetricParams.EvaluationMetric.MAE.equals(evaluationMetric)) {
            bl = false;
        } else if (MetricParams.EvaluationMetric.MSE.equals(evaluationMetric)) {
            bl = false;
        } else if (MetricParams.EvaluationMetric.RMSE.equals(evaluationMetric)) {
            bl = false;
        } else if (MetricParams.EvaluationMetric.RMSLE.equals(evaluationMetric)) {
            bl = false;
        } else if (MetricParams.EvaluationMetric.R2.equals(evaluationMetric)) {
            bl = true;
        } else {
            if (MetricParams.EvaluationMetric.CUSTOM.equals(evaluationMetric)) {
                throw new IllegalArgumentException("unsupported");
            }
            throw new MatchError((Object)evaluationMetric);
        }
        boolean largerIsBetter = bl;
        return new Metrics.Scorer((Function2 & Serializable & scala.Serializable)(m, d) -> BoxesRunTime.boxToDouble((double)Metrics$.$anonfun$getScorer$7(predictionType, metric, target, metricParams, m, d)), largerIsBetter);
    }

    public double computeMetric(MetricParams metricParams, ResolvedClassicalPredictionCoreParams coreParams, Model<?> model, Dataset<Row> data) {
        return this.getScorer(metricParams, coreParams).apply(model, data);
    }

    private static final RDD transform$1(String outColumn, Model model, Dataset dataFrame, String target$1) {
        return model.transform(dataFrame).rdd().map((Function1 & Serializable & scala.Serializable)row -> new Tuple2(row.getAs(target$1), row.getAs(outColumn)), ClassTag$.MODULE$.apply(Tuple2.class));
    }

    private static final double scoreRegression$1(RDD rdd, MetricParams.EvaluationMetric metric$1) {
        ModelScorer.AdvancedRegressionMetrics metrics = new ModelScorer.AdvancedRegressionMetrics((RDD<Tuple2<Object, Object>>)rdd);
        MetricParams.EvaluationMetric evaluationMetric = metric$1;
        if (MetricParams.EvaluationMetric.MSE.equals(evaluationMetric)) {
            return metrics.meanSquaredError();
        }
        if (MetricParams.EvaluationMetric.RMSE.equals(evaluationMetric)) {
            return metrics.rootMeanSquaredError();
        }
        if (MetricParams.EvaluationMetric.EVS.equals(evaluationMetric)) {
            return metrics.r2();
        }
        if (MetricParams.EvaluationMetric.R2.equals(evaluationMetric)) {
            return metrics.r2();
        }
        if (MetricParams.EvaluationMetric.MAE.equals(evaluationMetric)) {
            return metrics.meanAbsoluteError();
        }
        if (MetricParams.EvaluationMetric.RMSLE.equals(evaluationMetric)) {
            return metrics.rmsle();
        }
        if (MetricParams.EvaluationMetric.MAPE.equals(evaluationMetric)) {
            return metrics.mape();
        }
        throw new IllegalArgumentException(new StringBuilder(34).append("Metric ").append(metric$1).append(" unsupported for regression").toString());
    }

    private static final double scoreBinaryClassification$1(RDD rdd, MetricParams metricParams$1, MetricParams.EvaluationMetric metric$1) {
        BinaryClassificationPerf perf = new BinaryClassificationPerf((RDD<Tuple2<Object, Object>>)rdd, (Option<Object>)None$.MODULE$, ModelScorer$.MODULE$.gainMatrix(metricParams$1.costMatrixWeights), metricParams$1.liftPoint, (Option<MetricParams.ThresholdOptimizationMetric>)None$.MODULE$, BinaryClassificationPerf$.MODULE$.$lessinit$greater$default$6());
        MetricParams.EvaluationMetric evaluationMetric = metric$1;
        if (MetricParams.EvaluationMetric.ACCURACY.equals(evaluationMetric)) {
            return BoxesRunTime.unboxToDouble((Object)((TraversableOnce)perf.accuracy().map((Function1 & Serializable & scala.Serializable)x$1 -> BoxesRunTime.boxToDouble((double)x$1._2$mcD$sp()), IndexedSeq$.MODULE$.canBuildFrom())).max((Ordering)Ordering.Double$.MODULE$));
        }
        if (MetricParams.EvaluationMetric.RECALL.equals(evaluationMetric)) {
            return BoxesRunTime.unboxToDouble((Object)((TraversableOnce)perf.recall().map((Function1 & Serializable & scala.Serializable)x$2 -> BoxesRunTime.boxToDouble((double)x$2._2$mcD$sp()), IndexedSeq$.MODULE$.canBuildFrom())).max((Ordering)Ordering.Double$.MODULE$));
        }
        if (MetricParams.EvaluationMetric.PRECISION.equals(evaluationMetric)) {
            return BoxesRunTime.unboxToDouble((Object)((TraversableOnce)perf.precision().map((Function1 & Serializable & scala.Serializable)x$3 -> BoxesRunTime.boxToDouble((double)x$3._2$mcD$sp()), IndexedSeq$.MODULE$.canBuildFrom())).max((Ordering)Ordering.Double$.MODULE$));
        }
        if (MetricParams.EvaluationMetric.F1.equals(evaluationMetric)) {
            return BoxesRunTime.unboxToDouble((Object)((TraversableOnce)perf.f1().map((Function1 & Serializable & scala.Serializable)x$4 -> BoxesRunTime.boxToDouble((double)x$4._2$mcD$sp()), IndexedSeq$.MODULE$.canBuildFrom())).max((Ordering)Ordering.Double$.MODULE$));
        }
        if (MetricParams.EvaluationMetric.COST_MATRIX.equals(evaluationMetric)) {
            return BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps((double[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])perf.costMatrixGain())).map((Function1 & Serializable & scala.Serializable)x$5 -> BoxesRunTime.boxToDouble((double)x$5._2$mcD$sp()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())))).max((Ordering)Ordering.Double$.MODULE$));
        }
        if (MetricParams.EvaluationMetric.CUMULATIVE_LIFT.equals(evaluationMetric)) {
            return perf.lift();
        }
        if (MetricParams.EvaluationMetric.LOG_LOSS.equals(evaluationMetric)) {
            return perf.logLoss();
        }
        if (MetricParams.EvaluationMetric.ROC_AUC.equals(evaluationMetric)) {
            return perf.areaUnderROC();
        }
        if (MetricParams.EvaluationMetric.AVERAGE_PRECISION.equals(evaluationMetric)) {
            return ModelScorer$.MODULE$.averagePrecisionScore((RDD<Tuple2<Object, Object>>)rdd);
        }
        throw new IllegalArgumentException(new StringBuilder(45).append("Metric ").append(metric$1).append(" unsupported for binary classification").toString());
    }

    private static final double scoreMulticlass$1(RDD rdd, MetricParams.EvaluationMetric metric$1) {
        ModelScorer.ProbabilisticMulticlassMetrics perf = new ModelScorer.ProbabilisticMulticlassMetrics(42L, (RDD<Tuple3<Object, Object, double[]>>)rdd);
        MetricParams.EvaluationMetric evaluationMetric = metric$1;
        if (MetricParams.EvaluationMetric.ACCURACY.equals(evaluationMetric)) {
            return perf.accuracy_();
        }
        if (MetricParams.EvaluationMetric.RECALL.equals(evaluationMetric)) {
            return perf.averageRecall();
        }
        if (MetricParams.EvaluationMetric.PRECISION.equals(evaluationMetric)) {
            return perf.averagePrecision();
        }
        if (MetricParams.EvaluationMetric.F1.equals(evaluationMetric)) {
            return perf.f1();
        }
        if (MetricParams.EvaluationMetric.LOG_LOSS.equals(evaluationMetric)) {
            return perf.logLoss();
        }
        if (MetricParams.EvaluationMetric.ROC_AUC.equals(evaluationMetric)) {
            return perf.mrocAUC();
        }
        if (MetricParams.EvaluationMetric.AVERAGE_PRECISION.equals(evaluationMetric)) {
            return perf.maveragePrecision();
        }
        throw new IllegalArgumentException(new StringBuilder(45).append("Metric ").append(metric$1).append(" unsupported for binary classification").toString());
    }

    public static final /* synthetic */ double $anonfun$getScorer$7(PredictionMLTask.PredictionType predictionType$1, MetricParams.EvaluationMetric metric$1, String target$1, MetricParams metricParams$1, Model m, Dataset d) {
        PredictionMLTask.PredictionType predictionType = predictionType$1;
        if (PredictionMLTask.PredictionType.REGRESSION.equals(predictionType)) {
            return Metrics$.scoreRegression$1(Metrics$.transform$1("prediction", m, d, target$1), metric$1);
        }
        if (PredictionMLTask.PredictionType.BINARY_CLASSIFICATION.equals(predictionType)) {
            return Metrics$.scoreBinaryClassification$1(Metrics$.transform$1("probability", m, d, target$1).map((Function1 & Serializable & scala.Serializable)x0$1 -> {
                Tuple2 tuple2 = x0$1;
                if (tuple2 != null) {
                    double t = tuple2._1$mcD$sp();
                    Vector vec = (Vector)tuple2._2();
                    return new Tuple2.mcDD.sp(t, vec.apply(1));
                }
                throw new MatchError((Object)tuple2);
            }, ClassTag$.MODULE$.apply(Tuple2.class)), metricParams$1, metric$1);
        }
        if (PredictionMLTask.PredictionType.MULTICLASS.equals(predictionType)) {
            return Metrics$.scoreMulticlass$1(Metrics$.transform$1("probability", m, d, target$1).map((Function1 & Serializable & scala.Serializable)x0$2 -> {
                Tuple2 tuple2 = x0$2;
                if (tuple2 != null) {
                    double t = tuple2._1$mcD$sp();
                    Vector vec = (Vector)tuple2._2();
                    double[] arr = vec.toArray();
                    return new Tuple3((Object)BoxesRunTime.boxToDouble((double)t), (Object)BoxesRunTime.boxToDouble((double)((Tuple2)((TraversableOnce)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(arr)).view().zipWithIndex(IndexedSeqView$.MODULE$.arrCanBuildFrom())).maxBy((Function1 & Serializable & scala.Serializable)x$6 -> BoxesRunTime.boxToDouble((double)x$6._1$mcD$sp()), (Ordering)Ordering.Double$.MODULE$))._2$mcI$sp()), (Object)arr);
                }
                throw new MatchError((Object)tuple2);
            }, ClassTag$.MODULE$.apply(Tuple3.class)), metric$1);
        }
        throw new IllegalArgumentException(new StringBuilder(29).append("Unsupported prediction type: ").append(predictionType$1).toString());
    }

    private Metrics$() {
        MODULE$ = this;
    }
}

