/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.eda;

import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.streamimpl.StreamColumnFactory;
import com.dataiku.dip.datalayer.streamimpl.StreamRowFactory;
import com.dataiku.dip.eda.compute.computations.Computation;
import com.dataiku.dip.eda.compute.computations.ComputationResult;
import com.dataiku.dip.eda.compute.computations.multivariate.FetchCSV;
import com.dataiku.dip.eda.compute.computations.multivariate.PCA;
import com.dataiku.dip.eda.compute.engine.ComputationResultDataStreamer;
import com.dataiku.dip.eda.compute.engine.ComputationResultSession;
import com.dataiku.dip.eda.worksheets.cards.PCACard;
import com.dataiku.dip.input.formats.csv.RFC4180CSVParser;
import com.dataiku.dip.input.stream.InputStreamLineReader;
import com.dataiku.dip.input.stream.LineReader;
import com.dataiku.dip.recipes.eda.EDARecipeRunner;
import com.dataiku.dip.recipes.eda.EDASchemaColumns;
import com.dataiku.dip.recipes.eda.PCARecipeMeta;
import com.dataiku.dip.recipes.eda.PCARecipeOutputRole;
import com.dataiku.dip.recipes.eda.PCARecipePayloadParams;
import com.dataiku.dip.utils.ErrorContext;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;

