/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.split;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.ProcessorOutputToSIP;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.datalayer.streamimpl.StreamColumnFactory;
import com.dataiku.dip.datalayer.streamimpl.StreamRowFactory;
import com.dataiku.dip.datasets.StreamableDatasetSelection;
import com.dataiku.dip.datasets.UniversalSingleThreadPusher;
import com.dataiku.dip.datasets.fs.LocalFSProvider;
import com.dataiku.dip.fs.FSProvider;
import com.dataiku.dip.input.formats.csv.CSVFormatConfig;
import com.dataiku.dip.input.utils.CountingProcessorOutput;
import com.dataiku.dip.output.CSVOutputFormatter;
import com.dataiku.dip.output.OutputFormatter;
import com.dataiku.dip.output.SingleFileOutputWriter;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.shaker.server.ShakerStreamService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.io.File;
import java.util.Random;

public class SplitUtils {
    static boolean gzipSplits = ApplicationConfigurator.getParams().getBoolParam("dku.ml.splits.gzip", true);
    private static String gzExt = gzipSplits ? ".gz" : "";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.ml.splits.utils");

    public static File getMlTaskTrainSetFile(MLTaskLoc taskLoc, String instanceId) {
        return new File(taskLoc.getSplitsFolder(), "train-" + instanceId + ".csv" + gzExt);
    }

    public static File getMlTaskTestSetFile(MLTaskLoc taskLoc, String instanceId) {
        return new File(taskLoc.getSplitsFolder(), "test-" + instanceId + ".csv" + gzExt);
    }

    public static File getMlTaskFullSetFile(MLTaskLoc taskLoc, String instanceId) {
        return new File(taskLoc.getSplitsFolder(), "full-" + instanceId + ".csv" + gzExt);
    }

    public static File getSavedModelTrainSetFile(File splitFolder) {
        return new File(splitFolder, "train.csv" + gzExt);
    }

    public static File getSavedModelTestSetFile(File splitFolder) {
        return new File(splitFolder, "test.csv" + gzExt);
    }

    public static File getSavedModelFullSetFile(File splitFolder) {
        return new File(splitFolder, "full.csv" + gzExt);
    }

    public static CountingProcessorOutput getWriterToSingleFile(File outputFile, Schema schema, ColumnFactory factory) throws Exception {
        CSVFormatConfig config = CSVFormatConfig.getStandardTabExcelFormat();
        CSVOutputFormatter testFormatter = new CSVOutputFormatter(config);
        testFormatter.setOutputSchema(schema);
        SingleFileOutputWriter testWriter = new SingleFileOutputWriter((FSProvider)LocalFSProvider.makeFSProviderOnRoot("/"), outputFile.getAbsolutePath(), (OutputFormatter)testFormatter);
        testWriter.init(factory);
        return new CountingProcessorOutput((ProcessorOutput)testWriter);
    }

    public static long pushEFD(AuthCtx authCtx, File splitPath, SplitParams.EFDSplit split, Dataset dataset, Schema outputSchema, SerializedShakerScript script, ShakerStreamService shakerStreamService) throws Exception {
        StreamColumnFactory cf = new StreamColumnFactory();
        StreamRowFactory rf = new StreamRowFactory();
        CountingProcessorOutput splitWriter = SplitUtils.getWriterToSingleFile(splitPath, outputSchema, (ColumnFactory)cf);
        ProcessorOutputToSIP trainPipeline = shakerStreamService.getProcessorOutput(authCtx, dataset.getProjectKey(), script, (ProcessorOutput)splitWriter, (ColumnFactory)cf, (RowFactory)rf);
        StreamableDatasetSelection sel = (StreamableDatasetSelection)((Object)JSON.deepCopy((Object)((Object)split.selection)));
        sel.filter = split.filter;
        UniversalSingleThreadPusher ustp = new UniversalSingleThreadPusher(authCtx, dataset, (ProcessorOutput)trainPipeline, (ColumnFactory)cf, (RowFactory)rf);
        ustp.setDatasetSelection(sel);
        ustp.push();
        trainPipeline.lastRowEmitted();
        return splitWriter.getCount();
    }

    public static class RandomSplitter
    extends AbstractSplitter {
        Random random = new Random();

        RandomSplitter(ProcessorOutput train, ProcessorOutput test, long seed, double ratio) {
            this.train = train;
            this.test = test;
            this.ratio = ratio;
            this.random.setSeed(seed);
        }

        public void emitRow(Row row) throws Exception {
            double rnd = this.random.nextDouble();
            if (rnd < this.ratio) {
                this.train.emitRow(row);
            } else {
                this.test.emitRow(row);
            }
        }
    }

    public static abstract class AbstractSplitter
    implements ProcessorOutput {
        protected double ratio;
        ProcessorOutput train;
        ProcessorOutput test;

        public void lastRowEmitted() throws Exception {
            this.train.lastRowEmitted();
            this.test.lastRowEmitted();
        }

        public void cancel() throws Exception {
            this.train.cancel();
            this.test.cancel();
        }

        public void setMaxMemoryUsed(long size) {
            this.train.setMaxMemoryUsed(size);
            this.test.setMaxMemoryUsed(size);
        }
    }
}

