/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.export.input;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.PredictedDataService;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.RowInputStream;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.datasets.StreamableDatasetSelection;
import com.dataiku.dip.datasets.UniversalSingleThreadPuller;
import com.dataiku.dip.datasets.fs.BuiltinFSDatasets;
import com.dataiku.dip.datasets.fs.FilesystemDatasetConfig;
import com.dataiku.dip.export.ExportService;
import com.dataiku.dip.export.input.ExportInput;
import com.dataiku.dip.futures.DSSFuturePayloadUtils;
import com.dataiku.dip.futures.FuturePayload;
import com.dataiku.dip.input.formats.csv.CSVFormatConfig;
import com.dataiku.dip.input.formats.csv.CSVFormatExtractor;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.shaker.TableToSchema;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import java.io.File;
import java.io.IOException;
import org.springframework.beans.factory.annotation.Autowired;

public class ExportPredictedDataInput
implements ExportInput {
    @Autowired
    private PredictedDataService predictedDataService;
    private final FullModelId fullModelId;
    private AuthCtx authCtx;
    private final ExportInput.InputDescription description;
    private Schema schema;
    private MemTable table;
    private StreamableDatasetSelection selection;
    private ColumnFactory columnFactory;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.ml.predicted.data.export");

    public ExportPredictedDataInput(FullModelId fullModelId, AuthCtx authCtx) {
        this.fullModelId = fullModelId;
        this.authCtx = authCtx;
        this.description = new ExportInput.InputDescription();
        this.description.name = "Predicted data export";
        this.description.projectKey = fullModelId.getProjectKey();
        this.description.description = "Export this model's predicted data to a new dataset.";
    }

    @Override
    public void initialize(ExportService.LocalExportJob job, StreamableDatasetSelection selection, ColumnFactory columnFactory) throws Exception {
        this.columnFactory = columnFactory;
        this.selection = selection;
        SpringUtils.getInstance().autowire((Object)this);
        SerializedShakerScript.ShakerExplorationSampleSettings explorationSampling = this.fullModelId.getHeadMLTask().predictionDisplayScript.explorationSampling;
        this.table = this.predictedDataService.getUncachedUnfiltered_NOTRANSACTION((FullModelId)this.fullModelId, (SerializedShakerScript.ShakerExplorationSampleSettings)explorationSampling).table;
        this.schema = TableToSchema.inferSchemaSimple(this.table, true);
    }

    @Override
    public long getInputSize() {
        return -1L;
    }

    @Override
    public ExportInput.InputDescription describe() {
        return this.description;
    }

    @Override
    public Schema getSchema() throws Exception {
        return this.schema;
    }

    @Override
    public void stream(ProcessorOutput output) throws Exception {
        MLTask task = this.fullModelId.getHeadMLTask();
        SplitDesc splitDesc = this.fullModelId.getSplitDesc();
        File testsetFile = this.getTestsetFile(task, splitDesc);
        Schema testsetSchema = splitDesc.schema;
        File predictionFile = this.getPredictionFile(task);
        Schema predictedSchema = this.getPredictionSchema(task);
        if (predictionFile != null) {
            if (testsetFile != null) {
                this.streamAndMergeTestsetAndPrediction(output, testsetFile, testsetSchema, predictionFile, predictedSchema);
            } else {
                this.streamPrediction(output, predictionFile, predictedSchema);
            }
        } else {
            logger.info((Object)"Model does not have test set file nor prediction file, nothing to export");
            output.lastRowEmitted();
        }
    }

    private void streamAndMergeTestsetAndPrediction(ProcessorOutput output, File testsetFile, Schema testsetSchema, File predictionFile, Schema predictedSchema) throws Exception {
        logger.info((Object)String.format("Streaming and merging %s and %s for prediction data export.", testsetFile.getAbsolutePath(), predictionFile.getAbsolutePath()));
        PredictedDataReader predictedDataReader = new PredictedDataReader(this.authCtx, this.selection, this.columnFactory, PredictedDataService.getDefaultDerivedColumnsComputer(this.fullModelId), testsetFile, predictionFile, testsetSchema, predictedSchema);
        predictedDataReader.stream(output);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void streamPrediction(ProcessorOutput output, File predictionFile, Schema predictedSchema) throws Exception {
        logger.info((Object)String.format("Streaming %s for prediction data export.", predictionFile.getAbsolutePath()));
        Dataset predictedDataset = ExportPredictedDataInput.buildEphemeralExportDataset(predictionFile, predictedSchema, true);
        RowInputStream predictedRowInputStream = UniversalSingleThreadPuller.pull(this.authCtx, predictedDataset, this.selection, this.columnFactory);
        Row predictedRow = null;
        try {
            predictedRow = predictedRowInputStream.next();
            while (predictedRow != null) {
                output.emitRow(predictedRow);
                predictedRow = predictedRowInputStream.next();
            }
        }
        finally {
            if (predictedRow != null) {
                PredictedDataReader.closeRowInputStream(predictedRowInputStream);
            }
        }
        output.lastRowEmitted();
    }

    private File getTestsetFile(MLTask task, SplitDesc splitDesc) {
        File testsetFile = null;
        if (task.taskType == MLTask.MLTaskType.CLUSTERING) {
            testsetFile = DKUFileUtils.getWithin((File)this.fullModelId.getSplitFolder(), (String[])new String[]{splitDesc.fullPath});
        } else if (task.taskType == MLTask.MLTaskType.PREDICTION && !(task instanceof PredictionMLTask.TimeseriesForecastingMLTask)) {
            String datasetFilename = splitDesc.params.kfold ? splitDesc.fullPath : splitDesc.testPath;
            testsetFile = DKUFileUtils.getWithin((File)this.fullModelId.getSplitFolder(), (String[])new String[]{datasetFilename});
        }
        return testsetFile;
    }

    private File getPredictionFile(MLTask task) {
        File predictionFile = null;
        if (task.taskType == MLTask.MLTaskType.CLUSTERING) {
            predictionFile = this.fullModelId.getModelFile("clustered.csv");
        } else if (task.taskType == MLTask.MLTaskType.PREDICTION) {
            predictionFile = this.fullModelId.getModelFile("predicted.csv");
        }
        return predictionFile;
    }

    private Schema getPredictionSchema(MLTask task) throws IOException {
        Schema predictedSchema = null;
        if (task.taskType == MLTask.MLTaskType.CLUSTERING) {
            predictedSchema = this.fullModelId.getClusteredSchema();
        } else if (task.taskType == MLTask.MLTaskType.PREDICTION) {
            predictedSchema = this.fullModelId.getPredictedSchema();
        }
        if (predictedSchema == null && this.getPredictionFile(task) != null) {
            throw new RuntimeException("Unable to define prediction schema");
        }
        return predictedSchema;
    }

    @Override
    public FuturePayload.FuturePayloadTarget getSource() {
        return DSSFuturePayloadUtils.forFMI(this.fullModelId);
    }

    @Override
    public void close() {
    }

    private static Dataset buildEphemeralExportDataset(File file, Schema schema, boolean firstLineIsHeader) {
        Dataset ephemeralDataset = new Dataset();
        ephemeralDataset.setType(BuiltinFSDatasets.FS_META.getType());
        ephemeralDataset.setFullName("export.dataset");
        FilesystemDatasetConfig fdc = new FilesystemDatasetConfig();
        fdc.path = file.getAbsolutePath();
        ephemeralDataset.setParams(fdc);
        ephemeralDataset.setSchema(schema);
        ephemeralDataset.setFormatType(CSVFormatExtractor.META.getType());
        CSVFormatConfig formatParams = CSVFormatConfig.getStandardTabUNIXFormat();
        formatParams.parseHeaderRow = firstLineIsHeader;
        ephemeralDataset.setFormatParams(formatParams);
        return ephemeralDataset;
    }

    public static class PredictedDataReader {
        private final AuthCtx authCtx;
        private final StreamableDatasetSelection selection;
        private final ColumnFactory columnFactory;
        private final PredictedDataService.DerivedColumnsComputer derivedColumnsComputer;
        private final Schema predictedSchema;
        private final Dataset testsetDataset;
        private final Dataset predictedDataset;
        private long rowsCount = 0L;

        public PredictedDataReader(AuthCtx authCtx, StreamableDatasetSelection selection, ColumnFactory columnFactory, PredictedDataService.DerivedColumnsComputer derivedColumnsComputer, File testsetFile, File predictionFile, Schema testsetSchema, Schema predictedSchema) {
            this.authCtx = authCtx;
            this.selection = selection;
            this.columnFactory = columnFactory;
            this.predictedSchema = predictedSchema;
            this.testsetDataset = ExportPredictedDataInput.buildEphemeralExportDataset(testsetFile, testsetSchema, false);
            this.predictedDataset = ExportPredictedDataInput.buildEphemeralExportDataset(predictionFile, predictedSchema, true);
            this.derivedColumnsComputer = derivedColumnsComputer;
        }

        public void stream(ProcessorOutput processorOutput) throws Exception {
            this.verifyMatchingLength_NT(this.testsetDataset, this.predictedDataset);
            this.streamAndMerge_NT(processorOutput, this.testsetDataset, this.predictedDataset);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void streamAndMerge_NT(ProcessorOutput processorOutput, Dataset testsetDataset, Dataset predictedDataset) throws Exception {
            RowInputStream testsetRowInputStream = UniversalSingleThreadPuller.pull(this.authCtx, testsetDataset, this.selection, this.columnFactory);
            RowInputStream predictedRowInputStream = UniversalSingleThreadPuller.pull(this.authCtx, predictedDataset, this.selection, this.columnFactory);
            Row row = null;
            Row predictedRow = null;
            this.derivedColumnsComputer.initializeComputeColumns(this.columnFactory);
            try {
                while (!this.isExportComplete(row = testsetRowInputStream.next(), predictedRow = predictedRowInputStream.next())) {
                    if (row == null || predictedRow == null) {
                        throw new IOException("Cannot combine test set and prediction files, mismatch in rows count.");
                    }
                    for (int columnIndex = 0; columnIndex < this.predictedSchema.getColumns().size(); ++columnIndex) {
                        Column predictedColumn = this.columnFactory.column(((SchemaColumn)this.predictedSchema.getColumns().get(columnIndex)).getName());
                        String predictedColumnValue = predictedRow.get(predictedColumn);
                        row.with(predictedColumn, predictedColumnValue);
                    }
                    this.derivedColumnsComputer.compute(row, this.columnFactory);
                    processorOutput.emitRow(row);
                    ++this.rowsCount;
                }
            }
            finally {
                if (row != null) {
                    PredictedDataReader.closeRowInputStream(testsetRowInputStream);
                }
                if (predictedRow != null) {
                    PredictedDataReader.closeRowInputStream(predictedRowInputStream);
                }
            }
            this.derivedColumnsComputer.cleanEphemeralComputeColumns(this.columnFactory);
            processorOutput.lastRowEmitted();
        }

        protected static void closeRowInputStream(RowInputStream rowInputStream) throws Exception {
            while (rowInputStream.next() != null) {
            }
        }

        private void verifyMatchingLength_NT(Dataset testsetDataset, Dataset predictedDataset) throws Exception {
            RowInputStream testsetRowInputStream = UniversalSingleThreadPuller.pull(this.authCtx, testsetDataset, this.selection, this.columnFactory);
            int testsetRowCount = 0;
            while (testsetRowInputStream.next() != null) {
                ++testsetRowCount;
            }
            RowInputStream predictedRowInputStream = UniversalSingleThreadPuller.pull(this.authCtx, predictedDataset, this.selection, this.columnFactory);
            int predictionRowsCount = 0;
            while (predictedRowInputStream.next() != null) {
                ++predictionRowsCount;
            }
            if (testsetRowCount != predictionRowsCount) {
                String errorMessage = String.format("Cannot combine test set and prediction files, mismatch in rows count, test set has %d, prediction has %s", testsetRowCount, predictionRowsCount);
                logger.error((Object)errorMessage);
                throw ErrorContext.iae((String)errorMessage);
            }
        }

        private boolean isExportComplete(Row testsetRow, Row predictedRow) {
            return testsetRow == null && predictedRow == null;
        }
    }
}

