/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.split;

import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.prediction.split.AbstractSingleDatasetSplitGenerator;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.prediction.split.SplitUtils;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.ProcessorOutputToSIP;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.datalayer.RowInputStream;
import com.dataiku.dip.datalayer.sort.RowAndSortMark;
import com.dataiku.dip.datalayer.sort.SortedRowsIterator;
import com.dataiku.dip.datalayer.sort.Sorter;
import com.dataiku.dip.datalayer.sort.SpilledRowsStorage;
import com.dataiku.dip.datalayer.streamimpl.StreamColumnFactory;
import com.dataiku.dip.datalayer.streamimpl.StreamRowFactory;
import com.dataiku.dip.datasets.UniversalSingleThreadPuller;
import com.dataiku.dip.input.utils.CountingProcessorOutput;
import com.dataiku.dip.mec.KernelsModelEvaluationStoresService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.datasets.DatasetAccessService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.shaker.server.ShakerStreamService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelib.org.joda.time.Instant;
import java.io.File;
import java.util.ArrayList;
import org.apache.commons.codec.digest.DigestUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class SortedSingleDatasetSplitGenerator
extends AbstractSingleDatasetSplitGenerator {
    @Autowired
    private ShakerStreamService shakerStreamService;
    @Autowired
    private DatasetAccessService datasetAccessService;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private KernelsModelEvaluationStoresService kernelsModelEvaluationStoresService;
    DKULogger logger = DKULogger.getLogger((String)"dku.ml.prediction.split");

    public SortedSingleDatasetSplitGenerator(MLTaskLoc taskLoc, AnalysisCoreParams coreParams, SplitParams params, AuthCtx authCtx) {
        super(taskLoc, coreParams, params, authCtx);
        assert (params.ssdSplitMode == SplitParams.SplitMode.SORTED);
    }

    @Override
    public String getPolicyId() {
        String policyId = "type=" + String.valueOf((Object)this.params.ttPolicy) + (String)(this.params.kfold ? ",split=SORTED_KFOLD,folds=" + this.params.nFolds : ",split=SORTED") + ",splitBeforePrepare=" + this.params.splitBeforePrepare + ",ds=" + (this.params.ssdDatasetSmartName == null ? this.getDatasetLoc().getSmartName(this.taskLoc.analysisProjectKey) : this.params.ssdDatasetSmartName) + ",sel=(" + this.params.ssdSelection.getIdentifier() + "),r=" + this.params.ssdTrainingRatio;
        policyId = policyId + ",c=" + this.params.ssdColumn + ",ascending=" + this.params.testOnLargerValues;
        if (this.params.splitBeforePrepare) {
            return policyId;
        }
        return policyId + ",script=" + DigestUtils.md5Hex((String)this.scriptStepsPrettyStr);
    }

    @Override
    public SplitDesc updateSplitAndSplitDesc(SplitDesc splitDesc, String expectedInstanceId) throws Exception {
        Dataset dataset;
        try (Transaction t = this.transactionService.beginRead();){
            dataset = this.datasetAccessService.getMandatory(this.getDatasetLoc());
        }
        StreamRowFactory rf = new StreamRowFactory();
        StreamColumnFactory cf = new StreamColumnFactory();
        File trainPath = SplitUtils.getMlTaskTrainSetFile(this.taskLoc, expectedInstanceId);
        CountingProcessorOutput trainWriter = SplitUtils.getWriterToSingleFile(trainPath, splitDesc.schema, (ColumnFactory)cf);
        ProcessorOutputToSIP trainPipeline = this.shakerStreamService.getProcessorOutput(this.authCtx, dataset.getProjectKey(), this.coreParams.script, (ProcessorOutput)trainWriter, (ColumnFactory)cf, (RowFactory)rf);
        File testPath = SplitUtils.getMlTaskTestSetFile(this.taskLoc, expectedInstanceId);
        CountingProcessorOutput testWriter = SplitUtils.getWriterToSingleFile(testPath, splitDesc.schema, (ColumnFactory)cf);
        ProcessorOutputToSIP testPipeline = this.shakerStreamService.getProcessorOutput(this.authCtx, dataset.getProjectKey(), this.coreParams.script, (ProcessorOutput)testWriter, (ColumnFactory)cf, (RowFactory)rf);
        Sorter.MergeSortParams mergeSortParams = new Sorter.MergeSortParams();
        File tmpFolder = new File(this.taskLoc.getSplitsFolder(), "tmp-sort-" + expectedInstanceId + "-" + Instant.now().getMillis());
        tmpFolder.mkdirs();
        try (SpilledRowsStorage storage = new SpilledRowsStorage(tmpFolder, SpilledRowsStorage.factoryColumnsOfSchema((ColumnFactory)cf, splitDesc.schema), mergeSortParams);){
            Sorter.SortSpec spec = new Sorter.SortSpec(this.params.ssdColumn, this.params.testOnLargerValues);
            ArrayList<Sorter.SortSpec> specs = new ArrayList<Sorter.SortSpec>();
            specs.add(spec);
            Sorter sorter = new Sorter(specs, splitDesc.schema, (ColumnFactory)cf, storage, (RowFactory)rf, (ColumnFactory)cf, mergeSortParams);
            Column column = cf.getColumn(this.params.ssdColumn);
            RowInputStream input = UniversalSingleThreadPuller.pull(this.authCtx, dataset, this.params.ssdSelection, (ColumnFactory)cf);
            long rowCount = sorter.emitAndCountOnlyValidRowsForColumn(input, this.params.ssdColumn);
            sorter.lastRowEmitted();
            storage.doneWriting();
            SortedRowsIterator iterator = sorter.read();
            long indexThreshold = (long)((double)rowCount * this.params.ssdTrainingRatio);
            RowAndSortMark row = null;
            for (long index = 0L; index < indexThreshold && iterator.hasNext(); ++index) {
                row = iterator.next();
                trainPipeline.emitRow(row.row.row);
            }
            if (row != null) {
                this.logger.info((Object)("Sorted train/test split: threshold = " + row.row.row.get(column)));
            }
            while (iterator.hasNext()) {
                row = iterator.next();
                testPipeline.emitRow(row.row.row);
            }
            trainPipeline.lastRowEmitted();
            testPipeline.lastRowEmitted();
        }
        catch (Exception e) {
            this.logger.error((Object)"Failed creating sorted splits: ", (Throwable)e);
            trainPipeline.cancel();
            testPipeline.cancel();
            throw e;
        }
        splitDesc.trainPath = trainPath.getName();
        splitDesc.testPath = testPath.getName();
        splitDesc.trainRows = trainWriter.getCount();
        splitDesc.testRows = testWriter.getCount();
        return splitDesc;
    }
}