public class PCARecipeRunner
extends EDARecipeRunner<PCARecipePayloadParams> {
    public PCARecipeRunner(JobActivity activity) {
        super(activity);
    }

    @Override
    protected PCARecipePayloadParams parsePayload(String payload) {
        return PCARecipeMeta.META.parsePayload(payload);
    }

    @Override
    protected void validateRecipe() {
        super.validateRecipe();
        if (this.recipe.getSuccessors().size() == 0) {
            throw ErrorContext.iae((String)"The recipe requires at least one target, but none was found");
        }
        if (this.recipe.getSuccessors().size() > 3) {
            throw ErrorContext.iaef((String)"The recipe can have at most 3 targets, but %s were found", (Object)this.recipe.getSuccessors().size(), (Object[])new Object[0]);
        }
    }

    @Override
    protected Computation getComputationPlan() {
        int projectionDimension = ((PCARecipePayloadParams)this.payloadParams).columns.size();
        PCARecipePayloadParams.ProjectionsSchemaDetails schemaDetails = ((PCARecipePayloadParams)this.payloadParams).computeProjectionsSchema(this.sourceDs.getSchema());
        return new PCA(((PCARecipePayloadParams)this.payloadParams).columns, new FetchCSV(schemaDetails.fetchableColumnNames), projectionDimension, "input_");
    }

    @Override
    protected void writeOutputs(ComputationResultSession session) throws Exception {
        ComputationResult rawResult = session.getComputationResult();
        if (rawResult.isAtLeastPartiallyFailed()) {
            throw new IllegalArgumentException(String.format("Computation error: %s", rawResult.getFirstFailedResult().message));
        }
        PCA.PCAResult result = rawResult.as(PCA.PCAResult.class);
        int nComponents = result.eigenvalues.length;
        if (nComponents != result.eigenvectors.length) {
            throw new IllegalArgumentException(String.format("Expected the same number of eigenvalues as eigenvectors, but got %d eigenvalues and %d eigenvectors", nComponents, result.eigenvectors.length));
        }
        if (nComponents > ((PCARecipePayloadParams)this.payloadParams).columns.size()) {
            throw new IllegalArgumentException(String.format("Expected at most %d components, got %d", ((PCARecipePayloadParams)this.payloadParams).columns.size(), nComponents));
        }
        if (this.hasRecipeOutput(PCARecipeOutputRole.PROJECTIONS.name)) {
            this.writeProjections(result, session.getDataStreamer());
        }
        if (this.hasRecipeOutput(PCARecipeOutputRole.EIGENVECTORS.name)) {
            this.writeEigenvectors(result);
        }
        if (this.hasRecipeOutput(PCARecipeOutputRole.EIGENVALUES.name)) {
            this.writeEigenvalues(result);
        }
    }

    private void writeProjections(PCA.PCAResult result, ComputationResultDataStreamer dataStreamer) throws Exception {
        ComputationResultDataStreamer.DataStreamId projectionStreamId = result.projectionComputationResult.as(FetchCSV.FetchCSVResult.class).dataStreamId;
        StreamRowFactory rf = new StreamRowFactory();
        StreamColumnFactory cf = new StreamColumnFactory();
        ProcessorOutput writer = this.buildOutputDatasetWriter(PCARecipeOutputRole.PROJECTIONS.name, (ColumnFactory)cf);
        PCARecipePayloadParams.ProjectionsSchemaDetails schemaDetails = ((PCARecipePayloadParams)this.payloadParams).computeProjectionsSchema(this.sourceDs.getSchema());
        int expectedColumnDimension = schemaDetails.fetchableColumnNames.size();
        int actualColumnDimension = expectedColumnDimension - (((PCARecipePayloadParams)this.payloadParams).columns.size() - result.eigenvalues.length);
        try (InputStream inputStream = dataStreamer.streamData(projectionStreamId);){
            InputStreamLineReader islr = new InputStreamLineReader(inputStream, StandardCharsets.UTF_8);
            RFC4180CSVParser csvParser = new RFC4180CSVParser((LineReader)islr, ',');
            ArrayList<String> rowInput = new ArrayList<String>();
            int rowIndex = 0;
            while (csvParser.next(rowInput)) {
                if (rowInput.size() != expectedColumnDimension) {
                    throw new IllegalArgumentException(String.format("Expected exactly %d columns for projections, got %d for row %d", expectedColumnDimension, rowInput.size(), rowIndex));
                }
                Row row = rf.row();
                int colIndex = 0;
                for (String rawCellValue : rowInput) {
                    if (rawCellValue.length() > 0) {
                        SchemaColumn sc = schemaDetails.schemaColumns.get(colIndex);
                        if (colIndex + 1 > actualColumnDimension) {
                            throw new IllegalArgumentException(String.format("Expected projections on only %d components, but there are non-empty projections on component '%s' (column %d) for row %d", result.eigenvalues.length, sc.getName(), colIndex, rowIndex));
                        }
                        if (sc.getType().isInteger()) {
                            row.put(cf.column(sc.getName()), this.normalizeInteger(rawCellValue));
                        } else {
                            row.put(cf.column(sc.getName()), rawCellValue);
                        }
                    }
                    ++colIndex;
                }
                if (((PCARecipePayloadParams)this.payloadParams).addComputationTimestamp) {
                    row.put(cf.column("computation_timestamp"), this.computationTimestamp);
                }
                writer.emitRow(row);
                ++rowIndex;
            }
        }
        writer.lastRowEmitted();
    }

    private String normalizeInteger(String rawCellValue) {
        if (rawCellValue.length() >= 3 && rawCellValue.endsWith(".0")) {
            return rawCellValue.substring(0, rawCellValue.length() - 2);
        }
        return rawCellValue;
    }

    private void writeEigenvectors(PCA.PCAResult result) throws Exception {
        int nComponents = result.eigenvectors.length;
        int nRows = ((PCARecipePayloadParams)this.payloadParams).columns.size();
        for (int i = 0; i < nComponents; ++i) {
            if (result.eigenvectors[i].length == nRows) continue;
            throw new IllegalArgumentException(String.format("Eigenvectors should have a length of %d, got %d", nRows, result.eigenvectors[i].length));
        }
        StreamRowFactory rf = new StreamRowFactory();
        StreamColumnFactory cf = new StreamColumnFactory();
        ProcessorOutput writer = this.buildOutputDatasetWriter(PCARecipeOutputRole.EIGENVECTORS.name, (ColumnFactory)cf);
        for (int i = 0; i < nRows; ++i) {
            Row row = rf.row();
            row.put(cf.column(EDASchemaColumns.inputColumn()), ((PCARecipePayloadParams)this.payloadParams).columns.get(i));
            for (int j = 0; j < nComponents; ++j) {
                row.put(cf.column(EDASchemaColumns.principalComponent(j)), result.eigenvectors[j][i]);
            }
            if (((PCARecipePayloadParams)this.payloadParams).addComputationTimestamp) {
                row.put(cf.column("computation_timestamp"), this.computationTimestamp);
            }
            writer.emitRow(row);
        }
        writer.lastRowEmitted();
    }

    private void writeEigenvalues(PCA.PCAResult result) throws Exception {
        Row row;
        int i;
        int nComponents = result.eigenvalues.length;
        int nRows = ((PCARecipePayloadParams)this.payloadParams).columns.size();
        StreamRowFactory rf = new StreamRowFactory();
        StreamColumnFactory cf = new StreamColumnFactory();
        ProcessorOutput writer = this.buildOutputDatasetWriter(PCARecipeOutputRole.EIGENVALUES.name, (ColumnFactory)cf);
        double[] ratios = PCACard.explainedVarianceRatio(result.eigenvalues);
        double cumulativeRatio = 0.0;
        for (i = 0; i < nComponents; ++i) {
            row = rf.row();
            row.put(cf.column(EDASchemaColumns.principalComponent()), EDASchemaColumns.principalComponent(i));
            row.put(cf.column(EDASchemaColumns.variance()), result.eigenvalues[i]);
            row.put(cf.column(EDASchemaColumns.varianceRatio()), ratios[i]);
            row.put(cf.column(EDASchemaColumns.cumulativeVarianceRatio()), cumulativeRatio += ratios[i]);
            if (((PCARecipePayloadParams)this.payloadParams).addComputationTimestamp) {
                row.put(cf.column("computation_timestamp"), this.computationTimestamp);
            }
            writer.emitRow(row);
        }
        for (i = nComponents; i < nRows; ++i) {
            row = rf.row();
            row.put(cf.column(EDASchemaColumns.principalComponent()), EDASchemaColumns.principalComponent(i));
            if (((PCARecipePayloadParams)this.payloadParams).addComputationTimestamp) {
                row.put(cf.column("computation_timestamp"), this.computationTimestamp);
            }
            writer.emitRow(row);
        }
        writer.lastRowEmitted();
    }
}

