/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.spark.ml.prediction;

import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.spark.ml.prediction.KFoldCrossValidator;
import com.dataiku.dip.spark.ml.prediction.Metrics;
import com.dataiku.dip.spark.ml.prediction.ParameterSelection;
import com.dataiku.dip.spark.ml.prediction.ValidationSetValidator;
import com.dataiku.dip.spark.ml.prediction.Validator;
import java.io.Serializable;
import org.apache.log4j.Logger;
import org.apache.spark.mllib.util.MLUtils$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.IndexedSeq;
import scala.collection.SeqLike;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;

public final class ParameterSelection$ {
    public static ParameterSelection$ MODULE$;
    private final Logger LOGGER;

    static {
        new ParameterSelection$();
    }

    public Logger LOGGER() {
        return this.LOGGER;
    }

    public Validator getValidator(PredictionModelingParams.GridSearchParams params, Function0<Object> abortHook, Metrics.Scorer scorer, ParameterSelection.ValidationStartingPoint startingPoint, Function1<ParameterSelection.FoldScore, BoxedUnit> dumper) {
        PredictionModelingParams.GridSearchCrossValidationMode gridSearchCrossValidationMode = params.mode;
        if (PredictionModelingParams.GridSearchCrossValidationMode.SHUFFLE.equals(gridSearchCrossValidationMode)) {
            return new ValidationSetValidator(abortHook, 1.0 - (double)params.splitRatio, params.cvSeed, params.randomized, params.timeout, params.nIter, scorer, startingPoint, dumper);
        }
        if (PredictionModelingParams.GridSearchCrossValidationMode.KFOLD.equals(gridSearchCrossValidationMode)) {
            return new KFoldCrossValidator(abortHook, params.nFolds, params.cvSeed, params.randomized, params.timeout, params.nIter, scorer, startingPoint, dumper);
        }
        throw new UnsupportedOperationException(new StringBuilder(39).append("Cross validation mode ").append(gridSearchCrossValidationMode).append(" is not supported").toString());
    }

    public IndexedSeq<Tuple2<Dataset<Row>, Dataset<Row>>> folds(int k, int seed, Dataset<Row> data) {
        StructType schema = data.schema();
        SQLContext sql = data.sqlContext();
        return (IndexedSeq)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])MLUtils$.MODULE$.kFold(data.rdd(), k, seed, ClassTag$.MODULE$.apply(Row.class)))).map((Function1 & Serializable & scala.Serializable)x0$1 -> {
            Tuple2 tuple2 = x0$1;
            if (tuple2 != null) {
                RDD u = (RDD)tuple2._1();
                RDD v = (RDD)tuple2._2();
                return new Tuple2((Object)ParameterSelection$.toDF$1(u, sql, schema), (Object)ParameterSelection$.toDF$1(v, sql, schema));
            }
            throw new MatchError((Object)tuple2);
        }, Array$.MODULE$.fallbackCanBuildFrom(Predef.DummyImplicit$.MODULE$.dummyImplicit()));
    }

    public Tuple2<Dataset<Row>, Dataset<Row>> trainVal(double valFraction, int seed, Dataset<Row> data) {
        StructType schema = data.schema();
        SQLContext sql = data.sqlContext();
        Dataset[] datasetArray = (Dataset[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])data.rdd().randomSplit(new double[]{1.0 - valFraction, valFraction}, (long)seed))).map((Function1 & Serializable & scala.Serializable)rdd -> ParameterSelection$.toDF$2(rdd, sql, schema), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Dataset.class)));
        Option option = Array$.MODULE$.unapplySeq((Object)datasetArray);
        if (option.isEmpty() || option.get() == null || ((SeqLike)option.get()).lengthCompare(2) != 0) {
            throw new MatchError((Object)datasetArray);
        }
        Dataset x = (Dataset)((SeqLike)option.get()).apply(0);
        Dataset y = (Dataset)((SeqLike)option.get()).apply(1);
        Tuple2 tuple2 = new Tuple2((Object)x, (Object)y);
        Dataset x2 = (Dataset)tuple2._1();
        Dataset y2 = (Dataset)tuple2._2();
        return new Tuple2((Object)x2, (Object)y2);
    }

    private static final Dataset toDF$1(RDD rdd, SQLContext sql$1, StructType schema$1) {
        return sql$1.createDataFrame(rdd, schema$1);
    }

    private static final Dataset toDF$2(RDD rdd, SQLContext sql$2, StructType schema$2) {
        return sql$2.createDataFrame(rdd, schema$2);
    }

    private ParameterSelection$() {
        MODULE$ = this;
        this.LOGGER = Logger.getLogger((String)"dku.doctor");
    }
}

