/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.export.input;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.streamimpl.StreamRowFactory;
import com.dataiku.dip.datasets.StreamableDatasetSelection;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.export.ExportService;
import com.dataiku.dip.export.input.ExportInput;
import com.dataiku.dip.futures.DSSFuturePayloadUtils;
import com.dataiku.dip.futures.FuturePayload;
import com.dataiku.dip.input.formats.csv.CSVFormatConfig;
import com.dataiku.dip.input.formats.csv.RFC4180CSVParser;
import com.dataiku.dip.input.stream.InputStreamLineReader;
import com.dataiku.dip.input.stream.LineReader;
import com.dataiku.dip.utils.DKUFileUtils;
import java.io.File;
import java.io.InputStream;
import java.util.ArrayList;

public class ExportTrainTestDataInput
implements ExportInput {
    private static final String SPLIT_COLUMN_NAME = "row_origin";
    private static final String SPLIT_COLUMN_TRAIN_SET_VALUE = "train";
    private static final String SPLIT_COLUMN_TEST_SET_VALUE = "test";
    private final Schema schema;
    private final File trainFile;
    private final File testFile;
    private final ExportInput.InputDescription description;
    private final FullModelId fullModelId;
    private StreamRowFactory rowFactory;
    private ColumnFactory columnFactory;

    public ExportTrainTestDataInput(FullModelId fullModelId) throws Exception {
        this.fullModelId = fullModelId;
        SplitDesc splitDesc = fullModelId.getSplitDesc();
        this.schema = splitDesc.schema;
        if (this.schema.hasColumn(SPLIT_COLUMN_NAME)) {
            throw new IllegalArgumentException("Cannot export dataset. Input dataset already contains a column called 'row_origin'.");
        }
        this.schema.addColumn(SPLIT_COLUMN_NAME, Type.STRING);
        if (splitDesc.testPath == null || splitDesc.trainPath == null) {
            throw new IllegalArgumentException("Unsupported split mode: can only export train/test set splits.");
        }
        this.trainFile = fullModelId.getSplitFile(splitDesc.trainPath);
        this.testFile = fullModelId.getSplitFile(splitDesc.testPath);
        this.description = new ExportInput.InputDescription();
        this.description.name = "Train and test sets export";
        this.description.projectKey = fullModelId.getProjectKey();
        this.description.description = "Export the train and test sets to a new dataset.";
    }

    @Override
    public void initialize(ExportService.LocalExportJob job, StreamableDatasetSelection selection, ColumnFactory cf) throws Exception {
        this.rowFactory = new StreamRowFactory();
        this.columnFactory = cf;
    }

    @Override
    public long getInputSize() {
        return -1L;
    }

    @Override
    public ExportInput.InputDescription describe() {
        return this.description;
    }

    @Override
    public Schema getSchema() throws Exception {
        return this.schema;
    }

    @Override
    public void stream(ProcessorOutput output) throws Exception {
        this.streamOneFile(output, this.trainFile, SPLIT_COLUMN_TRAIN_SET_VALUE);
        this.streamOneFile(output, this.testFile, SPLIT_COLUMN_TEST_SET_VALUE);
        output.lastRowEmitted();
    }

    private void streamOneFile(ProcessorOutput output, File file, String splitColumnValue) throws Exception {
        try (InputStream gis = DKUFileUtils.readWithAutoDecompress((File)file);){
            CSVFormatConfig config = CSVFormatConfig.getStandardTabExcelFormat();
            RFC4180CSVParser csvParser = new RFC4180CSVParser((LineReader)new InputStreamLineReader(gis, config.charset), config.getSeparatorChar());
            ArrayList<Column> schemaColumns = new ArrayList<Column>();
            for (SchemaColumn sc : this.schema.getColumns()) {
                schemaColumns.add(this.columnFactory.column(sc.getName()));
            }
            Column splitColumn = (Column)schemaColumns.get(schemaColumns.size() - 1);
            ArrayList<String> csvRow = new ArrayList<String>();
            while (csvParser.next(csvRow)) {
                int schemaNumCols;
                int csvRowNumCols = csvRow.size();
                if (csvRowNumCols != (schemaNumCols = schemaColumns.size()) - 1) {
                    throw new IllegalArgumentException("Row has unexpected number of columns: " + csvRowNumCols + " (should be " + (schemaNumCols - 1) + ")");
                }
                Row row = this.rowFactory.row();
                for (int i = 0; i < csvRowNumCols; ++i) {
                    row.put((Column)schemaColumns.get(i), (String)csvRow.get(i));
                }
                row.put(splitColumn, splitColumnValue);
                output.emitRow(row);
            }
        }
    }

    @Override
    public FuturePayload.FuturePayloadTarget getSource() {
        return DSSFuturePayloadUtils.forFMI(this.fullModelId);
    }

    @Override
    public void close() {
    }
}

