/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.split;

import com.dataiku.dip.analysis.ml.prediction.flow.AbstractPredictionTrainingRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.split.ForcedSplitGenerator;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.prediction.split.SplitUtils;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.ProcessorOutputToSIP;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.datalayer.RowInputStream;
import com.dataiku.dip.datalayer.sort.RowAndSortMark;
import com.dataiku.dip.datalayer.sort.SortedRowsIterator;
import com.dataiku.dip.datalayer.sort.Sorter;
import com.dataiku.dip.datalayer.sort.SpilledRowsStorage;
import com.dataiku.dip.datalayer.streamimpl.StreamColumnFactory;
import com.dataiku.dip.datalayer.streamimpl.StreamRowFactory;
import com.dataiku.dip.datasets.UniversalSingleThreadPuller;
import com.dataiku.dip.input.utils.CountingProcessorOutput;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.shaker.server.ShakerStreamService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.joda.time.Instant;
import java.io.File;
import java.util.ArrayList;
import org.springframework.beans.factory.annotation.Autowired;

public class SSDForcedSplitGenerator
implements ForcedSplitGenerator {
    @Autowired
    private ShakerStreamService shakerStreamService;
    private final AuthCtx authCtx;
    private final Dataset dataset;
    private final SplitParams params;
    private final File targetFolder;
    private final SerializedShakerScript script;
    private final Schema preparationOutputSchema;
    private final AbstractPredictionTrainingRecipePayloadParams.OperationMode operationMode;
    DKULogger logger = DKULogger.getLogger((String)"dku.ml.prediction.split");

    public SSDForcedSplitGenerator(AuthCtx authCtx, Dataset dataset, SplitParams params, SerializedShakerScript script, Schema preparationOutputSchema, File targetFolder, AbstractPredictionTrainingRecipePayloadParams.OperationMode operationMode) {
        SpringUtils.getInstance().autowire((Object)this);
        assert (params.ttPolicy == SplitParams.TrainTestPolicy.SPLIT_SINGLE_DATASET);
        this.authCtx = authCtx;
        this.preparationOutputSchema = preparationOutputSchema;
        this.dataset = dataset;
        this.params = params;
        this.script = script;
        this.targetFolder = targetFolder;
        this.operationMode = operationMode;
    }

    @Override
    public SplitDesc compute() throws Exception {
        SplitDesc newDesc = new SplitDesc();
        newDesc.format = "csv1";
        newDesc.generationDate = System.currentTimeMillis();
        newDesc.params = (SplitParams)JSON.deepCopy((Object)this.params);
        newDesc.params.ssdDatasetSmartName = this.dataset.getSmartName(this.script.contextProjectKey);
        newDesc.schema = this.preparationOutputSchema;
        assert (this.operationMode == AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_ONLY || this.operationMode == AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_AND_FULL);
        StreamColumnFactory cf = new StreamColumnFactory();
        StreamRowFactory rf = new StreamRowFactory();
        File trainPath = SplitUtils.getSavedModelTrainSetFile(this.targetFolder);
        CountingProcessorOutput trainWriter = SplitUtils.getWriterToSingleFile(trainPath, newDesc.schema, (ColumnFactory)cf);
        ProcessorOutputToSIP trainPipeline = this.shakerStreamService.getProcessorOutput(this.authCtx, this.dataset.getProjectKey(), this.script, (ProcessorOutput)trainWriter, (ColumnFactory)cf, (RowFactory)rf);
        File testPath = SplitUtils.getSavedModelTestSetFile(this.targetFolder);
        CountingProcessorOutput testWriter = SplitUtils.getWriterToSingleFile(testPath, newDesc.schema, (ColumnFactory)cf);
        ProcessorOutputToSIP testPipeline = this.shakerStreamService.getProcessorOutput(this.authCtx, this.dataset.getProjectKey(), this.script, (ProcessorOutput)testWriter, (ColumnFactory)cf, (RowFactory)rf);
        ProcessorOutputToSIP fullPipeline = null;
        File fullPath = null;
        CountingProcessorOutput fullWriter = null;
        if (this.operationMode == AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_AND_FULL) {
            fullPath = SplitUtils.getSavedModelFullSetFile(this.targetFolder);
            fullWriter = SplitUtils.getWriterToSingleFile(fullPath, newDesc.schema, (ColumnFactory)cf);
            fullPipeline = this.shakerStreamService.getProcessorOutput(this.authCtx, this.dataset.getProjectKey(), this.script, (ProcessorOutput)fullWriter, (ColumnFactory)cf, (RowFactory)rf);
        }
        Sorter.MergeSortParams mergeSortParams = new Sorter.MergeSortParams();
        File tmpFolder = new File(this.targetFolder, "tmp-sort-" + Instant.now().getMillis());
        tmpFolder.mkdirs();
        try (SpilledRowsStorage storage = new SpilledRowsStorage(tmpFolder, SpilledRowsStorage.factoryColumnsOfSchema((ColumnFactory)cf, newDesc.schema), mergeSortParams);){
            Sorter.SortSpec spec = new Sorter.SortSpec(this.params.ssdColumn, this.params.testOnLargerValues);
            ArrayList<Sorter.SortSpec> specs = new ArrayList<Sorter.SortSpec>();
            specs.add(spec);
            Sorter sorter = new Sorter(specs, newDesc.schema, (ColumnFactory)cf, storage, (RowFactory)rf, (ColumnFactory)cf, mergeSortParams);
            Column column = cf.getColumn(this.params.ssdColumn);
            RowInputStream input = UniversalSingleThreadPuller.pull(this.authCtx, this.dataset, this.params.ssdSelection, (ColumnFactory)cf);
            long rowCount = sorter.emitAndCountOnlyValidRowsForColumn(input, this.params.ssdColumn);
            sorter.lastRowEmitted();
            storage.doneWriting();
            SortedRowsIterator iterator = sorter.read();
            long indexThreshold = (long)((double)rowCount * this.params.ssdTrainingRatio);
            RowAndSortMark row = null;
            for (long index = 0L; index < indexThreshold && iterator.hasNext(); ++index) {
                row = iterator.next();
                trainPipeline.emitRow(row.row.row);
                if (this.operationMode != AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_AND_FULL) continue;
                fullPipeline.emitRow(row.row.row);
            }
            this.logger.info((Object)("Sorted train/test split: threshold = " + (row == null ? "" : row.row.row.get(column))));
            while (iterator.hasNext()) {
                row = iterator.next();
                testPipeline.emitRow(row.row.row);
                if (this.operationMode != AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_AND_FULL) continue;
                fullPipeline.emitRow(row.row.row);
            }
            trainPipeline.lastRowEmitted();
            testPipeline.lastRowEmitted();
        }
        catch (Exception e) {
            this.logger.error((Object)"Failed creating sorted splits: ", (Throwable)e);
            trainPipeline.cancel();
            testPipeline.cancel();
            if (this.operationMode == AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_AND_FULL) {
                fullPipeline.cancel();
            }
            throw e;
        }
        newDesc.trainPath = trainPath.getName();
        newDesc.testPath = testPath.getName();
        newDesc.trainRows = trainWriter.getCount();
        newDesc.testRows = testWriter.getCount();
        if (this.operationMode == AbstractPredictionTrainingRecipePayloadParams.OperationMode.TRAIN_SPLITTED_AND_FULL) {
            fullPipeline.lastRowEmitted();
            newDesc.fullPath = fullPath.getName();
            newDesc.fullRows = fullWriter.getCount();
        }
        return newDesc;
    }
}

